77"""Comprehensive tests for the Triton SDPA kernel.
88
99Tests MHA, GQA, MQA with various head dims, sequence lengths, causal/non-causal,
10- bool masks, and the pack_gqa optimization. Reference outputs are computed using
11- torch SDPA with expanded KV heads (for GQA/MQA) in float32 for numerical
12- stability.
10+ and bool masks. Reference outputs are computed using torch SDPA with expanded KV
11+ heads (for GQA/MQA) in float32 for numerical stability.
1312
1413Test parametrization adapted from FlashAttention (tests/cute/test_flash_attn.py).
1514"""
@@ -31,11 +30,7 @@ def _skip_if_no_cuda():
3130def _import_sdpa ():
3231 from executorch .backends .cuda .triton .kernels .sdpa import sdpa
3332
34- try :
35- from executorch .backends .cuda .triton .kernels .sdpa import _should_pack_gqa
36- except ImportError :
37- _should_pack_gqa = None
38- return sdpa , _should_pack_gqa
33+ return sdpa
3934
4035
4136def _reference_sdpa (q , k , v , attn_mask = None , is_causal = False , scale = None ):
@@ -112,11 +107,7 @@ class TestTritonSdpa(unittest.TestCase):
112107 @classmethod
113108 def setUpClass (cls ):
114109 _skip_if_no_cuda ()
115- cls .sdpa , cls .should_pack_gqa = _import_sdpa ()
116-
117- @staticmethod
118- def _should_pack_gqa (L_q , num_groups , block_m ):
119- return TestTritonSdpa .should_pack_gqa (L_q , num_groups , block_m )
110+ cls .sdpa = _import_sdpa ()
120111
121112 # ------------------------------------------------------------------
122113 # MHA tests (no GQA, backwards compatibility)
@@ -220,12 +211,7 @@ def test_mha_non_pow2_causal(self):
220211 # ------------------------------------------------------------------
221212
222213 def test_gqa_decode (self ):
223- """GQA decode (seqlen_q=1): exercises pack_gqa path.
224-
225- This is the critical test for the pack_gqa optimization. With
226- seqlen_q=1, the heuristic should choose pack_gqa, folding all
227- Q heads into a single tile.
228- """
214+ """GQA decode (seqlen_q=1)."""
229215 for (H_q , H_kv , label ), D , Lk in itertools .product (
230216 GQA_CONFIGS , [64 , 128 , 256 ], [64 , 128 , 512 ]
231217 ):
@@ -238,13 +224,6 @@ def test_gqa_decode(self):
238224 k = torch .randn (B , H_kv , Lk , D , dtype = torch .bfloat16 , device = "cuda" )
239225 v = torch .randn (B , H_kv , Lk , D , dtype = torch .bfloat16 , device = "cuda" )
240226
241- # Verify heuristic chooses pack_gqa for decode
242- num_groups = H_q // H_kv
243- self .assertTrue (
244- self ._should_pack_gqa (Lq , num_groups , 64 ),
245- "Heuristic should choose pack_gqa for decode" ,
246- )
247-
248227 out = self .sdpa (q , k , v , enable_gqa = True )
249228 ref = _reference_sdpa (q , k , v )
250229
@@ -277,49 +256,35 @@ def test_gqa_decode_with_mask(self):
277256 self .assertLess (_max_abs_error (out , ref ), 0.05 )
278257
279258 def test_gqa_short_seqlen (self ):
280- """GQA with short seqlen_q (2-8): pack_gqa should still activate ."""
259+ """GQA with short seqlen_q (2-8)."""
281260 for Lq in [2 , 4 , 8 ]:
282261 for H_q , H_kv , label in [(8 , 2 , "gqa_4x" ), (16 , 2 , "gqa_8x" )]:
283262 with self .subTest (label = label , Lq = Lq ):
284263 B , Lk , D = 1 , 256 , 128
285- num_groups = H_q // H_kv
286264 torch .manual_seed (42 )
287265 q = torch .randn (B , H_q , Lq , D , dtype = torch .bfloat16 , device = "cuda" )
288266 k = torch .randn (B , H_kv , Lk , D , dtype = torch .bfloat16 , device = "cuda" )
289267 v = torch .randn (B , H_kv , Lk , D , dtype = torch .bfloat16 , device = "cuda" )
290268
291- # Verify heuristic activates pack_gqa
292- self .assertTrue (
293- self ._should_pack_gqa (Lq , num_groups , 64 ),
294- f"Should pack for Lq={ Lq } , groups={ num_groups } " ,
295- )
296-
297269 out = self .sdpa (q , k , v , enable_gqa = True )
298270 ref = _reference_sdpa (q , k , v )
299271
300272 self .assertFalse (torch .isnan (out ).any ())
301273 self .assertLess (_max_abs_error (out , ref ), 0.05 )
302274
303275 def test_gqa_prefill (self ):
304- """GQA prefill (long seqlen_q): should NOT use pack_gqa ."""
276+ """GQA prefill (long seqlen_q)."""
305277 for (H_q , H_kv , label ), L in itertools .product (
306278 [(8 , 2 , "gqa_4x" ), (16 , 2 , "gqa_8x" ), (6 , 1 , "mqa" )],
307279 [64 , 128 , 256 ],
308280 ):
309281 with self .subTest (label = label , L = L ):
310282 B , D = 1 , 128
311- num_groups = H_q // H_kv
312283 torch .manual_seed (42 )
313284 q = torch .randn (B , H_q , L , D , dtype = torch .bfloat16 , device = "cuda" )
314285 k = torch .randn (B , H_kv , L , D , dtype = torch .bfloat16 , device = "cuda" )
315286 v = torch .randn (B , H_kv , L , D , dtype = torch .bfloat16 , device = "cuda" )
316287
317- # Verify heuristic does NOT pack for long seqlen
318- self .assertFalse (
319- self ._should_pack_gqa (L , num_groups , 64 ),
320- f"Should NOT pack for L={ L } " ,
321- )
322-
323288 out = self .sdpa (q , k , v , is_causal = True , enable_gqa = True )
324289 ref = _reference_sdpa (q , k , v , is_causal = True )
325290
@@ -431,30 +396,6 @@ def test_qwen35_moe_config(self):
431396 _max_abs_error (out , ref ), 0.05 , f"Qwen config Lq={ Lq } Lk={ Lk } "
432397 )
433398
434- # ------------------------------------------------------------------
435- # Pack GQA heuristic tests
436- # ------------------------------------------------------------------
437-
438- def test_pack_gqa_heuristic (self ):
439- """Verify _should_pack_gqa matches expected behavior."""
440- # MHA: never pack
441- self .assertFalse (self ._should_pack_gqa (1 , 1 , 64 ))
442- self .assertFalse (self ._should_pack_gqa (128 , 1 , 64 ))
443-
444- # GQA decode (seqlen=1): always pack
445- self .assertTrue (self ._should_pack_gqa (1 , 8 , 64 ))
446- self .assertTrue (self ._should_pack_gqa (1 , 4 , 64 ))
447- self .assertTrue (self ._should_pack_gqa (1 , 2 , 64 ))
448-
449- # GQA short seqlen: pack when utilization improves
450- self .assertTrue (self ._should_pack_gqa (4 , 8 , 64 ))
451- self .assertTrue (self ._should_pack_gqa (8 , 8 , 64 ))
452-
453- # GQA long seqlen: don't pack (tiles already full)
454- self .assertFalse (self ._should_pack_gqa (64 , 8 , 64 ))
455- self .assertFalse (self ._should_pack_gqa (128 , 4 , 64 ))
456- self .assertFalse (self ._should_pack_gqa (256 , 2 , 64 ))
457-
458399 # ------------------------------------------------------------------
459400 # Edge cases and validation
460401 # ------------------------------------------------------------------
0 commit comments