Skip to content

refactor(dsv4): use pl.spmd for compressor_ratio128 dispatch sites#359

Open
bumble0918 wants to merge 1 commit into
hw-native-sys:mainfrom
bumble0918:feature/2026-05-22
Open

refactor(dsv4): use pl.spmd for compressor_ratio128 dispatch sites#359
bumble0918 wants to merge 1 commit into
hw-native-sys:mainfrom
bumble0918:feature/2026-05-22

Conversation

@bumble0918
Copy link
Copy Markdown
Contributor

#314
Replace pl.parallel + with pl.at(level=CORE_GROUP) with pl.spmd(...) at four sites in compressor (kv_score_proj, softmax_pool, kv_write, state_scatter_next). The latter two sit inside if/else branches and required the SSA scope fix in pypto #1414.

On a2a3 with --enable-l2-swimlane, Total Test Time drops from 409.7 us to 372.8 us (-9%); avg Head OH per task drops from 2.29 us to 0.57 us. All four output tensors continue to match the torch golden.

Replace `pl.parallel + with pl.at(level=CORE_GROUP)` with `pl.spmd(...)`
at four sites in compressor (kv_score_proj, softmax_pool, kv_write,
state_scatter_next). The latter two sit inside if/else branches and
required the SSA scope fix in pypto #1414.

On a2a3 with --enable-l2-swimlane, Total Test Time drops from 409.7 us
to 372.8 us (-9%); avg Head OH per task drops from 2.29 us to 0.57 us.
All four output tensors continue to match the torch golden.
@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented May 22, 2026

Review Change Stack

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: 0e88282b-2b3d-4d19-ae8a-bf95ee58b8f8

📥 Commits

Reviewing files that changed from the base of the PR and between be3c794 and fa8b524.

📒 Files selected for processing (1)
  • models/deepseek/v4/decode_compressor_ratio128.py

📝 Walkthrough

Walkthrough

The decode compressor kernel's four computational stages were systematically refactored to use pl.spmd(...)-driven iteration instead of pl.parallel(...)/pl.at(...) patterns, maintaining identical online softmax, pooling, writeback, and state-scatter semantics across batch, head-block, and output-column dimensions.

Changes

SPMD Iteration Pattern Conversion

Layer / File(s) Summary
Projection scratch batch iteration
models/deepseek/v4/decode_compressor_ratio128.py
Batch-parallel projection scratch computation (cmp128_kv_proj_scratch, cmp128_score_proj_scratch) refactored from pl.parallel/pl.at to pl.spmd over B * S // BATCH_CHUNK_0, preserving inner K-block pipelined matmul and tile assembly.
Softmax pooling per-head iteration
models/deepseek/v4/decode_compressor_ratio128.py
Head-block softmax pooling loop converted from pl.parallel(...)/pl.at(...) to pl.spmd(...), maintaining last-slot initialization, pre-token seed adjustment, online softmax recurrence (mi/li/oi), and per-head pooled assembly.
Non-rotate KV output writeback
models/deepseek/v4/decode_compressor_ratio128.py
KV final output writeback changed from pl.parallel/pl.at with core-group placement to pl.spmd tile iteration, assembling kv_final from normalized KV output tiles.
Next state scatter and update
models/deepseek/v4/decode_compressor_ratio128.py
Post-pooling state scatter refactored to nested pl.spmd/pl.range iteration, updating kv_state_flat and score_state_flat for s in [pre_tokens, S) with preserved dep_zero handling and APE index computation.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Possibly related PRs

  • hw-native-sys/pypto-lib#243: Implements the initial ratio128 decode compressor kernel that this PR refactors.
  • hw-native-sys/pypto-lib#342: Earlier tiling refactor of the same compressor stages (projection, pooling, writeback, state scatter) that restructures KV/score scratch handling and output dimensions.

Poem

🐰 From parallel dreams to spmd flows,
The kernel reshapes as structure grows,
Four stages glide through softer streams,
Where iteration patterns weave new schemes.

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately describes the main refactoring work: replacing pl.parallel + pl.at dispatches with pl.spmd across multiple sites in the compressor_ratio128 kernel.
Description check ✅ Passed The description is directly related to the changeset, detailing the four specific dispatch sites refactored, the underlying pypto fix dependency, and measurable performance improvements with validation results.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the compressor function in models/deepseek/v4/decode_compressor_ratio128.py by replacing several pl.parallel loops and pl.at context managers with pl.spmd loops. These changes affect the KV/score projection, softmax pooling, KV writing, and state scattering logic. A review comment identifies a potential issue where the use of integer division for the SPMD loop count could lead to unprocessed tokens if the total number of tokens is not a multiple of the batch chunk size, and suggests using ceiling division instead.


cmp128_kv_proj_scratch = pl.assemble(cmp128_kv_proj_scratch, kv_acc, [b_idx, o0])
cmp128_score_proj_scratch = pl.assemble(cmp128_score_proj_scratch, score_acc, [b_idx, o0])
for bi in pl.spmd(B * S // BATCH_CHUNK_0, name_hint="kv_score_proj"):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The use of integer division B * S // BATCH_CHUNK_0 for the pl.spmd loop count assumes that the total number of tokens is always a multiple of the batch chunk size. While this holds for the current DECODE_BATCH=64 and DECODE_SEQ=2 configuration, it may lead to missing work if the batch size is changed to a value that is not a multiple of 64 (e.g., 48), or if B * S is less than 64 (in which case the loop won't run at all). Consider using a ceiling division or adding a check to ensure all tokens are processed.

Suggested change
for bi in pl.spmd(B * S // BATCH_CHUNK_0, name_hint="kv_score_proj"):
for bi in pl.spmd((B * S + BATCH_CHUNK_0 - 1) // BATCH_CHUNK_0, name_hint="kv_score_proj"):

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant