Unlocking Low-Level Control: Customizing Keras Training Loops with JAX

445
6.4
Опубликовано 30 апреля 2026, 23:00
Do you want the speed and functional power of JAX without losing the high-level convenience of model.fit? In this video, Google ML Developer Advocate Yufeng Guo (@yufengg) explains how Keras implements the principle of Progressive Disclosure of Complexity.

Learn how to take full control of your learning algorithms by overriding the train_step() and test_step() methods while keeping access to built-in callbacks, distribution support, and evaluation tools.

What You’ll Learn:
- Why override train_step instead of writing a loop from scratch?
- Understanding how to handle trainable variables, non-trainable variables, and optimizer states in a functional environment.
- Creating a compute_loss_and_updates function to manage forward passes and auxiliary data.
- Using jax.value_and_grad to compute gradients and losses simultaneously.
- Updating evaluation metrics using stateless_update_state.


Chapters:
0:00 - Introduction & The Default model.fit()
0:18 - Customizing Keras Training Loops
0:46 - Overriding train_step()
1:14 - Setting up the JAX Backend
1:26 - The Stateless train_step
2:11 - Stateless Loss Computation
3:04 - Taking Gradients in train_step
4:06 - How to pass around non-trainable variables
4:43 - Updating the Model Weights
5:05 - Handling Metrics
5:21 - Custom Evaluation Loops (overriding test_step)


Resources:
Complete Code Example →goo.gle/4eeSvlD
Keras Documentation → goo.gle/42Ebpv0
Keras Developer Guides →goo.gle/4um97N3

Subscribe to Google for Developers → goo.gle/developers

Speaker: Yufeng Guo
Products Mentioned: Google AI
автотехномузыкадетское