Fuse final RMSNorm + LM-head into qwen3_decode_all#295
Conversation
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
📝 WalkthroughWalkthroughThis PR refactors the Qwen3-14B L3 generation pipeline to fuse final RMSNorm and LM-head computation into the decode kernel. The ChangesQwen3-14B L3 Decode Fusion
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@models/qwen3/14b/qwen3_14b_l3_generate.py`:
- Around line 757-766: The fused tail in qwen3_decode_all currently only
processes indices 0..BATCH_TILE (e.g., uses BATCH_TILE-limited loops at the end)
while the rest of the function operates on batch_padded / USER_BATCH_DYN, so
when model.runtime.max_batch_size > BATCH_TILE later tiles never get
RMSNorm/logits; fix by either rejecting builds with batch_padded > BATCH_TILE at
build time (validate model.runtime.max_batch_size and raise) or extend the fused
tail to tile the output buffers and loops the same way as the main decode path
(process all tiles: iterate over tiles up to batch_padded using the same tiling
math and write to rms_normed/out/logits_padded for every tile), updating the
branches that reference BATCH_TILE-limited ranges and the output buffers
(rms_normed, out, logits_padded) to be tile-aware so every batch element
receives final RMSNorm and logits.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: 7a372ad8-6016-40ac-b949-b63611e23957
📒 Files selected for processing (2)
llm/core/pypto_executor.pymodels/qwen3/14b/qwen3_14b_l3_generate.py
| final_norm_weight: pl.Tensor[[1, hidden], pl.FP32], | ||
| lm_head_weight_t: pl.Tensor[[padded_vocab, hidden], pl.BF16], | ||
| out: pl.Out[pl.Tensor[[USER_BATCH_DYN, hidden], pl.BF16]], | ||
| ) -> pl.Tensor[[USER_BATCH_DYN, hidden], pl.BF16]: | ||
| rms_normed: pl.Out[pl.Tensor[[BATCH_TILE, hidden], pl.BF16]], | ||
| logits_padded: pl.Out[pl.Tensor[[BATCH_TILE, padded_vocab], pl.FP32]], | ||
| ) -> pl.Tuple[ | ||
| pl.Tensor[[USER_BATCH_DYN, hidden], pl.BF16], | ||
| pl.Tensor[[BATCH_TILE, hidden], pl.BF16], | ||
| pl.Tensor[[BATCH_TILE, padded_vocab], pl.FP32], | ||
| ]: |
There was a problem hiding this comment.
Guard the fused tail against batches larger than one tile.
Lines 1317 and 1355 only process 0..BATCH_TILE, but the rest of qwen3_decode_all still handles batch_padded. Since the executor forwards model.runtime.max_batch_size into this builder, compiling L3 with batch > 16 will leave later tiles without final RMSNorm/logits even though the decode path ran for them. Please either reject batch > BATCH_TILE at build time or tile the fused tail/output buffers the same way as the main decode path.
🛡️ Minimal safe fix
def build_qwen3_14b_l3_generate_program(
num_layers: int = 40,
batch: int = BATCH,
@@
page_size: int = SEQ_TILE,
):
+ if batch > BATCH_TILE:
+ raise ValueError(
+ f"fused qwen3_decode_all tail currently supports batch <= {BATCH_TILE}, got {batch}"
+ )
if page_size != SEQ_TILE:
raise ValueError(
f"page_size={page_size} must equal SEQ_TILE={SEQ_TILE} "Also applies to: 1310-1372
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@models/qwen3/14b/qwen3_14b_l3_generate.py` around lines 757 - 766, The fused
tail in qwen3_decode_all currently only processes indices 0..BATCH_TILE (e.g.,
uses BATCH_TILE-limited loops at the end) while the rest of the function
operates on batch_padded / USER_BATCH_DYN, so when model.runtime.max_batch_size
> BATCH_TILE later tiles never get RMSNorm/logits; fix by either rejecting
builds with batch_padded > BATCH_TILE at build time (validate
model.runtime.max_batch_size and raise) or extend the fused tail to tile the
output buffers and loops the same way as the main decode path (process all
tiles: iterate over tiles up to batch_padded using the same tiling math and
write to rms_normed/out/logits_padded for every tile), updating the branches
that reference BATCH_TILE-limited ranges and the output buffers (rms_normed,
out, logits_padded) to be tile-aware so every batch element receives final
RMSNorm and logits.
Merge qwen3_final_rms and qwen3_lm_head into the L2 qwen3_decode_all kernel, eliminating one chip-task dispatch per decode step (~20 ms validate + ~1 ms prepare_ctx). The fused kernel returns (decode_out, rms_normed, logits_padded) as a pl.Tuple so all three Out params satisfy InOutUseDiscipline at the call site. Drop the rms_x aliased view of decode_out in PyptoQwen14BExecutor; the fused kernel reads the last-layer hidden state from its own scratch buffer, so decode_out can be allocated directly at the user batch shape.
05c0c58 to
60fcbc5
Compare
Summary
qwen3_final_rmsandqwen3_lm_headinto the L2qwen3_decode_allkernel, eliminating one chip-task dispatch per decode step (~20 ms validate + ~1 ms prepare_ctx).(decode_out, rms_normed, logits_padded)as apl.Tupleso all three Out params satisfyInOutUseDisciplineat the call site.rms_xaliased view ofdecode_outinPyptoQwen14BExecutor; the fused kernel reads the last-layer hidden state from its own scratch, sodecode_outcan be allocated directly at the user batch shape.Related Issues