Skip to content

Make MoE dispatch/MLP expert-axis batch sharding configurable (fix Mixtral EP throughput)#4179

Open
gulsumgudukbay wants to merge 1 commit into
AI-Hypercomputer:mainfrom
ROCm:fix-moe-expert-parallel-sharding
Open

Make MoE dispatch/MLP expert-axis batch sharding configurable (fix Mixtral EP throughput)#4179
gulsumgudukbay wants to merge 1 commit into
AI-Hypercomputer:mainfrom
ROCm:fix-moe-expert-parallel-sharding

Conversation

@gulsumgudukbay

Copy link
Copy Markdown
Collaborator

Description

PR #4007 added 'expert' to activation_batch_moe to fix a DeepSeek MoE throughput regression. That change is applied to the post-dispatch core MoE activations (dispatch_axis/mlp_axis) as well, where the tensor's expert dim E is already
sharded via activation_exp. For models with few large experts (e.g. Mixtral-8x7b), mapping the batch dim onto 'expert' too double-maps two tensor dims onto one mesh axis, so GSPMD abandons the expert-local layout and falls back to FSDP-style
AllGather+ReduceScatter instead of expert-parallel AllToAll, creatiing a large throughput regression under single-node expert parallelism (ici_expert_parallelism=-1).
This PR adds a config flag moe_dispatch_no_expert_sharding (default false) that selects, for the training dispatch/MLP batch axis only, a new activation_batch_no_exp rule (['data','fsdp','fsdp_transpose'], i.e. without 'expert'). Mixtral-8x7b uses it as true.
This is the per-model knob anticipated in the #4007 review discussion (sharding core MoE activations by the expert physical axes rather than the batch dimension), without changing the default for any other model.

Behavior

  • Default false is byte-identical to current main for every model.
  • Only mixtral-8x7b opts in; no other config inherits it.
  • For mixtral-8x7b, the flag only changes sharding when the expert mesh axis size > 1 (expert parallelism active). When expert is size 1 (TPU/FSDP-primary, or non-expert parallelism GPU) the two axis rules are identical, so those paths are unaffected.

Tests

Measured on 1x MI355X node (8 GPUs), JAX 0.9.1, ici_expert_parallelism=-1, capacity_factor>0:

Model bs tok/s/device (before) tok/s/device (this PR) train-step all-to-all
Mixtral-8x7b 11 ~7,400 ~10,900 0 -> 5
Mixtral-8x7b 6 ~7,500 ~11,400 0 -> 5
DeepSeek-v2-lite-16b 8 ~17,800 ~17,800 (unchanged) 0 (unchanged)

Mixtral recovers ~47% throughput via restored expert-parallel AllToAll; DeepSeek (flag off by default) is unchanged.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

The dispatch/MLP MoE activations are already expert-sharded via
activation_exp. Since AI-Hypercomputer#4007, their batch dim also maps to
activation_batch_moe, which includes 'expert'. Under single-node expert
parallelism (ici_expert_parallelism=-1) this double-maps two tensor dims
onto the 'expert' mesh axis, so GSPMD falls back from expert-parallel
AllToAll to FSDP-style AllGather+ReduceScatter, regressing throughput for
few-large-expert models (e.g. Mixtral-8x7b: ~7.4k -> ~10.9k tok/s/device
at bs=11 on 8x MI355X).

Add a config flag moe_dispatch_no_expert_sharding (default false) that
selects a new activation_batch_no_exp rule ([data, fsdp, fsdp_transpose],
no 'expert') for the training dispatch/MLP batch axis. Enable it for
mixtral-8x7b. Default-false keeps every other model and all TPU/non-EP
paths byte-identical; the flag only changes sharding when the 'expert'
mesh axis size > 1.
@codecov

codecov Bot commented Jun 16, 2026

Copy link
Copy Markdown

Codecov Report

✅ All modified and coverable lines are covered by tests.

📢 Thoughts on this report? Let us know!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant