Skip to content

Add weight for ragged gather kernel and enable fan out in bwd ragged sort#4166

Merged
copybara-service[bot] merged 1 commit into
mainfrom
chengnuojin-fix-ragged
Jun 17, 2026
Merged

Add weight for ragged gather kernel and enable fan out in bwd ragged sort#4166
copybara-service[bot] merged 1 commit into
mainfrom
chengnuojin-fix-ragged

Conversation

@NuojCheng

@NuojCheng NuojCheng commented Jun 15, 2026

Copy link
Copy Markdown
Collaborator

Description

This PR fixes the usage of ragged_gather kernel under _ring_ragged_unsort_bwd. As a result, it

  • Removes the bitcast_add in backward pass
  • Reduce HBM usage

Previously, this kernel was applied in non-fan-out way, since the backward pass the unsort requires weight balance, while the ragged gather kernel did not support. This PR adds this support and enable a fan-out style of this kernel, successfully avoid materializing the full padded-shape tensor in fp32.

FIXES: b/522526813

Tests

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov

codecov Bot commented Jun 15, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 77.77778% with 12 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/kernels/ragged/ragged_gather.py 71.42% 11 Missing and 1 partial ⚠️

📢 Thoughts on this report? Let us know!

@NuojCheng NuojCheng force-pushed the chengnuojin-fix-ragged branch from ee734a4 to 2de0a9c Compare June 16, 2026 01:42
@NuojCheng NuojCheng changed the title Add fan out for scatter bwd Add weight for ragged gather kernel and enable fan out in bwd ragged sort Jun 16, 2026
@NuojCheng NuojCheng marked this pull request as ready for review June 16, 2026 02:27
@NuojCheng NuojCheng force-pushed the chengnuojin-fix-ragged branch from 2de0a9c to 9172c06 Compare June 16, 2026 05:00
Comment thread tests/unit/moe_test.py
model_name="mixtral-8x7b",
override_model_config=True,
base_emb_dim=256,
base_emb_dim=2048, # we want emb dim being multiple of 1024 for fully using the kernel

@gobbleturk gobbleturk Jun 16, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a bit strange, why does emb dim matter?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ragged gather reduce kernel hardcoded num of column partitions being 8. Since the partition size should be multiple of 128, it means emb size should be multiple of 1024 for using this kernel. I updated the emb dim in test case for better demonstrating this knowledge, see context in

# This kernel partitions the output's columns into `num_column_partitions` and
# partition the output's rows into `num_row_partitions` and run each
# {row_partition} x {column_partition} combination on a separate SC subcore
# for parallelism. With such work partitioning, we guarantee that there won't
# be write collision (from different subcores) to the any output row X column.
#
# Each column partition should be multiple of 128 (number of lanes) due to
# DMA requirements. Unless requiring padding on the column dimension, larger
# column partitions (thus smaller row partitions given fixed num_cores) is
# more preferable because large row partition may lead to imbalanced load
# (valid_rows_mask may have more rows in some partitions than others).
# Most LLM's hidden size is multiple of 1024, `num_column_partitions=8` should
# work well in practice without requiring padding on the column size.
num_column_partitions = 8
assert num_cores % num_column_partitions == 0
num_rows_partitions = num_cores // num_column_partitions
aligned_hidden_size = _align_to(hidden_size, 128 * num_column_partitions)
col_size = aligned_hidden_size // num_column_partitions
row_tile_size = num_simd_lanes
padded_input_size = _align_to(
input_size,
math.lcm(num_rows_partitions * row_tile_size, reduce_group_size),
)
pad_input_size = padded_input_size - input_size
.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm this is a hard constraint for a lot of models, would this fail fast with a clear error message on these sizes? Ideally we can somehow support

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would create a bug to track

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, added in b/524661949

@Shuwen-Fang Shuwen-Fang requested review from Shuwen-Fang and removed request for jesselu-google June 16, 2026 16:52
@copybara-service copybara-service Bot merged commit 83412ca into main Jun 17, 2026
125 checks passed
@copybara-service copybara-service Bot deleted the chengnuojin-fix-ragged branch June 17, 2026 01:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants