Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions src/maxtext/checkpoint_conversion/to_maxtext.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,22 +319,20 @@ def get_maxtext_model_info(config):
# Get abstract model structure (name, shape) without materializing the weights to save memory
abstract_params_tree = maxtext_utils.get_abstract_param(maxtext_model_flax, config)["params"]

abstract_params_flat, _ = jax.tree_util.tree_flatten_with_path(abstract_params_tree)
# Standardize abstract tree for later unflattening
abstract_params_tree = jax.tree.map(
lambda _: 0,
abstract_params_tree,
is_leaf=lambda x: isinstance(x, nn.LogicallyPartitioned),
abstract_params_flat, abstract_params_treedef = jax.tree_util.tree_flatten_with_path(
abstract_params_tree, is_leaf=lambda x: isinstance(x, nn.LogicallyPartitioned)
)
abstract_params_treedef = jax.tree_util.tree_structure(abstract_params_tree)

max_logging.log("MaxText abstract model and state initialized.")

# preprocess state
maxtext_abstract_dict = {}
for mt_target_idx, (path_tuple, abstract_leaf_value) in enumerate(abstract_params_flat):
mt_param_key = "params-" + "-".join(param_key_parts_from_path(path_tuple))
mt_target_shape = abstract_leaf_value.shape
if isinstance(abstract_leaf_value, nn.LogicallyPartitioned):
mt_target_shape = abstract_leaf_value.value.shape
else:
mt_target_shape = abstract_leaf_value.shape
maxtext_abstract_dict[mt_param_key] = (mt_target_idx, mt_target_shape)

return maxtext_abstract_dict, abstract_params_treedef
Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/checkpoint_conversion/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -950,7 +950,7 @@ def extract_nnx_weights(weights_dict: dict) -> dict[str, np.ndarray]:
for path_tuple, leaf_value in leaves_with_paths:
path_keys = param_key_parts_from_path(path_tuple)
# Skip NNX RNG state variables (not model weights)
if "to_nnx__rngs" in path_keys or any(k.endswith("_rngs") for k in path_keys):
if "to_nnx__rngs" in path_keys or any(k == "rngs" or k.endswith("_rngs") for k in path_keys):
continue
maxtext_param_key = "params-" + "-".join(path_keys)
if not isinstance(leaf_value, (jax.Array, np.ndarray)):
Expand Down
41 changes: 22 additions & 19 deletions src/maxtext/common/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.


"""Create an Orbax CheckpointManager with specified (Async or not) Checkpointer."""

import time
Expand Down Expand Up @@ -355,11 +356,16 @@ def combine_sharding(sds, shardings):
use_ocdbt=use_ocdbt,
use_zarr3=use_zarr3,
)
# NNX checkpoints are saved as pure dicts (see maybe_save_checkpoint); the
# restore target must match — a boxed nnx.State wouldn't.
restore_target = abstract_unboxed_pre_state
if isinstance(abstract_unboxed_pre_state, nnx.State):
restore_target = abstract_unboxed_pre_state.to_pure_dict()
# Provide sharding info to ensure restoration returns JAX arrays (not NumPy arrays).
restore_args = jax.tree_util.tree_map(
lambda x: ocp.type_handlers.ArrayRestoreArgs(sharding=x.sharding), abstract_unboxed_pre_state
lambda x: ocp.type_handlers.ArrayRestoreArgs(sharding=x.sharding), restore_target
)
return ocp.Checkpointer(handler).restore(p, abstract_unboxed_pre_state, restore_args=restore_args)
return ocp.Checkpointer(handler).restore(p, restore_target, restore_args=restore_args)


def create_orbax_checkpoint_manager(
Expand Down Expand Up @@ -838,9 +844,7 @@ def map_to_pspec(data):
(EmergencyCheckpointManager, EmergencyReplicatorCheckpointManager),
):
checkpoint_path = str(checkpoint_manager.directory / str(step) / "items")
with handle_checkpoint_mismatch(
"restore NNX checkpoint", checkpoint_path
):
with handle_checkpoint_mismatch("restore NNX checkpoint", checkpoint_path):
restored_nnx = _load_linen_checkpoint_into_nnx(
checkpoint_path,
abstract_unboxed_pre_state,
Expand Down Expand Up @@ -876,9 +880,7 @@ def map_to_pspec(data):
EmergencyReplicatorCheckpointManager,
),
):
restored = checkpoint_manager.restore(
step, args=Composite(state=checkpoint_args)
).state
restored = checkpoint_manager.restore(step, args=Composite(state=checkpoint_args)).state
_assert_no_shaped_dtype_struct(restored)
return (
restored,
Expand Down Expand Up @@ -906,9 +908,7 @@ def map_to_pspec(data):
# Case 3: Default/Fallback case.
# This case acts as a wildcard ('_') and matches if none of the preceding cases were met.
case _:
restored = checkpoint_manager.restore(
step, args=Composite(items=checkpoint_args)
)
restored = checkpoint_manager.restore(step, args=Composite(items=checkpoint_args))
_assert_no_shaped_dtype_struct(restored)
return (restored, None)

Expand All @@ -918,9 +918,7 @@ def map_to_pspec(data):
else:
params = abstract_unboxed_pre_state.params

with handle_checkpoint_mismatch(
"load parameters", load_parameters_from_path
):
with handle_checkpoint_mismatch("load parameters", load_parameters_from_path):
restored_params = load_params_from_path(
load_parameters_from_path,
params,
Expand All @@ -932,9 +930,7 @@ def map_to_pspec(data):
return None, restored_params
elif load_full_state_from_path != "":
max_logging.log(f"Loading full state from path: {load_full_state_from_path}")
with handle_checkpoint_mismatch(
"load full state", load_full_state_from_path
):
with handle_checkpoint_mismatch("load full state", load_full_state_from_path):
restored_state = _load_full_state_from_path(
path=load_full_state_from_path,
abstract_unboxed_pre_state=abstract_unboxed_pre_state,
Expand Down Expand Up @@ -1034,7 +1030,8 @@ def maybe_save_checkpoint(checkpoint_manager, state, config, data_iterator, step
actual_step = int(step)
else:
if config.pure_nnx:
actual_step = int(state.optimizer.step) - 1
# Under DiLoCo the step lives on the DiLoCoTrainState; otherwise on the optimizer.
actual_step = int(state.step if config.enable_diloco else state.optimizer.step) - 1
else:
# Linen TrainState has .step attribute
actual_step = int(state.step) - 1
Expand All @@ -1045,7 +1042,13 @@ def maybe_save_checkpoint(checkpoint_manager, state, config, data_iterator, step

if config.pure_nnx:
# Save in the Linen on-disk layout so pure_nnx and Linen checkpoints are interchangeable.
state = train_state_nnx.to_linen_checkpoint_dict(state.to_pure_dict())
if config.enable_diloco:
# DiLoCoTrainState: persist the synchronized global model (outer params).
# The per-replica inner optimizer / outer-momentum state is not checkpointed.
step_value = state.step.get_value() if hasattr(state.step, "get_value") else state.step
state = train_state_nnx.to_linen_checkpoint_dict({"model": state.params, "optimizer": {"step": step_value}})
else:
state = train_state_nnx.to_linen_checkpoint_dict(state.to_pure_dict())

# Determine if a checkpoint save should be forced, overriding the usual `config.checkpoint_period` logic.
# This occurs if this function was called:
Expand Down
6 changes: 3 additions & 3 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1184,9 +1184,9 @@ position_id_per_seconds: 25
subslice_shape: ""

# NNX
enable_nnx: false
pure_nnx_decoder: false
pure_nnx: false
enable_nnx: true
pure_nnx_decoder: true
pure_nnx: true

################################## Qwen3-Next Specific Configs ##################################
# Kernel size for the 1D convolution in the Gated Delta Net
Expand Down
9 changes: 9 additions & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2789,6 +2789,15 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de

self.using_pipeline_parallelism = self.ici_pipeline_parallelism > 1 or self.dcn_pipeline_parallelism > 1
if self.using_pipeline_parallelism:
if self.pure_nnx:
# The NNX decoder has no pipeline path yet, so the scanned-layers axis ends up
# sharded by 'stage' and fails with a cryptic IndivisibleError at state init.
# Fail fast with a clear message instead. NNX pipeline support is tracked as PR11.5.
raise NotImplementedError(
"Pipeline parallelism is not yet supported on the NNX path. Set "
"ici_pipeline_parallelism=1 and dcn_pipeline_parallelism=1, or use the Linen path "
"(pure_nnx=False enable_nnx=False)."
)
num_stages = int(self.ici_pipeline_parallelism * self.dcn_pipeline_parallelism)
if self.num_pipeline_repeats == -1:
num_pipeline_repeats, remainder = divmod(
Expand Down
26 changes: 10 additions & 16 deletions src/maxtext/layers/nnx_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from flax.core import FrozenDict
from flax.core import meta
from flax.nnx import graph
from flax.nnx import tracers as nnx_tracers
from flax.nnx import variablelib
from flax.nnx.bridge import module as bdg_module
from flax.nnx.module import Module
Expand Down Expand Up @@ -180,19 +179,6 @@ def is_linen_initializing() -> bool:
return False


def _refresh_variable_trace_state(module: Module) -> None:
"""Resets stale ``_trace_state`` on Variables to unblock downstream ``nnx.split``.

``nnx.update`` called with JAX tracer values uses ``_unsafe_bypass_check=True``,
which leaves Variables with a stale ``_trace_state`` from the outer Python
context and breaks ``nnx.split`` with "Cannot extract graph node from different
trace level". Resets ``_trace_state`` on any Variable whose ``_can_update`` is False.
"""
for _, v in nnx.graph.iter_graph(module):
if isinstance(v, variablelib.Variable) and not v._can_update: # pylint: disable=protected-access
object.__setattr__(v, "_trace_state", nnx_tracers.TraceState())


class ToNNX(Module):
"""A wrapper to turn any Linen module into an NNX module.

Expand Down Expand Up @@ -511,9 +497,17 @@ def maybe_unbox(x):

warnings.warn(f"Found unknown module paths in incoming state:{paths_str}")

# Filter out unknown paths so we don't try to assign them to static attributes
filtered_state_flat = {k: v for k, v in new_state_flat.items() if k not in unknown_state_flat}
new_state = nnx.State(nnx.traversals.unflatten_mapping(filtered_state_flat))

# Rebind the module to the current trace via split / update / merge.
# nnx.update directly on the live module can leave stale tracers.
graphdef, full_state = nnx.split(module)
nnx.update(full_state, new_state)
module = nnx.merge(graphdef, full_state)

_fix_for_qwix_quantization(module)
nnx.update(module, new_state)
_refresh_variable_trace_state(module)
method_fn = _get_module_method(module, nnx_method)
out = method_fn(module, *args, **kwargs)
self._update_variables(module)
Expand Down
59 changes: 43 additions & 16 deletions src/maxtext/trainers/diloco/diloco.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@
import drjax
from flax import nnx
from flax import struct
from flax.training import train_state
import jax
import jax.numpy as jnp
from jaxtyping import Array, Int32, Key, PyTree, UInt32
import optax

from maxtext.configs import pyconfig
from maxtext.common.train_state_nnx import TrainStateNNX

Batch = Any
Params = PyTree
Expand Down Expand Up @@ -157,8 +157,10 @@ def add_diloco_dim(x):
# For NNX, model params (Param variables only) live under abstract_state.model;
# for Linen under abstract_state.params.
if config.pure_nnx:
model_params = abstract_state.model.filter(nnx.Param)
model_params_sharding = state_mesh_shardings.model.filter(nnx.Param)
_, model_params, _ = nnx.split(abstract_state.model, nnx.Param, ...)
model_params = model_params.to_pure_dict()
_, model_params_sharding, _ = nnx.split(state_mesh_shardings.model, nnx.Param, ...)
model_params_sharding = model_params_sharding.to_pure_dict()
else:
model_params = abstract_state.params
model_params_sharding = state_mesh_shardings.params
Expand Down Expand Up @@ -216,7 +218,11 @@ def init_diloco_state() -> tuple[DiLoCoTrainState, PyTree]:
# Outer state retains a single copy of the model parameters and optimizer state.
# For NNX, model params (Param variables only) live under state.model;
# for Linen under state.params.
outer_params = state.model.filter(nnx.Param) if config.pure_nnx else state.params
if config.pure_nnx:
_, outer_params, _ = nnx.split(state.model, nnx.Param, ...)
outer_params = outer_params.to_pure_dict()
else:
outer_params = state.params
outer_opt_state = outer_optimizer.init(outer_params)
outer_opt_state_sharding = jax.tree_util.tree_map(lambda x: x.sharding, outer_opt_state)
# For NNX, the step counter lives at state.optimizer.step; for Linen at state.step.
Expand Down Expand Up @@ -258,9 +264,11 @@ def synchronize(state):
# state (since last synchronization).
broadcast_outer_params = drjax.broadcast(state.params, mesh=mesh)
# For NNX, model Param vars live under inner_state.model; for Linen under inner_state.params.
inner_model_params = (
nnx.filter_state(state.inner_state.model, nnx.Param) if config.pure_nnx else state.inner_state.params
)
if config.pure_nnx:
_, inner_model_params, _ = nnx.split(state.inner_state.model, nnx.Param, ...)
inner_model_params = inner_model_params.to_pure_dict()
else:
inner_model_params = state.inner_state.params
model_delta = jax.tree.map(lambda x, y: y - x, inner_model_params, broadcast_outer_params)
# Treat the average delta as the outer optimizer's gradient and apply to
# the global (outer) model params.
Expand All @@ -273,15 +281,34 @@ def synchronize(state):
if config.pure_nnx:
# For NNX: merge new Param vars back with the non-Param model vars (e.g. RNG state).
def replace_nnx_model_params(s, new_params):
non_param_model = nnx.filter_state(s.model, nnx.Not(nnx.Param))
new_model = nnx.merge_state(non_param_model, new_params)
# Assign via __setitem__ so nested States are stored as plain dicts (matching
# nnx.state()'s pytree structure). The dict-literal constructor keeps them as
# State objects, which makes jax.lax.cond see mismatched pytree structures.
result = type(s)({})
result["model"] = new_model
result["optimizer"] = s["optimizer"]
return result
s_model = s["model"] if hasattr(s, "keys") else s.model
s_opt = s["optimizer"] if hasattr(s, "keys") else s.optimizer

graphdef, _, non_param_state = nnx.split(s_model, nnx.Param, ...)
new_model = nnx.merge(graphdef, new_params, non_param_state)

if type(s_model).__name__ == "State":
new_model = nnx.state(new_model)
elif isinstance(s_model, dict):
new_model = nnx.to_pure_dict(new_model)

if hasattr(s, "keys"):
# Replace "model" leaves by path, keeping s's treedef. Picking by position
# (leaves[N:]) breaks if a key sorts before "model"; reconstructing via
# type(s)({...}) breaks the lax.cond match — nnx.State recursive-wraps.
leaves_with_paths, treedef = jax.tree_util.tree_flatten_with_path(s)
new_model_iter = iter(jax.tree_util.tree_leaves(new_model))

def _is_model_leaf(path):
if not path:
return False
k = path[0]
return getattr(k, "key", None) == "model" or getattr(k, "name", None) == "model"

new_leaves = [next(new_model_iter) if _is_model_leaf(p) else leaf for p, leaf in leaves_with_paths]
return jax.tree_util.tree_unflatten(treedef, new_leaves)
else:
return TrainStateNNX(new_model, s_opt)

new_inner_state = drjax.map_fn(
lambda s: replace_nnx_model_params(s, new_outer_params),
Expand Down
27 changes: 22 additions & 5 deletions src/maxtext/trainers/post_train/sft/train_sft_native.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import tensorflow as tf
import jax

from flax import nnx
from flax.linen import partitioning as nn_partitioning

from maxtext.configs import pyconfig
Expand Down Expand Up @@ -75,13 +76,24 @@ def train_loop(config, recorder, state=None):

params_shardings, state_mesh_shardings = sharding.maybe_update_params_sharding_with_opt(config, state_mesh_shardings)

# NNX jits over the GraphDef + a flat nnx.State, so split the TrainStateNNX
# here (mirrors trainers/pre_train/train.py). Linen jits over the module.
if config.pure_nnx:
jit_model, state = nnx.split(state)
else:
jit_model = model

p_train_step, p_eval_step = train_utils.jit_train_and_eval_step(
config, model, mesh, state, state_mesh_shardings, train_step, eval_step, eval_data_iterator, params_shardings
config, jit_model, mesh, state, state_mesh_shardings, train_step, eval_step, eval_data_iterator, params_shardings
)

# Only the Linen step takes a dropout rng; pass it only there so the args
# match the jitted in_shardings (see get_functional_train_with_signature).
rng_args = () if config.pure_nnx else (init_rng,)

with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
shaped_batch = maxtext_utils.get_shaped_batch(config)
compiled = p_train_step.lower(state, shaped_batch, init_rng).compile()
compiled = p_train_step.lower(state, shaped_batch, *rng_args).compile()
compiled_stats = compiled.memory_analysis()
max_utils.print_compiled_memory_stats(compiled_stats)

Expand All @@ -91,7 +103,11 @@ def train_loop(config, recorder, state=None):
metric_logger = MetricLogger(config=config, learning_rate_schedule=learning_rate_schedule)

# Write train config params, num model params, and XLA flags to tensorboard
metric_logger.write_setup_info_to_tensorboard(state.params)
if config.pure_nnx:
_, setup_params, _ = nnx.split(state.model, nnx.Param, ...)
else:
setup_params = state.params
metric_logger.write_setup_info_to_tensorboard(setup_params)

_job_completed_gracefully = False
try:
Expand All @@ -103,9 +119,10 @@ def train_loop(config, recorder, state=None):
example_batch = data_loader.load_next_batch()
# pylint: disable=not-callable
nextrng = jax.jit(jax.random.fold_in)(init_rng, step)
step_rng_args = () if config.pure_nnx else (nextrng,)
with maybe_record_goodput(recorder, GoodputEvent.STEP, step):
with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
state, metrics = p_train_step(state, example_batch, nextrng)
state, metrics = p_train_step(state, example_batch, *step_rng_args)

step_time_delta = datetime.datetime.now() - last_step_completion

Expand Down Expand Up @@ -134,7 +151,7 @@ def train_loop(config, recorder, state=None):
if config.eval_steps > 0 and eval_step_count >= config.eval_steps:
break
with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
eval_metrics = p_eval_step(state, eval_batch, nextrng)
eval_metrics = p_eval_step(state, eval_batch, *step_rng_args)
eval_step_time_delta = datetime.datetime.now() - last_eval_step_completion
last_eval_step_completion = datetime.datetime.now()
metric_logger.buffer_and_write_metrics(
Expand Down
Loading
Loading