Improve SDPA to handle GQA and update models to use native GQA#18513
Improve SDPA to handle GQA and update models to use native GQA#18513mergennachin merged 2 commits intomainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18513
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 75 PendingAs of commit 3c4a653 with merge base 0527015 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
There was a problem hiding this comment.
Pull request overview
This PR extends the CUDA Triton SDPA implementation to support Grouped Query Attention (GQA/MQA) end-to-end, and updates the Qwen3.5 MoE example to rely on native GQA handling instead of explicitly expanding KV heads.
Changes:
- Update Qwen3.5 MoE attention to call PyTorch SDPA with
enable_gqa=True(removing explicit KV head expansion). - Add GQA support to the Triton SDPA kernel, including an optional “pack GQA” optimization with a utilization heuristic.
- Add a new CUDA test suite covering MHA/GQA/MQA correctness across shapes, masks, causal/non-causal, and pack-gqa behavior.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| examples/models/qwen3_5_moe/model.py | Switches attention to rely on SDPA’s native GQA support (enable_gqa=True) instead of repeating KV heads. |
| backends/cuda/triton/kernels/sdpa.py | Adds GQA shape validation, head remapping/packing logic, and propagates enable_gqa through kernel launches. |
| backends/cuda/tests/test_triton_sdpa.py | Introduces comprehensive correctness tests for Triton SDPA across MHA/GQA/MQA configurations. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
7830737 to
136d0a9
Compare
|
@claude review this PR |
|
Claude finished @JacobSzwejbka's task in 2m 5s —— View job Review: Improve SDPA to handle GQA
This PR adds GQA/MQA support to the Triton SDPA kernel with a "pack GQA" optimization adapted from FlashAttention, and updates the Qwen 3.5 MoE model to use native GQA. The approach is sound and the test coverage is thorough. A few issues to address: Bug:
|
136d0a9 to
e43437f
Compare
cdaa2d9 to
14524c2
Compare
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 4 out of 4 changed files in this pull request and generated 1 comment.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
backends/cuda/triton/kernels/sdpa.py
Outdated
| # Mirror the kernel selection logic: M32 when CTAs are sparse, M64 otherwise. | ||
| total_ctas_m64 = ((L_q * num_groups + 63) // 64) * (B * H_kv) | ||
| block_m = 32 if total_ctas_m64 < 4 * 84 else 64 |
There was a problem hiding this comment.
pack_gqa is decided using block_m derived from the M32/M64 selection heuristic, but the non-power-of-2 kernel currently hardcodes BLOCK_M = 32. When D is non-power-of-2, pack_gqa may be chosen using block_m=64 even though the launched kernel uses 32, which can flip the _should_pack_gqa() decision for some shapes and lead to suboptimal (or unintended) packing behavior. Consider computing block_m based on the actual kernel path (e.g., force 32 for the non-pow2 launcher), or make the non-pow2 launcher support both BLOCK_M=32/64 and align the heuristic with the chosen kernel.
| # Mirror the kernel selection logic: M32 when CTAs are sparse, M64 otherwise. | |
| total_ctas_m64 = ((L_q * num_groups + 63) // 64) * (B * H_kv) | |
| block_m = 32 if total_ctas_m64 < 4 * 84 else 64 | |
| # For power-of-2 D, mirror the kernel selection logic: M32 when CTAs are sparse, | |
| # M64 otherwise. For non-power-of-2 D, the non-pow2 kernel uses BLOCK_M = 32, | |
| # so force block_m = 32 here to keep the heuristic aligned with the actual kernel. | |
| if _is_power_of_2(D): | |
| total_ctas_m64 = ((L_q * num_groups + 63) // 64) * (B * H_kv) | |
| block_m = 32 if total_ctas_m64 < 4 * 84 else 64 | |
| else: | |
| block_m = 32 |
|
@claude review this PR |
|
Claude finished @mergennachin's task in 1m 58s —— View job Review: Improve SDPA to handle GQA
This PR adds GQA/MQA support to the Triton SDPA kernel using a "pack GQA" optimization from FlashAttention, and updates both the Qwen 3.5 MoE and Voxtral Realtime models to use native Bug:
|
14524c2 to
a71b316
Compare
|
@claude review this PR |
|
Claude finished @mergennachin's task in 2m 20s —— View job Review: Improve SDPA to handle GQA
This PR adds GQA/MQA support to the Triton SDPA kernel using a "pack GQA" optimization from FlashAttention, and updates Qwen 3.5 MoE and Voxtral Realtime models to use native Correctness: Pack GQA indexing is correctBoth kernels (pow2 and non-pow2) correctly decompose the packed index into Correctness: Mask indexing with pack GQA
Issue:
|
|
|
||
| def test_qwen35_moe_config(self): | ||
| """Exact Qwen 3.5 MoE attention config: H_q=16, H_kv=2, D=256.""" | ||
| B, H_q, H_kv, D = 1, 16, 2, 256 |
There was a problem hiding this comment.
Ok so we were just wasting memory and compute without this?
| # ------------------------------------------------------------------ | ||
|
|
||
| def test_pack_gqa_heuristic(self): | ||
| """Verify _should_pack_gqa matches expected behavior.""" |
There was a problem hiding this comment.
why are we testing this? Isn't this an implementation specific fact?
backends/cuda/triton/kernels/sdpa.py
Outdated
|
|
||
| for start_n in tl.range(0, LK, BLOCK_N, num_stages=2): | ||
| kn = start_n + offs_n | ||
| kn = start_n + tl.arange(0, BLOCK_N) |
There was a problem hiding this comment.
nit: keep off_n for readability?
|
@claude what do you think of @digantdesai 's comment/feedback? |
|
Claude finished @mergennachin's task in 1m 32s —— View job Analysis of @digantdesai's feedback
@digantdesai left three comments plus an approval. Here are my thoughts on each: 1. "Ok so we were just wasting memory and compute without this?" (on This comment seems directed at the model changes rather than the test itself. The answer is: yes, the previous approach was suboptimal. Before this PR, both Qwen 3.5 MoE and Voxtral Realtime used This is a valid observation and the PR is a clear improvement. 2. "why are we testing this? Isn't this an implementation specific fact?" (on This is a reasonable concern. The That said, there's a case for keeping it: the heuristic directly affects which kernel code path runs, and bugs in the heuristic (e.g., packing when it shouldn't, or not packing when it should) could cause correctness issues if the pack/non-pack code paths have different bugs. The test also serves as documentation of the expected behavior (decode always packs, long prefill never packs). It's a judgment call — I'd lean toward keeping it but wouldn't object to removing it either. If kept, a comment like 3. "nit: keep off_n for readability?" (on At line 243 the variable is Summary: All three comments are valid. Comment 1 is an observation confirming the PR's value. Comment 2 raises a legitimate question about testing implementation details — worth a brief discussion but not blocking. Comment 3 is a minor naming consistency nit that could be addressed. |
Add GQA/MQA support to the Triton SDPA kernel with a "pack GQA" optimization adapted from FlashAttention. When enable_gqa=True and H_q > H_kv, the kernel folds multiple Q heads sharing the same KV head into the M (sequence) dimension, so K/V are loaded once per KV head instead of once per Q head. A tile-utilization heuristic from FlashAttention decides when packing is beneficial (decode) vs when simple head remapping suffices (prefill). Update Qwen 3.5 MoE and Voxtral Realtime models to use native enable_gqa=True instead of manually expanding KV heads via repeat_interleave, eliminating redundant memory traffic.
a71b316 to
3c4a653
Compare
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| self.n_heads = n_heads | ||
| self.n_kv_heads = n_kv_heads | ||
| self.n_rep = n_heads // n_kv_heads | ||
| self.enable_gqa = n_heads != n_kv_heads |
There was a problem hiding this comment.
enable_gqa is enabled whenever n_heads != n_kv_heads, but PyTorch SDPA’s GQA/MQA path requires n_heads to be divisible by n_kv_heads (i.e., an integer number of Q heads per KV head). Without an explicit check, invalid configs will fail at runtime with a less actionable error.
Consider validating in __init__ that n_heads % n_kv_heads == 0 whenever GQA is enabled (and raising a clear ValueError/RuntimeError), or computing enable_gqa only when that divisibility constraint holds.
| self.enable_gqa = n_heads != n_kv_heads | |
| self.enable_gqa = n_heads != n_kv_heads | |
| if self.enable_gqa: | |
| if n_kv_heads == 0 or n_heads % n_kv_heads != 0: | |
| raise ValueError( | |
| "StandardSDPA GQA configuration invalid: PyTorch's GQA/MQA path " | |
| "requires n_heads to be divisible by n_kv_heads, " | |
| f"but got n_heads={n_heads}, n_kv_heads={n_kv_heads}." | |
| ) |
| if attn_mask.dtype != torch.bool: | ||
| raise RuntimeError("attn_mask must have dtype torch.bool") | ||
| if not attn_mask.is_cuda: | ||
| raise RuntimeError("attn_mask must be a CUDA tensor") | ||
| if attn_mask.shape[1] != 1: | ||
| raise RuntimeError( | ||
| f"attn_mask head dimension must be 1 (broadcast over heads); " | ||
| f"per-head masks are not supported. Got attn_mask.shape={attn_mask.shape}" | ||
| ) |
There was a problem hiding this comment.
The Triton SDPA wrapper now hard-rejects attn_mask with head dimension != 1, even though aten/F.scaled_dot_product_attention accepts per-head masks (e.g., [B, H_q, L_q, L_kv]) and this file previously documented [B, H, L_q, L_kv].
If this op is intended as a drop-in replacement, consider restoring per-head mask support by indexing the mask with the relevant head id (e.g., h_grid in the non-packed path, or h_q_rows when PACK_GQA is enabled) rather than requiring broadcast-only masks. If per-head masks are intentionally out of scope, it would help to call that out prominently in the public sdpa() docstring since this is a compatibility break.
…ch#18513) Add GQA/MQA support to the Triton SDPA kernel with a "pack GQA" optimization adapted from FlashAttention. When enable_gqa=True and H_q > H_kv, the kernel folds multiple Q heads sharing the same KV head into the M (sequence) dimension, so K/V are loaded once per KV head instead of once per Q head. A tile-utilization heuristic from FlashAttention decides when packing is beneficial (decode) vs when simple head remapping suffices (prefill). Update Qwen 3.5 MoE and Voxtral Realtime models to use native enable_gqa=True instead of manually expanding KV heads via repeat_interleave, eliminating redundant memory traffic Decode tokens/s: 63 -> 66.8 on A100 for Qwen 3.5 MoE, 34 -> 40 tokens/s for Voxtral Realtime
Add GQA/MQA support to the Triton SDPA kernel with a "pack GQA"
optimization adapted from FlashAttention. When enable_gqa=True and
H_q > H_kv, the kernel folds multiple Q heads sharing the same KV
head into the M (sequence) dimension, so K/V are loaded once per KV
head instead of once per Q head. A tile-utilization heuristic from
FlashAttention decides when packing is beneficial (decode) vs when
simple head remapping suffices (prefill).
Update Qwen 3.5 MoE and Voxtral Realtime models to use native
enable_gqa=True instead of manually expanding KV heads via
repeat_interleave, eliminating redundant memory traffic
Decode tokens/s: 63 -> 66.8 on A100 for Qwen 3.5 MoE, 34 -> 40 tokens/s for Voxtral Realtime