Skip to content

Commit e43437f

Browse files
committed
Enable GQA on voxtral_realtime
1 parent 4d908bf commit e43437f

5 files changed

Lines changed: 188 additions & 24 deletions

File tree

.claude/scheduled_tasks.lock

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"sessionId":"2321a402-406a-4717-8d8b-5d17b8c5210a","pid":921768,"acquiredAt":1774525472191}
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
"""Benchmark: Pack GQA SDPA vs repeat_interleave + MHA SDPA.
9+
10+
Compares two approaches for GQA attention on consumer GPUs:
11+
1. repeat_interleave: expand K/V to H_q heads, then call SDPA with H_q==H_kv
12+
2. pack_gqa: call SDPA with enable_gqa=True (kernel handles head mapping)
13+
14+
Usage:
15+
LD_LIBRARY_PATH=/home/mnachin/local/miniconda3/envs/executorch/lib:$LD_LIBRARY_PATH \
16+
python3 backends/cuda/tests/bench_sdpa_gqa.py
17+
"""
18+
19+
import sys
20+
import os
21+
import time
22+
23+
import torch
24+
import torch.nn.functional as F
25+
26+
# Import Triton SDPA
27+
kernels_dir = os.path.join(os.path.dirname(__file__), "..", "triton", "kernels")
28+
sys.path.insert(0, os.path.abspath(kernels_dir))
29+
from sdpa import sdpa
30+
31+
32+
def _benchmark_fn(fn, warmup=10, repeats=100):
33+
"""Benchmark a function, return median time in microseconds."""
34+
# Warmup
35+
for _ in range(warmup):
36+
fn()
37+
torch.cuda.synchronize()
38+
39+
times = []
40+
for _ in range(repeats):
41+
torch.cuda.synchronize()
42+
start = time.perf_counter()
43+
fn()
44+
torch.cuda.synchronize()
45+
end = time.perf_counter()
46+
times.append((end - start) * 1e6) # microseconds
47+
48+
times.sort()
49+
return times[len(times) // 2] # median
50+
51+
52+
def bench_config(B, H_q, H_kv, L_q, L_kv, D, has_mask=False):
53+
"""Benchmark one configuration, return (repeat_interleave_us, pack_gqa_us)."""
54+
num_groups = H_q // H_kv
55+
56+
torch.manual_seed(42)
57+
q = torch.randn(B, H_q, L_q, D, dtype=torch.bfloat16, device="cuda")
58+
k = torch.randn(B, H_kv, L_kv, D, dtype=torch.bfloat16, device="cuda")
59+
v = torch.randn(B, H_kv, L_kv, D, dtype=torch.bfloat16, device="cuda")
60+
61+
if has_mask:
62+
mask = torch.ones(B, 1, L_q, L_kv, dtype=torch.bool, device="cuda")
63+
else:
64+
mask = None
65+
66+
# Approach 1: repeat_interleave + MHA SDPA
67+
def fn_repeat():
68+
k_exp = k.repeat_interleave(num_groups, dim=1)
69+
v_exp = v.repeat_interleave(num_groups, dim=1)
70+
if mask is not None:
71+
mask_exp = mask.expand(B, H_q, L_q, L_kv)
72+
return sdpa(q, k_exp, v_exp, attn_mask=mask_exp)
73+
return sdpa(q, k_exp, v_exp)
74+
75+
# Approach 2: pack GQA SDPA
76+
def fn_pack_gqa():
77+
return sdpa(q, k, v, attn_mask=mask, enable_gqa=True)
78+
79+
t_repeat = _benchmark_fn(fn_repeat)
80+
t_pack = _benchmark_fn(fn_pack_gqa)
81+
82+
return t_repeat, t_pack
83+
84+
85+
def main():
86+
if not torch.cuda.is_available():
87+
print("CUDA not available")
88+
return
89+
90+
gpu_name = torch.cuda.get_device_name(0)
91+
print(f"GPU: {gpu_name}")
92+
print()
93+
94+
configs = [
95+
# Decode configs (L_q=1) — pack GQA should dominate
96+
{"B": 1, "H_q": 16, "H_kv": 2, "L_q": 1, "L_kv": 128, "D": 256, "label": "Qwen3.5 decode, ctx=128"},
97+
{"B": 1, "H_q": 16, "H_kv": 2, "L_q": 1, "L_kv": 512, "D": 256, "label": "Qwen3.5 decode, ctx=512"},
98+
{"B": 1, "H_q": 16, "H_kv": 2, "L_q": 1, "L_kv": 1024, "D": 256, "label": "Qwen3.5 decode, ctx=1024"},
99+
{"B": 1, "H_q": 16, "H_kv": 2, "L_q": 1, "L_kv": 2048, "D": 256, "label": "Qwen3.5 decode, ctx=2048"},
100+
{"B": 1, "H_q": 16, "H_kv": 2, "L_q": 1, "L_kv": 4096, "D": 256, "label": "Qwen3.5 decode, ctx=4096"},
101+
102+
# Decode with mask
103+
{"B": 1, "H_q": 16, "H_kv": 2, "L_q": 1, "L_kv": 1024, "D": 256, "label": "Qwen3.5 decode+mask, ctx=1024", "has_mask": True},
104+
105+
# Decode with different GQA ratios
106+
{"B": 1, "H_q": 32, "H_kv": 8, "L_q": 1, "L_kv": 2048, "D": 128, "label": "Llama-style 4:1 decode, ctx=2048"},
107+
{"B": 1, "H_q": 8, "H_kv": 1, "L_q": 1, "L_kv": 2048, "D": 128, "label": "MQA 8:1 decode, ctx=2048"},
108+
109+
# Short seqlen (pack GQA should help)
110+
{"B": 1, "H_q": 16, "H_kv": 2, "L_q": 4, "L_kv": 1024, "D": 256, "label": "Qwen3.5 short L_q=4, ctx=1024"},
111+
{"B": 1, "H_q": 16, "H_kv": 2, "L_q": 8, "L_kv": 1024, "D": 256, "label": "Qwen3.5 short L_q=8, ctx=1024"},
112+
113+
# Prefill configs (L_q=L_kv) — repeat_interleave should be comparable
114+
{"B": 1, "H_q": 16, "H_kv": 2, "L_q": 128, "L_kv": 128, "D": 256, "label": "Qwen3.5 prefill, L=128"},
115+
{"B": 1, "H_q": 16, "H_kv": 2, "L_q": 512, "L_kv": 512, "D": 256, "label": "Qwen3.5 prefill, L=512"},
116+
{"B": 1, "H_q": 16, "H_kv": 2, "L_q": 1024, "L_kv": 1024, "D": 256, "label": "Qwen3.5 prefill, L=1024"},
117+
118+
# Batch > 1
119+
{"B": 4, "H_q": 16, "H_kv": 2, "L_q": 1, "L_kv": 1024, "D": 256, "label": "Qwen3.5 B=4 decode, ctx=1024"},
120+
]
121+
122+
header = f"{'Config':<45} {'repeat_interleave':>18} {'pack_gqa':>12} {'Speedup':>10}"
123+
print(header)
124+
print("-" * len(header))
125+
126+
for cfg in configs:
127+
label = cfg.pop("label")
128+
has_mask = cfg.pop("has_mask", False)
129+
t_repeat, t_pack = bench_config(**cfg, has_mask=has_mask)
130+
speedup = t_repeat / t_pack
131+
print(
132+
f"{label:<45} {t_repeat:>14.1f} us {t_pack:>8.1f} us {speedup:>9.2f}x"
133+
)
134+
135+
print()
136+
print("Speedup > 1.0 means pack_gqa is faster.")
137+
print("Speedup < 1.0 means repeat_interleave is faster.")
138+
139+
140+
if __name__ == "__main__":
141+
main()

backends/cuda/tests/test_triton_sdpa.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,6 @@ def _reference_sdpa(q, k, v, attn_mask=None, is_causal=False, scale=None):
6868
)
6969

7070

71-
def _max_relative_error(out, ref):
72-
"""Max absolute error normalized by reference magnitude."""
73-
diff = (out.float() - ref.float()).abs()
74-
return (diff / ref.float().abs().clamp(min=1e-6)).max().item()
75-
76-
7771
def _max_abs_error(out, ref):
7872
return (out.float() - ref.float()).abs().max().item()
7973

@@ -102,11 +96,6 @@ def _max_abs_error(out, ref):
10296
(256, 256),
10397
]
10498

105-
SEQLEN_PAIRS_LONG = [
106-
(512, 512),
107-
(1024, 1024),
108-
]
109-
11099
# GQA configurations: (H_q, H_kv, label)
111100
GQA_CONFIGS = [
112101
(4, 4, "mha"), # MHA: 1:1
@@ -533,6 +522,16 @@ def test_gqa_validation_errors(self):
533522
with self.assertRaises(RuntimeError):
534523
self.sdpa(q, k, v, enable_gqa=False)
535524

525+
def test_per_head_mask_rejected(self):
526+
"""Per-head masks (H>1) should be rejected since the kernel broadcasts."""
527+
B, H, Lq, Lk, D = 1, 4, 4, 64, 64
528+
q = torch.randn(B, H, Lq, D, dtype=torch.bfloat16, device="cuda")
529+
k = torch.randn(B, H, Lk, D, dtype=torch.bfloat16, device="cuda")
530+
v = torch.randn(B, H, Lk, D, dtype=torch.bfloat16, device="cuda")
531+
mask = torch.ones(B, H, Lq, Lk, dtype=torch.bool, device="cuda")
532+
with self.assertRaises(RuntimeError):
533+
self.sdpa(q, k, v, attn_mask=mask)
534+
536535
def test_gqa_all_masked_decode(self):
537536
"""GQA decode with all-masked block should not NaN."""
538537
B, H_q, H_kv, Lq, Lk, D = 1, 8, 2, 1, 128, 64
@@ -547,6 +546,29 @@ def test_gqa_all_masked_decode(self):
547546
self.assertFalse(torch.isnan(out).any())
548547
self.assertFalse(torch.isinf(out).any())
549548

549+
def test_causal_lq_ne_lkv_rejected(self):
550+
"""is_causal=True with L_q != L_kv should raise RuntimeError."""
551+
B, H, D = 1, 4, 64
552+
q = torch.randn(B, H, 1, D, dtype=torch.bfloat16, device="cuda")
553+
k = torch.randn(B, H, 128, D, dtype=torch.bfloat16, device="cuda")
554+
v = torch.randn(B, H, 128, D, dtype=torch.bfloat16, device="cuda")
555+
with self.assertRaises(RuntimeError):
556+
self.sdpa(q, k, v, is_causal=True)
557+
558+
def test_non_pow2_no_mask(self):
559+
"""Non-pow2 head dim without mask should work (mask_ptr=0 path)."""
560+
B, H, Lq, Lk, D = 1, 4, 4, 64, 40 # D=40 is not pow2
561+
torch.manual_seed(42)
562+
q = torch.randn(B, H, Lq, D, dtype=torch.bfloat16, device="cuda")
563+
k = torch.randn(B, H, Lk, D, dtype=torch.bfloat16, device="cuda")
564+
v = torch.randn(B, H, Lk, D, dtype=torch.bfloat16, device="cuda")
565+
566+
out = self.sdpa(q, k, v)
567+
ref = _reference_sdpa(q, k, v)
568+
569+
self.assertFalse(torch.isnan(out).any())
570+
self.assertLess(_max_abs_error(out, ref), 0.05)
571+
550572

551573
if __name__ == "__main__":
552574
unittest.main()

backends/cuda/triton/kernels/sdpa.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -851,7 +851,7 @@ def _launch_non_pow2_kernel(
851851
stride_mlq_np2 = attn_mask.stride(2)
852852
stride_mlk_np2 = attn_mask.stride(3)
853853
else:
854-
mask_ptr = torch.empty((1,), device=query.device, dtype=torch.bool)
854+
mask_ptr = 0
855855
stride_mb_np2 = stride_mh_np2 = stride_mlq_np2 = stride_mlk_np2 = 0
856856

857857
def grid_non_pow2(meta):
@@ -941,12 +941,15 @@ def sdpa(
941941

942942
if is_causal and L_q != L_kv:
943943
raise RuntimeError(
944-
f"Causal masking requires L_q == L_kv; got L_q={L_q}, L_kv={L_kv}."
944+
f"Causal masking requires L_q == L_kv; got L_q={L_q}, L_kv={L_kv}. "
945+
"For decode (L_q < L_kv), use an explicit bool mask instead."
945946
)
946947

947948
# Decide whether to pack GQA based on tile utilization heuristic.
948-
# Use 64 as the reference BLOCK_M for the heuristic (the common case).
949-
pack_gqa = _should_pack_gqa(L_q, num_groups, 64)
949+
# Mirror the kernel selection logic: M32 when CTAs are sparse, M64 otherwise.
950+
total_ctas_m64 = ((L_q * num_groups + 63) // 64) * (B * H_kv)
951+
block_m = 32 if total_ctas_m64 < 4 * 84 else 64
952+
pack_gqa = _should_pack_gqa(L_q, num_groups, block_m)
950953

951954
out = torch.empty((B, H_q, L_q, D), device=query.device, dtype=query.dtype)
952955
sm_scale = 1.0 / math.sqrt(D) if scale == 0.0 else scale

examples/models/voxtral_realtime/model.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -514,9 +514,10 @@ def forward(
514514
class StandardSDPA(nn.Module):
515515
"""Scaled dot-product attention using F.scaled_dot_product_attention.
516516
517-
Supports GQA via repeat_interleave when n_heads != n_kv_heads.
518-
Expects Q in [B, S, H, D]; K/V in [B, H, S, D] by default
519-
(set transpose_kv=True if K/V arrive in [B, S, H, D]).
517+
Supports GQA via enable_gqa=True — the kernel maps Q heads to KV heads
518+
internally, avoiding redundant K/V memory expansion.
519+
Expects Q in [B, S, H_q, D]; K/V in [B, H_kv, S, D] by default
520+
(set transpose_kv=True if K/V arrive in [B, S, H_kv, D]).
520521
"""
521522

522523
def __init__(
@@ -525,7 +526,7 @@ def __init__(
525526
super().__init__()
526527
self.n_heads = n_heads
527528
self.n_kv_heads = n_kv_heads
528-
self.n_rep = n_heads // n_kv_heads
529+
self.enable_gqa = n_heads != n_kv_heads
529530
self.head_dim = head_dim
530531
self.dim = n_heads * head_dim
531532
self.transpose_kv = transpose_kv
@@ -545,15 +546,11 @@ def forward(
545546
k = k.transpose(1, 2)
546547
v = v.transpose(1, 2)
547548

548-
if self.n_rep > 1:
549-
k = k.repeat_interleave(self.n_rep, dim=1)
550-
v = v.repeat_interleave(self.n_rep, dim=1)
551-
552549
if attn_mask is None:
553550
attn_mask = _build_causal_mask_bool(input_pos, k.shape[2], q.device)
554551

555552
y = F.scaled_dot_product_attention(
556-
q, k, v, attn_mask=attn_mask, is_causal=False
553+
q, k, v, attn_mask=attn_mask, is_causal=False, enable_gqa=self.enable_gqa
557554
)
558555

559556
y = y.transpose(1, 2).contiguous()

0 commit comments

Comments
 (0)