Skip to content

Commit 9a997e8

Browse files
mergennachinrascani
authored andcommitted
Improve SDPA to handle GQA and update models to use native GQA (pytorch#18513)
Add GQA/MQA support to the Triton SDPA kernel with a "pack GQA" optimization adapted from FlashAttention. When enable_gqa=True and H_q > H_kv, the kernel folds multiple Q heads sharing the same KV head into the M (sequence) dimension, so K/V are loaded once per KV head instead of once per Q head. A tile-utilization heuristic from FlashAttention decides when packing is beneficial (decode) vs when simple head remapping suffices (prefill). Update Qwen 3.5 MoE and Voxtral Realtime models to use native enable_gqa=True instead of manually expanding KV heads via repeat_interleave, eliminating redundant memory traffic Decode tokens/s: 63 -> 66.8 on A100 for Qwen 3.5 MoE, 34 -> 40 tokens/s for Voxtral Realtime
1 parent 19fb497 commit 9a997e8

4 files changed

Lines changed: 822 additions & 126 deletions

File tree

0 commit comments

Comments
 (0)