Optimizing Flax NNX Models with Optax (Part 1)

290
Следующее
136 дней – 3287:13
JAX AI Stack: Summary & Conclusion
Популярные
19 дней – 3 06345:35
Prototype to Production with ADK
Опубликовано 4 декабря 2025, 5:00
This is the first of our two part series on optimizing models with Optax. Let’s explore how to optimize Flax NNX neural network models using Optax, the primary optimization library in the JAX ecosystem. This presentation is designed especially for those of you familiar with PyTorch, and we'll be drawing parallels and highlighting differences to help you transition smoothly. This first episode will cover the core workflow, rom defining a model to running a complete, JIT-compiled training loop. It establishes the foundational knowledge that we’ll need for the more advanced topics in the next episode.

Resources:
Learn more → goo.gle/learning-jax

Subscribe to Google for Developers → goo.gle/developers

Speaker: Robert Crowe
автотехномузыкадетское