Training our most capable Gemini models relies heavily on our JAX software stack+Google's TPU hardware platforms.
If you want to learn more, see this awesome book "How to Scale Your Model":
jax-ml.github.io/scaling-book/
Put together by several of my Google DeepMind colleagues listed below 🎉.
Posts by
We want this to be a living book, so please ask questions and give us feedback. We'll continue adding to it as time goes on. Without further ado, here’s a link to the beginning: jax-ml.github.io/scaling-book/ 11/11
The book was co-written with @sholtodouglas.bsky.social, @froystig.bsky.social, @levskaya.bsky.social, @reinerpope.bsky.social, Albert Webson, Charlie Chen, Vinay Ramasesh, and Federico Lebron 10/n
LLM systems programming is super fun! It's hard to do good ML research without it these days, and you don't need much compute to work on it. I hope this book will make it easier for more people (esp. academics) to work on this stuff 9/n
The rest of the book is a set of practical guides: how to write and profile parallel JAX code, and how to apply the previous two sections to real models like LLaMA-3. We also have worked problems at the end of each section if you like homework: jax-ml.github.io/scaling-book... 8/n
Now that we’ve talked about training, we need to talk about serving. How expensive should a model be to serve? What kind of latency can we expect? What are prefill and generation? How do we build an efficient inference service? We talk about this here: jax-ml.github.io/scaling-book... 7/n
Now for the good stuff! You may have heard of data or tensor parallelism, FSDP or pipelining. But why choose one over the other? Short answer: each adds communication, and the one with the lowest cost depends on the model. Part 5 dives into this: jax-ml.github.io/scaling-book... 6/n
5 years ago, there were many ML architectures, but today, there is (mostly) only one. _You should know the Transformer inside and out!_ How many FLOPs or params in LLaMA-3? How expensive is attention vs. a feed-forward block? You'll know after reading jax-ml.github.io/scaling-book... 5/n
Scaling an LLM involves distributing — a.k.a. "sharding" — its weights across multiple TPUs. To run it, we have to add cross-chip communication. Part 3 describes the TPU's communication primitives, and simple rules for multiplying sharded matrices: jax-ml.github.io/scaling-book... 4/n
A big chunk of this book is dedicated to understanding the hardware that provides those system resources. We emphasize TPUs in this book, but the principles and math can be adapted to GPUs too. Part 2 explains the TPU in detail: jax-ml.github.io/scaling-book... 3/n
The secret is to think in terms of basic system resources — compute, memory, and bandwidth — and calculate which one limits our performance. From this we can estimate the cost, runtime, and optimal parallelism strategy for any given LLM: jax-ml.github.io/scaling-book/ 2/n
Making LLMs run efficiently can feel scary, but scaling isn’t magic, it’s math! We wanted to demystify the “systems view” of LLMs and wrote a little textbook called “How To Scale Your Model” which we’re releasing today. 1/n
Excited to be here! Hopefully the skies are brighter on this side of the fence. Will be posting research stuff here, mostly