Update: reduce DeepSeek V4 sparse attention tasks#301
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Organization UI Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
📝 WalkthroughWalkthroughThe PR refactors sparse attention accumulation from explicit stepwise loop mutations to a modern ChangesSparse Attention Refactoring
Estimated Code Review Effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly Related PRs
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request refactors the attention mechanisms in models/deepseek/v4/attention_swa.py and models/deepseek/v4/sparse_attn.py by removing unnecessary batch chunking loops and optimizing the pl.at block structures. A critical issue was identified in sparse_attn.py where the variables used for final normalization are not defined if the sparse_k loop is empty (e.g., when sparse_k is 1), which would cause a runtime error. I recommend initializing these variables with their initial values before the loop to ensure they are always defined.
| for kk, (mi_loop, li_loop, oi_loop) in pl.range( | ||
| 1, | ||
| sparse_k, | ||
| init_values=(mi_init, li_init, oi_init), | ||
| ): | ||
| cur_kv_batch = pl.col_expand( | ||
| pl.full([MATMUL_ROW_PAD, HEAD_DIM], dtype=pl.FP32, value=0.0), | ||
| pl.cast(kv_topk_batch[kk : kk + 1, 0 : HEAD_DIM], target_type=pl.FP32), | ||
| ) | ||
| cur_score = pl.row_sum(pl.mul(q_batch, kv_batch)) | ||
| cur_score = pl.row_sum(pl.mul(q_batch, cur_kv_batch)) | ||
| cur_mi = pl.mul(cur_score, SOFTMAX_SCALE) | ||
| mi_new = pl.maximum(mi, cur_mi) | ||
| alpha = pl.exp(pl.sub(mi, mi_new)) | ||
| mi_new = pl.maximum(mi_loop, cur_mi) | ||
| alpha = pl.exp(pl.sub(mi_loop, mi_new)) | ||
| beta = pl.exp(pl.sub(cur_mi, mi_new)) | ||
| li = pl.add(pl.mul(alpha, li), beta) | ||
| oi = pl.add( | ||
| pl.row_expand_mul(oi, alpha), | ||
| pl.row_expand_mul(kv_batch, beta), | ||
| li_new = pl.add(pl.mul(alpha, li_loop), beta) | ||
| oi_new = pl.add( | ||
| pl.row_expand_mul(oi_loop, alpha), | ||
| pl.row_expand_mul(cur_kv_batch, beta), | ||
| ) | ||
| mi = mi_new | ||
| (mi_final, li_final, oi_final) = pl.yield_(mi_new, li_new, oi_new) |
There was a problem hiding this comment.
The variables mi_final, li_final, and oi_final are only assigned within the pl.range loop body. If sparse_k is 1 (which occurs during decode when only the current token is valid in the window and no compressed tokens are selected), the loop range pl.range(1, 1) will be empty. Consequently, these variables will remain undefined when they are accessed for the final normalization and output assembly at lines 218-220, leading to a runtime error. They should be initialized with the _init values before the loop to ensure correctness for all sparse_k values.
mi_final, li_final, oi_final = mi_init, li_init, oi_init
for kk, (mi_loop, li_loop, oi_loop) in pl.range(
1,
sparse_k,
init_values=(mi_init, li_init, oi_init),
):
cur_kv_batch = pl.col_expand(
pl.full([MATMUL_ROW_PAD, HEAD_DIM], dtype=pl.FP32, value=0.0),
pl.cast(kv_topk_batch[kk : kk + 1, 0 : HEAD_DIM], target_type=pl.FP32),
)
cur_score = pl.row_sum(pl.mul(q_batch, cur_kv_batch))
cur_mi = pl.mul(cur_score, SOFTMAX_SCALE)
mi_new = pl.maximum(mi_loop, cur_mi)
alpha = pl.exp(pl.sub(mi_loop, mi_new))
beta = pl.exp(pl.sub(cur_mi, mi_new))
li_new = pl.add(pl.mul(alpha, li_loop), beta)
oi_new = pl.add(
pl.row_expand_mul(oi_loop, alpha),
pl.row_expand_mul(cur_kv_batch, beta),
)
mi_final, li_final, oi_final = mi_new, li_new, oi_new
pl.yield_(mi_new, li_new, oi_new)- Fuse sparse attention online softmax into one device task per head tile - Pack grouped output rows per batch instead of per head - Keep attention_swa tensor setup chunked for larger T scaling
317296d to
2c02f00
Compare
Summary
Related Issues
None