diff --git a/docs/tutorials/posttraining/knowledge_distillation.md b/docs/tutorials/posttraining/knowledge_distillation.md index 2f759d75e6..bfadefebc2 100644 --- a/docs/tutorials/posttraining/knowledge_distillation.md +++ b/docs/tutorials/posttraining/knowledge_distillation.md @@ -234,7 +234,7 @@ python3 -m maxtext.checkpoint_conversion.to_maxtext \ The online distillation trainer depends on Tunix. The XPK launcher script ([`scripts/run_distill_xpk.sh`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/trainers/post_train/distillation/scripts/run_distill_xpk.sh)) contains a `prep_image` step that layers Tunix on top of the MaxText base image. For local runs, install the same pin used by the launcher — the default `TUNIX_SOURCE` in `run_distill_xpk.sh` is the source of truth. As of this writing: ```bash -pip install "git+https://github.com/google/tunix@110932a8395086511228483312131841521695c1" +pip install "git+https://github.com/google/tunix@44af800726dd5b2c5779a1987a9294f9a3eec9ef" ``` > **Note:** The commit pin above will drift as the launcher is updated. Before installing, check the `TUNIX_SOURCE` default in [`run_distill_xpk.sh`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/trainers/post_train/distillation/scripts/run_distill_xpk.sh) and use that spec. Once a Tunix PyPI release ships, this will become a versioned `google-tunix==` install. diff --git a/src/dependencies/extra_deps/post_train_github_deps.txt b/src/dependencies/extra_deps/post_train_github_deps.txt index 02730e98e7..bcf5bae01a 100644 --- a/src/dependencies/extra_deps/post_train_github_deps.txt +++ b/src/dependencies/extra_deps/post_train_github_deps.txt @@ -1,3 +1,3 @@ -google-tunix @ https://github.com/google/tunix/archive/683256db1a0919b5cfd46cee52cebc96331494fb.zip +google-tunix @ https://github.com/google/tunix/archive/44af800726dd5b2c5779a1987a9294f9a3eec9ef.zip tpu-inference @ https://github.com/vllm-project/tpu-inference/archive/4d08971683a64fd796a1b9fd0bb71128188882d5.zip vllm @ git+https://github.com/vllm-project/vllm@2131b597b18d051dced4c4a605d362fa37f46ed1 diff --git a/src/maxtext/common/checkpointing.py b/src/maxtext/common/checkpointing.py index 73f475bb39..53825a0f0f 100644 --- a/src/maxtext/common/checkpointing.py +++ b/src/maxtext/common/checkpointing.py @@ -14,6 +14,7 @@ """Create an Orbax CheckpointManager with specified (Async or not) Checkpointer.""" +import asyncio import time from typing import Any, Optional @@ -168,6 +169,44 @@ class GrainCheckpointRestore(ocp.args.CheckpointArgs): process_count: Optional[int] = None +class GrainCheckpointable(ocp_v1.StatefulCheckpointable): + """Adapts `GrainCheckpointHandler` to Orbax v1's `StatefulCheckpointable`.""" + + def __init__( + self, + *, + save_args: GrainCheckpointSave | None = None, + restore_args: GrainCheckpointRestore | None = None, + ): + self._handler = GrainCheckpointHandler() + self._save_args = save_args + self._restore_args = restore_args + + async def save(self, directory): + """Saves the Grain iterator state to the given directory.""" + # `GrainCheckpointHandler.save` snapshots iterator state (`get_state`) AND + # writes it; both must happen in this (blocking) save phase, NOT in the + # returned background coroutine. + path = await directory.await_creation() + self._handler.save(path, args=self._save_args) + + async def _committed(): # nothing left for the background commit phase + return None + + return _committed() + + async def load(self, directory: epath.Path): + """Loads the Grain iterator state from the given directory.""" + handler, args = self._handler, self._restore_args + + # This will be ran to completion so asynchronous portion is just for + # compatibility with Orbax v1 API. + async def _background_load(): + await asyncio.to_thread(handler.restore, directory, args=args) + + return _background_load() + + def _default_for_sds(sds): """Returns a deterministic value matching `sds` shape/dtype/sharding. diff --git a/src/maxtext/trainers/post_train/distillation/distillation_utils.py b/src/maxtext/trainers/post_train/distillation/distillation_utils.py index f063cdb23a..659b2eb66d 100644 --- a/src/maxtext/trainers/post_train/distillation/distillation_utils.py +++ b/src/maxtext/trainers/post_train/distillation/distillation_utils.py @@ -31,8 +31,7 @@ from maxtext.utils import max_logging from maxtext.utils import maxtext_utils -# Reuse MaxText's native checkpointing logic. -from maxtext.common.checkpointing import GrainCheckpointHandler, GrainCheckpointSave, GrainCheckpointRestore +from maxtext.common import checkpointing from tunix.sft import checkpoint_manager as tunix_checkpoint_manager from tunix.sft import peft_trainer @@ -647,8 +646,9 @@ def create_labels(self, targets, targets_segmentation=None, **kwargs): class MaxTextCheckpointManager(tunix_checkpoint_manager.CheckpointManager): """Custom CheckpointManager that uses MaxText's native handlers. - This manager extends Tunix to support saving/restoring the MaxText input pipeline - (Grain) alongside the model and optimizer. + Model and optimizer are delegated to Tunix's v1 ``Checkpointer`` unchanged. + The Grain input pipeline is added as an extra ``"iter"`` checkpointable via + ``GrainCheckpointable``, which wraps MaxText's ``GrainCheckpointHandler``. """ def __init__( @@ -662,32 +662,6 @@ def __init__( self.student_config = student_config self._iterator = raw_iterator - # Re-initialize internal Orbax manager with MaxText's Grain handler - # pylint: disable=access-member-before-definition - # pytype: disable=attribute-error - if self._checkpoint_manager is not None: - root_directory = self._checkpoint_manager.directory - - if options is None: - options = getattr(self._checkpoint_manager, "options", None) - - item_handlers = { - "model_params": checkpoint.PyTreeCheckpointHandler(), - "optimizer_state": checkpoint.PyTreeCheckpointHandler(), - "custom_metadata": checkpoint.JsonCheckpointHandler(), - # Use MaxText's handler for the iterator - "iter": GrainCheckpointHandler(), - } - - self._checkpoint_manager.close() - self._checkpoint_manager = checkpoint.CheckpointManager( - root_directory, - item_handlers=item_handlers, - options=options, - ) - # pytype: enable=attribute-error - # pylint: enable=access-member-before-definition - def save( self, step, @@ -697,10 +671,8 @@ def save( force=False, custom_metadata=None, ): - """Saves the checkpoint including the input pipeline state (if available).""" - if self._checkpoint_manager is None: - return False - if not force and not self._checkpoint_manager.should_save(step): + """Saves model, optimizer and the Grain input pipeline state.""" + if self._checkpointer is None: return False # Standard Tunix Logic for Model/Optimizer. @@ -711,21 +683,12 @@ def save( else: params = nnx.state(target_model) - # Define standard SaveArgs once to reuse - default_save_args = checkpoint.SaveArgs() - cp_save_args = { - "model_params": checkpoint.args.PyTreeSave( - item=params, save_args=jax.tree.map(lambda _: default_save_args, params) - ), - } - # Exclude optimizer state if the flag is set OR if learn_to_init_mode is active. + checkpointables: dict[str, Any] = {"model_params": params} + # Exclude optimizer state when learn_to_init_mode is active. exclude_opt = self.student_config.learn_to_init_mode if optimizer is not None and not exclude_opt: - optimizer_state = nnx.state(optimizer, nnx.optimizer.OptState) - cp_save_args["optimizer_state"] = checkpoint.args.PyTreeSave( - item=optimizer_state, save_args=jax.tree.map(lambda _: default_save_args, optimizer_state) - ) + checkpointables["optimizer_state"] = nnx.state(optimizer, nnx.optimizer.OptState) if self._iterator is not None: # Follow MaxText's logic to handle multi-process saving @@ -743,15 +706,11 @@ def save( local_iter = data_iter.local_iterator if hasattr(data_iter, "local_iterator") else data_iter grain_iters_to_save.append((local_iter, process_index, process_count_total)) - # Use GrainCheckpointSave wrapper - cp_save_args["iter"] = GrainCheckpointSave(item=grain_iters_to_save) + checkpointables["iter"] = checkpointing.GrainCheckpointable( + save_args=checkpointing.GrainCheckpointSave(item=grain_iters_to_save) + ) - return self._checkpoint_manager.save( - step, - args=checkpoint.args.Composite(**cp_save_args), - custom_metadata=custom_metadata or {}, - force=force, - ) + return self._save_checkpointables(step, checkpointables, force, custom_metadata) def maybe_restore( self, @@ -766,12 +725,12 @@ def maybe_restore( Returns: (restored step, custom_metadata dict). Step is 0 if no checkpoint exists. """ - if self._checkpoint_manager is None: + if self._checkpointer is None: return 0, {} target_model = getattr(model, "student_model", model) - step, _ = super().maybe_restore( + step, custom_metadata = super().maybe_restore( model=target_model, optimizer=optimizer, restore_only_lora_params=restore_only_lora_params, @@ -781,20 +740,14 @@ def maybe_restore( max_logging.log(f"Restored from checkpoint step {step}.") - metadata = self._checkpoint_manager.metadata(step) - if metadata and hasattr(metadata, "custom_metadata") and metadata.custom_metadata is not None: - custom_metadata = metadata.custom_metadata - else: - custom_metadata = {} - - return step, dict(custom_metadata) + return step, dict(custom_metadata or {}) def restore_iterator(self): """Restores the iterator using MaxText's logic.""" - if self._checkpoint_manager is None or self._iterator is None: + if self._checkpointer is None or self._iterator is None: return None - step = self._checkpoint_manager.latest_step() + step = self.latest_step() if step is None: return None @@ -804,9 +757,10 @@ def restore_iterator(self): data_iter = self._iterator local_iter = data_iter.local_iterator if hasattr(data_iter, "local_iterator") else data_iter - restore_args = GrainCheckpointRestore(item=local_iter) - - self._checkpoint_manager.restore(step, args=checkpoint.args.Composite(iter=restore_args)) + self._checkpointer.load_checkpointables( + step, + {"iter": checkpointing.GrainCheckpointable(restore_args=checkpointing.GrainCheckpointRestore(item=local_iter))}, + ) # Since Grain restores in-place via set_state(), we return the original object return self._iterator @@ -816,5 +770,5 @@ def restore_iterator(self): def wait_until_finished(self): """Blocks until all outstanding checkpoint operations are complete.""" - if self._checkpoint_manager is not None: - self._checkpoint_manager.wait_until_finished() + if self._checkpointer is not None: + self._checkpointer.wait() diff --git a/src/maxtext/trainers/post_train/distillation/scripts/run_distill_xpk.sh b/src/maxtext/trainers/post_train/distillation/scripts/run_distill_xpk.sh index 91078508d9..57487abb9f 100644 --- a/src/maxtext/trainers/post_train/distillation/scripts/run_distill_xpk.sh +++ b/src/maxtext/trainers/post_train/distillation/scripts/run_distill_xpk.sh @@ -105,7 +105,7 @@ # # Image pinning (used by prep_image): # TUNIX_SOURCE pip-installable spec for tunix. -# default: git+https://github.com/google/tunix@110932a8395086511228483312131841521695c1 +# default: git+https://github.com/google/tunix@44af800726dd5b2c5779a1987a9294f9a3eec9ef # Use "google-tunix==" once a pypi release ships with the # multi-host shard_input fix. # JAX_PIN default: 0.10.0 — version to pin back after tunix deps resolve. @@ -164,7 +164,7 @@ require_env() { : "${DISTILL_LAYER_INDICES:=[0,1,2,3,4,5,6,7]}" # Image pinning (used by prep_image). -: "${TUNIX_SOURCE:=git+https://github.com/google/tunix@110932a8395086511228483312131841521695c1}" +: "${TUNIX_SOURCE:=git+https://github.com/google/tunix@44af800726dd5b2c5779a1987a9294f9a3eec9ef}" : "${JAX_PIN:=0.10.0}" : "${JAXLIB_PIN:=0.10.0}" : "${LIBTPU_PIN:=0.0.39}"