Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/tutorials/posttraining/knowledge_distillation.md
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ python3 -m maxtext.checkpoint_conversion.to_maxtext \
The online distillation trainer depends on Tunix. The XPK launcher script ([`scripts/run_distill_xpk.sh`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/trainers/post_train/distillation/scripts/run_distill_xpk.sh)) contains a `prep_image` step that layers Tunix on top of the MaxText base image. For local runs, install the same pin used by the launcher — the default `TUNIX_SOURCE` in `run_distill_xpk.sh` is the source of truth. As of this writing:

```bash
pip install "git+https://github.com/google/tunix@110932a8395086511228483312131841521695c1"
pip install "git+https://github.com/google/tunix@44af800726dd5b2c5779a1987a9294f9a3eec9ef"
```

> **Note:** The commit pin above will drift as the launcher is updated. Before installing, check the `TUNIX_SOURCE` default in [`run_distill_xpk.sh`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/trainers/post_train/distillation/scripts/run_distill_xpk.sh) and use that spec. Once a Tunix PyPI release ships, this will become a versioned `google-tunix==<ver>` install.
Expand Down
2 changes: 1 addition & 1 deletion src/dependencies/extra_deps/post_train_github_deps.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
google-tunix @ https://github.com/google/tunix/archive/683256db1a0919b5cfd46cee52cebc96331494fb.zip
google-tunix @ https://github.com/google/tunix/archive/44af800726dd5b2c5779a1987a9294f9a3eec9ef.zip
tpu-inference @ https://github.com/vllm-project/tpu-inference/archive/4d08971683a64fd796a1b9fd0bb71128188882d5.zip
vllm @ git+https://github.com/vllm-project/vllm@2131b597b18d051dced4c4a605d362fa37f46ed1
39 changes: 39 additions & 0 deletions src/maxtext/common/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

"""Create an Orbax CheckpointManager with specified (Async or not) Checkpointer."""

import asyncio
import time
from typing import Any, Optional

Expand Down Expand Up @@ -168,6 +169,44 @@ class GrainCheckpointRestore(ocp.args.CheckpointArgs):
process_count: Optional[int] = None


class GrainCheckpointable(ocp_v1.StatefulCheckpointable):
"""Adapts `GrainCheckpointHandler` to Orbax v1's `StatefulCheckpointable`."""

def __init__(
self,
*,
save_args: GrainCheckpointSave | None = None,
restore_args: GrainCheckpointRestore | None = None,
):
self._handler = GrainCheckpointHandler()
self._save_args = save_args
self._restore_args = restore_args

async def save(self, directory):
"""Saves the Grain iterator state to the given directory."""
# `GrainCheckpointHandler.save` snapshots iterator state (`get_state`) AND
# writes it; both must happen in this (blocking) save phase, NOT in the
# returned background coroutine.
path = await directory.await_creation()
self._handler.save(path, args=self._save_args)

async def _committed(): # nothing left for the background commit phase
return None

return _committed()

async def load(self, directory: epath.Path):
"""Loads the Grain iterator state from the given directory."""
handler, args = self._handler, self._restore_args

# This will be ran to completion so asynchronous portion is just for
# compatibility with Orbax v1 API.
async def _background_load():
await asyncio.to_thread(handler.restore, directory, args=args)

return _background_load()


def _default_for_sds(sds):
"""Returns a deterministic value matching `sds` shape/dtype/sharding.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@

from maxtext.utils import max_logging
from maxtext.utils import maxtext_utils
# Reuse MaxText's native checkpointing logic.
from maxtext.common.checkpointing import GrainCheckpointHandler, GrainCheckpointSave, GrainCheckpointRestore
from maxtext.common import checkpointing
from tunix.sft import checkpoint_manager as tunix_checkpoint_manager
from tunix.sft import peft_trainer

Expand Down Expand Up @@ -647,8 +646,9 @@ def create_labels(self, targets, targets_segmentation=None, **kwargs):
class MaxTextCheckpointManager(tunix_checkpoint_manager.CheckpointManager):
"""Custom CheckpointManager that uses MaxText's native handlers.

This manager extends Tunix to support saving/restoring the MaxText input pipeline
(Grain) alongside the model and optimizer.
Model and optimizer are delegated to Tunix's v1 ``Checkpointer`` unchanged.
The Grain input pipeline is added as an extra ``"iter"`` checkpointable via
``GrainCheckpointable``, which wraps MaxText's ``GrainCheckpointHandler``.
"""

def __init__(
Expand All @@ -662,32 +662,6 @@ def __init__(
self.student_config = student_config
self._iterator = raw_iterator

# Re-initialize internal Orbax manager with MaxText's Grain handler
# pylint: disable=access-member-before-definition
# pytype: disable=attribute-error
if self._checkpoint_manager is not None:
root_directory = self._checkpoint_manager.directory

if options is None:
options = getattr(self._checkpoint_manager, "options", None)

item_handlers = {
"model_params": checkpoint.PyTreeCheckpointHandler(),
"optimizer_state": checkpoint.PyTreeCheckpointHandler(),
"custom_metadata": checkpoint.JsonCheckpointHandler(),
# Use MaxText's handler for the iterator
"iter": GrainCheckpointHandler(),
}

self._checkpoint_manager.close()
self._checkpoint_manager = checkpoint.CheckpointManager(
root_directory,
item_handlers=item_handlers,
options=options,
)
# pytype: enable=attribute-error
# pylint: enable=access-member-before-definition

def save(
self,
step,
Expand All @@ -697,10 +671,8 @@ def save(
force=False,
custom_metadata=None,
):
"""Saves the checkpoint including the input pipeline state (if available)."""
if self._checkpoint_manager is None:
return False
if not force and not self._checkpoint_manager.should_save(step):
"""Saves model, optimizer and the Grain input pipeline state."""
if self._checkpointer is None:
return False

# Standard Tunix Logic for Model/Optimizer.
Expand All @@ -711,21 +683,12 @@ def save(
else:
params = nnx.state(target_model)

# Define standard SaveArgs once to reuse
default_save_args = checkpoint.SaveArgs()
cp_save_args = {
"model_params": checkpoint.args.PyTreeSave(
item=params, save_args=jax.tree.map(lambda _: default_save_args, params)
),
}
# Exclude optimizer state if the flag is set OR if learn_to_init_mode is active.
checkpointables: dict[str, Any] = {"model_params": params}
# Exclude optimizer state when learn_to_init_mode is active.
exclude_opt = self.student_config.learn_to_init_mode

if optimizer is not None and not exclude_opt:
optimizer_state = nnx.state(optimizer, nnx.optimizer.OptState)
cp_save_args["optimizer_state"] = checkpoint.args.PyTreeSave(
item=optimizer_state, save_args=jax.tree.map(lambda _: default_save_args, optimizer_state)
)
checkpointables["optimizer_state"] = nnx.state(optimizer, nnx.optimizer.OptState)

if self._iterator is not None:
# Follow MaxText's logic to handle multi-process saving
Expand All @@ -743,15 +706,11 @@ def save(
local_iter = data_iter.local_iterator if hasattr(data_iter, "local_iterator") else data_iter
grain_iters_to_save.append((local_iter, process_index, process_count_total))

# Use GrainCheckpointSave wrapper
cp_save_args["iter"] = GrainCheckpointSave(item=grain_iters_to_save)
checkpointables["iter"] = checkpointing.GrainCheckpointable(
save_args=checkpointing.GrainCheckpointSave(item=grain_iters_to_save)
)

return self._checkpoint_manager.save(
step,
args=checkpoint.args.Composite(**cp_save_args),
custom_metadata=custom_metadata or {},
force=force,
)
return self._save_checkpointables(step, checkpointables, force, custom_metadata)

def maybe_restore(
self,
Expand All @@ -766,12 +725,12 @@ def maybe_restore(
Returns:
(restored step, custom_metadata dict). Step is 0 if no checkpoint exists.
"""
if self._checkpoint_manager is None:
if self._checkpointer is None:
return 0, {}

target_model = getattr(model, "student_model", model)

step, _ = super().maybe_restore(
step, custom_metadata = super().maybe_restore(
model=target_model,
optimizer=optimizer,
restore_only_lora_params=restore_only_lora_params,
Expand All @@ -781,20 +740,14 @@ def maybe_restore(

max_logging.log(f"Restored from checkpoint step {step}.")

metadata = self._checkpoint_manager.metadata(step)
if metadata and hasattr(metadata, "custom_metadata") and metadata.custom_metadata is not None:
custom_metadata = metadata.custom_metadata
else:
custom_metadata = {}

return step, dict(custom_metadata)
return step, dict(custom_metadata or {})

def restore_iterator(self):
"""Restores the iterator using MaxText's logic."""
if self._checkpoint_manager is None or self._iterator is None:
if self._checkpointer is None or self._iterator is None:
return None

step = self._checkpoint_manager.latest_step()
step = self.latest_step()
if step is None:
return None

Expand All @@ -804,9 +757,10 @@ def restore_iterator(self):
data_iter = self._iterator
local_iter = data_iter.local_iterator if hasattr(data_iter, "local_iterator") else data_iter

restore_args = GrainCheckpointRestore(item=local_iter)

self._checkpoint_manager.restore(step, args=checkpoint.args.Composite(iter=restore_args))
self._checkpointer.load_checkpointables(
step,
{"iter": checkpointing.GrainCheckpointable(restore_args=checkpointing.GrainCheckpointRestore(item=local_iter))},
)
# Since Grain restores in-place via set_state(), we return the original object
return self._iterator

Expand All @@ -816,5 +770,5 @@ def restore_iterator(self):

def wait_until_finished(self):
"""Blocks until all outstanding checkpoint operations are complete."""
if self._checkpoint_manager is not None:
self._checkpoint_manager.wait_until_finished()
if self._checkpointer is not None:
self._checkpointer.wait()
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@
#
# Image pinning (used by prep_image):
# TUNIX_SOURCE pip-installable spec for tunix.
# default: git+https://github.com/google/tunix@110932a8395086511228483312131841521695c1
# default: git+https://github.com/google/tunix@44af800726dd5b2c5779a1987a9294f9a3eec9ef
# Use "google-tunix==<ver>" once a pypi release ships with the
# multi-host shard_input fix.
# JAX_PIN default: 0.10.0 — version to pin back after tunix deps resolve.
Expand Down Expand Up @@ -164,7 +164,7 @@ require_env() {
: "${DISTILL_LAYER_INDICES:=[0,1,2,3,4,5,6,7]}"

# Image pinning (used by prep_image).
: "${TUNIX_SOURCE:=git+https://github.com/google/tunix@110932a8395086511228483312131841521695c1}"
: "${TUNIX_SOURCE:=git+https://github.com/google/tunix@44af800726dd5b2c5779a1987a9294f9a3eec9ef}"
: "${JAX_PIN:=0.10.0}"
: "${JAXLIB_PIN:=0.10.0}"
: "${LIBTPU_PIN:=0.0.39}"
Expand Down
Loading