[AIROCMLIR-707] Fix split-kv attention masking and sweep RMS for attention configs#2371
Open
bogdan-petkovic wants to merge 13 commits into
Open
[AIROCMLIR-707] Fix split-kv attention masking and sweep RMS for attention configs#2371bogdan-petkovic wants to merge 13 commits into
bogdan-petkovic wants to merge 13 commits into
Conversation
Signed-off-by: bogdan-petkovic <bogdan.petkovic@htecgroup.com>
Contributor
There was a problem hiding this comment.
Pull request overview
This PR aims to make split-KV attention performance sweeps reliable by fixing NaN production in the split-KV kernel softmax update path, aligning rocmlir-gen’s host-side split-KV masking/combine behavior with causal/prefix-causal semantics, and updating the attention sweep driver’s default RMS threshold when none is explicitly provided.
Changes:
- Add
-inf - (-inf)guards in split-KV softmax state updates to prevent NaNs in fully-masked/empty KV partitions. - Update rocmlir-gen split-KV validity masking to support per-row causal masking and do split-KV combine math in f32 for fp16/bf16 storage.
- Change attention sweep default
-RMS_thresholdinjection to0.005(for attention sweeps without an explicit threshold).
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 5 comments.
| File | Description |
|---|---|
mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp |
Prevent NaNs in split-KV softmax updates by guarding -inf - (-inf) cases. |
mlir/tools/rocmlir-gen/rocmlir-gen.cpp |
Rework split-KV valid-split masking for causal cases and do f32 combine for fp16/bf16. |
mlir/utils/performance/parameterSweeps.py |
Set a new default attention sweep RMS threshold (0.005) when none is specified. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Signed-off-by: bogdan-petkovic <bogdan.petkovic@htecgroup.com>
Signed-off-by: bogdan-petkovic <bpetkovi@amd.com>
Signed-off-by: bogdan-petkovic <bogdan.petkovic@htecgroup.com>
Signed-off-by: bogdan-petkovic <bpetkovi@amd.com>
Signed-off-by: bogdan-petkovic <bpetkovi@amd.com>
Signed-off-by: bogdan-petkovic <bpetkovi@amd.com>
Signed-off-by: bogdan-petkovic <bpetkovi@amd.com>
Signed-off-by: bogdan-petkovic <bpetkovi@amd.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
Attention performance sweeps were failing on split-KV configurations (
split_kv > 1) with RMS validation errors, NaNs, OOM crashes, or invalid results. The failures showed up across causal masking, GQA, KV-cache stylecurrent_seqlen, and mixed dtypes (f16,bf16,i8), including cases withtrans_qand bias.The goal is to make split-KV attention sweeps reliable without weakening the default kernel verifier policy. Fixes target:
scaleFinalOutput) for smallseq_len_kwith largesplit_kv.rocmlir-genthat did not match split-KV behavior under causal masking.rocmlir-gen's default-RMS_thresholdfor fp16/bf16 without an explicit override stays0.001. Only the attention sweep driver gets a separate default when the config does not set a threshold.Technical Details
Kernel path: guard
-inf - (-inf)in split-KV softmax updatesIn
mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp, split-KV attention rewrites update row-wise softmax state across KV partitions. When a partition has no valid keys, both the running row max and the partition max can be-inf, soexp2(score - max)andexp2(old_max - new_max)can see-inf - (-inf)and become NaN. That poisons row sums and downstream combine.The change detects the case where both operands are
-infand uses0beforeexp2, so empty or fully masked partitions contribute zero instead of NaN. This is a correctness fix on the kernel path, not a verifier relaxation.Kernel path: non-causal split-KV iteration math +
scaleFinalOutput0/0 guardThe non-causal / non-KV-cache split-KV branch in
GridwiseAttentionAccelRewritePatterncomputed per-split iterations asgemm0M / (gemm0MPerBlock * splitKV)using truncating integer division. Whengemm0MBlocks < splitKV(smallseq_len_kwith largesplit_kv), this evaluates to0for every split, so every split-block skips the softmax loop entirely. The kernel then divides the (zero) output bysum = 0inscaleFinalOutput, producing NaNs that propagate through the host combine stage.This is fixed by:
endtogemm0MBlocksso trailing splits wherestart >= gemm0MBlocksbecome cleanly empty.0/0guard inscaleFinalOutput: when a row's sum is exactly zero, the per-split output stores0instead ofNaN. The host combine stage already tolerates-infmax, but only when each per-split partial output is finite.Host reference: causal valid-split masking
In
mlir/tools/rocmlir-gen/rocmlir-gen.cpp,computeValidSplitKV()used a causal branch that forcedcurrSeqLen = 0, which made every split invalid and did not reflect per-query causal reach. Causal masking now builds a per-(batch-head, query-row) valid split count from each row's effective key length (includingprefix_offsetwhen present). Non-causal configs keep the per-batch-head path. TheusePerRowMaskpredicate also triggers for non-emptyprefix_offset, mirroring how the kernel treats prefix-causal as causal.createMaskSplitKV()accepts either per-batch-head or per-(batch-head, query-row)validSplitKVlayouts and broadcasts the threshold tensor accordingly.computeFinalAttentionStage()still masks invalid splits on both the partial output and LSE tensors before the split-KV combine; its size assertion now accepts either layout.Host reference: f32 combine for narrow floats
For
f16andbf16, the split-KV combine stage (reduce-max, exp, weighted sum, normalization) runs inf32and casts the final combined result back to storage type. That reduces accumulation error in the reference path when comparing against the GPU kernel during sweeps.Sweep policy: default attention RMS 0.005
In
mlir/utils/performance/parameterSweeps.py, attention sweeps without an explicit-RMS_thresholdnow append-RMS_threshold 0.005for all attention dtypes (bf16keeps0.01). The band covers observed sweep disagreement (including i8 withtrans_qand highsplit_kv) while staying tighter than the old bf16-only default.Configs that set their own
-RMS_thresholdare unchanged.Sweep policy: widen RMS band for unscaled large-head_dim attention
Sampled configs with
with_attn_scale=Falseandhead_dim_qk > 64saturate softmax (|QK| ~ O(sqrt(d))collapses to near one-hot), so CPU vs GPU float-arithmetic ordering insideexp/accumulate dominates the diff (observed RMS up to ~6% in bf16, ~1% in f16/i8) independent ofsplit_kv. For this regime only,test_configwidens-RMS_thresholdto0.15so the verifier still catches NaN/crash regressions but no longer false-fails on known float associativity in saturated softmax. All other configs are unaffected.Sweep generator: device-memory-aware split-KV prefilter (folds in draft #2366)
In
mlir/utils/performance/attentionSweeps.py, the sweep generator now estimates extra split-KV temporary storage for each sampled shape and rejects samples above a budget before generating MLIR. The budget defaults todeviceMem / 8, clamped to[1 GiB, 8 GiB], with a1.5 GiBfallback if the HIP query fails. A--splitkv-extra-bytes-limitCLI override is available. Filter-out reasons are now tracked separately (MAX_TOKENSvssplitKV extra-storage) and reported cumulatively across initial and refill batches.No compiler/verifier logic changes for this part; scope is limited to sweep sampling / filtering behavior.
Test Plan
AttentionSweepsjob.gfx1201(16 GiB) with all changes in place. 6/13 PASS on the kernel path. Of the remaining 7: 6 are correctly filtered out by the new split-KV prefilter (they were OOM-bound on a 16 GiB GPU becausesplit_kv = 128× largeg × num_heads_q × head_dim_vblew past VRAM), and 1 (split_kv=1, bf16 trans-all) trips only the per-elementrelDiffcheck on near-zero outputs while passing both RMS and abs-diff — out of scope here and tracked separately.Test Result
Submission Checklist