Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/maxtext/common/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ class DecoderBlockType(enum.Enum):
SIMPLE_MLP = "simple_mlp"
LLAMA4 = "llama4"
OLMO3 = "olmo3"
DEEPSEEK4 = "deepseek4"


class AttentionType(enum.Enum):
Expand Down
4 changes: 2 additions & 2 deletions src/maxtext/common/metric_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def _log_training_metrics(self, metrics, step):

if self.config.num_experts > 1:
moe_lb_loss = scalars.get("learning/moe_lb_loss", 0.0)
log_parts.append(f"moe_lb_loss: {moe_lb_loss:.3f}")
log_parts.append(f"moe_lb_loss: {moe_lb_loss:.6f}")

if self.config.mtp_num_layers > 0:
mtp_loss = scalars.get("learning/mtp_loss", 0.0)
Expand All @@ -217,7 +217,7 @@ def _log_eval_metrics(self, metrics, step):
f"avg_z_loss={scalars.get('eval/avg_z_loss', 0.0):.3f}",
]
if self.config.num_experts > 1:
log_parts.append(f"avg_moe_lb_loss={scalars['eval/avg_moe_lb_loss']:.3f}")
log_parts.append(f"avg_moe_lb_loss={scalars['eval/avg_moe_lb_loss']:.6f}")
if self.config.mtp_num_layers > 0:
log_parts.extend(
[
Expand Down
68 changes: 68 additions & 0 deletions src/maxtext/configs/models/deepseek4-tiny.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Tiny model config for DeepSeek V4 for CPU execution and testing

base_emb_dim: 64
base_num_query_heads: 4
base_num_kv_heads: 1
base_num_decoder_layers: 43
base_mlp_dim: 64
base_moe_mlp_dim: 64
vocab_size: 129280
head_dim: 32

# --- Standard Defaults ---
enable_dropout: false
logits_via_embedding: false
normalization_layer_epsilon: 1.0e-6

# --- V4 Specific Architectural Keys ---
decoder_block: "deepseek4"
mhc_expansion_rate: 4
first_num_hash_layers: 3
indexer_head_dim: 32
indexer_n_heads: 4
indexer_topk: 16

# Note: Layers (0,1) are not compressed.
# The 44th layer (MTP module with compress_ratio=0) has been explicitly dropped for now.
# This leaves exactly 43 layers: 2 prefix [0,0] + 40 scanned + 1 suffix [4].
compress_ratios: [0, 0, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4]

# --- MoE configuration ---
mlp_activations: ["silu", "linear"]
num_experts: 16
num_experts_per_tok: 4
shared_experts: 1
routed_score_func: "sqrtsoftplus"
routed_bias: true
routed_bias_update_rate: 0.001
load_balance_loss_weight: 0.0001
adamw_mask: [".*gate.*bias.*"]

# --- Attention configuration ---
attention: 'dot_product'
attention_type: 'compressed'
q_lora_rank: 16
o_groups: 4
o_lora_rank: 16
sliding_window_size: 32

# --- RoPE ---

rope_type: "default"
rope_max_timescale: 10000 # Main RoPE theta
compressed_rope_max_timescale: 160000 # Compressed RoPE theta
max_position_embeddings: 4096
67 changes: 67 additions & 0 deletions src/maxtext/configs/models/deepseek4.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# model config for DeepSeek V4

base_emb_dim: 4096
base_num_query_heads: 64
base_num_kv_heads: 1
base_num_decoder_layers: 43
base_mlp_dim: 2048
base_moe_mlp_dim: 2048
vocab_size: 129280

# --- Standard Defaults ---
enable_dropout: false
logits_via_embedding: false
normalization_layer_epsilon: 1.0e-6

# --- V4 Specific Architectural Keys ---
decoder_block: "deepseek4"
mhc_expansion_rate: 4
first_num_hash_layers: 3
indexer_head_dim: 128
indexer_n_heads: 64
indexer_topk: 512

# Note: Layers (0,1) are not compressed.
# The 44th layer (MTP module with compress_ratio=0) has been explicitly dropped for now.
# This leaves exactly 43 layers: 2 prefix [0,0] + 40 scanned + 1 suffix [4].
compress_ratios: [0, 0, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4]

# --- MoE configuration ---
mlp_activations: ["silu", "linear"]
num_experts: 256
num_experts_per_tok: 6
shared_experts: 1
routed_score_func: "sqrtsoftplus"
routed_bias: true
routed_bias_update_rate: 0.001
load_balance_loss_weight: 0.0001
adamw_mask: [".*gate.*bias.*"]

# --- Attention configuration ---
attention: 'dot_product'
attention_type: 'compressed'
q_lora_rank: 1024
o_groups: 8
o_lora_rank: 1024
sliding_window_size: 128

# --- RoPE ---

rope_type: "default"
rope_max_timescale: 10000 # Main RoPE theta
compressed_rope_max_timescale: 160000 # Compressed RoPE theta
max_position_embeddings: 1048576
5 changes: 3 additions & 2 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ class ProfilerType(str, Enum):
"deepseek3-671b-batchsplit",
"deepseek3-test",
"deepseek3-tiny",
"deepseek4-tiny",
"deepseek3.2-671b",
"deepseek4",
"deepseek-custom",
Expand Down Expand Up @@ -553,7 +554,7 @@ class Attention(BaseModel):
"autoselected",
description="The attention algorithm to use (dot_product, flash, etc).",
)
attention_type: Literal["global", "local_sliding", "chunk", "mla", "full"] = Field(
attention_type: Literal["global", "local_sliding", "chunk", "mla", "full", "compressed"] = Field(
"global", description="The variant of attention to use."
)
share_kv_projections: bool = Field(
Expand Down Expand Up @@ -2980,7 +2981,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
)
if self.decoder_block == DecoderBlockType.GPT_OSS and not self.sparse_matmul and self.capacity_factor != -1:
raise ValueError("GPT-OSS MoE only supports dropless (capacity_factor=-1) with dense matmul.")
if self.routed_bias and self.routed_bias_update_rate > 0.0 and self.decoder_block != DecoderBlockType.DEEPSEEK:
if self.routed_bias and self.routed_bias_update_rate > 0.0 and self.decoder_block not in (DecoderBlockType.DEEPSEEK, DecoderBlockType.DEEPSEEK4):
raise ValueError("Loss-free load balancing is only supported for the DeepSeek decoder block.")
if self.model_name.startswith("deepseek4") and self.first_num_hash_layers > 0 and self.use_ring_of_experts:
raise ValueError("DeepSeek V4 hash routing is currently not supported with ring of experts.")
Expand Down
7 changes: 2 additions & 5 deletions src/maxtext/layers/attention_compressed.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,10 +694,6 @@ def __init__(
self.q_lora_rank = q_lora_rank
self.compress_ratio = compress_ratio

# Determine the correct underlying attention type based on the compress_ratio
if self.compress_ratio == 0:
attention_type = AttentionType.LOCAL_SLIDING

super().__init__(
config=config,
num_query_heads=num_query_heads,
Expand Down Expand Up @@ -727,6 +723,7 @@ def __init__(
use_bias_in_projections=use_bias_in_projections,
name=name,
rngs=rngs,
rope_type="deepseek4",
**kwargs,
)

Expand Down Expand Up @@ -1047,7 +1044,7 @@ def __call__(
# -> [batch, q_length, emb_dim]
final_out = self.o_b_proj(grouped_flat)

return final_out
return final_out, None


def compressed_attention(
Expand Down
8 changes: 8 additions & 0 deletions src/maxtext/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
Qwen3OmniMoeVisionRotaryEmbedding,
RotaryEmbedding,
YarnRotaryEmbedding,
DeepSeekV4RotaryEmbedding,
PartialRotaryEmbedding,
Gemma4PartialRotaryEmbedding,
)
Expand Down Expand Up @@ -850,6 +851,13 @@ def init_rotary_embedding(self):
shard_mode=self.config.shard_mode,
rngs=self.rngs,
)
elif rope_type == "deepseek4":
rotary_embedding = DeepSeekV4RotaryEmbedding(
head_dim=rope_embedding_dims,
partial_rotary_factor=self.partial_rotary_factor if self.partial_rotary_factor is not None else 1.0,
rope_theta=self.rope_max_timescale,
dtype=self.dtype,
)
elif self.is_qwen3_hybrid:
rotary_embedding = PartialRotaryEmbedding(
min_timescale=self.config.rope_min_timescale,
Expand Down
Loading