diff --git a/src/maxtext/trainers/post_train/sft/train_sft_native.py b/src/maxtext/trainers/post_train/sft/train_sft_native.py index 54596618ee..1291052497 100644 --- a/src/maxtext/trainers/post_train/sft/train_sft_native.py +++ b/src/maxtext/trainers/post_train/sft/train_sft_native.py @@ -80,7 +80,8 @@ def train_loop(config, recorder, state=None): ) with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): - shaped_batch = maxtext_utils.get_shaped_batch(config) + data_sharding = sharding.get_input_data_sharding(config, mesh) + shaped_batch = maxtext_utils.get_shaped_batch(config, batch_sharding=data_sharding) compiled = p_train_step.lower(state, shaped_batch, init_rng).compile() compiled_stats = compiled.memory_analysis() max_utils.print_compiled_memory_stats(compiled_stats) diff --git a/src/maxtext/trainers/pre_train/train.py b/src/maxtext/trainers/pre_train/train.py index fd2cc7b56c..84539c045f 100644 --- a/src/maxtext/trainers/pre_train/train.py +++ b/src/maxtext/trainers/pre_train/train.py @@ -640,7 +640,8 @@ def train_loop(config, recorder, state=None): ) with jax.set_mesh(mesh), mesh, nn_partitioning.axis_rules(config.logical_axis_rules): - shaped_batch = maxtext_utils.get_shaped_batch(config) + data_sharding = sharding.get_input_data_sharding(config, mesh) + shaped_batch = maxtext_utils.get_shaped_batch(config, batch_sharding=data_sharding) if config.shard_optimizer_over_data and isinstance(model, nn.Module): state = sharding.maybe_shard_with_name(state, state_mesh_shardings, config.shard_mode) elif config.shard_optimizer_over_data: diff --git a/src/maxtext/trainers/pre_train/train_compile.py b/src/maxtext/trainers/pre_train/train_compile.py index 3675fcb211..836c425f09 100644 --- a/src/maxtext/trainers/pre_train/train_compile.py +++ b/src/maxtext/trainers/pre_train/train_compile.py @@ -172,7 +172,8 @@ def create_train_state_fn(): logical_annotations = maxtext_utils.get_logical_annotations(config, topology_mesh, init_state_fn) # Shaped batch - shaped_batch = maxtext_utils.get_shaped_batch(config) + data_sharding = sharding.get_input_data_sharding(config, topology_mesh) + shaped_batch = maxtext_utils.get_shaped_batch(config, batch_sharding=data_sharding) if config.pure_nnx: shaped_train_args = ( diff --git a/src/maxtext/utils/maxtext_utils.py b/src/maxtext/utils/maxtext_utils.py index 16b022c3a4..238758da92 100644 --- a/src/maxtext/utils/maxtext_utils.py +++ b/src/maxtext/utils/maxtext_utils.py @@ -148,7 +148,7 @@ def get_reorder_callable(cp_size, shard_mode, reorder_strategy=ReorderStrategy.D ) -def get_shaped_batch(config): +def get_shaped_batch(config, batch_sharding=None): """Return the shape of the batch - this is what eval_shape would return for the output of create_data_iterator, but eval_shape doesn't work, see b/306901078.""" if config.enable_diloco: @@ -160,21 +160,21 @@ def get_shaped_batch(config): else: batch_shape = (config.global_batch_size_to_load, config.max_target_length) shaped_batch = {} - shaped_batch["inputs"] = jax.ShapeDtypeStruct(batch_shape, jnp.int32) - shaped_batch["inputs_position"] = jax.ShapeDtypeStruct(batch_shape, jnp.int32) - shaped_batch["inputs_segmentation"] = jax.ShapeDtypeStruct(batch_shape, jnp.int32) - shaped_batch["targets"] = jax.ShapeDtypeStruct(batch_shape, jnp.int32) - shaped_batch["targets_position"] = jax.ShapeDtypeStruct(batch_shape, jnp.int32) - shaped_batch["targets_segmentation"] = jax.ShapeDtypeStruct(batch_shape, jnp.int32) + shaped_batch["inputs"] = jax.ShapeDtypeStruct(batch_shape, jnp.int32, sharding=batch_sharding) + shaped_batch["inputs_position"] = jax.ShapeDtypeStruct(batch_shape, jnp.int32, sharding=batch_sharding) + shaped_batch["inputs_segmentation"] = jax.ShapeDtypeStruct(batch_shape, jnp.int32, sharding=batch_sharding) + shaped_batch["targets"] = jax.ShapeDtypeStruct(batch_shape, jnp.int32, sharding=batch_sharding) + shaped_batch["targets_position"] = jax.ShapeDtypeStruct(batch_shape, jnp.int32, sharding=batch_sharding) + shaped_batch["targets_segmentation"] = jax.ShapeDtypeStruct(batch_shape, jnp.int32, sharding=batch_sharding) if config.use_multimodal: image_shape = mm_processor.get_dummy_image_shape_for_init( config.model_name, batch_size=config.micro_batch_size_to_train_on ) - shaped_batch["images"] = jax.ShapeDtypeStruct(image_shape, jnp.int32) - shaped_batch["image_masks"] = jax.ShapeDtypeStruct(image_shape[:2], jnp.int32) + shaped_batch["images"] = jax.ShapeDtypeStruct(image_shape, jnp.int32, sharding=batch_sharding) + shaped_batch["image_masks"] = jax.ShapeDtypeStruct(image_shape[:2], jnp.int32, sharding=batch_sharding) if config.use_audio: audio_shape = mm_processor.get_dummy_audio_shape_for_init(config) - shaped_batch["audios"] = jax.ShapeDtypeStruct(audio_shape, jnp.float32) + shaped_batch["audios"] = jax.ShapeDtypeStruct(audio_shape, jnp.float32, sharding=batch_sharding) return shaped_batch diff --git a/tests/unit/compile_cache_test.py b/tests/unit/compile_cache_test.py new file mode 100644 index 0000000000..b94cd0dec6 --- /dev/null +++ b/tests/unit/compile_cache_test.py @@ -0,0 +1,123 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for JAX compilation cache hits in train.py. + +This test ensures that the `train_step` function is compiled only once. +It verifies that the Ahead-Of-Time (AOT) compilation signature (which uses +a dummy `shaped_batch` constructed in `train.py`) matches the runtime +compilation signature (which uses the actual `example_batch` from the data pipeline). + +If this test fails, it likely means a regression was introduced where the AOT +batch sharding/shape does not match the runtime batch sharding/shape. This causes +JAX to recompile `train_step` at step 0, leading to a "double compilation" +and a very slow first step. + +To debug: +1. Verify that `maxtext_utils.get_shaped_batch` in `train.py` is called with the + correct `sharding` argument (matching the data pipeline sharding). +2. Check if there are differences in shapes or dtypes between the AOT dummy batch + and the runtime batch. +""" + +import os +import tempfile +import shutil +import pytest +import subprocess +import sys + +from tests.utils.test_helpers import ( + get_test_config_path, + get_test_base_output_directory, +) + + +@pytest.mark.cpu_only +def test_train_step_cache_hit(): + temp_dir = tempfile.mkdtemp() + _base_output_directory = get_test_base_output_directory() + + try: + small_model_overrides = [ + "base_emb_dim=16", + "base_num_query_heads=4", + "base_num_kv_heads=4", + "base_mlp_dim=16", + "base_num_decoder_layers=1", + "head_dim=64", + "max_target_length=64", + "vocab_size=32", + "sharding_tolerance=0.1", + ] + + cmd = [ + sys.executable, + "-m", + "maxtext.trainers.pre_train.train", + get_test_config_path(), + f"base_output_directory={_base_output_directory}", + "run_name=compile_cache_test_cpu", + "steps=2", + "enable_checkpointing=False", + "enable_goodput_recording=False", + "dataset_type=synthetic", + "hardware=cpu", + "skip_jax_distributed_system=True", + f"jax_cache_dir={temp_dir}", + ] + small_model_overrides + + env = os.environ.copy() + env["JAX_PLATFORMS"] = "cpu" + env["JAX_ENABLE_COMPILATION_CACHE"] = "true" + env["JAX_COMPILATION_CACHE_DIR"] = temp_dir + env["JAX_LOG_COMPILES"] = "1" + + print("Running CPU training subprocess:", " ".join(cmd)) + result = subprocess.run(cmd, env=env, capture_output=True, text=True, check=True) + + captured_logs = result.stderr + + # Print captured logs for debugging (will be shown by pytest if assert fails) + print("=== Captured Subprocess Stderr ===") + print(captured_logs) + print("===================================") + + # Check if cache dir has files + cache_files = os.listdir(temp_dir) + print("=== Cache Directory Content ===") + print(f"Path: {temp_dir}") + print(f"Files: {cache_files}") + print("===============================") + + assert len(cache_files) > 0, ( + "JAX compilation cache directory is empty. This suggests the compilation " + "cache was not writeable or the JAX cache configuration was ignored." + ) + + assert len(cache_files) == 1, ( + f"Expected exactly 1 JAX compilation cache file, but found {len(cache_files)}: {cache_files}. " + "This indicates a cache miss where AOT compilation and runtime execution generated different keys, " + "causing train_step to be compiled twice (double-compilation regression)." + ) + + assert "Persistent compilation cache hit for 'jit_train_step'" in captured_logs, ( + "Did not find 'Persistent compilation cache hit for 'jit_train_step'' in logs. " + "This means the runtime execution of train_step did not hit the cache populated by the AOT compilation. " + "Check if the AOT input batch signature (shape/dtype/sharding) matches the runtime input batch." + ) + + finally: + if os.path.exists(temp_dir): + shutil.rmtree(temp_dir) diff --git a/tests/unit/maxtext_utils_test.py b/tests/unit/maxtext_utils_test.py index 3d4e983281..2e90880a83 100644 --- a/tests/unit/maxtext_utils_test.py +++ b/tests/unit/maxtext_utils_test.py @@ -1104,6 +1104,7 @@ def test_linen_in_shardings_includes_rng(self): self.assertEqual(len(in_shardings), 3) +@pytest.mark.cpu_only class TestGetShapedBatch(unittest.TestCase): """Tests for get_shaped_batch.""" @@ -1159,6 +1160,25 @@ def test_all_values_are_shape_dtype_struct(self): for v in batch.values(): self.assertIsInstance(v, jax.ShapeDtypeStruct) + def test_get_shaped_batch_unsharded(self): + """Verify that get_shaped_batch returns unsharded ShapeDtypeStructs by default.""" + cfg = self._make_cfg() + shaped_batch = maxtext_utils.get_shaped_batch(cfg) + self.assertIn("inputs", shaped_batch) + self.assertIsNone(shaped_batch["inputs"].sharding) + + def test_get_shaped_batch_sharded(self): + """Verify that get_shaped_batch applies the passed sharding to ShapeDtypeStructs.""" + cfg = self._make_cfg() + devices = np.array(jax.local_devices()[:1]).reshape( + 1, + ) + mesh = Mesh(devices, ("x",)) + sharding_spec = NamedSharding(mesh, PartitionSpec("x")) + shaped_batch = maxtext_utils.get_shaped_batch(cfg, batch_sharding=sharding_spec) + self.assertIn("inputs", shaped_batch) + self.assertEqual(shaped_batch["inputs"].sharding, sharding_spec) + class TestShouldPreventCseInRemat(unittest.TestCase): """Tests for should_prevent_cse_in_remat."""