Add weight for ragged gather kernel and enable fan out in bwd ragged sort#4166
Conversation
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
ee734a4 to
2de0a9c
Compare
2de0a9c to
9172c06
Compare
| model_name="mixtral-8x7b", | ||
| override_model_config=True, | ||
| base_emb_dim=256, | ||
| base_emb_dim=2048, # we want emb dim being multiple of 1024 for fully using the kernel |
There was a problem hiding this comment.
this is a bit strange, why does emb dim matter?
There was a problem hiding this comment.
ragged gather reduce kernel hardcoded num of column partitions being 8. Since the partition size should be multiple of 128, it means emb size should be multiple of 1024 for using this kernel. I updated the emb dim in test case for better demonstrating this knowledge, see context in
maxtext/src/maxtext/kernels/ragged/ragged_gather_reduce.py
Lines 434 to 458 in 61c225f
There was a problem hiding this comment.
hmm this is a hard constraint for a lot of models, would this fail fast with a clear error message on these sizes? Ideally we can somehow support
There was a problem hiding this comment.
I would create a bug to track
There was a problem hiding this comment.
Agree, added in b/524661949
Description
This PR fixes the usage of
ragged_gatherkernel under_ring_ragged_unsort_bwd. As a result, itbitcast_addin backward passPreviously, this kernel was applied in non-fan-out way, since the backward pass the unsort requires weight balance, while the ragged gather kernel did not support. This PR adds this support and enable a fan-out style of this kernel, successfully avoid materializing the full padded-shape tensor in fp32.
FIXES: b/522526813
Tests
ragged_buffer_factor=3.0: xprofChecklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.