Fix double-compilation in train_step by matching input sharding.#4174
Fix double-compilation in train_step by matching input sharding.#4174igorts-git wants to merge 1 commit into
Conversation
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
… AOT compilation shaped batch. Previously, train_step compiled twice in various training loops (pre-train, SFT): first Ahead-Of-Time (AOT) using an unsharded shaped_batch dummy, and then again on the first step execution using a sharded example_batch from the data pipeline. This caused a slow first training step. This change propagates the input data sharding to get_shaped_batch in train.py, train_compile.py, and train_sft_native.py so the dummy AOT batch has the same sharding annotation as the runtime batch, matching their signatures and enabling JAX JIT compilation cache hits. Also added: - Unit tests for get_shaped_batch sharding in maxtext_utils_test.py. - CPU compilation cache regression test in compile_cache_test.py.
48c763b to
1463b83
Compare
|
🤖 Hi @igorts-git, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
|
🤖 I'm sorry @igorts-git, but I was unable to process your request. Please see the logs for more details. |
|
🤖 Hi @igorts-git, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
|
🤖 I'm sorry @igorts-git, but I was unable to process your request. Please see the logs for more details. |
|
🤖 Hi @igorts-git, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
|
🤖 I'm sorry @igorts-git, but I was unable to process your request. Please see the logs for more details. |
Description
This PR ensures that the input data sharding is propagated to the AOT dummy batch so its JAX signature matches the runtime batch signature.
Problem
In MaxText, the first step of training (step 0) was taking an excessively long time due to a double compilation of train_step. JAX was compiling the train_step function twice:
Because the AOT batch did not have the correct sharding annotations, JAX encountered a cache miss on the compiled executable key at runtime, triggering a full re-compilation.
Unfortunately, the PR is much larger than what is necessary, because pre-commit insisted in re-formatting a bunch of code and refused to commit without it.
FIXES: b/440499679
Tests
Ran a small model with and without the fix to confirm that the double-compilation issue is fixed.
Added a unit test that is supposed to catch double compilation regressions in the future.
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.