Make MoE dispatch/MLP expert-axis batch sharding configurable (fix Mixtral EP throughput)#4179
Open
gulsumgudukbay wants to merge 1 commit into
Open
Conversation
The dispatch/MLP MoE activations are already expert-sharded via activation_exp. Since AI-Hypercomputer#4007, their batch dim also maps to activation_batch_moe, which includes 'expert'. Under single-node expert parallelism (ici_expert_parallelism=-1) this double-maps two tensor dims onto the 'expert' mesh axis, so GSPMD falls back from expert-parallel AllToAll to FSDP-style AllGather+ReduceScatter, regressing throughput for few-large-expert models (e.g. Mixtral-8x7b: ~7.4k -> ~10.9k tok/s/device at bs=11 on 8x MI355X). Add a config flag moe_dispatch_no_expert_sharding (default false) that selects a new activation_batch_no_exp rule ([data, fsdp, fsdp_transpose], no 'expert') for the training dispatch/MLP batch axis. Enable it for mixtral-8x7b. Default-false keeps every other model and all TPU/non-EP paths byte-identical; the flag only changes sharding when the 'expert' mesh axis size > 1.
Codecov Report✅ All modified and coverable lines are covered by tests. 📢 Thoughts on this report? Let us know! |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
PR #4007 added
'expert'toactivation_batch_moeto fix a DeepSeek MoE throughput regression. That change is applied to the post-dispatch core MoE activations (dispatch_axis/mlp_axis) as well, where the tensor's expert dimEis alreadysharded via
activation_exp. For models with few large experts (e.g. Mixtral-8x7b), mapping the batch dim onto'expert'too double-maps two tensor dims onto one mesh axis, so GSPMD abandons the expert-local layout and falls back to FSDP-styleAllGather+ReduceScatter instead of expert-parallel AllToAll, creatiing a large throughput regression under single-node expert parallelism (
ici_expert_parallelism=-1).This PR adds a config flag
moe_dispatch_no_expert_sharding(defaultfalse) that selects, for the training dispatch/MLP batch axis only, a newactivation_batch_no_exprule (['data','fsdp','fsdp_transpose'], i.e. without'expert'). Mixtral-8x7b uses it as true.This is the per-model knob anticipated in the #4007 review discussion (sharding core MoE activations by the expert physical axes rather than the batch dimension), without changing the default for any other model.
Behavior
falseis byte-identical to currentmainfor every model.mixtral-8x7bopts in; no other config inherits it.mixtral-8x7b, the flag only changes sharding when theexpertmesh axis size > 1 (expert parallelism active). Whenexpertis size 1 (TPU/FSDP-primary, or non-expert parallelism GPU) the two axis rules are identical, so those paths are unaffected.Tests
Measured on 1x MI355X node (8 GPUs), JAX 0.9.1,
ici_expert_parallelism=-1,capacity_factor>0:Mixtral recovers ~47% throughput via restored expert-parallel AllToAll; DeepSeek (flag off by default) is unchanged.
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.