Skip to content

CK Tile FMHA/batch-prefill: support BF16/FP16 Q with FP8 KV for GQA decode #3744

@ThomasNing

Description

@ThomasNing

Request

Please add CK Tile FMHA / batch-prefill support for the serving-friendly mixed-dtype GQA attention contract:

  • Q: BF16 or FP16
  • K/V cache: FP8 E4M3, with descales
  • O: BF16 or FP16
  • GQA: arbitrary num_q_heads / num_kv_heads ratios, including 12
  • Head dim: 128
  • Decode-like use cases: q_len=1, paged/batch-prefill metadata, context length 1K+

Motivation

We are serving GLM-4.5-Air FP8 on AMD MI300X (gfx942) through SGLang/AITER. The model is non-MLA GQA:

  • num_attention_heads=96
  • num_key_value_heads=8
  • head_dim=128
  • GQA ratio: 12

The desired decode path is BF16/FP16 activation Q with FP8 KV cache. The model weights and KV cache are FP8, but the inter-layer activation / attention Q remains BF16 or FP16 in the serving stack. Quantizing Q to FP8 just to call attention adds extra overhead and an additional accuracy/perf tuning surface.

Current behavior observed

On MI300X (gfx942), AITER paged_attention_ragged supports this mixed contract:

  • BF16 Q/O
  • FP8 K/V cache
  • gqa_ratio=12
  • head_size=128
  • page/block size 1

It compiles a paged-attention specialization with kv_dtype=uint8_t and fp8_kv_dtype=fp8_e4m3.

CK Tile mha_batch_prefill_func currently supports the same GLM GQA shape if Q/K/V are all FP8. For example, a probe with batch 4, q_len=1, ctx=1024, num_q_heads=96, num_kv_heads=8, head_dim=128, all FP8 Q/K/V, BF16 output compiled:

mha_batch_prefill_fp8bf16_nlogits_nbias_nmask_nlse_ndropout_pertensor

The same probe with BF16 Q and FP8 K/V fails after compiling a BF16 query path:

mha_batch_prefill_bf16_nlogits_nbias_nmask_nlse_ndropout_nqscale
RuntimeError: query and key must have the same dtype

Desired support

Please support CK Tile FMHA / batch-prefill for:

  • BF16/FP16 Q with FP8 K/V
  • BF16/FP16 output
  • K/V descales
  • arbitrary GQA ratios such as 12
  • decode-style q_len=1 and paged/batch-prefill metadata

This would let AITER/SGLang use CK Tile for GLM-style FP8 KV cache decode without adding an explicit Q quantization path before attention.

Reference

The probe script and run notes are being tracked in the HW_Optimization GLM MI300X branch:

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions