Commit 9a997e8
Improve SDPA to handle GQA and update models to use native GQA (pytorch#18513)
Add GQA/MQA support to the Triton SDPA kernel with a "pack GQA"
optimization adapted from FlashAttention. When enable_gqa=True and
H_q > H_kv, the kernel folds multiple Q heads sharing the same KV
head into the M (sequence) dimension, so K/V are loaded once per KV
head instead of once per Q head. A tile-utilization heuristic from
FlashAttention decides when packing is beneficial (decode) vs when
simple head remapping suffices (prefill).
Update Qwen 3.5 MoE and Voxtral Realtime models to use native
enable_gqa=True instead of manually expanding KV heads via
repeat_interleave, eliminating redundant memory traffic
Decode tokens/s: 63 -> 66.8 on A100 for Qwen 3.5 MoE, 34 -> 40 tokens/s
for Voxtral Realtime1 parent 19fb497 commit 9a997e8
4 files changed
Lines changed: 822 additions & 126 deletions
File tree
- backends/cuda
- tests
- triton/kernels
- examples/models
- qwen3_5_moe
- voxtral_realtime
0 commit comments