Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions dispatcher/codegen/fmha/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@
# Ensure this directory and parent codegen/ are on sys.path for sibling imports
_THIS_DIR = Path(__file__).resolve().parent
_CODEGEN_DIR = _THIS_DIR.parent
_DISPATCHER_PYTHON_DIR = _CODEGEN_DIR.parent / "python"
sys.path.insert(0, str(_THIS_DIR))
sys.path.insert(0, str(_CODEGEN_DIR))
sys.path.insert(0, str(_DISPATCHER_PYTHON_DIR))

from symbol_map import ( # noqa: E402
BWD_DTYPE_MAP,
Expand All @@ -39,6 +41,10 @@
canonical_mask,
canonical_qscale,
)
from fmha_dtype_contract import ( # noqa: E402
FmhaDTypeContractKind,
dtype_contract_from_signature,
)

# Import shared hardware data from parent arch_specs_generated (generated from
# arch_specs.json by generate_arch_specs.py). Falls back to inline defaults if
Expand Down Expand Up @@ -872,6 +878,15 @@ def validate_config(

# --- Family-specific rules ---
if family == "batch_prefill":
dtype_contract = dtype_contract_from_signature(sig)
if dtype_contract.kind == FmhaDTypeContractKind.MIXED_Q_FP8_KV:
result.add_error(
"batch_prefill mixed activation/FP8-KV dtype contract is not implemented "
f"(Q={dtype_contract.q_dtype}, K={dtype_contract.k_dtype}, "
f"V={dtype_contract.v_dtype}, O={dtype_contract.o_dtype}); "
"current generated kernels use one data_type token for Q/K/V, and fp8bf16 "
"means FP8 Q/K/V with BF16 output"
)
if sig.get("vlayout", "r") != "r":
result.add_error("batch_prefill only supports row-major V layout")
if not sig.get("paged_kv", False):
Expand Down
144 changes: 144 additions & 0 deletions dispatcher/python/fmha_dtype_contract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
#!/usr/bin/env python3

# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT

from dataclasses import dataclass
from enum import Enum
from typing import Mapping, Optional


class FmhaDTypeContractKind(Enum):
HOMOGENEOUS = "homogeneous"
ALL_FP8_WITH_BF16_OUTPUT = "all_fp8_with_bf16_output"
ALL_FP8_WITH_FP32_OUTPUT = "all_fp8_with_fp32_output"
MIXED_Q_FP8_KV = "mixed_q_fp8_kv"
UNSUPPORTED = "unsupported"


@dataclass(frozen=True)
class FmhaDTypeContract:
data_type: str
q_dtype: str
k_dtype: str
v_dtype: str
o_dtype: str
kind: FmhaDTypeContractKind

@property
def uses_fp8_kv(self) -> bool:
return _is_fp8(self.k_dtype) and _is_fp8(self.v_dtype)


_TOKEN_CONTRACTS = {
"fp16": ("fp16", "fp16", "fp16", "fp16"),
"bf16": ("bf16", "bf16", "bf16", "bf16"),
"fp32": ("fp32", "fp32", "fp32", "fp32"),
"fp8": ("fp8", "fp8", "fp8", "fp8"),
"bf8": ("bf8", "bf8", "bf8", "bf8"),
"fp8bf16": ("fp8", "fp8", "fp8", "bf16"),
"fp8fp32": ("fp8", "fp8", "fp8", "fp32"),
"fp8fp16": ("fp8", "fp8", "fp8", "fp16"),
"mxfp8": ("fp8", "fp8", "fp8", "fp32"),
}


def _normalize_dtype(dtype: Optional[str]) -> Optional[str]:
if dtype is None:
return None

normalized = str(dtype).lower()
aliases = {
"float16": "fp16",
"half": "fp16",
"uint16": "bf16",
"bfloat16": "bf16",
"float32": "fp32",
"uint8": "fp8",
"fp8_e4m3": "fp8",
"fp8_e4m3fnuz": "fp8",
"float8_e4m3fnuz": "fp8",
}
return aliases.get(normalized, normalized)


def _is_fp8(dtype: str) -> bool:
return _normalize_dtype(dtype) in {"fp8", "bf8", "mxfp8"}


def _classify(
q_dtype: str, k_dtype: str, v_dtype: str, o_dtype: str
) -> FmhaDTypeContractKind:
q_dtype = _normalize_dtype(q_dtype) or ""
k_dtype = _normalize_dtype(k_dtype) or ""
v_dtype = _normalize_dtype(v_dtype) or ""
o_dtype = _normalize_dtype(o_dtype) or ""

if q_dtype == k_dtype == v_dtype == o_dtype:
return FmhaDTypeContractKind.HOMOGENEOUS
if _is_fp8(q_dtype) and _is_fp8(k_dtype) and _is_fp8(v_dtype):
if o_dtype == "bf16":
return FmhaDTypeContractKind.ALL_FP8_WITH_BF16_OUTPUT
if o_dtype == "fp32":
return FmhaDTypeContractKind.ALL_FP8_WITH_FP32_OUTPUT
if (
q_dtype in {"fp16", "bf16"}
and _is_fp8(k_dtype)
and _is_fp8(v_dtype)
and o_dtype in {"fp16", "bf16"}
):
return FmhaDTypeContractKind.MIXED_Q_FP8_KV
return FmhaDTypeContractKind.UNSUPPORTED


def dtype_contract_from_components(
data_type: str,
q_dtype: str,
k_dtype: str,
v_dtype: str,
o_dtype: str,
) -> FmhaDTypeContract:
data_type = _normalize_dtype(data_type) or data_type
q_dtype = _normalize_dtype(q_dtype) or ""
k_dtype = _normalize_dtype(k_dtype) or ""
v_dtype = _normalize_dtype(v_dtype) or ""
o_dtype = _normalize_dtype(o_dtype) or ""
return FmhaDTypeContract(
data_type=data_type,
q_dtype=q_dtype,
k_dtype=k_dtype,
v_dtype=v_dtype,
o_dtype=o_dtype,
kind=_classify(q_dtype, k_dtype, v_dtype, o_dtype),
)


def dtype_contract_from_data_type(data_type: str) -> FmhaDTypeContract:
data_type = _normalize_dtype(data_type) or data_type
q_dtype, k_dtype, v_dtype, o_dtype = _TOKEN_CONTRACTS.get(
data_type, (data_type, data_type, data_type, data_type)
)
return dtype_contract_from_components(data_type, q_dtype, k_dtype, v_dtype, o_dtype)


def dtype_contract_from_signature(signature: Mapping[str, object]) -> FmhaDTypeContract:
data_type = str(signature.get("data_type", "fp16"))
inferred = dtype_contract_from_data_type(data_type)
kv_dtype = signature.get("kv_data_type", signature.get("kv_dtype"))

q_dtype = signature.get("q_data_type", signature.get("q_dtype", inferred.q_dtype))
k_dtype = signature.get(
"k_data_type", signature.get("k_dtype", kv_dtype or inferred.k_dtype)
)
v_dtype = signature.get(
"v_data_type", signature.get("v_dtype", kv_dtype or inferred.v_dtype)
)
o_dtype = signature.get("o_data_type", signature.get("o_dtype", inferred.o_dtype))

return dtype_contract_from_components(
data_type,
str(q_dtype),
str(k_dtype),
str(v_dtype),
str(o_dtype),
)
63 changes: 63 additions & 0 deletions dispatcher/python/fmha_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@

import numpy as np

from fmha_dtype_contract import (
FmhaDTypeContract,
FmhaDTypeContractKind,
dtype_contract_from_components,
dtype_contract_from_data_type,
)


# =============================================================================
# Utility helpers
Expand Down Expand Up @@ -350,6 +357,60 @@ def _bf16_to_float32(arr: np.ndarray) -> np.ndarray:
return (arr.astype(np.uint32) << 16).view(np.float32)


def _array_contract_dtype(arr: np.ndarray, fallback: str) -> str:
if arr.dtype == np.uint8:
return "fp8"
if arr.dtype == np.uint16:
return "bf16"
if arr.dtype == np.float16:
return "fp16"
if arr.dtype == np.float32:
return "fp32"
return fallback


def get_batch_prefill_dtype_contract(
data_type: str,
Q: np.ndarray,
K: np.ndarray,
V: np.ndarray,
) -> FmhaDTypeContract:
"""Classify the public batch_prefill dtype contract requested by the caller."""
inferred = dtype_contract_from_data_type(data_type)
return dtype_contract_from_components(
data_type=data_type,
q_dtype=_array_contract_dtype(Q, inferred.q_dtype),
k_dtype=_array_contract_dtype(K, inferred.k_dtype),
v_dtype=_array_contract_dtype(V, inferred.v_dtype),
o_dtype=inferred.o_dtype,
)


def _validate_batch_prefill_input_dtypes(
api_family: str,
data_type: str,
Q: np.ndarray,
K: np.ndarray,
V: np.ndarray,
) -> None:
"""Reject dtype contracts that CK Tile batch_prefill cannot dispatch yet."""
if api_family != "batch_prefill":
return

contract = get_batch_prefill_dtype_contract(data_type, Q, K, V)
if contract.kind == FmhaDTypeContractKind.MIXED_Q_FP8_KV:
raise ValueError(
"CK Tile batch_prefill does not yet support the mixed activation/FP8-KV "
"dtype contract "
f"(data_type={data_type}, Q={contract.q_dtype}, K={contract.k_dtype}, "
f"V={contract.v_dtype}, O={contract.o_dtype}). "
"The current dispatcher and generated kernels select Q/K/V types from a single "
"data_type token; fp8bf16 means FP8 Q/K/V with BF16 output. Use AITER "
"paged_attention_ragged or another fallback for BF16/FP16 Q with FP8 KV until "
"CK Tile has mixed Q/KV kernel instances."
)


def cpu_attention_fwd(
Q: np.ndarray,
K: np.ndarray,
Expand Down Expand Up @@ -811,6 +872,8 @@ def run(
Returns:
FmhaResult with output array, timing, TFLOPS
"""
_validate_batch_prefill_input_dtypes(api_family, data_type, Q, K, V)

# Map CK dtype to numpy dtype for buffer allocation.
# bf16 is stored as uint16 (upper 16 bits of float32).
# fp8 uses uint8 (1 byte per element).
Expand Down
38 changes: 38 additions & 0 deletions dispatcher/tests/test_fmha_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,44 @@ def test_batch_prefill_valid_group(self):
r = validate_config(cfg, SPECS)
self.assertTrue(r.valid, r.errors)

def test_batch_prefill_rejects_mixed_activation_fp8_kv_contract(self):
cfg = _base_config(
family="batch_prefill",
dtype="bf16",
pipeline="qr_async",
mode="group",
paged_kv=True,
page_size=16,
q_data_type="bf16",
kv_data_type="fp8",
o_data_type="bf16",
)

r = validate_config(cfg, SPECS)

self.assertFalse(r.valid)
self.assertTrue(
any("mixed activation/FP8-KV dtype contract" in e for e in r.errors),
r.errors,
)

def test_batch_prefill_keeps_all_fp8_bf16_output_contract_valid(self):
cfg = _base_config(
family="batch_prefill",
dtype="fp8bf16",
pipeline="qr_async",
mode="group",
paged_kv=True,
page_size=16,
q_data_type="fp8",
kv_data_type="fp8",
o_data_type="bf16",
)

r = validate_config(cfg, SPECS)

self.assertTrue(r.valid, r.errors)

def test_splitkv_combine_bn1_must_be_32(self):
cfg = _base_config(family="fwd_splitkv_combine", pipeline="qr")
cfg["algorithm"]["tile"][3] = 64
Expand Down
81 changes: 81 additions & 0 deletions dispatcher/tests/test_fmha_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
#!/usr/bin/env python3

# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT

import sys
import unittest
from pathlib import Path

import numpy as np

ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT / "python"))

from fmha_dtype_contract import FmhaDTypeContractKind # noqa: E402
from fmha_utils import ( # noqa: E402
_validate_batch_prefill_input_dtypes,
get_batch_prefill_dtype_contract,
)


class TestBatchPrefillDtypeValidation(unittest.TestCase):
def test_mixed_bf16_q_fp8_kv_gqa_decode_reports_unsupported(self):
batch = 4
q_len = 1
ctx_len = 1024
num_q_heads = 96
num_kv_heads = 8
head_dim = 128

q = np.zeros((batch, num_q_heads, q_len, head_dim), dtype=np.uint16)
k = np.zeros((batch, num_kv_heads, ctx_len, head_dim), dtype=np.uint8)
v = np.zeros((batch, num_kv_heads, ctx_len, head_dim), dtype=np.uint8)

with self.assertRaisesRegex(
ValueError,
"mixed activation/FP8-KV dtype contract",
):
_validate_batch_prefill_input_dtypes("batch_prefill", "bf16", q, k, v)

def test_mixed_fp16_q_fp8_kv_reports_unsupported(self):
q = np.zeros((4, 96, 1, 128), dtype=np.float16)
k = np.zeros((4, 8, 1024, 128), dtype=np.uint8)
v = np.zeros((4, 8, 1024, 128), dtype=np.uint8)

contract = get_batch_prefill_dtype_contract("fp16", q, k, v)

self.assertEqual(contract.kind, FmhaDTypeContractKind.MIXED_Q_FP8_KV)
with self.assertRaisesRegex(ValueError, "AITER paged_attention_ragged"):
_validate_batch_prefill_input_dtypes("batch_prefill", "fp16", q, k, v)

def test_all_fp8_bf16_output_path_remains_allowed(self):
q = np.zeros((1, 96, 1, 128), dtype=np.uint8)
k = np.zeros((1, 8, 128, 128), dtype=np.uint8)
v = np.zeros((1, 8, 128, 128), dtype=np.uint8)

contract = get_batch_prefill_dtype_contract("fp8bf16", q, k, v)

self.assertEqual(contract.kind, FmhaDTypeContractKind.ALL_FP8_WITH_BF16_OUTPUT)
_validate_batch_prefill_input_dtypes("batch_prefill", "fp8bf16", q, k, v)

def test_all_bf16_batch_prefill_path_remains_allowed(self):
q = np.zeros((1, 96, 1, 128), dtype=np.uint16)
k = np.zeros((1, 8, 128, 128), dtype=np.uint16)
v = np.zeros((1, 8, 128, 128), dtype=np.uint16)

contract = get_batch_prefill_dtype_contract("bf16", q, k, v)

self.assertEqual(contract.kind, FmhaDTypeContractKind.HOMOGENEOUS)
_validate_batch_prefill_input_dtypes("batch_prefill", "bf16", q, k, v)

def test_non_batch_prefill_paths_are_unchanged(self):
q = np.zeros((1, 96, 1, 128), dtype=np.uint16)
k = np.zeros((1, 8, 128, 128), dtype=np.uint8)
v = np.zeros((1, 8, 128, 128), dtype=np.uint8)

_validate_batch_prefill_input_dtypes("fwd", "bf16", q, k, v)


if __name__ == "__main__":
unittest.main()