Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/maxtext/common/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,4 @@ class CustomRule(enum.Enum):
EP_AS_CP = "ep-as-cp" # Support EP only
PIPELINE_LARGE_MOE = "pipeline-large-moe"
FSDP_2D = "2d-fsdp"
EP_AS_DP = "ep-as-dp"
85 changes: 85 additions & 0 deletions src/maxtext/configs/custom_mesh_and_rule/ep-as-dp.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# This rule uses data, FSDP, FSDP_transpose and expert. Expert axis acts as
# data parallelism in components except core dMoE part (between EP all2all).

mesh_axes: ['data', 'fsdp', 'fsdp_transpose', 'expert']
data_sharding: [['data', 'fsdp', 'fsdp_transpose', 'expert']]
logical_axis_rules: [
# ==========================================
# Vocabulary Embedding
# ==========================================
# Vocab Activations
['activation_embed_and_logits_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
['activation_embed_and_logits_batch_sequence', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
['activation_vocab', []],
# Vocab Weights
['vocab', []],
['embed_vocab', ['fsdp', 'fsdp_transpose']],
# ==========================================
# Attention
# ==========================================
# Attention Activations
['activation_batch_attn', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
['activation_heads', []],
['activation_kv_heads', []],
['activation_length_attn', ['context']],
['activation_q_length', ['context']],
['activation_kv_length', []],
['activation_embed_attn', ['tensor', 'tensor_transpose']],
['activation_kv', []],
['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
['activation_kv_head_dim', []],
# Attention Weights
['q_lora', ['fsdp']],
["q_lora_up_proj", []],
['kv_lora', ['fsdp']],
["kv_lora_up_proj", []],
# ==========================================
# Mixture of Experts (MoE)
# ==========================================
# MoE Activations
['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
['activation_length_moe', []],
['activation_norm_length_moe', []],
['activation_embed_moe', []],
['activation_mlp_moe', []],
['activation_exp', ['expert']],
# MoE Weights
['exp', 'expert'],
['mlp_moe', ['fsdp_transpose']],
['embed_moe', ['fsdp', 'fsdp_transpose']],
# ==========================================
# Standard MLP / Dense Layers / Model Structure
# ==========================================
# Dense Activations
['activation_mlp', []],
# Note activation batch and length also get used in vocab
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
['activation_length', []],
['activation_norm_length', []],
['activation_embed', []],
['activation_stage', []],
# General Weights
['mlp', ['fsdp_transpose']],
# GDN (linear-attention) projections shard like 'mlp' during training; the
# vLLM serving config overrides this to match tpu-inference's ATTN_HEAD order.
['embed', ['fsdp', 'fsdp_transpose']],
['embed', ['fsdp']],
# ==========================================
# Deprecated / Scheduled for Removal
# ==========================================
['exp_with_fsdp', 'fsdp'],
]
Loading