Skip to content

Fuse final RMSNorm + LM-head into qwen3_decode_all#295

Open
wangqin1723-max wants to merge 1 commit into
hw-native-sys:mainfrom
wangqin1723-max:fused_rms_lmhead_to_decode
Open

Fuse final RMSNorm + LM-head into qwen3_decode_all#295
wangqin1723-max wants to merge 1 commit into
hw-native-sys:mainfrom
wangqin1723-max:fused_rms_lmhead_to_decode

Conversation

@wangqin1723-max
Copy link
Copy Markdown
Contributor

@wangqin1723-max wangqin1723-max commented May 15, 2026

Summary

  • Merge qwen3_final_rms and qwen3_lm_head into the L2 qwen3_decode_all kernel, eliminating one chip-task dispatch per decode step (~20 ms validate + ~1 ms prepare_ctx).
  • Fused kernel returns (decode_out, rms_normed, logits_padded) as a pl.Tuple so all three Out params satisfy InOutUseDiscipline at the call site.
  • Drop the rms_x aliased view of decode_out in PyptoQwen14BExecutor; the fused kernel reads the last-layer hidden state from its own scratch, so decode_out can be allocated directly at the user batch shape.

Related Issues

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented May 15, 2026

Review Change Stack

📝 Walkthrough

Walkthrough

This PR refactors the Qwen3-14B L3 generation pipeline to fuse final RMSNorm and LM-head computation into the decode kernel. The qwen3_decode_all function now computes rms_normed and logits_padded internally and returns them alongside decode_out, eliminating the separate RMSNorm/LM-head dispatch. The orchestrator and executor are rewired to consume the new tuple returns and match the updated buffer layout.

Changes

Qwen3-14B L3 Decode Fusion

Layer / File(s) Summary
K_CHUNK tuning and qwen3_decode_all fusion
models/qwen3/14b/qwen3_14b_l3_generate.py
K_CHUNK constant increases from 128 to 256. qwen3_decode_all signature is extended to accept final_norm_weight and lm_head_weight_t and output buffers for rms_normed and logits_padded, with return type changed to a tuple. A fused tail section computes RMSNorm and LM-head GEMM, producing all three outputs directly.
Host orchestrator wiring for fused decode outputs
models/qwen3/14b/qwen3_14b_l3_generate.py
host_orch signature is extended to accept the final-RMS and LM-head weights plus new output buffers. Step0 and main decode loop are updated to destructure the tuple (out_d, out_rms, out_lm) from qwen3_decode_all and pass the new arguments, removing the prior two-stage decode-then-qwen3_rms_lmhead flow.
Executor buffer allocation and tensor-dict binding
llm/core/pypto_executor.py
decode_out is allocated as an independent shared-memory tensor of shape (actual_batch, hidden_size). The tensor-dict values list for the l3_generate dispatch removes the shared rms_x entry, ensuring strict-parameter binding aligns with the new orchestrator interface.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

  • hw-native-sys/pypto-lib#258: Refactors Qwen3-14B L3 decode path around final RMSNorm + LM-head fusion, with host_orch directly consuming the fused tuple outputs instead of stepwise dispatch calls.
  • hw-native-sys/pypto-lib#245: Introduces the unified L3 generation flow that this PR refactors; modifies the same PyptoQwen14BExecutor L3 buffer/tensor-dict bindings and qwen3_decode_all/host_orch orchestrator interfaces.

Poem

🐰 A tail of decode so fused and bright,
RMSNorm and logits merged in flight,
Tensors dance in shared-memory grace,
No dispatch needed—fusion's the pace!
One kernel's work now spans the whole race.

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 33.33% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately and concisely summarizes the main change: fusing final RMSNorm and LM-head operations into the qwen3_decode_all function.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Description check ✅ Passed The pull request description accurately describes the main changes: fusing RMSNorm and LM-head into qwen3_decode_all, removing an extra dispatch per step, and simplifying buffer allocation in PyptoQwen14BExecutor.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@models/qwen3/14b/qwen3_14b_l3_generate.py`:
- Around line 757-766: The fused tail in qwen3_decode_all currently only
processes indices 0..BATCH_TILE (e.g., uses BATCH_TILE-limited loops at the end)
while the rest of the function operates on batch_padded / USER_BATCH_DYN, so
when model.runtime.max_batch_size > BATCH_TILE later tiles never get
RMSNorm/logits; fix by either rejecting builds with batch_padded > BATCH_TILE at
build time (validate model.runtime.max_batch_size and raise) or extend the fused
tail to tile the output buffers and loops the same way as the main decode path
(process all tiles: iterate over tiles up to batch_padded using the same tiling
math and write to rms_normed/out/logits_padded for every tile), updating the
branches that reference BATCH_TILE-limited ranges and the output buffers
(rms_normed, out, logits_padded) to be tile-aware so every batch element
receives final RMSNorm and logits.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: 7a372ad8-6016-40ac-b949-b63611e23957

📥 Commits

Reviewing files that changed from the base of the PR and between aee5258 and 05c0c58.

📒 Files selected for processing (2)
  • llm/core/pypto_executor.py
  • models/qwen3/14b/qwen3_14b_l3_generate.py

Comment on lines +757 to +766
final_norm_weight: pl.Tensor[[1, hidden], pl.FP32],
lm_head_weight_t: pl.Tensor[[padded_vocab, hidden], pl.BF16],
out: pl.Out[pl.Tensor[[USER_BATCH_DYN, hidden], pl.BF16]],
) -> pl.Tensor[[USER_BATCH_DYN, hidden], pl.BF16]:
rms_normed: pl.Out[pl.Tensor[[BATCH_TILE, hidden], pl.BF16]],
logits_padded: pl.Out[pl.Tensor[[BATCH_TILE, padded_vocab], pl.FP32]],
) -> pl.Tuple[
pl.Tensor[[USER_BATCH_DYN, hidden], pl.BF16],
pl.Tensor[[BATCH_TILE, hidden], pl.BF16],
pl.Tensor[[BATCH_TILE, padded_vocab], pl.FP32],
]:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Guard the fused tail against batches larger than one tile.

Lines 1317 and 1355 only process 0..BATCH_TILE, but the rest of qwen3_decode_all still handles batch_padded. Since the executor forwards model.runtime.max_batch_size into this builder, compiling L3 with batch > 16 will leave later tiles without final RMSNorm/logits even though the decode path ran for them. Please either reject batch > BATCH_TILE at build time or tile the fused tail/output buffers the same way as the main decode path.

🛡️ Minimal safe fix
 def build_qwen3_14b_l3_generate_program(
     num_layers: int = 40,
     batch: int = BATCH,
@@
     page_size: int = SEQ_TILE,
 ):
+    if batch > BATCH_TILE:
+        raise ValueError(
+            f"fused qwen3_decode_all tail currently supports batch <= {BATCH_TILE}, got {batch}"
+        )
     if page_size != SEQ_TILE:
         raise ValueError(
             f"page_size={page_size} must equal SEQ_TILE={SEQ_TILE} "

Also applies to: 1310-1372

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@models/qwen3/14b/qwen3_14b_l3_generate.py` around lines 757 - 766, The fused
tail in qwen3_decode_all currently only processes indices 0..BATCH_TILE (e.g.,
uses BATCH_TILE-limited loops at the end) while the rest of the function
operates on batch_padded / USER_BATCH_DYN, so when model.runtime.max_batch_size
> BATCH_TILE later tiles never get RMSNorm/logits; fix by either rejecting
builds with batch_padded > BATCH_TILE at build time (validate
model.runtime.max_batch_size and raise) or extend the fused tail to tile the
output buffers and loops the same way as the main decode path (process all
tiles: iterate over tiles up to batch_padded using the same tiling math and
write to rms_normed/out/logits_padded for every tile), updating the branches
that reference BATCH_TILE-limited ranges and the output buffers (rms_normed,
out, logits_padded) to be tile-aware so every batch element receives final
RMSNorm and logits.

Merge qwen3_final_rms and qwen3_lm_head into the L2 qwen3_decode_all
kernel, eliminating one chip-task dispatch per decode step (~20 ms
validate + ~1 ms prepare_ctx). The fused kernel returns
(decode_out, rms_normed, logits_padded) as a pl.Tuple so all three Out
params satisfy InOutUseDiscipline at the call site.

Drop the rms_x aliased view of decode_out in PyptoQwen14BExecutor; the
fused kernel reads the last-layer hidden state from its own scratch
buffer, so decode_out can be allocated directly at the user batch
shape.
@wangqin1723-max wangqin1723-max force-pushed the fused_rms_lmhead_to_decode branch from 05c0c58 to 60fcbc5 Compare May 15, 2026 04:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant