From 686947bf30657f0a3e390a78b7604afd727a9531 Mon Sep 17 00:00:00 2001 From: NuojCheng Date: Mon, 15 Jun 2026 17:21:56 +0000 Subject: [PATCH] add ep as dp --- src/maxtext/common/common_types.py | 1 + .../configs/custom_mesh_and_rule/ep-as-dp.yml | 85 +++++++++++++++++++ 2 files changed, 86 insertions(+) create mode 100644 src/maxtext/configs/custom_mesh_and_rule/ep-as-dp.yml diff --git a/src/maxtext/common/common_types.py b/src/maxtext/common/common_types.py index d4b52207fc..344b85a2f8 100644 --- a/src/maxtext/common/common_types.py +++ b/src/maxtext/common/common_types.py @@ -152,3 +152,4 @@ class CustomRule(enum.Enum): EP_AS_CP = "ep-as-cp" # Support EP only PIPELINE_LARGE_MOE = "pipeline-large-moe" FSDP_2D = "2d-fsdp" + EP_AS_DP = "ep-as-dp" diff --git a/src/maxtext/configs/custom_mesh_and_rule/ep-as-dp.yml b/src/maxtext/configs/custom_mesh_and_rule/ep-as-dp.yml new file mode 100644 index 0000000000..038f84a6de --- /dev/null +++ b/src/maxtext/configs/custom_mesh_and_rule/ep-as-dp.yml @@ -0,0 +1,85 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This rule uses data, FSDP, FSDP_transpose and expert. Expert axis acts as +# data parallelism in components except core dMoE part (between EP all2all). + +mesh_axes: ['data', 'fsdp', 'fsdp_transpose', 'expert'] +data_sharding: [['data', 'fsdp', 'fsdp_transpose', 'expert']] +logical_axis_rules: [ + # ========================================== + # Vocabulary Embedding + # ========================================== + # Vocab Activations + ['activation_embed_and_logits_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], + ['activation_embed_and_logits_batch_sequence', ['data', 'fsdp', 'fsdp_transpose', 'expert']], + ['activation_vocab', []], + # Vocab Weights + ['vocab', []], + ['embed_vocab', ['fsdp', 'fsdp_transpose']], + # ========================================== + # Attention + # ========================================== + # Attention Activations + ['activation_batch_attn', ['data', 'fsdp', 'fsdp_transpose', 'expert']], + ['activation_heads', []], + ['activation_kv_heads', []], + ['activation_length_attn', ['context']], + ['activation_q_length', ['context']], + ['activation_kv_length', []], + ['activation_embed_attn', ['tensor', 'tensor_transpose']], + ['activation_kv', []], + ['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], + ['activation_kv_head_dim', []], + # Attention Weights + ['q_lora', ['fsdp']], + ["q_lora_up_proj", []], + ['kv_lora', ['fsdp']], + ["kv_lora_up_proj", []], + # ========================================== + # Mixture of Experts (MoE) + # ========================================== + # MoE Activations + ['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose', 'expert']], + ['activation_length_moe', []], + ['activation_norm_length_moe', []], + ['activation_embed_moe', []], + ['activation_mlp_moe', []], + ['activation_exp', ['expert']], + # MoE Weights + ['exp', 'expert'], + ['mlp_moe', ['fsdp_transpose']], + ['embed_moe', ['fsdp', 'fsdp_transpose']], + # ========================================== + # Standard MLP / Dense Layers / Model Structure + # ========================================== + # Dense Activations + ['activation_mlp', []], + # Note activation batch and length also get used in vocab + ['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], + ['activation_length', []], + ['activation_norm_length', []], + ['activation_embed', []], + ['activation_stage', []], + # General Weights + ['mlp', ['fsdp_transpose']], + # GDN (linear-attention) projections shard like 'mlp' during training; the + # vLLM serving config overrides this to match tpu-inference's ATTN_HEAD order. + ['embed', ['fsdp', 'fsdp_transpose']], + ['embed', ['fsdp']], + # ========================================== + # Deprecated / Scheduled for Removal + # ========================================== + ['exp_with_fsdp', 'fsdp'], + ]