Skip to content

Fix double-compilation in train_step by matching input sharding.#4174

Open
igorts-git wants to merge 1 commit into
mainfrom
igorts/b440499679-train_step-hlo
Open

Fix double-compilation in train_step by matching input sharding.#4174
igorts-git wants to merge 1 commit into
mainfrom
igorts/b440499679-train_step-hlo

Conversation

@igorts-git

@igorts-git igorts-git commented Jun 16, 2026

Copy link
Copy Markdown
Collaborator

Description

This PR ensures that the input data sharding is propagated to the AOT dummy batch so its JAX signature matches the runtime batch signature.

Problem

In MaxText, the first step of training (step 0) was taking an excessively long time due to a double compilation of train_step. JAX was compiling the train_step function twice:

  • Ahead-Of-Time (AOT) compilation using a dummy shaped_batch (unsharded by default) during initialization.
  • Runtime compilation on step 0 when executing with the actual example_batch from the data pipeline (which is sharded).

Because the AOT batch did not have the correct sharding annotations, JAX encountered a cache miss on the compiled executable key at runtime, triggering a full re-compilation.

Unfortunately, the PR is much larger than what is necessary, because pre-commit insisted in re-formatting a bunch of code and refused to commit without it.

FIXES: b/440499679

Tests

Ran a small model with and without the fix to confirm that the double-compilation issue is fixed.
Added a unit test that is supposed to catch double compilation regressions in the future.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov

codecov Bot commented Jun 16, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 75.00000% with 10 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/trainers/pre_train/train_compile.py 73.07% 7 Missing ⚠️
src/maxtext/utils/maxtext_utils.py 70.00% 3 Missing ⚠️

📢 Thoughts on this report? Let us know!

… AOT compilation shaped batch.

Previously, train_step compiled twice in various training loops (pre-train, SFT): first Ahead-Of-Time (AOT) using an unsharded shaped_batch dummy, and then again on the first step execution using a sharded example_batch from the data pipeline. This caused a slow first training step.

This change propagates the input data sharding to get_shaped_batch in train.py, train_compile.py, and train_sft_native.py so the dummy AOT batch has the same sharding annotation as the runtime batch, matching their signatures and enabling JAX JIT compilation cache hits.

Also added:
- Unit tests for get_shaped_batch sharding in maxtext_utils_test.py.
- CPU compilation cache regression test in compile_cache_test.py.
@igorts-git igorts-git force-pushed the igorts/b440499679-train_step-hlo branch from 48c763b to 1463b83 Compare June 16, 2026 17:42
@igorts-git igorts-git changed the title [DRAFT don't review yet] Fix double-compilation in train_step by matching input sharding. Fix double-compilation in train_step by matching input sharding. Jun 16, 2026
@github-actions

Copy link
Copy Markdown

🤖 Hi @igorts-git, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions

Copy link
Copy Markdown

🤖 I'm sorry @igorts-git, but I was unable to process your request. Please see the logs for more details.

@github-actions

Copy link
Copy Markdown

🤖 Hi @igorts-git, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions

Copy link
Copy Markdown

🤖 I'm sorry @igorts-git, but I was unable to process your request. Please see the logs for more details.

@github-actions

Copy link
Copy Markdown

🤖 Hi @igorts-git, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions

Copy link
Copy Markdown

🤖 I'm sorry @igorts-git, but I was unable to process your request. Please see the logs for more details.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants