Skip to content

Improve SDPA to handle GQA and update models to use native GQA#18513

Merged
mergennachin merged 2 commits intomainfrom
mergennachin/gqa
Mar 31, 2026
Merged

Improve SDPA to handle GQA and update models to use native GQA#18513
mergennachin merged 2 commits intomainfrom
mergennachin/gqa

Conversation

@mergennachin
Copy link
Copy Markdown
Contributor

@mergennachin mergennachin commented Mar 26, 2026

Add GQA/MQA support to the Triton SDPA kernel with a "pack GQA"
optimization adapted from FlashAttention. When enable_gqa=True and
H_q > H_kv, the kernel folds multiple Q heads sharing the same KV
head into the M (sequence) dimension, so K/V are loaded once per KV
head instead of once per Q head. A tile-utilization heuristic from
FlashAttention decides when packing is beneficial (decode) vs when
simple head remapping suffices (prefill).

Update Qwen 3.5 MoE and Voxtral Realtime models to use native
enable_gqa=True instead of manually expanding KV heads via
repeat_interleave, eliminating redundant memory traffic

Decode tokens/s: 63 -> 66.8 on A100 for Qwen 3.5 MoE, 34 -> 40 tokens/s for Voxtral Realtime

@mergennachin mergennachin requested a review from lucylq as a code owner March 26, 2026 02:38
Copilot AI review requested due to automatic review settings March 26, 2026 02:38
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Mar 26, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18513

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 75 Pending

As of commit 3c4a653 with merge base 0527015 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 26, 2026
@github-actions
Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR extends the CUDA Triton SDPA implementation to support Grouped Query Attention (GQA/MQA) end-to-end, and updates the Qwen3.5 MoE example to rely on native GQA handling instead of explicitly expanding KV heads.

Changes:

  • Update Qwen3.5 MoE attention to call PyTorch SDPA with enable_gqa=True (removing explicit KV head expansion).
  • Add GQA support to the Triton SDPA kernel, including an optional “pack GQA” optimization with a utilization heuristic.
  • Add a new CUDA test suite covering MHA/GQA/MQA correctness across shapes, masks, causal/non-causal, and pack-gqa behavior.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.

File Description
examples/models/qwen3_5_moe/model.py Switches attention to rely on SDPA’s native GQA support (enable_gqa=True) instead of repeating KV heads.
backends/cuda/triton/kernels/sdpa.py Adds GQA shape validation, head remapping/packing logic, and propagates enable_gqa through kernel launches.
backends/cuda/tests/test_triton_sdpa.py Introduces comprehensive correctness tests for Triton SDPA across MHA/GQA/MQA configurations.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@JacobSzwejbka
Copy link
Copy Markdown
Contributor

@claude review this PR

@claude
Copy link
Copy Markdown

claude bot commented Mar 26, 2026

Claude finished @JacobSzwejbka's task in 2m 5s —— View job


Review: Improve SDPA to handle GQA

  • Read PR diff and all changed files
  • Review backends/cuda/triton/kernels/sdpa.py changes
  • Review backends/cuda/tests/test_triton_sdpa.py (new test file)
  • Review examples/models/qwen3_5_moe/model.py changes
  • Post review feedback

This PR adds GQA/MQA support to the Triton SDPA kernel with a "pack GQA" optimization adapted from FlashAttention, and updates the Qwen 3.5 MoE model to use native GQA. The approach is sound and the test coverage is thorough. A few issues to address:

Bug: _sdpa_abstract typo fix is good but may break existing callers

sdpa.py:1010 — The rename from enable_gq to enable_gqa in the @sdpa.register_fake function is a correct fix for what was clearly a typo in the original code. However, if any existing exported models or cached graph modules reference the old parameter name, this could cause breakage. Since the old name was never functional (it was never passed through to _validate_qkv_shapes), this is likely fine — just worth noting.

Issue: Causal masking is overly restrictive for GQA decode

sdpa.py:942-945 — The check if is_causal and L_q != L_kv raises an error, but this prevents using is_causal=True during decode where L_q=1, L_kv=cache_len. This matches the existing behavior (not introduced by this PR), but it's worth noting that standard SDPA implementations (PyTorch, FlashAttention) handle causal with L_q != L_kv by applying the causal mask relative to the bottom-right corner. The Qwen model works around this by using an explicit bool mask instead, so this isn't blocking, but it does limit the kernel's generality. Fix this →

Issue: _should_pack_gqa heuristic uses hardcoded block_m=64 regardless of actual kernel selection

sdpa.py:949 — The pack decision is made with block_m=64, but the actual kernel may use BLOCK_M=32 (selected at line 769 based on CTA count threshold). When BLOCK_M=32 is selected, the packing heuristic computed with block_m=64 may be suboptimal. For decode with L_q=1 and num_groups >= 2, packing is always beneficial regardless of block size, so this doesn't affect the critical path. But for intermediate L_q values (e.g., 16-32), the heuristic could make the wrong decision.

Issue: Mask handling doesn't account for packed GQA row positions

sdpa.py:256-262 (non-pow2 kernel) and sdpa.py:423-433 (pow2 kernel) — The mask is indexed by seq_pos (the sequence position component of the packed index), which is correct since the mask shape is [B, 1, L_q, L_kv] and broadcasts over heads. This is consistent and works because the mask is explicitly required to have head dim = 1. Good that _prepare_mask_params enforces attn_mask.shape[1] != 1 (line 706). Copilot's review comment about this was a valid concern, but the code handles it correctly by requiring broadcast masks.

Suggestion: Non-pow2 kernel creates a dummy tensor when no mask is present

sdpa.py:854torch.empty((1,), device=query.device, dtype=torch.bool) allocates a small dummy tensor just to have a valid pointer. This allocation happens on every call without a mask. Consider using 0 (like the pow2 path at line 777) instead, which avoids the allocation entirely. Fix this →

Nit: SEQLEN_PAIRS_LONG is defined but never used

test_triton_sdpa.py:105-108 — The SEQLEN_PAIRS_LONG list is defined but never referenced in any test. Either remove it or add a test using it (e.g., a long-sequence stress test). Fix this →

Nit: _max_relative_error is defined but never used

test_triton_sdpa.py:71-74 — The helper function _max_relative_error is defined but all tests use _max_abs_error instead. Remove if unneeded.

Observation: Model change removes explicit KV head expansion

model.py:267-278 — Replacing the explicit repeat_interleave with enable_gqa=True is clean and correct. This eliminates the memory overhead of expanding K/V heads for Qwen's 8:1 ratio (16 Q heads, 2 KV heads). The performance improvement (66.8 tok/s on A100) likely comes partly from this reduced memory traffic.

Summary

The core kernel changes are well-structured. The pack GQA optimization is a known technique from FlashAttention and the implementation follows the same pattern. The two-path strategy (pack for decode, remap for prefill) with the tile-utilization heuristic is reasonable. The test coverage is comprehensive across MHA/GQA/MQA configurations with various head dims, sequence lengths, masks, and edge cases.

Main things to address:

  1. The hardcoded block_m=64 for the packing heuristic vs actual BLOCK_M=32 selection
  2. The dummy tensor allocation in the non-pow2 path (minor perf)
  3. Clean up unused test constants/helpers

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 4 out of 4 changed files in this pull request and generated 1 comment.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +949 to +951
# Mirror the kernel selection logic: M32 when CTAs are sparse, M64 otherwise.
total_ctas_m64 = ((L_q * num_groups + 63) // 64) * (B * H_kv)
block_m = 32 if total_ctas_m64 < 4 * 84 else 64
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pack_gqa is decided using block_m derived from the M32/M64 selection heuristic, but the non-power-of-2 kernel currently hardcodes BLOCK_M = 32. When D is non-power-of-2, pack_gqa may be chosen using block_m=64 even though the launched kernel uses 32, which can flip the _should_pack_gqa() decision for some shapes and lead to suboptimal (or unintended) packing behavior. Consider computing block_m based on the actual kernel path (e.g., force 32 for the non-pow2 launcher), or make the non-pow2 launcher support both BLOCK_M=32/64 and align the heuristic with the chosen kernel.

Suggested change
# Mirror the kernel selection logic: M32 when CTAs are sparse, M64 otherwise.
total_ctas_m64 = ((L_q * num_groups + 63) // 64) * (B * H_kv)
block_m = 32 if total_ctas_m64 < 4 * 84 else 64
# For power-of-2 D, mirror the kernel selection logic: M32 when CTAs are sparse,
# M64 otherwise. For non-power-of-2 D, the non-pow2 kernel uses BLOCK_M = 32,
# so force block_m = 32 here to keep the heuristic aligned with the actual kernel.
if _is_power_of_2(D):
total_ctas_m64 = ((L_q * num_groups + 63) // 64) * (B * H_kv)
block_m = 32 if total_ctas_m64 < 4 * 84 else 64
else:
block_m = 32

Copilot uses AI. Check for mistakes.
@mergennachin
Copy link
Copy Markdown
Contributor Author

@claude review this PR

@claude
Copy link
Copy Markdown

claude bot commented Mar 30, 2026

Claude finished @mergennachin's task in 1m 58s —— View job


Review: Improve SDPA to handle GQA

  • Read PR diff and all changed files
  • Review backends/cuda/triton/kernels/sdpa.py — kernel changes
  • Review backends/cuda/tests/test_triton_sdpa.py — new test file
  • Review examples/models/qwen3_5_moe/model.py — model changes
  • Review examples/models/voxtral_realtime/model.py — model changes
  • Post review feedback

This PR adds GQA/MQA support to the Triton SDPA kernel using a "pack GQA" optimization from FlashAttention, and updates both the Qwen 3.5 MoE and Voxtral Realtime models to use native enable_gqa=True instead of manually expanding KV heads via repeat_interleave. The approach is solid and the test coverage is thorough. Here are my findings:

Bug: _should_pack_gqa heuristic ignores the non-pow2 kernel path

sdpa.py:948-952 — The pack decision computes block_m using the pow2 M32/M64 selection heuristic, but this logic runs unconditionally for both pow2 and non-pow2 D. The non-pow2 kernel (_sdpa_fwd_kernel_non_pow2) always uses BLOCK_M=32 (line 836), so the heuristic should force block_m=32 when D is non-pow2. Currently, for a non-pow2 D with enough CTAs, the heuristic will compute with block_m=64 while the actual kernel uses 32, potentially making the wrong packing decision for intermediate L_q values.

Copilot flagged this same issue. The fix is straightforward — gate the M32/M64 selection on _is_power_of_2(D):

if _is_power_of_2(D):
    total_ctas_m64 = ((L_q * num_groups + 63) // 64) * (B * H_kv)
    block_m = 32 if total_ctas_m64 < 4 * 84 else 64
else:
    block_m = 32

Fix this →

Issue: _sdpa_abstract doesn't forward enable_gqa to validation

sdpa.py:1025_sdpa_abstract now correctly passes enable_gqa to _validate_qkv_shapes, but note that _validate_sdpa_inputs (called at line 934 for the real kernel) no longer validates enable_gqa — the old check if enable_gqa is not False: raise was removed. This is intentional, but _sdpa_abstract doesn't call _validate_sdpa_inputs at all, meaning it won't catch invalid dropout_p, non-CUDA tensors, or wrong dtypes during export tracing. This is a pre-existing gap, not introduced by this PR.

Observation: _prepare_mask_params now enforces attn_mask.shape[1] == 1

sdpa.py:706-710 — This is a good defensive addition. The kernel indexes masks by seq_pos only (not by head), so per-head masks (H > 1) would silently apply head 0's mask to all heads. The new validation catches this upfront with a clear error message. The docstring at line 926 is also updated to specify [B, 1, L_q, L_kv].

Observation: Non-pow2 kernel now uses mask_ptr = 0 instead of dummy tensor

sdpa.py:854 — The non-pow2 no-mask path previously allocated torch.empty((1,), ...) as a dummy pointer. It now uses 0, consistent with the pow2 path at line 777. This eliminates a small per-call allocation.

Correctness: Causal masking with pack GQA

sdpa.py:252-254 (non-pow2) and sdpa.py:435-439 (pow2) — The causal mask comparison kn[None, :] > seq_pos[:, None] uses seq_pos (derived from the packed index decomposition) rather than offs_packed. This is correct: in pack mode, seq_pos = offs_packed // NUM_GROUPS extracts the actual sequence position, so the causal relationship is between the true Q sequence position and the K sequence position. In non-pack mode, seq_pos = offs_packed which is just the raw sequence offset.

However, note the restriction at line 942: is_causal requires L_q == L_kv, so causal + pack GQA only fires for prefill (where L_q = L_kv). For decode (L_q=1, L_kv=cache_len), users must provide an explicit bool mask. This is documented in the error message, which is good.

Correctness: Mask indexing with pack GQA

sdpa.py:256-262 (non-pow2) and sdpa.py:423-433 (pow2) — The mask is indexed using seq_pos[:, None] * stride_mq, not the packed offset. Since attn_mask.shape[1] is enforced to be 1, the mask broadcasts correctly over all Q heads within a group. Different heads sharing the same KV head and the same sequence position will load the same mask row, which is correct behavior for a broadcast mask.

Model changes look correct

Qwen 3.5 MoE (model.py:270-274) — The repeat_interleave expansion is removed and enable_gqa=True is passed instead. The mask is [B, 1, L_q, L_kv] (via unsqueeze(0).unsqueeze(0) on a [T, max_len] mask), which is compatible with the new head-dim-1 requirement.

Voxtral Realtime (model.py:514-557) — StandardSDPA now stores self.enable_gqa = n_heads != n_kv_heads and passes it through. The repeat_interleave call is removed. The mask comes from _build_causal_mask_bool which returns [1, 1, L_q, L_kv], compatible with the broadcast requirement.

Test coverage

The new test file is comprehensive — it covers MHA, GQA (2x, 4x, 8x), MQA, pow2 and non-pow2 head dims, causal/non-causal, bool masks, decode (L_q=1) and prefill (L_q=L_kv), batch sizes, the Qwen config specifically, edge cases (all-masked, per-head mask rejection, causal L_q!=L_kv rejection), and the heuristic itself. The 0.05 max-abs-error tolerance against float32 reference is reasonable for bf16.

Summary

The implementation is well-structured and the pack GQA optimization follows the established FlashAttention pattern. The main actionable item is the block_m heuristic bug for non-pow2 head dims. The rest of the changes are clean and correct.


View job

@mergennachin mergennachin changed the title Improve SDPA to handle GQA Improve SDPA to handle GQA and update models to use native GQA Mar 30, 2026
@mergennachin
Copy link
Copy Markdown
Contributor Author

@claude review this PR

@claude
Copy link
Copy Markdown

claude bot commented Mar 30, 2026

Claude finished @mergennachin's task in 2m 20s —— View job


Review: Improve SDPA to handle GQA

  • Read PR diff and all changed files
  • Review backends/cuda/triton/kernels/sdpa.py — kernel changes
  • Review backends/cuda/tests/test_triton_sdpa.py — new test file
  • Review examples/models/qwen3_5_moe/model.py — model changes
  • Review examples/models/voxtral_realtime/model.py — model changes
  • Post review feedback

This PR adds GQA/MQA support to the Triton SDPA kernel using a "pack GQA" optimization from FlashAttention, and updates Qwen 3.5 MoE and Voxtral Realtime models to use native enable_gqa=True. The implementation is solid — the two previous reviews covered the ground well, and I can see the non-pow2 block_m heuristic issue has been fixed since the first review. Here are my findings on the current state:

Correctness: Pack GQA indexing is correct

Both kernels (pow2 and non-pow2) correctly decompose the packed index into seq_pos = offs_packed // NUM_GROUPS and h_within = offs_packed % NUM_GROUPS. The interleaving order [h0_pos0, h1_pos0, ..., h(G-1)_pos0, h0_pos1, ...] ensures that consecutive rows within a tile access the same sequence position across different Q heads, which is optimal for mask and causal comparisons (they share the same seq_pos). The K/V loads use h_kv = h_grid (pack mode) or h_kv = h_grid // NUM_GROUPS (non-pack mode), both correct.

Correctness: Mask indexing with pack GQA

sdpa.py:256-262 (non-pow2) and sdpa.py:423-433 (pow2) — The mask is indexed by seq_pos, not offs_packed. Since attn_mask.shape[1] is enforced to be 1 (line 706), different Q heads at the same sequence position correctly load the same mask row. This is the right behavior for a broadcast mask.

Issue: _validate_sdpa_inputs accepts but ignores enable_gqa

sdpa.py:664-689 — The function takes enable_gqa as a parameter but doesn't use it (the old rejection check was removed). This parameter should either be removed from the signature or validated. Currently it serves no purpose in this function — the actual GQA validation happens in _validate_qkv_shapes. Minor cleanup, not a bug.

Issue: l_i initialization differs between pow2 and non-pow2 kernels

sdpa.py:233 (non-pow2): l_i = tl.full((BLOCK_M,), 1.0, dtype=tl.float32)
sdpa.py:403 (pow2): l_i = tl.zeros([BLOCK_M], dtype=tl.float32)

The non-pow2 kernel initializes l_i to 1.0 while the pow2 kernel initializes it to 0.0. This is a pre-existing difference (not introduced by this PR), but it means the two kernels handle the all-masked case differently. With l_i=1.0, the non-pow2 kernel will output acc/1.0 = 0.0 for all-masked rows. With l_i=0.0, the pow2 kernel relies on inv_l_i = tl.where(l_i > 0, 1.0 / l_i, 0.0) at line 465 to produce 0.0. Both are correct for producing zeros on all-masked blocks, but the inconsistency could be confusing for future maintainers. Not blocking.

Issue: scale=0.0 sentinel is fragile

sdpa.py:960sm_scale = 1.0 / math.sqrt(D) if scale == 0.0 else scale. Using 0.0 as a sentinel to mean "use default" is fragile — a caller who genuinely wants scale=0.0 (unusual but valid) would get 1/sqrt(D) instead. This is pre-existing behavior, but worth noting since the PR touches this path. Consider using None as the sentinel if this API is ever revised.

Observation: Heuristic block_m computation is now correct for non-pow2

sdpa.py:948-956 — The current code correctly forces block_m = 32 for non-pow2 D (matching the non-pow2 kernel's hardcoded BLOCK_M=32), and only runs the M32/M64 selection heuristic for pow2 D. This addresses the issue flagged by both Copilot and the previous Claude review.

Observation: Heuristic CTA count is consistent across pack/non-pack

The heuristic at line 955 computes total_ctas_m64 = ((L_q * num_groups + 63) // 64) * (B * H_kv). This assumes the packed configuration, but L_q * num_groups * H_kv == L_q * H_q, so the total CTA count equals ((L_q * H_q / H_kv + 63) // 64) * (B * H_kv). The actual kernel at line 766 computes ((Lq_packed + 63) // 64) * (B * H_grid) which is identical when packing (and different only in the ceiling division rounding when not packing). This is acceptable since the heuristic is conservative — when packing is not chosen, the kernel selection at line 766 re-derives the actual CTA count with the correct (non-packed) dimensions.

Model changes are clean

Qwen 3.5 MoE (model.py:270-274) — Removes repeat_interleave and passes enable_gqa=True. The mask shape [1, 1, T, max_len] (via unsqueeze(0).unsqueeze(0)) has head dim 1, compatible with the new enforcement.

Voxtral Realtime (model.py:514-557) — Stores self.enable_gqa = n_heads != n_kv_heads and passes it through. Removes repeat_interleave. The mask from _build_causal_mask_bool returns [1, 1, L_q, L_kv], compatible with the broadcast requirement.

Test coverage is thorough

The test file covers MHA, GQA (2x, 4x, 8x), MQA, pow2 and non-pow2 head dims, causal/non-causal, bool masks, decode and prefill, batch sizes, the Qwen config specifically, edge cases (all-masked, per-head mask rejection, causal L_q!=L_kv rejection), and the heuristic. The 0.05 max-abs-error tolerance against float32 reference is reasonable for bf16. The import approach has been fixed per Copilot's suggestion (using executorch.backends.cuda.triton.kernels.sdpa instead of sys.path mutation).

Summary

The implementation is well-structured and correct. The pack GQA optimization follows the established FlashAttention pattern and the two-path strategy (pack for decode, remap for prefill) is sound. The main issues from previous reviews (non-pow2 block_m heuristic, dummy tensor allocation, sys.path mutation) have all been addressed. Remaining nits are minor: the unused enable_gqa parameter in _validate_sdpa_inputs and the l_i initialization inconsistency between the two kernels.


View job


def test_qwen35_moe_config(self):
"""Exact Qwen 3.5 MoE attention config: H_q=16, H_kv=2, D=256."""
B, H_q, H_kv, D = 1, 16, 2, 256
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok so we were just wasting memory and compute without this?

# ------------------------------------------------------------------

def test_pack_gqa_heuristic(self):
"""Verify _should_pack_gqa matches expected behavior."""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we testing this? Isn't this an implementation specific fact?


for start_n in tl.range(0, LK, BLOCK_N, num_stages=2):
kn = start_n + offs_n
kn = start_n + tl.arange(0, BLOCK_N)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: keep off_n for readability?

@mergennachin
Copy link
Copy Markdown
Contributor Author

@claude what do you think of @digantdesai 's comment/feedback?

@claude
Copy link
Copy Markdown

claude bot commented Mar 31, 2026

Claude finished @mergennachin's task in 1m 32s —— View job


Analysis of @digantdesai's feedback

  • Read the specific code locations referenced by digantdesai's comments
  • Analyze each comment and provide thoughts

@digantdesai left three comments plus an approval. Here are my thoughts on each:

1. "Ok so we were just wasting memory and compute without this?" (on test_triton_sdpa.py:414, referencing the Qwen config test)

This comment seems directed at the model changes rather than the test itself. The answer is: yes, the previous approach was suboptimal. Before this PR, both Qwen 3.5 MoE and Voxtral Realtime used repeat_interleave to expand K/V heads to match Q heads before calling SDPA. For Qwen's 8:1 ratio (16 Q heads, 2 KV heads), that meant materializing 8x copies of K and V tensors — extra memory allocation and bandwidth for data the kernel could handle implicitly. The performance numbers in the PR description (63 → 66.8 tok/s for Qwen, 34 → 40 tok/s for Voxtral on A100) confirm meaningful wasted bandwidth. The waste scales with the GQA ratio, so models with higher ratios (like MQA with only 1 KV head) would benefit even more.

This is a valid observation and the PR is a clear improvement.

2. "why are we testing this? Isn't this an implementation specific fact?" (on test_triton_sdpa.py:439, about test_pack_gqa_heuristic)

This is a reasonable concern. The _should_pack_gqa heuristic is an internal optimization decision — whether to fold Q heads into the sequence dimension or use simple head remapping. Both paths produce identical results; only performance differs. Testing it pins an implementation detail that could change without affecting correctness.

That said, there's a case for keeping it: the heuristic directly affects which kernel code path runs, and bugs in the heuristic (e.g., packing when it shouldn't, or not packing when it should) could cause correctness issues if the pack/non-pack code paths have different bugs. The test also serves as documentation of the expected behavior (decode always packs, long prefill never packs). It's a judgment call — I'd lean toward keeping it but wouldn't object to removing it either. If kept, a comment like # Pin heuristic behavior to catch regressions in the pack/non-pack code path selection would help future readers understand why an implementation detail is being tested.

3. "nit: keep off_n for readability?" (on sdpa.py:243, about renaming off_n to kn)

At line 243 the variable is kn = start_n + tl.arange(0, BLOCK_N), which was previously called something like off_n or offs_n. The rename to kn is more compact but less self-descriptive — off_n or offs_n follows the established Triton convention (matching offs_d, offs_packed used elsewhere in the same kernels). Keeping offs_n would be more consistent with the surrounding naming. This is a minor nit but I agree with digantdesai that the original name was clearer.


Summary: All three comments are valid. Comment 1 is an observation confirming the PR's value. Comment 2 raises a legitimate question about testing implementation details — worth a brief discussion but not blocking. Comment 3 is a minor naming consistency nit that could be addressed.

View job

Add GQA/MQA support to the Triton SDPA kernel with a "pack GQA"
optimization adapted from FlashAttention. When enable_gqa=True and
H_q > H_kv, the kernel folds multiple Q heads sharing the same KV
head into the M (sequence) dimension, so K/V are loaded once per KV
head instead of once per Q head. A tile-utilization heuristic from
FlashAttention decides when packing is beneficial (decode) vs when
simple head remapping suffices (prefill).

Update Qwen 3.5 MoE and Voxtral Realtime models to use native
enable_gqa=True instead of manually expanding KV heads via
repeat_interleave, eliminating redundant memory traffic.
Copilot AI review requested due to automatic review settings March 31, 2026 01:31
@mergennachin mergennachin merged commit 99d7b77 into main Mar 31, 2026
186 of 189 checks passed
@mergennachin mergennachin deleted the mergennachin/gqa branch March 31, 2026 01:35
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.n_rep = n_heads // n_kv_heads
self.enable_gqa = n_heads != n_kv_heads
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

enable_gqa is enabled whenever n_heads != n_kv_heads, but PyTorch SDPA’s GQA/MQA path requires n_heads to be divisible by n_kv_heads (i.e., an integer number of Q heads per KV head). Without an explicit check, invalid configs will fail at runtime with a less actionable error.

Consider validating in __init__ that n_heads % n_kv_heads == 0 whenever GQA is enabled (and raising a clear ValueError/RuntimeError), or computing enable_gqa only when that divisibility constraint holds.

Suggested change
self.enable_gqa = n_heads != n_kv_heads
self.enable_gqa = n_heads != n_kv_heads
if self.enable_gqa:
if n_kv_heads == 0 or n_heads % n_kv_heads != 0:
raise ValueError(
"StandardSDPA GQA configuration invalid: PyTorch's GQA/MQA path "
"requires n_heads to be divisible by n_kv_heads, "
f"but got n_heads={n_heads}, n_kv_heads={n_kv_heads}."
)

Copilot uses AI. Check for mistakes.
Comment on lines 704 to +712
if attn_mask.dtype != torch.bool:
raise RuntimeError("attn_mask must have dtype torch.bool")
if not attn_mask.is_cuda:
raise RuntimeError("attn_mask must be a CUDA tensor")
if attn_mask.shape[1] != 1:
raise RuntimeError(
f"attn_mask head dimension must be 1 (broadcast over heads); "
f"per-head masks are not supported. Got attn_mask.shape={attn_mask.shape}"
)
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Triton SDPA wrapper now hard-rejects attn_mask with head dimension != 1, even though aten/F.scaled_dot_product_attention accepts per-head masks (e.g., [B, H_q, L_q, L_kv]) and this file previously documented [B, H, L_q, L_kv].

If this op is intended as a drop-in replacement, consider restoring per-head mask support by indexing the mask with the relevant head id (e.g., h_grid in the non-packed path, or h_q_rows when PACK_GQA is enabled) rather than requiring broadcast-only masks. If per-head masks are intentionally out of scope, it would help to call that out prominently in the public sdpa() docstring since this is a compatibility break.

Copilot uses AI. Check for mistakes.
rascani pushed a commit to rascani/executorch that referenced this pull request Apr 1, 2026
…ch#18513)

Add GQA/MQA support to the Triton SDPA kernel with a "pack GQA"
optimization adapted from FlashAttention. When enable_gqa=True and
H_q > H_kv, the kernel folds multiple Q heads sharing the same KV
head into the M (sequence) dimension, so K/V are loaded once per KV
head instead of once per Q head. A tile-utilization heuristic from
FlashAttention decides when packing is beneficial (decode) vs when
simple head remapping suffices (prefill).

Update Qwen 3.5 MoE and Voxtral Realtime models to use native
enable_gqa=True instead of manually expanding KV heads via
repeat_interleave, eliminating redundant memory traffic


Decode tokens/s: 63 -> 66.8 on A100 for Qwen 3.5 MoE, 34 -> 40 tokens/s
for Voxtral Realtime
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants