Request
Please add CK Tile FMHA / batch-prefill support for the serving-friendly mixed-dtype GQA attention contract:
Q: BF16 or FP16
K/V cache: FP8 E4M3, with descales
O: BF16 or FP16
- GQA: arbitrary
num_q_heads / num_kv_heads ratios, including 12
- Head dim:
128
- Decode-like use cases:
q_len=1, paged/batch-prefill metadata, context length 1K+
Motivation
We are serving GLM-4.5-Air FP8 on AMD MI300X (gfx942) through SGLang/AITER. The model is non-MLA GQA:
num_attention_heads=96
num_key_value_heads=8
head_dim=128
- GQA ratio:
12
The desired decode path is BF16/FP16 activation Q with FP8 KV cache. The model weights and KV cache are FP8, but the inter-layer activation / attention Q remains BF16 or FP16 in the serving stack. Quantizing Q to FP8 just to call attention adds extra overhead and an additional accuracy/perf tuning surface.
Current behavior observed
On MI300X (gfx942), AITER paged_attention_ragged supports this mixed contract:
- BF16 Q/O
- FP8 K/V cache
gqa_ratio=12
head_size=128
- page/block size
1
It compiles a paged-attention specialization with kv_dtype=uint8_t and fp8_kv_dtype=fp8_e4m3.
CK Tile mha_batch_prefill_func currently supports the same GLM GQA shape if Q/K/V are all FP8. For example, a probe with batch 4, q_len=1, ctx=1024, num_q_heads=96, num_kv_heads=8, head_dim=128, all FP8 Q/K/V, BF16 output compiled:
mha_batch_prefill_fp8bf16_nlogits_nbias_nmask_nlse_ndropout_pertensor
The same probe with BF16 Q and FP8 K/V fails after compiling a BF16 query path:
mha_batch_prefill_bf16_nlogits_nbias_nmask_nlse_ndropout_nqscale
RuntimeError: query and key must have the same dtype
Desired support
Please support CK Tile FMHA / batch-prefill for:
- BF16/FP16 Q with FP8 K/V
- BF16/FP16 output
- K/V descales
- arbitrary GQA ratios such as
12
- decode-style
q_len=1 and paged/batch-prefill metadata
This would let AITER/SGLang use CK Tile for GLM-style FP8 KV cache decode without adding an explicit Q quantization path before attention.
Reference
The probe script and run notes are being tracked in the HW_Optimization GLM MI300X branch:
Request
Please add CK Tile FMHA / batch-prefill support for the serving-friendly mixed-dtype GQA attention contract:
Q: BF16 or FP16K/V cache: FP8 E4M3, with descalesO: BF16 or FP16num_q_heads / num_kv_headsratios, including12128q_len=1, paged/batch-prefill metadata, context length 1K+Motivation
We are serving GLM-4.5-Air FP8 on AMD MI300X (
gfx942) through SGLang/AITER. The model is non-MLA GQA:num_attention_heads=96num_key_value_heads=8head_dim=12812The desired decode path is BF16/FP16 activation Q with FP8 KV cache. The model weights and KV cache are FP8, but the inter-layer activation / attention Q remains BF16 or FP16 in the serving stack. Quantizing Q to FP8 just to call attention adds extra overhead and an additional accuracy/perf tuning surface.
Current behavior observed
On MI300X (
gfx942), AITERpaged_attention_raggedsupports this mixed contract:gqa_ratio=12head_size=1281It compiles a paged-attention specialization with
kv_dtype=uint8_tandfp8_kv_dtype=fp8_e4m3.CK Tile
mha_batch_prefill_funccurrently supports the same GLM GQA shape if Q/K/V are all FP8. For example, a probe with batch4,q_len=1,ctx=1024,num_q_heads=96,num_kv_heads=8,head_dim=128, all FP8 Q/K/V, BF16 output compiled:The same probe with BF16 Q and FP8 K/V fails after compiling a BF16 query path:
Desired support
Please support CK Tile FMHA / batch-prefill for:
12q_len=1and paged/batch-prefill metadataThis would let AITER/SGLang use CK Tile for GLM-style FP8 KV cache decode without adding an explicit Q quantization path before attention.
Reference
The probe script and run notes are being tracked in the HW_Optimization GLM MI300X branch:
amd/scripts/aiter_ck_gqa_fp8_probe.py