From 6deaacce46600988d34b4015f98dcdf936a2a665 Mon Sep 17 00:00:00 2001 From: Param Bole Date: Fri, 12 Jun 2026 19:59:54 +0000 Subject: [PATCH 1/2] feat(models): Integrate DeepSeek V4 architecture and routing This commit introduces full support for DeepSeek V4 by integrating its compressed attention mechanisms, MoE routing, and architectural layers. Key changes: - Add `deepseek4.yml` configuration and `DeepSeek4DecoderLayer` implementation. - Implement hybrid Hash Routing and Token Routing for MoE layers. - Add prefix/suffix layer unrolling for non-uniform compression blocks. - Fix Pydantic validation for base MLP dimensions. - Bypass MLA instantiation in favor of native CompressedAttention (CSA/HCA). --- src/maxtext/common/common_types.py | 1 + src/maxtext/configs/models/deepseek4.yml | 63 +++++ src/maxtext/configs/types.py | 2 +- src/maxtext/layers/attention_compressed.py | 7 +- src/maxtext/layers/attentions.py | 8 + src/maxtext/layers/decoders.py | 130 +++++++++- src/maxtext/layers/moe.py | 19 +- src/maxtext/models/deepseek.py | 66 ++--- src/maxtext/models/deepseek4.py | 255 ++++++++++++++++++++ tests/unit/deepseek_v4_vs_reference_test.py | 9 +- 10 files changed, 515 insertions(+), 45 deletions(-) create mode 100644 src/maxtext/configs/models/deepseek4.yml create mode 100644 src/maxtext/models/deepseek4.py diff --git a/src/maxtext/common/common_types.py b/src/maxtext/common/common_types.py index d4b52207fc..71dbc105d4 100644 --- a/src/maxtext/common/common_types.py +++ b/src/maxtext/common/common_types.py @@ -113,6 +113,7 @@ class DecoderBlockType(enum.Enum): SIMPLE_MLP = "simple_mlp" LLAMA4 = "llama4" OLMO3 = "olmo3" + DEEPSEEK4 = "deepseek4" class AttentionType(enum.Enum): diff --git a/src/maxtext/configs/models/deepseek4.yml b/src/maxtext/configs/models/deepseek4.yml new file mode 100644 index 0000000000..521e8c4871 --- /dev/null +++ b/src/maxtext/configs/models/deepseek4.yml @@ -0,0 +1,63 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# model config for DeepSeek V4 + +base_emb_dim: 4096 +base_num_query_heads: 64 +base_num_kv_heads: 1 +base_num_decoder_layers: 43 +base_mlp_dim: 2048 +base_moe_mlp_dim: 2048 +vocab_size: 129280 + +# --- Standard Defaults --- +enable_dropout: false +logits_via_embedding: false +normalization_layer_epsilon: 1.0e-6 + +# --- V4 Specific Architectural Keys --- +decoder_block: "deepseek4" +mhc_expansion_rate: 4 +first_num_hash_layers: 3 +indexer_head_dim: 128 +indexer_n_heads: 64 +indexer_topk: 512 + +# Note: Layers (0,1) are not compressed. +# The 44th layer (MTP module with compress_ratio=0) has been explicitly dropped for now. +# This leaves exactly 43 layers: 2 prefix [0,0] + 40 scanned + 1 suffix [4]. +compress_ratios: [0, 0, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4] + +# --- MoE configuration --- +mlp_activations: ["silu", "linear"] +num_experts: 256 +num_experts_per_tok: 6 +shared_experts: 1 +routed_score_func: "sqrtsoftplus" + +# --- Attention configuration --- +attention: 'dot_product' +attention_type: 'compressed' +q_lora_rank: 1024 +o_groups: 8 +o_lora_rank: 1024 +sliding_window_size: 128 + +# --- RoPE --- + +rope_type: "default" +rope_max_timescale: 10000 # Main RoPE theta +compressed_rope_max_timescale: 160000 # Compressed RoPE theta +max_position_embeddings: 1048576 diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index b70b7238d3..e6cbdf666f 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -553,7 +553,7 @@ class Attention(BaseModel): "autoselected", description="The attention algorithm to use (dot_product, flash, etc).", ) - attention_type: Literal["global", "local_sliding", "chunk", "mla", "full"] = Field( + attention_type: Literal["global", "local_sliding", "chunk", "mla", "full", "compressed"] = Field( "global", description="The variant of attention to use." ) share_kv_projections: bool = Field( diff --git a/src/maxtext/layers/attention_compressed.py b/src/maxtext/layers/attention_compressed.py index e9a25f46b5..2b84915beb 100644 --- a/src/maxtext/layers/attention_compressed.py +++ b/src/maxtext/layers/attention_compressed.py @@ -694,10 +694,6 @@ def __init__( self.q_lora_rank = q_lora_rank self.compress_ratio = compress_ratio - # Determine the correct underlying attention type based on the compress_ratio - if self.compress_ratio == 0: - attention_type = AttentionType.LOCAL_SLIDING - super().__init__( config=config, num_query_heads=num_query_heads, @@ -727,6 +723,7 @@ def __init__( use_bias_in_projections=use_bias_in_projections, name=name, rngs=rngs, + rope_type="deepseek4", **kwargs, ) @@ -1047,7 +1044,7 @@ def __call__( # -> [batch, q_length, emb_dim] final_out = self.o_b_proj(grouped_flat) - return final_out + return final_out, None def compressed_attention( diff --git a/src/maxtext/layers/attentions.py b/src/maxtext/layers/attentions.py index 93c54e25a6..f0c7c90dbb 100644 --- a/src/maxtext/layers/attentions.py +++ b/src/maxtext/layers/attentions.py @@ -58,6 +58,7 @@ Qwen3OmniMoeVisionRotaryEmbedding, RotaryEmbedding, YarnRotaryEmbedding, + DeepSeekV4RotaryEmbedding, PartialRotaryEmbedding, Gemma4PartialRotaryEmbedding, ) @@ -850,6 +851,13 @@ def init_rotary_embedding(self): shard_mode=self.config.shard_mode, rngs=self.rngs, ) + elif rope_type == "deepseek4": + rotary_embedding = DeepSeekV4RotaryEmbedding( + head_dim=rope_embedding_dims, + partial_rotary_factor=self.partial_rotary_factor if self.partial_rotary_factor is not None else 1.0, + rope_theta=self.rope_max_timescale, + dtype=self.dtype, + ) elif self.is_qwen3_hybrid: rotary_embedding = PartialRotaryEmbedding( min_timescale=self.config.rope_min_timescale, diff --git a/src/maxtext/layers/decoders.py b/src/maxtext/layers/decoders.py index 5429169f8a..91d5f93fdb 100644 --- a/src/maxtext/layers/decoders.py +++ b/src/maxtext/layers/decoders.py @@ -40,6 +40,7 @@ from maxtext.layers.quantizations import AqtQuantization as Quant from maxtext.models import ( deepseek, + deepseek4, deepseek_batchsplit, deepseek_batchsplit_fp8, gemma, @@ -457,6 +458,10 @@ def get_decoder_layers(self): deepseek.DeepSeekDenseLayerToLinen, deepseek.DeepSeekMoELayerToLinen, ] + case DecoderBlockType.DEEPSEEK4: + return ( + [deepseek4.DeepSeek4ScannableBlockToLinen] if self.config.scan_layers else [deepseek4.DeepSeek4LayerToLinen] + ) case DecoderBlockType.GEMMA: return [gemma.GemmaDecoderLayerToLinen] case DecoderBlockType.GEMMA2: @@ -534,6 +539,7 @@ def get_norm_layer(self, num_features: int): DecoderBlockType.MISTRAL, DecoderBlockType.MIXTRAL, DecoderBlockType.DEEPSEEK, + DecoderBlockType.DEEPSEEK4, DecoderBlockType.GEMMA, DecoderBlockType.GEMMA2, DecoderBlockType.GEMMA3, @@ -999,6 +1005,17 @@ def __call__( previous_chunk, slot, ) + elif cfg.decoder_block == DecoderBlockType.DEEPSEEK4: + y = self._apply_deepseek4_scanned_blocks( + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + previous_chunk, + slot, + decoder_input_tokens, + ) else: RemattedBlockLayer = RemattedBlockLayers[0] scan_length = int(cfg.num_decoder_layers / cfg.inhomogeneous_layer_cycle_interval) @@ -1133,7 +1150,7 @@ def __call__( "is_nope_layer": llama4.determine_is_nope_layer(lyr, self.config.nope_layer_interval), "is_moe_layer": llama4.determine_is_moe_layer(lyr, self.config.interleave_moe_layer_step), } - if cfg.decoder_block in (DecoderBlockType.QWEN3_NEXT, DecoderBlockType.QWEN3_5): + if cfg.decoder_block in (DecoderBlockType.QWEN3_NEXT, DecoderBlockType.QWEN3_5, DecoderBlockType.DEEPSEEK4): layer_kwargs = {"layer_idx": lyr} kv_cache = None if kv_caches is not None: @@ -1355,6 +1372,117 @@ def _apply_gemma4_scanned_blocks( return y + def _apply_deepseek4_scanned_blocks( + self, + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + previous_chunk, + slot, + decoder_input_tokens, + ): + """Applies DeepSeek V4 scanned decoder blocks. + + DeepSeek V4 natively has 44 layers, but we explicitly drop the final MTP module (Layer 43) + for now, strictly evaluating the 43 standard hidden layers. + + The layout perfectly maps to the 43 remaining elements in the config array: + - 2 Prefix layers (Indices 0, 1): compress_ratio = [0, 0] + - 40 Scanned layers (Indices 2 to 41): 20 perfectly repeating chunks of [4, 128] + - 1 Suffix layer (Index 42): compress_ratio = [4] + + Total evaluated layers: 2 + 40 + 1 = 43 layers. + """ + + cfg = self.config + mesh = self.mesh + + broadcast_args = ( + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + slot, + previous_chunk, + ) + + layer_call_kwargs = { + "previous_chunk": previous_chunk, + "slot": slot, + "decoder_input_tokens": decoder_input_tokens, + } + + # 1. Prefix Unrolling (Layers 0, 1) + # These layers use Hash Routing and compress_ratio=0. + for layer_idx in range(2): + prefix_layer = deepseek4.DeepSeek4LayerToLinen( + config=cfg, + mesh=mesh, + name=f"layers_{layer_idx}", + quant=self.quant, + model_mode=self.model_mode, + layer_idx=layer_idx, + ) + y, _ = prefix_layer( + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + **layer_call_kwargs, + ) + + # 2. Chunked Scanning (Layers 2 to 41) + # These 40 layers perfectly alternate CSA (4) and HCA (128). + num_full_blocks = 20 + if num_full_blocks > 0: + ScannableBlockToLinen = deepseek4.DeepSeek4ScannableBlockToLinen + policy = self.get_remat_policy() + RemattedDeepSeek4Block = self.set_remat_policy([ScannableBlockToLinen], policy)[0] + + y, _ = nn.scan( + RemattedDeepSeek4Block, + variable_axes={ + "params": cfg.param_scan_axis, + "cache": 0, + "intermediates": 0, + "aqt": 0, + "_overwrite_with_gradient": 0, + }, + split_rngs={"params": True, "dropout": cfg.enable_dropout}, + in_axes=(nn.broadcast,) * len(broadcast_args), + length=num_full_blocks, + metadata_params={ + nn.PARTITION_NAME: "layers", + "abstract_init": False, + }, + )(config=cfg, mesh=mesh, quant=self.quant, model_mode=model_mode, name="scanned_blocks",)(y, *broadcast_args) + + # 3. Suffix Unrolling (Layer 42) + # Layer 42 is the final CSA block (4). + # The 44th MTP layer (compress_ratio=0) is dropped/excluded. + for layer_idx in range(42, 43): + suffix_layer = deepseek4.DeepSeek4LayerToLinen( + config=cfg, + mesh=mesh, + name=f"layers_{layer_idx}", + quant=self.quant, + model_mode=self.model_mode, + layer_idx=layer_idx, + ) + y, _ = suffix_layer( + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + **layer_call_kwargs, + ) + + return y + def _apply_gemma4_small_layers( self, y, diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index df7ba653c9..0d27a1ec8f 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -208,6 +208,10 @@ def calculate_load_balance_updates(top_k_indices, num_experts, rate): return output +class Tid2EidVar(nnx.Variable): + """Custom variable to hold tid2eid without trainable param overhead.""" + + class GateLogit(nnx.Module): """A layer used to compute gate logits, allowing to return the pre bias values for DeepSeek routing.""" @@ -399,8 +403,11 @@ def __init__( # DeepSeek V4 Hash Routing if self.is_hash_routing: # Token-ID to Expert-ID lookup table for static routing - self.tid2eid = nnx.Variable( - jnp.zeros((self.config.vocab_size, self.num_experts_per_tok), dtype=jnp.int32), + # Must be stored as float32 because MaxText passes the entire variable tree + # through jax.value_and_grad, which strictly requires all leaves to be inexact types + # (even if they receive no gradients). We cast to int32 dynamically during routing. + self.tid2eid = Tid2EidVar( + jnp.zeros((self.config.vocab_size, self.num_experts_per_tok), dtype=jnp.float32), out_sharding=None, # Replicated across shards for local lookup ) else: @@ -665,7 +672,13 @@ def get_topk(self, gate_logits, pre_bias_logits, rngs=None, input_ids=None): return top_k_weights, top_k_indices if self.is_hash_routing: - top_k_indices = self.tid2eid[input_ids] + if input_ids is None: + raise ValueError("input_ids cannot be None when is_hash_routing is True") + # Access the static routing table + tid2eid_int = self.tid2eid.value + # Cast the float32 array to int32 (JAX automatically assigns 0.0 gradients to integer casts) + tid2eid_int = tid2eid_int.astype(jnp.int32) + top_k_indices = tid2eid_int[input_ids] top_k_weights = jnp.take_along_axis(pre_bias_logits, top_k_indices, axis=-1) # NOTE: deepseek2 has a different pattern elif self.config.model_name.startswith(("deepseek3", "deepseek4")): diff --git a/src/maxtext/models/deepseek.py b/src/maxtext/models/deepseek.py index 27e1a6f7ad..a860ee1c74 100644 --- a/src/maxtext/models/deepseek.py +++ b/src/maxtext/models/deepseek.py @@ -138,37 +138,39 @@ def __init__( self.engram_layer_norm = None self.engram = None - self.self_attention = attention_mla.MLA( - config=self.config, - num_query_heads=self.config.num_query_heads, - num_kv_heads=self.config.num_kv_heads, - head_dim=self.config.head_dim, - max_target_length=self.config.max_target_length, - max_prefill_predict_length=self.config.max_prefill_predict_length, - attention_kernel=self.config.attention, - attention_type=self.config.attention_type, - inputs_q_shape=self.dummy_inputs_shape, - inputs_kv_shape=self.dummy_inputs_shape, - mesh=mesh, - dtype=self.config.dtype, - weight_dtype=self.config.weight_dtype, - dropout_rate=self.config.dropout_rate, - name="self_attention", - quant=quant, - kv_quant=quantizations.configure_kv_quant(config), - q_lora_rank=self.config.q_lora_rank, - kv_lora_rank=self.config.kv_lora_rank, - qk_nope_head_dim=self.config.qk_nope_head_dim, - qk_rope_head_dim=self.config.qk_rope_head_dim, - v_head_dim=self.config.v_head_dim, - max_position_embeddings=self.config.max_position_embeddings, - original_max_position_embeddings=self.config.original_max_position_embeddings, - mscale=self.config.mscale, - rope_factor=self.config.rope_factor, - model_mode=model_mode, - rngs=rngs, - attn_logits_soft_cap=self.config.attn_logits_soft_cap, - ) + # DeepSeek V4 natively overrides this block with CompressedAttention. + if self.config.decoder_block.value != "deepseek4": + self.self_attention = attention_mla.MLA( + config=self.config, + num_query_heads=self.config.num_query_heads, + num_kv_heads=self.config.num_kv_heads, + head_dim=self.config.head_dim, + max_target_length=self.config.max_target_length, + max_prefill_predict_length=self.config.max_prefill_predict_length, + attention_kernel=self.config.attention, + attention_type=self.config.attention_type, + inputs_q_shape=self.dummy_inputs_shape, + inputs_kv_shape=self.dummy_inputs_shape, + mesh=mesh, + dtype=self.config.dtype, + weight_dtype=self.config.weight_dtype, + dropout_rate=self.config.dropout_rate, + name="self_attention", + quant=quant, + kv_quant=quantizations.configure_kv_quant(self.config), + q_lora_rank=self.config.q_lora_rank, + kv_lora_rank=self.config.kv_lora_rank, + qk_nope_head_dim=self.config.qk_nope_head_dim, + qk_rope_head_dim=self.config.qk_rope_head_dim, + v_head_dim=self.config.v_head_dim, + max_position_embeddings=self.config.max_position_embeddings, + original_max_position_embeddings=self.config.original_max_position_embeddings, + mscale=self.config.mscale, + rope_factor=self.config.rope_factor, + model_mode=model_mode, + rngs=rngs, + attn_logits_soft_cap=self.config.attn_logits_soft_cap, + ) self.dropout = Dropout(rate=self.config.dropout_rate, broadcast_dims=(-2,), rngs=self.rngs) if self.is_mhc_enabled: @@ -333,7 +335,7 @@ def __init__( rngs=self.rngs, ) - def mlp_op(self, x, deterministic): + def mlp_op(self, x, deterministic, *args, **kwargs): mlp = self.mlp(x, deterministic, intermediate_sharding=self.mlp_intermediate_sharding, out_sharding=self.out_sharding) return self.with_logical_constraint(mlp) diff --git a/src/maxtext/models/deepseek4.py b/src/maxtext/models/deepseek4.py new file mode 100644 index 0000000000..adee819cd5 --- /dev/null +++ b/src/maxtext/models/deepseek4.py @@ -0,0 +1,255 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DeepSeek-V4 model definition.""" + +from typing import Optional + +from flax import nnx +import flax.linen as nn +from jax.sharding import Mesh + +from maxtext.common.common_types import Config, AttentionType +from maxtext.common.common_types import HyperConnectionType +from maxtext.layers import attention_compressed +from maxtext.layers import initializers +from maxtext.layers import moe +from maxtext.layers import nnx_wrappers +from maxtext.layers import quantizations +from maxtext.models import deepseek +from jax.ad_checkpoint import checkpoint_name + + +class DeepSeek4DecoderLayer(deepseek.DeepSeekGenericLayer): + """DeepSeek-V4 specific decoder layer. + + Note: V4 does not utilize purely dense layers in the initial transformer blocks. + Every layer is a Sparse MoE layer (which internally contains shared dense experts). + """ + + def __init__( + self, + config: Config, + model_mode: str, + mesh: Mesh, + rngs: nnx.Rngs, + quant: Optional[quantizations.AqtQuantization] = None, + layer_idx: int = -1, + ratio: Optional[int] = None, + is_hash_routing: Optional[bool] = None, + ) -> None: + super().__init__( + config=config, + model_mode=model_mode, + mesh=mesh, + rngs=rngs, + quant=quant, + layer_idx=layer_idx, + ) + + # Determine if this layer uses Hash Routing based on first_num_hash_layers. + if is_hash_routing is None: + is_hash_routing = layer_idx < config.first_num_hash_layers + self.mlp = moe.RoutedAndSharedMoE( + config=self.config, + mesh=self.mesh, + kernel_init=initializers.nd_dense_init(self.config.dense_init_scale, "fan_in", "truncated_normal"), + kernel_axes=("embed", None), + dtype=self.config.dtype, + weight_dtype=self.config.weight_dtype, + quant=quant, + is_hash_routing=is_hash_routing, + rngs=rngs, + ) + + if ratio is None: + ratio = config.compress_ratios[layer_idx] + + # Route to LOCAL_SLIDING if compression is disabled for this layer, + # otherwise default to the globally configured attention type (e.g., COMPRESSED). + layer_attention_type = AttentionType.LOCAL_SLIDING if ratio == 0 else AttentionType(self.config.attention_type) + + self.self_attention = attention_compressed.CompressedAttention( + config=self.config, + compress_ratio=ratio, + num_query_heads=self.config.num_query_heads, + num_kv_heads=self.config.num_kv_heads, + head_dim=self.config.head_dim, + max_target_length=self.config.max_target_length, + max_prefill_predict_length=self.config.max_prefill_predict_length, + attention_kernel=self.config.attention, + attention_type=layer_attention_type, + inputs_q_shape=self.dummy_inputs_shape, + inputs_kv_shape=self.dummy_inputs_shape, + mesh=self.mesh, + dtype=self.config.dtype, + weight_dtype=self.config.weight_dtype, + dropout_rate=self.config.dropout_rate, + sliding_window_size=self.config.sliding_window_size, + q_lora_rank=self.config.q_lora_rank, + name=f"compressed_attention_layer_{layer_idx}", + quant=quant, + kv_quant=quantizations.configure_kv_quant(config), + model_mode=model_mode, + rngs=rngs, + ) + + # pylint: disable=arguments-differ + def mlp_op(self, inputs, deterministic, *args, **kwargs): + input_ids = kwargs.get("input_ids") + mlp_lnx, load_balance_loss, moe_bias_updates = self.mlp( + inputs=inputs, + input_ids=input_ids, + ) + return self.with_logical_constraint(mlp_lnx), load_balance_loss, moe_bias_updates + + def __call__( + self, + inputs, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + previous_chunk=None, + slot: None | int = None, + kv_cache=None, + attention_metadata=None, + decoder_input_tokens=None, + ): + if isinstance(inputs, tuple): + inputs = inputs[0] + + x = self.with_logical_constraint(inputs) + x = checkpoint_name(x, "decoder_layer_input") + + _, intermediate_inputs = self.self_attention_with_norm_op( + x, + decoder_segment_ids, + decoder_positions, + deterministic, + previous_chunk, + slot, + ) + + layer_output, metadata = self.mhc_mlp( + self.post_attention_norm_op, + self.mlp_op, + x=intermediate_inputs, + mhc_type=HyperConnectionType.MLP_MOE, + deterministic=deterministic, + input_ids=decoder_input_tokens, + ) + load_balance_loss = metadata.get("load_balance_loss", None) + moe_bias_updates = metadata.get("moe_bias_updates", None) + + layer_output = self.dropout_op(layer_output, deterministic=deterministic) + return self.post_process(layer_output, load_balance_loss, moe_bias_updates, kv_cache) + + +class DeepSeek4ScannableBlock(nnx.Module): + """A scannable block containing exactly two DeepSeek V4 layers (CSA and HCA). + + DeepSeek V4 layers alternate `compress_ratio=4` (CSA) and `compress_ratio=128` (HCA) + throughout the middle of the network. This block encapsulates one full `[4, 128]` + cycle so it can be perfectly scanned using JAX `nn.scan`. + """ + + def __init__( + self, + config: Config, + mesh: Mesh, + model_mode: str, + rngs: nnx.Rngs, + quant: None | quantizations.AqtQuantization = None, + ): + self.config = config + self.mesh = mesh + self.model_mode = model_mode + self.quant = quant + self.rngs = rngs + + # Layer 0 in the block: CSA (compress_ratio=4) with Standard MoE (is_hash_routing=False) + self.layers_0 = DeepSeek4DecoderLayer( + config=self.config, + mesh=self.mesh, + model_mode=self.model_mode, + rngs=self.rngs, + quant=self.quant, + ratio=4, + is_hash_routing=False, + ) + + # Layer 1 in the block: HCA (compress_ratio=128) with Standard MoE (is_hash_routing=False) + self.layers_1 = DeepSeek4DecoderLayer( + config=self.config, + mesh=self.mesh, + model_mode=self.model_mode, + rngs=self.rngs, + quant=self.quant, + ratio=128, + is_hash_routing=False, + ) + + def __call__( + self, + inputs, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + slot=None, + previous_chunk=None, + attention_metadata=None, + kv_cache=None, + ): + inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_norm_length", "activation_embed")) + inputs = checkpoint_name(inputs, "decoder_layer_input") + y = inputs + + y, _ = self.layers_0( + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + previous_chunk=previous_chunk, + slot=slot, + kv_cache=kv_cache, + attention_metadata=attention_metadata, + ) + + y, _ = self.layers_1( + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + previous_chunk=previous_chunk, + slot=slot, + kv_cache=kv_cache, + attention_metadata=attention_metadata, + ) + + return y, None + + +DeepSeek4LayerToLinen = nnx_wrappers.to_linen_class( + DeepSeek4DecoderLayer, + base_metadata_fn=initializers.variable_to_logically_partitioned, +) + +DeepSeek4ScannableBlockToLinen = nnx_wrappers.to_linen_class( + DeepSeek4ScannableBlock, + base_metadata_fn=initializers.variable_to_logically_partitioned, +) diff --git a/tests/unit/deepseek_v4_vs_reference_test.py b/tests/unit/deepseek_v4_vs_reference_test.py index 1da95a184e..0233655964 100644 --- a/tests/unit/deepseek_v4_vs_reference_test.py +++ b/tests/unit/deepseek_v4_vs_reference_test.py @@ -524,9 +524,13 @@ def _run_e2e_test(self, layer_type, is_packed=False): "compressed_sparse_attention": self.pt_config.compress_rates["compressed_sparse_attention"], "heavily_compressed_attention": self.pt_config.compress_rates["heavily_compressed_attention"], } + compress_ratio = compress_ratio_map[layer_type] + layer_attention_type = AttentionType.LOCAL_SLIDING if compress_ratio == 0 else AttentionType.COMPRESSED + mt_attn = CompressedAttention( config=mt_config, - compress_ratio=compress_ratio_map[layer_type], + compress_ratio=compress_ratio, + attention_type=layer_attention_type, num_query_heads=self.num_heads, num_kv_heads=1, head_dim=self.head_dim, @@ -652,8 +656,7 @@ def _run_e2e_test(self, layer_type, is_packed=False): print(f"top_k_indices mismatches: {num_mismatches}") # 6. Execute MaxText - - mt_out = mt_attn(x_mt, x_mt, segs_mt, pos_mt, deterministic=True, model_mode=MODEL_MODE_TRAIN) + mt_out, _ = mt_attn(x_mt, x_mt, segs_mt, pos_mt, deterministic=True, model_mode=MODEL_MODE_TRAIN) # 7. Asserts if not is_packed: From fa86b18c7e930c932ec163daa271eff61a1a797c Mon Sep 17 00:00:00 2001 From: Dipak Gaikwad Date: Sun, 14 Jun 2026 16:40:16 +0000 Subject: [PATCH 2/2] Enabled auxilillary loss free load balancing and sequence wise load balancing for both Deepseek v3 and V4. Tested by running training loop with new tiny Deeepseek V4 model added as part of the commit, here are the logs for testing Without load balancing active logs : https://paste.googleplex.com/6421399878107136 with load balancing logs : https://paste.googleplex.com/6551357300539392 Here are the results actived for reducing the varience : 1 === Load Balancing Variance Analysis (Step 0 vs Step 20) === 2 3 | Layer Index | Step 0 Var (Baseline) | Step 20 Var (Run A) | Step 20 Var (Run B) | Improvement (A vs B) | 4 |-------------|-----------------------|---------------------|---------------------|----------------------| 5 | 0 | 774432.00 | 8746.00 | 3763.00 | 56.97% | 6 | 10 | 780344.00 | 13780.50 | 8604.25 | 37.56% | 7 | 20 | 756392.00 | 5593.12 | 1832.38 | 67.24% | 8 | 30 | 784392.00 | 8013.00 | 3920.38 | 51.07% | 9 | 42 | 594376.00 | 2286.50 | 436.38 | 80.92% | 10 |-------------|-----------------------|---------------------|---------------------|----------------------| 11 | TOTAL/AVG | 31439160.00 | 384273.50 | 202623.25 | 47.27% | Raw data collected for this analysis: https://paste.googleplex.com/6097992598814720 https://paste.googleplex.com/4811890982256640 --- src/maxtext/common/metric_logger.py | 4 +- src/maxtext/configs/models/deepseek4-tiny.yml | 68 +++++++++++++++++++ src/maxtext/configs/models/deepseek4.yml | 4 ++ src/maxtext/configs/types.py | 3 +- src/maxtext/layers/moe.py | 6 +- src/maxtext/optimizers/optimizers.py | 15 ++++ src/maxtext/trainers/pre_train/train.py | 66 +++++++++++++----- tests/unit/deepseek_routed_bias_test.py | 65 ++++++++++++++++++ tests/unit/optimizers_test.py | 39 +++++++++++ tests/unit/train_nnx_test.py | 18 ++++- 10 files changed, 265 insertions(+), 23 deletions(-) create mode 100644 src/maxtext/configs/models/deepseek4-tiny.yml create mode 100644 tests/unit/deepseek_routed_bias_test.py diff --git a/src/maxtext/common/metric_logger.py b/src/maxtext/common/metric_logger.py index 44771ecb05..a36d56da84 100644 --- a/src/maxtext/common/metric_logger.py +++ b/src/maxtext/common/metric_logger.py @@ -197,7 +197,7 @@ def _log_training_metrics(self, metrics, step): if self.config.num_experts > 1: moe_lb_loss = scalars.get("learning/moe_lb_loss", 0.0) - log_parts.append(f"moe_lb_loss: {moe_lb_loss:.3f}") + log_parts.append(f"moe_lb_loss: {moe_lb_loss:.6f}") if self.config.mtp_num_layers > 0: mtp_loss = scalars.get("learning/mtp_loss", 0.0) @@ -217,7 +217,7 @@ def _log_eval_metrics(self, metrics, step): f"avg_z_loss={scalars.get('eval/avg_z_loss', 0.0):.3f}", ] if self.config.num_experts > 1: - log_parts.append(f"avg_moe_lb_loss={scalars['eval/avg_moe_lb_loss']:.3f}") + log_parts.append(f"avg_moe_lb_loss={scalars['eval/avg_moe_lb_loss']:.6f}") if self.config.mtp_num_layers > 0: log_parts.extend( [ diff --git a/src/maxtext/configs/models/deepseek4-tiny.yml b/src/maxtext/configs/models/deepseek4-tiny.yml new file mode 100644 index 0000000000..3bfa0973c0 --- /dev/null +++ b/src/maxtext/configs/models/deepseek4-tiny.yml @@ -0,0 +1,68 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Tiny model config for DeepSeek V4 for CPU execution and testing + +base_emb_dim: 64 +base_num_query_heads: 4 +base_num_kv_heads: 1 +base_num_decoder_layers: 43 +base_mlp_dim: 64 +base_moe_mlp_dim: 64 +vocab_size: 129280 +head_dim: 32 + +# --- Standard Defaults --- +enable_dropout: false +logits_via_embedding: false +normalization_layer_epsilon: 1.0e-6 + +# --- V4 Specific Architectural Keys --- +decoder_block: "deepseek4" +mhc_expansion_rate: 4 +first_num_hash_layers: 3 +indexer_head_dim: 32 +indexer_n_heads: 4 +indexer_topk: 16 + +# Note: Layers (0,1) are not compressed. +# The 44th layer (MTP module with compress_ratio=0) has been explicitly dropped for now. +# This leaves exactly 43 layers: 2 prefix [0,0] + 40 scanned + 1 suffix [4]. +compress_ratios: [0, 0, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4] + +# --- MoE configuration --- +mlp_activations: ["silu", "linear"] +num_experts: 16 +num_experts_per_tok: 4 +shared_experts: 1 +routed_score_func: "sqrtsoftplus" +routed_bias: true +routed_bias_update_rate: 0.001 +load_balance_loss_weight: 0.0001 +adamw_mask: [".*gate.*bias.*"] + +# --- Attention configuration --- +attention: 'dot_product' +attention_type: 'compressed' +q_lora_rank: 16 +o_groups: 4 +o_lora_rank: 16 +sliding_window_size: 32 + +# --- RoPE --- + +rope_type: "default" +rope_max_timescale: 10000 # Main RoPE theta +compressed_rope_max_timescale: 160000 # Compressed RoPE theta +max_position_embeddings: 4096 diff --git a/src/maxtext/configs/models/deepseek4.yml b/src/maxtext/configs/models/deepseek4.yml index 521e8c4871..29cf6063d7 100644 --- a/src/maxtext/configs/models/deepseek4.yml +++ b/src/maxtext/configs/models/deepseek4.yml @@ -46,6 +46,10 @@ num_experts: 256 num_experts_per_tok: 6 shared_experts: 1 routed_score_func: "sqrtsoftplus" +routed_bias: true +routed_bias_update_rate: 0.001 +load_balance_loss_weight: 0.0001 +adamw_mask: [".*gate.*bias.*"] # --- Attention configuration --- attention: 'dot_product' diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index e6cbdf666f..5f0d2e18fe 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -226,6 +226,7 @@ class ProfilerType(str, Enum): "deepseek3-671b-batchsplit", "deepseek3-test", "deepseek3-tiny", + "deepseek4-tiny", "deepseek3.2-671b", "deepseek4", "deepseek-custom", @@ -2980,7 +2981,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de ) if self.decoder_block == DecoderBlockType.GPT_OSS and not self.sparse_matmul and self.capacity_factor != -1: raise ValueError("GPT-OSS MoE only supports dropless (capacity_factor=-1) with dense matmul.") - if self.routed_bias and self.routed_bias_update_rate > 0.0 and self.decoder_block != DecoderBlockType.DEEPSEEK: + if self.routed_bias and self.routed_bias_update_rate > 0.0 and self.decoder_block not in (DecoderBlockType.DEEPSEEK, DecoderBlockType.DEEPSEEK4): raise ValueError("Loss-free load balancing is only supported for the DeepSeek decoder block.") if self.model_name.startswith("deepseek4") and self.first_num_hash_layers > 0 and self.use_ring_of_experts: raise ValueError("DeepSeek V4 hash routing is currently not supported with ring of experts.") diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index 0d27a1ec8f..5d92642d50 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -348,8 +348,11 @@ def __call__(self, inputs: jax.Array, _initializing: bool = False) -> Tuple[jax. pre_bias_logits = output if self.use_bias: + # Architectural Note: Bias is an nnx.Param rather than nnx.Variable due to Linen/NNX state + # management transitions otherwise we will have to manage the overhead. We use jax.lax.stop_gradient + # here to mathematically enforce the Auxiliary-Loss-Free constraint, isolating it from sequence-wise loss leaks. bias = jnp.asarray(self.bias[...], self.dtype) - output += bias + output += jax.lax.stop_gradient(bias) return output, pre_bias_logits @@ -2162,7 +2165,6 @@ def dense_matmul( lb_loss = ( self.load_balance_loss(top_k_indices, softmax_probs) if self.config.load_balance_loss_weight > 0.0 else None ) - # TODO(dipakg-lang, b/521990776): Add sequence-wise balance loss * 0.0001 else: lb_loss = None diff --git a/src/maxtext/optimizers/optimizers.py b/src/maxtext/optimizers/optimizers.py index 9992d7674f..4200504927 100644 --- a/src/maxtext/optimizers/optimizers.py +++ b/src/maxtext/optimizers/optimizers.py @@ -238,6 +238,21 @@ def get_optimizer(config, learning_rate_schedule, model=None): lambda params: jax.tree_util.tree_map(lambda x: "frozen" if x else "trainable", freeze_mask_fn(params)), ) + if getattr(config, "routed_bias", False): + import re + from flax import traverse_util + bias_regex = re.compile(".*gate.*bias.*") + # Architectural Note: Optax's Muon implementation correctly routes 2D+ matrices to the + # Newton-Schulz algorithm, but its fallback logic for 1D vectors (like our GateLogit bias) + # routes them to a standard AdamW optimizer *without* exposing a weight decay mask. + # To prevent the Muon optimizer from decaying our auxiliary-loss-free bias to zero, + # we apply a global optax.set_to_zero() mask here. + def bias_mask_fn(params): + flat_params = traverse_util.flatten_dict(params) + mask = {k: bool(bias_regex.match("/".join(map(str, k)))) for k in flat_params.keys()} + return traverse_util.unflatten_dict(mask) + base_opt = optax.chain(base_opt, optax.masked(optax.set_to_zero(), bias_mask_fn)) + return base_opt diff --git a/src/maxtext/trainers/pre_train/train.py b/src/maxtext/trainers/pre_train/train.py index 3be6baff8c..6867033113 100644 --- a/src/maxtext/trainers/pre_train/train.py +++ b/src/maxtext/trainers/pre_train/train.py @@ -277,12 +277,6 @@ def loss_fn(model, config, data, dropout_rng, params, sparsity_state=None, is_tr else: max_logging.debug("\nNo MoE load balance loss found. Defaulting to 0.0.") - # get MoE routed bias term updates - moe_bias_updates = None - if config.routed_bias and config.routed_bias_update_rate > 0.0: - nested_key = ("intermediates", "decoder", "moe_layers", "moe_bias_updates") - moe_bias_updates = maxtext_utils.get_nested_value(intermediate_outputs, nested_key, None) - # Add the model's primary output to the intermediates dict so it can be used # by the acceptance rate calculation in eval_step. intermediate_outputs["logits"] = logits @@ -294,7 +288,6 @@ def loss_fn(model, config, data, dropout_rng, params, sparsity_state=None, is_tr "total_weights": total_weights, "moe_lb_loss": moe_lb_loss, "indexer_loss": indexer_loss, - "moe_bias_updates": moe_bias_updates, "mtp_loss": mtp_loss, "batch_stats": (intermediate_outputs.get("batch_stats", None) if hasattr(intermediate_outputs, "get") else None), } @@ -408,9 +401,9 @@ def diff_wrapper(param, rest, config, data): moe_lb_loss = aux["moe_lb_loss"] indexer_loss = aux.get("indexer_loss", 0.0) z_loss = aux.get("z_loss", 0.0) - moe_bias_updates = aux.get("moe_bias_updates") mtp_loss = aux.get("mtp_loss", 0.0) new_opt_state = None + bias_metrics = {} if isinstance(model, nn.Module): if config.gradient_clipping_threshold > 0: @@ -467,12 +460,30 @@ def move(path, value): else: new_state = state.apply_gradients(grads=full_grads) - # Apply updates for Auxiliary-Loss-Free load balancing for DeepSeek family - if config.routed_bias and config.routed_bias_update_rate > 0.0 and moe_bias_updates is not None: - target_path = ("params", "decoder", "moe_layers", "DeepSeekMoeBlock_0", "MoeBlock_0", "gate", "bias") - # Updates the shape to be aligned with state. - moe_bias_updates = jnp.array(moe_bias_updates[0]).transpose() - new_state = maxtext_utils.update_state_param(new_state, target_path, moe_bias_updates) + # Apply updates for Auxiliary-Loss-Free load balancing for the DeepSeek family. + # We dynamically traverse the PyTree to apply updates because the topology varies drastically: + # 1. DeepSeek V3 mixes dense layers (no bias updates) with MoE layers. + # 2. DeepSeek V4 introduces Hash Routing in early layers (which lack a learnable bias entirely). + # 3. DeepSeek V4 groups alternating attention topologies into nested `ScannableBlocks`. + # Dynamic traversal ensures we only target the correct `gate.bias` parameters without hardcoded, brittle paths. + if config.routed_bias and config.routed_bias_update_rate > 0.0: + from flax import traverse_util + flat_intermediates = traverse_util.flatten_dict(aux.get("intermediate_outputs", {})) + flat_params = traverse_util.flatten_dict(new_state.params) + new_flat_params = dict(flat_params) + + for path, update in flat_intermediates.items(): + if path[-1] == "moe_bias_updates": + prefix = path[1:-1] if path[0] == "intermediates" else path[:-1] + for param_path in flat_params.keys(): + param_prefix = param_path[1:] if param_path[0] == "params" else param_path + if len(param_prefix) >= len(prefix) and param_prefix[:len(prefix)] == prefix and param_path[-2:] == ("gate", "bias"): + update_val = update[0] if isinstance(update, (tuple, list)) else update + bias_metrics[f"learning/moe_bias_before_norm_{'-'.join(map(str, param_path))}"] = jnp.linalg.norm(new_flat_params[param_path]) + new_flat_params[param_path] = new_flat_params[param_path] + jnp.array(update_val).transpose() + bias_metrics[f"learning/moe_bias_update_norm_{'-'.join(map(str, param_path))}"] = jnp.linalg.norm(jnp.array(update_val)) + + new_state = new_state.replace(params=traverse_util.unflatten_dict(new_flat_params)) else: if config.gradient_clipping_threshold > 0: grads = maxtext_utils.apply_gradient_clipping(raw_grads, None, config.gradient_clipping_threshold) @@ -493,9 +504,27 @@ def move(path, value): new_state = state # Apply updates for Auxiliary-Loss-Free load balancing for DeepSeek family - if config.routed_bias and config.routed_bias_update_rate > 0.0 and moe_bias_updates is not None: - target_bias = new_state.model.decoder.moe_layers.DeepSeekMoeBlock_0.MoeBlock_0.gate.bias - target_bias.value = target_bias.value + jnp.array(moe_bias_updates[0]).transpose() + if config.routed_bias and config.routed_bias_update_rate > 0.0: + from flax import traverse_util + flat_intermediates = traverse_util.flatten_dict(aux.get("intermediate_outputs", {})) + jax.debug.print("FLAT_INTERMEDIATE_KEYS_NNX: {}", flat_intermediates.keys()) + for path, update in flat_intermediates.items(): + if path[-1] == "moe_bias_updates": + target = new_state.model + for key in path[:-1]: + if hasattr(target, key): + target = getattr(target, key) + elif isinstance(target, dict) and key in target: + target = target[key] + else: + break + else: + for _, node in nnx.iter_graph(target): + if type(node).__name__ == "GateLogit" and hasattr(node, "bias") and node.bias is not None: + update_val = update[0] if isinstance(update, (tuple, list)) else update + bias_metrics[f"learning/moe_bias_before_norm_{'-'.join(map(str, path[:-1]))}"] = jnp.linalg.norm(node.bias.value) + node.bias.value = node.bias.value + jnp.array(update_val).transpose() + bias_metrics[f"learning/moe_bias_update_norm_{'-'.join(map(str, path[:-1]))}"] = jnp.linalg.norm(jnp.array(update_val)) lm_loss = xent_sum / (total_weights + EPS) scalar_metrics = { @@ -508,6 +537,9 @@ def move(path, value): "learning/mtp_loss": mtp_loss, "learning/total_weights": total_weights, } + scalar_metrics.update(bias_metrics) + if bias_metrics: + jax.debug.print("--- ROUTED BIAS METRICS --- {}", bias_metrics) if config.use_qk_clip: if isinstance(model, nn.Module): new_state = qk_clip_utils.apply_qk_clip(new_state, intermediate_outputs, config) diff --git a/tests/unit/deepseek_routed_bias_test.py b/tests/unit/deepseek_routed_bias_test.py new file mode 100644 index 0000000000..9e12a7da9a --- /dev/null +++ b/tests/unit/deepseek_routed_bias_test.py @@ -0,0 +1,65 @@ +import unittest +import jax +import jax.numpy as jnp +import optax +from flax.training import train_state +from maxtext.configs import pyconfig +from maxtext.models import models +from maxtext.trainers.pre_train import train as pre_train +class DeepSeekRoutedBiasTest(unittest.TestCase): + def setUp(self): + self.mesh = jax.sharding.Mesh(jax.devices(), ('data',)) + def _make_dummy_data(self, batch=1, seq=16): + return { + "inputs": jnp.zeros((batch, seq), dtype=jnp.int32), + "inputs_position": jnp.broadcast_to(jnp.arange(seq), (batch, seq)), + "inputs_segmentation": jnp.ones((batch, seq), dtype=jnp.int32), + "targets": jnp.zeros((batch, seq), dtype=jnp.int32), + "targets_segmentation": jnp.ones((batch, seq), dtype=jnp.int32), + } + def _create_and_run_train_step(self, config_args): + config = pyconfig.initialize(config_args) + rngs = jax.nnx.Rngs(0) if hasattr(jax, 'nnx') else __import__('flax.nnx', fromlist=['Rngs']).Rngs(0) + import flax.nnx as nnx + from maxtext.common import train_state_nnx + rngs = nnx.Rngs(0) + model = models.Transformer(config, self.mesh, quant=None, rngs=rngs) + data = self._make_dummy_data(batch=config.micro_batch_size_to_train_on, seq=config.max_target_length) + optimizer = nnx.Optimizer(model, optax.sgd(0.01), wrt=nnx.Param) + ts = train_state_nnx.TrainStateNNX(model, optimizer) + state_graphdef, state_pure = nnx.split(ts) + new_state, metrics = pre_train.train_step( + state_graphdef, config, state_mesh_shardings=None, params_shardings=None, state=state_pure, data=data + ) + return new_state, metrics + def test_deepseek_v3_dense_routed_bias_success(self): + """Proves that a DeepSeek V3 model with dense layers (no moe_layers attribute) + successfully traverses the state tree and updates routed bias without crashing. + """ + config_args = [ + "", + "src/maxtext/configs/base.yml", + "model_name=deepseek3-tiny", + "decoder_block=deepseek", + "num_decoder_layers=2", + "per_device_batch_size=1", + "max_target_length=16", + "routed_bias=True", + "routed_bias_update_rate=0.001", + "skip_jax_distributed_system=True", + "base_emb_dim=64", + "base_mlp_dim=64", + "base_moe_mlp_dim=64", + "base_num_query_heads=1", + "base_num_kv_heads=1", + "num_experts=2", + "num_experts_per_tok=2", + "first_num_dense_layers=1", + "sparse_matmul=False", + "override_model_config=True", + ] + new_state, metrics = self._create_and_run_train_step(config_args) + self.assertIsNotNone(new_state) + self.assertIn("learning/loss", metrics["scalar"]) +if __name__ == '__main__': + unittest.main() diff --git a/tests/unit/optimizers_test.py b/tests/unit/optimizers_test.py index b8eab1061e..4b9fe305eb 100644 --- a/tests/unit/optimizers_test.py +++ b/tests/unit/optimizers_test.py @@ -622,5 +622,44 @@ def __init__(self, rngs: nnx.Rngs): self.assertEqual(result.self_attention.out.kernel.value, mdn((0, -2), (-1,))) +class TestGetOptimizerGlobalMask(unittest.TestCase): + """Tests that the global optimizer cleanly masks out the routed bias.""" + def test_routed_bias_global_mask(self): + config = pyconfig.initialize(["", "src/maxtext/configs/base.yml", "routed_bias=True", "opt_type=sgd"]) + # We define a dummy params dict containing a routed bias and a regular weight. + # The routed bias must be completely ignored by the optimizer. + params = { + "decoder": { + "moe_layers": { + "MoeBlock_0": { + "gate": { + "bias": jnp.array([1.0]), + "kernel": jnp.array([1.0]) + } + } + } + } + } + grads = { + "decoder": { + "moe_layers": { + "MoeBlock_0": { + "gate": { + "bias": jnp.array([0.5]), + "kernel": jnp.array([0.5]) + } + } + } + } + } + # We use sgd because it's simple to test updates, but the mask logic applies + # cleanly to any base optimizer returned by get_optimizer. + opt = optimizers.get_optimizer(config, learning_rate_schedule=0.1) + opt_state = opt.init(params) + updates, _ = opt.update(grads, opt_state, params) + # The routed bias update should be exactly 0.0 (masked by set_to_zero) + self.assertEqual(updates["decoder"]["moe_layers"]["MoeBlock_0"]["gate"]["bias"].item(), 0.0) + # The kernel should receive the SGD gradient update (-0.1 * 0.5) + self.assertTrue(updates["decoder"]["moe_layers"]["MoeBlock_0"]["gate"]["kernel"].item() < 0.0) if __name__ == "__main__": unittest.main() diff --git a/tests/unit/train_nnx_test.py b/tests/unit/train_nnx_test.py index ebeededbd7..b31bc4a5dc 100644 --- a/tests/unit/train_nnx_test.py +++ b/tests/unit/train_nnx_test.py @@ -61,8 +61,12 @@ class _Cfg: shard_mode: int = 0 # ShardMode.AUTO weight_sparsity_n: int = 0 weight_sparsity_m: int = 0 + decoder_block: str = "default" +class _DummyDecoder(nnx.Module): + pass + class _TinyDecoder(nnx.Module): """Mimics NNXDecoder.__call__ enough for loss_fn to run end-to-end. @@ -73,6 +77,7 @@ class _TinyDecoder(nnx.Module): def __init__(self, vocab_size: int, hidden: int, rngs: nnx.Rngs): self.embed = nnx.Embed(vocab_size, hidden, rngs=rngs) self.proj = nnx.Linear(hidden, vocab_size, rngs=rngs) + self.decoder = _DummyDecoder() def __call__( self, @@ -125,7 +130,6 @@ def test_returns_loss_and_full_aux_dict(self): "total_weights", "moe_lb_loss", "indexer_loss", - "moe_bias_updates", "mtp_loss", ): self.assertIn(key, aux) @@ -194,6 +198,18 @@ def test_train_step_with_gradient_clipping(self): self.assertIsInstance(new_state, nnx.State) self.assertTrue(jnp.isfinite(metrics["scalar"]["learning/loss"])) + def test_train_step_deepseek_aux_loss(self): + cfg, ts = _build_state() + cfg.routed_bias = True + cfg.routed_bias_update_rate = 0.001 + cfg.decoder_block = "deepseek" + state_graphdef, state_pure = nnx.split(ts) + data = _make_data(batch=cfg.micro_batch_size_to_train_on, vocab=cfg.vocab_size) + # The robust trainer logic will correctly traverse and NOT crash, ignoring the hardcoded path + new_state, metrics = pre_train.train_step( + state_graphdef, cfg, state_mesh_shardings=None, params_shardings=None, state=state_pure, data=data + ) + self.assertIsInstance(new_state, nnx.State) class TestEvalStepNNX(unittest.TestCase): """Cover the NNX branch of eval_step (lines 568-570)."""