From 631941c0c4d70c98d23d8d59cb58dc87e7ec12d6 Mon Sep 17 00:00:00 2001 From: Param Bole Date: Wed, 17 Jun 2026 18:10:16 +0000 Subject: [PATCH] feat(models): Integrate DeepSeek V4 architecture and routing --- src/maxtext/common/common_types.py | 1 + src/maxtext/configs/models/deepseek4-284b.yml | 64 ++++ src/maxtext/configs/types.py | 6 +- src/maxtext/layers/attention_compressed.py | 37 +-- src/maxtext/layers/attentions.py | 1 + src/maxtext/layers/decoders.py | 110 ++++++- src/maxtext/layers/embeddings.py | 17 +- src/maxtext/layers/moe.py | 19 +- src/maxtext/models/deepseek.py | 68 ++--- src/maxtext/models/deepseek4.py | 274 ++++++++++++++++++ src/maxtext/utils/globals.py | 1 + tests/unit/deepseek_v4_vs_reference_test.py | 39 +-- tests/unit/train_compile_test.py | 20 ++ 13 files changed, 578 insertions(+), 79 deletions(-) create mode 100644 src/maxtext/configs/models/deepseek4-284b.yml create mode 100644 src/maxtext/models/deepseek4.py diff --git a/src/maxtext/common/common_types.py b/src/maxtext/common/common_types.py index d4b52207fc..71dbc105d4 100644 --- a/src/maxtext/common/common_types.py +++ b/src/maxtext/common/common_types.py @@ -113,6 +113,7 @@ class DecoderBlockType(enum.Enum): SIMPLE_MLP = "simple_mlp" LLAMA4 = "llama4" OLMO3 = "olmo3" + DEEPSEEK4 = "deepseek4" class AttentionType(enum.Enum): diff --git a/src/maxtext/configs/models/deepseek4-284b.yml b/src/maxtext/configs/models/deepseek4-284b.yml new file mode 100644 index 0000000000..5ba2dd062f --- /dev/null +++ b/src/maxtext/configs/models/deepseek4-284b.yml @@ -0,0 +1,64 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Model config for DeepSeek-V4-Flash 284B (https://huggingface.co/deepseek-ai/DeepSeek-V4-Flash) + +base_emb_dim: 4096 +base_num_query_heads: 64 +base_num_kv_heads: 1 +base_num_decoder_layers: 43 +base_mlp_dim: 2048 +base_moe_mlp_dim: 2048 +vocab_size: 129280 +head_dim: 512 + +# --- Standard Defaults --- +enable_dropout: false +logits_via_embedding: false +normalization_layer_epsilon: 1.0e-6 + +# --- V4 Specific Architectural Keys --- +decoder_block: "deepseek4" +mhc_expansion_rate: 4 +first_num_hash_layers: 3 +indexer_head_dim: 128 +indexer_n_heads: 64 +indexer_topk: 512 + +# Note: Layers (0, 1, 2) are prefix layers. +# The 44th layer (MTP module with compress_ratio=0) has been explicitly dropped for now. +# This leaves exactly 43 layers: 3 prefix [0,0,4] + 40 scanned. +compress_ratios: [0, 0, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4] + +# --- MoE configuration --- +mlp_activations: ["silu", "linear"] +num_experts: 256 +num_experts_per_tok: 6 +mlp_activations_limit: 10 +shared_experts: 1 +routed_score_func: "sqrtsoftplus" + +# --- Attention configuration --- +attention_type: 'compressed' +q_lora_rank: 1024 +o_groups: 8 +o_lora_rank: 1024 +sliding_window_size: 128 + +# --- RoPE --- + +rope_type: "default" +rope_max_timescale: 10000 # Main RoPE theta +compressed_rope_max_timescale: 160000 # Compressed RoPE theta +max_position_embeddings: 1048576 diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index e43f34f247..d1f293aae8 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -227,7 +227,7 @@ class ProfilerType(str, Enum): "deepseek3-test", "deepseek3-tiny", "deepseek3.2-671b", - "deepseek4", + "deepseek4-284b", "deepseek-custom", "kimi-k2-1t", "gemma-7b", @@ -553,7 +553,7 @@ class Attention(BaseModel): "autoselected", description="The attention algorithm to use (dot_product, flash, etc).", ) - attention_type: Literal["global", "local_sliding", "chunk", "mla", "full"] = Field( + attention_type: Literal["global", "local_sliding", "chunk", "mla", "full", "compressed"] = Field( "global", description="The variant of attention to use." ) share_kv_projections: bool = Field( @@ -2925,6 +2925,8 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de raise ValueError("`local_checkpoint_period` must be > 0 for emergency checkpointing.") if self.moba and self.attention not in ("dot_product"): raise ValueError("MoBA is only supported with dot_product attention.") + if self.decoder_block == DecoderBlockType.DEEPSEEK4 and self.attention != "dot_product": + raise ValueError("DeepSeek4 decoder block currently only supports dot_product attention.") if self.use_indexer: if self.q_lora_rank == 0: raise NotImplementedError("Sparse indexer has not implemented for q_lora_rank = 0.") diff --git a/src/maxtext/layers/attention_compressed.py b/src/maxtext/layers/attention_compressed.py index e9a25f46b5..391ec6cedd 100644 --- a/src/maxtext/layers/attention_compressed.py +++ b/src/maxtext/layers/attention_compressed.py @@ -680,24 +680,23 @@ def __init__( rngs: Optional[nnx.Rngs] = None, **kwargs, ): - """Initializes the CompressedAttention layer. + """Inherits all standard Attention hyperparameters and selectively instantiates + an underlying HCA or CSA compressor based on the provided `compress_ratio`. - Inherits all standard Attention hyperparameters and selectively instantiates - an underlying HCA or CSA compressor based on the provided `layer_type`. + Highlights of DeepSeek-V4 attention integration: + - Shared-KV: The layer supports decoupling Q and KV heads for heavy compression. + - MQA: Multi-Query Attention used alongside heavy KV compression. + - 3 Different Attention Modes: Sliding Window (prefix), HCA (128x), and CSA (4x). + - Dual RoPE Theta: Uses 10000 for standard uncompressed tokens and 160000 for compressed. Args: (See maxtext.layers.attentions.Attention for standard attention arguments) q_lora_rank: The rank for the LoRA projection in the compressed query. - compress_ratio: The compression ratio for the compressor. + compress_ratio: The compression ratio (0, 4, or 128) for the compressor. """ - """Initializes the Compressed Attention module.""" self.q_lora_rank = q_lora_rank self.compress_ratio = compress_ratio - # Determine the correct underlying attention type based on the compress_ratio - if self.compress_ratio == 0: - attention_type = AttentionType.LOCAL_SLIDING - super().__init__( config=config, num_query_heads=num_query_heads, @@ -809,20 +808,22 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No rngs=self.rngs, ) - # DeepSeek-V4 uses a separate RoPE theta (160000) for compressed tokens. - # We must instantiate a dedicated rotary embedding for the compressors - self.compress_rotary_embedding = DeepSeekV4RotaryEmbedding( + # Override the base rotary embedding with the correct theta for this layer. + # CSA / HCA layers use compressed_rope_max_timescale (160000). + # Sliding window prefix layers use rope_max_timescale (10000). + rope_theta = self.config.compressed_rope_max_timescale if self.compress_ratio > 0 else self.config.rope_max_timescale + self.rotary_embedding = DeepSeekV4RotaryEmbedding( head_dim=self.config.head_dim, - partial_rotary_factor=1.0, - rope_theta=self.config.compressed_rope_max_timescale, - dtype=self.dtype, + partial_rotary_factor=self.config.qk_rope_head_dim / self.config.head_dim, + rope_theta=rope_theta, + fprop_dtype=self.dtype, ) if self.compress_ratio > 4: self.hca_compressor = DeepseekV4HCACompressor( config=self.config, compress_ratio=self.compress_ratio, - rotary_embedding=self.compress_rotary_embedding, + rotary_embedding=self.rotary_embedding, kernel_init=self.kernel_init, quant=self.quant, model_mode=self.model_mode, @@ -832,7 +833,7 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No self.csa_compressor = DeepseekV4CSACompressor( config=self.config, compress_ratio=self.compress_ratio, - rotary_embedding=self.compress_rotary_embedding, + rotary_embedding=self.rotary_embedding, kernel_init=self.kernel_init, quant=self.quant, model_mode=self.model_mode, @@ -1047,7 +1048,7 @@ def __call__( # -> [batch, q_length, emb_dim] final_out = self.o_b_proj(grouped_flat) - return final_out + return final_out, None def compressed_attention( diff --git a/src/maxtext/layers/attentions.py b/src/maxtext/layers/attentions.py index 679c891360..ab7673d1d4 100644 --- a/src/maxtext/layers/attentions.py +++ b/src/maxtext/layers/attentions.py @@ -850,6 +850,7 @@ def init_rotary_embedding(self): shard_mode=self.config.shard_mode, rngs=self.rngs, ) + elif self.is_qwen3_hybrid: rotary_embedding = PartialRotaryEmbedding( min_timescale=self.config.rope_min_timescale, diff --git a/src/maxtext/layers/decoders.py b/src/maxtext/layers/decoders.py index b28b6dcb7a..0150c7b401 100644 --- a/src/maxtext/layers/decoders.py +++ b/src/maxtext/layers/decoders.py @@ -41,6 +41,7 @@ from maxtext.layers.quantizations import AqtQuantization as Quant from maxtext.models import ( deepseek, + deepseek4, deepseek_batchsplit, deepseek_batchsplit_fp8, gemma, @@ -467,6 +468,10 @@ def get_decoder_layers(self): deepseek.DeepSeekDenseLayerToLinen, deepseek.DeepSeekMoELayerToLinen, ] + case DecoderBlockType.DEEPSEEK4: + return ( + [deepseek4.DeepSeek4ScannableBlockToLinen] if self.config.scan_layers else [deepseek4.DeepSeek4LayerToLinen] + ) case DecoderBlockType.GEMMA: return [gemma.GemmaDecoderLayerToLinen] case DecoderBlockType.GEMMA2: @@ -632,6 +637,7 @@ def get_norm_layer(self, num_features: int): DecoderBlockType.MISTRAL, DecoderBlockType.MIXTRAL, DecoderBlockType.DEEPSEEK, + DecoderBlockType.DEEPSEEK4, DecoderBlockType.GEMMA, DecoderBlockType.GEMMA2, DecoderBlockType.GEMMA3, @@ -1061,6 +1067,17 @@ def __call__( previous_chunk, slot, ) + elif cfg.decoder_block == DecoderBlockType.DEEPSEEK4: + y = self._apply_deepseek4_scanned_blocks( + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + previous_chunk, + slot, + decoder_input_tokens, + ) else: RemattedBlockLayer = RemattedBlockLayers[0] scan_length = int(cfg.num_decoder_layers / cfg.inhomogeneous_layer_cycle_interval) @@ -1195,7 +1212,7 @@ def __call__( "is_nope_layer": llama4.determine_is_nope_layer(lyr, self.config.nope_layer_interval), "is_moe_layer": llama4.determine_is_moe_layer(lyr, self.config.interleave_moe_layer_step), } - if cfg.decoder_block in (DecoderBlockType.QWEN3_NEXT, DecoderBlockType.QWEN3_5): + if cfg.decoder_block in (DecoderBlockType.QWEN3_NEXT, DecoderBlockType.QWEN3_5, DecoderBlockType.DEEPSEEK4): layer_kwargs = {"layer_idx": lyr} kv_cache = None if kv_caches is not None: @@ -1423,6 +1440,97 @@ def _apply_gemma4_scanned_blocks( return y + def _apply_deepseek4_scanned_blocks( + self, + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + previous_chunk, + slot, + decoder_input_tokens, + ): + """Applies DeepSeek V4 scanned decoder blocks. + + DeepSeek V4 has some number of prefix layers (defined by `first_num_hash_layers`) + that use static Hash Routing. The remaining layers alternate `compress_ratio=128` (HCA) + and `compress_ratio=4` (CSA) and are evaluated in a single `nn.scan` block. + + For DeepSeek4-Flash (43 hidden layers total): + - 3 Prefix layers (Indices 0, 1, 2) + - 40 Scanned layers: 20 perfectly repeating chunks of [128, 4] + """ + + cfg = self.config + mesh = self.mesh + + broadcast_args = ( + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + slot, + previous_chunk, + ) + + layer_call_kwargs = { + "previous_chunk": previous_chunk, + "slot": slot, + "decoder_input_tokens": decoder_input_tokens, + } + + # 1. Prefix Unrolling + # These layers use Hash Routing. + num_hash_layers = cfg.first_num_hash_layers + for layer_idx in range(num_hash_layers): + prefix_layer = deepseek4.DeepSeek4LayerToLinen( + config=cfg, + mesh=mesh, + name=f"layers_{layer_idx}", + quant=self.quant, + model_mode=self.model_mode, + layer_idx=layer_idx, + ) + y, _ = prefix_layer( + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + **layer_call_kwargs, + ) + + # 2. Chunked Scanning + # The remaining layers perfectly alternate HCA (128) and CSA (4). + num_remaining_layers = cfg.num_decoder_layers - num_hash_layers + num_full_blocks = num_remaining_layers // 2 + + if num_full_blocks > 0: + ScannableBlockToLinen = deepseek4.DeepSeek4ScannableBlockToLinen + policy = self.get_remat_policy() + RemattedDeepSeek4Block = self.set_remat_policy([ScannableBlockToLinen], policy)[0] + + y, _ = nn.scan( + RemattedDeepSeek4Block, + variable_axes={ + "params": cfg.param_scan_axis, + "cache": 0, + "intermediates": 0, + "aqt": 0, + "_overwrite_with_gradient": 0, + }, + split_rngs={"params": True, "dropout": cfg.enable_dropout}, + in_axes=(nn.broadcast,) * len(broadcast_args), + length=num_full_blocks, + metadata_params={ + nn.PARTITION_NAME: "layers", + "abstract_init": False, + }, + )(config=cfg, mesh=mesh, quant=self.quant, model_mode=model_mode, name="scanned_blocks",)(y, *broadcast_args) + + return y + def _apply_gemma4_small_layers( self, y, diff --git a/src/maxtext/layers/embeddings.py b/src/maxtext/layers/embeddings.py index 86b6723bd5..ad6b171f2f 100644 --- a/src/maxtext/layers/embeddings.py +++ b/src/maxtext/layers/embeddings.py @@ -1803,7 +1803,7 @@ def qwen3_omni_mrope_embedding_as_linen( ) -class DeepSeekV4RotaryEmbedding(nnx.Module): +class DeepSeekV4RotaryEmbedding(RotaryEmbedding): """DeepSeek-V4 partial rotary embedding with interleaved frequencies. DeepSeek-V4 uses an interleaved positional encoding where consecutive channels @@ -1822,12 +1822,23 @@ def __init__( head_dim: int, partial_rotary_factor: float = 64.0 / 512.0, rope_theta: float = 10000.0, - dtype: Any = jnp.float32, + fprop_dtype: Any = jnp.float32, + min_timescale: int = 10000, + max_timescale: int = 10000, + mesh: Any = None, + **kwargs, ): + super().__init__( + min_timescale=min_timescale, + max_timescale=max_timescale, + mesh=mesh, + fprop_dtype=fprop_dtype, + **kwargs, + ) self.head_dim = head_dim self.partial_rotary_factor = partial_rotary_factor self.rope_theta = rope_theta - self.dtype = dtype + self.fprop_dtype = fprop_dtype # Compute the partial rotary dimension (rope_head_dim) self.dim = int(head_dim * partial_rotary_factor) diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index 020956098c..4bb7cc7c08 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -208,6 +208,10 @@ def calculate_load_balance_updates(top_k_indices, num_experts, rate): return output +class Tid2EidVar(nnx.Variable): + """Custom variable to hold tid2eid without trainable param overhead.""" + + class GateLogit(nnx.Module): """A layer used to compute gate logits, allowing to return the pre bias values for DeepSeek routing.""" @@ -399,8 +403,11 @@ def __init__( # DeepSeek V4 Hash Routing if self.is_hash_routing: # Token-ID to Expert-ID lookup table for static routing - self.tid2eid = nnx.Variable( - jnp.zeros((self.config.vocab_size, self.num_experts_per_tok), dtype=jnp.int32), + # Must be stored as float32 because MaxText passes the entire variable tree + # through jax.value_and_grad, which strictly requires all leaves to be inexact types + # (even if they receive no gradients). We cast to int32 dynamically during routing. + self.tid2eid = Tid2EidVar( + jnp.zeros((self.config.vocab_size, self.num_experts_per_tok), dtype=jnp.float32), out_sharding=None, # Replicated across shards for local lookup ) else: @@ -665,7 +672,13 @@ def get_topk(self, gate_logits, pre_bias_logits, rngs=None, input_ids=None): return top_k_weights, top_k_indices if self.is_hash_routing: - top_k_indices = self.tid2eid[input_ids] + if input_ids is None: + raise ValueError("input_ids cannot be None when is_hash_routing is True") + # Access the static routing table + tid2eid_int = self.tid2eid.value + # Cast the float32 array to int32 (JAX automatically assigns 0.0 gradients to integer casts) + tid2eid_int = tid2eid_int.astype(jnp.int32) + top_k_indices = tid2eid_int[input_ids] top_k_weights = jnp.take_along_axis(pre_bias_logits, top_k_indices, axis=-1) # NOTE: deepseek2 has a different pattern elif self.config.model_name.startswith(("deepseek3", "deepseek4")): diff --git a/src/maxtext/models/deepseek.py b/src/maxtext/models/deepseek.py index 27e1a6f7ad..d3a72b31bf 100644 --- a/src/maxtext/models/deepseek.py +++ b/src/maxtext/models/deepseek.py @@ -25,7 +25,7 @@ import jax.numpy as jnp from jax.sharding import Mesh from maxtext.common.common_types import Config -from maxtext.common.common_types import HyperConnectionType, MODEL_MODE_PREFILL +from maxtext.common.common_types import HyperConnectionType, MODEL_MODE_PREFILL, DecoderBlockType from maxtext.layers import attention_mla from maxtext.layers import initializers from maxtext.layers import linears @@ -138,37 +138,39 @@ def __init__( self.engram_layer_norm = None self.engram = None - self.self_attention = attention_mla.MLA( - config=self.config, - num_query_heads=self.config.num_query_heads, - num_kv_heads=self.config.num_kv_heads, - head_dim=self.config.head_dim, - max_target_length=self.config.max_target_length, - max_prefill_predict_length=self.config.max_prefill_predict_length, - attention_kernel=self.config.attention, - attention_type=self.config.attention_type, - inputs_q_shape=self.dummy_inputs_shape, - inputs_kv_shape=self.dummy_inputs_shape, - mesh=mesh, - dtype=self.config.dtype, - weight_dtype=self.config.weight_dtype, - dropout_rate=self.config.dropout_rate, - name="self_attention", - quant=quant, - kv_quant=quantizations.configure_kv_quant(config), - q_lora_rank=self.config.q_lora_rank, - kv_lora_rank=self.config.kv_lora_rank, - qk_nope_head_dim=self.config.qk_nope_head_dim, - qk_rope_head_dim=self.config.qk_rope_head_dim, - v_head_dim=self.config.v_head_dim, - max_position_embeddings=self.config.max_position_embeddings, - original_max_position_embeddings=self.config.original_max_position_embeddings, - mscale=self.config.mscale, - rope_factor=self.config.rope_factor, - model_mode=model_mode, - rngs=rngs, - attn_logits_soft_cap=self.config.attn_logits_soft_cap, - ) + # DeepSeek V4 natively overrides this block with CompressedAttention. + if self.config.decoder_block != DecoderBlockType.DEEPSEEK4: + self.self_attention = attention_mla.MLA( + config=self.config, + num_query_heads=self.config.num_query_heads, + num_kv_heads=self.config.num_kv_heads, + head_dim=self.config.head_dim, + max_target_length=self.config.max_target_length, + max_prefill_predict_length=self.config.max_prefill_predict_length, + attention_kernel=self.config.attention, + attention_type=self.config.attention_type, + inputs_q_shape=self.dummy_inputs_shape, + inputs_kv_shape=self.dummy_inputs_shape, + mesh=mesh, + dtype=self.config.dtype, + weight_dtype=self.config.weight_dtype, + dropout_rate=self.config.dropout_rate, + name="self_attention", + quant=quant, + kv_quant=quantizations.configure_kv_quant(self.config), + q_lora_rank=self.config.q_lora_rank, + kv_lora_rank=self.config.kv_lora_rank, + qk_nope_head_dim=self.config.qk_nope_head_dim, + qk_rope_head_dim=self.config.qk_rope_head_dim, + v_head_dim=self.config.v_head_dim, + max_position_embeddings=self.config.max_position_embeddings, + original_max_position_embeddings=self.config.original_max_position_embeddings, + mscale=self.config.mscale, + rope_factor=self.config.rope_factor, + model_mode=model_mode, + rngs=rngs, + attn_logits_soft_cap=self.config.attn_logits_soft_cap, + ) self.dropout = Dropout(rate=self.config.dropout_rate, broadcast_dims=(-2,), rngs=self.rngs) if self.is_mhc_enabled: @@ -333,7 +335,7 @@ def __init__( rngs=self.rngs, ) - def mlp_op(self, x, deterministic): + def mlp_op(self, x, deterministic, *args, **kwargs): mlp = self.mlp(x, deterministic, intermediate_sharding=self.mlp_intermediate_sharding, out_sharding=self.out_sharding) return self.with_logical_constraint(mlp) diff --git a/src/maxtext/models/deepseek4.py b/src/maxtext/models/deepseek4.py new file mode 100644 index 0000000000..12b0b83823 --- /dev/null +++ b/src/maxtext/models/deepseek4.py @@ -0,0 +1,274 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DeepSeek-V4 model definition.""" + +from typing import Optional + +from flax import nnx +import flax.linen as nn +from jax.sharding import Mesh + +from maxtext.common.common_types import Config, AttentionType +from maxtext.common.common_types import HyperConnectionType +from maxtext.layers import attention_compressed +from maxtext.layers import initializers +from maxtext.layers import moe +from maxtext.layers import nnx_wrappers +from maxtext.layers import quantizations +from maxtext.models import deepseek +from jax.ad_checkpoint import checkpoint_name + + +class DeepSeek4DecoderLayer(deepseek.DeepSeekGenericLayer): + """DeepSeek-V4 specific decoder layer. + + Note: V4 does not utilize purely dense layers in the initial transformer blocks. + Every layer is a Sparse MoE layer (which internally contains shared dense experts). + + Args: + config: Configuration for the model. + model_mode: The mode of the model (e.g. 'train', 'inference'). + mesh: JAX sharding mesh. + rngs: NNX Rngs. + quant: Optional AQT quantization config. + layer_idx: The index of the layer. + compress_ratio: DeepSeek V4 specific parameter defining the KV cache compression + ratio. Expected values are 0 (no compression, sliding window), 4 (CSA), or 128 (HCA). + is_hash_routing: DeepSeek V4 specific parameter defining if this layer uses + static deterministic hash routing (used in prefix layers). + """ + + def __init__( + self, + config: Config, + model_mode: str, + mesh: Mesh, + rngs: nnx.Rngs, + quant: Optional[quantizations.AqtQuantization] = None, + layer_idx: int = -1, + compress_ratio: Optional[int] = None, + is_hash_routing: Optional[bool] = None, + ) -> None: + super().__init__( + config=config, + model_mode=model_mode, + mesh=mesh, + rngs=rngs, + quant=quant, + layer_idx=layer_idx, + ) + + # DeepSeek V4 applies Hash Routing to the first `config.first_num_hash_layers` layers. + # For the unscannable prefix layers, we can safely determine this using `layer_idx`. + # However, for layers inside `nn.scan` blocks, `layer_idx` is a dynamic JAX tracer + # and cannot be evaluated as a boolean condition. Since all scannable layers occur + # after the hash-routed prefix, the scannable block explicitly passes + # `is_hash_routing=False` to safely bypass this check. + if is_hash_routing is None: + is_hash_routing = layer_idx < config.first_num_hash_layers + self.mlp = moe.RoutedAndSharedMoE( + config=self.config, + mesh=self.mesh, + kernel_init=initializers.nd_dense_init(self.config.dense_init_scale, "fan_in", "truncated_normal"), + kernel_axes=("embed", None), + dtype=self.config.dtype, + weight_dtype=self.config.weight_dtype, + quant=quant, + is_hash_routing=is_hash_routing, + rngs=rngs, + ) + + if compress_ratio is None: + compress_ratio = config.compress_ratios[layer_idx] + + # Route to LOCAL_SLIDING if compression is disabled for this layer, + # otherwise default to the globally configured attention type (e.g., COMPRESSED). + layer_attention_type = ( + AttentionType.LOCAL_SLIDING if compress_ratio == 0 else AttentionType(self.config.attention_type) + ) + + self.self_attention = attention_compressed.CompressedAttention( + config=self.config, + compress_ratio=compress_ratio, + num_query_heads=self.config.num_query_heads, + num_kv_heads=self.config.num_kv_heads, + head_dim=self.config.head_dim, + max_target_length=self.config.max_target_length, + max_prefill_predict_length=self.config.max_prefill_predict_length, + attention_kernel=self.config.attention, + attention_type=layer_attention_type, + inputs_q_shape=self.dummy_inputs_shape, + inputs_kv_shape=self.dummy_inputs_shape, + mesh=self.mesh, + dtype=self.config.dtype, + weight_dtype=self.config.weight_dtype, + dropout_rate=self.config.dropout_rate, + sliding_window_size=self.config.sliding_window_size, + q_lora_rank=self.config.q_lora_rank, + name=f"compressed_attention_layer_{layer_idx}", + quant=quant, + kv_quant=quantizations.configure_kv_quant(config), + model_mode=model_mode, + rngs=rngs, + ) + + # pylint: disable=arguments-differ + def mlp_op(self, inputs, deterministic, *args, **kwargs): + input_ids = kwargs.get("input_ids") + mlp_lnx, load_balance_loss, moe_bias_updates = self.mlp( + inputs=inputs, + input_ids=input_ids, + ) + return self.with_logical_constraint(mlp_lnx), load_balance_loss, moe_bias_updates + + def __call__( + self, + inputs, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + previous_chunk=None, + slot: None | int = None, + kv_cache=None, + attention_metadata=None, + decoder_input_tokens=None, + ): + if isinstance(inputs, tuple): + inputs = inputs[0] + + x = self.with_logical_constraint(inputs) + x = checkpoint_name(x, "decoder_layer_input") + + _, intermediate_inputs = self.self_attention_with_norm_op( + x, + decoder_segment_ids, + decoder_positions, + deterministic, + previous_chunk, + slot, + ) + + layer_output, metadata = self.mhc_mlp( + self.post_attention_norm_op, + self.mlp_op, + x=intermediate_inputs, + mhc_type=HyperConnectionType.MLP_MOE, + deterministic=deterministic, + input_ids=decoder_input_tokens, + ) + load_balance_loss = metadata.get("load_balance_loss", None) + moe_bias_updates = metadata.get("moe_bias_updates", None) + + layer_output = self.dropout_op(layer_output, deterministic=deterministic) + return self.post_process(layer_output, load_balance_loss, moe_bias_updates, kv_cache) + + +class DeepSeek4ScannableBlock(nnx.Module): + """A scannable block containing exactly two DeepSeek V4 layers (HCA and CSA). + + DeepSeek V4 layers alternate `compress_ratio=128` (HCA) and `compress_ratio=4` (CSA) + throughout the middle of the network. This block encapsulates one full `[128, 4]` + cycle so it can be perfectly scanned using JAX `nn.scan`. + """ + + def __init__( + self, + config: Config, + mesh: Mesh, + model_mode: str, + rngs: nnx.Rngs, + quant: None | quantizations.AqtQuantization = None, + ): + self.config = config + self.mesh = mesh + self.model_mode = model_mode + self.quant = quant + self.rngs = rngs + + # Layer 0 in the block: HCA (compress_ratio=128) with Standard MoE (is_hash_routing=False) + self.layers_0 = DeepSeek4DecoderLayer( + config=self.config, + mesh=self.mesh, + model_mode=self.model_mode, + rngs=self.rngs, + quant=self.quant, + compress_ratio=128, + is_hash_routing=False, + ) + + # Layer 1 in the block: CSA (compress_ratio=4) with Standard MoE (is_hash_routing=False) + self.layers_1 = DeepSeek4DecoderLayer( + config=self.config, + mesh=self.mesh, + model_mode=self.model_mode, + rngs=self.rngs, + quant=self.quant, + compress_ratio=4, + is_hash_routing=False, + ) + + def __call__( + self, + inputs, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + slot=None, + previous_chunk=None, + attention_metadata=None, + kv_cache=None, + ): + inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_norm_length", "activation_embed")) + inputs = checkpoint_name(inputs, "decoder_layer_input") + y = inputs + + y, _ = self.layers_0( + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + previous_chunk=previous_chunk, + slot=slot, + kv_cache=kv_cache, + attention_metadata=attention_metadata, + ) + + y, _ = self.layers_1( + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + previous_chunk=previous_chunk, + slot=slot, + kv_cache=kv_cache, + attention_metadata=attention_metadata, + ) + + return y, None + + +DeepSeek4LayerToLinen = nnx_wrappers.to_linen_class( + DeepSeek4DecoderLayer, + base_metadata_fn=initializers.variable_to_logically_partitioned, +) + +DeepSeek4ScannableBlockToLinen = nnx_wrappers.to_linen_class( + DeepSeek4ScannableBlock, + base_metadata_fn=initializers.variable_to_logically_partitioned, +) diff --git a/src/maxtext/utils/globals.py b/src/maxtext/utils/globals.py index e3b3aadf2d..48caa91ef1 100644 --- a/src/maxtext/utils/globals.py +++ b/src/maxtext/utils/globals.py @@ -75,6 +75,7 @@ "deepseek2-16b": "deepseek-ai/DeepSeek-V2-Lite", "deepseek3-671b": "deepseek-ai/DeepSeek-V3", "deepseek3.2-671b": "deepseek-ai/DeepSeek-V3.2", + "deepseek4": "deepseek-ai/DeepSeek-V4-Flash", "gpt-oss-20b": "openai/gpt-oss-20b", "gpt-oss-120b": "openai/gpt-oss-120b", "qwen3-omni-30b-a3b": "Qwen/Qwen3-Omni-30B-A3B-Instruct", diff --git a/tests/unit/deepseek_v4_vs_reference_test.py b/tests/unit/deepseek_v4_vs_reference_test.py index 1da95a184e..0b75aa9ff4 100644 --- a/tests/unit/deepseek_v4_vs_reference_test.py +++ b/tests/unit/deepseek_v4_vs_reference_test.py @@ -57,13 +57,13 @@ # Tests # ============================================================================== -# HuggingFace reference: https://huggingface.co/deepseek-ai/DeepSeek-V4/blob/main/modeling_deepseek_v4.py +# HuggingFace reference: https://github.com/huggingface/transformers/blob/main/src/transformers/models/deepseek_v4/modeling_deepseek_v4.py # pylint: disable=line-too-long from jax.experimental import mesh_utils from jax.sharding import Mesh from maxtext.common.common_types import MODEL_MODE_TRAIN from maxtext.configs import pyconfig from maxtext.layers.attention_compressed import CompressedAttention -from maxtext.layers.embeddings import DeepSeekV4RotaryEmbedding as MTRope + from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from transformers.models.deepseek_v4.modeling_deepseek_v4 import DeepseekV4Attention from transformers.models.deepseek_v4.modeling_deepseek_v4 import DeepseekV4RotaryEmbedding as PTRope @@ -75,7 +75,7 @@ class DeepSeekV4RotaryEmbeddingTest(unittest.TestCase): def setUp(self): self.batch_size = 2 - self.seq_len = 16 + self.seq_len = 4096 self.head_dim = 128 self.num_heads = 4 self.main_rope_theta = 10000.0 @@ -408,6 +408,8 @@ def setUp(self): self.q_lora_rank = 32 self.o_groups = 2 self.o_lora_rank = 64 + self.qk_rope_head_dim = 64 + self.partial_rotary_factor = self.qk_rope_head_dim / self.head_dim self.rngs = nnx.Rngs(0) @@ -431,8 +433,12 @@ def setUp(self): layer_types=["sliding_attention"], num_hidden_layers=1, rope_parameters={ - "main": {"rope_type": "default", "rope_theta": 10000.0, "partial_rotary_factor": 1.0}, - "compress": {"rope_type": "default", "rope_theta": 160000.0, "partial_rotary_factor": 1.0}, + "main": {"rope_type": "default", "rope_theta": 10000.0, "partial_rotary_factor": self.partial_rotary_factor}, + "compress": { + "rope_type": "default", + "rope_theta": 160000.0, + "partial_rotary_factor": self.partial_rotary_factor, + }, }, sliding_window=2048, attention_dropout=0.0, @@ -524,9 +530,13 @@ def _run_e2e_test(self, layer_type, is_packed=False): "compressed_sparse_attention": self.pt_config.compress_rates["compressed_sparse_attention"], "heavily_compressed_attention": self.pt_config.compress_rates["heavily_compressed_attention"], } + compress_ratio = compress_ratio_map[layer_type] + layer_attention_type = AttentionType.LOCAL_SLIDING if compress_ratio == 0 else AttentionType.COMPRESSED + mt_attn = CompressedAttention( config=mt_config, - compress_ratio=compress_ratio_map[layer_type], + compress_ratio=compress_ratio, + attention_type=layer_attention_type, num_query_heads=self.num_heads, num_kv_heads=1, head_dim=self.head_dim, @@ -540,14 +550,6 @@ def _run_e2e_test(self, layer_type, is_packed=False): rngs=self.rngs, ) self.mt_attn = mt_attn - if layer_type == "sliding_attention": - rope_factor = self.pt_config.rope_parameters["main"]["partial_rotary_factor"] - mt_rope = MTRope(head_dim=self.head_dim, partial_rotary_factor=rope_factor, rope_theta=10000.0) - else: - rope_factor = self.pt_config.rope_parameters["compress"]["partial_rotary_factor"] - mt_rope = MTRope(head_dim=self.head_dim, partial_rotary_factor=rope_factor, rope_theta=160000.0) - - mt_attn.rotary_embedding = mt_rope # 3. Copy Weights self._copy_linear(mt_attn.wq_a, ref_attn.q_a_proj) @@ -652,8 +654,7 @@ def _run_e2e_test(self, layer_type, is_packed=False): print(f"top_k_indices mismatches: {num_mismatches}") # 6. Execute MaxText - - mt_out = mt_attn(x_mt, x_mt, segs_mt, pos_mt, deterministic=True, model_mode=MODEL_MODE_TRAIN) + mt_out, _ = mt_attn(x_mt, x_mt, segs_mt, pos_mt, deterministic=True, model_mode=MODEL_MODE_TRAIN) # 7. Asserts if not is_packed: @@ -771,7 +772,7 @@ def setUp(self): "vocab_size": self.vocab_size, "first_num_hash_layers": 3, "decoder_block": "deepseek", - "model_name": "deepseek4", + "model_name": "deepseek4-284b", "attention": "dot_product", "base_mlp_dim": 256, "base_moe_mlp_dim": 256, @@ -809,7 +810,7 @@ def test_hash_router(self): ) # Sync weights - mx_moe.tid2eid.value = jnp.array(pt_router.tid2eid.numpy()) + mx_moe.tid2eid.value = jnp.array(pt_router.tid2eid.numpy(), dtype=jnp.float32) mx_moe.gate.kernel.value = jnp.array(pt_router.weight.detach().numpy()).T hidden_states = torch.randn(self.batch_size, self.seq_len, self.hidden_dim) @@ -910,7 +911,7 @@ def test_swiglu_clamp(self): "topk_routing_group": 1, "mlp_activations_limit": limit, "decoder_block": "deepseek", - "model_name": "deepseek4", + "model_name": "deepseek4-284b", "attention": "dot_product", "base_mlp_dim": 256, "base_moe_mlp_dim": 256, diff --git a/tests/unit/train_compile_test.py b/tests/unit/train_compile_test.py index 1975ad1abf..41557c8c3c 100644 --- a/tests/unit/train_compile_test.py +++ b/tests/unit/train_compile_test.py @@ -804,6 +804,26 @@ def test_deepseek32(self): ) ) + def test_deepseek4(self): + # test deepseek4 compile + compiled_trainstep_file = "/tmp/test_deepseek4.pickle" + train_compile_main( + ( + "", + get_test_config_path(), + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v5p-256", + "use_iota_embed=true", + "compile_topology_num_slices=1", + "model_name=deepseek4-284b", + "per_device_batch_size=1", + "max_target_length=1024", + "attention=dot_product", + "dtype=bfloat16", + "weight_dtype=bfloat16", + ) + ) + @pytest.mark.cpu_only def test_indexer_dense_warmup(self): # test deepseek3.2 with sparse attention