diff --git a/backends/cuda/tests/test_triton_sdpa.py b/backends/cuda/tests/test_triton_sdpa.py new file mode 100644 index 00000000000..233a21001d2 --- /dev/null +++ b/backends/cuda/tests/test_triton_sdpa.py @@ -0,0 +1,515 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Comprehensive tests for the Triton SDPA kernel. + +Tests MHA, GQA, MQA with various head dims, sequence lengths, causal/non-causal, +and bool masks. Reference outputs are computed using torch SDPA with expanded KV +heads (for GQA/MQA) in float32 for numerical stability. + +Test parametrization adapted from FlashAttention (tests/cute/test_flash_attn.py). +""" + +import itertools +import unittest + +import torch +import torch.nn.functional as F + + +def _skip_if_no_cuda(): + if not torch.cuda.is_available(): + raise unittest.SkipTest("CUDA not available") + if not torch.cuda.is_bf16_supported(): + raise unittest.SkipTest("BF16 not supported on this GPU") + + +def _import_sdpa(): + from executorch.backends.cuda.triton.kernels.sdpa import sdpa + + return sdpa + + +def _reference_sdpa(q, k, v, attn_mask=None, is_causal=False, scale=None): + """Compute reference SDPA in float32 with expanded KV heads for GQA. + + Adapted from FlashAttention's testing.py: expand KV heads via + repeat_interleave, upcast to float32, use torch SDPA. + """ + H_q = q.shape[1] + H_kv = k.shape[1] + num_groups = H_q // H_kv + + # Expand KV heads for GQA/MQA + if num_groups > 1: + k = k.repeat_interleave(num_groups, dim=1) + v = v.repeat_interleave(num_groups, dim=1) + + # Expand mask head dim if needed + if attn_mask is not None and attn_mask.shape[1] == 1 and H_q > 1: + attn_mask = attn_mask.expand(-1, H_q, -1, -1) + + # Upcast to float32 for reference accuracy + return F.scaled_dot_product_attention( + q.float(), + k.float(), + v.float(), + attn_mask=attn_mask, + is_causal=is_causal, + scale=scale, + ) + + +def _max_abs_error(out, ref): + return (out.float() - ref.float()).abs().max().item() + + +# --------------------------------------------------------------------------- +# Test configurations adapted from FlashAttention +# --------------------------------------------------------------------------- + +# Head dimensions: power-of-2 and non-power-of-2 +HEAD_DIMS_POW2 = [64, 128, 256] +HEAD_DIMS_NON_POW2 = [80, 96] + +# Sequence length pairs (seqlen_q, seqlen_kv) adapted from FlashAttention +# Note: Lk must be >= BLOCK_N (32) to avoid unmasked zero-padding in K +# that dilutes softmax. This is a pre-existing kernel limitation for very +# short KV sequences. +SEQLEN_PAIRS_BASIC = [ + (1, 64), + (1, 128), + (1, 512), + (4, 128), + (64, 64), + (64, 128), + (128, 128), + (128, 256), + (256, 256), +] + +# GQA configurations: (H_q, H_kv, label) +GQA_CONFIGS = [ + (4, 4, "mha"), # MHA: 1:1 + (6, 3, "gqa_2x"), # GQA: 2 Q heads per KV head + (8, 2, "gqa_4x"), # GQA: 4 Q heads per KV head + (16, 2, "gqa_8x"), # GQA: 8 Q heads per KV head (Qwen 3.5 MoE config) + (6, 1, "mqa"), # MQA: all Q heads share 1 KV head +] + + +class TestTritonSdpa(unittest.TestCase): + """Test Triton SDPA kernel correctness against PyTorch reference.""" + + @classmethod + def setUpClass(cls): + _skip_if_no_cuda() + cls.sdpa = _import_sdpa() + + # ------------------------------------------------------------------ + # MHA tests (no GQA, backwards compatibility) + # ------------------------------------------------------------------ + + def test_mha_basic(self): + """MHA with various seqlens, pow2 head dims, no mask.""" + for D, (Lq, Lk) in itertools.product(HEAD_DIMS_POW2, SEQLEN_PAIRS_BASIC): + if Lq > Lk: + continue # skip invalid causal configs + with self.subTest(D=D, Lq=Lq, Lk=Lk): + B, H = 2, 4 + torch.manual_seed(42) + q = torch.randn(B, H, Lq, D, dtype=torch.bfloat16, device="cuda") + k = torch.randn(B, H, Lk, D, dtype=torch.bfloat16, device="cuda") + v = torch.randn(B, H, Lk, D, dtype=torch.bfloat16, device="cuda") + + out = self.sdpa(q, k, v) + ref = _reference_sdpa(q, k, v) + + self.assertFalse(torch.isnan(out).any(), "NaN in output") + self.assertLess( + _max_abs_error(out, ref), 0.05, f"D={D} Lq={Lq} Lk={Lk}" + ) + + def test_mha_causal(self): + """MHA with causal masking.""" + for D in HEAD_DIMS_POW2: + for L in [64, 128, 256]: + with self.subTest(D=D, L=L): + B, H = 2, 4 + torch.manual_seed(42) + q = torch.randn(B, H, L, D, dtype=torch.bfloat16, device="cuda") + k = torch.randn(B, H, L, D, dtype=torch.bfloat16, device="cuda") + v = torch.randn(B, H, L, D, dtype=torch.bfloat16, device="cuda") + + out = self.sdpa(q, k, v, is_causal=True) + ref = _reference_sdpa(q, k, v, is_causal=True) + + self.assertFalse(torch.isnan(out).any()) + self.assertLess(_max_abs_error(out, ref), 0.05) + + def test_mha_bool_mask(self): + """MHA with explicit bool attention mask.""" + B, H, D = 1, 4, 64 + for Lq, Lk in [(4, 128), (64, 64), (1, 256)]: + with self.subTest(Lq=Lq, Lk=Lk): + torch.manual_seed(42) + q = torch.randn(B, H, Lq, D, dtype=torch.bfloat16, device="cuda") + k = torch.randn(B, H, Lk, D, dtype=torch.bfloat16, device="cuda") + v = torch.randn(B, H, Lk, D, dtype=torch.bfloat16, device="cuda") + + # Sparse mask: only first half of KV visible + mask = torch.zeros(B, 1, Lq, Lk, dtype=torch.bool, device="cuda") + mask[:, :, :, : Lk // 2] = True + + out = self.sdpa(q, k, v, attn_mask=mask) + ref = _reference_sdpa(q, k, v, attn_mask=mask) + + self.assertFalse(torch.isnan(out).any()) + self.assertLess(_max_abs_error(out, ref), 0.05) + + def test_mha_non_pow2_head_dim(self): + """MHA with non-power-of-2 head dimensions.""" + for D in HEAD_DIMS_NON_POW2: + for Lq, Lk in [(1, 64), (4, 128), (64, 64), (128, 128)]: + if Lq > Lk: + continue + with self.subTest(D=D, Lq=Lq, Lk=Lk): + B, H = 1, 4 + torch.manual_seed(42) + q = torch.randn(B, H, Lq, D, dtype=torch.bfloat16, device="cuda") + k = torch.randn(B, H, Lk, D, dtype=torch.bfloat16, device="cuda") + v = torch.randn(B, H, Lk, D, dtype=torch.bfloat16, device="cuda") + + out = self.sdpa(q, k, v) + ref = _reference_sdpa(q, k, v) + + self.assertFalse(torch.isnan(out).any()) + self.assertLess(_max_abs_error(out, ref), 0.05) + + def test_mha_non_pow2_causal(self): + """MHA with non-pow2 head dim and causal masking.""" + for D in HEAD_DIMS_NON_POW2: + for L in [64, 128]: + with self.subTest(D=D, L=L): + B, H = 1, 4 + torch.manual_seed(42) + q = torch.randn(B, H, L, D, dtype=torch.bfloat16, device="cuda") + k = torch.randn(B, H, L, D, dtype=torch.bfloat16, device="cuda") + v = torch.randn(B, H, L, D, dtype=torch.bfloat16, device="cuda") + + out = self.sdpa(q, k, v, is_causal=True) + ref = _reference_sdpa(q, k, v, is_causal=True) + + self.assertFalse(torch.isnan(out).any()) + self.assertLess(_max_abs_error(out, ref), 0.05) + + # ------------------------------------------------------------------ + # GQA tests + # ------------------------------------------------------------------ + + def test_gqa_decode(self): + """GQA decode (seqlen_q=1).""" + for (H_q, H_kv, label), D, Lk in itertools.product( + GQA_CONFIGS, [64, 128, 256], [64, 128, 512] + ): + if H_q == H_kv: + continue # skip MHA, tested above + with self.subTest(label=label, D=D, Lk=Lk): + B, Lq = 1, 1 + torch.manual_seed(42) + q = torch.randn(B, H_q, Lq, D, dtype=torch.bfloat16, device="cuda") + k = torch.randn(B, H_kv, Lk, D, dtype=torch.bfloat16, device="cuda") + v = torch.randn(B, H_kv, Lk, D, dtype=torch.bfloat16, device="cuda") + + out = self.sdpa(q, k, v, enable_gqa=True) + ref = _reference_sdpa(q, k, v) + + self.assertEqual(out.shape, (B, H_q, Lq, D)) + self.assertFalse(torch.isnan(out).any()) + self.assertLess( + _max_abs_error(out, ref), 0.05, f"{label} D={D} Lk={Lk}" + ) + + def test_gqa_decode_with_mask(self): + """GQA decode with bool attention mask.""" + for H_q, H_kv, label in GQA_CONFIGS: + if H_q == H_kv: + continue + with self.subTest(label=label): + B, Lq, Lk, D = 1, 1, 256, 128 + torch.manual_seed(42) + q = torch.randn(B, H_q, Lq, D, dtype=torch.bfloat16, device="cuda") + k = torch.randn(B, H_kv, Lk, D, dtype=torch.bfloat16, device="cuda") + v = torch.randn(B, H_kv, Lk, D, dtype=torch.bfloat16, device="cuda") + + # Mask: only first 100 positions visible + mask = torch.zeros(B, 1, Lq, Lk, dtype=torch.bool, device="cuda") + mask[:, :, :, :100] = True + + out = self.sdpa(q, k, v, attn_mask=mask, enable_gqa=True) + ref = _reference_sdpa(q, k, v, attn_mask=mask) + + self.assertFalse(torch.isnan(out).any()) + self.assertLess(_max_abs_error(out, ref), 0.05) + + def test_gqa_short_seqlen(self): + """GQA with short seqlen_q (2-8).""" + for Lq in [2, 4, 8]: + for H_q, H_kv, label in [(8, 2, "gqa_4x"), (16, 2, "gqa_8x")]: + with self.subTest(label=label, Lq=Lq): + B, Lk, D = 1, 256, 128 + torch.manual_seed(42) + q = torch.randn(B, H_q, Lq, D, dtype=torch.bfloat16, device="cuda") + k = torch.randn(B, H_kv, Lk, D, dtype=torch.bfloat16, device="cuda") + v = torch.randn(B, H_kv, Lk, D, dtype=torch.bfloat16, device="cuda") + + out = self.sdpa(q, k, v, enable_gqa=True) + ref = _reference_sdpa(q, k, v) + + self.assertFalse(torch.isnan(out).any()) + self.assertLess(_max_abs_error(out, ref), 0.05) + + def test_gqa_prefill(self): + """GQA prefill (long seqlen_q).""" + for (H_q, H_kv, label), L in itertools.product( + [(8, 2, "gqa_4x"), (16, 2, "gqa_8x"), (6, 1, "mqa")], + [64, 128, 256], + ): + with self.subTest(label=label, L=L): + B, D = 1, 128 + torch.manual_seed(42) + q = torch.randn(B, H_q, L, D, dtype=torch.bfloat16, device="cuda") + k = torch.randn(B, H_kv, L, D, dtype=torch.bfloat16, device="cuda") + v = torch.randn(B, H_kv, L, D, dtype=torch.bfloat16, device="cuda") + + out = self.sdpa(q, k, v, is_causal=True, enable_gqa=True) + ref = _reference_sdpa(q, k, v, is_causal=True) + + self.assertEqual(out.shape, (B, H_q, L, D)) + self.assertFalse(torch.isnan(out).any()) + self.assertLess(_max_abs_error(out, ref), 0.05) + + def test_gqa_non_pow2_head_dim(self): + """GQA with non-power-of-2 head dimensions.""" + for D in HEAD_DIMS_NON_POW2: + for Lq, Lk in [(1, 128), (4, 200), (64, 64)]: + with self.subTest(D=D, Lq=Lq, Lk=Lk): + B, H_q, H_kv = 1, 8, 2 + torch.manual_seed(42) + q = torch.randn(B, H_q, Lq, D, dtype=torch.bfloat16, device="cuda") + k = torch.randn(B, H_kv, Lk, D, dtype=torch.bfloat16, device="cuda") + v = torch.randn(B, H_kv, Lk, D, dtype=torch.bfloat16, device="cuda") + + out = self.sdpa(q, k, v, enable_gqa=True) + ref = _reference_sdpa(q, k, v) + + self.assertFalse(torch.isnan(out).any()) + self.assertLess( + _max_abs_error(out, ref), 0.05, f"D={D} Lq={Lq} Lk={Lk}" + ) + + def test_gqa_causal_prefill(self): + """GQA with causal masking during prefill.""" + for H_q, H_kv, label in [(8, 2, "gqa_4x"), (6, 1, "mqa")]: + for L in [64, 128]: + with self.subTest(label=label, L=L): + B, D = 2, 128 + torch.manual_seed(42) + q = torch.randn(B, H_q, L, D, dtype=torch.bfloat16, device="cuda") + k = torch.randn(B, H_kv, L, D, dtype=torch.bfloat16, device="cuda") + v = torch.randn(B, H_kv, L, D, dtype=torch.bfloat16, device="cuda") + + out = self.sdpa(q, k, v, is_causal=True, enable_gqa=True) + ref = _reference_sdpa(q, k, v, is_causal=True) + + self.assertFalse(torch.isnan(out).any()) + self.assertLess(_max_abs_error(out, ref), 0.05) + + def test_gqa_causal_decode_with_mask(self): + """GQA decode with causal-like bool mask (simulating KV cache).""" + H_q, H_kv, D = 16, 2, 256 + for cache_len in [64, 256, 512]: + with self.subTest(cache_len=cache_len): + B, Lq = 1, 1 + torch.manual_seed(42) + q = torch.randn(B, H_q, Lq, D, dtype=torch.bfloat16, device="cuda") + k = torch.randn( + B, H_kv, cache_len, D, dtype=torch.bfloat16, device="cuda" + ) + v = torch.randn( + B, H_kv, cache_len, D, dtype=torch.bfloat16, device="cuda" + ) + + # KV cache mask: first `pos` entries are valid + pos = cache_len * 3 // 4 + mask = torch.zeros(B, 1, Lq, cache_len, dtype=torch.bool, device="cuda") + mask[:, :, :, :pos] = True + + out = self.sdpa(q, k, v, attn_mask=mask, enable_gqa=True) + ref = _reference_sdpa(q, k, v, attn_mask=mask) + + self.assertFalse(torch.isnan(out).any()) + self.assertLess(_max_abs_error(out, ref), 0.05) + + def test_gqa_batch_size(self): + """GQA with batch_size > 1.""" + for B in [2, 4]: + with self.subTest(B=B): + H_q, H_kv, Lq, Lk, D = 8, 2, 1, 128, 128 + torch.manual_seed(42) + q = torch.randn(B, H_q, Lq, D, dtype=torch.bfloat16, device="cuda") + k = torch.randn(B, H_kv, Lk, D, dtype=torch.bfloat16, device="cuda") + v = torch.randn(B, H_kv, Lk, D, dtype=torch.bfloat16, device="cuda") + + out = self.sdpa(q, k, v, enable_gqa=True) + ref = _reference_sdpa(q, k, v) + + self.assertFalse(torch.isnan(out).any()) + self.assertLess(_max_abs_error(out, ref), 0.05) + + # ------------------------------------------------------------------ + # Qwen 3.5 MoE configuration + # ------------------------------------------------------------------ + + def test_qwen35_moe_config(self): + """Exact Qwen 3.5 MoE attention config: H_q=16, H_kv=2, D=256.""" + B, H_q, H_kv, D = 1, 16, 2, 256 + for Lq, Lk in [(1, 128), (1, 512), (1, 1024), (4, 512)]: + with self.subTest(Lq=Lq, Lk=Lk): + torch.manual_seed(42) + q = torch.randn(B, H_q, Lq, D, dtype=torch.bfloat16, device="cuda") + k = torch.randn(B, H_kv, Lk, D, dtype=torch.bfloat16, device="cuda") + v = torch.randn(B, H_kv, Lk, D, dtype=torch.bfloat16, device="cuda") + + # Simulate KV cache mask + mask = torch.ones(B, 1, Lq, Lk, dtype=torch.bool, device="cuda") + + out = self.sdpa(q, k, v, attn_mask=mask, enable_gqa=True) + ref = _reference_sdpa(q, k, v, attn_mask=mask) + + self.assertEqual(out.shape, (B, H_q, Lq, D)) + self.assertFalse(torch.isnan(out).any()) + self.assertLess( + _max_abs_error(out, ref), 0.05, f"Qwen config Lq={Lq} Lk={Lk}" + ) + + # ------------------------------------------------------------------ + # Edge cases and validation + # ------------------------------------------------------------------ + + def test_output_shape(self): + """Output shape is always [B, H_q, L_q, D].""" + B, D = 1, 64 + for H_q, H_kv in [(4, 4), (8, 2), (6, 1)]: + for Lq, Lk in [(1, 64), (32, 64)]: + with self.subTest(H_q=H_q, H_kv=H_kv, Lq=Lq, Lk=Lk): + q = torch.randn(B, H_q, Lq, D, dtype=torch.bfloat16, device="cuda") + k = torch.randn(B, H_kv, Lk, D, dtype=torch.bfloat16, device="cuda") + v = torch.randn(B, H_kv, Lk, D, dtype=torch.bfloat16, device="cuda") + enable = H_q != H_kv + out = self.sdpa(q, k, v, enable_gqa=enable) + self.assertEqual(out.shape, (B, H_q, Lq, D)) + self.assertEqual(out.dtype, torch.bfloat16) + + def test_custom_scale(self): + """Custom attention scale.""" + B, H, Lq, Lk, D = 1, 4, 1, 64, 128 + torch.manual_seed(42) + q = torch.randn(B, H, Lq, D, dtype=torch.bfloat16, device="cuda") + k = torch.randn(B, H, Lk, D, dtype=torch.bfloat16, device="cuda") + v = torch.randn(B, H, Lk, D, dtype=torch.bfloat16, device="cuda") + + scale = 0.05 + out = self.sdpa(q, k, v, scale=scale) + ref = _reference_sdpa(q, k, v, scale=scale) + + self.assertFalse(torch.isnan(out).any()) + self.assertLess(_max_abs_error(out, ref), 0.05) + + def test_all_masked(self): + """All-masked block should produce zeros, not NaN.""" + B, H, D = 1, 4, 64 + Lq, Lk = 4, 128 + torch.manual_seed(42) + q = torch.randn(B, H, Lq, D, dtype=torch.bfloat16, device="cuda") + k = torch.randn(B, H, Lk, D, dtype=torch.bfloat16, device="cuda") + v = torch.randn(B, H, Lk, D, dtype=torch.bfloat16, device="cuda") + + # All-False mask: every entry is masked + mask = torch.zeros(B, 1, Lq, Lk, dtype=torch.bool, device="cuda") + out = self.sdpa(q, k, v, attn_mask=mask) + + self.assertFalse(torch.isnan(out).any(), "All-masked should not NaN") + self.assertFalse(torch.isinf(out).any(), "All-masked should not Inf") + + def test_gqa_validation_errors(self): + """Invalid GQA configs should raise RuntimeError.""" + B, D = 1, 64 + + # H_q not divisible by H_kv + q = torch.randn(B, 5, 1, D, dtype=torch.bfloat16, device="cuda") + k = torch.randn(B, 3, 64, D, dtype=torch.bfloat16, device="cuda") + v = torch.randn(B, 3, 64, D, dtype=torch.bfloat16, device="cuda") + with self.assertRaises(RuntimeError): + self.sdpa(q, k, v, enable_gqa=True) + + # H_q != H_kv without enable_gqa + q = torch.randn(B, 8, 1, D, dtype=torch.bfloat16, device="cuda") + k = torch.randn(B, 2, 64, D, dtype=torch.bfloat16, device="cuda") + v = torch.randn(B, 2, 64, D, dtype=torch.bfloat16, device="cuda") + with self.assertRaises(RuntimeError): + self.sdpa(q, k, v, enable_gqa=False) + + def test_per_head_mask_rejected(self): + """Per-head masks (H>1) should be rejected since the kernel broadcasts.""" + B, H, Lq, Lk, D = 1, 4, 4, 64, 64 + q = torch.randn(B, H, Lq, D, dtype=torch.bfloat16, device="cuda") + k = torch.randn(B, H, Lk, D, dtype=torch.bfloat16, device="cuda") + v = torch.randn(B, H, Lk, D, dtype=torch.bfloat16, device="cuda") + mask = torch.ones(B, H, Lq, Lk, dtype=torch.bool, device="cuda") + with self.assertRaises(RuntimeError): + self.sdpa(q, k, v, attn_mask=mask) + + def test_gqa_all_masked_decode(self): + """GQA decode with all-masked block should not NaN.""" + B, H_q, H_kv, Lq, Lk, D = 1, 8, 2, 1, 128, 64 + torch.manual_seed(42) + q = torch.randn(B, H_q, Lq, D, dtype=torch.bfloat16, device="cuda") + k = torch.randn(B, H_kv, Lk, D, dtype=torch.bfloat16, device="cuda") + v = torch.randn(B, H_kv, Lk, D, dtype=torch.bfloat16, device="cuda") + + mask = torch.zeros(B, 1, Lq, Lk, dtype=torch.bool, device="cuda") + out = self.sdpa(q, k, v, attn_mask=mask, enable_gqa=True) + + self.assertFalse(torch.isnan(out).any()) + self.assertFalse(torch.isinf(out).any()) + + def test_causal_lq_ne_lkv_rejected(self): + """is_causal=True with L_q != L_kv should raise RuntimeError.""" + B, H, D = 1, 4, 64 + q = torch.randn(B, H, 1, D, dtype=torch.bfloat16, device="cuda") + k = torch.randn(B, H, 128, D, dtype=torch.bfloat16, device="cuda") + v = torch.randn(B, H, 128, D, dtype=torch.bfloat16, device="cuda") + with self.assertRaises(RuntimeError): + self.sdpa(q, k, v, is_causal=True) + + def test_non_pow2_no_mask(self): + """Non-pow2 head dim without mask should work (mask_ptr=0 path).""" + B, H, Lq, Lk, D = 1, 4, 4, 64, 40 # D=40 is not pow2 + torch.manual_seed(42) + q = torch.randn(B, H, Lq, D, dtype=torch.bfloat16, device="cuda") + k = torch.randn(B, H, Lk, D, dtype=torch.bfloat16, device="cuda") + v = torch.randn(B, H, Lk, D, dtype=torch.bfloat16, device="cuda") + + out = self.sdpa(q, k, v) + ref = _reference_sdpa(q, k, v) + + self.assertFalse(torch.isnan(out).any()) + self.assertLess(_max_abs_error(out, ref), 0.05) + + +if __name__ == "__main__": + unittest.main() diff --git a/backends/cuda/triton/kernels/sdpa.py b/backends/cuda/triton/kernels/sdpa.py index d4597ff4197..d83f8e0557a 100644 --- a/backends/cuda/triton/kernels/sdpa.py +++ b/backends/cuda/triton/kernels/sdpa.py @@ -3,6 +3,19 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# +# The GQA "pack GQA" optimization is adapted from FlashAttention +# (Tri Dao, 2023-2025): +# https://github.com/Dao-AILab/flash-attention +# flash_attn/cute/pack_gqa.py — PackGQA class +# hopper/heuristics.h — should_pack_gqa() tile-utilization heuristic +# Licensed under BSD-3-Clause. +# +# Pack GQA folds multiple Q heads that share the same KV head into the M +# (sequence) dimension of a single tile, so K/V are loaded once per KV head +# instead of once per Q head. The tile-utilization heuristic decides when +# packing is beneficial (short seqlen_q, e.g. decode) vs. when simple head +# remapping suffices (long seqlen_q, e.g. prefill). """ Triton SDPA Kernel for ExecuTorch CUDA Backend. @@ -11,6 +24,11 @@ that can replace the default ATen/Edge SDPA operator during graph transformation to allow us export the model without decomposing the SDPA operator under libtorch free environment and have better performance. + +GQA support: when enable_gqa=True and H_q > H_kv, the kernel uses "pack GQA" +(adapted from FlashAttention) to fold multiple Q heads sharing the same KV head +into the M (sequence) dimension of a single tile. This avoids redundant K/V reads +and improves tile utilization, especially during decode (seqlen_q=1). """ import math @@ -40,35 +58,70 @@ def _next_power_of_2(x: int) -> int: return 256 +def _should_pack_gqa(L_q: int, num_groups: int, block_m: int) -> bool: + """Decide whether to use pack GQA based on tile utilization. + + Pack GQA folds multiple Q heads into the M dimension so they share + the same K/V loads. This helps when seqlen_q is small relative to + BLOCK_M (e.g., decode with seqlen_q=1). + + Heuristic from FlashAttention (hopper/heuristics.h, should_pack_gqa): + compare tile utilization with and without packing; pack if it + improves efficiency by >10%. + + Reference: https://github.com/Dao-AILab/flash-attention/blob/main/hopper/heuristics.h + """ + if num_groups <= 1: + return False + + def round_up(a, b): + return ((a + b - 1) // b) * b + + nopack_eff = L_q / round_up(L_q, block_m) + pack_eff = (L_q * num_groups) / round_up(L_q * num_groups, block_m) + return nopack_eff < 0.9 * pack_eff + + def _validate_qkv_shapes( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, -) -> tuple[int, int, int, int, int, int]: + enable_gqa: bool = False, +) -> tuple[int, int, int, int, int, int, int]: """ Validate dimensions and return shape info. Args: - query: Query tensor [B, H, L_q, D] - key: Key tensor [B, H, L_kv, D] - value: Value tensor [B, H, L_kv, D] + query: Query tensor [B, H_q, L_q, D] + key: Key tensor [B, H_kv, L_kv, D] + value: Value tensor [B, H_kv, L_kv, D] + enable_gqa: If True, H_q must be a multiple of H_kv (GQA/MQA). Returns: - Tuple of (B, H, L_q, L_kv, D_q, D_kv) + Tuple of (B, H_q, H_kv, L_q, L_kv, D_q, D_kv) Raises: RuntimeError: If dimensions are incompatible """ B_q, H_q, L_q, D_q = query.shape B_k, H_k, L_kv_k, D_k = key.shape B_v, H_v, L_kv_v, D_v = value.shape - # Validate batch and head dimensions + # Validate batch dimensions if not (B_q == B_k == B_v): raise RuntimeError( f"Batch dimension must match; got B_q={B_q}, B_k={B_k}, B_v={B_v}." ) - - if not (H_q == H_k == H_v): - raise RuntimeError( - f"Head dimension must match; got H_q={H_q}, H_k={H_k}, H_v={H_v}." - ) + # Validate head dimensions + if not (H_k == H_v): + raise RuntimeError(f"K and V head counts must match; got H_k={H_k}, H_v={H_v}.") + if enable_gqa: + if H_q % H_k != 0: + raise RuntimeError( + f"GQA requires H_q divisible by H_kv; got H_q={H_q}, H_kv={H_k}." + ) + else: + if not (H_q == H_k): + raise RuntimeError( + f"Head counts must match (or use enable_gqa=True); " + f"got H_q={H_q}, H_k={H_k}." + ) # Head dimension must match if not (D_q == D_k == D_v): raise RuntimeError( @@ -79,7 +132,7 @@ def _validate_qkv_shapes( raise RuntimeError( f"Key and Value must have the same sequence length; got L_k={L_kv_k}, L_v={L_kv_v}." ) - return B_q, H_q, L_q, L_kv_k, D_q, D_k + return B_q, H_q, H_k, L_q, L_kv_k, D_q, D_k # ============================================================================== @@ -93,7 +146,7 @@ def _sdpa_fwd_kernel_non_pow2( o_ptr, mask_ptr, B, - H, + H_grid, LQ, LK, HEAD_DIM, @@ -123,31 +176,57 @@ def _sdpa_fwd_kernel_non_pow2( BLOCK_D: tl.constexpr, HAS_MASK: tl.constexpr, IS_CAUSAL: tl.constexpr, + NUM_GROUPS: tl.constexpr, + PACK_GQA: tl.constexpr, ): """ SDPA forward kernel for non-power-of-2 HEAD_DIM. Uses dynamic masking to handle arbitrary head dimensions. + + PACK_GQA: when True, multiple Q heads sharing the same KV head are + folded into the M dimension. The grid iterates over H_kv heads and + each tile processes up to BLOCK_M rows from the packed (head, seq) + space. K/V are loaded once per KV head. """ pid_m = tl.program_id(axis=0) pid_bh = tl.program_id(axis=1) - b = pid_bh // H - h = pid_bh % H + b = pid_bh // H_grid + h_grid = pid_bh % H_grid - offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - offs_n = tl.arange(0, BLOCK_N) + offs_packed = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_d = tl.arange(0, BLOCK_D) - d_mask = offs_d < HEAD_DIM - q_row_mask = offs_m < LQ - q_base = q_ptr + b * stride_qb + h * stride_qh - k_base = k_ptr + b * stride_kb + h * stride_kh - v_base = v_ptr + b * stride_vb + h * stride_vh - o_base = o_ptr + b * stride_ob + h * stride_oh + if PACK_GQA: + seq_pos = offs_packed // NUM_GROUPS + h_within = offs_packed % NUM_GROUPS + h_q_rows = h_grid * NUM_GROUPS + h_within + h_kv = h_grid + row_valid = seq_pos < LQ + q_ptrs = ( + q_ptr + + b * stride_qb + + h_q_rows[:, None] * stride_qh + + seq_pos[:, None] * stride_ql + + offs_d[None, :] * stride_qd + ) + else: + seq_pos = offs_packed + h_kv = h_grid // NUM_GROUPS + row_valid = offs_packed < LQ + q_ptrs = ( + q_ptr + + b * stride_qb + + h_grid * stride_qh + + offs_packed[:, None] * stride_ql + + offs_d[None, :] * stride_qd + ) - q_ptrs = q_base + (offs_m[:, None] * stride_ql + offs_d[None, :] * stride_qd) - q = tl.load(q_ptrs, mask=q_row_mask[:, None] & d_mask[None, :], other=0.0) + q = tl.load(q_ptrs, mask=row_valid[:, None] & d_mask[None, :], other=0.0) + + k_base = k_ptr + b * stride_kb + h_kv * stride_kh + v_base = v_ptr + b * stride_vb + h_kv * stride_vh acc = tl.zeros((BLOCK_M, BLOCK_D), dtype=tl.float32) m_i = tl.full((BLOCK_M,), -float("inf"), dtype=tl.float32) @@ -161,27 +240,27 @@ def _sdpa_fwd_kernel_non_pow2( NEG_INF: tl.constexpr = float("-inf") for start_n in tl.range(0, LK, BLOCK_N, num_stages=2): - kn = start_n + offs_n - kv_col_mask = kn < LK + offs_n = start_n + tl.arange(0, BLOCK_N) + kv_col_mask = offs_n < LK - k_ptrs = k_base + (kn[:, None] * stride_kl + offs_d[None, :] * stride_kd) + k_ptrs = k_base + (offs_n[:, None] * stride_kl + offs_d[None, :] * stride_kd) k = tl.load(k_ptrs, mask=kv_col_mask[:, None] & d_mask[None, :], other=0.0) qk = tl.dot(q, tl.trans(k)) qk = (qk * qk_scale_log2).to(tl.float32) if IS_CAUSAL: - row_abs = offs_m[:, None] - col_abs = kn[None, :] - causal_mask = col_abs > row_abs + causal_mask = offs_n[None, :] > seq_pos[:, None] qk = tl.where(causal_mask, tl.full(qk.shape, NEG_INF, dtype=tl.float32), qk) if HAS_MASK: - mask_ptrs = ( - mask_b_base + offs_m[:, None] * stride_mlq + kn[None, :] * stride_mlk + m_ptrs = ( + mask_b_base + + seq_pos[:, None] * stride_mlq + + offs_n[None, :] * stride_mlk ) - tile_valid = q_row_mask[:, None] & kv_col_mask[None, :] - keep = tl.load(mask_ptrs, mask=tile_valid, other=False) + tile_valid = row_valid[:, None] & kv_col_mask[None, :] + keep = tl.load(m_ptrs, mask=tile_valid, other=False) qk = tl.where(keep, qk, tl.full(qk.shape, NEG_INF, dtype=tl.float32)) qk = tl.where( @@ -189,8 +268,6 @@ def _sdpa_fwd_kernel_non_pow2( ) m_ij = tl.maximum(m_i, tl.max(qk, 1).to(tl.float32)) - # Guard against all-masked blocks: when m_ij == -inf, qk - m_ij = NaN. - # Use 0.0 for p in that case (no contribution to output). safe_diff = tl.where( m_ij[:, None] > -float("inf"), qk - m_ij[:, None], -float("inf") ) @@ -201,7 +278,7 @@ def _sdpa_fwd_kernel_non_pow2( acc = (acc * alpha[:, None]).to(tl.float32) - v_ptrs = v_base + (kn[:, None] * stride_vl + offs_d[None, :] * stride_vd) + v_ptrs = v_base + (offs_n[:, None] * stride_vl + offs_d[None, :] * stride_vd) v = tl.load(v_ptrs, mask=kv_col_mask[:, None] & d_mask[None, :], other=0.0) acc = tl.dot(p.to(v.dtype), v, acc).to(tl.float32) @@ -210,8 +287,24 @@ def _sdpa_fwd_kernel_non_pow2( m_i = m_ij out = acc / l_i[:, None] - o_ptrs = o_base + (offs_m[:, None] * stride_ol + offs_d[None, :] * stride_od) - tl.store(o_ptrs, out.to(tl.bfloat16), mask=q_row_mask[:, None] & d_mask[None, :]) + + if PACK_GQA: + o_ptrs = ( + o_ptr + + b * stride_ob + + h_q_rows[:, None] * stride_oh + + seq_pos[:, None] * stride_ol + + offs_d[None, :] * stride_od + ) + else: + o_ptrs = ( + o_ptr + + b * stride_ob + + h_grid * stride_oh + + offs_packed[:, None] * stride_ol + + offs_d[None, :] * stride_od + ) + tl.store(o_ptrs, out.to(tl.bfloat16), mask=row_valid[:, None] & d_mask[None, :]) # ============================================================================== @@ -225,7 +318,7 @@ def _sdpa_fwd_kernel_body( O_ptr, Mask_ptr, B, - H, + H_grid, Lq, Lk, stride_qb, @@ -253,38 +346,74 @@ def _sdpa_fwd_kernel_body( BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, HEAD_DIM: tl.constexpr, + NUM_GROUPS: tl.constexpr, + PACK_GQA: tl.constexpr, ): """ Shared kernel body for SDPA forward pass. + + PACK_GQA: when True, multiple Q heads sharing the same KV head are + folded into the M dimension (adapted from FlashAttention's pack_gqa). + The grid iterates over H_kv heads; each tile processes rows from the + packed (head, seq) space. K/V are loaded once per KV head, eliminating + redundant HBM reads across Q heads in a group. + + When False, the grid iterates over H_q heads and each program handles + one Q head with simple h_kv = h_q // NUM_GROUPS remapping. """ pid_m = tl.program_id(axis=0) pid_bh = tl.program_id(axis=1) - b = pid_bh // H - h = pid_bh % H + b = pid_bh // H_grid + h_grid = pid_bh % H_grid - offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - offs_n_init = tl.arange(0, BLOCK_N) + offs_packed = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_d = tl.arange(0, HEAD_DIM) - q_ptrs = Q_ptr + ( - b * stride_qb - + h * stride_qh - + (offs_m[:, None] * stride_qm) - + (offs_d[None, :] * stride_qd) - ) - q_mask = (offs_m[:, None] < Lq) & (offs_d[None, :] < HEAD_DIM) + if PACK_GQA: + # Decompose packed index: heads interleaved with positions + # [h0_pos0, h1_pos0, ..., h(G-1)_pos0, h0_pos1, h1_pos1, ...] + seq_pos = offs_packed // NUM_GROUPS + h_within = offs_packed % NUM_GROUPS + h_q_rows = h_grid * NUM_GROUPS + h_within # [BLOCK_M] vector + h_kv = h_grid + row_valid = seq_pos < Lq + + # Scattered Q load: each row may be a different Q head + q_ptrs = Q_ptr + ( + b * stride_qb + + h_q_rows[:, None] * stride_qh + + seq_pos[:, None] * stride_qm + + offs_d[None, :] * stride_qd + ) + else: + seq_pos = offs_packed + h_kv = h_grid // NUM_GROUPS + row_valid = offs_packed < Lq + + # Uniform Q load: all rows are the same Q head + q_ptrs = Q_ptr + ( + b * stride_qb + + h_grid * stride_qh + + offs_packed[:, None] * stride_qm + + offs_d[None, :] * stride_qd + ) + + q_mask = row_valid[:, None] & (offs_d[None, :] < HEAD_DIM) q = tl.load(q_ptrs, mask=q_mask, other=0.0).to(tl.bfloat16) m_i = tl.full([BLOCK_M], -float("inf"), dtype=tl.float32) l_i = tl.zeros([BLOCK_M], dtype=tl.float32) acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + offs_n_init = tl.arange(0, BLOCK_N) + for start_n in tl.range(0, Lk, BLOCK_N): offs_n = start_n + offs_n_init + # K load: uniform (single KV head, shared across all Q heads in tile) k_ptrs = K_ptr + ( b * stride_kb - + h * stride_kh + + h_kv * stride_kh + (offs_n[:, None] * stride_kn) + (offs_d[None, :] * stride_kd) ) @@ -296,26 +425,22 @@ def _sdpa_fwd_kernel_body( if HAS_MASK: mask_ptrs = Mask_ptr + ( b * stride_mb - + (offs_m[:, None] * stride_mq) + + (seq_pos[:, None] * stride_mq) + (offs_n[None, :] * stride_mk) ) - mn_mask = (offs_m[:, None] < Lq) & (offs_n[None, :] < Lk) + mn_mask = row_valid[:, None] & (offs_n[None, :] < Lk) mask_block = tl.load(mask_ptrs, mask=mn_mask, other=False) qk = tl.where( mask_block, qk, tl.full(qk.shape, -float("inf"), dtype=tl.float32) ) if IS_CAUSAL: - abs_m = offs_m[:, None] - abs_n = offs_n[None, :] - causal = abs_n > abs_m + causal = offs_n[None, :] > seq_pos[:, None] qk = tl.where( causal, tl.full(qk.shape, -float("inf"), dtype=tl.float32), qk ) m_ij = tl.maximum(m_i, tl.max(qk, axis=1).to(tl.float32)) - # Guard against all-masked blocks: when m_ij == -inf, qk - m_ij = NaN. - # Use 0.0 for p in that case (no contribution to output). safe_diff = tl.where( m_ij[:, None] > -float("inf"), qk - m_ij[:, None], -float("inf") ) @@ -324,9 +449,10 @@ def _sdpa_fwd_kernel_body( safe_alpha_diff = tl.where(m_ij > -float("inf"), m_i - m_ij, 0.0) alpha = tl.exp(safe_alpha_diff).to(tl.float32) + # V load: uniform (single KV head) v_ptrs = V_ptr + ( b * stride_vb - + h * stride_vh + + h_kv * stride_vh + (offs_n[:, None] * stride_vn) + (offs_d[None, :] * stride_vd) ) @@ -341,13 +467,22 @@ def _sdpa_fwd_kernel_body( inv_l_i = tl.where(l_i > 0, 1.0 / l_i, 0.0) acc = acc * inv_l_i[:, None] - o_ptrs = O_ptr + ( - b * stride_ob - + h * stride_oh - + (offs_m[:, None] * stride_om) - + (offs_d[None, :] * stride_od) - ) - o_mask = (offs_m[:, None] < Lq) & (offs_d[None, :] < HEAD_DIM) + # O store: scattered when PACK_GQA, uniform otherwise + if PACK_GQA: + o_ptrs = O_ptr + ( + b * stride_ob + + h_q_rows[:, None] * stride_oh + + seq_pos[:, None] * stride_om + + offs_d[None, :] * stride_od + ) + else: + o_ptrs = O_ptr + ( + b * stride_ob + + h_grid * stride_oh + + offs_packed[:, None] * stride_om + + offs_d[None, :] * stride_od + ) + o_mask = row_valid[:, None] & (offs_d[None, :] < HEAD_DIM) tl.store(o_ptrs, acc.to(tl.bfloat16), mask=o_mask) @@ -359,7 +494,7 @@ def _sdpa_fwd_kernel_body( triton.Config({"BLOCK_M": 64, "BLOCK_N": 256}, num_warps=8, num_stages=3), triton.Config({"BLOCK_M": 64, "BLOCK_N": 32}, num_warps=4, num_stages=2), ], - key=["Lq", "Lk", "HEAD_DIM", "HAS_MASK", "IS_CAUSAL"], + key=["Lq", "Lk", "HEAD_DIM", "HAS_MASK", "IS_CAUSAL", "NUM_GROUPS", "PACK_GQA"], ) @triton.jit def _sdpa_fwd_kernel_m64( @@ -369,7 +504,7 @@ def _sdpa_fwd_kernel_m64( O_ptr, Mask_ptr, B, - H, + H_grid, Lq, Lk, stride_qb, @@ -395,12 +530,11 @@ def _sdpa_fwd_kernel_m64( HAS_MASK: tl.constexpr, IS_CAUSAL: tl.constexpr, HEAD_DIM: tl.constexpr, + NUM_GROUPS: tl.constexpr, + PACK_GQA: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): - """ - SDPA kernel with BLOCK_M=64 optimizations. - """ _sdpa_fwd_kernel_body( Q_ptr, K_ptr, @@ -408,7 +542,7 @@ def _sdpa_fwd_kernel_m64( O_ptr, Mask_ptr, B, - H, + H_grid, Lq, Lk, stride_qb, @@ -436,6 +570,8 @@ def _sdpa_fwd_kernel_m64( BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, HEAD_DIM=HEAD_DIM, + NUM_GROUPS=NUM_GROUPS, + PACK_GQA=PACK_GQA, ) @@ -446,7 +582,7 @@ def _sdpa_fwd_kernel_m64( triton.Config({"BLOCK_M": 32, "BLOCK_N": 256}, num_warps=4, num_stages=2), triton.Config({"BLOCK_M": 32, "BLOCK_N": 32}, num_warps=4, num_stages=2), ], - key=["Lq", "Lk", "HEAD_DIM", "HAS_MASK", "IS_CAUSAL"], + key=["Lq", "Lk", "HEAD_DIM", "HAS_MASK", "IS_CAUSAL", "NUM_GROUPS", "PACK_GQA"], ) @triton.jit def _sdpa_fwd_kernel_m32( @@ -456,7 +592,7 @@ def _sdpa_fwd_kernel_m32( O_ptr, Mask_ptr, B, - H, + H_grid, Lq, Lk, stride_qb, @@ -482,12 +618,11 @@ def _sdpa_fwd_kernel_m32( HAS_MASK: tl.constexpr, IS_CAUSAL: tl.constexpr, HEAD_DIM: tl.constexpr, + NUM_GROUPS: tl.constexpr, + PACK_GQA: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): - """ - SDPA kernel with BLOCK_M=32 optimizations for small workloads. - """ _sdpa_fwd_kernel_body( Q_ptr, K_ptr, @@ -495,7 +630,7 @@ def _sdpa_fwd_kernel_m32( O_ptr, Mask_ptr, B, - H, + H_grid, Lq, Lk, stride_qb, @@ -523,6 +658,8 @@ def _sdpa_fwd_kernel_m32( BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, HEAD_DIM=HEAD_DIM, + NUM_GROUPS=NUM_GROUPS, + PACK_GQA=PACK_GQA, ) @@ -552,10 +689,6 @@ def _validate_sdpa_inputs( raise RuntimeError( "dropout_p must be 0.0 (not supported in this implementation)." ) - if enable_gqa is not False: - raise RuntimeError( - "enable_gqa must be False (not supported in this implementation)." - ) def _prepare_mask_params( @@ -572,13 +705,18 @@ def _prepare_mask_params( raise RuntimeError("attn_mask must have dtype torch.bool") if not attn_mask.is_cuda: raise RuntimeError("attn_mask must be a CUDA tensor") + if attn_mask.shape[1] != 1: + raise RuntimeError( + f"attn_mask head dimension must be 1 (broadcast over heads); " + f"per-head masks are not supported. Got attn_mask.shape={attn_mask.shape}" + ) if ( attn_mask.shape[0] != B or attn_mask.shape[2] != L_q or attn_mask.shape[3] != L_kv ): raise RuntimeError( - f"attn_mask shape mismatch: expected [B={B}, H, L_q={L_q}, L_kv={L_kv}], " + f"attn_mask shape mismatch: expected [B={B}, 1, L_q={L_q}, L_kv={L_kv}], " f"got {attn_mask.shape}" ) return ( @@ -596,7 +734,8 @@ def _launch_pow2_kernel( value: torch.Tensor, out: torch.Tensor, B: int, - H: int, + H_q: int, + H_kv: int, L_q: int, L_kv: int, D: int, @@ -607,6 +746,8 @@ def _launch_pow2_kernel( stride_mq: int, stride_mk: int, is_causal: bool, + num_groups: int, + pack_gqa: bool, ) -> None: """Launch power-of-2 optimized SDPA kernel.""" stride_qb, stride_qh, stride_qm, stride_qd = query.stride() @@ -614,10 +755,17 @@ def _launch_pow2_kernel( stride_vb, stride_vh, stride_vn, stride_vd = value.stride() stride_ob, stride_oh, stride_om, stride_od = out.stride() + if pack_gqa: + H_grid = H_kv + Lq_packed = L_q * num_groups + else: + H_grid = H_q + Lq_packed = L_q + def grid(meta): - return (triton.cdiv(L_q, meta["BLOCK_M"]), B * H) + return (triton.cdiv(Lq_packed, meta["BLOCK_M"]), B * H_grid) - total_ctas_m64 = ((L_q + 63) // 64) * (B * H) + total_ctas_m64 = ((Lq_packed + 63) // 64) * (B * H_grid) threshold = 4 * 84 kernel = ( _sdpa_fwd_kernel_m32 if total_ctas_m64 < threshold else _sdpa_fwd_kernel_m64 @@ -630,7 +778,7 @@ def grid(meta): out, Mask_ptr if HAS_MASK else 0, B, - H, + H_grid, L_q, L_kv, stride_qb, @@ -656,6 +804,8 @@ def grid(meta): HAS_MASK=HAS_MASK, IS_CAUSAL=is_causal, HEAD_DIM=D, + NUM_GROUPS=num_groups, + PACK_GQA=pack_gqa, ) @@ -666,13 +816,16 @@ def _launch_non_pow2_kernel( out: torch.Tensor, attn_mask: Optional[torch.Tensor], B: int, - H: int, + H_q: int, + H_kv: int, L_q: int, L_kv: int, D: int, sm_scale: float, HAS_MASK: bool, is_causal: bool, + num_groups: int, + pack_gqa: bool, ) -> None: """Launch non-power-of-2 SDPA kernel with dynamic HEAD_DIM masking.""" stride_qb, stride_qh, stride_qm, stride_qd = query.stride() @@ -686,6 +839,13 @@ def _launch_non_pow2_kernel( num_warps = 4 num_stages = 2 + if pack_gqa: + H_grid = H_kv + Lq_packed = L_q * num_groups + else: + H_grid = H_q + Lq_packed = L_q + if HAS_MASK: mask_ptr = attn_mask stride_mb_np2 = attn_mask.stride(0) @@ -693,11 +853,11 @@ def _launch_non_pow2_kernel( stride_mlq_np2 = attn_mask.stride(2) stride_mlk_np2 = attn_mask.stride(3) else: - mask_ptr = torch.empty((1,), device=query.device, dtype=torch.bool) + mask_ptr = 0 stride_mb_np2 = stride_mh_np2 = stride_mlq_np2 = stride_mlk_np2 = 0 def grid_non_pow2(meta): - return (triton.cdiv(L_q, meta["BLOCK_M"]), B * H) + return (triton.cdiv(Lq_packed, meta["BLOCK_M"]), B * H_grid) wrap_triton(_sdpa_fwd_kernel_non_pow2)[grid_non_pow2]( query, @@ -706,7 +866,7 @@ def grid_non_pow2(meta): out, mask_ptr, B, - H, + H_grid, L_q, L_kv, D, @@ -736,6 +896,8 @@ def grid_non_pow2(meta): BLOCK_D=BLOCK_D, HAS_MASK=HAS_MASK, IS_CAUSAL=is_causal, + NUM_GROUPS=num_groups, + PACK_GQA=pack_gqa, num_warps=num_warps, num_stages=num_stages, ) @@ -753,31 +915,50 @@ def sdpa( enable_gqa: bool = False, ) -> torch.Tensor: """ - Triton fused Scaled Dot-Product Attention with optimized dual-kernel approach. + Triton fused Scaled Dot-Product Attention with GQA pack optimization. + + When enable_gqa=True and H_q > H_kv, this kernel automatically decides + whether to use "pack GQA" (folding Q heads into the M dimension so they + share K/V loads) based on a tile-utilization heuristic from FlashAttention. Args: - query: Query tensor with size [B, H, L_q, D] and dtype torch.bfloat16 - key: Key tensor [B, H, L_kv, D] and dtype torch.bfloat16 - value: Value tensor [B, H, L_kv, D] and dtype torch.bfloat16 - attn_mask: Optional attention mask [B, H, L_q, L_kv] with dtype torch.bool - dropout_p: must be 0.0 (others are not supported) - is_causal: whether to apply causal masking + query: Query tensor [B, H_q, L_q, D], dtype torch.bfloat16 + key: Key tensor [B, H_kv, L_kv, D], dtype torch.bfloat16 + value: Value tensor [B, H_kv, L_kv, D], dtype torch.bfloat16 + attn_mask: Optional bool mask [B, 1, L_q, L_kv] (broadcast over heads) + dropout_p: must be 0.0 + is_causal: apply causal masking scale: attention scale (default: 1/sqrt(D)) - enable_gqa: must be False (True is not supported) + enable_gqa: allow H_q != H_kv (GQA/MQA) Returns: - Output tensor [B, H, L_q, D] with dtype torch.bfloat16 + Output tensor [B, H_q, L_q, D], dtype torch.bfloat16 """ _validate_sdpa_inputs(query, key, value, dropout_p, enable_gqa) - B, H, L_q, L_kv, D_q, _ = _validate_qkv_shapes(query, key, value) + B, H_q, H_kv, L_q, L_kv, D_q, _ = _validate_qkv_shapes( + query, key, value, enable_gqa + ) D = D_q + num_groups = H_q // H_kv if is_causal and L_q != L_kv: raise RuntimeError( - f"Causal masking requires L_q == L_kv; got L_q={L_q}, L_kv={L_kv}." + f"Causal masking requires L_q == L_kv; got L_q={L_q}, L_kv={L_kv}. " + "For decode (L_q < L_kv), use an explicit bool mask instead." ) - out = torch.empty((B, H, L_q, D), device=query.device, dtype=query.dtype) + # Decide whether to pack GQA based on tile utilization heuristic. + # Use the actual BLOCK_M that the launched kernel will use: + # - non-pow2 path always uses BLOCK_M=32 + # - pow2 path selects M32 or M64 based on CTA occupancy + if not _is_power_of_2(D): + block_m = 32 + else: + total_ctas_m64 = ((L_q * num_groups + 63) // 64) * (B * H_kv) + block_m = 32 if total_ctas_m64 < 4 * 84 else 64 + pack_gqa = _should_pack_gqa(L_q, num_groups, block_m) + + out = torch.empty((B, H_q, L_q, D), device=query.device, dtype=query.dtype) sm_scale = 1.0 / math.sqrt(D) if scale == 0.0 else scale HAS_MASK, Mask_ptr, stride_mb, stride_mq, stride_mk = _prepare_mask_params( attn_mask, B, L_q, L_kv @@ -790,7 +971,8 @@ def sdpa( value, out, B, - H, + H_q, + H_kv, L_q, L_kv, D, @@ -801,6 +983,8 @@ def sdpa( stride_mq, stride_mk, is_causal, + num_groups, + pack_gqa, ) else: _launch_non_pow2_kernel( @@ -810,13 +994,16 @@ def sdpa( out, attn_mask, B, - H, + H_q, + H_kv, L_q, L_kv, D, sm_scale, HAS_MASK, is_causal, + num_groups, + pack_gqa, ) return out @@ -833,7 +1020,7 @@ def _sdpa_abstract( dropout_p: float = 0.0, is_causal: bool = False, scale: float = 0.0, - enable_gq: bool = False, + enable_gqa: bool = False, ) -> torch.Tensor: """ Abstract/fake implementation for torch.export. @@ -842,6 +1029,6 @@ def _sdpa_abstract( # Validate dtypes match assert query.dtype == key.dtype == value.dtype, "Q, K, V must have the same dtype" # Validate kqv's shape and get the output shape - B, H, L_q, _, D_q, _ = _validate_qkv_shapes(query, key, value) + B, H_q, _H_kv, L_q, _, D_q, _ = _validate_qkv_shapes(query, key, value, enable_gqa) - return torch.empty(B, H, L_q, D_q, dtype=query.dtype, device=query.device) + return torch.empty(B, H_q, L_q, D_q, dtype=query.dtype, device=query.device) diff --git a/examples/models/qwen3_5_moe/model.py b/examples/models/qwen3_5_moe/model.py index 25be4c5f8c1..1e2abba2b9f 100644 --- a/examples/models/qwen3_5_moe/model.py +++ b/examples/models/qwen3_5_moe/model.py @@ -267,14 +267,11 @@ def forward(self, x, input_pos): # KV cache k, v = self.kv_cache.update(input_pos, k, v) - # GQA expansion - if self.n_kv_groups > 1: - k = k.repeat_interleave(self.n_kv_groups, dim=1) - v = v.repeat_interleave(self.n_kv_groups, dim=1) - - # SDPA with bool mask + # SDPA with GQA — kernel maps Q heads to KV heads internally attn_mask = self.mask[input_pos].unsqueeze(0).unsqueeze(0) - y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + y = F.scaled_dot_product_attention( + q, k, v, attn_mask=attn_mask, enable_gqa=True + ) y = y.transpose(1, 2).contiguous().view(B, T, -1) diff --git a/examples/models/voxtral_realtime/model.py b/examples/models/voxtral_realtime/model.py index cc92710bb33..4a5981caf68 100644 --- a/examples/models/voxtral_realtime/model.py +++ b/examples/models/voxtral_realtime/model.py @@ -514,9 +514,10 @@ def forward( class StandardSDPA(nn.Module): """Scaled dot-product attention using F.scaled_dot_product_attention. - Supports GQA via repeat_interleave when n_heads != n_kv_heads. - Expects Q in [B, S, H, D]; K/V in [B, H, S, D] by default - (set transpose_kv=True if K/V arrive in [B, S, H, D]). + Supports GQA via enable_gqa=True — the kernel maps Q heads to KV heads + internally, avoiding redundant K/V memory expansion. + Expects Q in [B, S, H_q, D]; K/V in [B, H_kv, S, D] by default + (set transpose_kv=True if K/V arrive in [B, S, H_kv, D]). """ def __init__( @@ -525,7 +526,7 @@ def __init__( super().__init__() self.n_heads = n_heads self.n_kv_heads = n_kv_heads - self.n_rep = n_heads // n_kv_heads + self.enable_gqa = n_heads != n_kv_heads self.head_dim = head_dim self.dim = n_heads * head_dim self.transpose_kv = transpose_kv @@ -545,15 +546,11 @@ def forward( k = k.transpose(1, 2) v = v.transpose(1, 2) - if self.n_rep > 1: - k = k.repeat_interleave(self.n_rep, dim=1) - v = v.repeat_interleave(self.n_rep, dim=1) - if attn_mask is None: attn_mask = _build_causal_mask_bool(input_pos, k.shape[2], q.device) y = F.scaled_dot_product_attention( - q, k, v, attn_mask=attn_mask, is_causal=False + q, k, v, attn_mask=attn_mask, is_causal=False, enable_gqa=self.enable_gqa ) y = y.transpose(1, 2).contiguous()