refactor(dsv4): use pl.spmd for compressor_ratio128 dispatch sites#359
refactor(dsv4): use pl.spmd for compressor_ratio128 dispatch sites#359bumble0918 wants to merge 1 commit into
Conversation
Replace `pl.parallel + with pl.at(level=CORE_GROUP)` with `pl.spmd(...)` at four sites in compressor (kv_score_proj, softmax_pool, kv_write, state_scatter_next). The latter two sit inside if/else branches and required the SSA scope fix in pypto #1414. On a2a3 with --enable-l2-swimlane, Total Test Time drops from 409.7 us to 372.8 us (-9%); avg Head OH per task drops from 2.29 us to 0.57 us. All four output tensors continue to match the torch golden.
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Organization UI Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (1)
📝 WalkthroughWalkthroughThe decode compressor kernel's four computational stages were systematically refactored to use ChangesSPMD Iteration Pattern Conversion
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request refactors the compressor function in models/deepseek/v4/decode_compressor_ratio128.py by replacing several pl.parallel loops and pl.at context managers with pl.spmd loops. These changes affect the KV/score projection, softmax pooling, KV writing, and state scattering logic. A review comment identifies a potential issue where the use of integer division for the SPMD loop count could lead to unprocessed tokens if the total number of tokens is not a multiple of the batch chunk size, and suggests using ceiling division instead.
|
|
||
| cmp128_kv_proj_scratch = pl.assemble(cmp128_kv_proj_scratch, kv_acc, [b_idx, o0]) | ||
| cmp128_score_proj_scratch = pl.assemble(cmp128_score_proj_scratch, score_acc, [b_idx, o0]) | ||
| for bi in pl.spmd(B * S // BATCH_CHUNK_0, name_hint="kv_score_proj"): |
There was a problem hiding this comment.
The use of integer division B * S // BATCH_CHUNK_0 for the pl.spmd loop count assumes that the total number of tokens is always a multiple of the batch chunk size. While this holds for the current DECODE_BATCH=64 and DECODE_SEQ=2 configuration, it may lead to missing work if the batch size is changed to a value that is not a multiple of 64 (e.g., 48), or if B * S is less than 64 (in which case the loop won't run at all). Consider using a ceiling division or adding a check to ensure all tokens are processed.
| for bi in pl.spmd(B * S // BATCH_CHUNK_0, name_hint="kv_score_proj"): | |
| for bi in pl.spmd((B * S + BATCH_CHUNK_0 - 1) // BATCH_CHUNK_0, name_hint="kv_score_proj"): |
#314
Replace
pl.parallel + with pl.at(level=CORE_GROUP)withpl.spmd(...)at four sites in compressor (kv_score_proj, softmax_pool, kv_write, state_scatter_next). The latter two sit inside if/else branches and required the SSA scope fix in pypto #1414.On a2a3 with --enable-l2-swimlane, Total Test Time drops from 409.7 us to 372.8 us (-9%); avg Head OH per task drops from 2.29 us to 0.57 us. All four output tensors continue to match the torch golden.