Scaling Up (Part 1)

146
Следующее
3 дня – 1076:18
NumPy & JAX NumPy (Part 2)
Популярные
Опубликовано 4 декабря 2025, 5:01
In part one of this three part series on sharding and parallelism we’ll explore how to scale your Flax NNX models using JAX's powerful distributed computing capabilities, specifically its SPMD paradigm. If you're coming from PyTorch and have started using JAX and Flax NNX, you know that modern models often outgrow single accelerators. Let’s discuss JAX's approach to parallelism and how NNX integrates with it seamlessly. This episode will cover the "why" and "what" of distributed training, introducing the fundamental concepts of parallelism and the core JAX primitives needed to implement them.

Resources:
Learn more → goo.gle/learning-jax

Subscribe to Google for Developers → goo.gle/developers

Speaker: Robert Crowe
Свежие видео
3 дня – 4310:25
Enhancing Reliability (Part 1)
3 дня – 5610:00
Efficient Data Loading
автотехномузыкадетское