Skip to content

[AIROCMLIR-707] Fix split-kv attention masking and sweep RMS for attention configs#2371

Open
bogdan-petkovic wants to merge 13 commits into
ROCm:developfrom
bogdan-petkovic:bogdan-petkovic/attn-splitkv-sweep-fix
Open

[AIROCMLIR-707] Fix split-kv attention masking and sweep RMS for attention configs#2371
bogdan-petkovic wants to merge 13 commits into
ROCm:developfrom
bogdan-petkovic:bogdan-petkovic/attn-splitkv-sweep-fix

Conversation

@bogdan-petkovic
Copy link
Copy Markdown
Contributor

@bogdan-petkovic bogdan-petkovic commented May 11, 2026

Motivation

Attention performance sweeps were failing on split-KV configurations (split_kv > 1) with RMS validation errors, NaNs, OOM crashes, or invalid results. The failures showed up across causal masking, GQA, KV-cache style current_seqlen, and mixed dtypes (f16, bf16, i8), including cases with trans_q and bias.

The goal is to make split-KV attention sweeps reliable without weakening the default kernel verifier policy. Fixes target:

  • GPU split-KV softmax updates that produced NaNs on fully masked or empty key partitions.
  • GPU split-KV iteration math that yielded zero iterations per split (and downstream 0/0 in scaleFinalOutput) for small seq_len_k with large split_kv.
  • Host reference masking and combine logic in rocmlir-gen that did not match split-KV behavior under causal masking.
  • Sweep validation tolerance so attention configs are checked against a realistic numeric band instead of failing on reference-path noise.
  • Sweep generator behavior so memory-heavy split-KV cases don't surface in CI as OOM-crashed sweep samples.

rocmlir-gen's default -RMS_threshold for fp16/bf16 without an explicit override stays 0.001. Only the attention sweep driver gets a separate default when the config does not set a threshold.

Technical Details

Kernel path: guard -inf - (-inf) in split-KV softmax updates
In mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp, split-KV attention rewrites update row-wise softmax state across KV partitions. When a partition has no valid keys, both the running row max and the partition max can be -inf, so exp2(score - max) and exp2(old_max - new_max) can see -inf - (-inf) and become NaN. That poisons row sums and downstream combine.

The change detects the case where both operands are -inf and uses 0 before exp2, so empty or fully masked partitions contribute zero instead of NaN. This is a correctness fix on the kernel path, not a verifier relaxation.

Kernel path: non-causal split-KV iteration math + scaleFinalOutput 0/0 guard
The non-causal / non-KV-cache split-KV branch in GridwiseAttentionAccelRewritePattern computed per-split iterations as gemm0M / (gemm0MPerBlock * splitKV) using truncating integer division. When gemm0MBlocks < splitKV (small seq_len_k with large split_kv), this evaluates to 0 for every split, so every split-block skips the softmax loop entirely. The kernel then divides the (zero) output by sum = 0 in scaleFinalOutput, producing NaNs that propagate through the host combine stage.

This is fixed by:

  • Using ceil-division to compute iterations per split, then clamping end to gemm0MBlocks so trailing splits where start >= gemm0MBlocks become cleanly empty.
  • A defensive 0/0 guard in scaleFinalOutput: when a row's sum is exactly zero, the per-split output stores 0 instead of NaN. The host combine stage already tolerates -inf max, but only when each per-split partial output is finite.

Host reference: causal valid-split masking
In mlir/tools/rocmlir-gen/rocmlir-gen.cpp, computeValidSplitKV() used a causal branch that forced currSeqLen = 0, which made every split invalid and did not reflect per-query causal reach. Causal masking now builds a per-(batch-head, query-row) valid split count from each row's effective key length (including prefix_offset when present). Non-causal configs keep the per-batch-head path. The usePerRowMask predicate also triggers for non-empty prefix_offset, mirroring how the kernel treats prefix-causal as causal.

createMaskSplitKV() accepts either per-batch-head or per-(batch-head, query-row) validSplitKV layouts and broadcasts the threshold tensor accordingly. computeFinalAttentionStage() still masks invalid splits on both the partial output and LSE tensors before the split-KV combine; its size assertion now accepts either layout.

Host reference: f32 combine for narrow floats
For f16 and bf16, the split-KV combine stage (reduce-max, exp, weighted sum, normalization) runs in f32 and casts the final combined result back to storage type. That reduces accumulation error in the reference path when comparing against the GPU kernel during sweeps.

Sweep policy: default attention RMS 0.005
In mlir/utils/performance/parameterSweeps.py, attention sweeps without an explicit -RMS_threshold now append -RMS_threshold 0.005 for all attention dtypes (bf16 keeps 0.01). The band covers observed sweep disagreement (including i8 with trans_q and high split_kv) while staying tighter than the old bf16-only default.

Configs that set their own -RMS_threshold are unchanged.

Sweep policy: widen RMS band for unscaled large-head_dim attention
Sampled configs with with_attn_scale=False and head_dim_qk > 64 saturate softmax (|QK| ~ O(sqrt(d)) collapses to near one-hot), so CPU vs GPU float-arithmetic ordering inside exp/accumulate dominates the diff (observed RMS up to ~6% in bf16, ~1% in f16/i8) independent of split_kv. For this regime only, test_config widens -RMS_threshold to 0.15 so the verifier still catches NaN/crash regressions but no longer false-fails on known float associativity in saturated softmax. All other configs are unaffected.

Sweep generator: device-memory-aware split-KV prefilter (folds in draft #2366)
In mlir/utils/performance/attentionSweeps.py, the sweep generator now estimates extra split-KV temporary storage for each sampled shape and rejects samples above a budget before generating MLIR. The budget defaults to deviceMem / 8, clamped to [1 GiB, 8 GiB], with a 1.5 GiB fallback if the HIP query fails. A --splitkv-extra-bytes-limit CLI override is available. Filter-out reasons are now tracked separately (MAX_TOKENS vs splitKV extra-storage) and reported cumulatively across initial and refill batches.

No compiler/verifier logic changes for this part; scope is limited to sweep sampling / filtering behavior.

Test Plan

  • Rely on PR CI for the full build, lit, Python performance script tests, and the AttentionSweeps job.
  • Locally re-ran the original 8 failing May-3 sweep configs plus 5 newer May-14 sweep failures on gfx1201 (16 GiB) with all changes in place. 6/13 PASS on the kernel path. Of the remaining 7: 6 are correctly filtered out by the new split-KV prefilter (they were OOM-bound on a 16 GiB GPU because split_kv = 128 × large g × num_heads_q × head_dim_v blew past VRAM), and 1 (split_kv=1, bf16 trans-all) trips only the per-element relDiff check on near-zero outputs while passing both RMS and abs-diff — out of scope here and tracked separately.

Test Result

  • PR CI (build, lit, Python format/lint, Python performance script tests, AttentionSweeps)

Submission Checklist

Signed-off-by: bogdan-petkovic <bogdan.petkovic@htecgroup.com>
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR aims to make split-KV attention performance sweeps reliable by fixing NaN production in the split-KV kernel softmax update path, aligning rocmlir-gen’s host-side split-KV masking/combine behavior with causal/prefix-causal semantics, and updating the attention sweep driver’s default RMS threshold when none is explicitly provided.

Changes:

  • Add -inf - (-inf) guards in split-KV softmax state updates to prevent NaNs in fully-masked/empty KV partitions.
  • Update rocmlir-gen split-KV validity masking to support per-row causal masking and do split-KV combine math in f32 for fp16/bf16 storage.
  • Change attention sweep default -RMS_threshold injection to 0.005 (for attention sweeps without an explicit threshold).

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 5 comments.

File Description
mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp Prevent NaNs in split-KV softmax updates by guarding -inf - (-inf) cases.
mlir/tools/rocmlir-gen/rocmlir-gen.cpp Rework split-KV valid-split masking for causal cases and do f32 combine for fp16/bf16.
mlir/utils/performance/parameterSweeps.py Set a new default attention sweep RMS threshold (0.005) when none is specified.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread mlir/tools/rocmlir-gen/rocmlir-gen.cpp Outdated
Comment thread mlir/tools/rocmlir-gen/rocmlir-gen.cpp Outdated
Comment thread mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp Outdated
Comment thread mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp Outdated
Comment thread mlir/tools/rocmlir-gen/rocmlir-gen.cpp
bogdan-petkovic and others added 7 commits May 14, 2026 12:40
Signed-off-by: bogdan-petkovic <bogdan.petkovic@htecgroup.com>
Signed-off-by: bogdan-petkovic <bpetkovi@amd.com>
Signed-off-by: bogdan-petkovic <bogdan.petkovic@htecgroup.com>
Signed-off-by: bogdan-petkovic <bpetkovi@amd.com>
Signed-off-by: bogdan-petkovic <bpetkovi@amd.com>
@bogdan-petkovic bogdan-petkovic marked this pull request as ready for review May 15, 2026 12:18
@bogdan-petkovic bogdan-petkovic requested a review from causten as a code owner May 15, 2026 12:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants