diff --git a/dispatcher/codegen/fmha/validation.py b/dispatcher/codegen/fmha/validation.py index 20b3a00540..2fb791a8ba 100644 --- a/dispatcher/codegen/fmha/validation.py +++ b/dispatcher/codegen/fmha/validation.py @@ -29,8 +29,10 @@ # Ensure this directory and parent codegen/ are on sys.path for sibling imports _THIS_DIR = Path(__file__).resolve().parent _CODEGEN_DIR = _THIS_DIR.parent +_DISPATCHER_PYTHON_DIR = _CODEGEN_DIR.parent / "python" sys.path.insert(0, str(_THIS_DIR)) sys.path.insert(0, str(_CODEGEN_DIR)) +sys.path.insert(0, str(_DISPATCHER_PYTHON_DIR)) from symbol_map import ( # noqa: E402 BWD_DTYPE_MAP, @@ -39,6 +41,10 @@ canonical_mask, canonical_qscale, ) +from fmha_dtype_contract import ( # noqa: E402 + FmhaDTypeContractKind, + dtype_contract_from_signature, +) # Import shared hardware data from parent arch_specs_generated (generated from # arch_specs.json by generate_arch_specs.py). Falls back to inline defaults if @@ -872,6 +878,15 @@ def validate_config( # --- Family-specific rules --- if family == "batch_prefill": + dtype_contract = dtype_contract_from_signature(sig) + if dtype_contract.kind == FmhaDTypeContractKind.MIXED_Q_FP8_KV: + result.add_error( + "batch_prefill mixed activation/FP8-KV dtype contract is not implemented " + f"(Q={dtype_contract.q_dtype}, K={dtype_contract.k_dtype}, " + f"V={dtype_contract.v_dtype}, O={dtype_contract.o_dtype}); " + "current generated kernels use one data_type token for Q/K/V, and fp8bf16 " + "means FP8 Q/K/V with BF16 output" + ) if sig.get("vlayout", "r") != "r": result.add_error("batch_prefill only supports row-major V layout") if not sig.get("paged_kv", False): diff --git a/dispatcher/python/fmha_dtype_contract.py b/dispatcher/python/fmha_dtype_contract.py new file mode 100644 index 0000000000..21b6e93296 --- /dev/null +++ b/dispatcher/python/fmha_dtype_contract.py @@ -0,0 +1,144 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +from dataclasses import dataclass +from enum import Enum +from typing import Mapping, Optional + + +class FmhaDTypeContractKind(Enum): + HOMOGENEOUS = "homogeneous" + ALL_FP8_WITH_BF16_OUTPUT = "all_fp8_with_bf16_output" + ALL_FP8_WITH_FP32_OUTPUT = "all_fp8_with_fp32_output" + MIXED_Q_FP8_KV = "mixed_q_fp8_kv" + UNSUPPORTED = "unsupported" + + +@dataclass(frozen=True) +class FmhaDTypeContract: + data_type: str + q_dtype: str + k_dtype: str + v_dtype: str + o_dtype: str + kind: FmhaDTypeContractKind + + @property + def uses_fp8_kv(self) -> bool: + return _is_fp8(self.k_dtype) and _is_fp8(self.v_dtype) + + +_TOKEN_CONTRACTS = { + "fp16": ("fp16", "fp16", "fp16", "fp16"), + "bf16": ("bf16", "bf16", "bf16", "bf16"), + "fp32": ("fp32", "fp32", "fp32", "fp32"), + "fp8": ("fp8", "fp8", "fp8", "fp8"), + "bf8": ("bf8", "bf8", "bf8", "bf8"), + "fp8bf16": ("fp8", "fp8", "fp8", "bf16"), + "fp8fp32": ("fp8", "fp8", "fp8", "fp32"), + "fp8fp16": ("fp8", "fp8", "fp8", "fp16"), + "mxfp8": ("fp8", "fp8", "fp8", "fp32"), +} + + +def _normalize_dtype(dtype: Optional[str]) -> Optional[str]: + if dtype is None: + return None + + normalized = str(dtype).lower() + aliases = { + "float16": "fp16", + "half": "fp16", + "uint16": "bf16", + "bfloat16": "bf16", + "float32": "fp32", + "uint8": "fp8", + "fp8_e4m3": "fp8", + "fp8_e4m3fnuz": "fp8", + "float8_e4m3fnuz": "fp8", + } + return aliases.get(normalized, normalized) + + +def _is_fp8(dtype: str) -> bool: + return _normalize_dtype(dtype) in {"fp8", "bf8", "mxfp8"} + + +def _classify( + q_dtype: str, k_dtype: str, v_dtype: str, o_dtype: str +) -> FmhaDTypeContractKind: + q_dtype = _normalize_dtype(q_dtype) or "" + k_dtype = _normalize_dtype(k_dtype) or "" + v_dtype = _normalize_dtype(v_dtype) or "" + o_dtype = _normalize_dtype(o_dtype) or "" + + if q_dtype == k_dtype == v_dtype == o_dtype: + return FmhaDTypeContractKind.HOMOGENEOUS + if _is_fp8(q_dtype) and _is_fp8(k_dtype) and _is_fp8(v_dtype): + if o_dtype == "bf16": + return FmhaDTypeContractKind.ALL_FP8_WITH_BF16_OUTPUT + if o_dtype == "fp32": + return FmhaDTypeContractKind.ALL_FP8_WITH_FP32_OUTPUT + if ( + q_dtype in {"fp16", "bf16"} + and _is_fp8(k_dtype) + and _is_fp8(v_dtype) + and o_dtype in {"fp16", "bf16"} + ): + return FmhaDTypeContractKind.MIXED_Q_FP8_KV + return FmhaDTypeContractKind.UNSUPPORTED + + +def dtype_contract_from_components( + data_type: str, + q_dtype: str, + k_dtype: str, + v_dtype: str, + o_dtype: str, +) -> FmhaDTypeContract: + data_type = _normalize_dtype(data_type) or data_type + q_dtype = _normalize_dtype(q_dtype) or "" + k_dtype = _normalize_dtype(k_dtype) or "" + v_dtype = _normalize_dtype(v_dtype) or "" + o_dtype = _normalize_dtype(o_dtype) or "" + return FmhaDTypeContract( + data_type=data_type, + q_dtype=q_dtype, + k_dtype=k_dtype, + v_dtype=v_dtype, + o_dtype=o_dtype, + kind=_classify(q_dtype, k_dtype, v_dtype, o_dtype), + ) + + +def dtype_contract_from_data_type(data_type: str) -> FmhaDTypeContract: + data_type = _normalize_dtype(data_type) or data_type + q_dtype, k_dtype, v_dtype, o_dtype = _TOKEN_CONTRACTS.get( + data_type, (data_type, data_type, data_type, data_type) + ) + return dtype_contract_from_components(data_type, q_dtype, k_dtype, v_dtype, o_dtype) + + +def dtype_contract_from_signature(signature: Mapping[str, object]) -> FmhaDTypeContract: + data_type = str(signature.get("data_type", "fp16")) + inferred = dtype_contract_from_data_type(data_type) + kv_dtype = signature.get("kv_data_type", signature.get("kv_dtype")) + + q_dtype = signature.get("q_data_type", signature.get("q_dtype", inferred.q_dtype)) + k_dtype = signature.get( + "k_data_type", signature.get("k_dtype", kv_dtype or inferred.k_dtype) + ) + v_dtype = signature.get( + "v_data_type", signature.get("v_dtype", kv_dtype or inferred.v_dtype) + ) + o_dtype = signature.get("o_data_type", signature.get("o_dtype", inferred.o_dtype)) + + return dtype_contract_from_components( + data_type, + str(q_dtype), + str(k_dtype), + str(v_dtype), + str(o_dtype), + ) diff --git a/dispatcher/python/fmha_utils.py b/dispatcher/python/fmha_utils.py index 5d3d085496..0c30d823b7 100644 --- a/dispatcher/python/fmha_utils.py +++ b/dispatcher/python/fmha_utils.py @@ -28,6 +28,13 @@ import numpy as np +from fmha_dtype_contract import ( + FmhaDTypeContract, + FmhaDTypeContractKind, + dtype_contract_from_components, + dtype_contract_from_data_type, +) + # ============================================================================= # Utility helpers @@ -350,6 +357,60 @@ def _bf16_to_float32(arr: np.ndarray) -> np.ndarray: return (arr.astype(np.uint32) << 16).view(np.float32) +def _array_contract_dtype(arr: np.ndarray, fallback: str) -> str: + if arr.dtype == np.uint8: + return "fp8" + if arr.dtype == np.uint16: + return "bf16" + if arr.dtype == np.float16: + return "fp16" + if arr.dtype == np.float32: + return "fp32" + return fallback + + +def get_batch_prefill_dtype_contract( + data_type: str, + Q: np.ndarray, + K: np.ndarray, + V: np.ndarray, +) -> FmhaDTypeContract: + """Classify the public batch_prefill dtype contract requested by the caller.""" + inferred = dtype_contract_from_data_type(data_type) + return dtype_contract_from_components( + data_type=data_type, + q_dtype=_array_contract_dtype(Q, inferred.q_dtype), + k_dtype=_array_contract_dtype(K, inferred.k_dtype), + v_dtype=_array_contract_dtype(V, inferred.v_dtype), + o_dtype=inferred.o_dtype, + ) + + +def _validate_batch_prefill_input_dtypes( + api_family: str, + data_type: str, + Q: np.ndarray, + K: np.ndarray, + V: np.ndarray, +) -> None: + """Reject dtype contracts that CK Tile batch_prefill cannot dispatch yet.""" + if api_family != "batch_prefill": + return + + contract = get_batch_prefill_dtype_contract(data_type, Q, K, V) + if contract.kind == FmhaDTypeContractKind.MIXED_Q_FP8_KV: + raise ValueError( + "CK Tile batch_prefill does not yet support the mixed activation/FP8-KV " + "dtype contract " + f"(data_type={data_type}, Q={contract.q_dtype}, K={contract.k_dtype}, " + f"V={contract.v_dtype}, O={contract.o_dtype}). " + "The current dispatcher and generated kernels select Q/K/V types from a single " + "data_type token; fp8bf16 means FP8 Q/K/V with BF16 output. Use AITER " + "paged_attention_ragged or another fallback for BF16/FP16 Q with FP8 KV until " + "CK Tile has mixed Q/KV kernel instances." + ) + + def cpu_attention_fwd( Q: np.ndarray, K: np.ndarray, @@ -811,6 +872,8 @@ def run( Returns: FmhaResult with output array, timing, TFLOPS """ + _validate_batch_prefill_input_dtypes(api_family, data_type, Q, K, V) + # Map CK dtype to numpy dtype for buffer allocation. # bf16 is stored as uint16 (upper 16 bits of float32). # fp8 uses uint8 (1 byte per element). diff --git a/dispatcher/tests/test_fmha_rules.py b/dispatcher/tests/test_fmha_rules.py index b2bcd99c09..29b316b978 100644 --- a/dispatcher/tests/test_fmha_rules.py +++ b/dispatcher/tests/test_fmha_rules.py @@ -133,6 +133,44 @@ def test_batch_prefill_valid_group(self): r = validate_config(cfg, SPECS) self.assertTrue(r.valid, r.errors) + def test_batch_prefill_rejects_mixed_activation_fp8_kv_contract(self): + cfg = _base_config( + family="batch_prefill", + dtype="bf16", + pipeline="qr_async", + mode="group", + paged_kv=True, + page_size=16, + q_data_type="bf16", + kv_data_type="fp8", + o_data_type="bf16", + ) + + r = validate_config(cfg, SPECS) + + self.assertFalse(r.valid) + self.assertTrue( + any("mixed activation/FP8-KV dtype contract" in e for e in r.errors), + r.errors, + ) + + def test_batch_prefill_keeps_all_fp8_bf16_output_contract_valid(self): + cfg = _base_config( + family="batch_prefill", + dtype="fp8bf16", + pipeline="qr_async", + mode="group", + paged_kv=True, + page_size=16, + q_data_type="fp8", + kv_data_type="fp8", + o_data_type="bf16", + ) + + r = validate_config(cfg, SPECS) + + self.assertTrue(r.valid, r.errors) + def test_splitkv_combine_bn1_must_be_32(self): cfg = _base_config(family="fwd_splitkv_combine", pipeline="qr") cfg["algorithm"]["tile"][3] = 64 diff --git a/dispatcher/tests/test_fmha_utils.py b/dispatcher/tests/test_fmha_utils.py new file mode 100644 index 0000000000..622a487cc7 --- /dev/null +++ b/dispatcher/tests/test_fmha_utils.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +import sys +import unittest +from pathlib import Path + +import numpy as np + +ROOT = Path(__file__).resolve().parents[1] +sys.path.insert(0, str(ROOT / "python")) + +from fmha_dtype_contract import FmhaDTypeContractKind # noqa: E402 +from fmha_utils import ( # noqa: E402 + _validate_batch_prefill_input_dtypes, + get_batch_prefill_dtype_contract, +) + + +class TestBatchPrefillDtypeValidation(unittest.TestCase): + def test_mixed_bf16_q_fp8_kv_gqa_decode_reports_unsupported(self): + batch = 4 + q_len = 1 + ctx_len = 1024 + num_q_heads = 96 + num_kv_heads = 8 + head_dim = 128 + + q = np.zeros((batch, num_q_heads, q_len, head_dim), dtype=np.uint16) + k = np.zeros((batch, num_kv_heads, ctx_len, head_dim), dtype=np.uint8) + v = np.zeros((batch, num_kv_heads, ctx_len, head_dim), dtype=np.uint8) + + with self.assertRaisesRegex( + ValueError, + "mixed activation/FP8-KV dtype contract", + ): + _validate_batch_prefill_input_dtypes("batch_prefill", "bf16", q, k, v) + + def test_mixed_fp16_q_fp8_kv_reports_unsupported(self): + q = np.zeros((4, 96, 1, 128), dtype=np.float16) + k = np.zeros((4, 8, 1024, 128), dtype=np.uint8) + v = np.zeros((4, 8, 1024, 128), dtype=np.uint8) + + contract = get_batch_prefill_dtype_contract("fp16", q, k, v) + + self.assertEqual(contract.kind, FmhaDTypeContractKind.MIXED_Q_FP8_KV) + with self.assertRaisesRegex(ValueError, "AITER paged_attention_ragged"): + _validate_batch_prefill_input_dtypes("batch_prefill", "fp16", q, k, v) + + def test_all_fp8_bf16_output_path_remains_allowed(self): + q = np.zeros((1, 96, 1, 128), dtype=np.uint8) + k = np.zeros((1, 8, 128, 128), dtype=np.uint8) + v = np.zeros((1, 8, 128, 128), dtype=np.uint8) + + contract = get_batch_prefill_dtype_contract("fp8bf16", q, k, v) + + self.assertEqual(contract.kind, FmhaDTypeContractKind.ALL_FP8_WITH_BF16_OUTPUT) + _validate_batch_prefill_input_dtypes("batch_prefill", "fp8bf16", q, k, v) + + def test_all_bf16_batch_prefill_path_remains_allowed(self): + q = np.zeros((1, 96, 1, 128), dtype=np.uint16) + k = np.zeros((1, 8, 128, 128), dtype=np.uint16) + v = np.zeros((1, 8, 128, 128), dtype=np.uint16) + + contract = get_batch_prefill_dtype_contract("bf16", q, k, v) + + self.assertEqual(contract.kind, FmhaDTypeContractKind.HOMOGENEOUS) + _validate_batch_prefill_input_dtypes("batch_prefill", "bf16", q, k, v) + + def test_non_batch_prefill_paths_are_unchanged(self): + q = np.zeros((1, 96, 1, 128), dtype=np.uint16) + k = np.zeros((1, 8, 128, 128), dtype=np.uint8) + v = np.zeros((1, 8, 128, 128), dtype=np.uint8) + + _validate_batch_prefill_input_dtypes("fwd", "bf16", q, k, v) + + +if __name__ == "__main__": + unittest.main()