Skip to content

[DeepSeek-V4] Implement model integration, decoders, and configuration stack#4153

Open
parambole wants to merge 4 commits into
mainfrom
dsv4_model_integrate
Open

[DeepSeek-V4] Implement model integration, decoders, and configuration stack#4153
parambole wants to merge 4 commits into
mainfrom
dsv4_model_integrate

Conversation

@parambole

@parambole parambole commented Jun 11, 2026

Copy link
Copy Markdown
Collaborator

Description

This PR introduces native architectural and routing support for the DeepSeek V4 model in MaxText.

Why & What: DeepSeek V4 introduces non-uniform architectural features that require explicit configuration unrolling. This PR solves the integration by implementing:

  • Compressed Attention (CSA/HCA): Bypasses standard MLA instantiation and natively integrates DeepSeek V4's alternating CSA and HCA attention blocks.
  • Hybrid Routing: Implements DeepSeek's transition from fixed Hash Routing (early layers) to learned Token Routing (later layers) natively within the MoE framework.
  • Architectural Scanning: Unrolls the 44-layer configuration to properly handle the [0, 0] prefix compression ratios, the perfectly alternating [4, 128] scanned middle layers, and the [4, 0] suffix layers.

Tests

  • Unit Tests: Verified mathematical parity against reference implementations using tests/unit/deepseek_v4_vs_reference_test.py.
  • E2E Compilation: Successfully compiled the full DeepSeek V4 model on a simulated v5p-512 mesh to guarantee memory constraints and HLO generation.

Compile Command to Reproduce:

python3  -m  maxtext.trainers.pre_train.train_compile  src/maxtext/configs/base.yml
  base_output_directory=/tmp/maxtext_logs
  run_name=dsv4_v5p512_compile
  per_device_batch_size=1
  enable_checkpointing=false
  model_name=deepseek4
  compile_topology=v5p-512
  compile_topology_num_slices=1
  ici_fsdp_parallelism=-1
  steps=1
  max_target_length=4096
  async_checkpointing=false
  tokenizer_type=huggingface
  tokenizer_path=deepseek-ai/DeepSeek-V3
  attention=dot_product
  dtype=bfloat16
  weight_dtype=bfloat16
  megablox=False
  sparse_matmul=False
  dataset_type=synthetic
  scan_layers=true

Proof of Compilation:

Memory analysis: CompiledMemoryStats(generated_code_size_in_bytes=260855808, argument_size_in_bytes=18401962496, output_size_in_bytes=18401889280, alias_size_in_bytes=18401880576, temp_size_in_bytes=94892786400, host_generated_code_size_in_bytes=0, host_argument_size_in_bytes=0, host_output_size_in_bytes=0, host_alias_size_in_bytes=0, host_temp_size_in_bytes=0)

Checklist

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov

codecov Bot commented Jun 11, 2026

Copy link
Copy Markdown

@parambole parambole force-pushed the dsv4_model_integrate branch 2 times, most recently from 2a19018 to 23adce0 Compare June 12, 2026 20:00
@parambole parambole marked this pull request as ready for review June 12, 2026 20:09
@parambole parambole changed the title Add DeepSeek V4 architecture support [DeepSeek-V4] Implement model integration, decoders, and configuration stack Jun 12, 2026
@github-actions

Copy link
Copy Markdown

🤖 Hi @parambole, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions

Copy link
Copy Markdown

🤖 I'm sorry @parambole, but I was unable to process your request. Please see the logs for more details.

Comment thread src/maxtext/configs/models/deepseek4.yml Outdated
@parambole parambole force-pushed the dsv4_model_integrate branch from 23adce0 to 6deaacc Compare June 12, 2026 21:17
Comment thread src/maxtext/configs/models/deepseek4-284b.yml

@entrpn entrpn left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just one comment, everything else looks good.

@RissyRan

Copy link
Copy Markdown
Collaborator

Are you able to have a real run and check profile to see if the scan blocks order as expected? Compile test won't be able to verify a RunTime error.

@github-actions

Copy link
Copy Markdown

🤖 Hi @parambole, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions

Copy link
Copy Markdown

🤖 I'm sorry @parambole, but I was unable to process your request. Please see the logs for more details.

This commit introduces full support for DeepSeek V4 by integrating its
compressed attention mechanisms, MoE routing, and architectural layers.

Key changes:
- Add `deepseek4.yml` configuration and `DeepSeek4DecoderLayer` implementation.
- Implement hybrid Hash Routing and Token Routing for MoE layers.
- Add prefix/suffix layer unrolling for non-uniform compression blocks.
- Fix Pydantic validation for base MLP dimensions.
- Bypass MLA instantiation in favor of native CompressedAttention (CSA/HCA).
@parambole parambole force-pushed the dsv4_model_integrate branch from 6deaacc to 5953a73 Compare June 15, 2026 19:40
base_mlp_dim: 2048
base_moe_mlp_dim: 2048
vocab_size: 129280

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add head_dim: 512 here

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

partial_rotary_factor=self.config.qk_rope_head_dim / self.config.head_dim

@shuningjin shuningjin left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the integration! Main suggestions:

  1. Update scan logic.
  • Note first_num_hash_layers=3 for the prefix layers. Followed by [HCA-128, CSA-4] cycles. There is no suffix. This is true for both flash and pro.
  1. Check two RoPE theta for sliding_window and hca_or_csa.
  • hca_or_csa should use theta2, rather than a mix of theta1 & theta2
  1. Add unit test for train compile. example

See in line for details and other minor comments.

# See the License for the specific language governing permissions and
# limitations under the License.

# model config for DeepSeek V4

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which version? Seems like flash: https://huggingface.co/deepseek-ai/DeepSeek-V4-Flash/blob/main/config.json?

For the file name deepseek4.yml, perhaps deepseek4-284b.yml?

routed_score_func: "sqrtsoftplus"

# --- Attention configuration ---
attention: 'dot_product'

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We typically do not put attention in config. Suggest to remove here, and add check in types.py. E.g.,

if self.moba and self.attention not in ("dot_product"):
raise ValueError("MoBA is only supported with dot_product attention.")

Comment thread src/maxtext/layers/attentions.py Outdated
shard_mode=self.config.shard_mode,
rngs=self.rngs,
)
elif rope_type == "deepseek4":

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change rope_type == "deepseek4" to

if self.config.model_name.startswith("deepseek4"): 
  1. Currently, we only have three rope_type

    rope_type: "default" # one of "default", "llama3.1" or "yarn"

    Recall you set rope_type: default in deepseek4.yml. I notice you attempt to override function arg to rope_type=deepseek4. This is confusing and unnecessary.

  2. most models use the name to differentiate model-specific rope. e.g.,

if self.config.model_name.startswith("gemma4"):

use_bias_in_projections=use_bias_in_projections,
name=name,
rngs=rngs,
rope_type="deepseek4",

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can remove, see the other comment


# Note: Layers (0,1) are not compressed.
# The 44th layer (MTP module with compress_ratio=0) has been explicitly dropped for now.
# This leaves exactly 43 layers: 2 prefix [0,0] + 40 scanned + 1 suffix [4].

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be: 3 prefix [0,0,4] + 40 scanned. Note first_num_hash_layers=3 correspond to layer 0, 1, 2.

q_lora_rank: The rank for the LoRA projection in the compressed query.
compress_ratio: The compression ratio for the compressor.
"""
"""Initializes the Compressed Attention module."""

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

an underlying HCA or CSA compressor based on the provided layer_type.

currently seems using compress_ratio (0, =4, >4). although I would prefer if we can use layer_type instead


"""Initializes the Compressed Attention module."""

duplicate line, can remove


Also, might worth adding more docstring.

  • highlight: Shared-KV, MQA, 3 different attention (sliding, hca, csa), different rope theta
  • like HF

elif rope_type == "deepseek4":
rotary_embedding = DeepSeekV4RotaryEmbedding(
head_dim=rope_embedding_dims,
partial_rotary_factor=self.partial_rotary_factor if self.partial_rotary_factor is not None else 1.0,

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you clarify when we use partial_rotary_factor vs. 1.0?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The usage of DeepSeekV4RotaryEmbedding with two theta seems different from HF.

  1. HF:
  1. In this PR, there are two rotary embedding:
  • self.compress_rotary_embedding, initialized here with config.compressed_rope_max_timescale
  • self.rotary_embedding=init_rotary_embedding() in attention.py, initialized via base class that uses config.rope_max_timescale
  • elif rope_type == "deepseek4":
    rotary_embedding = DeepSeekV4RotaryEmbedding(
    head_dim=rope_embedding_dims,
    partial_rotary_factor=self.partial_rotary_factor if self.partial_rotary_factor is not None else 1.0,
    rope_theta=self.rope_max_timescale,
    dtype=self.dtype,
    )
  1. This PR:
  • sliding window: main_rope
  • CSA / HCA: compressed_rope + main_rope [inconsistent]
  1. The unit test for attention is passing as you override self.rotary_embedding for CSA / HCA.
  • The test logic is different from how we actually attention in model.
  • rope_factor = self.pt_config.rope_parameters["compress"]["partial_rotary_factor"]
    mt_rope = MTRope(head_dim=self.head_dim, partial_rotary_factor=rope_factor, rope_theta=160000.0)
    mt_attn.rotary_embedding = mt_rope

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For RoPE, it is safer to test with longer length. e.g., seq_len=4096

self.seq_len = 4096

for seqlen in [128, 5000, 10000]:

Comment thread src/maxtext/layers/attentions.py Outdated
head_dim=rope_embedding_dims,
partial_rotary_factor=self.partial_rotary_factor if self.partial_rotary_factor is not None else 1.0,
rope_theta=self.rope_max_timescale,
dtype=self.dtype,

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you make DeepSeekV4RotaryEmbedding inherits RotaryEmbedding, to be consistent as other RoPE classes?

  • dtype=self.dtype perhaps should be fprop_dtype # The dtype of the output ?
  • class DeepSeekV4RotaryEmbedding(nnx.Module):
  • class RotaryEmbedding(nnx.Module):
    """Rotary Position Embedding."""
    def __init__(
    self,
    min_timescale: int,
    max_timescale: int,
    mesh: Mesh,
    embedding_dims: int = 0,
    cast_as_fprop_dtype: bool = True,
    fprop_dtype: DType = jnp.bfloat16,
    shard_mode: ShardMode = ShardMode.AUTO,
    # Not used in RotaryEmbedding but passed in by nnx.bridge.to_linen.
    # TODO: Remove when bridge no longer needed
    rope_linear_scaling_factor: float = 1.0,
    rngs: nnx.Rngs = None,

@parambole parambole force-pushed the dsv4_model_integrate branch from 5eb3336 to 0b1a9a5 Compare June 16, 2026 17:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants