Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/maxtext/common/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ class DecoderBlockType(enum.Enum):
SIMPLE_MLP = "simple_mlp"
LLAMA4 = "llama4"
OLMO3 = "olmo3"
DEEPSEEK4 = "deepseek4"


class AttentionType(enum.Enum):
Expand Down
64 changes: 64 additions & 0 deletions src/maxtext/configs/models/deepseek4-284b.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Model config for DeepSeek-V4-Flash 284B (https://huggingface.co/deepseek-ai/DeepSeek-V4-Flash)

base_emb_dim: 4096
base_num_query_heads: 64
base_num_kv_heads: 1
base_num_decoder_layers: 43
base_mlp_dim: 2048
base_moe_mlp_dim: 2048
vocab_size: 129280
head_dim: 512

# --- Standard Defaults ---
enable_dropout: false
logits_via_embedding: false
normalization_layer_epsilon: 1.0e-6

# --- V4 Specific Architectural Keys ---
decoder_block: "deepseek4"
mhc_expansion_rate: 4
first_num_hash_layers: 3
indexer_head_dim: 128
indexer_n_heads: 64
indexer_topk: 512

# Note: Layers (0, 1, 2) are prefix layers.
# The 44th layer (MTP module with compress_ratio=0) has been explicitly dropped for now.
# This leaves exactly 43 layers: 3 prefix [0,0,4] + 40 scanned.
compress_ratios: [0, 0, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4]
Comment thread
parambole marked this conversation as resolved.

# --- MoE configuration ---
mlp_activations: ["silu", "linear"]
num_experts: 256
num_experts_per_tok: 6
mlp_activations_limit: 10
shared_experts: 1
routed_score_func: "sqrtsoftplus"

Comment thread
parambole marked this conversation as resolved.
# --- Attention configuration ---
attention_type: 'compressed'
q_lora_rank: 1024
o_groups: 8
o_lora_rank: 1024
sliding_window_size: 128

# --- RoPE ---

rope_type: "default"
rope_max_timescale: 10000 # Main RoPE theta
compressed_rope_max_timescale: 160000 # Compressed RoPE theta
max_position_embeddings: 1048576
6 changes: 4 additions & 2 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ class ProfilerType(str, Enum):
"deepseek3-test",
"deepseek3-tiny",
"deepseek3.2-671b",
"deepseek4",
"deepseek4-284b",
"deepseek-custom",
"kimi-k2-1t",
"gemma-7b",
Expand Down Expand Up @@ -553,7 +553,7 @@ class Attention(BaseModel):
"autoselected",
description="The attention algorithm to use (dot_product, flash, etc).",
)
attention_type: Literal["global", "local_sliding", "chunk", "mla", "full"] = Field(
attention_type: Literal["global", "local_sliding", "chunk", "mla", "full", "compressed"] = Field(
"global", description="The variant of attention to use."
)
share_kv_projections: bool = Field(
Expand Down Expand Up @@ -2925,6 +2925,8 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
raise ValueError("`local_checkpoint_period` must be > 0 for emergency checkpointing.")
if self.moba and self.attention not in ("dot_product"):
raise ValueError("MoBA is only supported with dot_product attention.")
if self.decoder_block == DecoderBlockType.DEEPSEEK4 and self.attention != "dot_product":
raise ValueError("DeepSeek4 decoder block currently only supports dot_product attention.")
if self.use_indexer:
if self.q_lora_rank == 0:
raise NotImplementedError("Sparse indexer has not implemented for q_lora_rank = 0.")
Expand Down
37 changes: 19 additions & 18 deletions src/maxtext/layers/attention_compressed.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,24 +680,23 @@ def __init__(
rngs: Optional[nnx.Rngs] = None,
**kwargs,
):
"""Initializes the CompressedAttention layer.
"""Inherits all standard Attention hyperparameters and selectively instantiates
an underlying HCA or CSA compressor based on the provided `compress_ratio`.

Inherits all standard Attention hyperparameters and selectively instantiates
an underlying HCA or CSA compressor based on the provided `layer_type`.
Highlights of DeepSeek-V4 attention integration:
- Shared-KV: The layer supports decoupling Q and KV heads for heavy compression.
- MQA: Multi-Query Attention used alongside heavy KV compression.
- 3 Different Attention Modes: Sliding Window (prefix), HCA (128x), and CSA (4x).
- Dual RoPE Theta: Uses 10000 for standard uncompressed tokens and 160000 for compressed.

Args:
(See maxtext.layers.attentions.Attention for standard attention arguments)
q_lora_rank: The rank for the LoRA projection in the compressed query.
compress_ratio: The compression ratio for the compressor.
compress_ratio: The compression ratio (0, 4, or 128) for the compressor.
"""
"""Initializes the Compressed Attention module."""
self.q_lora_rank = q_lora_rank
self.compress_ratio = compress_ratio

# Determine the correct underlying attention type based on the compress_ratio
if self.compress_ratio == 0:
attention_type = AttentionType.LOCAL_SLIDING
Comment thread
parambole marked this conversation as resolved.

super().__init__(
config=config,
num_query_heads=num_query_heads,
Expand Down Expand Up @@ -809,20 +808,22 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No
rngs=self.rngs,
)

# DeepSeek-V4 uses a separate RoPE theta (160000) for compressed tokens.
# We must instantiate a dedicated rotary embedding for the compressors
self.compress_rotary_embedding = DeepSeekV4RotaryEmbedding(
# Override the base rotary embedding with the correct theta for this layer.
# CSA / HCA layers use compressed_rope_max_timescale (160000).
# Sliding window prefix layers use rope_max_timescale (10000).
rope_theta = self.config.compressed_rope_max_timescale if self.compress_ratio > 0 else self.config.rope_max_timescale
self.rotary_embedding = DeepSeekV4RotaryEmbedding(
head_dim=self.config.head_dim,
partial_rotary_factor=1.0,
rope_theta=self.config.compressed_rope_max_timescale,
dtype=self.dtype,
partial_rotary_factor=self.config.qk_rope_head_dim / self.config.head_dim,
rope_theta=rope_theta,
fprop_dtype=self.dtype,
)

if self.compress_ratio > 4:
self.hca_compressor = DeepseekV4HCACompressor(
config=self.config,
compress_ratio=self.compress_ratio,
rotary_embedding=self.compress_rotary_embedding,
rotary_embedding=self.rotary_embedding,
kernel_init=self.kernel_init,
quant=self.quant,
model_mode=self.model_mode,
Expand All @@ -832,7 +833,7 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No
self.csa_compressor = DeepseekV4CSACompressor(
config=self.config,
compress_ratio=self.compress_ratio,
rotary_embedding=self.compress_rotary_embedding,
rotary_embedding=self.rotary_embedding,
kernel_init=self.kernel_init,
quant=self.quant,
model_mode=self.model_mode,
Expand Down Expand Up @@ -1047,7 +1048,7 @@ def __call__(
# -> [batch, q_length, emb_dim]
final_out = self.o_b_proj(grouped_flat)

return final_out
return final_out, None


def compressed_attention(
Expand Down
10 changes: 10 additions & 0 deletions src/maxtext/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
Qwen3OmniMoeVisionRotaryEmbedding,
RotaryEmbedding,
YarnRotaryEmbedding,
DeepSeekV4RotaryEmbedding,
PartialRotaryEmbedding,
Gemma4PartialRotaryEmbedding,
)
Expand Down Expand Up @@ -850,6 +851,15 @@ def init_rotary_embedding(self):
shard_mode=self.config.shard_mode,
rngs=self.rngs,
)
elif self.config.decoder_block == DecoderBlockType.DEEPSEEK4:
# DeepSeek models apply RoPE only to the first `qk_rope_head_dim` channels out of the full `head_dim`.
# We explicitly calculate this ratio so DeepSeekV4RotaryEmbedding can internally slice the correct dimension.
rotary_embedding = DeepSeekV4RotaryEmbedding(
head_dim=self.config.head_dim,
partial_rotary_factor=self.config.qk_rope_head_dim / self.config.head_dim,
rope_theta=self.rope_max_timescale,
fprop_dtype=self.dtype,
)
elif self.is_qwen3_hybrid:
rotary_embedding = PartialRotaryEmbedding(
min_timescale=self.config.rope_min_timescale,
Expand Down
110 changes: 109 additions & 1 deletion src/maxtext/layers/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from maxtext.layers.quantizations import AqtQuantization as Quant
from maxtext.models import (
deepseek,
deepseek4,
deepseek_batchsplit,
deepseek_batchsplit_fp8,
gemma,
Expand Down Expand Up @@ -467,6 +468,10 @@ def get_decoder_layers(self):
deepseek.DeepSeekDenseLayerToLinen,
deepseek.DeepSeekMoELayerToLinen,
]
case DecoderBlockType.DEEPSEEK4:
return (
[deepseek4.DeepSeek4ScannableBlockToLinen] if self.config.scan_layers else [deepseek4.DeepSeek4LayerToLinen]
)
case DecoderBlockType.GEMMA:
return [gemma.GemmaDecoderLayerToLinen]
case DecoderBlockType.GEMMA2:
Expand Down Expand Up @@ -632,6 +637,7 @@ def get_norm_layer(self, num_features: int):
DecoderBlockType.MISTRAL,
DecoderBlockType.MIXTRAL,
DecoderBlockType.DEEPSEEK,
DecoderBlockType.DEEPSEEK4,
DecoderBlockType.GEMMA,
DecoderBlockType.GEMMA2,
DecoderBlockType.GEMMA3,
Expand Down Expand Up @@ -1061,6 +1067,17 @@ def __call__(
previous_chunk,
slot,
)
elif cfg.decoder_block == DecoderBlockType.DEEPSEEK4:
y = self._apply_deepseek4_scanned_blocks(
y,
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
previous_chunk,
slot,
decoder_input_tokens,
)
else:
RemattedBlockLayer = RemattedBlockLayers[0]
scan_length = int(cfg.num_decoder_layers / cfg.inhomogeneous_layer_cycle_interval)
Expand Down Expand Up @@ -1195,7 +1212,7 @@ def __call__(
"is_nope_layer": llama4.determine_is_nope_layer(lyr, self.config.nope_layer_interval),
"is_moe_layer": llama4.determine_is_moe_layer(lyr, self.config.interleave_moe_layer_step),
}
if cfg.decoder_block in (DecoderBlockType.QWEN3_NEXT, DecoderBlockType.QWEN3_5):
if cfg.decoder_block in (DecoderBlockType.QWEN3_NEXT, DecoderBlockType.QWEN3_5, DecoderBlockType.DEEPSEEK4):
layer_kwargs = {"layer_idx": lyr}
kv_cache = None
if kv_caches is not None:
Expand Down Expand Up @@ -1423,6 +1440,97 @@ def _apply_gemma4_scanned_blocks(

return y

def _apply_deepseek4_scanned_blocks(
self,
y,
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
previous_chunk,
slot,
decoder_input_tokens,
):
"""Applies DeepSeek V4 scanned decoder blocks.

DeepSeek V4 has some number of prefix layers (defined by `first_num_hash_layers`)
that use static Hash Routing. The remaining layers alternate `compress_ratio=128` (HCA)
and `compress_ratio=4` (CSA) and are evaluated in a single `nn.scan` block.

For DeepSeek4-Flash (43 hidden layers total):
- 3 Prefix layers (Indices 0, 1, 2)
- 40 Scanned layers: 20 perfectly repeating chunks of [128, 4]
"""

cfg = self.config
mesh = self.mesh

broadcast_args = (
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
slot,
previous_chunk,
)

layer_call_kwargs = {
"previous_chunk": previous_chunk,
"slot": slot,
"decoder_input_tokens": decoder_input_tokens,
}

# 1. Prefix Unrolling
# These layers use Hash Routing.
num_hash_layers = cfg.first_num_hash_layers
for layer_idx in range(num_hash_layers):
prefix_layer = deepseek4.DeepSeek4LayerToLinen(
config=cfg,
mesh=mesh,
name=f"layers_{layer_idx}",
quant=self.quant,
model_mode=self.model_mode,
layer_idx=layer_idx,
)
y, _ = prefix_layer(
y,
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
**layer_call_kwargs,
)

# 2. Chunked Scanning
# The remaining layers perfectly alternate HCA (128) and CSA (4).
num_remaining_layers = cfg.num_decoder_layers - num_hash_layers
num_full_blocks = num_remaining_layers // 2

if num_full_blocks > 0:
ScannableBlockToLinen = deepseek4.DeepSeek4ScannableBlockToLinen
policy = self.get_remat_policy()
RemattedDeepSeek4Block = self.set_remat_policy([ScannableBlockToLinen], policy)[0]

y, _ = nn.scan(
RemattedDeepSeek4Block,
variable_axes={
"params": cfg.param_scan_axis,
"cache": 0,
"intermediates": 0,
"aqt": 0,
"_overwrite_with_gradient": 0,
},
split_rngs={"params": True, "dropout": cfg.enable_dropout},
in_axes=(nn.broadcast,) * len(broadcast_args),
length=num_full_blocks,
metadata_params={
nn.PARTITION_NAME: "layers",
"abstract_init": False,
},
)(config=cfg, mesh=mesh, quant=self.quant, model_mode=model_mode, name="scanned_blocks",)(y, *broadcast_args)

return y

def _apply_gemma4_small_layers(
self,
y,
Expand Down
17 changes: 14 additions & 3 deletions src/maxtext/layers/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -1803,7 +1803,7 @@ def qwen3_omni_mrope_embedding_as_linen(
)


class DeepSeekV4RotaryEmbedding(nnx.Module):
class DeepSeekV4RotaryEmbedding(RotaryEmbedding):
"""DeepSeek-V4 partial rotary embedding with interleaved frequencies.

DeepSeek-V4 uses an interleaved positional encoding where consecutive channels
Expand All @@ -1822,12 +1822,23 @@ def __init__(
head_dim: int,
partial_rotary_factor: float = 64.0 / 512.0,
rope_theta: float = 10000.0,
dtype: Any = jnp.float32,
fprop_dtype: Any = jnp.float32,
min_timescale: int = 10000,
max_timescale: int = 10000,
mesh: Any = None,
**kwargs,
):
super().__init__(
min_timescale=min_timescale,
max_timescale=max_timescale,
mesh=mesh,
fprop_dtype=fprop_dtype,
**kwargs,
)
self.head_dim = head_dim
self.partial_rotary_factor = partial_rotary_factor
self.rope_theta = rope_theta
self.dtype = dtype
self.fprop_dtype = fprop_dtype

# Compute the partial rotary dimension (rope_head_dim)
self.dim = int(head_dim * partial_rotary_factor)
Expand Down
Loading
Loading