Skip to content

Refactor moe.p: gmm and a2a unsort#4170

Open
Shuwen-Fang wants to merge 1 commit into
mainfrom
refactor
Open

Refactor moe.p: gmm and a2a unsort#4170
Shuwen-Fang wants to merge 1 commit into
mainfrom
refactor

Conversation

@Shuwen-Fang

@Shuwen-Fang Shuwen-Fang commented Jun 15, 2026

Copy link
Copy Markdown
Collaborator

Description

This is the second PR refactoring sparse_matmul to make chunking activations and future features easier to implement.

Major changes:

  • Extracted logic for customizing GMM up & gate per sharding strategy and routing config into its own helper function get_gmm_for_local_experts
  • Extracted token unsort & comms between EP shards into unsort_output_with_ra2a

Tests

Verified loss and perplexity is identical on main vs refactor branch after 20 train steps: loss: 12.259, perplexity: 210794.859 for both.

commands to reproduce:

export LIBTPU_FLAGS="\
  --xla_tpu_dvfs_p_state=7 \
  --xla_tpu_scoped_vmem_limit_kib=65536 \
  --xla_tpu_num_sparse_cores_for_gather_offloading=1 \
  --xla_tpu_bf16_emission_mode=NATIVE_EMISSION \
  --xla_tpu_enable_sparse_core_reduce_scatter_v2=true \
  --xla_tpu_enable_sparse_core_collective_offload_all_gather=true \
  --xla_tpu_enable_sparse_core_collective_offload_2d_all_gather=true \
  --xla_tpu_use_tc_device_shape_on_sc=True \
  --xla_sc_disable_megacore_partitioning=True \
  --xla_tpu_enable_async_collective_fusion_fuse_all_gather=false \
  --xla_enable_async_all_gather=true \
  --xla_tpu_prefer_async_allgather_to_allreduce=true \
  --xla_tpu_enable_sparse_core_collective_offload_all_reduce=true \
  --xla_tpu_enable_sparse_core_collective_offload_reduce_scatter=true \
  --xla_tpu_enable_sparse_core_collective_offload_3d_all_gather=true \
  --xla_tpu_use_single_sparse_core_for_all_gather_offload=true \
  --xla_tpu_enable_concurrent_sparse_core_offloading=true \
  --xla_tpu_enable_offloading_gather_to_sparsecore=true \
  --xla_tpu_sparse_core_all_gather_latency_multiplier=1 \
  --xla_tpu_sparse_core_reduce_scatter_latency_multiplier=3 \
  --xla_tpu_enable_sparse_core_collective_aggregator=true \
  --xla_tpu_enable_latency_hiding_layer_scheduler=true \
  --xla_tpu_scheduler_percent_shared_memory_limit=150 \
  --xla_tpu_enable_layer_scheduler_for_dependent_collectives=true \
  --xla_tpu_enable_sparse_core_collective_offload_nd_reduce_scatter=true \
  --xla_tpu_pcie_bandwidth_multiplier=0.03 \
  --xla_tpu_enable_sparse_core_offload_queuing_in_lhs=true \
  --xla_tpu_enable_multi_compute_overlap_in_layer_scheduler=false \
  --xla_tpu_enable_3d_reduce_scatter_decomposer=false \
  --xla_tpu_enable_ici_ag_pipelining=true \
  --xla_tpu_enable_ici_rs_pipelining=true"

python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \
  model_name=deepseek3-tiny \
  override_model_config=True \
  base_num_decoder_layers=1 \
  first_num_dense_layers=0 \
  per_device_batch_size=1.0 \
  max_target_length=1024 \
  dcn_pipeline_parallelism=1 \
  dcn_data_parallelism=-1 \
  ici_pipeline_parallelism=1 \
  ici_fsdp_transpose_parallelism=1 \
  ici_fsdp_parallelism=1 \
  ici_expert_parallelism=-1 \
  allow_split_physical_axes=True \
  use_iota_embed=True \
  remat_policy=custom \
  decoder_layer_input=offload \
  opt_type=adamw \
  mu_dtype=bfloat16 \
  grad_dtype=bfloat16 \
  dtype=bfloat16 \
  weight_dtype=bfloat16 \
  use_random_routing=True \
  megablox=True \
  sparse_matmul=True \
  use_custom_sort_vjp=True \
  use_ring_of_experts=True \
  use_ragged_sort=False \
  shard_exp_on_fsdp=False \
  sa_use_fused_bwd_kernel=True \
  sa_block_q=2048 \
  sa_block_kv=2048 \
  sa_block_q_dkv=512 \
  sa_block_kv_dkv=512 \
  sa_block_kv_dkv_compute=512 \
  sa_block_kv_dq=512 \
  sa_block_q_dq=512 \
  attention=flash \
  use_tokamax_splash=False \
  use_max_logit_estimate=-1 \
  cost_estimate_flops_fwd=5000000000000 \
  cost_estimate_flops_bwd=5000000000000 \
  float32_weight_sum=False \
  use_tokamax_gmm=False \
  tokenizer_type=huggingface \
  tokenizer_path=deepseek-ai/DeepSeek-V3 \
  dataset_type=synthetic \
  dataset_path=gs://max-datasets-rogue \
  enable_checkpointing=False \
  check_vma=False \
  steps=20 \
  attention=dot_product \
  log_config=true \
  base_output_directory=${BASE_OUTPUT_DIR} \
  run_name=${WORKLOAD_NAME}

Full table of correctness test results:

| Mode | Ring of Experts | EP Size | FSDP Size | Loss (main) | Loss (refactor) | Perp (main) | Perp (refactor) | Tok/s/Dev
(main) | Tok/s/Dev (refactor) | Command Differences |
|---|---|---|---|---|---|---|---|---|---|---|
| Sparse | True | 8 | 1 | 12.259 | 12.259 | 210794.859 | 210794.859 | 134,312 | 135,845 | sparse_matmul=True use_ring_of_experts=True ici_expert_parallelism=-1 |
| Sparse | False | 8 | 1 | 12.259 | 12.259 | 210800.438 | 210800.438 | 165,695 | 166,883 | sparse_matmul=True use_ring_of_experts=False ici_expert_parallelism=-1 |
| Sparse | False | 1 | 8 | 12.259 | 12.259 | 210867.609 | 210867.609 | 166,802 | 179,712 | sparse_matmul=True use_ring_of_experts=False ici_expert_parallelism=1 ici_fsdp_parallelism=-1 |
| Dense | False | 1 | 8 | 12.259 | 12.259 | 210843.078 | 210843.078 | 178,025 | 161,234 | sparse_matmul=False use_ring_of_experts=False ici_expert_parallelism=1 ici_fsdp_parallelism=-1 |
| Dense | False | 8 | 1 | 12.259 | 12.259 | 210809.906 | 210809.906 | 177,469 | 164,076 | sparse_matmul=False use_ring_of_experts=False ici_expert_parallelism=-1 ici_fsdp_parallelism=1 |

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov

codecov Bot commented Jun 15, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 93.75000% with 2 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/layers/moe.py 93.75% 1 Missing and 1 partial ⚠️

📢 Thoughts on this report? Let us know!

Comment thread src/maxtext/layers/moe.py
num_ep = self.get_expert_parallelism_size()
num_experts_per_shard = self.config.num_experts // num_ep
use_truncated_buffer = self.config.use_ring_of_experts and x.shape[0] < routing.sorted_selected_experts.shape[0]
if use_truncated_buffer:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can merge line 1647 and 1648?

Comment thread src/maxtext/layers/moe.py
group_offset=experts_start,
)

def unsort_output_with_ra2a(intermediate_output, routing, route_metadata, output_shape, is_batch_sharded_by_expert):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function includes both unsort and ra2a.. Maybe we should name it unsort_output_and_ra2a?

Comment thread src/maxtext/layers/moe.py
output_offsets,
recv_sizes,
axis_name=self._expert_parallelism_name,
)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

emmm why ragged_all_to_all show up twice, one in a function and one outside the function?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants