Skip to content

Commit 3c4a653

Browse files
committed
Addressing digantdesai's feedback
1 parent 9ee3a11 commit 3c4a653

2 files changed

Lines changed: 15 additions & 72 deletions

File tree

backends/cuda/tests/test_triton_sdpa.py

Lines changed: 7 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,8 @@
77
"""Comprehensive tests for the Triton SDPA kernel.
88
99
Tests MHA, GQA, MQA with various head dims, sequence lengths, causal/non-causal,
10-
bool masks, and the pack_gqa optimization. Reference outputs are computed using
11-
torch SDPA with expanded KV heads (for GQA/MQA) in float32 for numerical
12-
stability.
10+
and bool masks. Reference outputs are computed using torch SDPA with expanded KV
11+
heads (for GQA/MQA) in float32 for numerical stability.
1312
1413
Test parametrization adapted from FlashAttention (tests/cute/test_flash_attn.py).
1514
"""
@@ -31,11 +30,7 @@ def _skip_if_no_cuda():
3130
def _import_sdpa():
3231
from executorch.backends.cuda.triton.kernels.sdpa import sdpa
3332

34-
try:
35-
from executorch.backends.cuda.triton.kernels.sdpa import _should_pack_gqa
36-
except ImportError:
37-
_should_pack_gqa = None
38-
return sdpa, _should_pack_gqa
33+
return sdpa
3934

4035

4136
def _reference_sdpa(q, k, v, attn_mask=None, is_causal=False, scale=None):
@@ -112,11 +107,7 @@ class TestTritonSdpa(unittest.TestCase):
112107
@classmethod
113108
def setUpClass(cls):
114109
_skip_if_no_cuda()
115-
cls.sdpa, cls.should_pack_gqa = _import_sdpa()
116-
117-
@staticmethod
118-
def _should_pack_gqa(L_q, num_groups, block_m):
119-
return TestTritonSdpa.should_pack_gqa(L_q, num_groups, block_m)
110+
cls.sdpa = _import_sdpa()
120111

121112
# ------------------------------------------------------------------
122113
# MHA tests (no GQA, backwards compatibility)
@@ -220,12 +211,7 @@ def test_mha_non_pow2_causal(self):
220211
# ------------------------------------------------------------------
221212

222213
def test_gqa_decode(self):
223-
"""GQA decode (seqlen_q=1): exercises pack_gqa path.
224-
225-
This is the critical test for the pack_gqa optimization. With
226-
seqlen_q=1, the heuristic should choose pack_gqa, folding all
227-
Q heads into a single tile.
228-
"""
214+
"""GQA decode (seqlen_q=1)."""
229215
for (H_q, H_kv, label), D, Lk in itertools.product(
230216
GQA_CONFIGS, [64, 128, 256], [64, 128, 512]
231217
):
@@ -238,13 +224,6 @@ def test_gqa_decode(self):
238224
k = torch.randn(B, H_kv, Lk, D, dtype=torch.bfloat16, device="cuda")
239225
v = torch.randn(B, H_kv, Lk, D, dtype=torch.bfloat16, device="cuda")
240226

241-
# Verify heuristic chooses pack_gqa for decode
242-
num_groups = H_q // H_kv
243-
self.assertTrue(
244-
self._should_pack_gqa(Lq, num_groups, 64),
245-
"Heuristic should choose pack_gqa for decode",
246-
)
247-
248227
out = self.sdpa(q, k, v, enable_gqa=True)
249228
ref = _reference_sdpa(q, k, v)
250229

@@ -277,49 +256,35 @@ def test_gqa_decode_with_mask(self):
277256
self.assertLess(_max_abs_error(out, ref), 0.05)
278257

279258
def test_gqa_short_seqlen(self):
280-
"""GQA with short seqlen_q (2-8): pack_gqa should still activate."""
259+
"""GQA with short seqlen_q (2-8)."""
281260
for Lq in [2, 4, 8]:
282261
for H_q, H_kv, label in [(8, 2, "gqa_4x"), (16, 2, "gqa_8x")]:
283262
with self.subTest(label=label, Lq=Lq):
284263
B, Lk, D = 1, 256, 128
285-
num_groups = H_q // H_kv
286264
torch.manual_seed(42)
287265
q = torch.randn(B, H_q, Lq, D, dtype=torch.bfloat16, device="cuda")
288266
k = torch.randn(B, H_kv, Lk, D, dtype=torch.bfloat16, device="cuda")
289267
v = torch.randn(B, H_kv, Lk, D, dtype=torch.bfloat16, device="cuda")
290268

291-
# Verify heuristic activates pack_gqa
292-
self.assertTrue(
293-
self._should_pack_gqa(Lq, num_groups, 64),
294-
f"Should pack for Lq={Lq}, groups={num_groups}",
295-
)
296-
297269
out = self.sdpa(q, k, v, enable_gqa=True)
298270
ref = _reference_sdpa(q, k, v)
299271

300272
self.assertFalse(torch.isnan(out).any())
301273
self.assertLess(_max_abs_error(out, ref), 0.05)
302274

303275
def test_gqa_prefill(self):
304-
"""GQA prefill (long seqlen_q): should NOT use pack_gqa."""
276+
"""GQA prefill (long seqlen_q)."""
305277
for (H_q, H_kv, label), L in itertools.product(
306278
[(8, 2, "gqa_4x"), (16, 2, "gqa_8x"), (6, 1, "mqa")],
307279
[64, 128, 256],
308280
):
309281
with self.subTest(label=label, L=L):
310282
B, D = 1, 128
311-
num_groups = H_q // H_kv
312283
torch.manual_seed(42)
313284
q = torch.randn(B, H_q, L, D, dtype=torch.bfloat16, device="cuda")
314285
k = torch.randn(B, H_kv, L, D, dtype=torch.bfloat16, device="cuda")
315286
v = torch.randn(B, H_kv, L, D, dtype=torch.bfloat16, device="cuda")
316287

317-
# Verify heuristic does NOT pack for long seqlen
318-
self.assertFalse(
319-
self._should_pack_gqa(L, num_groups, 64),
320-
f"Should NOT pack for L={L}",
321-
)
322-
323288
out = self.sdpa(q, k, v, is_causal=True, enable_gqa=True)
324289
ref = _reference_sdpa(q, k, v, is_causal=True)
325290

@@ -431,30 +396,6 @@ def test_qwen35_moe_config(self):
431396
_max_abs_error(out, ref), 0.05, f"Qwen config Lq={Lq} Lk={Lk}"
432397
)
433398

434-
# ------------------------------------------------------------------
435-
# Pack GQA heuristic tests
436-
# ------------------------------------------------------------------
437-
438-
def test_pack_gqa_heuristic(self):
439-
"""Verify _should_pack_gqa matches expected behavior."""
440-
# MHA: never pack
441-
self.assertFalse(self._should_pack_gqa(1, 1, 64))
442-
self.assertFalse(self._should_pack_gqa(128, 1, 64))
443-
444-
# GQA decode (seqlen=1): always pack
445-
self.assertTrue(self._should_pack_gqa(1, 8, 64))
446-
self.assertTrue(self._should_pack_gqa(1, 4, 64))
447-
self.assertTrue(self._should_pack_gqa(1, 2, 64))
448-
449-
# GQA short seqlen: pack when utilization improves
450-
self.assertTrue(self._should_pack_gqa(4, 8, 64))
451-
self.assertTrue(self._should_pack_gqa(8, 8, 64))
452-
453-
# GQA long seqlen: don't pack (tiles already full)
454-
self.assertFalse(self._should_pack_gqa(64, 8, 64))
455-
self.assertFalse(self._should_pack_gqa(128, 4, 64))
456-
self.assertFalse(self._should_pack_gqa(256, 2, 64))
457-
458399
# ------------------------------------------------------------------
459400
# Edge cases and validation
460401
# ------------------------------------------------------------------

backends/cuda/triton/kernels/sdpa.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -240,22 +240,24 @@ def _sdpa_fwd_kernel_non_pow2(
240240
NEG_INF: tl.constexpr = float("-inf")
241241

242242
for start_n in tl.range(0, LK, BLOCK_N, num_stages=2):
243-
kn = start_n + tl.arange(0, BLOCK_N)
244-
kv_col_mask = kn < LK
243+
offs_n = start_n + tl.arange(0, BLOCK_N)
244+
kv_col_mask = offs_n < LK
245245

246-
k_ptrs = k_base + (kn[:, None] * stride_kl + offs_d[None, :] * stride_kd)
246+
k_ptrs = k_base + (offs_n[:, None] * stride_kl + offs_d[None, :] * stride_kd)
247247
k = tl.load(k_ptrs, mask=kv_col_mask[:, None] & d_mask[None, :], other=0.0)
248248

249249
qk = tl.dot(q, tl.trans(k))
250250
qk = (qk * qk_scale_log2).to(tl.float32)
251251

252252
if IS_CAUSAL:
253-
causal_mask = kn[None, :] > seq_pos[:, None]
253+
causal_mask = offs_n[None, :] > seq_pos[:, None]
254254
qk = tl.where(causal_mask, tl.full(qk.shape, NEG_INF, dtype=tl.float32), qk)
255255

256256
if HAS_MASK:
257257
m_ptrs = (
258-
mask_b_base + seq_pos[:, None] * stride_mlq + kn[None, :] * stride_mlk
258+
mask_b_base
259+
+ seq_pos[:, None] * stride_mlq
260+
+ offs_n[None, :] * stride_mlk
259261
)
260262
tile_valid = row_valid[:, None] & kv_col_mask[None, :]
261263
keep = tl.load(m_ptrs, mask=tile_valid, other=False)
@@ -276,7 +278,7 @@ def _sdpa_fwd_kernel_non_pow2(
276278

277279
acc = (acc * alpha[:, None]).to(tl.float32)
278280

279-
v_ptrs = v_base + (kn[:, None] * stride_vl + offs_d[None, :] * stride_vd)
281+
v_ptrs = v_base + (offs_n[:, None] * stride_vl + offs_d[None, :] * stride_vd)
280282
v = tl.load(v_ptrs, mask=kv_col_mask[:, None] & d_mask[None, :], other=0.0)
281283

282284
acc = tl.dot(p.to(v.dtype), v, acc).to(tl.float32)

0 commit comments

Comments
 (0)