Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/maxtext/trainers/post_train/sft/train_sft_native.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ def train_loop(config, recorder, state=None):
)

with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
shaped_batch = maxtext_utils.get_shaped_batch(config)
data_sharding = sharding.get_input_data_sharding(config, mesh)
shaped_batch = maxtext_utils.get_shaped_batch(config, batch_sharding=data_sharding)
compiled = p_train_step.lower(state, shaped_batch, init_rng).compile()
compiled_stats = compiled.memory_analysis()
max_utils.print_compiled_memory_stats(compiled_stats)
Expand Down
3 changes: 2 additions & 1 deletion src/maxtext/trainers/pre_train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,7 +640,8 @@ def train_loop(config, recorder, state=None):
)

with jax.set_mesh(mesh), mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
shaped_batch = maxtext_utils.get_shaped_batch(config)
data_sharding = sharding.get_input_data_sharding(config, mesh)
shaped_batch = maxtext_utils.get_shaped_batch(config, batch_sharding=data_sharding)
if config.shard_optimizer_over_data and isinstance(model, nn.Module):
state = sharding.maybe_shard_with_name(state, state_mesh_shardings, config.shard_mode)
elif config.shard_optimizer_over_data:
Expand Down
3 changes: 2 additions & 1 deletion src/maxtext/trainers/pre_train/train_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,8 @@ def create_train_state_fn():
logical_annotations = maxtext_utils.get_logical_annotations(config, topology_mesh, init_state_fn)

# Shaped batch
shaped_batch = maxtext_utils.get_shaped_batch(config)
data_sharding = sharding.get_input_data_sharding(config, topology_mesh)
shaped_batch = maxtext_utils.get_shaped_batch(config, batch_sharding=data_sharding)

if config.pure_nnx:
shaped_train_args = (
Expand Down
20 changes: 10 additions & 10 deletions src/maxtext/utils/maxtext_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def get_reorder_callable(cp_size, shard_mode, reorder_strategy=ReorderStrategy.D
)


def get_shaped_batch(config):
def get_shaped_batch(config, batch_sharding=None):
"""Return the shape of the batch - this is what eval_shape would return for the
output of create_data_iterator, but eval_shape doesn't work, see b/306901078."""
if config.enable_diloco:
Expand All @@ -160,21 +160,21 @@ def get_shaped_batch(config):
else:
batch_shape = (config.global_batch_size_to_load, config.max_target_length)
shaped_batch = {}
shaped_batch["inputs"] = jax.ShapeDtypeStruct(batch_shape, jnp.int32)
shaped_batch["inputs_position"] = jax.ShapeDtypeStruct(batch_shape, jnp.int32)
shaped_batch["inputs_segmentation"] = jax.ShapeDtypeStruct(batch_shape, jnp.int32)
shaped_batch["targets"] = jax.ShapeDtypeStruct(batch_shape, jnp.int32)
shaped_batch["targets_position"] = jax.ShapeDtypeStruct(batch_shape, jnp.int32)
shaped_batch["targets_segmentation"] = jax.ShapeDtypeStruct(batch_shape, jnp.int32)
shaped_batch["inputs"] = jax.ShapeDtypeStruct(batch_shape, jnp.int32, sharding=batch_sharding)
shaped_batch["inputs_position"] = jax.ShapeDtypeStruct(batch_shape, jnp.int32, sharding=batch_sharding)
shaped_batch["inputs_segmentation"] = jax.ShapeDtypeStruct(batch_shape, jnp.int32, sharding=batch_sharding)
shaped_batch["targets"] = jax.ShapeDtypeStruct(batch_shape, jnp.int32, sharding=batch_sharding)
shaped_batch["targets_position"] = jax.ShapeDtypeStruct(batch_shape, jnp.int32, sharding=batch_sharding)
shaped_batch["targets_segmentation"] = jax.ShapeDtypeStruct(batch_shape, jnp.int32, sharding=batch_sharding)
if config.use_multimodal:
image_shape = mm_processor.get_dummy_image_shape_for_init(
config.model_name, batch_size=config.micro_batch_size_to_train_on
)
shaped_batch["images"] = jax.ShapeDtypeStruct(image_shape, jnp.int32)
shaped_batch["image_masks"] = jax.ShapeDtypeStruct(image_shape[:2], jnp.int32)
shaped_batch["images"] = jax.ShapeDtypeStruct(image_shape, jnp.int32, sharding=batch_sharding)
shaped_batch["image_masks"] = jax.ShapeDtypeStruct(image_shape[:2], jnp.int32, sharding=batch_sharding)
if config.use_audio:
audio_shape = mm_processor.get_dummy_audio_shape_for_init(config)
shaped_batch["audios"] = jax.ShapeDtypeStruct(audio_shape, jnp.float32)
shaped_batch["audios"] = jax.ShapeDtypeStruct(audio_shape, jnp.float32, sharding=batch_sharding)
return shaped_batch


Expand Down
123 changes: 123 additions & 0 deletions tests/unit/compile_cache_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for JAX compilation cache hits in train.py.

This test ensures that the `train_step` function is compiled only once.
It verifies that the Ahead-Of-Time (AOT) compilation signature (which uses
a dummy `shaped_batch` constructed in `train.py`) matches the runtime
compilation signature (which uses the actual `example_batch` from the data pipeline).

If this test fails, it likely means a regression was introduced where the AOT
batch sharding/shape does not match the runtime batch sharding/shape. This causes
JAX to recompile `train_step` at step 0, leading to a "double compilation"
and a very slow first step.

To debug:
1. Verify that `maxtext_utils.get_shaped_batch` in `train.py` is called with the
correct `sharding` argument (matching the data pipeline sharding).
2. Check if there are differences in shapes or dtypes between the AOT dummy batch
and the runtime batch.
"""

import os
import tempfile
import shutil
import pytest
import subprocess
import sys

from tests.utils.test_helpers import (
get_test_config_path,
get_test_base_output_directory,
)


@pytest.mark.cpu_only
def test_train_step_cache_hit():
temp_dir = tempfile.mkdtemp()
_base_output_directory = get_test_base_output_directory()

try:
small_model_overrides = [
"base_emb_dim=16",
"base_num_query_heads=4",
"base_num_kv_heads=4",
"base_mlp_dim=16",
"base_num_decoder_layers=1",
"head_dim=64",
"max_target_length=64",
"vocab_size=32",
"sharding_tolerance=0.1",
]

cmd = [
sys.executable,
"-m",
"maxtext.trainers.pre_train.train",
get_test_config_path(),
f"base_output_directory={_base_output_directory}",
"run_name=compile_cache_test_cpu",
"steps=2",
"enable_checkpointing=False",
"enable_goodput_recording=False",
"dataset_type=synthetic",
"hardware=cpu",
"skip_jax_distributed_system=True",
f"jax_cache_dir={temp_dir}",
] + small_model_overrides

env = os.environ.copy()
env["JAX_PLATFORMS"] = "cpu"
env["JAX_ENABLE_COMPILATION_CACHE"] = "true"
env["JAX_COMPILATION_CACHE_DIR"] = temp_dir
env["JAX_LOG_COMPILES"] = "1"

print("Running CPU training subprocess:", " ".join(cmd))
result = subprocess.run(cmd, env=env, capture_output=True, text=True, check=True)

captured_logs = result.stderr

# Print captured logs for debugging (will be shown by pytest if assert fails)
print("=== Captured Subprocess Stderr ===")
print(captured_logs)
print("===================================")

# Check if cache dir has files
cache_files = os.listdir(temp_dir)
print("=== Cache Directory Content ===")
print(f"Path: {temp_dir}")
print(f"Files: {cache_files}")
print("===============================")

assert len(cache_files) > 0, (
"JAX compilation cache directory is empty. This suggests the compilation "
"cache was not writeable or the JAX cache configuration was ignored."
)

assert len(cache_files) == 1, (
f"Expected exactly 1 JAX compilation cache file, but found {len(cache_files)}: {cache_files}. "
"This indicates a cache miss where AOT compilation and runtime execution generated different keys, "
"causing train_step to be compiled twice (double-compilation regression)."
)

assert "Persistent compilation cache hit for 'jit_train_step'" in captured_logs, (
"Did not find 'Persistent compilation cache hit for 'jit_train_step'' in logs. "
"This means the runtime execution of train_step did not hit the cache populated by the AOT compilation. "
"Check if the AOT input batch signature (shape/dtype/sharding) matches the runtime input batch."
)

finally:
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir)
20 changes: 20 additions & 0 deletions tests/unit/maxtext_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1104,6 +1104,7 @@ def test_linen_in_shardings_includes_rng(self):
self.assertEqual(len(in_shardings), 3)


@pytest.mark.cpu_only
class TestGetShapedBatch(unittest.TestCase):
"""Tests for get_shaped_batch."""

Expand Down Expand Up @@ -1159,6 +1160,25 @@ def test_all_values_are_shape_dtype_struct(self):
for v in batch.values():
self.assertIsInstance(v, jax.ShapeDtypeStruct)

def test_get_shaped_batch_unsharded(self):
"""Verify that get_shaped_batch returns unsharded ShapeDtypeStructs by default."""
cfg = self._make_cfg()
shaped_batch = maxtext_utils.get_shaped_batch(cfg)
self.assertIn("inputs", shaped_batch)
self.assertIsNone(shaped_batch["inputs"].sharding)

def test_get_shaped_batch_sharded(self):
"""Verify that get_shaped_batch applies the passed sharding to ShapeDtypeStructs."""
cfg = self._make_cfg()
devices = np.array(jax.local_devices()[:1]).reshape(
1,
)
mesh = Mesh(devices, ("x",))
sharding_spec = NamedSharding(mesh, PartitionSpec("x"))
shaped_batch = maxtext_utils.get_shaped_batch(cfg, batch_sharding=sharding_spec)
self.assertIn("inputs", shaped_batch)
self.assertEqual(shaped_batch["inputs"].sharding, sharding_spec)


class TestShouldPreventCseInRemat(unittest.TestCase):
"""Tests for should_prevent_cse_in_remat."""
Expand Down
Loading