Checkpointing Flax NNX Models with Orbax (Part 1)

42
Следующее
Популярные
18 дней – 31 5639:22
Gemini 3 for Developers
3 дня – 4413:03
Scaling Up (Part 3)
Опубликовано 4 декабря 2025, 5:00
Orbax is the standard checkpointing library in the JAX ecosystem. In part one of this two episode series on checkpointing with Orbax we’ll explore how to effectively save and restore your Flax NNX models. As we’ve seen, Flax NNX offers a Pythonic, stateful approach which feels closer to frameworks like PyTorch, while still having all the advantages of JAX. We'll cover how NNX manages state and how Orbax interacts with it, starting with the basics and moving to advanced techniques like handling distributed training. This episode will cover the core concepts: what NNX state is, how Orbax is structured, and the complete workflow for saving and restoring a single, basic NNX model.

Resources:
Learn more → goo.gle/learning-jax

Subscribe to Google for Developers → goo.gle/developers

Speaker: Robert Crowe
автотехномузыкадетское