From 3e0d0c31d878b922bf693055cb9aca669c0a19e2 Mon Sep 17 00:00:00 2001 From: AlpinDale Date: Fri, 12 Sep 2025 04:08:12 +0000 Subject: [PATCH 1/3] feat: sage attention support --- CMakeLists.txt | 77 +- aphrodite/engine/args_tools.py | 2 + aphrodite/platforms/cuda.py | 4 + aphrodite/platforms/interface.py | 1 + aphrodite/v1/attention/backends/sage_attn.py | 466 ++++++ kernels/attention/sage_attn/cp_async.cuh | 141 ++ kernels/attention/sage_attn/dispatch_utils.h | 112 ++ kernels/attention/sage_attn/fused/fused.cu | 1083 +++++++++++++ kernels/attention/sage_attn/math.cuh | 155 ++ kernels/attention/sage_attn/mma.cuh | 722 +++++++++ .../sage_attn/numeric_conversion.cuh | 149 ++ kernels/attention/sage_attn/permuted_smem.cuh | 196 +++ .../attention/sage_attn/qattn/attn_utils.cuh | 992 ++++++++++++ .../qattn/qk_int_sv_f16_cuda_sm80.cu | 1380 +++++++++++++++++ .../qattn/qk_int_sv_f8_cuda_sm89.cuh | 710 +++++++++ .../sage_attn/qattn/qk_int_sv_f8_cuda_sm90.cu | 916 +++++++++++ ...9_qk_int8_sv_f8_accum_f16_attn_inst_buf.cu | 180 +++ ...f8_accum_f16_fuse_v_scale_attn_inst_buf.cu | 187 +++ .../sm89_qk_int8_sv_f8_accum_f32_attn.cu | 180 +++ ...9_qk_int8_sv_f8_accum_f32_attn_inst_buf.cu | 179 +++ ..._int8_sv_f8_accum_f32_fuse_v_scale_attn.cu | 187 +++ ...f8_accum_f32_fuse_v_scale_attn_inst_buf.cu | 187 +++ ...accum_f32_fuse_v_scale_fuse_v_mean_attn.cu | 192 +++ .../attention/sage_attn/reduction_utils.cuh | 194 +++ kernels/attention/sage_attn/utils.cuh | 38 + kernels/attention/sage_attn/wgmma.cuh | 300 ++++ kernels/ops.h | 231 ++- kernels/torch_bindings.cpp | 271 ++++ tools/generate_torch_registration.py | 407 +++++ 29 files changed, 9835 insertions(+), 4 deletions(-) create mode 100644 aphrodite/v1/attention/backends/sage_attn.py create mode 100644 kernels/attention/sage_attn/cp_async.cuh create mode 100644 kernels/attention/sage_attn/dispatch_utils.h create mode 100644 kernels/attention/sage_attn/fused/fused.cu create mode 100644 kernels/attention/sage_attn/math.cuh create mode 100644 kernels/attention/sage_attn/mma.cuh create mode 100644 kernels/attention/sage_attn/numeric_conversion.cuh create mode 100644 kernels/attention/sage_attn/permuted_smem.cuh create mode 100644 kernels/attention/sage_attn/qattn/attn_utils.cuh create mode 100644 kernels/attention/sage_attn/qattn/qk_int_sv_f16_cuda_sm80.cu create mode 100644 kernels/attention/sage_attn/qattn/qk_int_sv_f8_cuda_sm89.cuh create mode 100644 kernels/attention/sage_attn/qattn/qk_int_sv_f8_cuda_sm90.cu create mode 100644 kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f16_attn_inst_buf.cu create mode 100644 kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf.cu create mode 100644 kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_attn.cu create mode 100644 kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_attn_inst_buf.cu create mode 100644 kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn.cu create mode 100644 kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf.cu create mode 100644 kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn.cu create mode 100644 kernels/attention/sage_attn/reduction_utils.cuh create mode 100644 kernels/attention/sage_attn/utils.cuh create mode 100644 kernels/attention/sage_attn/wgmma.cuh create mode 100755 tools/generate_torch_registration.py diff --git a/CMakeLists.txt b/CMakeLists.txt index c555a5ed44..9a0dd3dd50 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -774,7 +774,7 @@ set(APHRODITE_EXT_SRC SRCS "${SRCS}" CUDA_ARCHS "${W4A8_ARCHS}") - list(APPEND VLLM_EXT_SRC "${SRCS}") + list(APPEND APHRODITE_EXT_SRC "${SRCS}") message(STATUS "Building W4A8 kernels for archs: ${W4A8_ARCHS}") else() @@ -790,6 +790,77 @@ set(APHRODITE_EXT_SRC endif() endif() + # + # SageAttention kernels + # + + # Only build SageAttention kernels if we are building for at least SM 8.0 compatible archs + cuda_archs_loose_intersection(SAGE_ATTN_ARCHS "8.0;8.6;8.7;8.9;9.0+PTX" "${CUDA_ARCHS}") + if (SAGE_ATTN_ARCHS) + + # Base SageAttention sources (always included) + set(SAGE_ATTN_BASE_SRCS + "kernels/attention/sage_attn/fused/fused.cu") + + # SM 8.0 specific kernels + cuda_archs_loose_intersection(SAGE_ATTN_SM80_ARCHS "8.0;8.6;8.7" "${CUDA_ARCHS}") + set(SAGE_ATTN_SM80_SRCS) + if (SAGE_ATTN_SM80_ARCHS) + list(APPEND SAGE_ATTN_SM80_SRCS + "kernels/attention/sage_attn/qattn/qk_int_sv_f16_cuda_sm80.cu") + set_gencode_flags_for_srcs( + SRCS "${SAGE_ATTN_SM80_SRCS}" + CUDA_ARCHS "${SAGE_ATTN_SM80_ARCHS}") + message(STATUS "Building SageAttention SM80 kernels for archs: ${SAGE_ATTN_SM80_ARCHS}") + endif() + + # SM 8.9 specific kernels + cuda_archs_loose_intersection(SAGE_ATTN_SM89_ARCHS "8.9" "${CUDA_ARCHS}") + set(SAGE_ATTN_SM89_SRCS) + if (SAGE_ATTN_SM89_ARCHS) + list(APPEND SAGE_ATTN_SM89_SRCS + "kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f16_attn_inst_buf.cu" + "kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf.cu" + "kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_attn_inst_buf.cu" + "kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_attn.cu" + "kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf.cu" + "kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn.cu" + "kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn.cu") + set_gencode_flags_for_srcs( + SRCS "${SAGE_ATTN_SM89_SRCS}" + CUDA_ARCHS "${SAGE_ATTN_SM89_ARCHS}") + message(STATUS "Building SageAttention SM89 kernels for archs: ${SAGE_ATTN_SM89_ARCHS}") + endif() + + # SM 9.0 specific kernels + cuda_archs_loose_intersection(SAGE_ATTN_SM90_ARCHS "9.0+PTX" "${CUDA_ARCHS}") + set(SAGE_ATTN_SM90_SRCS) + if (SAGE_ATTN_SM90_ARCHS) + list(APPEND SAGE_ATTN_SM90_SRCS + "kernels/attention/sage_attn/qattn/qk_int_sv_f8_cuda_sm90.cu") + set_gencode_flags_for_srcs( + SRCS "${SAGE_ATTN_SM90_SRCS}" + CUDA_ARCHS "${SAGE_ATTN_SM90_ARCHS}") + message(STATUS "Building SageAttention SM90 kernels for archs: ${SAGE_ATTN_SM90_ARCHS}") + endif() + + set(SAGE_ATTN_SRCS ${SAGE_ATTN_BASE_SRCS}) + list(APPEND SAGE_ATTN_SRCS ${SAGE_ATTN_SM80_SRCS}) + list(APPEND SAGE_ATTN_SRCS ${SAGE_ATTN_SM89_SRCS}) + list(APPEND SAGE_ATTN_SRCS ${SAGE_ATTN_SM90_SRCS}) + + set_gencode_flags_for_srcs( + SRCS "${SAGE_ATTN_BASE_SRCS}" + CUDA_ARCHS "${SAGE_ATTN_ARCHS}") + + list(APPEND APHRODITE_EXT_SRC "${SAGE_ATTN_SRCS}") + + message(STATUS "Building SageAttention kernels for archs: ${SAGE_ATTN_ARCHS}") + else() + message(STATUS "Not building SageAttention kernels as no compatible archs found" + " in CUDA target architectures (requires SM 8.0 or above)") + endif() + # if CUDA endif endif() @@ -953,10 +1024,10 @@ if (APHRODITE_GPU_LANG STREQUAL "CUDA") include(cmake/external_project/flashmla.cmake) # Only build flash attention if not disabled - if (NOT DEFINED ENV{APHRODITE_DISABLE_FLASH_ATTN} OR NOT $ENV{APHRODITE_DISABLE_FLASH_ATTN}) + if (NOT DEFINED ENV{APHRODITE_DISABLE_FLASH_ATTN_COMPILE} OR NOT $ENV{APHRODITE_DISABLE_FLASH_ATTN_COMPILE}) # vllm-flash-attn should be last as it overwrites some CMake functions include(cmake/external_project/vllm_flash_attn.cmake) else() - message(STATUS "Flash attention compilation disabled by APHRODITE_DISABLE_FLASH_ATTN") + message(STATUS "Flash attention compilation disabled by APHRODITE_DISABLE_FLASH_ATTN_COMPILE") endif() endif() diff --git a/aphrodite/engine/args_tools.py b/aphrodite/engine/args_tools.py index 9a6e232d04..bf36fa57a0 100644 --- a/aphrodite/engine/args_tools.py +++ b/aphrodite/engine/args_tools.py @@ -1560,6 +1560,8 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: "TREE_ATTN", "XFORMERS_APHRODITE_V1", "XFORMERS", + "SAGE_ATTN", + "SAGE_ATTN_APHRODITE_V1", ] if (envs.is_set("APHRODITE_ATTENTION_BACKEND") and envs.APHRODITE_ATTENTION_BACKEND not in V1_BACKENDS): diff --git a/aphrodite/platforms/cuda.py b/aphrodite/platforms/cuda.py index acae3c677e..21727f0533 100644 --- a/aphrodite/platforms/cuda.py +++ b/aphrodite/platforms/cuda.py @@ -265,6 +265,7 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, FLASH_ATTN_V1 = "aphrodite.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501 TREE_ATTN_V1 = "aphrodite.v1.attention.backends.tree_attn.TreeAttentionBackend" # noqa: E501 XFORMERS_APHRODITE_V1 = "aphrodite.v1.attention.backends.xformers.XFormersAttentionBackend" # noqa: E501 + SAGE_ATTN_V1 = "aphrodite.v1.attention.backends.sage_attn.SageAttentionBackend" # noqa: E501 if selected_backend == _Backend.FLEX_ATTENTION: log_once("INFO", "Using FlexAttention backend on V1 engine.") @@ -281,6 +282,9 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, elif selected_backend == _Backend.XFORMERS: log_once("INFO", "Using XFormers backend on V1 engine.") return XFORMERS_APHRODITE_V1 + elif selected_backend == _Backend.SAGE_ATTN: + log_once("INFO", "Using SageAttention backend on V1 engine.") + return SAGE_ATTN_V1 from aphrodite.attention.selector import is_attn_backend_supported diff --git a/aphrodite/platforms/interface.py b/aphrodite/platforms/interface.py index da3a8f7392..a5fd155ec6 100644 --- a/aphrodite/platforms/interface.py +++ b/aphrodite/platforms/interface.py @@ -60,6 +60,7 @@ class _Backend(enum.Enum): FLEX_ATTENTION = enum.auto() TREE_ATTN = enum.auto() XFORMERS_APHRODITE_V1 = enum.auto() + SAGE_ATTN = enum.auto() class PlatformEnum(enum.Enum): diff --git a/aphrodite/v1/attention/backends/sage_attn.py b/aphrodite/v1/attention/backends/sage_attn.py new file mode 100644 index 0000000000..4865fd70b4 --- /dev/null +++ b/aphrodite/v1/attention/backends/sage_attn.py @@ -0,0 +1,466 @@ +"""Attention layer with SageAttention.""" +from dataclasses import dataclass +from typing import Optional + +import torch +from sageattention import sageattn + +from aphrodite.attention.backends.abstract import (AttentionBackend, + AttentionImpl, + AttentionType) +from aphrodite.config import AphroditeConfig +from aphrodite.v1.attention.backends.utils import (AttentionMetadataBuilder, + CommonAttentionMetadata) +from aphrodite.v1.kv_cache_interface import AttentionSpec + + +@dataclass +class SageAttentionMetadata: + """Metadata for SageAttentionBackend.""" + num_actual_tokens: int # Number of tokens excluding padding. + max_query_len: int + query_start_loc: torch.Tensor + max_seq_len: int + seq_lens: torch.Tensor + block_table: torch.Tensor + slot_mapping: torch.Tensor + causal: bool = True + + +class SageAttentionBackend(AttentionBackend): + + accept_output_buffer: bool = False + + @classmethod + def get_supported_dtypes(cls) -> list[torch.dtype]: + return [torch.float16, torch.bfloat16] + + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + # SageAttention doesn't support head sizes larger than 128 + return [64, 96, 128] + + @classmethod + def validate_head_size(cls, head_size: int) -> None: + supported_head_sizes = cls.get_supported_head_sizes() + if head_size not in supported_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by SageAttention. " + f"Supported head sizes are: {supported_head_sizes}.") + + @staticmethod + def get_name() -> str: + return "SAGE_ATTN_APHRODITE_V1" + + @staticmethod + def get_impl_cls() -> type["SageAttentionImpl"]: + return SageAttentionImpl + + @staticmethod + def get_metadata_cls() -> type["SageAttentionMetadata"]: + return SageAttentionMetadata + + @staticmethod + def get_builder_cls() -> type["SageAttentionMetadataBuilder"]: + return SageAttentionMetadataBuilder + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> tuple[int, ...]: + # Standard cache shape for paged attention + return (2, num_blocks, block_size, num_kv_heads, head_size) + + @staticmethod + def get_kv_cache_stride_order() -> tuple[int, ...]: + # Standard stride order + return (0, 1, 2, 3, 4) + + +class SageAttentionMetadataBuilder( + AttentionMetadataBuilder[SageAttentionMetadata]): + + def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], + aphrodite_config: AphroditeConfig, device: torch.device): + self.device = device + + def build(self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False) -> SageAttentionMetadata: + """Build attention metadata.""" + num_actual_tokens = common_attn_metadata.num_actual_tokens + max_query_len = common_attn_metadata.max_query_len + max_seq_len = common_attn_metadata.max_seq_len + query_start_loc = common_attn_metadata.query_start_loc + seq_lens = common_attn_metadata.seq_lens + block_table_tensor = common_attn_metadata.block_table_tensor + slot_mapping = common_attn_metadata.slot_mapping + causal = common_attn_metadata.causal + + return SageAttentionMetadata( + num_actual_tokens=num_actual_tokens, + max_query_len=max_query_len, + query_start_loc=query_start_loc, + max_seq_len=max_seq_len, + seq_lens=seq_lens, + block_table=block_table_tensor, + slot_mapping=slot_mapping, + causal=causal) + + def use_cascade_attention(self, *args, **kwargs) -> bool: + # SageAttention doesn't support cascade attention + return False + + +class SageAttentionImpl(AttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + logits_soft_cap: Optional[float] = None, + attn_type: AttentionType = AttentionType.DECODER, + kv_sharing_target_layer_name: Optional[str] = None, + ) -> None: + if logits_soft_cap is not None: + raise ValueError("SageAttention does not support logits soft cap.") + if kv_sharing_target_layer_name is not None: + raise NotImplementedError( + "KV sharing is not supported in SageAttention backend.") + + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.sliding_window = sliding_window + self.kv_cache_dtype = kv_cache_dtype + self.attn_type = attn_type + + SageAttentionBackend.validate_head_size(head_size) + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + if kv_cache_dtype != "auto": + raise NotImplementedError( + "SageAttention backend does not support quantized KV cache.") + + def forward( + self, + layer: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: SageAttentionMetadata, + output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with SageAttention. + + Args: + query: shape = [num_tokens, num_heads, head_size] + key: shape = [num_tokens, num_kv_heads, head_size] + value: shape = [num_tokens, num_kv_heads, head_size] + kv_cache: shape = [2, num_blocks, block_size, num_kv_heads, + head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + if output_scale is not None or output_block_scale is not None: + raise NotImplementedError( + "Fused output quantization is not supported for " + "SageAttention") + + if attn_metadata is None: + # Profiling run. + return query # Just return query for profiling + + # Create output tensor since accept_output_buffer = False + num_tokens = query.shape[0] + output = torch.empty(num_tokens, self.num_heads * self.head_size, + dtype=query.dtype, device=query.device) + + # Handle encoder attention differently - no KV cache needed + if self.attn_type in (AttentionType.ENCODER_ONLY, + AttentionType.ENCODER): + result = self._forward_encoder_attention( + query[:attn_metadata.num_actual_tokens], + key[:attn_metadata.num_actual_tokens], + value[:attn_metadata.num_actual_tokens], + attn_metadata) + output[:attn_metadata.num_actual_tokens] = result + return output + + num_actual_tokens = attn_metadata.num_actual_tokens + + # For decoder and cross-attention, use KV cache + if kv_cache.numel() > 0: + key_cache, value_cache = kv_cache.unbind(0) + + # Update KV cache if we have new key/value pairs + if key is not None and value is not None: + self._reshape_and_cache( + key, value, key_cache, value_cache, + attn_metadata.slot_mapping) + + # Use cached keys and values for attention computation + result = self._forward_with_kv_cache( + query[:num_actual_tokens], + key_cache, + value_cache, + attn_metadata) + output[:num_actual_tokens] = result + return output + else: + # Direct attention computation (prefill without cache) + result = self._forward_direct_attention( + query[:num_actual_tokens], + key[:num_actual_tokens], + value[:num_actual_tokens], + attn_metadata) + output[:num_actual_tokens] = result + return output + + def _reshape_and_cache( + self, + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + ) -> None: + """Reshape and cache key/value tensors.""" + # Flatten slot mapping and use it to update cache + flat_slot_mapping = slot_mapping.flatten() + + # Get the number of tokens to cache + num_tokens = key.shape[0] + + # Reshape key and value to match cache format + key_to_cache = key.view(num_tokens, self.num_kv_heads, + self.head_size) + value_to_cache = value.view(num_tokens, self.num_kv_heads, + self.head_size) + + # Update the cache using slot mapping + for i, slot_idx in enumerate(flat_slot_mapping[:num_tokens]): + if slot_idx >= 0: # Valid slot + block_idx = slot_idx // key_cache.shape[1] + block_offset = slot_idx % key_cache.shape[1] + key_cache[block_idx, block_offset, :, :] = key_to_cache[i] + value_cache[block_idx, block_offset, :, :] = value_to_cache[i] + + def _forward_encoder_attention( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: SageAttentionMetadata, + ) -> torch.Tensor: + """Forward pass for encoder attention without KV cache.""" + # For encoder attention, we use direct Q, K, V tensors + return self._compute_sage_attention( + query, key, value, attn_metadata, is_causal=False) + + def _forward_direct_attention( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: SageAttentionMetadata, + ) -> torch.Tensor: + """Forward pass with direct Q, K, V tensors (prefill).""" + return self._compute_sage_attention( + query, key, value, attn_metadata, + is_causal=attn_metadata.causal) + + def _forward_with_kv_cache( + self, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + attn_metadata: SageAttentionMetadata, + ) -> torch.Tensor: + """Forward pass using KV cache (decode or prefill with cache).""" + batch_size = attn_metadata.query_start_loc.shape[0] - 1 + + max_seq_len = attn_metadata.max_seq_len + key_list = [] + value_list = [] + + for i in range(batch_size): + seq_len = attn_metadata.seq_lens[i].item() + block_table = attn_metadata.block_table[i] + + seq_key = self._extract_from_cache(key_cache, block_table, + seq_len) + seq_value = self._extract_from_cache(value_cache, block_table, + seq_len) + + key_list.append(seq_key) + value_list.append(seq_value) + + key = self._pad_and_concat(key_list, max_seq_len) + value = self._pad_and_concat(value_list, max_seq_len) + + return self._compute_sage_attention( + query, key, value, attn_metadata, + is_causal=attn_metadata.causal) + + def _extract_from_cache( + self, + cache: torch.Tensor, + block_table: torch.Tensor, + seq_len: int, + ) -> torch.Tensor: + """Extract a sequence from the paged cache.""" + block_size = cache.shape[1] + num_blocks = (seq_len + block_size - 1) // block_size + + extracted = [] + for block_idx in range(num_blocks): + block_id = block_table[block_idx].item() + if block_idx == num_blocks - 1: + # Last block might be partial + end_offset = seq_len % block_size + if end_offset == 0: + end_offset = block_size + extracted.append(cache[block_id, :end_offset]) + else: + extracted.append(cache[block_id]) + + return torch.cat(extracted, dim=0) + + def _pad_and_concat( + self, + tensor_list: list[torch.Tensor], + max_len: int, + ) -> torch.Tensor: + """Pad tensors to max_len and concatenate.""" + padded = [] + for tensor in tensor_list: + if tensor.shape[0] < max_len: + padding = torch.zeros( + max_len - tensor.shape[0], + *tensor.shape[1:], + dtype=tensor.dtype, + device=tensor.device) + padded.append(torch.cat([tensor, padding], dim=0)) + else: + padded.append(tensor[:max_len]) + + return torch.cat(padded, dim=0) + + def _compute_sage_attention( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: SageAttentionMetadata, + is_causal: bool, + ) -> torch.Tensor: + """Compute attention using SageAttention.""" + # NOTE: We'll handle GQA expansion after reshaping to avoid dimension + # issues + + # NOTE: SageAttention doesn't support ALiBi or sliding window in basic + # API + if self.alibi_slopes is not None: + raise NotImplementedError( + "ALiBi slopes are not supported in SageAttention backend") + if self.sliding_window is not None: + raise NotImplementedError( + "Sliding window is not supported in SageAttention backend") + + num_tokens = query.shape[0] + output = torch.empty(num_tokens, self.num_heads * self.head_size, + dtype=query.dtype, device=query.device) + + # Process each sequence separately + batch_size = attn_metadata.query_start_loc.shape[0] - 1 + + for i in range(batch_size): + query_start = attn_metadata.query_start_loc[i].item() + query_end = attn_metadata.query_start_loc[i + 1].item() + query_len = query_end - query_start + seq_len = attn_metadata.seq_lens[i].item() + + seq_query = query[ + query_start:query_end] # [query_len, num_heads * head_size] + seq_key = key[ + i * attn_metadata.max_seq_len:i * + attn_metadata.max_seq_len + seq_len + ] # [seq_len, num_kv_heads * head_size] + seq_value = value[ + i * attn_metadata.max_seq_len:i * + attn_metadata.max_seq_len + + seq_len] # [seq_len, num_kv_heads * head_size] + + # First reshape to separate heads and head_size + query_len = query_end - query_start + seq_query = seq_query.view( + query_len, self.num_heads, + self.head_size) # [query_len, num_heads, head_size] + seq_key = seq_key.view( + seq_len, self.num_kv_heads, + self.head_size) # [seq_len, num_kv_heads, head_size] + seq_value = seq_value.view( + seq_len, self.num_kv_heads, + self.head_size) # [seq_len, num_kv_heads, head_size] + + # Expand KV heads if needed (for GQA) + if self.num_kv_heads != self.num_heads: + seq_key = seq_key.repeat_interleave( + self.num_queries_per_kv, + dim=1) # [seq_len, num_heads, head_size] + seq_value = seq_value.repeat_interleave( + self.num_queries_per_kv, + dim=1) # [seq_len, num_heads, head_size] + + # Reshape for SageAttention + # SageAttention expects + # [batch_size, num_heads, seq_len, head_size] (HND layout) + seq_query = seq_query.unsqueeze( + 0).transpose(1, 2) # [1, num_heads, query_len, head_size] + seq_key = seq_key.unsqueeze( + 0).transpose(1, 2) # [1, num_heads, seq_len, head_size] + seq_value = seq_value.unsqueeze( + 0).transpose(1, 2) # [1, num_heads, seq_len, head_size] + + # Compute attention using SageAttention + # For decode phase (query_len=1), we can't use causal since + # qo_len != kv_len + use_causal = is_causal and (query_len == seq_len) + + seq_output = sageattn( + q=seq_query, + k=seq_key, + v=seq_value, + tensor_layout="HND", + is_causal=use_causal, + sm_scale=self.scale) + + # Convert back to [query_len, num_heads, head_size] then flatten + # to [query_len, num_heads * head_size] + seq_output = seq_output.squeeze( + 0).transpose(0, 1) # [query_len, num_heads, head_size] + seq_output_flat = seq_output.contiguous().view( + seq_output.shape[0], -1) # [query_len, num_heads * head_size] + output[query_start:query_end] = seq_output_flat + + return output diff --git a/kernels/attention/sage_attn/cp_async.cuh b/kernels/attention/sage_attn/cp_async.cuh new file mode 100644 index 0000000000..c877c16e20 --- /dev/null +++ b/kernels/attention/sage_attn/cp_async.cuh @@ -0,0 +1,141 @@ +/* + * Copyright (c) 2024 by SageAttention team. + * + * This file is based on code from Flashinfer, https://github.com/flashinfer-ai/flashinfer/blob/v0.1.5/include/flashinfer/cp_async.cuh + * Copyright (c) 2023 by FlashInfer team. + * Small modifications made by SageAttention team, 2024 (e.g., renamed namespace). + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include + +namespace cp_async { + +enum class SharedMemFillMode { + kFillZero, // Fill zero to shared memory when predicate is false + kNoFill // Do not fill zero to shared memory when predicate is false +}; + +enum class PrefetchMode { + kNoPrefetch, // Do not fetch additional data from global memory to L2 + kPrefetch // Fetch additional data from global memory to L2 +}; + +#if (__CUDACC_VER_MAJOR__ >= 11) +#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)) +#define CP_ASYNC_ENABLED +#endif +#endif + +/*! + * \brief Wrapper of PTX cp.async.commit_group instruction, commit all prior uncommitted + * cp.async instructions to a group + */ +__device__ __forceinline__ void commit_group() { +#ifdef CP_ASYNC_ENABLED + asm volatile("cp.async.commit_group;\n" ::); +#endif +} + +/*! + * \brief Wrapper of PTX cp.async.wait_group instruction + * \tparam n Wait till most recent n groups are committed + */ +template +__device__ __forceinline__ void wait_group() { +#ifdef CP_ASYNC_ENABLED + asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); +#endif +} + +/*! + * \brief Wrapper of PTX cp.async.cg.shared.global instruction, asynchronously copy data from + * global memory to shared memory + * \tparam prefetch_mode Whether to fetch additional data from global memory to L2 + * \tparam T Data type + * \param smem_ptr Pointer to shared memory + * \param gmem_ptr Pointer to global memory + */ +template +__device__ __forceinline__ void load_128b(T* smem_ptr, const T* gmem_ptr) { +#ifdef CP_ASYNC_ENABLED + uint32_t smem_int_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); + if constexpr (prefetch_mode == PrefetchMode::kPrefetch) { + asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), + "l"(gmem_ptr), "n"(16), "r"(16)); + } else { + asm volatile("cp.async.cg.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), + "l"(gmem_ptr), "n"(16), "r"(16)); + } +#else + *((uint4*)smem_ptr) = *((uint4*)gmem_ptr); +#endif +} + +/*! + * \brief Wrapper of PTX cp.async.cg.shared.global instruction, asynchronously copy data from + * global memory to shared memory with predicate. + * \tparam prefetch_mode Whether to fetch additional data from global memory to L2 + * \tparam fill_mode Whether to fill zero to shared memory when predicate is false + * \tparam T Data type + * \param smem_ptr Pointer to shared memory + * \param gmem_ptr Pointer to global memory + * \param predicate Predicate value + * \note fill zero is slower than not fill zero + */ +template +__device__ __forceinline__ void pred_load_128b(T* smem_ptr, const T* gmem_ptr, bool predicate) { +#ifdef CP_ASYNC_ENABLED + uint32_t smem_int_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); + if constexpr (fill_mode == SharedMemFillMode::kFillZero) { + int src_in_bytes = predicate ? 16 : 0; + if constexpr (prefetch_mode == PrefetchMode::kPrefetch) { + asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), + "l"(gmem_ptr), "n"(16), "r"(src_in_bytes)); + } else { + asm volatile("cp.async.cg.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), + "l"(gmem_ptr), "n"(16), "r"(src_in_bytes)); + } + } else { + if constexpr (prefetch_mode == PrefetchMode::kPrefetch) { + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global.L2::128B [%1], [%2], %3;\n" + "}\n" ::"r"((int)predicate), + "r"(smem_int_ptr), "l"(gmem_ptr), "n"(16)); + } else { + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)predicate), + "r"(smem_int_ptr), "l"(gmem_ptr), "n"(16)); + } + } +#else + if (predicate) { + *((uint4*)smem_ptr) = *((uint4*)gmem_ptr); + } else { + if constexpr (fill_mode == SharedMemFillMode::kFillZero) { + *((uint4*)smem_ptr) = make_uint4(0, 0, 0, 0); + } + } +#endif +} + +} // namespace cp_async \ No newline at end of file diff --git a/kernels/attention/sage_attn/dispatch_utils.h b/kernels/attention/sage_attn/dispatch_utils.h new file mode 100644 index 0000000000..04fceaa6f8 --- /dev/null +++ b/kernels/attention/sage_attn/dispatch_utils.h @@ -0,0 +1,112 @@ +/* + * Copyright (c) 2024 by SageAttention team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include +#include + +#define DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, ...) \ + if (head_dim == 64) { \ + constexpr int HEAD_DIM = 64; \ + __VA_ARGS__ \ + } else if (head_dim == 128) { \ + constexpr int HEAD_DIM = 128; \ + __VA_ARGS__ \ + } else { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported head dim: " << int(head_dim); \ + throw std::invalid_argument(err_msg.str()); \ + } + +#define DISPATCH_CAUSAL(is_causal, IS_CAUSAL, ...) \ + if (is_causal == 1) { \ + constexpr bool IS_CAUSAL = true; \ + __VA_ARGS__ \ + } else if (is_causal == 0) { \ + constexpr bool IS_CAUSAL = false; \ + __VA_ARGS__ \ + } else { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported causal mode: " << int(is_causal); \ + throw std::invalid_argument(err_msg.str()); \ + } + +#define DISPATCH_QK_QUANT_GRAN(qk_quant_gran, QK_QUANT_GRAN, ...) \ + if (qk_quant_gran == 2) { \ + constexpr int QK_QUANT_GRAN = 2; \ + __VA_ARGS__ \ + } else if (qk_quant_gran == 3) { \ + constexpr int QK_QUANT_GRAN = 3; \ + __VA_ARGS__ \ + } else { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported qk_quant_gran: " << int(qk_quant_gran); \ + throw std::invalid_argument(err_msg.str()); \ + } + +#define DISPATCH_RETURN_LSE(return_lse, RETURN_LSE, ...) \ + if (return_lse == 1) { \ + constexpr bool RETURN_LSE = true; \ + __VA_ARGS__ \ + } else if (return_lse == 0) { \ + constexpr bool RETURN_LSE = false; \ + __VA_ARGS__ \ + } else { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported causal mode: " << int(return_lse); \ + throw std::invalid_argument(err_msg.str()); \ + } + +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(pytorch_dtype, c_type, ...) \ + if (pytorch_dtype == at::ScalarType::Half) { \ + using c_type = half; \ + __VA_ARGS__ \ + } else if (pytorch_dtype == at::ScalarType::BFloat16) { \ + using c_type = nv_bfloat16; \ + __VA_ARGS__ \ + } else { \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \ + TORCH_CHECK(false, oss.str()); \ + } + +#define DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, ...) \ + if (block_size == 64) { \ + constexpr int BLOCK_SIZE = 64; \ + __VA_ARGS__ \ + } else if (block_size == 128) { \ + constexpr int BLOCK_SIZE = 128; \ + __VA_ARGS__ \ + } else { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported block_size " << int(block_size); \ + throw std::invalid_argument(err_msg.str()); \ + } + +#define DISPATCH_WARP_BLOCK_SIZE(warp_block_size, WARP_BLOCK_SIZE, ...) \ + if (warp_block_size == 16) { \ + constexpr int WARP_BLOCK_SIZE = 16; \ + __VA_ARGS__ \ + } else if (warp_block_size == 32) { \ + constexpr int WARP_BLOCK_SIZE = 32; \ + __VA_ARGS__ \ + } else { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported warp_block_size " << int(warp_block_size); \ + throw std::invalid_argument(err_msg.str()); \ + } diff --git a/kernels/attention/sage_attn/fused/fused.cu b/kernels/attention/sage_attn/fused/fused.cu new file mode 100644 index 0000000000..d9265e769a --- /dev/null +++ b/kernels/attention/sage_attn/fused/fused.cu @@ -0,0 +1,1083 @@ +/* + * Copyright (c) 2024 by SageAttention team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "../dispatch_utils.h" +#include "../utils.cuh" +#include "../reduction_utils.cuh" +#include "../numeric_conversion.cuh" +#include "../cp_async.cuh" +#include +#include + +enum class QuantType +{ + kInt8, + kInt4, +}; + +template +__device__ __forceinline__ float convert_to_float(T val) +{ + static_assert(std::is_same::value || std::is_same::value, "Only half and bfloat16 are supported"); + + if constexpr (std::is_same::value) + { + return __half2float(val); + } + else if constexpr (std::is_same::value) + { + return __bfloat162float(val); + } +} + +template +__device__ __forceinline__ T convert_from_float(float val) +{ + static_assert(std::is_same::value || std::is_same::value, "Only half and bfloat16 are supported"); + + if constexpr (std::is_same::value) + { + return __float2half_rn(val); + } + else if constexpr (std::is_same::value) + { + return __float2bfloat16_rn(val); + } +} + +template +__global__ void QuantInt8Kernel(T *__restrict__ input, T *__restrict__ mean, int8_t *__restrict__ output, float *__restrict__ scale, float sm_scale, const uint32_t num_tokens, + const uint32_t stride_bz_input, const uint32_t stride_seq_input, const uint32_t stride_h_input, + const uint32_t stride_bz_mean, const uint32_t stride_h_mean, + const uint32_t stride_bz_output, const uint32_t stride_seq_output, const uint32_t stride_h_output, + const uint32_t stride_bz_scale, const uint32_t stride_h_scale) +{ + static_assert(std::is_same::value || std::is_same::value, "Only half and bfloat16 are supported"); + static_assert(num_pack_per_thread > 0, "The number of pack per thread must be greater than 0"); + + constexpr uint32_t pack_size = 8; // float4 contains 8 half or 8 bfloat16 + constexpr uint32_t num_threads_per_token = head_dim / pack_size; + + static_assert(num_threads_per_token <= 32, "The number of threads per token must be less than or equal to warp size"); + + T x_val[num_pack_per_thread][8]; + T mean_val[8]; + float x_val_float[num_pack_per_thread][8]; + float mean_val_float[8]; + + uint32_t bx = blockIdx.x; + uint32_t head_id = blockIdx.y; + uint32_t batch_id = blockIdx.z; + uint32_t thread_id = threadIdx.x; + + uint32_t thread_base_token = bx * BLOCK_SIZE + thread_id / num_threads_per_token; + T *input_ptr_base = input + batch_id * stride_bz_input + head_id * stride_h_input + thread_base_token * stride_seq_input + thread_id % num_threads_per_token * pack_size; + T *mean_ptr_base = mean + batch_id * stride_bz_mean + head_id * stride_h_mean + thread_id % num_threads_per_token * pack_size; + int8_t *output_ptr_base = output + batch_id * stride_bz_output + head_id * stride_h_output + thread_base_token * stride_seq_output + thread_id % num_threads_per_token * pack_size; + float *scale_ptr_base = scale + batch_id * stride_bz_scale + head_id * stride_h_scale + bx; + + if constexpr (sub_mean) + { + *(float4*)(&mean_val[0]) = *(float4*)(mean_ptr_base); +#pragma unroll + for (uint32_t j = 0; j < 8; j++) + { + mean_val_float[j] = convert_to_float(mean_val[j]); + } + } + + constexpr uint32_t iter_stride = BLOCK_SIZE / num_pack_per_thread; + + // load the data + for (uint32_t i = 0; i < num_pack_per_thread; i++) + { + if (thread_base_token + i * iter_stride < num_tokens) + { + *(float4*)(&x_val[i][0]) = *(float4*)(input_ptr_base + i * iter_stride * stride_seq_input); +#pragma unroll + for (uint32_t j = 0; j < 8; j++) + { + x_val_float[i][j] = convert_to_float(x_val[i][j]); + } + + if constexpr (sub_mean) + { +#pragma unroll + for (uint32_t j = 0; j < 8; j++) + { + x_val_float[i][j] -= mean_val_float[j]; + } + } + + if constexpr (has_sm_scale) + { +#pragma unroll + for (uint32_t j = 0; j < 8; j++) + { + x_val_float[i][j] *= sm_scale; + } + } + } + else + { +#pragma unroll + for (uint32_t j = 0; j < 8; j++) + { + x_val_float[i][j] = 0.0f; + } + } + } + + float amax_val = 0.0000001f; // prevent from dividing by zero + +#pragma unroll + for (uint32_t i = 0; i < num_pack_per_thread; i++) + { +#pragma unroll + for (uint32_t j = 0; j < 8; j++) + { + amax_val = fmaxf(amax_val, fabsf(x_val_float[i][j])); + } + } + + __shared__ float s_amax; + const float block_amax_val = aphrodite::blockReduceMax(amax_val); + if (thread_id == 0) + { + s_amax = block_amax_val; + scale_ptr_base[0] = s_amax / 127.0f; + } + + __syncthreads(); + + float tmp_scale = 127.0f / s_amax; + + char4 o_val[num_pack_per_thread][2]; + +#pragma unroll + for (uint32_t i = 0; i < num_pack_per_thread; i++) + { +#pragma unroll + for (uint32_t j = 0; j < 2; j += 1) + { + o_val[i][j] = make_char4( + float_to_int8_rn(x_val_float[i][j * 4 + 0] * tmp_scale), + float_to_int8_rn(x_val_float[i][j * 4 + 1] * tmp_scale), + float_to_int8_rn(x_val_float[i][j * 4 + 2] * tmp_scale), + float_to_int8_rn(x_val_float[i][j * 4 + 3] * tmp_scale) + ); + } + } + + // int8 result +#pragma unroll + for (uint32_t i = 0; i < num_pack_per_thread; i++) + { + + if (thread_base_token + i * iter_stride < num_tokens) + { + *reinterpret_cast(output_ptr_base + i * iter_stride * stride_seq_output) = *reinterpret_cast(&o_val[i][0]); + } + } +} + +template +__global__ void SubMeanKernel(T *__restrict__ input, T *__restrict__ mean, half *__restrict__ output, const uint32_t num_tokens, + const uint32_t stride_bz_input, const uint32_t stride_seq_input, const uint32_t stride_h_input, + const uint32_t stride_bz_mean, const uint32_t stride_h_mean, + const uint32_t stride_bz_output, const uint32_t stride_seq_output, const uint32_t stride_h_output) +{ + static_assert(std::is_same::value || std::is_same::value, "Only half and bfloat16 are supported"); + static_assert(num_pack_per_thread > 0, "The number of pack per thread must be greater than 0"); + + using T2 = typename std::conditional::value, half2, nv_bfloat162>::type; + + constexpr uint32_t pack_size = 8; // float4 contains 8 half or 8 bfloat16 + constexpr uint32_t num_threads_per_token = head_dim / pack_size; + + static_assert(num_threads_per_token <= 32, "The number of threads per token must be less than or equal to warp size"); + + T2 x_val[num_pack_per_thread][4]; + T2 mean_val[4]; + + uint32_t bx = blockIdx.x; + uint32_t head_id = blockIdx.y; + uint32_t batch_id = blockIdx.z; + uint32_t thread_id = threadIdx.x; + + uint32_t thread_base_token = bx * BLOCK_SIZE + thread_id / num_threads_per_token; + T *input_ptr_base = input + batch_id * stride_bz_input + head_id * stride_h_input + thread_base_token * stride_seq_input + thread_id % num_threads_per_token * pack_size; + T *mean_ptr_base = mean + batch_id * stride_bz_mean + head_id * stride_h_mean + thread_id % num_threads_per_token * pack_size; + half *output_ptr_base = output + batch_id * stride_bz_output + head_id * stride_h_output + thread_base_token * stride_seq_output + thread_id % num_threads_per_token * pack_size; + + *(float4*)(&mean_val[0]) = *(float4*)(mean_ptr_base); + + constexpr uint32_t iter_stride = BLOCK_SIZE / num_pack_per_thread; + + // load the data + for (uint32_t i = 0; i < num_pack_per_thread; i++) + { + if (thread_base_token + i * iter_stride < num_tokens) + { + *(float4*)(&x_val[i][0]) = *(float4*)(input_ptr_base + i * iter_stride * stride_seq_input); +#pragma unroll + for (uint32_t j = 0; j < 4; j++) + { + x_val[i][j] = __hsub2(x_val[i][j], mean_val[j]); + + if constexpr (std::is_same::value) + { + ((half2*)x_val[i])[j] = __float22half2_rn(__bfloat1622float2(x_val[i][j])); + } + } + } + } + +#pragma unroll + for (uint32_t i = 0; i < num_pack_per_thread; i++) + { + if (thread_base_token + i * iter_stride < num_tokens) + { + *reinterpret_cast(output_ptr_base + i * iter_stride * stride_seq_output) = *reinterpret_cast(&x_val[i][0]); + } + } +} + +template +__global__ void TransposePadPermuteKernel(T *__restrict__ input, T *__restrict__ output, const uint32_t num_tokens, + const uint32_t stride_bz_input, const uint32_t stride_seq_input, const uint32_t stride_h_input, + const uint32_t stride_bz_output, const uint32_t stride_d_output, const uint32_t stride_h_output) +{ + + static_assert(std::is_same::value || std::is_same::value, "Only half and bfloat16 are supported"); + + constexpr uint32_t pack_size = 8; // float4 contains 8 half or 8 bfloat16 + uint32_t num_threads_per_token = head_dim / pack_size; + uint32_t num_threads_per_cta = CTA_SIZE / pack_size; + + uint32_t bx = blockIdx.x; + uint32_t head_id = blockIdx.y; + uint32_t batch_id = blockIdx.z; + uint32_t thread_id = threadIdx.x; + + uint32_t thread_base_token = bx * CTA_SIZE + thread_id / num_threads_per_token; + + T *input_ptr_base = input + batch_id * stride_bz_input + head_id * stride_h_input + thread_base_token * stride_seq_input + thread_id % num_threads_per_token * pack_size; + T* output_ptr_base = output + batch_id * stride_bz_output + head_id * stride_h_output + bx * CTA_SIZE + thread_id % num_threads_per_cta * pack_size + thread_id / num_threads_per_cta * stride_d_output; + + __shared__ T shared_load[CTA_SIZE][head_dim]; + __shared__ T shared_store[head_dim][CTA_SIZE]; + + // 0, 1, 4, 5, 8, 9, 12, 13, 2, 3, 6, 7, 10, 11, 14, 15 + // permute on the seq dimension for fp8 mma + uint32_t smem_load_row_base = ((thread_id / num_threads_per_token) / 16) * 16; + uint32_t smem_load_row_mod = (thread_id / num_threads_per_token) % 16; + uint32_t smem_load_row = smem_load_row_base + (smem_load_row_mod / 8) * 2 + ((smem_load_row_mod / 2) % 4) * 4 + (smem_load_row_mod % 2); + + constexpr cp_async::SharedMemFillMode fill_mode = pad_zero ? cp_async::SharedMemFillMode::kFillZero : cp_async::SharedMemFillMode::kNoFill; + cp_async::pred_load_128b(shared_load[smem_load_row] + thread_id % num_threads_per_token * pack_size, input_ptr_base, thread_base_token < num_tokens); + cp_async::commit_group(); + cp_async::wait_group<0>(); + __syncthreads(); + + uint32_t smem_row_base = thread_id % CTA_SIZE; + uint32_t smem_col_base = thread_id / CTA_SIZE; + uint32_t smem_col_stride = head_dim / 8; + + // TODO: use ldmatrix to do permutation +#pragma unroll + for (uint32_t i = 0; i < 8; i++) + { + shared_store[smem_col_base + i * smem_col_stride][smem_row_base] = shared_load[smem_row_base][smem_col_base + i * smem_col_stride]; + } + + __syncthreads(); + + *(float4*)(output_ptr_base) = *(float4*)(&shared_store[thread_id / num_threads_per_cta][thread_id % num_threads_per_cta * pack_size]); +} + + +template +__global__ void MeanScaleKernel(T *__restrict__ input, int8_t *__restrict__ output, float *__restrict__ mean, float *__restrict__ scale, const float scale_max, const uint32_t num_tokens, + const uint32_t stride_bz_input, const uint32_t stride_d_input, const uint32_t stride_h_input, + const uint32_t stride_bz_output, const uint32_t stride_d_output, const uint32_t stride_h_output, + const uint32_t stride_bz_mean, const uint32_t stride_h_mean, + const uint32_t stride_bz_scale, const uint32_t stride_h_scale) +{ + static_assert(std::is_same::value || std::is_same::value, "Only half and bfloat16 are supported"); + + constexpr uint32_t pack_size = 8; // float4 contains 8 half or 8 bfloat16 + + uint32_t head_id = blockIdx.x; + uint32_t batch_id = blockIdx.y; + uint32_t d_id = blockIdx.z; + uint32_t thread_id = threadIdx.x; + + uint32_t num_threads = blockDim.x; + uint32_t gmem_stride = num_threads * pack_size; + // pad the number of tokens to 16 to deal with fp8 permute in previous kernel + uint32_t fp8_padded_num_tokens = (num_tokens + 15) / 16 * 16; + uint32_t num_iters = fp8_padded_num_tokens / gmem_stride + ((fp8_padded_num_tokens % gmem_stride) > thread_id * pack_size); + + T *input_ptr_base = input + batch_id * stride_bz_input + head_id * stride_h_input + d_id * stride_d_input + thread_id * pack_size; + int8_t *output_ptr_base = output + batch_id * stride_bz_output + head_id * stride_h_output + d_id * stride_d_output + thread_id * pack_size; + + T x_val[8]; + float x_val_float[8]; + uint32_t x_val_fp8[2]; + + float max_val = - 1000000.0f; + float min_val = 1000000.0f; + float sum_val = 0.0f; + + for (int i = 0; i < num_iters; i++) + { + *(float4*)(&x_val[0]) = *(float4*)(input_ptr_base + i * gmem_stride); +#pragma unroll + for (uint32_t j = 0; j < 8; j++) + { + float x_temp = convert_to_float(x_val[j]); + max_val = fmaxf(max_val, x_temp); + min_val = fminf(min_val, x_temp); + + if constexpr (sub_mean) + { + sum_val += x_temp; + } + } + } + + // reduce + __shared__ float s_amax_val; + __shared__ float s_mean_val; + + float block_max_val = aphrodite::blockReduceMax(max_val); + float block_min_val = aphrodite::blockReduceMin(min_val); + float block_sum_val; + + if constexpr (sub_mean) + { + block_sum_val = aphrodite::blockReduceSum(sum_val); + } + + if (thread_id == 0) + { + s_mean_val = block_sum_val / fp8_padded_num_tokens; + + if constexpr (sub_mean) + { + s_amax_val = fmaxf(fabsf(block_max_val - s_mean_val), fabsf(block_min_val - s_mean_val)); + mean[batch_id * stride_bz_mean + head_id * stride_h_mean + d_id] = s_mean_val; + } + else + { + s_amax_val = fmaxf(fabsf(block_max_val), fabsf(block_min_val)); + } + + scale[batch_id * stride_bz_scale + head_id * stride_h_scale + d_id] = s_amax_val / scale_max; + } + + __syncthreads(); + + float mean_val = s_mean_val; + float recp_scale = scale_max / s_amax_val; + + // recalculate num_iters to cover all fp8 output tokens to prevent nan in random initialization + uint32_t padded_num_tokens = (num_tokens + pad_size - 1) / pad_size * pad_size; + num_iters = padded_num_tokens / gmem_stride + ((padded_num_tokens % gmem_stride) > thread_id * pack_size); + + for (int i = 0; i < num_iters; i++) + { + *(float4*)(&x_val[0]) = *(float4*)(input_ptr_base + i * gmem_stride); +#pragma unroll + for (uint32_t j = 0; j < 8; j++) + { + x_val_float[j] = convert_to_float(x_val[j]); + if constexpr (sub_mean) + { + x_val_float[j] = (x_val_float[j] - mean_val) * recp_scale; + } + else + { + x_val_float[j] *= recp_scale; + } + } + + floatx4_to_e4m3x4(x_val_fp8, x_val_float, x_val_float + 2); + floatx4_to_e4m3x4(x_val_fp8 + 1, x_val_float + 4, x_val_float + 6); + + *(uint2*)(output_ptr_base + i * gmem_stride) = *(uint2*)(&x_val_fp8[0]); + } +} + +void quant_per_block_int8_cuda( + torch::Tensor input, + torch::Tensor output, + torch::Tensor scale, + double sm_scale, + int64_t block_size, + int64_t tensor_layout) +{ + CHECK_CUDA(input); + CHECK_CUDA(output); + CHECK_CUDA(scale); + + CHECK_DTYPE(output, torch::kInt8); + CHECK_DTYPE(scale, torch::kFloat); + + CHECK_LASTDIM_CONTIGUOUS(input); + CHECK_CONTIGUOUS(output); + CHECK_CONTIGUOUS(scale); + + CHECK_DIMS(input, 4); + CHECK_DIMS(output, 4); + CHECK_DIMS(scale, 3); + + const int batch_size = input.size(0); + const int head_dim = input.size(3); + + int stride_bz_input = input.stride(0); + int stride_bz_output = output.stride(0); + + int num_tokens, num_heads; + int stride_seq_input, stride_h_input, stride_seq_output, stride_h_output; + + if (tensor_layout == 0) + { + num_tokens = input.size(1); + num_heads = input.size(2); + stride_seq_input = input.stride(1); + stride_h_input = input.stride(2); + stride_seq_output = output.stride(1); + stride_h_output = output.stride(2); + } + else + { + num_tokens = input.size(2); + num_heads = input.size(1); + stride_seq_input = input.stride(2); + stride_h_input = input.stride(1); + stride_seq_output = output.stride(2); + stride_h_output = output.stride(1); + } + + auto input_dtype = input.scalar_type(); + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input_dtype, c_type, { + DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, { + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + + CHECK_SHAPE(output, input.size(0), input.size(1), input.size(2), input.size(3)); + CHECK_SHAPE(scale, batch_size, num_heads, (num_tokens + BLOCK_SIZE - 1) / BLOCK_SIZE); + + dim3 grid((num_tokens + BLOCK_SIZE - 1) / BLOCK_SIZE, num_heads, batch_size); + + constexpr int num_pack_per_thread = (BLOCK_SIZE * (HEAD_DIM / 8) + 1023) / 1024; + + dim3 block(BLOCK_SIZE * (HEAD_DIM / 8) / num_pack_per_thread); + + QuantInt8Kernel<<>>( + reinterpret_cast(input.data_ptr()), + nullptr, + output.data_ptr(), + reinterpret_cast(scale.data_ptr()), + sm_scale, + num_tokens, + stride_bz_input, stride_seq_input, stride_h_input, + 0, 0, + stride_bz_output, stride_seq_output, stride_h_output, + scale.stride(0), scale.stride(1) + ); + }); + }); + }); +} + +void quant_per_block_int8_cuda( + torch::Tensor input, + torch::Tensor output, + torch::Tensor scale, + int64_t block_size, + int64_t tensor_layout) +{ + CHECK_CUDA(input); + CHECK_CUDA(output); + CHECK_CUDA(scale); + + CHECK_DTYPE(output, torch::kInt8); + CHECK_DTYPE(scale, torch::kFloat); + + CHECK_LASTDIM_CONTIGUOUS(input); + CHECK_CONTIGUOUS(output); + CHECK_CONTIGUOUS(scale); + + CHECK_DIMS(input, 4); + CHECK_DIMS(output, 4); + CHECK_DIMS(scale, 3); + + const int batch_size = input.size(0); + const int head_dim = input.size(3); + + int stride_bz_input = input.stride(0); + int stride_bz_output = output.stride(0); + + int num_tokens, num_heads; + int stride_seq_input, stride_h_input, stride_seq_output, stride_h_output; + + if (tensor_layout == 0) + { + num_tokens = input.size(1); + num_heads = input.size(2); + stride_seq_input = input.stride(1); + stride_h_input = input.stride(2); + stride_seq_output = output.stride(1); + stride_h_output = output.stride(2); + } + else + { + num_tokens = input.size(2); + num_heads = input.size(1); + stride_seq_input = input.stride(2); + stride_h_input = input.stride(1); + stride_seq_output = output.stride(2); + stride_h_output = output.stride(1); + } + + auto input_dtype = input.scalar_type(); + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input_dtype, c_type, { + DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, { + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + + CHECK_SHAPE(output, input.size(0), input.size(1), input.size(2), input.size(3)); + CHECK_SHAPE(scale, batch_size, num_heads, (num_tokens + BLOCK_SIZE - 1) / BLOCK_SIZE); + + dim3 grid((num_tokens + BLOCK_SIZE - 1) / BLOCK_SIZE, num_heads, batch_size); + + constexpr int num_pack_per_thread = (BLOCK_SIZE * (HEAD_DIM / 8) + 1023) / 1024; + + dim3 block(BLOCK_SIZE * (HEAD_DIM / 8) / num_pack_per_thread); + + QuantInt8Kernel<<>>( + reinterpret_cast(input.data_ptr()), + nullptr, + output.data_ptr(), + reinterpret_cast(scale.data_ptr()), + 0.0f, + num_tokens, + stride_bz_input, stride_seq_input, stride_h_input, + 0, 0, + stride_bz_output, stride_seq_output, stride_h_output, + scale.stride(0), scale.stride(1) + ); + }); + }); + }); +} + +void quant_per_block_int8_fuse_sub_mean_cuda( + torch::Tensor input, + torch::Tensor mean, + torch::Tensor output, + torch::Tensor scale, + int64_t block_size, + int64_t tensor_layout) +{ + CHECK_CUDA(input); + CHECK_CUDA(mean); + CHECK_CUDA(output); + CHECK_CUDA(scale); + + CHECK_DTYPE(output, torch::kInt8); + CHECK_DTYPE(scale, torch::kFloat); + + CHECK_LASTDIM_CONTIGUOUS(input); + CHECK_CONTIGUOUS(mean); + CHECK_CONTIGUOUS(output); + CHECK_CONTIGUOUS(scale); + + CHECK_DIMS(input, 4); + CHECK_DIMS(mean, 3); + CHECK_DIMS(output, 4); + CHECK_DIMS(scale, 3); + + const int batch_size = input.size(0); + const int head_dim = input.size(3); + + int stride_bz_input = input.stride(0); + int stride_bz_output = output.stride(0); + + int num_tokens, num_heads; + int stride_seq_input, stride_h_input, stride_seq_output, stride_h_output; + + if (tensor_layout == 0) + { + num_tokens = input.size(1); + num_heads = input.size(2); + stride_seq_input = input.stride(1); + stride_h_input = input.stride(2); + stride_seq_output = output.stride(1); + stride_h_output = output.stride(2); + } + else + { + num_tokens = input.size(2); + num_heads = input.size(1); + stride_seq_input = input.stride(2); + stride_h_input = input.stride(1); + stride_seq_output = output.stride(2); + stride_h_output = output.stride(1); + } + + auto input_dtype = input.scalar_type(); + auto mean_dtype = mean.scalar_type(); + + TORCH_CHECK(input_dtype == mean_dtype, "Input and mean must have the same data type"); + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input_dtype, c_type, { + DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, { + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + + CHECK_SHAPE(mean, batch_size, num_heads, head_dim); + CHECK_SHAPE(output, input.size(0), input.size(1), input.size(2), input.size(3)); + CHECK_SHAPE(scale, batch_size, num_heads, (num_tokens + BLOCK_SIZE - 1) / BLOCK_SIZE); + + dim3 grid((num_tokens + BLOCK_SIZE - 1) / BLOCK_SIZE, num_heads, batch_size); + + constexpr int num_pack_per_thread = (BLOCK_SIZE * (HEAD_DIM / 8) + 1023) / 1024; + + dim3 block(BLOCK_SIZE * (HEAD_DIM / 8) / num_pack_per_thread); + + QuantInt8Kernel<<>>( + reinterpret_cast(input.data_ptr()), + reinterpret_cast(mean.data_ptr()), + output.data_ptr(), + reinterpret_cast(scale.data_ptr()), + 0.0f, + num_tokens, + stride_bz_input, stride_seq_input, stride_h_input, + mean.stride(0), mean.stride(1), + stride_bz_output, stride_seq_output, stride_h_output, + scale.stride(0), scale.stride(1) + ); + }); + }); + }); +} + +// use block size 128 and warp_block size 32 +void quant_per_warp_int8_cuda( + torch::Tensor input, + torch::Tensor output, + torch::Tensor scale, + int64_t block_size, + int64_t warp_block_size, + int64_t tensor_layout) +{ + CHECK_CUDA(input); + CHECK_CUDA(output); + CHECK_CUDA(scale); + + CHECK_DTYPE(output, torch::kInt8); + CHECK_DTYPE(scale, torch::kFloat); + + CHECK_LASTDIM_CONTIGUOUS(input); + CHECK_CONTIGUOUS(output); + CHECK_CONTIGUOUS(scale); + + CHECK_DIMS(input, 4); + CHECK_DIMS(output, 4); + CHECK_DIMS(scale, 3); + + const int batch_size = input.size(0); + const int head_dim = input.size(3); + + int stride_bz_input = input.stride(0); + int stride_bz_output = output.stride(0); + + int num_tokens, num_heads; + int stride_seq_input, stride_h_input, stride_seq_output, stride_h_output; + + if (tensor_layout == 0) + { + num_tokens = input.size(1); + num_heads = input.size(2); + stride_seq_input = input.stride(1); + stride_h_input = input.stride(2); + stride_seq_output = output.stride(1); + stride_h_output = output.stride(2); + } + else + { + num_tokens = input.size(2); + num_heads = input.size(1); + stride_seq_input = input.stride(2); + stride_h_input = input.stride(1); + stride_seq_output = output.stride(2); + stride_h_output = output.stride(1); + } + + auto input_dtype = input.scalar_type(); + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input_dtype, c_type, { + DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, { + DISPATCH_WARP_BLOCK_SIZE(warp_block_size, WARP_BLOCK_SIZE, { + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + + CHECK_SHAPE(output, input.size(0), input.size(1), input.size(2), input.size(3)); + CHECK_SHAPE(scale, batch_size, num_heads, (num_tokens + BLOCK_SIZE - 1) / BLOCK_SIZE * (BLOCK_SIZE / WARP_BLOCK_SIZE)); + + dim3 grid((num_tokens + BLOCK_SIZE - 1) / BLOCK_SIZE * (BLOCK_SIZE / WARP_BLOCK_SIZE), num_heads, batch_size); + + constexpr int num_pack_per_thread = (WARP_BLOCK_SIZE * (HEAD_DIM / 8) + 1023) / 1024; + + dim3 block(WARP_BLOCK_SIZE * (HEAD_DIM / 8) / num_pack_per_thread); + + QuantInt8Kernel<<>>( + reinterpret_cast(input.data_ptr()), + nullptr, + output.data_ptr(), + reinterpret_cast(scale.data_ptr()), + 0.0, + num_tokens, + stride_bz_input, stride_seq_input, stride_h_input, + 0, 0, + stride_bz_output, stride_seq_output, stride_h_output, + scale.stride(0), scale.stride(1) + ); + }); + }); + }); + }); +} + +void sub_mean_cuda( + torch::Tensor input, + torch::Tensor mean, + torch::Tensor output, + int64_t tensor_layout) +{ + CHECK_CUDA(input); + CHECK_CUDA(mean); + CHECK_CUDA(output); + + CHECK_LASTDIM_CONTIGUOUS(input); + CHECK_CONTIGUOUS(mean); + CHECK_CONTIGUOUS(output); + + CHECK_DIMS(input, 4); + CHECK_DIMS(mean, 3); + CHECK_DIMS(output, 4); + + CHECK_DTYPE(output, torch::kHalf); + + const int batch_size = input.size(0); + const int head_dim = input.size(3); + + int stride_bz_input = input.stride(0); + int stride_bz_output = output.stride(0); + + int num_tokens, num_heads; + int stride_seq_input, stride_h_input, stride_seq_output, stride_h_output; + + if (tensor_layout == 0) + { + num_tokens = input.size(1); + num_heads = input.size(2); + stride_seq_input = input.stride(1); + stride_h_input = input.stride(2); + stride_seq_output = output.stride(1); + stride_h_output = output.stride(2); + } + else + { + num_tokens = input.size(2); + num_heads = input.size(1); + stride_seq_input = input.stride(2); + stride_h_input = input.stride(1); + stride_seq_output = output.stride(2); + stride_h_output = output.stride(1); + } + + auto input_dtype = input.scalar_type(); + auto mean_dtype = mean.scalar_type(); + + TORCH_CHECK(input_dtype == mean_dtype, "Input and mean must have the same data type"); + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input_dtype, c_type, { + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + + CHECK_SHAPE(mean, batch_size, num_heads, head_dim); + CHECK_SHAPE(output, input.size(0), input.size(1), input.size(2), input.size(3)); + + constexpr int BLOCK_SIZE = (HEAD_DIM == 128) ? 64 : 128; + + dim3 grid((num_tokens + BLOCK_SIZE - 1) / BLOCK_SIZE, num_heads, batch_size); + + constexpr int num_pack_per_thread = (BLOCK_SIZE * (HEAD_DIM / 8) + 1023) / 1024; + + dim3 block(BLOCK_SIZE * (HEAD_DIM / 8) / num_pack_per_thread); + + SubMeanKernel<<>>( + reinterpret_cast(input.data_ptr()), + reinterpret_cast(mean.data_ptr()), + reinterpret_cast(output.data_ptr()), + num_tokens, + stride_bz_input, stride_seq_input, stride_h_input, + mean.stride(0), mean.stride(1), + stride_bz_output, stride_seq_output, stride_h_output + ); + }); + }); +} + +void transpose_pad_permute_cuda( + torch::Tensor input, + torch::Tensor output, + int64_t tensor_layout) +{ + CHECK_CUDA(input); + CHECK_CUDA(output); + + CHECK_LASTDIM_CONTIGUOUS(input); + CHECK_CONTIGUOUS(output); + + CHECK_DIMS(input, 4); + CHECK_DIMS(output, 4); + + constexpr int CTA_SIZE = 64; + + const int batch_size = input.size(0); + const int head_dim = input.size(3); + + int stride_bz_input = input.stride(0); + int stride_bz_output = output.stride(0); + + int num_tokens, padded_num_tokens, num_heads; + int stride_seq_input, stride_h_input, stride_d_output, stride_h_output; + + if (tensor_layout == 0) + { + num_tokens = input.size(1); + num_heads = input.size(2); + stride_seq_input = input.stride(1); + stride_h_input = input.stride(2); + stride_d_output = output.stride(1); + stride_h_output = output.stride(2); + + padded_num_tokens = (num_tokens + CTA_SIZE - 1) / CTA_SIZE * CTA_SIZE; + + CHECK_SHAPE(output, batch_size, head_dim, num_heads, padded_num_tokens); + } + else + { + num_tokens = input.size(2); + num_heads = input.size(1); + stride_seq_input = input.stride(2); + stride_h_input = input.stride(1); + stride_d_output = output.stride(2); + stride_h_output = output.stride(1); + + padded_num_tokens = (num_tokens + CTA_SIZE - 1) / CTA_SIZE * CTA_SIZE; + CHECK_SHAPE(output, batch_size, num_heads, head_dim, padded_num_tokens); + } + + auto input_dtype = input.scalar_type(); + auto output_dtype = output.scalar_type(); + + TORCH_CHECK(input_dtype == output_dtype, "Input and output must have the same data type"); + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input_dtype, c_type, { + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + dim3 grid(padded_num_tokens / CTA_SIZE, num_heads, batch_size); + + static_assert(CTA_SIZE * HEAD_DIM <= 8192); + + dim3 block(CTA_SIZE * (HEAD_DIM / 8)); + + TransposePadPermuteKernel<<>>( + reinterpret_cast(input.data_ptr()), + reinterpret_cast(output.data_ptr()), + num_tokens, + stride_bz_input, stride_seq_input, stride_h_input, + stride_bz_output, stride_d_output, stride_h_output + ); + }); + }); +} + +void scale_fuse_quant_cuda( + torch::Tensor input, + torch::Tensor output, + torch::Tensor scale, + int64_t num_tokens, + double scale_max, + int64_t tensor_layout) +{ + CHECK_CUDA(input); + CHECK_CUDA(output); + CHECK_CUDA(scale); + + // CHECK_DTYPE(output, torch::kInt8); + CHECK_DTYPE(scale, torch::kFloat); + + CHECK_CONTIGUOUS(input); + CHECK_CONTIGUOUS(output); + CHECK_CONTIGUOUS(scale); + + CHECK_DIMS(input, 4); + CHECK_DIMS(output, 4); + CHECK_DIMS(scale, 3); + + const int batch_size = input.size(0); + const int num_tokens_padded = input.size(3); + + int stride_bz_input = input.stride(0); + int stride_bz_output = output.stride(0); + + int num_heads, head_dim; + int stride_d_input, stride_h_input, stride_d_output, stride_h_output; + + if (tensor_layout == 0) + { + num_heads = input.size(2); + head_dim = input.size(1); + stride_d_input = input.stride(1); + stride_h_input = input.stride(2); + stride_d_output = output.stride(1); + stride_h_output = output.stride(2); + } + else + { + num_heads = input.size(1); + head_dim = input.size(2); + stride_d_input = input.stride(2); + stride_h_input = input.stride(1); + stride_d_output = output.stride(2); + stride_h_output = output.stride(1); + } + + CHECK_SHAPE(output, input.size(0), input.size(1), input.size(2), input.size(3)); + CHECK_SHAPE(scale, batch_size, num_heads, head_dim); + + constexpr int CTA_SIZE = 256; + + dim3 grid(num_heads, batch_size, head_dim); + dim3 block(CTA_SIZE); + + auto input_dtype = input.scalar_type(); + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input_dtype, c_type, { + MeanScaleKernel<64, false, c_type><<>>( + reinterpret_cast(input.data_ptr()), + reinterpret_cast(output.data_ptr()), + nullptr, + reinterpret_cast(scale.data_ptr()), + scale_max, + num_tokens, + stride_bz_input, stride_d_input, stride_h_input, + stride_bz_output, stride_d_output, stride_h_output, + 0, 0, + scale.stride(0), scale.stride(1) + ); + }); +} + +void mean_scale_fuse_quant_cuda( + torch::Tensor input, + torch::Tensor output, + torch::Tensor mean, + torch::Tensor scale, + int64_t num_tokens, + double scale_max, + int64_t tensor_layout) +{ + CHECK_CUDA(input); + CHECK_CUDA(output); + CHECK_CUDA(mean); + CHECK_CUDA(scale); + + // CHECK_DTYPE(output, torch::kInt8); + CHECK_DTYPE(mean, torch::kFloat); + CHECK_DTYPE(scale, torch::kFloat); + + CHECK_CONTIGUOUS(input); + CHECK_CONTIGUOUS(output); + CHECK_CONTIGUOUS(mean); + CHECK_CONTIGUOUS(scale); + + CHECK_DIMS(input, 4); + CHECK_DIMS(output, 4); + CHECK_DIMS(mean, 3); + CHECK_DIMS(scale, 3); + + const int batch_size = input.size(0); + const int num_tokens_padded = input.size(3); + + int stride_bz_input = input.stride(0); + int stride_bz_output = output.stride(0); + + int num_heads, head_dim; + int stride_d_input, stride_h_input, stride_d_output, stride_h_output; + + if (tensor_layout == 0) + { + num_heads = input.size(2); + head_dim = input.size(1); + stride_d_input = input.stride(1); + stride_h_input = input.stride(2); + stride_d_output = output.stride(1); + stride_h_output = output.stride(2); + } + else + { + num_heads = input.size(1); + head_dim = input.size(2); + stride_d_input = input.stride(2); + stride_h_input = input.stride(1); + stride_d_output = output.stride(2); + stride_h_output = output.stride(1); + } + + CHECK_SHAPE(output, input.size(0), input.size(1), input.size(2), input.size(3)); + CHECK_SHAPE(mean, batch_size, num_heads, head_dim); + CHECK_SHAPE(scale, batch_size, num_heads, head_dim); + + constexpr int CTA_SIZE = 256; + + dim3 grid(num_heads, batch_size, head_dim); + dim3 block(CTA_SIZE); + + auto input_dtype = input.scalar_type(); + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input_dtype, c_type, { + MeanScaleKernel<64, true, c_type><<>>( + reinterpret_cast(input.data_ptr()), + reinterpret_cast(output.data_ptr()), + reinterpret_cast(mean.data_ptr()), + reinterpret_cast(scale.data_ptr()), + scale_max, + num_tokens, + stride_bz_input, stride_d_input, stride_h_input, + stride_bz_output, stride_d_output, stride_h_output, + mean.stride(0), mean.stride(1), + scale.stride(0), scale.stride(1) + ); + }); +} \ No newline at end of file diff --git a/kernels/attention/sage_attn/math.cuh b/kernels/attention/sage_attn/math.cuh new file mode 100644 index 0000000000..0a60c08365 --- /dev/null +++ b/kernels/attention/sage_attn/math.cuh @@ -0,0 +1,155 @@ +/* + * Copyright (c) 2024 by SageAttention team. + * + * This file is based on code from Flashinfer, https://github.com/flashinfer-ai/flashinfer/blob/v0.1.5/include/flashinfer/math.cuh + * Copyright (c) 2023 by FlashInfer team. + * Small modifications made by SageAttention team, 2024 (e.g., renamed namespace). + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#ifndef USHORT_TYPE +#define USHORT_TYPE +typedef unsigned short ushort; +#endif + +namespace math { + +// log2(e) +constexpr float log2e = 1.44269504088896340736f; +constexpr float log2e_recp = 1.0f / log2e; + +__forceinline__ __device__ half2 uint32_as_half2(uint32_t x) { return *(half2*)&x; } + +__forceinline__ __device__ uint32_t half2_as_uint32(half2 x) { return *(uint32_t*)&x; } + +/*! + * \brief Wrapper of PTX ex2.approx instruction, which computes 2^x + * \param x input + */ +__forceinline__ __device__ float ptx_exp2(float x) { + float y; + asm volatile("ex2.approx.ftz.f32 %0, %1;" : "=f"(y) : "f"(x)); + return y; +} + +/*! + * \brief Wrapper of PTX lg2.approx instruction, which computes log2(x) + * \param x input + */ +__forceinline__ __device__ float ptx_log2(float x) { + float y; + asm volatile("lg2.approx.ftz.f32 %0, %1;" : "=f"(y) : "f"(x)); + return y; +} + +/*! + * \brief Wrapper of PTX ex2.approx.f16x2 instruction, which computes 2^x + * \param x input + */ +__forceinline__ __device__ half2 ptx_exp2(half2 x) { + uint32_t y_u32; + uint32_t x_u32 = half2_as_uint32(x); + asm volatile("ex2.approx.f16x2 %0, %1;" : "=r"(y_u32) : "r"(x_u32)); + return uint32_as_half2(y_u32); +} + +/*! + * \brief Wrapper of PTX ex2.approx.f16 instruction, which computes 2^x + * \param x input + */ +__forceinline__ __device__ half ptx_exp2(half x) { + ushort y_u16; + asm volatile("ex2.approx.f16 %0, %1;" : "=h"(y_u16) : "h"(__half_as_ushort(x))); + return __ushort_as_half(y_u16); +} + +/*! + * \brief Wrapper of PTX rcp.approx instruction, which computes 1/x + * \param x input + */ +__forceinline__ __device__ float ptx_rcp(float x) { + float y; + asm volatile("rcp.approx.ftz.f32 %0, %1;" : "=f"(y) : "f"(x)); + return y; +} + +/*! + * \brief Wrapper of PTX shfl.sync.bfly instruction, which performs a butterfly shuffle + * between threads in a warp. + * \param x The value in the source lane + * \param lane_mask The mask to perform thread index xor with: y[i] <- x[i ^ delta] + */ +__forceinline__ __device__ float shfl_xor_sync(float x, int lane_mask) { + float y; + asm volatile("shfl.sync.bfly.b32 %0, %1, %2, 0x1f, 0xffffffff;" + : "=f"(y) + : "f"(x), "r"(lane_mask)); + return y; +} + +/*! + * \brief Wrapper of PTX shfl.sync.bfly instruction on half2, which performs a butterfly + * shuffle between threads in a warp. + * \param x The value in the source lane + * \param lane_mask The mask to perform thread index xor with: y[i] <- x[i ^ lane_mask] + */ +__forceinline__ __device__ half2 shfl_xor_sync(half2 x, int lane_mask) { + return __shfl_xor_sync(0xffffffff, x, lane_mask); +} + +/*! + * \brief Wrapper of PTX rsqrt approximation instruction, which computes 1/sqrt(x) + * \param x input + */ +__forceinline__ __device__ float rsqrt(float x) { + float y; + asm volatile("rsqrt.approx.ftz.f32 %0, %1;" : "=f"(y) : "f"(x)); + return y; +} + +/*! + * \brief Wrapper of PTX tanh.approx.f32 instruction, which computes tanh(x) + * \param x input + */ +__forceinline__ __device__ float tanh(float x) { + float y; + asm volatile("tanh.approx.f32 %0, %1;" : "=f"(y) : "f"(x)); + return y; +} + +/*! + * \brief Wrapper of PTX tanh.approx.f16x2 instruction, which computes tanh(x) + * \param x input + */ +__forceinline__ __device__ half2 tanh(half2 x) { + uint32_t y_u32; + uint32_t x_u32 = half2_as_uint32(x); + asm volatile("tanh.approx.f16x2 %0, %1;" : "=r"(y_u32) : "r"(x_u32)); + return uint32_as_half2(y_u32); +} + +/*! + * \brief Wrapper of PTX tanh.approx.f16 instruction, which computes tanh(x) + * \param x input + */ +__forceinline__ __device__ half tanh(half x) { + ushort y_u16; + asm volatile("tanh.approx.f16 %0, %1;" : "=h"(y_u16) : "h"(__half_as_ushort(x))); + return __ushort_as_half(y_u16); +} + +} // namespace math \ No newline at end of file diff --git a/kernels/attention/sage_attn/mma.cuh b/kernels/attention/sage_attn/mma.cuh new file mode 100644 index 0000000000..7cda7a5e09 --- /dev/null +++ b/kernels/attention/sage_attn/mma.cuh @@ -0,0 +1,722 @@ +/* + * Adapted from Flashinfer, https://github.com/flashinfer-ai/flashinfer/blob/v0.1.5/include/flashinfer/mma.cuh + * Copyright (c) 2023 by FlashInfer team. + * + * Modifications copyright (c) 2024 by SageAttention team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include +#include + +namespace mma{ + +#if (__CUDACC_VER_MAJOR__ >= 11) +#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)) +#define MMA_F16F16F32_M16N8K16_ENABLED +#define MMA_F16F16F16_M16N8K16_ENABLED +#define MMA_S8S8S32_M16N8K32_ENABLED +#define MMA_S4S4S32_M16N8K64_ENABLED +#endif +#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 750)) +#define MMA_F16F16F32_M16N8K8_ENABLED +#define MMA_F16F16F16_M16N8K8_ENABLED +#define LDMATRIX_M8N8X2_ENABLED +#define LDMATRIX_M8N8X4_ENABLED +#endif +#endif + +#if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120400) +#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 890)) +#define MMA_F8F8F32_M16N8K16_ENABLED +#endif +#endif + +#if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120800) +#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 890)) +#define MMA_F8F8F16_M16N8K16_ENABLED +#endif +#endif + +#if defined(__CUDA_ARCH__) +#define RUNTIME_ASSERT(x) __brkpt() +#else +#include +#define RUNTIME_ASSERT(x) assert(0 && x) +#endif + +enum class MMAMode { + kInit = 0U, + kInplaceUpdate = 1U, +}; + +/*! + * \brief Wrapper of PTX ldmatrix m8n8.x2 instruction, loads data from shared memory + * to fragment + * \tparam T data type of the fragment + * \param R pointer to the fragment + * \param smem_ptr pointer to the shared memory + */ +template +__device__ __forceinline__ void ldmatrix_m8n8x2(uint32_t* R, T* smem_ptr) { +#ifdef LDMATRIX_M8N8X2_ENABLED + uint32_t smem_int_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0, %1}, [%2];\n" + : "=r"(R[0]), "=r"(R[1]) + : "r"(smem_int_ptr)); +#else + RUNTIME_ASSERT("Unsupported CUDA architecture for ldmatrix instruction"); +#endif +} + +/*! + * \brief Wrapper of PTX ldmatrix m8n8.x4 instruction, loads data from shared memory + * to fragment + * \tparam T data type of the fragment + * \param R pointer to the fragment + * \param smem_ptr pointer to the shared memory + */ +template +__device__ __forceinline__ void ldmatrix_m8n8x4(uint32_t* R, T* smem_ptr) { +#ifdef LDMATRIX_M8N8X4_ENABLED + uint32_t smem_int_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(R[0]), "=r"(R[1]), "=r"(R[2]), "=r"(R[3]) + : "r"(smem_int_ptr)); +#else + RUNTIME_ASSERT("Unsupported CUDA architecture for ldmatrix instruction"); +#endif +} + +/*! + * \brief Wrapper of PTX ldmatrix m8n8.x4 transposed instruction, loads data from + * shared memory to fragment and transposes the fragment + * \tparam T data type of the fragment + * \param R pointer to the fragment + * \param smem_ptr pointer to the shared memory + */ +template +__device__ __forceinline__ void ldmatrix_m8n8x4_trans(uint32_t* R, T* smem_ptr) { +#ifdef LDMATRIX_M8N8X4_ENABLED + uint32_t smem_int_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile("ldmatrix.sync.aligned.trans.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(R[0]), "=r"(R[1]), "=r"(R[2]), "=r"(R[3]) + : "r"(smem_int_ptr)); +#else + RUNTIME_ASSERT("Unsupported CUDA architecture for ldmatrix instruction"); +#endif +} + +/*! + * \brief Wrapper of the mma m16n8k16 instruction for row major and column major f16 matrix + * multiplication, accumulated in f32. + * \tparam mma_mode The mode of mma instruction, either kInit or kInplaceUpdate + * \param C pointer to the accumulator + * \param A pointer to the fragment of matrix A + * \param B pointer to the fragment of matrix B + */ +template +__device__ __forceinline__ void mma_sync_m16n8k16_row_col_f16f16f32(float* C, uint32_t* A, + uint32_t* B) { +#ifdef MMA_F16F16F32_M16N8K16_ENABLED + // ! only support half dtype now + if constexpr (mma_mode == MMAMode::kInplaceUpdate) + { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(C[0]), "f"(C[1]), + "f"(C[2]), "f"(C[3])); + } + else if constexpr (mma_mode == MMAMode::kInit) + { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(0.f), "f"(0.f), + "f"(0.f), "f"(0.f)); + } +#else + RUNTIME_ASSERT("Unsupported CUDA architecture for mma instruction"); +#endif +} + +/*! + * \brief Wrapper of the mma m16n16k16 instruction for row major and column major f16 matrix + * multiplication, accumulated in f32. + * \tparam mma_mode The mode of mma instruction, either kInit or kInplaceUpdate + * \param C pointer to the accumulator + * \param A pointer to the fragment of matrix A + * \param B pointer to the fragment of matrix B + */ +template +__device__ __forceinline__ void mma_sync_m16n16k16_row_col_f16f16f32(float* C, uint32_t* A, + uint32_t* B) { +#ifdef MMA_F16F16F32_M16N8K16_ENABLED + // ! only support half dtype now + if constexpr (mma_mode == MMAMode::kInplaceUpdate) + { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(C[0]), "f"(C[1]), + "f"(C[2]), "f"(C[3])); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "f"(C[4]), "f"(C[5]), + "f"(C[6]), "f"(C[7])); + } + else if constexpr (mma_mode == MMAMode::kInit) + { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(0.f), "f"(0.f), + "f"(0.f), "f"(0.f)); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "f"(0.f), "f"(0.f), + "f"(0.f), "f"(0.f)); + } +#else + RUNTIME_ASSERT("Unsupported CUDA architecture for mma instruction"); +#endif +} + +/*! + * \brief Wrapper of the mma m16n8k16 instruction for row major and column major f16 matrix + * multiplication, accumulated in f16. + * \tparam mma_mode The mode of mma instruction, either kInit or kInplaceUpdate + * \param C pointer to the accumulator + * \param A pointer to the fragment of matrix A + * \param B pointer to the fragment of matrix B + */ +template +__device__ __forceinline__ void mma_sync_m16n8k16_row_col_f16f16f16(uint32_t* C, uint32_t* A, + uint32_t* B) { +#ifdef MMA_F16F16F16_M16N8K16_ENABLED + if constexpr (mma_mode == MMAMode::kInplaceUpdate) + { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(C[0]), "=r"(C[1]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1])); + } + else if constexpr (mma_mode == MMAMode::kInit) + { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(C[0]), "=r"(C[1]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(0), "r"(0)); + } +#else + RUNTIME_ASSERT("Unsupported CUDA architecture for mma instruction"); +#endif +} + +/*! + * \brief Wrapper of the mma m16n16k16 instruction for row major and column major f16 matrix + * multiplication, accumulated in f16. + * \tparam mma_mode The mode of mma instruction, either kInit or kInplaceUpdate + * \param C pointer to the accumulator + * \param A pointer to the fragment of matrix A + * \param B pointer to the fragment of matrix B + */ +template +__device__ __forceinline__ void mma_sync_m16n16k16_row_col_f16f16f16(uint32_t* C, uint32_t* A, + uint32_t* B) { +#ifdef MMA_F16F16F16_M16N8K16_ENABLED + if constexpr (mma_mode == MMAMode::kInplaceUpdate) + { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(C[0]), "=r"(C[1]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1])); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(C[2]), "=r"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "r"(C[2]), "r"(C[3])); + } + else if constexpr (mma_mode == MMAMode::kInit) + { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(C[0]), "=r"(C[1]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(0), "r"(0)); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(C[2]), "=r"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "r"(0), "r"(0)); + } +#else + RUNTIME_ASSERT("Unsupported CUDA architecture for mma instruction"); +#endif +} + +/*! + * \brief Wrapper of the mma m16n8k32 instruction for row major and column major int8 matrix + * multiplication, accumulated in int32. + * \tparam mma_mode The mode of mma instruction, either kInit or kInplaceUpdate + * \param C pointer to the accumulator + * \param A pointer to the fragment of matrix A + * \param B pointer to the fragment of matrix B + */ +template +__device__ __forceinline__ void mma_sync_m16n8k32_row_col_s8s8s32(int32_t* C, uint32_t* A, + uint32_t* B) { +#ifdef MMA_S8S8S32_M16N8K32_ENABLED + if constexpr (mma_mode == MMAMode::kInplaceUpdate) + { + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(C[0]), "=r"(C[1]), "=r"(C[2]), "=r"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), + "r"(C[2]), "r"(C[3])); + } + else if constexpr (mma_mode == MMAMode::kInit) + { + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(C[0]), "=r"(C[1]), "=r"(C[2]), "=r"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(0), "r"(0), + "r"(0), "r"(0)); + } +#else + RUNTIME_ASSERT("Unsupported CUDA architecture for mma instruction"); +#endif +} + +/*! + * \brief Wrapper of the mma m16n16k32 instruction for row major and column major int8 matrix + * multiplication, accumulated in int32. + * \tparam mma_mode The mode of mma instruction, either kInit or kInplaceUpdate + * \param C pointer to the accumulator + * \param A pointer to the fragment of matrix A + * \param B pointer to the fragment of matrix B + */ +template +__device__ __forceinline__ void mma_sync_m16n16k32_row_col_s8s8s32(int32_t* C, uint32_t* A, + uint32_t* B) { +#ifdef MMA_S8S8S32_M16N8K32_ENABLED + if constexpr (mma_mode == MMAMode::kInplaceUpdate) + { + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(C[0]), "=r"(C[1]), "=r"(C[2]), "=r"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), + "r"(C[2]), "r"(C[3])); + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(C[4]), "=r"(C[5]), "=r"(C[6]), "=r"(C[7]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "r"(C[4]), "r"(C[5]), + "r"(C[6]), "r"(C[7])); + } + else if constexpr (mma_mode == MMAMode::kInit) + { + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(C[0]), "=r"(C[1]), "=r"(C[2]), "=r"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(0), "r"(0), + "r"(0), "r"(0)); + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(C[4]), "=r"(C[5]), "=r"(C[6]), "=r"(C[7]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "r"(0), "r"(0), + "r"(0), "r"(0)); + } +#else + RUNTIME_ASSERT("Unsupported CUDA architecture for mma instruction"); +#endif +} + +/*! + * \brief Wrapper of the mma m16n8k32 instruction for row major and column major int4 matrix + * multiplication, accumulated in int32. + * \tparam mma_mode The mode of mma instruction, either kInit or kInplaceUpdate + * \param C pointer to the accumulator + * \param A pointer to the fragment of matrix A + * \param B pointer to the fragment of matrix B + */ +template +__device__ __forceinline__ void mma_sync_m16n8k64_row_col_s4s4s32(int32_t* C, uint32_t* A, + uint32_t* B) { +#ifdef MMA_S4S4S32_M16N8K64_ENABLED + if constexpr (mma_mode == MMAMode::kInplaceUpdate) + { + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(C[0]), "=r"(C[1]), "=r"(C[2]), "=r"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), + "r"(C[2]), "r"(C[3])); + } + else if constexpr (mma_mode == MMAMode::kInit) + { + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(C[0]), "=r"(C[1]), "=r"(C[2]), "=r"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(0), "r"(0), + "r"(0), "r"(0)); + } +#else + RUNTIME_ASSERT("Unsupported CUDA architecture for mma instruction"); +#endif +} + +/*! + * \brief Wrapper of the mma m16n16k64 instruction for row major and column major int4 matrix + * multiplication, accumulated in int32. + * \tparam mma_mode The mode of mma instruction, either kInit or kInplaceUpdate + * \param C pointer to the accumulator + * \param A pointer to the fragment of matrix A + * \param B pointer to the fragment of matrix B + */ +template +__device__ __forceinline__ void mma_sync_m16n16k64_row_col_s4s4s32(int32_t* C, uint32_t* A, + uint32_t* B) { +#ifdef MMA_S4S4S32_M16N8K64_ENABLED + if constexpr (mma_mode == MMAMode::kInplaceUpdate) + { + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(C[0]), "=r"(C[1]), "=r"(C[2]), "=r"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), + "r"(C[2]), "r"(C[3])); + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(C[4]), "=r"(C[5]), "=r"(C[6]), "=r"(C[7]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "r"(C[4]), "r"(C[5]), + "r"(C[6]), "r"(C[7])); + } + else if constexpr (mma_mode == MMAMode::kInit) + { + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(C[0]), "=r"(C[1]), "=r"(C[2]), "=r"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(0), "r"(0), + "r"(0), "r"(0)); + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(C[4]), "=r"(C[5]), "=r"(C[6]), "=r"(C[7]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "r"(0), "r"(0), + "r"(0), "r"(0)); + } +#else + RUNTIME_ASSERT("Unsupported CUDA architecture for mma instruction"); +#endif +} + +/*! + * \brief Wrapper of the mma m16n8k32 instruction for row major and column major fp8 e4m3 matrix + * multiplication, accumulated in fp32. + * \tparam mma_mode The mode of mma instruction, either kInit or kInplaceUpdate + * \param C pointer to the accumulator + * \param A pointer to the fragment of matrix A + * \param B pointer to the fragment of matrix B + */ +template +__device__ __forceinline__ void mma_sync_m16n8k32_row_col_f8f8f32(float* C, uint32_t* A, + uint32_t* B) { +#ifdef MMA_F8F8F32_M16N8K16_ENABLED + if constexpr (mma_mode == MMAMode::kInplaceUpdate) + { + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(C[0]), "f"(C[1]), + "f"(C[2]), "f"(C[3])); + } + else if constexpr (mma_mode == MMAMode::kInit) + { + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(0.f), "f"(0.f), + "f"(0.f), "f"(0.f)); + } +#else + RUNTIME_ASSERT("Unsupported CUDA architecture for mma instruction"); +#endif +} + +/*! + * \brief Wrapper of the mma m16n16k32 instruction for row major and column major fp8 matrix + * multiplication, accumulated in fp16. + * \tparam mma_mode The mode of mma instruction, either kInit or kInplaceUpdate + * \param C pointer to the accumulator + * \param A pointer to the fragment of matrix A + * \param B pointer to the fragment of matrix B + */ +template +__device__ __forceinline__ void mma_sync_m16n16k32_row_col_f8f8f16(uint32_t* C_uint32, uint32_t* A, + uint32_t* B) { + //uint32_t* C_uint32 = reinterpret_cast(C); +#ifdef MMA_F8F8F16_M16N8K16_ENABLED + if constexpr (mma_mode == MMAMode::kInplaceUpdate) + { + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.f16.e4m3.e4m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(C_uint32[0]), "=r"(C_uint32[1]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(C_uint32[0]), "r"(C_uint32[1])); + + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.f16.e4m3.e4m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(C_uint32[2]), "=r"(C_uint32[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "r"(C_uint32[2]), "r"(C_uint32[3])); + } + else if constexpr (mma_mode == MMAMode::kInit) + { + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.f16.e4m3.e4m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(C_uint32[0]), "=r"(C_uint32[1]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(0), "r"(0)); + + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.f16.e4m3.e4m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(C_uint32[2]), "=r"(C_uint32[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "r"(0), "r"(0)); + } +#else + RUNTIME_ASSERT("Unsupported CUDA architecture for mma instruction"); +#endif +} + + + +/*! + * \brief Wrapper of the mma m16n16k32 instruction for row major and column major fp8 matrix + * multiplication, accumulated in fp32. + * \tparam mma_mode The mode of mma instruction, either kInit or kInplaceUpdate + * \param C pointer to the accumulator + * \param A pointer to the fragment of matrix A + * \param B pointer to the fragment of matrix B + */ +template +__device__ __forceinline__ void mma_sync_m16n16k32_row_col_f8f8f32(float* C, uint32_t* A, + uint32_t* B) { +#ifdef MMA_F8F8F32_M16N8K16_ENABLED + if constexpr (mma_mode == MMAMode::kInplaceUpdate) + { + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(C[0]), "f"(C[1]), + "f"(C[2]), "f"(C[3])); + + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "f"(C[4]), "f"(C[5]), + "f"(C[6]), "f"(C[7])); + } + else if constexpr (mma_mode == MMAMode::kInit) + { + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(0.f), "f"(0.f), + "f"(0.f), "f"(0.f)); + + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "f"(0.f), "f"(0.f), + "f"(0.f), "f"(0.f)); + } +#else + RUNTIME_ASSERT("Unsupported CUDA architecture for mma instruction"); +#endif +} + +/*! + * \brief Use mma instructions to compute rowsum. + */ +__device__ __forceinline__ void rowsum_f16f16f32(float* d, uint32_t* s) { +#ifdef MMA_F16F16F32_M16N8K16_ENABLED + asm volatile( + "{\n" + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, _, %1, _}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, 0., %9, 0.};\n" + "}\n" + : "=f"(d[0]), "=f"(d[1]) + : "r"(s[0]), "r"(s[1]), "r"(s[2]), "r"(s[3]), "r"(1006648320), // 1006648320 packs two 1.0f in half precision + "r"(1006648320), "f"(d[0]), "f"(d[1])); +#else + RUNTIME_ASSERT("Unsupported CUDA architecture for mma instruction"); +#endif +} + +/*! + * \brief Use mma instructions to compute rowsum. + */ +__device__ __forceinline__ void rowsum_f8f8f32(float* d, uint32_t* s) { +#ifdef MMA_F8F8F32_M16N8K16_ENABLED + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 " + "{%0, _, %1, _}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, 0., %9, 0.};\n" + : "=f"(d[0]), "=f"(d[1]) + : "r"(s[0]), "r"(s[1]), "r"(s[2]), "r"(s[3]), "r"(943208504), "r"(943208504), // 943208504 packs four 1.0f in e4m3 + "f"(d[0]), "f"(d[1])); +#else + RUNTIME_ASSERT("Unsupported CUDA architecture for mma instruction"); +#endif +} + +} // namespace mma diff --git a/kernels/attention/sage_attn/numeric_conversion.cuh b/kernels/attention/sage_attn/numeric_conversion.cuh new file mode 100644 index 0000000000..88f80f501e --- /dev/null +++ b/kernels/attention/sage_attn/numeric_conversion.cuh @@ -0,0 +1,149 @@ +/* + * Copyright (c) 2024 by SageAttention team. + * + * Inspired by CUTLASS, https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/numeric_conversion.h + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include +#include +#include + +#if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120400) +#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 890)) +#define FP8_CAST_ENABLED +#endif +#endif + +#if defined(__CUDA_ARCH__) +#define RUNTIME_ASSERT(x) __brkpt() +#else +#include +#define RUNTIME_ASSERT(x) assert(0 && x) +#endif + +__device__ __forceinline__ void unpack_half2_from_uint32_to_float(float* dest, uint32_t source) { + uint16_t h0 = source & 0xFFFF; + uint16_t h1 = (source >> 16) & 0xFFFF; + asm("cvt.f32.f16 %0, %1;" : "=f"(dest[0]) : "h"(h0)); + asm("cvt.f32.f16 %0, %1;" : "=f"(dest[1]) : "h"(h1)); +} + +__device__ __forceinline__ void floatx4_to_e4m3x4(uint32_t *dest, float *source0, float *source1) +{ +#ifdef FP8_CAST_ENABLED + asm volatile( \ + "{\n" \ + ".reg .b16 lo;\n" \ + ".reg .b16 hi;\n" \ + "cvt.rn.satfinite.e4m3x2.f32 lo, %2, %1;\n" \ + "cvt.rn.satfinite.e4m3x2.f32 hi, %4, %3;\n" \ + "mov.b32 %0, {lo, hi};\n" \ + "}" \ + : "=r"(dest[0]) : "f"(source0[0]), "f"(source0[1]), "f"(source1[0]), "f"(source1[1])); +#else + RUNTIME_ASSERT("Unsupported CUDA architecture for FP8 CAST instruction"); +#endif +} + +__device__ __forceinline__ void floatx4_to_e5m2x4(uint32_t *dest, float *source0, float *source1) +{ +#ifdef FP8_CAST_ENABLED + asm volatile( \ + "{\n" \ + ".reg .b16 lo;\n" \ + ".reg .b16 hi;\n" \ + "cvt.rn.satfinite.e5m2x2.f32 lo, %2, %1;\n" \ + "cvt.rn.satfinite.e5m2x2.f32 hi, %4, %3;\n" \ + "mov.b32 %0, {lo, hi};\n" \ + "}" \ + : "=r"(dest[0]) : "f"(source0[0]), "f"(source1[1]), "f"(source1[0]), "f"(source1[1])); +#else + RUNTIME_ASSERT("Unsupported CUDA architecture for FP8 CAST instruction"); +#endif +} + +__device__ __forceinline__ void halfx4_to_e4m3x4(uint32_t *dest, uint32_t *source0, uint32_t *source1) +{ +#ifdef FP8_CAST_ENABLED + asm volatile( \ + "{\n" \ + ".reg .b16 lo;\n" \ + ".reg .b16 hi;\n" \ + "cvt.rn.satfinite.e4m3x2.f16x2 lo, %1;\n" \ + "cvt.rn.satfinite.e4m3x2.f16x2 hi, %2;\n" \ + "mov.b32 %0, {lo, hi};\n" \ + "}" \ + : "=r"(dest[0]) : "r"(source0[0]), "r"(source1[0])); +#else + RUNTIME_ASSERT("Unsupported CUDA architecture for FP8 CAST instruction"); +#endif +} + +__device__ __forceinline__ void halfx4_to_e5m2x4(uint32_t *dest, uint32_t *source0, uint32_t *source1) +{ +#ifdef FP8_CAST_ENABLED + asm volatile( \ + "{\n" \ + ".reg .b16 lo;\n" \ + ".reg .b16 hi;\n" \ + "cvt.rn.satfinite.e5m2x2.f16x2 lo, %1;\n" \ + "cvt.rn.satfinite.e5m2x2.f16x2 hi, %2;\n" \ + "mov.b32 %0, {lo, hi};\n" \ + "}" \ + : "=r"(dest[0]) : "r"(source0[0]), "r"(source1[0])); +#else + RUNTIME_ASSERT("Unsupported CUDA architecture for FP8 CAST instruction"); +#endif +} + +__device__ __forceinline__ void e4m3x4_to_halfx4(uint32_t *dest0, uint32_t *dest1, uint32_t *source) +{ +#ifdef FP8_CAST_ENABLED + asm volatile( \ + "{\n" \ + ".reg .b16 lo, hi;\n" \ + "mov.b32 {lo, hi}, %2;\n" \ + "cvt.rn.f16x2.e4m3x2 %0, lo;\n" \ + "cvt.rn.f16x2.e4m3x2 %1, hi;\n" \ + "}\n" : "=r"(dest0[0]), "=r"(dest1[0]) : "r"(source[0])); +#else + RUNTIME_ASSERT("Unsupported CUDA architecture for FP8 CAST instruction"); +#endif +} + +__device__ __forceinline__ void e5m2x4_to_halfx4(uint32_t *dest0, uint32_t *dest1, uint32_t *source) +{ +#ifdef FP8_CAST_ENABLED + asm volatile( \ + "{\n" \ + ".reg .b16 lo, hi;\n" \ + "mov.b32 {lo, hi}, %2;\n" \ + "cvt.rn.f16x2.e5m2x2 %0, lo;\n" \ + "cvt.rn.f16x2.e5m2x2 %1, hi;\n" \ + "}\n" : "=r"(dest0[0]), "=r"(dest1[0]) : "r"(source[0])); +#else + RUNTIME_ASSERT("Unsupported CUDA architecture for FP8 CAST instruction"); +#endif +} + +__device__ __forceinline__ int8_t float_to_int8_rn(float x) +{ + uint32_t dst; + asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x)); + return reinterpret_cast(dst); +} \ No newline at end of file diff --git a/kernels/attention/sage_attn/permuted_smem.cuh b/kernels/attention/sage_attn/permuted_smem.cuh new file mode 100644 index 0000000000..e831b9390c --- /dev/null +++ b/kernels/attention/sage_attn/permuted_smem.cuh @@ -0,0 +1,196 @@ +/* + * Adapted from Flashinfer, https://github.com/flashinfer-ai/flashinfer/blob/v0.1.5/include/flashinfer/permuted_smem.cuh + * Copyright (c) 2023 by FlashInfer team. + * + * Modifications copyright (c) 2024 by SageAttention team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include + +#include + +#include "cp_async.cuh" +#include "mma.cuh" + +enum class SwizzleMode { + k32B, // for k32B mode, a line of shared memory must have 32B (16 half value) + k64B, // for k64B mode, a line of shared memory must have 64B (32 half value) + k128B, // 128B already spans all banks in shared memory. a line of shared memory can have multiple 128B. +}; + +// Use 128bit as the granularity to fetch/store data per thread to maximize memory bandwidth +using b128_t = uint4; + +/*! + * \brief A stateless shared memory wrapper that uses templates to avoid runtime conditionals. It makes sure + * that access to consecutive rows idx in the same column idx will make full use of the shared memory bank through + * permutation in the granularity of 128bit. + * + * This struct treats all offsets to be the number of `b128_t` elements. It is designed to be stateless, + * meaning it does not maintain any information about the current pointer position. The offset returnd by + * the struct can be used to access the shared memory through the provided interface. + * + * The struct guarantees that the read to permuted offset (i, j) will be the value stored in permuted offset (i, j). + * We assume that shared memory operation operates on at least two consecutive 128-bit values in a row within a warp. + * Under this assumption, we do not permute for k32B mode. + */ +template +struct smem_t { + // The base pointer. + b128_t* base; + // How many b128_t value a row contains + // uint32_t stride; + + __device__ __forceinline__ smem_t() : base(nullptr) {} + template + __device__ __forceinline__ smem_t(T* base) : base((b128_t*)base) { + if constexpr (swizzle_mode == SwizzleMode::k128B) { + static_assert(stride % 8 == 0, "Stride must be multiple of 8 for 128B swizzle mode"); + } else if constexpr (swizzle_mode == SwizzleMode::k64B) { + static_assert(stride == 4, "Stride must be 4 for 64B swizzle mode"); + } else if constexpr (swizzle_mode == SwizzleMode::k32B) { + static_assert(stride == 2, "Stride must be 2 for 32B swizzle mode"); + } else { + static_assert(swizzle_mode != swizzle_mode, "Unsupported swizzle mode"); + } + } + + /*! + * \brief Set the base pointer. + */ + template + __device__ __forceinline__ void set_base(T* new_base) { + base = (b128_t*)new_base; + } + + /*! + * \brief Compute the element offset given coordinates in a permuted shared memory. + * \param i The row index. + * \param j The column index. + */ + static __device__ __forceinline__ uint32_t get_permuted_offset(const uint32_t &i, const uint32_t &j) { + if constexpr (swizzle_mode == SwizzleMode::k128B) { + return i * stride + (j ^ (i % 8)); + } else if constexpr (swizzle_mode == SwizzleMode::k64B) { + return i * stride + (j ^ ((i / 2) % 4)); + } else if constexpr (swizzle_mode == SwizzleMode::k32B) { + return i * stride + j; + } + } + + /*! + * \tparam step_size The step size to advance the offset in the permuted shared memory. + * \param offset The current offset. + */ + template + static __device__ __forceinline__ uint32_t advance_offset_by_column(const uint32_t &offset) { + if constexpr (swizzle_mode == SwizzleMode::k128B) { + static_assert(step_size % 8 == 0, + "Unsupported step size"); + return offset + step_size; + } else if constexpr (swizzle_mode == SwizzleMode::k64B) { + static_assert(step_size == 4, "Unsupported step size"); + return offset + step_size; + } else if constexpr (swizzle_mode == SwizzleMode::k32B) { + static_assert(step_size == 2, "Unsupported step size"); + return offset + step_size; + } + } + + // ! use with care + template + static __device__ __forceinline__ uint32_t advance_offset_by_column(const uint32_t &offset, const uint32_t &step_idx) { + if constexpr (swizzle_mode == SwizzleMode::k128B) { + static_assert(step_size == 2 || step_size == 4 || step_size % 8 == 0, + "Unsupported step size"); + if constexpr (step_size == 2) { + return (offset ^ (0x2 + (0x4 * (step_idx % 2 == 1)))) + (step_idx % 4 == 3) * 8; + } else if constexpr (step_size == 4) { + return (offset ^ 0x4) + (step_idx % 2 == 1) * 8; + } else { + // step_size % 8 == 0 + return offset + step_size; + } + } else if constexpr (swizzle_mode == SwizzleMode::k64B) { + static_assert(step_size == 2 || step_size == 4, "Unsupported step size"); + if constexpr (step_size == 2) { + return (offset ^ 0x2) + (step_idx % 2 == 1) * 4; + } else { + return offset + step_size; + } + } else if constexpr (swizzle_mode == SwizzleMode::k32B) { + return offset + step_size; + } + } + + template + static __device__ __forceinline__ uint32_t advance_offset_by_row(const uint32_t &offset) { + if constexpr (swizzle_mode == SwizzleMode::k128B) { + static_assert(step_size == 4 || step_size % 8 == 0, "Unsupported step size"); + if constexpr (step_size == 4) { + return (offset ^ 0x4) + step_size * stride; + } else { + // step_size % 8 == 0 + return offset + step_size * stride; + } + } else if constexpr (swizzle_mode == SwizzleMode::k64B) { + static_assert(step_size == 4 || step_size % 8 == 0, "Unsupported step size"); + if constexpr (step_size == 4) { + return (offset ^ 0x2) + step_size * stride; + } else { + // step_size % 8 == 0 + return offset + step_size * stride; + } + } else if constexpr (swizzle_mode == SwizzleMode::k32B) { + return offset + step_size * stride; + } + } + + __device__ __forceinline__ void ldmatrix_m8n8x2(const uint32_t &offset, uint32_t* R) const { + b128_t* smem_ptr = base + offset; + mma::ldmatrix_m8n8x2(R, smem_ptr); + } + + __device__ __forceinline__ void ldmatrix_m8n8x4(const uint32_t &offset, uint32_t* R) const { + b128_t* smem_ptr = base + offset; + mma::ldmatrix_m8n8x4(R, smem_ptr); + } + + __device__ __forceinline__ void ldmatrix_m8n8x4_trans(const uint32_t &offset, uint32_t* R) const { + b128_t* smem_ptr = base + offset; + mma::ldmatrix_m8n8x4_trans(R, smem_ptr); + } + + template + __device__ __forceinline__ void load_128b_async(const uint32_t &offset, const T* gptr, bool predicate) const { + b128_t* smem_ptr = base + offset; + cp_async::pred_load_128b( + smem_ptr, reinterpret_cast(gptr), predicate); + } + + template + __device__ __forceinline__ void load_128b_async(const uint32_t &offset, const T* gptr) const { + b128_t* smem_ptr = base + offset; + cp_async::load_128b(smem_ptr, reinterpret_cast(gptr)); + } + + template + __device__ __forceinline__ void store_128b(const uint32_t &offset, T* gptr) const { + *reinterpret_cast(gptr) = *(base + offset); + } +}; \ No newline at end of file diff --git a/kernels/attention/sage_attn/qattn/attn_utils.cuh b/kernels/attention/sage_attn/qattn/attn_utils.cuh new file mode 100644 index 0000000000..471501d936 --- /dev/null +++ b/kernels/attention/sage_attn/qattn/attn_utils.cuh @@ -0,0 +1,992 @@ +/* + * Copyright (c) 2024 by SageAttention team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include "../utils.cuh" +#include +#include +#include + +#include "../cp_async.cuh" +#include "../mma.cuh" +#include "../permuted_smem.cuh" +#include "../numeric_conversion.cuh" + +#define WARP_SIZE 32 + +#define S_FP8_OFFSET 8.807f +#define S_FP8_OFFSET_EXP 6680.8477f +#define S_FP8_OFFSET_EXP_INV 0.0022326917f + +#define div_ceil(M, N) (((M) + (N)-1) / (N)) + +enum class MaskMode { + kNone = 0, + kCausal = 1, +}; + +enum class DataType { + kHalf, + kInt8, + kInt4, + kE4M3, + kE5M2, +}; + +enum class QuantGranularity { + kPerTensor = 0, + kPerBlock = 1, + kPerWarp = 2, + kPerThread = 3, +}; + +enum class ComputeUnit { + kTensorCore, + kCudaCore, +}; + +__device__ __forceinline__ uint32_t get_warp_id() +{ + return threadIdx.y; +} + +__device__ __forceinline__ uint32_t get_lane_id() +{ + return threadIdx.x; +} + +template +__device__ __forceinline__ uint32_t get_warp_idx_q() +{ + return get_warp_id() / num_warps_k; +} + +template +__device__ __forceinline__ uint32_t get_warp_idx_k() +{ + return get_warp_id() % num_warps_k; +} + +template +__device__ __forceinline__ void load_global_to_share(T **lane_ptr, uint32_t &smem_offset, + const uint32_t &gmem_stride, + const smem_t &smem) +{ + static_assert(global_to_shared_copy_lines_per_warp_per_iter * global_to_shared_line_lanes == WARP_SIZE); + static_assert(std::is_same::value || std::is_same::value); + + constexpr uint32_t pack_size = std::is_same::value ? 8 : 16; + +#pragma unroll + for (uint32_t i = 0; i < smem_iters_col; i++) + { +#pragma unroll + for (uint32_t j = 0; j < smem_iters_row; j++) + { + smem.load_128b_async(smem_offset, *lane_ptr); + *lane_ptr += (global_to_shared_line_lanes * pack_size); + smem_offset = smem.advance_offset_by_column(smem_offset); + } + + smem_offset = smem.advance_offset_by_row(smem_offset - (smem_iters_row * global_to_shared_line_lanes)); + *lane_ptr += ((global_to_shared_copy_lines_per_warp_per_iter * gmem_stride) - (smem_iters_row * global_to_shared_line_lanes * pack_size)); + } + smem_offset -= (smem_iters_col * global_to_shared_copy_lines_per_warp_per_iter * stride); + *lane_ptr += (CTA - smem_iters_col * global_to_shared_copy_lines_per_warp_per_iter) * gmem_stride; +} + +// with predicate +template +__device__ __forceinline__ void load_global_to_share(T **lane_ptr, uint32_t &smem_offset, + const uint32_t &gmem_stride, + const smem_t &smem, uint32_t base_idx, uint32_t max_len) +{ + static_assert(global_to_shared_copy_lines_per_warp_per_iter * global_to_shared_line_lanes == WARP_SIZE); + static_assert(std::is_same::value || std::is_same::value); + + constexpr uint32_t pack_size = std::is_same::value ? 8 : 16; + +#pragma unroll + for (uint32_t i = 0; i < smem_iters_col; i++) + { +#pragma unroll + for (uint32_t j = 0; j < smem_iters_row; j++) + { + smem.load_128b_async(smem_offset, *lane_ptr, base_idx < max_len); + *lane_ptr += (global_to_shared_line_lanes * pack_size); + smem_offset = smem.advance_offset_by_column(smem_offset); + } + + smem_offset = smem.advance_offset_by_row(smem_offset - (smem_iters_row * global_to_shared_line_lanes)); + *lane_ptr += ((global_to_shared_copy_lines_per_warp_per_iter * gmem_stride) - (smem_iters_row * global_to_shared_line_lanes * pack_size)); + base_idx += global_to_shared_copy_lines_per_warp_per_iter; + } + smem_offset -= (smem_iters_col * global_to_shared_copy_lines_per_warp_per_iter * stride); + *lane_ptr += (CTA - smem_iters_col * global_to_shared_copy_lines_per_warp_per_iter) * gmem_stride; +} + +template +__device__ __forceinline__ void load_fp8_V_global_to_share(int8_t **lane_ptr, uint32_t &smem_offset, + const uint32_t &gmem_stride, + const smem_t &smem) +{ + static_assert(global_to_shared_copy_lines_per_warp_per_iter * global_to_shared_line_lanes == WARP_SIZE); + constexpr uint32_t pack_size_fp8 = 16; + +#pragma unroll + for (uint32_t i = 0; i < smem_iters_col; i++) + { +#pragma unroll + for (uint32_t j = 0; j < smem_iters_row; j++) + { + smem.load_128b_async(smem_offset, *lane_ptr); + *lane_ptr += (global_to_shared_line_lanes * pack_size_fp8); + smem_offset = smem.advance_offset_by_column(smem_offset); + } + + smem_offset = smem.advance_offset_by_row(smem_offset - (smem_iters_row * global_to_shared_line_lanes)); + *lane_ptr += ((global_to_shared_copy_lines_per_warp_per_iter * gmem_stride) - (smem_iters_row * global_to_shared_line_lanes * pack_size_fp8)); + } + smem_offset -= (smem_iters_col * global_to_shared_copy_lines_per_warp_per_iter * stride); + // for QK: *lane_ptr += (CTA - smem_iters_col * global_to_shared_copy_lines_per_warp_per_iter) * gmem_stride; + *lane_ptr += CTA; // ! prevent underflow + *lane_ptr -= (smem_iters_col * global_to_shared_copy_lines_per_warp_per_iter) * gmem_stride; +} + +template +__device__ __forceinline__ void compute_int_qk(const smem_t &smem_Q, const smem_t &smem_K, int32_t RS[][num_tiles_k][8], uint32_t &offset_Q, uint32_t &offset_K) +{ + static_assert(DTypeQK == DataType::kInt8 || DTypeQK == DataType::kInt4); + + uint32_t RQ[num_tiles_q][4]; + uint32_t RK[4]; + + // the first iteration, mma mode is kInit +#pragma unroll + for (uint32_t iter = 0; iter < 1; iter++) + { + // load RQ +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + smem_Q.ldmatrix_m8n8x4(offset_Q, RQ[fq]); + offset_Q = smem_Q.advance_offset_by_row<16>(offset_Q); + } + // ! using permutation invariance + offset_Q = smem_Q.advance_offset_by_column<2>(offset_Q - (num_tiles_q * 16 * stride), iter); + +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { + // load RK + smem_K.ldmatrix_m8n8x4(offset_K, RK); + offset_K = smem_K.advance_offset_by_row<16>(offset_K); + + // mma +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + if constexpr (DTypeQK == DataType::kInt8) + { + mma::mma_sync_m16n16k32_row_col_s8s8s32(RS[fq][fk], RQ[fq], RK); + } + else if constexpr (DTypeQK == DataType::kInt4) + { + mma::mma_sync_m16n16k64_row_col_s4s4s32(RS[fq][fk], RQ[fq], RK); + } + } + } + offset_K = smem_K.advance_offset_by_column<2>(offset_K - (num_tiles_k * 16 * stride), iter); + } + + // following iteration, mma mode is kInplace +#pragma unroll + for (uint32_t iter = 1; iter < num_tiles_qk_inner; iter++) + { + // load RQ +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + smem_Q.ldmatrix_m8n8x4(offset_Q, RQ[fq]); + offset_Q = smem_Q.advance_offset_by_row<16>(offset_Q); + } + offset_Q = smem_Q.advance_offset_by_column<2>(offset_Q - (num_tiles_q * 16 * stride), iter); + +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { + // load RK + smem_K.ldmatrix_m8n8x4(offset_K, RK); + offset_K = smem_K.advance_offset_by_row<16>(offset_K); + + // mma +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + if constexpr (DTypeQK == DataType::kInt8) + { + mma::mma_sync_m16n16k32_row_col_s8s8s32(RS[fq][fk], RQ[fq], RK); + } + else if constexpr (DTypeQK == DataType::kInt4) + { + mma::mma_sync_m16n16k64_row_col_s4s4s32(RS[fq][fk], RQ[fq], RK); + } + } + } + offset_K = smem_K.advance_offset_by_column<2>(offset_K - (num_tiles_k * 16 * stride), iter); + } + + offset_Q -= (2 * num_tiles_qk_inner); + offset_K -= (2 * num_tiles_qk_inner); +} + +// for case when num_tiles_qk_inner = 1 +template +__device__ __forceinline__ void compute_int_qk(const smem_t &smem_K, int32_t RS[][num_tiles_k][8], uint32_t RQ[][4], uint32_t offset_K) +{ + static_assert(DTypeQK == DataType::kInt8 || DTypeQK == DataType::kInt4); + static_assert(num_tiles_qk_inner == 1); + + uint32_t RK[4]; + + // mma mode is kInit +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { + // load RK + smem_K.ldmatrix_m8n8x4(offset_K, RK); + offset_K = smem_K.advance_offset_by_row<16>(offset_K); + + // mma +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + if constexpr (DTypeQK == DataType::kInt8) + { + mma::mma_sync_m16n16k32_row_col_s8s8s32(RS[fq][fk], RQ[fq], RK); + } + else if constexpr (DTypeQK == DataType::kInt4) + { + mma::mma_sync_m16n16k64_row_col_s4s4s32(RS[fq][fk], RQ[fq], RK); + } + } + } +} + +template +__device__ __forceinline__ void apply_causal_mask(const uint32_t &Q_idx_lane_base, const uint32_t &K_idx_lane_base, DTypeQKAccum RS[][num_tiles_k][8]) +{ +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { +#pragma unroll + for (uint32_t k = 0; k < 8; k++) + { + const uint32_t q_idx = Q_idx_lane_base + fq * 16 + 8 * ((k % 4) / 2); + const uint32_t kv_idx = K_idx_lane_base + fk * 16 + 8 * (k / 4) + k % 2; + const bool out_of_boundary = (kv_idx > q_idx); + + if constexpr (std::is_same::value) + { + RS[fq][fk][k] = (out_of_boundary ? -5000000.0f : RS[fq][fk][k]); + } + else if constexpr (std::is_same::value) + { + RS[fq][fk][k] = (out_of_boundary ? __float2half_rn(-50000.0f) : RS[fq][fk][k]); + } + } + } + } +} + +template +__device__ __forceinline__ void apply_out_of_bound_mask(const uint32_t &K_idx_lane_base, DTypeQKAccum RS[][num_tiles_k][8], const uint32_t &kv_len) +{ +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { +#pragma unroll + for (uint32_t k = 0; k < 8; k++) + { + const uint32_t kv_idx = K_idx_lane_base + fk * 16 + 8 * (k / 4) + k % 2; + const bool out_of_boundary = (kv_idx >= kv_len); + + if constexpr (std::is_same::value) + { + RS[fq][fk][k] = (out_of_boundary ? -5000000.0f : RS[fq][fk][k]); + } + else if constexpr (std::is_same::value) + { + RS[fq][fk][k] = (out_of_boundary ? __float2half_rn(-50000.0f) : RS[fq][fk][k]); + } + } + } + } +} + +// for DTypeQKAccum float +template +__device__ __forceinline__ void update_mdo(float RS[][num_tiles_k][8], DTypeSVAccum RO[][num_tiles_v][8], float m[][2], float d[][2], const float &sm_scale) +{ + static_assert(std::is_same::value || (!use_half_o_scale)); +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t k = 0; k < 2; k++) + { + // assign the smallest value possible + float m_prev = m[fq][k]; + float m_temp = -5000000.0f; +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { + float m_local = max(max(RS[fq][fk][k * 2 + 0], RS[fq][fk][k * 2 + 1]), + max(RS[fq][fk][k * 2 + 4], RS[fq][fk][k * 2 + 5])); + m_temp = max(m_temp, m_local); + } + + if constexpr (!fuse_scale) + { + if constexpr (exp_offset) + { + m_temp = fmaf(m_temp, sm_scale, -S_FP8_OFFSET); + } + else + { + m_temp *= sm_scale; + } + } + else if constexpr (exp_offset) + { + m_temp += (-S_FP8_OFFSET); + } + + // exchange element with the 4 threads in the row + m_temp = max(m_temp, __shfl_xor_sync(0xffffffff, m_temp, 0x1)); // 0 exchange with 1, 2 exchange with 3 + m_temp = max(m_temp, __shfl_xor_sync(0xffffffff, m_temp, 0x2)); // 0 exchange with 2, 1 exchange with 3 + + m[fq][k] = max(m[fq][k], m_temp); + + float o_scale = math::ptx_exp2(m_prev - m[fq][k]); + + // update denominator + d[fq][k] *= o_scale; + + half2 o_scale2; + if constexpr (use_half_o_scale) + { + o_scale2 = __floats2half2_rn(o_scale, o_scale); + } + + // update RO +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { + if constexpr (std::is_same::value) + { + RO[fq][fv][k * 2 + 0] *= o_scale; + RO[fq][fv][k * 2 + 1] *= o_scale; + RO[fq][fv][k * 2 + 4] *= o_scale; + RO[fq][fv][k * 2 + 5] *= o_scale; + } + else if constexpr (std::is_same::value) + { + if constexpr (use_half_o_scale) + { + ((half2*)RO[fq][fv])[k] = __hmul2(((half2*)RO[fq][fv])[k], o_scale2); + ((half2*)RO[fq][fv])[k + 2] = __hmul2(((half2*)RO[fq][fv])[k + 2], o_scale2); + } + else + { + RO[fq][fv][k * 2 + 0] = __float2half_rn(__half2float(RO[fq][fv][k * 2 + 0]) * o_scale); + RO[fq][fv][k * 2 + 1] = __float2half_rn(__half2float(RO[fq][fv][k * 2 + 1]) * o_scale); + RO[fq][fv][k * 2 + 4] = __float2half_rn(__half2float(RO[fq][fv][k * 2 + 4]) * o_scale); + RO[fq][fv][k * 2 + 5] = __float2half_rn(__half2float(RO[fq][fv][k * 2 + 5]) * o_scale); + } + } + } + + // raise RS to exponent + float negative_m = -m[fq][k]; +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { + if constexpr (fuse_scale) + { + RS[fq][fk][k * 2 + 0] = math::ptx_exp2(RS[fq][fk][k * 2 + 0] + negative_m); + RS[fq][fk][k * 2 + 1] = math::ptx_exp2(RS[fq][fk][k * 2 + 1] + negative_m); + RS[fq][fk][k * 2 + 4] = math::ptx_exp2(RS[fq][fk][k * 2 + 4] + negative_m); + RS[fq][fk][k * 2 + 5] = math::ptx_exp2(RS[fq][fk][k * 2 + 5] + negative_m); + } + else + { + RS[fq][fk][k * 2 + 0] = math::ptx_exp2(fmaf(RS[fq][fk][k * 2 + 0], sm_scale, negative_m)); + RS[fq][fk][k * 2 + 1] = math::ptx_exp2(fmaf(RS[fq][fk][k * 2 + 1], sm_scale, negative_m)); + RS[fq][fk][k * 2 + 4] = math::ptx_exp2(fmaf(RS[fq][fk][k * 2 + 4], sm_scale, negative_m)); + RS[fq][fk][k * 2 + 5] = math::ptx_exp2(fmaf(RS[fq][fk][k * 2 + 5], sm_scale, negative_m)); + } + } + } + } +} + +template +__device__ __forceinline__ void RS_32_to_16(T RS[][num_tiles_k][8], uint32_t RS_16[][num_tiles_k][4]) +{ + static_assert(sizeof(T) == 4); +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { + ((half2*)RS_16[fq][fk])[0] = __float22half2_rn(((float2*)RS[fq][fk])[0]); + ((half2*)RS_16[fq][fk])[1] = __float22half2_rn(((float2*)RS[fq][fk])[1]); + ((half2*)RS_16[fq][fk])[2] = __float22half2_rn(((float2*)RS[fq][fk])[2]); + ((half2*)RS_16[fq][fk])[3] = __float22half2_rn(((float2*)RS[fq][fk])[3]); + } + } +} + +template +__device__ __forceinline__ void RS_32_to_8(float RS[][num_tiles_k][8], uint32_t RS_8[][num_tiles_k / 2][4]) +{ +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k / 2; fk++) + { + floatx4_to_e4m3x4(RS_8[fq][fk], RS[fq][fk * 2 + 0], RS[fq][fk * 2 + 0] + 4); + floatx4_to_e4m3x4(RS_8[fq][fk] + 1, RS[fq][fk * 2 + 0] + 2, RS[fq][fk * 2 + 0] + 6); + floatx4_to_e4m3x4(RS_8[fq][fk] + 2, RS[fq][fk * 2 + 1], RS[fq][fk * 2 + 1] + 4); + floatx4_to_e4m3x4(RS_8[fq][fk] + 3, RS[fq][fk * 2 + 1] + 2, RS[fq][fk * 2 + 1] + 6); + } + } +} + +template +__device__ __forceinline__ void RS_16_to_8(uint32_t RS[][num_tiles_k][4], uint32_t RS_8[][num_tiles_k / 2][4]) +{ +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k / 2; fk++) + { + halfx4_to_e4m3x4(RS_8[fq][fk], RS[fq][fk * 2 + 0], RS[fq][fk * 2 + 0] + 2); + halfx4_to_e4m3x4(RS_8[fq][fk] + 1, RS[fq][fk * 2 + 0] + 1, RS[fq][fk * 2 + 0] + 3); + halfx4_to_e4m3x4(RS_8[fq][fk] + 2, RS[fq][fk * 2 + 1], RS[fq][fk * 2 + 1] + 2); + halfx4_to_e4m3x4(RS_8[fq][fk] + 3, RS[fq][fk * 2 + 1] + 1, RS[fq][fk * 2 + 1] + 3); + } + } +} + +template +__device__ __forceinline__ void RS_8_to_16(uint32_t RS_8[][num_tiles_k / 2][4], uint32_t RS[][num_tiles_k][4]) +{ +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k / 2; fk++) + { + e4m3x4_to_halfx4(RS[fq][fk * 2 + 0], RS[fq][fk * 2 + 0] + 2, RS_8[fq][fk]); + e4m3x4_to_halfx4(RS[fq][fk * 2 + 0] + 1, RS[fq][fk * 2 + 0] + 3, RS_8[fq][fk] + 1); + e4m3x4_to_halfx4(RS[fq][fk * 2 + 1], RS[fq][fk * 2 + 1] + 2, RS_8[fq][fk] + 2); + e4m3x4_to_halfx4(RS[fq][fk * 2 + 1] + 1, RS[fq][fk * 2 + 1] + 3, RS_8[fq][fk] + 3); + } + } +} + +template +__device__ __forceinline__ void accumulate_d(T RS[][num_tiles_k][(compute_unit == ComputeUnit::kTensorCore)? 4 : 8], float d[][2]) +{ + // for compute unit cuda core, RS is float + // for compute unit tensor core, RS is packed half + static_assert((std::is_same::value && compute_unit == ComputeUnit::kCudaCore) || + (std::is_same::value && compute_unit == ComputeUnit::kTensorCore)); + +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { + if constexpr (compute_unit == ComputeUnit::kTensorCore) + { + // full accumulate with tensor core + mma::rowsum_f16f16f32(d[fq], (uint32_t*)(RS[fq][fk])); + } + else if constexpr (compute_unit == ComputeUnit::kCudaCore) + { + // partial accumulate with cuda core + d[fq][0] += RS[fq][fk][0] + RS[fq][fk][1] + RS[fq][fk][4] + RS[fq][fk][5]; + d[fq][1] += RS[fq][fk][2] + RS[fq][fk][3] + RS[fq][fk][6] + RS[fq][fk][7]; + } + } + } +} + +template +__device__ __forceinline__ void accumulate_d_f8(uint32_t RS[][num_tiles_k / 2][4], float d[][2]) +{ +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k / 2; fk++) + { + mma::rowsum_f8f8f32(d[fq], RS[fq][fk]); + } + } +} + +template +__device__ __forceinline__ void compute_fp16_sv(const smem_t &smem_V, uint32_t RS_f16[][num_tiles_k][4], DTypeSVAccum RO[][num_tiles_v][8], float d[][2]) +{ + uint32_t smem_V_row_base = get_warp_idx_k() * (num_tiles_k * 16) + get_lane_id() % 16; + uint32_t smem_V_col_base = get_lane_id() / 16; +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { + // load RV + uint32_t RV[4]; + uint32_t offset_V = (smem_V).get_permuted_offset(smem_V_row_base + fk * 16, smem_V_col_base + fv * 2); + smem_V.ldmatrix_m8n8x4_trans(offset_V, RV); +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + if constexpr (std::is_same::value) + { + mma::mma_sync_m16n16k16_row_col_f16f16f32(RO[fq][fv], RS_f16[fq][fk], RV); + } + else if constexpr (std::is_same::value) + { + mma::mma_sync_m16n16k16_row_col_f16f16f16((uint32_t*)RO[fq][fv], RS_f16[fq][fk], RV); + } + } + } + } +} + +template +__device__ __forceinline__ void compute_fp16_sv_permuted(const smem_t &smem_V, T RS_f16[][num_tiles_k][RS_width], DTypeSVAccum RO[][num_tiles_v][8], float d[][2], uint32_t &offset_V) +{ + static_assert(sizeof(T) == 4); + + // ! be sure you know what you are doing +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { + // load RV + uint32_t RV[4]; + smem_V.ldmatrix_m8n8x4_trans(offset_V, RV); +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + if constexpr (std::is_same::value) + { + mma::mma_sync_m16n16k16_row_col_f16f16f32(RO[fq][fv], (uint32_t*)(RS_f16[fq][fk]), RV); + } + else if constexpr (std::is_same::value) + { + mma::mma_sync_m16n16k16_row_col_f16f16f16((uint32_t*)RO[fq][fv], (uint32_t*)(RS_f16[fq][fk]), RV); + } + } + + offset_V = smem_V.advance_offset_by_column<2>(offset_V, fv); + } + offset_V = smem_V.advance_offset_by_row<16>(offset_V - (2 * num_tiles_v)); + } + + // make offset_V their original value + offset_V -= (16 * num_tiles_k * stride); +} + +template +__device__ __forceinline__ void compute_fp16_sv_permuted_inst_buf(const smem_t &smem_V, T RS_f16[][num_tiles_k][RS_width], DTypeSVAccum RO[][num_tiles_v][8], float d[][2], uint32_t &offset_V) +{ + static_assert(sizeof(T) == 4); + static_assert(std::is_same::value); + + uint32_t RO_inst_buf[num_tiles_q][num_tiles_v][4]; + + // ! be sure you know what you are doing +#pragma unroll + for (uint32_t fk = 0; fk < 1; fk++) + { +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { + // load RV + uint32_t RV[4]; + smem_V.ldmatrix_m8n8x4_trans(offset_V, RV); +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + { + mma::mma_sync_m16n16k16_row_col_f16f16f16((uint32_t*)RO_inst_buf[fq][fv], (uint32_t*)(RS_f16[fq][fk]), RV); + } + } + + offset_V = smem_V.advance_offset_by_column<2>(offset_V, fv); + } + offset_V = smem_V.advance_offset_by_row<16>(offset_V - (2 * num_tiles_v)); + } + +#pragma unroll + for (uint32_t fk = 1; fk < num_tiles_k; fk++) + { +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { + // load RV + uint32_t RV[4]; + smem_V.ldmatrix_m8n8x4_trans(offset_V, RV); +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + { + mma::mma_sync_m16n16k16_row_col_f16f16f16((uint32_t*)RO_inst_buf[fq][fv], (uint32_t*)(RS_f16[fq][fk]), RV); + } + } + + offset_V = smem_V.advance_offset_by_column<2>(offset_V, fv); + } + offset_V = smem_V.advance_offset_by_row<16>(offset_V - (2 * num_tiles_v)); + } + + // accumulate into RO +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { + RO[fq][fv][0] += __half2float(((half2*)RO_inst_buf[fq][fv])[0].x); + RO[fq][fv][1] += __half2float(((half2*)RO_inst_buf[fq][fv])[0].y); + RO[fq][fv][2] += __half2float(((half2*)RO_inst_buf[fq][fv])[1].x); + RO[fq][fv][3] += __half2float(((half2*)RO_inst_buf[fq][fv])[1].y); + RO[fq][fv][4] += __half2float(((half2*)RO_inst_buf[fq][fv])[2].x); + RO[fq][fv][5] += __half2float(((half2*)RO_inst_buf[fq][fv])[2].y); + RO[fq][fv][6] += __half2float(((half2*)RO_inst_buf[fq][fv])[3].x); + RO[fq][fv][7] += __half2float(((half2*)RO_inst_buf[fq][fv])[3].y); + } + } + + // make offset_V their original value + offset_V -= (16 * num_tiles_k * stride); +} + +template +__device__ __forceinline__ void normalize_d(DTypeSVAccum RO[][num_tiles_v][8], DTypeQKAccum m[][2], float d[][2]) +{ + if constexpr (compute_unit == ComputeUnit::kCudaCore) + { + // accumulate_d performs partial accumulation with cuda core + // aggregate d +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t k = 0; k < 2; k++) + { + d[fq][k] += __shfl_xor_sync(0xffffffff, d[fq][k], 0x1); // sum 0 and 1, 2 and 3 + d[fq][k] += __shfl_xor_sync(0xffffffff, d[fq][k], 0x2); // sum 0 and 2, 1 and 3 + } + } + } + + // divide O by d + float d_rcp[num_tiles_q][2]; +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t k = 0; k < 2; k++) + { + // TODO: check m to prevent nan + d_rcp[fq][k] = math::ptx_rcp(d[fq][k]); + } + } + +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { +#pragma unroll + for (uint32_t k = 0; k < 8; k++) + { + if constexpr (std::is_same::value) + { + RO[fq][fv][k] *= d_rcp[fq][(k % 4) / 2]; + } + else if constexpr (std::is_same::value) + { + RO[fq][fv][k] = __float2half_rn(__half2float(RO[fq][fv][k]) * d_rcp[fq][(k % 4) / 2]); + } + } + } + } +} + +template +__device__ __forceinline__ void compute_fp8_sv(const smem_t &smem_V, uint32_t RS_f8[][num_tiles_k / 2][4], DTypeSVAccum RO[][num_tiles_v][8], float d[][2]) +{ + uint32_t smem_V_row_base = get_lane_id() % 8 + (get_lane_id() / 16) * 8; + // uint32_t smem_V_col_base = get_warp_idx_k() * ((16 * num_tiles_k) / 16) + (get_lane_id() / 8) % 2; + uint32_t smem_V_col_base = (get_lane_id() / 8) % 2; +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k / 2; fk++) + { + uint32_t offset_V = smem_V.get_permuted_offset(smem_V_row_base, smem_V_col_base + fk * 2); +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { + // load RV + uint32_t RV[4]; + // uint32_t offset_V = (smem_V).get_permuted_offset(smem_V_row_base + fv * 16, smem_V_col_base + fk * 2); + smem_V.ldmatrix_m8n8x4(offset_V, RV); +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + if constexpr (std::is_same::value) + { + mma::mma_sync_m16n16k32_row_col_f8f8f32(RO[fq][fv], RS_f8[fq][fk], RV); + } + else if constexpr (std::is_same::value) + { + // ! Not Implemented + } + } + offset_V = smem_V.advance_offset_by_row<16>(offset_V); + } + } +} + +template +__device__ __forceinline__ void compute_fp8_sv_inst_buf(const smem_t &smem_V, uint32_t RS_f8[][num_tiles_k / 2][4], DTypeSVAccum RO[][num_tiles_v][8], float d[][2]) +{ + uint32_t smem_V_row_base = get_lane_id() % 8 + (get_lane_id() / 16) * 8; + // uint32_t smem_V_col_base = get_warp_idx_k() * ((16 * num_tiles_k) / 16) + (get_lane_id() / 8) % 2; + uint32_t smem_V_col_base = (get_lane_id() / 8) % 2; + + float RO_inst_buf[num_tiles_q][num_tiles_v][8]; + +#pragma unroll + for (uint32_t fk = 0; fk < 1; fk++) + { + uint32_t offset_V = smem_V.get_permuted_offset(smem_V_row_base, smem_V_col_base + fk * 2); +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { + // load RV + uint32_t RV[4]; + // uint32_t offset_V = (smem_V).get_permuted_offset(smem_V_row_base + fv * 16, smem_V_col_base + fk * 2); + smem_V.ldmatrix_m8n8x4(offset_V, RV); +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + if constexpr (std::is_same::value) + { + mma::mma_sync_m16n16k32_row_col_f8f8f32(RO_inst_buf[fq][fv], RS_f8[fq][fk], RV); + } + else if constexpr (std::is_same::value) + { + // ! Not Implemented + } + } + offset_V = smem_V.advance_offset_by_row<16>(offset_V); + } + } + +#pragma unroll + for (uint32_t fk = 1; fk < num_tiles_k / 2; fk++) + { + uint32_t offset_V = smem_V.get_permuted_offset(smem_V_row_base, smem_V_col_base + fk * 2); +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { + // load RV + uint32_t RV[4]; + // uint32_t offset_V = (smem_V).get_permuted_offset(smem_V_row_base + fv * 16, smem_V_col_base + fk * 2); + smem_V.ldmatrix_m8n8x4(offset_V, RV); +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + if constexpr (std::is_same::value) + { + mma::mma_sync_m16n16k32_row_col_f8f8f32(RO_inst_buf[fq][fv], RS_f8[fq][fk], RV); + } + else if constexpr (std::is_same::value) + { + // ! Not Implemented + } + } + offset_V = smem_V.advance_offset_by_row<16>(offset_V); + } + } + +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { + RO[fq][fv][0] += RO_inst_buf[fq][fv][0]; + RO[fq][fv][1] += RO_inst_buf[fq][fv][1]; + RO[fq][fv][2] += RO_inst_buf[fq][fv][2]; + RO[fq][fv][3] += RO_inst_buf[fq][fv][3]; + RO[fq][fv][4] += RO_inst_buf[fq][fv][4]; + RO[fq][fv][5] += RO_inst_buf[fq][fv][5]; + RO[fq][fv][6] += RO_inst_buf[fq][fv][6]; + RO[fq][fv][7] += RO_inst_buf[fq][fv][7]; + } + } +} + +template +__device__ __forceinline__ void compute_fp8_sv_inst_buf_fp16_accu(const smem_t &smem_V, uint32_t RS_f8[][num_tiles_k / 2][4], DTypeSVAccum RO[][num_tiles_v][8], float d[][2]) +{ + uint32_t smem_V_row_base = get_lane_id() % 8 + (get_lane_id() / 16) * 8; + // uint32_t smem_V_col_base = get_warp_idx_k() * ((16 * num_tiles_k) / 16) + (get_lane_id() / 8) % 2; + uint32_t smem_V_col_base = (get_lane_id() / 8) % 2; + + uint32_t RO_int32[num_tiles_q][num_tiles_v][4]; + +#pragma unroll + for (uint32_t fk = 0; fk < 1; fk++) + { + uint32_t offset_V = smem_V.get_permuted_offset(smem_V_row_base, smem_V_col_base + fk * 2); +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { + // load RV + uint32_t RV[4]; + // uint32_t offset_V = (smem_V).get_permuted_offset(smem_V_row_base + fv * 16, smem_V_col_base + fk * 2); + smem_V.ldmatrix_m8n8x4(offset_V, RV); +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + if constexpr (std::is_same::value) + { + //mma::mma_sync_m16n16k32_row_col_f8f8f32(RO_inst_buf[fq][fv], RS_f8[fq][fk], RV); + mma::mma_sync_m16n16k32_row_col_f8f8f16(RO_int32[fq][fv], RS_f8[fq][fk], RV); + } + else if constexpr (std::is_same::value) + { + // ! Not Implemented + } + } + offset_V = smem_V.advance_offset_by_row<16>(offset_V); + } + } + +#pragma unroll + for (uint32_t fk = 1; fk < num_tiles_k / 2; fk++) + { + uint32_t offset_V = smem_V.get_permuted_offset(smem_V_row_base, smem_V_col_base + fk * 2); +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { + // load RV + uint32_t RV[4]; + // uint32_t offset_V = (smem_V).get_permuted_offset(smem_V_row_base + fv * 16, smem_V_col_base + fk * 2); + smem_V.ldmatrix_m8n8x4(offset_V, RV); +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + if constexpr (std::is_same::value) + { + //mma::mma_sync_m16n16k32_row_col_f8f8f32(RO_inst_buf[fq][fv], RS_f8[fq][fk], RV); + mma::mma_sync_m16n16k32_row_col_f8f8f16(RO_int32[fq][fv], RS_f8[fq][fk], RV); + } + else if constexpr (std::is_same::value) + { + // ! Not Implemented + } + } + offset_V = smem_V.advance_offset_by_row<16>(offset_V); + } + } + float RO_tmp_float[2]; +#pragma unroll + for(int i = 0; i < num_tiles_q; i++){ +#pragma unroll + for(int j = 0; j < num_tiles_v; j++){ + #pragma unroll + for(int k = 0; k < 4; k++){ + unpack_half2_from_uint32_to_float(RO_tmp_float, RO_int32[i][j][k]); + RO[i][j][k * 2 + 0] += RO_tmp_float[0]; + RO[i][j][k * 2 + 1] += RO_tmp_float[1]; + } + } + } + +// #pragma unroll +// for (uint32_t fq = 0; fq < num_tiles_q; fq++) +// { +// #pragma unroll +// for (uint32_t fv = 0; fv < num_tiles_v; fv++) +// { +// RO[fq][fv][0] += RO_inst_buf[fq][fv][0]; +// RO[fq][fv][1] += RO_inst_buf[fq][fv][1]; +// RO[fq][fv][2] += RO_inst_buf[fq][fv][2]; +// RO[fq][fv][3] += RO_inst_buf[fq][fv][3]; +// RO[fq][fv][4] += RO_inst_buf[fq][fv][4]; +// RO[fq][fv][5] += RO_inst_buf[fq][fv][5]; +// RO[fq][fv][6] += RO_inst_buf[fq][fv][6]; +// RO[fq][fv][7] += RO_inst_buf[fq][fv][7]; +// } +// } +} \ No newline at end of file diff --git a/kernels/attention/sage_attn/qattn/qk_int_sv_f16_cuda_sm80.cu b/kernels/attention/sage_attn/qattn/qk_int_sv_f16_cuda_sm80.cu new file mode 100644 index 0000000000..f3249f4ed9 --- /dev/null +++ b/kernels/attention/sage_attn/qattn/qk_int_sv_f16_cuda_sm80.cu @@ -0,0 +1,1380 @@ +/* + * Copyright (c) 2024 by SageAttention team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../utils.cuh" +#include +#include +#include + +#include "../cp_async.cuh" +#include "../mma.cuh" +#include "../permuted_smem.cuh" +#include "../math.cuh" +#include "../dispatch_utils.h" + +#include "attn_utils.cuh" + +#define PACK_SIZE_QK 16 // as if it is int8 +#define PACK_SIZE_V 8 // fp16 +#define PACK_SIZE_O 8 // fp16 + +// treat as if int8 tensor core +#define MMA_QK_M 16 +#define MMA_QK_N 16 +#define MMA_QK_K 32 + +// fp16 tensor core +#define MMA_SV_M 16 +#define MMA_SV_N 16 +#define MMA_SV_K 16 + +template +__global__ void qk_int_sv_f16_attn_kernel(int8_t *__restrict__ Q, int8_t *__restrict__ K, half *__restrict__ V, DTypeOut *__restrict__ O, float *__restrict__ Lse, + float *__restrict__ Q_scale, float *__restrict__ K_scale, DTypeOut *__restrict__ V_mean, + const uint32_t qo_len, const uint32_t kv_len, const uint32_t num_kv_groups, + const uint32_t stride_bz_q, const uint32_t stride_seq_q, const uint32_t stride_h_q, + const uint32_t stride_bz_k, const uint32_t stride_seq_k, const uint32_t stride_h_k, + const uint32_t stride_bz_v, const uint32_t stride_seq_v, const uint32_t stride_h_v, + const uint32_t stride_bz_o, const uint32_t stride_seq_o, const uint32_t stride_h_o, + float sm_scale) +{ + // compile time check + static_assert(DTypeQK == DataType::kInt8 || DTypeQK == DataType::kInt4, "DTypeQK must be int8 or int4"); + static_assert(Q_GRAN == QuantGranularity::kPerBlock || Q_GRAN == QuantGranularity::kPerWarp || Q_GRAN == QuantGranularity::kPerThread, "Q_GRAN must be kPerBlock, kPerWarp or kPerThread"); + static_assert(K_GRAN == QuantGranularity::kPerBlock || K_GRAN == QuantGranularity::kPerWarp || K_GRAN == QuantGranularity::kPerThread, "K_GRAN must be kPerBlock, kPerWarp or kPerThread"); + static_assert(std::is_same::value || !use_inst_buffer, "use_inst_buffer only supports DTypeSVAccum as float"); + static_assert(std::is_same::value || std::is_same::value, "DTypeSVAccum must be float or half"); + static_assert(std::is_same::value || std::is_same::value, "DTypeOut must be half or nv_bfloat16"); + static_assert(head_dim % 64 == 0, "head_dim must be a multiple of 64"); + static_assert(!fuse_v_mean || std::is_same::value, "fuse_v_mean only supports half"); + static_assert(CTA_Q / CTA_K <= 2); // for efficient causal implementation + + using DTypeOut2 = typename std::conditional::value, half2, nv_bfloat162>::type; + + constexpr uint32_t num_warps_q = CTA_Q / WARP_Q; + constexpr uint32_t num_warps_k = CTA_K / WARP_K; + constexpr uint32_t num_warps = num_warps_q * num_warps_k; + constexpr uint32_t num_tiles_q = WARP_Q / MMA_QK_M; + constexpr uint32_t num_tiles_k = WARP_K / MMA_QK_N; + constexpr uint32_t num_tiles_qk_inner = (DTypeQK == DataType::kInt8) ? (head_dim / MMA_QK_K) : (head_dim / 2 / MMA_QK_K); + constexpr uint32_t num_tiles_v = head_dim / MMA_SV_N; + + constexpr uint32_t QK_SMEM_STRIDE = (DTypeQK == DataType::kInt8) ? (head_dim) : (head_dim / 2); + constexpr uint32_t O_SMEM_STRIDE = head_dim; + constexpr uint32_t V_SMEM_STRIDE = head_dim; + + extern __shared__ int8_t smem[]; + + const uint32_t lane_id = get_lane_id(); + const uint32_t warp_id = get_warp_id(); + + // maximize L2 hit rate + const uint32_t batch_id = blockIdx.z; + const uint32_t bx = blockIdx.x; + const uint32_t num_qo_heads = gridDim.y; + const uint32_t head_id = blockIdx.y; + + // transfer to base 2 instead of base e with better numerical efficiency + sm_scale *= math::log2e; + + // RS holds the fragment of S + int32_t RS[num_tiles_q][num_tiles_k][8]; + DTypeSVAccum RO[num_tiles_q][num_tiles_v][8]; + float m[num_tiles_q][2]; // max + float d[num_tiles_q][2]; // denominator + + uint32_t q_scale_idx, k_scale_idx; + + if constexpr (Q_GRAN == QuantGranularity::kPerBlock) + { + const uint32_t num_block_q = gridDim.x; + q_scale_idx = batch_id * num_qo_heads * num_block_q + head_id * num_block_q + bx; + } + else if constexpr (Q_GRAN == QuantGranularity::kPerWarp) + { + const uint32_t num_warp_block_q = gridDim.x * num_warps_q; + q_scale_idx = batch_id * num_qo_heads * num_warp_block_q + head_id * num_warp_block_q + bx * num_warps_q + get_warp_idx_q(); + } + else if constexpr (Q_GRAN == QuantGranularity::kPerThread) + { + const uint32_t num_warp_block_q = gridDim.x * num_warps_q; + q_scale_idx = batch_id * num_qo_heads * (num_warp_block_q * 8) + head_id * (num_warp_block_q * 8) + bx * (num_warps_q * 8) + get_warp_idx_q() * 8 + lane_id / 4; + } + + if constexpr (K_GRAN == QuantGranularity::kPerBlock) + { + const uint32_t num_block_k = div_ceil(kv_len, CTA_K); + k_scale_idx = batch_id * (num_qo_heads / num_kv_groups) * num_block_k + (head_id / num_kv_groups) * num_block_k; + } + else if constexpr (K_GRAN == QuantGranularity::kPerWarp) + { + const uint32_t num_warp_block_k = div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K); + k_scale_idx = batch_id * (num_qo_heads / num_kv_groups) * num_warp_block_k + (head_id / num_kv_groups) * num_warp_block_k + get_warp_idx_k(); + } + else if constexpr (K_GRAN == QuantGranularity::kPerThread) + { + const uint32_t num_warp_block_k = div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K); + k_scale_idx = batch_id * (num_qo_heads / num_kv_groups) * (num_warp_block_k * 4) + (head_id / num_kv_groups) * (num_warp_block_k * 4) + get_warp_idx_k() * 4 + lane_id % 4; + } + + constexpr uint32_t k_scale_advance_offset = (K_GRAN == QuantGranularity::kPerBlock) ? 1 : (K_GRAN == QuantGranularity::kPerWarp) ? (CTA_K / WARP_K) : (CTA_K / WARP_K) * 4; + + // initialize o, m, d +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { + if constexpr (std::is_same::value) + { +#pragma unroll + for (uint32_t k = 0; k < 8; k++) + { + RO[fq][fv][k] = 0.0f; + } + } + else if constexpr (std::is_same::value) + { +#pragma unroll + for (uint32_t k = 0; k < 4; k++) + { + ((int32_t*)RO[fq][fv])[k] = 0; + } + } + } + } +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t k = 0; k < 2; k++) + { + m[fq][k] = -5000000.0f; + d[fq][k] = 1.0f; + } + } + + constexpr uint32_t K_smem_idx_offset = CTA_Q; + constexpr uint32_t V_smem_idx_offset = CTA_Q + CTA_K; + + constexpr SwizzleMode swizzle_mode_QK = (QK_SMEM_STRIDE == 32) ? SwizzleMode::k32B : (QK_SMEM_STRIDE == 64) ? SwizzleMode::k64B : SwizzleMode::k128B; + smem_t smem_Q(smem); + smem_t smem_K(smem + K_smem_idx_offset * QK_SMEM_STRIDE); + constexpr SwizzleMode swizzle_mode_V = (V_SMEM_STRIDE == 32) ? SwizzleMode::k64B : SwizzleMode::k128B; + smem_t smem_V(smem + V_smem_idx_offset * QK_SMEM_STRIDE); + constexpr SwizzleMode swizzle_mode_O = (O_SMEM_STRIDE == 32) ? SwizzleMode::k64B : SwizzleMode::k128B; + smem_t smem_O(smem); + + constexpr uint32_t global_to_shared_line_lanes_QK = (QK_SMEM_STRIDE == 32) ? 2 : (QK_SMEM_STRIDE == 64) ? 4 : 8; + constexpr uint32_t global_to_shared_copy_lines_per_warp_QK = (QK_SMEM_STRIDE == 32) ? 16 : (QK_SMEM_STRIDE == 64) ? 8 : 4; + constexpr uint32_t global_to_shared_line_lanes_V = (V_SMEM_STRIDE == 32) ? 4 : 8; + constexpr uint32_t global_to_shared_copy_lines_per_warp_V = (V_SMEM_STRIDE == 32) ? 8 : 4; + constexpr uint32_t global_to_shared_line_lanes_O = (O_SMEM_STRIDE == 32) ? 4 : 8; + constexpr uint32_t global_to_shared_copy_lines_per_warp_O = (O_SMEM_STRIDE == 32) ? 8 : 4; + + constexpr uint32_t QK_smem_iters_row = QK_SMEM_STRIDE / (global_to_shared_line_lanes_QK * PACK_SIZE_QK); + constexpr uint32_t Q_smem_iters_col = CTA_Q / (num_warps * global_to_shared_copy_lines_per_warp_QK); + constexpr uint32_t K_smem_iters_col = CTA_K / (num_warps * global_to_shared_copy_lines_per_warp_QK); + constexpr uint32_t V_smem_iters_row = V_SMEM_STRIDE / (global_to_shared_line_lanes_V * PACK_SIZE_V); + constexpr uint32_t V_smem_iters_col = CTA_K / (num_warps * global_to_shared_copy_lines_per_warp_V); + constexpr uint32_t O_smem_iters_row = O_SMEM_STRIDE / (global_to_shared_line_lanes_O * PACK_SIZE_O); + constexpr uint32_t O_smem_iters_col = CTA_Q / (num_warps * global_to_shared_copy_lines_per_warp_O); + + int8_t *Q_lane_base_ptr = Q + batch_id * stride_bz_q + head_id * stride_h_q + (bx * CTA_Q + CTA_Q / num_warps * warp_id + lane_id / global_to_shared_line_lanes_QK) * stride_seq_q + (lane_id % global_to_shared_line_lanes_QK) * PACK_SIZE_QK; + int8_t *K_lane_base_ptr = K + batch_id * stride_bz_k + (head_id / num_kv_groups) * stride_h_k + (CTA_K / num_warps * warp_id + lane_id / global_to_shared_line_lanes_QK) * stride_seq_k + (lane_id % global_to_shared_line_lanes_QK) * PACK_SIZE_QK; + half *V_lane_base_ptr = V + batch_id * stride_bz_v + (head_id / num_kv_groups) * stride_h_v + (CTA_K / num_warps * warp_id + lane_id / global_to_shared_line_lanes_V) * stride_seq_v + (lane_id % global_to_shared_line_lanes_V) * PACK_SIZE_V; + uint32_t Q_smem_offset_load = smem_Q.get_permuted_offset(warp_id * global_to_shared_copy_lines_per_warp_QK * Q_smem_iters_col + lane_id / global_to_shared_line_lanes_QK, lane_id % global_to_shared_line_lanes_QK); + uint32_t K_smem_offset_load = smem_K.get_permuted_offset(warp_id * global_to_shared_copy_lines_per_warp_QK * K_smem_iters_col + lane_id / global_to_shared_line_lanes_QK, lane_id % global_to_shared_line_lanes_QK); + uint32_t V_smem_offset_load = smem_V.get_permuted_offset(warp_id * global_to_shared_copy_lines_per_warp_V * V_smem_iters_col + lane_id / global_to_shared_line_lanes_V, lane_id % global_to_shared_line_lanes_V); + + uint32_t Q_smem_offset_mma = smem_Q.get_permuted_offset(get_warp_idx_q() * WARP_Q + lane_id % 16, lane_id / 16); + uint32_t K_smem_offset_mma = smem_K.get_permuted_offset(get_warp_idx_k() * WARP_K + lane_id % 8 + (lane_id / 16) * 8, (lane_id / 8) % 2); + uint32_t V_smem_offset_mma = smem_V.get_permuted_offset(get_warp_idx_k() * WARP_K + lane_id % 16, lane_id / 16); + + // for causal masking + uint32_t Q_idx_lane_base = bx * CTA_Q + get_warp_idx_q() * WARP_Q + lane_id / 4; + uint32_t K_idx_lane_base = get_warp_idx_k() * WARP_K + 2 * (lane_id % 4); + + // for loading + uint32_t Q_load_idx_lane_base = bx * CTA_Q + CTA_Q / num_warps * warp_id + lane_id / global_to_shared_line_lanes_QK; + uint32_t K_load_idx_lane_base = CTA_K / num_warps * warp_id + lane_id / global_to_shared_line_lanes_QK; + uint32_t V_load_idx_lane_base = CTA_K / num_warps * warp_id + lane_id / global_to_shared_line_lanes_V; + + const uint32_t num_iterations = div_ceil( + mask_mode == MaskMode::kCausal + ? min(kv_len, (bx + 1) * CTA_Q) + : kv_len, + CTA_K); + + // load Q with predicate + load_global_to_share( + &Q_lane_base_ptr, Q_smem_offset_load, stride_seq_q, smem_Q, Q_load_idx_lane_base, qo_len); + cp_async::commit_group(); + cp_async::wait_group<0>(); + __syncthreads(); + + // for num_tiles_qk_inner = 1, we load all Qs in register + uint32_t RQ[num_tiles_q][4]; + if constexpr (num_tiles_qk_inner == 1) + { +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + smem_Q.ldmatrix_m8n8x4(Q_smem_offset_mma, RQ[fq]); + Q_smem_offset_mma = smem_Q.advance_offset_by_row<16>(Q_smem_offset_mma); + } + } + + // load K with predicate + load_global_to_share( + &K_lane_base_ptr, K_smem_offset_load, stride_seq_k, smem_K, K_load_idx_lane_base, kv_len); + cp_async::commit_group(); + + float q_scale = Q_scale[q_scale_idx]; + + float original_sm_scale = sm_scale; + float dequant_scale = q_scale * K_scale[k_scale_idx + 0 * k_scale_advance_offset]; + + sm_scale = original_sm_scale * dequant_scale; + + // load V with predicate + load_global_to_share( + &V_lane_base_ptr, V_smem_offset_load, stride_seq_v, smem_V, V_load_idx_lane_base, kv_len); + cp_async::commit_group(); + + K_load_idx_lane_base += CTA_K; + V_load_idx_lane_base += CTA_K; + +#pragma unroll + for (uint32_t iter = 1; iter < num_iterations - 1; iter++) + { + // ensure K is ready + cp_async::wait_group<1>(); + __syncthreads(); + + // compute QK^T + if constexpr (num_tiles_qk_inner == 1) + { + compute_int_qk( + smem_K, RS, RQ, K_smem_offset_mma); + } + else + { + compute_int_qk( + smem_Q, smem_K, RS, Q_smem_offset_mma, K_smem_offset_mma); + } + + float RS_f32[num_tiles_q][num_tiles_k][8]; + +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { +#pragma unroll + for (uint32_t k = 0; k < 8; k++) + { + RS_f32[fq][fk][k] = __int2float_rz(RS[fq][fk][k]); + } + } + } + + // do not apply causal mask and out of bound mask for these iterations + K_idx_lane_base += CTA_K; + + if constexpr (std::is_same::value) + { + update_mdo(RS_f32, RO, m, d, sm_scale); + } + else if constexpr (std::is_same::value) + { + update_mdo(RS_f32, RO, m, d, sm_scale); + } + + if constexpr (DenominatorAccumUnit == ComputeUnit::kCudaCore) + { + accumulate_d(RS_f32, d); + } + + uint32_t RS_f16[num_tiles_q][num_tiles_k][4]; + RS_32_to_16(RS_f32, RS_f16); + + if constexpr (DenominatorAccumUnit == ComputeUnit::kTensorCore) + { + accumulate_d(RS_f16, d); + } + + __syncthreads(); + + // load K + load_global_to_share( + &K_lane_base_ptr, K_smem_offset_load, stride_seq_k, smem_K); + cp_async::commit_group(); + + dequant_scale = q_scale * K_scale[k_scale_idx + iter * k_scale_advance_offset]; + sm_scale = original_sm_scale * dequant_scale; + + // ensure V is ready + cp_async::wait_group<1>(); + __syncthreads(); + + if constexpr (!use_inst_buffer) + { + compute_fp16_sv_permuted( + smem_V, RS_f16, RO, d, V_smem_offset_mma); + } + else + { + compute_fp16_sv_permuted_inst_buf( + smem_V, RS_f16, RO, d, V_smem_offset_mma); + } + + __syncthreads(); + // load V + load_global_to_share( + &V_lane_base_ptr, V_smem_offset_load, stride_seq_v, smem_V); + cp_async::commit_group(); + K_load_idx_lane_base += CTA_K; + V_load_idx_lane_base += CTA_K; + } + + // second last iter, apply causal mask + if (num_iterations > 1) + { + // ensure K is ready + cp_async::wait_group<1>(); + __syncthreads(); + + // compute QK^T + if constexpr (num_tiles_qk_inner == 1) + { + compute_int_qk( + smem_K, RS, RQ, K_smem_offset_mma); + } + else + { + compute_int_qk( + smem_Q, smem_K, RS, Q_smem_offset_mma, K_smem_offset_mma); + } + + float RS_f32[num_tiles_q][num_tiles_k][8]; + +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { +#pragma unroll + for (uint32_t k = 0; k < 8; k++) + { + RS_f32[fq][fk][k] = __int2float_rz(RS[fq][fk][k]) * dequant_scale; + } + } + } + + if constexpr (mask_mode == MaskMode::kCausal) + { + apply_causal_mask(Q_idx_lane_base, K_idx_lane_base, RS_f32); + } + // apply_out_of_bound_mask(K_idx_lane_base, RS_f32, kv_len); + K_idx_lane_base += CTA_K; + + if constexpr (std::is_same::value) + { + update_mdo(RS_f32, RO, m, d, original_sm_scale); + } + else if constexpr (std::is_same::value) + { + update_mdo(RS_f32, RO, m, d, original_sm_scale); + } + + if constexpr (DenominatorAccumUnit == ComputeUnit::kCudaCore) + { + accumulate_d(RS_f32, d); + } + + uint32_t RS_f16[num_tiles_q][num_tiles_k][4]; + RS_32_to_16(RS_f32, RS_f16); + + if constexpr (DenominatorAccumUnit == ComputeUnit::kTensorCore) + { + accumulate_d(RS_f16, d); + } + + __syncthreads(); + + // load K with predicate + load_global_to_share( + &K_lane_base_ptr, K_smem_offset_load, stride_seq_k, smem_K, K_load_idx_lane_base, kv_len); + cp_async::commit_group(); + + dequant_scale = q_scale * K_scale[k_scale_idx + (num_iterations - 1) * k_scale_advance_offset]; + sm_scale = original_sm_scale * dequant_scale; + + // ensure V is ready + cp_async::wait_group<1>(); + __syncthreads(); + + if constexpr (!use_inst_buffer) + { + compute_fp16_sv_permuted( + smem_V, RS_f16, RO, d, V_smem_offset_mma); + } + else + { + compute_fp16_sv_permuted_inst_buf( + smem_V, RS_f16, RO, d, V_smem_offset_mma); + } + + __syncthreads(); + // load V with predicate + load_global_to_share( + &V_lane_base_ptr, V_smem_offset_load, stride_seq_v, smem_V, V_load_idx_lane_base, kv_len); + cp_async::commit_group(); + K_load_idx_lane_base += CTA_K; + V_load_idx_lane_base += CTA_K; + } + + // last iter, apply causal mask and out of bound mask + { + // ensure K is ready + cp_async::wait_group<1>(); + __syncthreads(); + + // compute QK^T + if constexpr (num_tiles_qk_inner == 1) + { + compute_int_qk( + smem_K, RS, RQ, K_smem_offset_mma); + } + else + { + compute_int_qk( + smem_Q, smem_K, RS, Q_smem_offset_mma, K_smem_offset_mma); + } + + float RS_f32[num_tiles_q][num_tiles_k][8]; + +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { +#pragma unroll + for (uint32_t k = 0; k < 8; k++) + { + RS_f32[fq][fk][k] = __int2float_rz(RS[fq][fk][k]) * dequant_scale; + } + } + } + + if constexpr (mask_mode == MaskMode::kCausal) + { + apply_causal_mask(Q_idx_lane_base, K_idx_lane_base, RS_f32); + } + // check out of bound in the last iter + apply_out_of_bound_mask(K_idx_lane_base, RS_f32, kv_len); + K_idx_lane_base += CTA_K; + + if constexpr (std::is_same::value) + { + update_mdo(RS_f32, RO, m, d, original_sm_scale); + } + else if constexpr (std::is_same::value) + { + update_mdo(RS_f32, RO, m, d, original_sm_scale); + } + + if constexpr (DenominatorAccumUnit == ComputeUnit::kCudaCore) + { + accumulate_d(RS_f32, d); + } + + uint32_t RS_f16[num_tiles_q][num_tiles_k][4]; + RS_32_to_16(RS_f32, RS_f16); + + if constexpr (DenominatorAccumUnit == ComputeUnit::kTensorCore) + { + accumulate_d(RS_f16, d); + } + + // ensure V is ready + cp_async::wait_group<0>(); + __syncthreads(); + + if constexpr (!use_inst_buffer) + { + compute_fp16_sv_permuted( + smem_V, RS_f16, RO, d, V_smem_offset_mma); + } + else + { + compute_fp16_sv_permuted_inst_buf( + smem_V, RS_f16, RO, d, V_smem_offset_mma); + } + + __syncthreads(); + + } + + // TODO: thread block sync mdo state for num_warps_k > 0 + + normalize_d(RO, m, d); + + // save the result + // if (get_warp_idx_k() == 0) + // { + + // convert half to bfloat16 + if constexpr (std::is_same::value && std::is_same::value) + { +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { + ((nv_bfloat162*)RO[fq][fv])[0] = __float22bfloat162_rn(__half22float2(((half2*)RO[fq][fv])[0])); + ((nv_bfloat162*)RO[fq][fv])[1] = __float22bfloat162_rn(__half22float2(((half2*)RO[fq][fv])[1])); + ((nv_bfloat162*)RO[fq][fv])[2] = __float22bfloat162_rn(__half22float2(((half2*)RO[fq][fv])[2])); + ((nv_bfloat162*)RO[fq][fv])[3] = __float22bfloat162_rn(__half22float2(((half2*)RO[fq][fv])[3])); + } + } + } + + // add v_mean + if constexpr (fuse_v_mean) + { + DTypeOut2 v_mean[2]; + DTypeOut *V_mean_lane_ptr = V_mean + batch_id * (num_qo_heads / num_kv_groups) * head_dim + (head_id / num_kv_groups) * head_dim + lane_id % 4 * 2; +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { + v_mean[0] = *((DTypeOut2*)(V_mean_lane_ptr + fv * 16)); + v_mean[1] = *((DTypeOut2*)(V_mean_lane_ptr + 8 + fv * 16)); +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + ((DTypeOut2*)RO[fq][fv])[0] = __hadd2(((DTypeOut2*)RO[fq][fv])[0], v_mean[0]); + ((DTypeOut2*)RO[fq][fv])[1] = __hadd2(((DTypeOut2*)RO[fq][fv])[1], v_mean[0]); + ((DTypeOut2*)RO[fq][fv])[2] = __hadd2(((DTypeOut2*)RO[fq][fv])[2], v_mean[1]); + ((DTypeOut2*)RO[fq][fv])[3] = __hadd2(((DTypeOut2*)RO[fq][fv])[3], v_mean[1]); + } + } + } + + // save the result to shared memory + uint32_t smem_O_row_base = get_warp_idx_q() * WARP_Q + lane_id / 4; +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { + uint32_t offset_O = smem_O.get_permuted_offset(smem_O_row_base + fq * MMA_QK_M, fv * (MMA_SV_N / PACK_SIZE_O)); + + if constexpr (std::is_same::value) + { + // convert RO to half + uint32_t RO_f16[4]; +#pragma unroll + for (uint32_t k = 0; k < 4; k++) + { + if constexpr (std::is_same::value) + { + ((half2*)RO_f16)[k] = __float22half2_rn(((float2*)RO[fq][fv])[k]); + } + else if constexpr (std::is_same::value) + { + ((nv_bfloat162*)RO_f16)[k] = __float22bfloat162_rn(((float2*)RO[fq][fv])[k]); + } + } + + ((uint32_t*)(smem_O.base + offset_O))[lane_id % 4] = RO_f16[0]; + ((uint32_t*)(smem_O.base + offset_O + 8 * (O_SMEM_STRIDE / PACK_SIZE_O)))[lane_id % 4] = RO_f16[1]; + + // ! permuted, make sure you know what you are doing + ((uint32_t*)(smem_O.base + (offset_O ^ 0x1)))[lane_id % 4] = RO_f16[2]; + ((uint32_t*)(smem_O.base + (offset_O ^ 0x1) + 8 * (O_SMEM_STRIDE / PACK_SIZE_O)))[lane_id % 4] = RO_f16[3]; + } + else if constexpr (std::is_same::value) + { + ((uint32_t*)(smem_O.base + offset_O))[lane_id % 4] = ((uint32_t*)RO[fq][fv])[0]; + ((uint32_t*)(smem_O.base + offset_O + 8 * (O_SMEM_STRIDE / PACK_SIZE_O)))[lane_id % 4] = ((uint32_t*)RO[fq][fv])[1]; + + // ! permuted, make sure you know what you are doing + ((uint32_t*)(smem_O.base + (offset_O ^ 0x1)))[lane_id % 4] = ((uint32_t*)RO[fq][fv])[2]; + ((uint32_t*)(smem_O.base + (offset_O ^ 0x1) + 8 * (O_SMEM_STRIDE / PACK_SIZE_O)))[lane_id % 4] = ((uint32_t*)RO[fq][fv])[3]; + } + } + } + + // ! do we need to sync here? + __syncwarp(); + + // shared memory to global memory + DTypeOut *O_lane_ptr = O + batch_id * stride_bz_o + head_id * stride_h_o + (bx * CTA_Q + WARP_Q * get_warp_idx_q() + lane_id / global_to_shared_line_lanes_O) * stride_seq_o + lane_id % global_to_shared_line_lanes_O * PACK_SIZE_O; + uint32_t offset_O = smem_O.get_permuted_offset(get_warp_idx_q() * WARP_Q + lane_id / global_to_shared_line_lanes_O, lane_id % global_to_shared_line_lanes_O); + uint32_t O_load_idx_lane_base = bx * CTA_Q + CTA_Q / num_warps * warp_id + lane_id / global_to_shared_line_lanes_O; + +#pragma unroll + for (uint32_t i = 0; i < O_smem_iters_col; i++) + { +#pragma unroll + for (uint32_t j = 0; j < O_smem_iters_row; j++) + { + if (O_load_idx_lane_base < qo_len) + { + smem_O.store_128b(offset_O, O_lane_ptr); + } + O_lane_ptr += (global_to_shared_line_lanes_O * PACK_SIZE_O); + offset_O = smem_O.advance_offset_by_column(offset_O); + } + + offset_O = smem_O.advance_offset_by_row(offset_O - (O_smem_iters_row * global_to_shared_line_lanes_O)); + O_lane_ptr += ((global_to_shared_copy_lines_per_warp_O * stride_seq_o) - (O_smem_iters_row * global_to_shared_line_lanes_O * PACK_SIZE_O)); + O_load_idx_lane_base += global_to_shared_copy_lines_per_warp_O; + } + + if constexpr (return_lse) + { + uint32_t lse_idx = bx * CTA_Q + lane_id / 4 + 8 * (lane_id % 4) + WARP_Q * get_warp_idx_q(); + float *lse_lane_ptr = Lse + batch_id * (qo_len * num_qo_heads) + head_id * qo_len + lse_idx; + uint32_t fq = (lane_id % 4) / 2; + uint32_t k = (lane_id % 4) % 2; + + if (lse_idx < qo_len && (lane_id % 4) < 2 * num_tiles_q) + { + lse_lane_ptr[0] = (math::ptx_log2(d[fq][k]) + m[fq][k]); + } + } + + // } +} + +// tensor_layout 0 for [B, N, H, D], 1 for [B, H, N, D] +torch::Tensor qk_int8_sv_f16_accum_f32_attn(torch::Tensor query, + torch::Tensor key, + torch::Tensor value, + torch::Tensor output, + torch::Tensor query_scale, + torch::Tensor key_scale, + int64_t tensor_layout, + int64_t is_causal, + int64_t qk_quant_gran, + double sm_scale, + int64_t return_lse) +{ + CHECK_CUDA(query); + CHECK_CUDA(key); + CHECK_CUDA(value); + CHECK_CUDA(output); + CHECK_CUDA(query_scale); + CHECK_CUDA(key_scale); + + CHECK_CONTIGUOUS(query); + CHECK_CONTIGUOUS(key); + CHECK_LASTDIM_CONTIGUOUS(value); + CHECK_LASTDIM_CONTIGUOUS(output); + CHECK_CONTIGUOUS(query_scale); + CHECK_CONTIGUOUS(key_scale); + + CHECK_DTYPE(query, torch::kInt8); + CHECK_DTYPE(key, torch::kInt8); + CHECK_DTYPE(value, torch::kHalf); + CHECK_DTYPE(query_scale, torch::kFloat32); + CHECK_DTYPE(key_scale, torch::kFloat32); + + CHECK_DIMS(query, 4); + CHECK_DIMS(key, 4); + CHECK_DIMS(value, 4); + CHECK_DIMS(output, 4); + CHECK_DIMS(query_scale, 3); + CHECK_DIMS(key_scale, 3); + + const int head_dim = query.size(3); + const int batch_size = query.size(0); + + int stride_bz_q = query.stride(0); + int stride_bz_k = key.stride(0); + int stride_bz_v = value.stride(0); + int stride_bz_o = output.stride(0); + + int qo_len, kv_len, num_qo_heads, num_kv_heads; + int stride_seq_q, stride_seq_k, stride_seq_v, stride_seq_o; + int stride_h_q, stride_h_k, stride_h_v, stride_h_o; + + if (tensor_layout == 0) + { + qo_len = query.size(1); + kv_len = key.size(1); + num_qo_heads = query.size(2); + num_kv_heads = key.size(2); + CHECK_SHAPE(key, batch_size, kv_len, num_kv_heads, head_dim); + CHECK_SHAPE(value, batch_size, kv_len, num_kv_heads, head_dim); + + stride_seq_q = query.stride(1); + stride_seq_k = key.stride(1); + stride_seq_v = value.stride(1); + stride_seq_o = output.stride(1); + + stride_h_q = query.stride(2); + stride_h_k = key.stride(2); + stride_h_v = value.stride(2); + stride_h_o = output.stride(2); + } + else if (tensor_layout == 1) + { + qo_len = query.size(2); + kv_len = key.size(2); + num_qo_heads = query.size(1); + num_kv_heads = key.size(1); + CHECK_SHAPE(key, batch_size, num_kv_heads, kv_len, head_dim); + CHECK_SHAPE(value, batch_size, num_kv_heads, kv_len, head_dim); + + stride_seq_q = query.stride(2); + stride_seq_k = key.stride(2); + stride_seq_v = value.stride(2); + stride_seq_o = output.stride(2); + + stride_h_q = query.stride(1); + stride_h_k = key.stride(1); + stride_h_v = value.stride(1); + stride_h_o = output.stride(1); + } + else + { + throw std::invalid_argument("tensor_layout must be 0 or 1"); + } + + if (num_qo_heads % num_kv_heads != 0) { + std::ostringstream err_msg; + err_msg << "num_qo_heads (" << num_qo_heads << ") must be divisible by num_kv_heads (" << num_kv_heads << ")"; + throw std::invalid_argument(err_msg.str()); + } + + const int num_kv_groups = num_qo_heads / num_kv_heads; + + torch::Tensor lse = torch::empty({0}); + if (return_lse) + { + lse = torch::empty({batch_size, num_qo_heads, qo_len}, query.options().dtype(torch::kFloat32)); + } + + auto output_dtype = output.scalar_type(); + + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + DISPATCH_CAUSAL(is_causal, IS_CAUSAL, { + DISPATCH_QK_QUANT_GRAN(qk_quant_gran, QK_QUANT_GRAN, { + DISPATCH_RETURN_LSE(return_lse, RETURN_LSE, { + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(output_dtype, DTypeOut, { + constexpr int CTA_Q = 128; + constexpr int CTA_K = 64; + constexpr int WARP_Q = 32; + constexpr int WARP_K = 64; + + constexpr MaskMode mask_mode = IS_CAUSAL ? MaskMode::kCausal : MaskMode::kNone; + + if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp)) + { + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, div_ceil(qo_len, CTA_Q) * (CTA_Q / WARP_Q)); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K)); + } + else if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread)) + { + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, div_ceil(qo_len, CTA_Q) * (CTA_Q / WARP_Q) * 8); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K) * 4); + } + else + { + static_assert(QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp) || QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread), "Unsupported quantization granularity"); + } + + // smem_Q smem_K smem_V smem_O + size_t smem_max = std::max(CTA_Q * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(half), CTA_Q * HEAD_DIM * sizeof(half)); + + auto kernel_func = qk_int_sv_f16_attn_kernel(QK_QUANT_GRAN), static_cast(QK_QUANT_GRAN), float, false, DTypeOut, ComputeUnit::kTensorCore, + mask_mode, RETURN_LSE, false>; + + cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_max); + + dim3 grid(div_ceil(qo_len, CTA_Q), num_qo_heads, batch_size); + dim3 block(32, (CTA_Q / WARP_Q) * (CTA_K / WARP_K)); + + kernel_func<<>>( + query.data_ptr(), + key.data_ptr(), + reinterpret_cast(value.data_ptr()), + reinterpret_cast(output.data_ptr()), + (RETURN_LSE) ? reinterpret_cast(lse.data_ptr()) : nullptr, + reinterpret_cast(query_scale.data_ptr()), + reinterpret_cast(key_scale.data_ptr()), + nullptr, + qo_len, + kv_len, + num_kv_groups, + stride_bz_q, stride_seq_q, stride_h_q, + stride_bz_k, stride_seq_k, stride_h_k, + stride_bz_v, stride_seq_v, stride_h_v, + stride_bz_o, stride_seq_o, stride_h_o, + sm_scale); + }); + }); + }); + }); + }); + + return lse; +} + +torch::Tensor qk_int8_sv_f16_accum_f16_attn(torch::Tensor query, + torch::Tensor key, + torch::Tensor value, + torch::Tensor output, + torch::Tensor query_scale, + torch::Tensor key_scale, + int64_t tensor_layout, + int64_t is_causal, + int64_t qk_quant_gran, + double sm_scale, + int64_t return_lse) +{ + CHECK_CUDA(query); + CHECK_CUDA(key); + CHECK_CUDA(value); + CHECK_CUDA(output); + CHECK_CUDA(query_scale); + CHECK_CUDA(key_scale); + + CHECK_CONTIGUOUS(query); + CHECK_CONTIGUOUS(key); + CHECK_LASTDIM_CONTIGUOUS(value); + CHECK_LASTDIM_CONTIGUOUS(output); + CHECK_CONTIGUOUS(query_scale); + CHECK_CONTIGUOUS(key_scale); + + CHECK_DTYPE(query, torch::kInt8); + CHECK_DTYPE(key, torch::kInt8); + CHECK_DTYPE(value, torch::kHalf); + CHECK_DTYPE(query_scale, torch::kFloat32); + CHECK_DTYPE(key_scale, torch::kFloat32); + + CHECK_DIMS(query, 4); + CHECK_DIMS(key, 4); + CHECK_DIMS(value, 4); + CHECK_DIMS(output, 4); + CHECK_DIMS(query_scale, 3); + CHECK_DIMS(key_scale, 3); + + const int head_dim = query.size(3); + const int batch_size = query.size(0); + + int stride_bz_q = query.stride(0); + int stride_bz_k = key.stride(0); + int stride_bz_v = value.stride(0); + int stride_bz_o = output.stride(0); + + int qo_len, kv_len, num_qo_heads, num_kv_heads; + int stride_seq_q, stride_seq_k, stride_seq_v, stride_seq_o; + int stride_h_q, stride_h_k, stride_h_v, stride_h_o; + + if (tensor_layout == 0) + { + qo_len = query.size(1); + kv_len = key.size(1); + num_qo_heads = query.size(2); + num_kv_heads = key.size(2); + CHECK_SHAPE(key, batch_size, kv_len, num_kv_heads, head_dim); + CHECK_SHAPE(value, batch_size, kv_len, num_kv_heads, head_dim); + + stride_seq_q = query.stride(1); + stride_seq_k = key.stride(1); + stride_seq_v = value.stride(1); + stride_seq_o = output.stride(1); + + stride_h_q = query.stride(2); + stride_h_k = key.stride(2); + stride_h_v = value.stride(2); + stride_h_o = output.stride(2); + } + else if (tensor_layout == 1) + { + qo_len = query.size(2); + kv_len = key.size(2); + num_qo_heads = query.size(1); + num_kv_heads = key.size(1); + CHECK_SHAPE(key, batch_size, num_kv_heads, kv_len, head_dim); + CHECK_SHAPE(value, batch_size, num_kv_heads, kv_len, head_dim); + + stride_seq_q = query.stride(2); + stride_seq_k = key.stride(2); + stride_seq_v = value.stride(2); + stride_seq_o = output.stride(2); + + stride_h_q = query.stride(1); + stride_h_k = key.stride(1); + stride_h_v = value.stride(1); + stride_h_o = output.stride(1); + } + else + { + throw std::invalid_argument("tensor_layout must be 0 or 1"); + } + + if (num_qo_heads % num_kv_heads != 0) { + std::ostringstream err_msg; + err_msg << "num_qo_heads (" << num_qo_heads << ") must be divisible by num_kv_heads (" << num_kv_heads << ")"; + throw std::invalid_argument(err_msg.str()); + } + + torch::Tensor lse = torch::empty({0}); + if (return_lse) + { + lse = torch::empty({batch_size, num_qo_heads, qo_len}, query.options().dtype(torch::kFloat32)); + } + + const int num_kv_groups = num_qo_heads / num_kv_heads; + + auto output_dtype = output.scalar_type(); + + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + DISPATCH_CAUSAL(is_causal, IS_CAUSAL, { + DISPATCH_QK_QUANT_GRAN(qk_quant_gran, QK_QUANT_GRAN, { + DISPATCH_RETURN_LSE(return_lse, RETURN_LSE, { + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(output_dtype, DTypeOut, { + + constexpr int CTA_Q = 128; + constexpr int CTA_K = 64; + constexpr int WARP_Q = 32; + constexpr int WARP_K = 64; + + constexpr MaskMode mask_mode = IS_CAUSAL ? MaskMode::kCausal : MaskMode::kNone; + + if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp)) + { + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, div_ceil(qo_len, CTA_Q) * (CTA_Q / WARP_Q)); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K)); + } + else if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread)) + { + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, div_ceil(qo_len, CTA_Q) * (CTA_Q / WARP_Q) * 8); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K) * 4); + } + else + { + static_assert(QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp) || QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread), "Unsupported quantization granularity"); + } + + // smem_Q smem_K smem_V smem_O + size_t smem_max = std::max(CTA_Q * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(half), CTA_Q * HEAD_DIM * sizeof(half)); + + auto kernel_func = qk_int_sv_f16_attn_kernel(QK_QUANT_GRAN), static_cast(QK_QUANT_GRAN), half, false, DTypeOut, ComputeUnit::kTensorCore, + mask_mode, RETURN_LSE, false>; + + cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_max); + + dim3 grid(div_ceil(qo_len, CTA_Q), num_qo_heads, batch_size); + dim3 block(32, (CTA_Q / WARP_Q) * (CTA_K / WARP_K)); + + kernel_func<<>>( + query.data_ptr(), + key.data_ptr(), + reinterpret_cast(value.data_ptr()), + reinterpret_cast(output.data_ptr()), + (RETURN_LSE) ? reinterpret_cast(lse.data_ptr()) : nullptr, + reinterpret_cast(query_scale.data_ptr()), + reinterpret_cast(key_scale.data_ptr()), + nullptr, + qo_len, + kv_len, + num_kv_groups, + stride_bz_q, stride_seq_q, stride_h_q, + stride_bz_k, stride_seq_k, stride_h_k, + stride_bz_v, stride_seq_v, stride_h_v, + stride_bz_o, stride_seq_o, stride_h_o, + sm_scale); + }); + }); + }); + }); + }); + + return lse; +} + +torch::Tensor qk_int8_sv_f16_accum_f16_attn_inst_buf(torch::Tensor query, + torch::Tensor key, + torch::Tensor value, + torch::Tensor output, + torch::Tensor query_scale, + torch::Tensor key_scale, + int64_t tensor_layout, + int64_t is_causal, + int64_t qk_quant_gran, + double sm_scale, + int64_t return_lse) +{ + CHECK_CUDA(query); + CHECK_CUDA(key); + CHECK_CUDA(value); + CHECK_CUDA(output); + CHECK_CUDA(query_scale); + CHECK_CUDA(key_scale); + + CHECK_CONTIGUOUS(query); + CHECK_CONTIGUOUS(key); + CHECK_LASTDIM_CONTIGUOUS(value); + CHECK_LASTDIM_CONTIGUOUS(output); + CHECK_CONTIGUOUS(query_scale); + CHECK_CONTIGUOUS(key_scale); + + CHECK_DTYPE(query, torch::kInt8); + CHECK_DTYPE(key, torch::kInt8); + CHECK_DTYPE(value, torch::kHalf); + CHECK_DTYPE(query_scale, torch::kFloat32); + CHECK_DTYPE(key_scale, torch::kFloat32); + + CHECK_DIMS(query, 4); + CHECK_DIMS(key, 4); + CHECK_DIMS(value, 4); + CHECK_DIMS(output, 4); + CHECK_DIMS(query_scale, 3); + CHECK_DIMS(key_scale, 3); + + const int head_dim = query.size(3); + const int batch_size = query.size(0); + + int stride_bz_q = query.stride(0); + int stride_bz_k = key.stride(0); + int stride_bz_v = value.stride(0); + int stride_bz_o = output.stride(0); + + int qo_len, kv_len, num_qo_heads, num_kv_heads; + int stride_seq_q, stride_seq_k, stride_seq_v, stride_seq_o; + int stride_h_q, stride_h_k, stride_h_v, stride_h_o; + + if (tensor_layout == 0) + { + qo_len = query.size(1); + kv_len = key.size(1); + num_qo_heads = query.size(2); + num_kv_heads = key.size(2); + CHECK_SHAPE(key, batch_size, kv_len, num_kv_heads, head_dim); + CHECK_SHAPE(value, batch_size, kv_len, num_kv_heads, head_dim); + + stride_seq_q = query.stride(1); + stride_seq_k = key.stride(1); + stride_seq_v = value.stride(1); + stride_seq_o = output.stride(1); + + stride_h_q = query.stride(2); + stride_h_k = key.stride(2); + stride_h_v = value.stride(2); + stride_h_o = output.stride(2); + } + else if (tensor_layout == 1) + { + qo_len = query.size(2); + kv_len = key.size(2); + num_qo_heads = query.size(1); + num_kv_heads = key.size(1); + CHECK_SHAPE(key, batch_size, num_kv_heads, kv_len, head_dim); + CHECK_SHAPE(value, batch_size, num_kv_heads, kv_len, head_dim); + + stride_seq_q = query.stride(2); + stride_seq_k = key.stride(2); + stride_seq_v = value.stride(2); + stride_seq_o = output.stride(2); + + stride_h_q = query.stride(1); + stride_h_k = key.stride(1); + stride_h_v = value.stride(1); + stride_h_o = output.stride(1); + } + else + { + throw std::invalid_argument("tensor_layout must be 0 or 1"); + } + + if (num_qo_heads % num_kv_heads != 0) { + std::ostringstream err_msg; + err_msg << "num_qo_heads (" << num_qo_heads << ") must be divisible by num_kv_heads (" << num_kv_heads << ")"; + throw std::invalid_argument(err_msg.str()); + } + + torch::Tensor lse = torch::empty({0}); + if (return_lse) + { + lse = torch::empty({batch_size, num_qo_heads, qo_len}, query.options().dtype(torch::kFloat32)); + } + + const int num_kv_groups = num_qo_heads / num_kv_heads; + + auto output_dtype = output.scalar_type(); + + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + DISPATCH_CAUSAL(is_causal, IS_CAUSAL, { + DISPATCH_QK_QUANT_GRAN(qk_quant_gran, QK_QUANT_GRAN, { + DISPATCH_RETURN_LSE(return_lse, RETURN_LSE, { + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(output_dtype, DTypeOut, { + + constexpr int CTA_Q = 128; + constexpr int CTA_K = 64; + constexpr int WARP_Q = (HEAD_DIM == 64) ? 32 : 16; + constexpr int WARP_K = 64; + + constexpr MaskMode mask_mode = IS_CAUSAL ? MaskMode::kCausal : MaskMode::kNone; + + if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp)) + { + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, div_ceil(qo_len, CTA_Q) * (CTA_Q / WARP_Q)); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K)); + } + else if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread)) + { + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, div_ceil(qo_len, CTA_Q) * (CTA_Q / WARP_Q) * 8); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K) * 4); + } + else + { + static_assert(QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp) || QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread), "Unsupported quantization granularity"); + } + + // smem_Q smem_K smem_V smem_O + size_t smem_max = std::max(CTA_Q * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(half), CTA_Q * HEAD_DIM * sizeof(half)); + + auto kernel_func = qk_int_sv_f16_attn_kernel(QK_QUANT_GRAN), static_cast(QK_QUANT_GRAN), float, true, DTypeOut, ComputeUnit::kTensorCore, + mask_mode, RETURN_LSE, false>; + + cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_max); + + dim3 grid(div_ceil(qo_len, CTA_Q), num_qo_heads, batch_size); + dim3 block(32, (CTA_Q / WARP_Q) * (CTA_K / WARP_K)); + + kernel_func<<>>( + query.data_ptr(), + key.data_ptr(), + reinterpret_cast(value.data_ptr()), + reinterpret_cast(output.data_ptr()), + (RETURN_LSE) ? reinterpret_cast(lse.data_ptr()) : nullptr, + reinterpret_cast(query_scale.data_ptr()), + reinterpret_cast(key_scale.data_ptr()), + nullptr, + qo_len, + kv_len, + num_kv_groups, + stride_bz_q, stride_seq_q, stride_h_q, + stride_bz_k, stride_seq_k, stride_h_k, + stride_bz_v, stride_seq_v, stride_h_v, + stride_bz_o, stride_seq_o, stride_h_o, + sm_scale); + }); + }); + }); + }); + }); + + return lse; +} + +torch::Tensor qk_int8_sv_f16_accum_f16_fuse_v_mean_attn(torch::Tensor query, + torch::Tensor key, + torch::Tensor value, + torch::Tensor output, + torch::Tensor query_scale, + torch::Tensor key_scale, + torch::Tensor value_mean, + int64_t tensor_layout, + int64_t is_causal, + int64_t qk_quant_gran, + double sm_scale, + int64_t return_lse) +{ + CHECK_CUDA(query); + CHECK_CUDA(key); + CHECK_CUDA(value); + CHECK_CUDA(output); + CHECK_CUDA(query_scale); + CHECK_CUDA(key_scale); + CHECK_CUDA(value_mean); + + CHECK_CONTIGUOUS(query); + CHECK_CONTIGUOUS(key); + CHECK_LASTDIM_CONTIGUOUS(value); + CHECK_LASTDIM_CONTIGUOUS(output); + CHECK_CONTIGUOUS(query_scale); + CHECK_CONTIGUOUS(key_scale); + CHECK_CONTIGUOUS(value_mean); + + CHECK_DTYPE(query, torch::kInt8); + CHECK_DTYPE(key, torch::kInt8); + CHECK_DTYPE(value, torch::kHalf); + CHECK_DTYPE(query_scale, torch::kFloat32); + CHECK_DTYPE(key_scale, torch::kFloat32); + + CHECK_DIMS(query, 4); + CHECK_DIMS(key, 4); + CHECK_DIMS(value, 4); + CHECK_DIMS(output, 4); + CHECK_DIMS(query_scale, 3); + CHECK_DIMS(key_scale, 3); + CHECK_DIMS(value_mean, 3); + + const int head_dim = query.size(3); + const int batch_size = query.size(0); + + int stride_bz_q = query.stride(0); + int stride_bz_k = key.stride(0); + int stride_bz_v = value.stride(0); + int stride_bz_o = output.stride(0); + + int qo_len, kv_len, num_qo_heads, num_kv_heads; + int stride_seq_q, stride_seq_k, stride_seq_v, stride_seq_o; + int stride_h_q, stride_h_k, stride_h_v, stride_h_o; + + if (tensor_layout == 0) + { + qo_len = query.size(1); + kv_len = key.size(1); + num_qo_heads = query.size(2); + num_kv_heads = key.size(2); + CHECK_SHAPE(key, batch_size, kv_len, num_kv_heads, head_dim); + CHECK_SHAPE(value, batch_size, kv_len, num_kv_heads, head_dim); + + stride_seq_q = query.stride(1); + stride_seq_k = key.stride(1); + stride_seq_v = value.stride(1); + stride_seq_o = output.stride(1); + + stride_h_q = query.stride(2); + stride_h_k = key.stride(2); + stride_h_v = value.stride(2); + stride_h_o = output.stride(2); + } + else if (tensor_layout == 1) + { + qo_len = query.size(2); + kv_len = key.size(2); + num_qo_heads = query.size(1); + num_kv_heads = key.size(1); + CHECK_SHAPE(key, batch_size, num_kv_heads, kv_len, head_dim); + CHECK_SHAPE(value, batch_size, num_kv_heads, kv_len, head_dim); + + stride_seq_q = query.stride(2); + stride_seq_k = key.stride(2); + stride_seq_v = value.stride(2); + stride_seq_o = output.stride(2); + + stride_h_q = query.stride(1); + stride_h_k = key.stride(1); + stride_h_v = value.stride(1); + stride_h_o = output.stride(1); + } + else + { + throw std::invalid_argument("tensor_layout must be 0 or 1"); + } + + if (num_qo_heads % num_kv_heads != 0) { + std::ostringstream err_msg; + err_msg << "num_qo_heads (" << num_qo_heads << ") must be divisible by num_kv_heads (" << num_kv_heads << ")"; + throw std::invalid_argument(err_msg.str()); + } + + torch::Tensor lse = torch::empty({0}); + if (return_lse) + { + lse = torch::empty({batch_size, num_qo_heads, qo_len}, query.options().dtype(torch::kFloat32)); + } + + const int num_kv_groups = num_qo_heads / num_kv_heads; + + auto output_dtype = output.scalar_type(); + auto value_mean_dtype = value_mean.scalar_type(); + + TORCH_CHECK(value_mean_dtype == output_dtype, "value_mean and output must have the same dtype"); + + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + DISPATCH_CAUSAL(is_causal, IS_CAUSAL, { + DISPATCH_QK_QUANT_GRAN(qk_quant_gran, QK_QUANT_GRAN, { + DISPATCH_RETURN_LSE(return_lse, RETURN_LSE, { + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(output_dtype, DTypeOut, { + + constexpr int CTA_Q = 128; + constexpr int CTA_K = 64; + constexpr int WARP_Q = 32; + constexpr int WARP_K = 64; + + constexpr MaskMode mask_mode = IS_CAUSAL ? MaskMode::kCausal : MaskMode::kNone; + + if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp)) + { + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, div_ceil(qo_len, CTA_Q) * (CTA_Q / WARP_Q)); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K)); + } + else if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread)) + { + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, div_ceil(qo_len, CTA_Q) * (CTA_Q / WARP_Q) * 8); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K) * 4); + } + else + { + static_assert(QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp) || QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread), "Unsupported quantization granularity"); + } + + CHECK_SHAPE(value_mean, batch_size, num_kv_heads, head_dim); + + // smem_Q smem_K smem_V smem_O + size_t smem_max = std::max(CTA_Q * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(half), CTA_Q * HEAD_DIM * sizeof(half)); + + auto kernel_func = qk_int_sv_f16_attn_kernel(QK_QUANT_GRAN), static_cast(QK_QUANT_GRAN), half, false, DTypeOut, ComputeUnit::kTensorCore, + mask_mode, RETURN_LSE, true>; + + cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_max); + + dim3 grid(div_ceil(qo_len, CTA_Q), num_qo_heads, batch_size); + dim3 block(32, (CTA_Q / WARP_Q) * (CTA_K / WARP_K)); + + kernel_func<<>>( + query.data_ptr(), + key.data_ptr(), + reinterpret_cast(value.data_ptr()), + reinterpret_cast(output.data_ptr()), + (RETURN_LSE) ? reinterpret_cast(lse.data_ptr()) : nullptr, + reinterpret_cast(query_scale.data_ptr()), + reinterpret_cast(key_scale.data_ptr()), + reinterpret_cast(value_mean.data_ptr()), + qo_len, + kv_len, + num_kv_groups, + stride_bz_q, stride_seq_q, stride_h_q, + stride_bz_k, stride_seq_k, stride_h_k, + stride_bz_v, stride_seq_v, stride_h_v, + stride_bz_o, stride_seq_o, stride_h_o, + sm_scale); + }); + }); + }); + }); + }); + + return lse; +} \ No newline at end of file diff --git a/kernels/attention/sage_attn/qattn/qk_int_sv_f8_cuda_sm89.cuh b/kernels/attention/sage_attn/qattn/qk_int_sv_f8_cuda_sm89.cuh new file mode 100644 index 0000000000..b85b8291ed --- /dev/null +++ b/kernels/attention/sage_attn/qattn/qk_int_sv_f8_cuda_sm89.cuh @@ -0,0 +1,710 @@ +/* + * Copyright (c) 2024 by SageAttention team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../utils.cuh" +#include +#include +#include + +#include "../cp_async.cuh" +#include "../mma.cuh" +#include "../permuted_smem.cuh" +#include "../math.cuh" +#include "../dispatch_utils.h" + +#include "attn_utils.cuh" + +#define PACK_SIZE_QK 16 // as if it is int8 +#define PACK_SIZE_V 16 // fp8 +#define PACK_SIZE_O 8 // fp16 + +// treat as if int8 tensor core +#define MMA_QK_M 16 +#define MMA_QK_N 16 +#define MMA_QK_K 32 + +// fp8 tensor core +#define MMA_SV_M 16 +#define MMA_SV_N 16 +#define MMA_SV_K 32 + +template +__global__ void qk_int_sv_f8_attn_kernel(int8_t *__restrict__ Q, int8_t *__restrict__ K, int8_t *__restrict__ V, DTypeOut *__restrict__ O, float *__restrict__ Lse, + float *__restrict__ Q_scale, float *__restrict__ K_scale, float *__restrict__ V_scale, float *__restrict__ V_mean, + const uint32_t qo_len, const uint32_t kv_len, const uint32_t num_kv_groups, + const uint32_t stride_bz_q, const uint32_t stride_seq_q, const uint32_t stride_h_q, + const uint32_t stride_bz_k, const uint32_t stride_seq_k, const uint32_t stride_h_k, + const uint32_t stride_bz_v, const uint32_t stride_h_v, const uint32_t stride_d_v, + const uint32_t stride_bz_o, const uint32_t stride_seq_o, const uint32_t stride_h_o, + float sm_scale) +{ + // compile time check + static_assert(DTypeQK == DataType::kInt8 || DTypeQK == DataType::kInt4, "DTypeQK must be int8 or int4"); + static_assert(Q_GRAN == QuantGranularity::kPerBlock || Q_GRAN == QuantGranularity::kPerWarp || Q_GRAN == QuantGranularity::kPerThread, "Q_GRAN must be kPerBlock, kPerWarp or kPerThread"); + static_assert(K_GRAN == QuantGranularity::kPerBlock || K_GRAN == QuantGranularity::kPerWarp || K_GRAN == QuantGranularity::kPerThread, "K_GRAN must be kPerBlock, kPerWarp or kPerThread"); + static_assert(head_dim % 64 == 0, "head_dim must be a multiple of 64"); + static_assert(std::is_same::value, "DTypeSVAccum must be float, half is WIP"); + static_assert(std::is_same::value || std::is_same::value, "DTypeOut must be half or nv_bfloat16"); + static_assert(CTA_K % 64 == 0); + static_assert(CTA_Q / CTA_K <= 2); // for efficient causal implementation + + constexpr uint32_t num_warps_q = CTA_Q / WARP_Q; + constexpr uint32_t num_warps_k = CTA_K / WARP_K; + constexpr uint32_t num_warps = num_warps_q * num_warps_k; + constexpr uint32_t num_tiles_q = WARP_Q / MMA_QK_M; + constexpr uint32_t num_tiles_k = WARP_K / MMA_QK_N; + constexpr uint32_t num_tiles_qk_inner = (DTypeQK == DataType::kInt8) ? (head_dim / MMA_QK_K) : (head_dim / 2 / MMA_QK_K); + constexpr uint32_t num_tiles_v = head_dim / MMA_SV_N; + + constexpr uint32_t QK_SMEM_STRIDE = (DTypeQK == DataType::kInt8) ? (head_dim) : (head_dim / 2); + constexpr uint32_t O_SMEM_STRIDE = head_dim; + // for fp16: head_dim + constexpr uint32_t V_SMEM_STRIDE = CTA_K; + + extern __shared__ int8_t smem[]; + + const uint32_t lane_id = get_lane_id(); + const uint32_t warp_id = get_warp_id(); + + // maximize L2 hit rate + const uint32_t batch_id = blockIdx.z; + const uint32_t bx = blockIdx.x; + const uint32_t num_qo_heads = gridDim.y; + const uint32_t head_id = blockIdx.y; + + // transfer to base 2 instead of base e with better numerical efficiency + sm_scale *= math::log2e; + + // RS holds the fragment of S + int32_t RS[num_tiles_q][num_tiles_k][8]; + DTypeSVAccum RO[num_tiles_q][num_tiles_v][8]; + float m[num_tiles_q][2]; // max + float d[num_tiles_q][2]; // denominator + + uint32_t q_scale_idx, k_scale_idx; + + if constexpr (Q_GRAN == QuantGranularity::kPerBlock) + { + const uint32_t num_block_q = gridDim.x; + q_scale_idx = batch_id * num_qo_heads * num_block_q + head_id * num_block_q + bx; + } + else if constexpr (Q_GRAN == QuantGranularity::kPerWarp) + { + const uint32_t num_warp_block_q = gridDim.x * num_warps_q; + q_scale_idx = batch_id * num_qo_heads * num_warp_block_q + head_id * num_warp_block_q + bx * num_warps_q + get_warp_idx_q(); + } + else if constexpr (Q_GRAN == QuantGranularity::kPerThread) + { + const uint32_t num_warp_block_q = gridDim.x * num_warps_q; + q_scale_idx = batch_id * num_qo_heads * (num_warp_block_q * 8) + head_id * (num_warp_block_q * 8) + bx * (num_warps_q * 8) + get_warp_idx_q() * 8 + lane_id / 4; + } + + if constexpr (K_GRAN == QuantGranularity::kPerBlock) + { + const uint32_t num_block_k = div_ceil(kv_len, CTA_K); + k_scale_idx = batch_id * (num_qo_heads / num_kv_groups) * num_block_k + (head_id / num_kv_groups) * num_block_k; + } + else if constexpr (K_GRAN == QuantGranularity::kPerWarp) + { + const uint32_t num_warp_block_k = div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K); + k_scale_idx = batch_id * (num_qo_heads / num_kv_groups) * num_warp_block_k + (head_id / num_kv_groups) * num_warp_block_k + get_warp_idx_k(); + } + else if constexpr (K_GRAN == QuantGranularity::kPerThread) + { + const uint32_t num_warp_block_k = div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K); + k_scale_idx = batch_id * (num_qo_heads / num_kv_groups) * (num_warp_block_k * 4) + (head_id / num_kv_groups) * (num_warp_block_k * 4) + get_warp_idx_k() * 4 + lane_id % 4; + } + + constexpr uint32_t k_scale_advance_offset = (K_GRAN == QuantGranularity::kPerBlock) ? 1 : (K_GRAN == QuantGranularity::kPerWarp) ? (CTA_K / WARP_K) : (CTA_K / WARP_K) * 4; + + // initialize o, m, d +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { + if constexpr (std::is_same::value) + { +#pragma unroll + for (uint32_t k = 0; k < 8; k++) + { + RO[fq][fv][k] = 0.0f; + } + } + else if constexpr (std::is_same::value) + { +#pragma unroll + for (uint32_t k = 0; k < 4; k++) + { + ((int32_t*)RO[fq][fv])[k] = 0; + } + } + } + } +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t k = 0; k < 2; k++) + { + m[fq][k] = -5000000.0f; + d[fq][k] = 1.0f; + } + } + + constexpr uint32_t K_smem_idx_offset = CTA_Q; + constexpr uint32_t V_smem_idx_offset = CTA_Q + CTA_K; + + constexpr SwizzleMode swizzle_mode_QK = (QK_SMEM_STRIDE == 32) ? SwizzleMode::k32B : (QK_SMEM_STRIDE == 64) ? SwizzleMode::k64B : SwizzleMode::k128B; + smem_t smem_Q(smem); + smem_t smem_K(smem + K_smem_idx_offset * QK_SMEM_STRIDE); + // for fp16: 32 + constexpr SwizzleMode swizzle_mode_V = (V_SMEM_STRIDE == 64) ? SwizzleMode::k64B : SwizzleMode::k128B; + smem_t smem_V(smem + V_smem_idx_offset * QK_SMEM_STRIDE); + constexpr SwizzleMode swizzle_mode_O = (O_SMEM_STRIDE == 32) ? SwizzleMode::k64B : SwizzleMode::k128B; + smem_t smem_O(smem); + + constexpr uint32_t global_to_shared_line_lanes_QK = (QK_SMEM_STRIDE == 32) ? 2 : (QK_SMEM_STRIDE == 64) ? 4 : 8; + constexpr uint32_t global_to_shared_copy_lines_per_warp_QK = (QK_SMEM_STRIDE == 32) ? 16 : (QK_SMEM_STRIDE == 64) ? 8 : 4; + // for fp16: 32 + constexpr uint32_t global_to_shared_line_lanes_V = (V_SMEM_STRIDE == 64) ? 4 : 8; + // for fp16: 32 + constexpr uint32_t global_to_shared_copy_lines_per_warp_V = (V_SMEM_STRIDE == 64) ? 8 : 4; + constexpr uint32_t global_to_shared_line_lanes_O = (O_SMEM_STRIDE == 32) ? 4 : 8; + constexpr uint32_t global_to_shared_copy_lines_per_warp_O = (O_SMEM_STRIDE == 32) ? 8 : 4; + + constexpr uint32_t QK_smem_iters_row = QK_SMEM_STRIDE / (global_to_shared_line_lanes_QK * PACK_SIZE_QK); + constexpr uint32_t Q_smem_iters_col = CTA_Q / (num_warps * global_to_shared_copy_lines_per_warp_QK); + constexpr uint32_t K_smem_iters_col = CTA_K / (num_warps * global_to_shared_copy_lines_per_warp_QK); + constexpr uint32_t V_smem_iters_row = V_SMEM_STRIDE / (global_to_shared_line_lanes_V * PACK_SIZE_V); + // for fp16: CTA_K + constexpr uint32_t V_smem_iters_col = head_dim / (num_warps * global_to_shared_copy_lines_per_warp_V); + constexpr uint32_t O_smem_iters_row = O_SMEM_STRIDE / (global_to_shared_line_lanes_O * PACK_SIZE_O); + constexpr uint32_t O_smem_iters_col = CTA_Q / (num_warps * global_to_shared_copy_lines_per_warp_O); + + int8_t *Q_lane_base_ptr = Q + batch_id * stride_bz_q + head_id * stride_h_q + (bx * CTA_Q + CTA_Q / num_warps * warp_id + lane_id / global_to_shared_line_lanes_QK) * stride_seq_q + (lane_id % global_to_shared_line_lanes_QK) * PACK_SIZE_QK; + int8_t *K_lane_base_ptr = K + batch_id * stride_bz_k + (head_id / num_kv_groups) * stride_h_k + (CTA_K / num_warps * warp_id + lane_id / global_to_shared_line_lanes_QK) * stride_seq_k + (lane_id % global_to_shared_line_lanes_QK) * PACK_SIZE_QK; + // for fp16: CTA_K / num_warps * warp_id * stride_seq_v + lane_id / global_to_shared_line_lanes_V * stride_seq_v + int8_t *V_lane_base_ptr = V + batch_id * stride_bz_v + (head_id / num_kv_groups) * stride_h_v + head_dim / num_warps * warp_id * stride_d_v + lane_id / global_to_shared_line_lanes_V * stride_d_v + (lane_id % global_to_shared_line_lanes_V) * PACK_SIZE_V; + uint32_t Q_smem_offset_load = smem_Q.get_permuted_offset(warp_id * global_to_shared_copy_lines_per_warp_QK * Q_smem_iters_col + lane_id / global_to_shared_line_lanes_QK, lane_id % global_to_shared_line_lanes_QK); + uint32_t K_smem_offset_load = smem_K.get_permuted_offset(warp_id * global_to_shared_copy_lines_per_warp_QK * K_smem_iters_col + lane_id / global_to_shared_line_lanes_QK, lane_id % global_to_shared_line_lanes_QK); + uint32_t V_smem_offset_load = smem_V.get_permuted_offset(warp_id * global_to_shared_copy_lines_per_warp_V * V_smem_iters_col + lane_id / global_to_shared_line_lanes_V, lane_id % global_to_shared_line_lanes_V); + + uint32_t Q_smem_offset_mma = smem_Q.get_permuted_offset(get_warp_idx_q() * WARP_Q + lane_id % 16, lane_id / 16); + uint32_t K_smem_offset_mma = smem_K.get_permuted_offset(get_warp_idx_k() * WARP_K + lane_id % 8 + (lane_id / 16) * 8, (lane_id / 8) % 2); + // for fp 16: + // uint32_t V_smem_offset_mma = smem_V.get_permuted_offset(get_warp_idx_k() * WARP_K + lane_id % 16, lane_id / 16); + uint32_t V_smem_offset_mma = smem_V.get_permuted_offset(lane_id % 8 + (lane_id / 16) * 8, get_warp_idx_k() * WARP_K / PACK_SIZE_V + (lane_id / 8) % 2); + + // for causal masking + uint32_t Q_idx_lane_base = bx * CTA_Q + get_warp_idx_q() * WARP_Q + lane_id / 4; + uint32_t K_idx_lane_base = get_warp_idx_k() * WARP_K + 2 * (lane_id % 4); + + // for loading + uint32_t Q_load_idx_lane_base = bx * CTA_Q + CTA_Q / num_warps * warp_id + lane_id / global_to_shared_line_lanes_QK; + uint32_t K_load_idx_lane_base = CTA_K / num_warps * warp_id + lane_id / global_to_shared_line_lanes_QK; + + const uint32_t num_iterations = div_ceil( + mask_mode == MaskMode::kCausal + ? min(kv_len, (bx + 1) * CTA_Q) + : kv_len, + CTA_K); + + // load Q with predicate + load_global_to_share( + &Q_lane_base_ptr, Q_smem_offset_load, stride_seq_q, smem_Q, Q_load_idx_lane_base, qo_len); + cp_async::commit_group(); + cp_async::wait_group<0>(); + __syncthreads(); + + // for num_tiles_qk_inner = 1, we load all Qs in register + uint32_t RQ[num_tiles_q][4]; + if constexpr (num_tiles_qk_inner == 1) + { +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + smem_Q.ldmatrix_m8n8x4(Q_smem_offset_mma, RQ[fq]); + Q_smem_offset_mma = smem_Q.advance_offset_by_row<16>(Q_smem_offset_mma); + } + } + + // load K with predicate + load_global_to_share( + &K_lane_base_ptr, K_smem_offset_load, stride_seq_k, smem_K, K_load_idx_lane_base, kv_len); + cp_async::commit_group(); + + float q_scale = Q_scale[q_scale_idx]; + + float original_sm_scale = sm_scale; + float dequant_scale = q_scale * K_scale[k_scale_idx + 0 * k_scale_advance_offset]; + + sm_scale = original_sm_scale * dequant_scale; + + // load V + // ! we assume that V is padded. If not, there might be illegal memory access or nan issue. + // for fp16: + // load_global_to_share stride_seq_v + load_fp8_V_global_to_share( + &V_lane_base_ptr, V_smem_offset_load, stride_d_v, smem_V); + cp_async::commit_group(); + + K_load_idx_lane_base += CTA_K; + +#pragma unroll + for (uint32_t iter = 1; iter < num_iterations - 1; iter++) + { + // ensure K is ready + cp_async::wait_group<1>(); + __syncthreads(); + + // compute QK^T + if constexpr (num_tiles_qk_inner == 1) + { + compute_int_qk( + smem_K, RS, RQ, K_smem_offset_mma); + } + else + { + compute_int_qk( + smem_Q, smem_K, RS, Q_smem_offset_mma, K_smem_offset_mma); + } + float RS_f32[num_tiles_q][num_tiles_k][8]; + +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { +#pragma unroll + for (uint32_t k = 0; k < 8; k++) + { + RS_f32[fq][fk][k] = __int2float_rz(RS[fq][fk][k]); + } + } + } + + K_idx_lane_base += CTA_K; + + if constexpr (std::is_same::value) + { + update_mdo(RS_f32, RO, m, d, sm_scale); + } + else if constexpr (std::is_same::value) + { + update_mdo(RS_f32, RO, m, d, sm_scale); + } + + if constexpr (DenominatorAccumUnit == ComputeUnit::kCudaCore) + { + accumulate_d(RS_f32, d); + } + + uint32_t RS_f8[num_tiles_q][num_tiles_k / 2][4]; + RS_32_to_8(RS_f32, RS_f8); + + if constexpr (DenominatorAccumUnit == ComputeUnit::kTensorCore) + { + accumulate_d_f8(RS_f8, d); + } + + __syncthreads(); + + // load K without predicate + load_global_to_share( + &K_lane_base_ptr, K_smem_offset_load, stride_seq_k, smem_K); + cp_async::commit_group(); + + dequant_scale = q_scale * K_scale[k_scale_idx + iter * k_scale_advance_offset]; + sm_scale = original_sm_scale * dequant_scale; + + // ensure V is ready + cp_async::wait_group<1>(); + __syncthreads(); + + // for fp16: + // compute_fp16_sv_permuted( + // smem_V, RS_f16, RO, d, V_smem_offset_mma); + if constexpr (!use_inst_buffer) + { + compute_fp8_sv( + smem_V, RS_f8, RO, d); + } + else + { + if constexpr (!use_pv_fp16_accu){ + compute_fp8_sv_inst_buf( + smem_V, RS_f8, RO, d); + } + else{ + compute_fp8_sv_inst_buf_fp16_accu( + smem_V, RS_f8, RO, d); + } + } + __syncthreads(); + // load V + // for fp16: + // load_global_to_share stride_seq_v + load_fp8_V_global_to_share( + &V_lane_base_ptr, V_smem_offset_load, stride_d_v, smem_V); + cp_async::commit_group(); + + K_load_idx_lane_base += CTA_K; + } + + // second last iter, apply causal mask + if (num_iterations > 1) + { + // ensure K is ready + cp_async::wait_group<1>(); + __syncthreads(); + + // compute QK^T + if constexpr (num_tiles_qk_inner == 1) + { + compute_int_qk( + smem_K, RS, RQ, K_smem_offset_mma); + } + else + { + compute_int_qk( + smem_Q, smem_K, RS, Q_smem_offset_mma, K_smem_offset_mma); + } + + float RS_f32[num_tiles_q][num_tiles_k][8]; + +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { +#pragma unroll + for (uint32_t k = 0; k < 8; k++) + { + RS_f32[fq][fk][k] = __int2float_rz(RS[fq][fk][k]) * dequant_scale; + } + } + } + + if constexpr (mask_mode == MaskMode::kCausal) + { + apply_causal_mask(Q_idx_lane_base, K_idx_lane_base, RS_f32); + } + // apply_out_of_bound_mask(K_idx_lane_base, RS_f32, kv_len); + K_idx_lane_base += CTA_K; + + if constexpr (std::is_same::value) + { + update_mdo(RS_f32, RO, m, d, original_sm_scale); + } + else if constexpr (std::is_same::value) + { + update_mdo(RS_f32, RO, m, d, original_sm_scale); + } + + if constexpr (DenominatorAccumUnit == ComputeUnit::kCudaCore) + { + accumulate_d(RS_f32, d); + } + + uint32_t RS_f8[num_tiles_q][num_tiles_k / 2][4]; + RS_32_to_8(RS_f32, RS_f8); + + if constexpr (DenominatorAccumUnit == ComputeUnit::kTensorCore) + { + accumulate_d_f8(RS_f8, d); + } + + __syncthreads(); + + // load K with predicate + load_global_to_share( + &K_lane_base_ptr, K_smem_offset_load, stride_seq_k, smem_K, K_load_idx_lane_base, kv_len); + cp_async::commit_group(); + + dequant_scale = q_scale * K_scale[k_scale_idx + (num_iterations - 1) * k_scale_advance_offset]; + sm_scale = original_sm_scale * dequant_scale; + + // ensure V is ready + cp_async::wait_group<1>(); + __syncthreads(); + + // for fp16: + // compute_fp16_sv_permuted( + // smem_V, RS_f16, RO, d, V_smem_offset_mma); + if constexpr (!use_inst_buffer) + { + compute_fp8_sv( + smem_V, RS_f8, RO, d); + } + else + { + if constexpr (!use_pv_fp16_accu){ + compute_fp8_sv_inst_buf( + smem_V, RS_f8, RO, d); + } + else{ + compute_fp8_sv_inst_buf_fp16_accu( + smem_V, RS_f8, RO, d); + } + } + + __syncthreads(); + // load V + // for fp16: + // load_global_to_share stride_seq_v + load_fp8_V_global_to_share( + &V_lane_base_ptr, V_smem_offset_load, stride_d_v, smem_V); + cp_async::commit_group(); + K_load_idx_lane_base += CTA_K; + } + + // last iter, apply causal mask and out of bound mask + { + // ensure K is ready + cp_async::wait_group<1>(); + __syncthreads(); + + // compute QK^T + if constexpr (num_tiles_qk_inner == 1) + { + compute_int_qk( + smem_K, RS, RQ, K_smem_offset_mma); + } + else + { + compute_int_qk( + smem_Q, smem_K, RS, Q_smem_offset_mma, K_smem_offset_mma); + } + + float RS_f32[num_tiles_q][num_tiles_k][8]; + +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { +#pragma unroll + for (uint32_t k = 0; k < 8; k++) + { + RS_f32[fq][fk][k] = __int2float_rz(RS[fq][fk][k]) * dequant_scale; + } + } + } + + if constexpr (mask_mode == MaskMode::kCausal) + { + apply_causal_mask(Q_idx_lane_base, K_idx_lane_base, RS_f32); + } + apply_out_of_bound_mask(K_idx_lane_base, RS_f32, kv_len); + K_idx_lane_base += CTA_K; + + if constexpr (std::is_same::value) + { + update_mdo(RS_f32, RO, m, d, original_sm_scale); + } + else if constexpr (std::is_same::value) + { + update_mdo(RS_f32, RO, m, d, original_sm_scale); + } + + if constexpr (DenominatorAccumUnit == ComputeUnit::kCudaCore) + { + accumulate_d(RS_f32, d); + } + + uint32_t RS_f8[num_tiles_q][num_tiles_k / 2][4]; + RS_32_to_8(RS_f32, RS_f8); + + if constexpr (DenominatorAccumUnit == ComputeUnit::kTensorCore) + { + accumulate_d_f8(RS_f8, d); + } + + // ensure V is ready + cp_async::wait_group<0>(); + __syncthreads(); + + // for fp16: + // compute_fp16_sv_permuted( + // smem_V, RS_f16, RO, d, V_smem_offset_mma); + if constexpr (!use_inst_buffer) + { + compute_fp8_sv( + smem_V, RS_f8, RO, d); + } + else + { + if constexpr (!use_pv_fp16_accu){ + compute_fp8_sv_inst_buf( + smem_V, RS_f8, RO, d); + } + else{ + compute_fp8_sv_inst_buf_fp16_accu( + smem_V, RS_f8, RO, d); + } + } + + __syncthreads(); + + } + + // TODO: thread block sync mdo state for num_warps_k > 0. Then only one thread block needs to do the final saving. + + normalize_d(RO, m, d); + + // ! here we just implement the case for fp32 acumulation + if constexpr (fuse_v_scale) + { + float v_scale[4]; + float *V_scale_base_ptr = V_scale + batch_id * (num_qo_heads / num_kv_groups) * head_dim + (head_id / num_kv_groups) * head_dim + (lane_id % 4 ) * 2; +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { + ((float2*)v_scale)[0] = *((float2*)(V_scale_base_ptr + fv * 16)); + ((float2*)v_scale)[1] = *((float2*)(V_scale_base_ptr + fv * 16 + 8)); +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + RO[fq][fv][0] *= v_scale[0]; + RO[fq][fv][1] *= v_scale[1]; + RO[fq][fv][2] *= v_scale[0]; + RO[fq][fv][3] *= v_scale[1]; + RO[fq][fv][4] *= v_scale[2]; + RO[fq][fv][5] *= v_scale[3]; + RO[fq][fv][6] *= v_scale[2]; + RO[fq][fv][7] *= v_scale[3]; + } + } + } + + if constexpr (fuse_v_mean) + { + float v_mean[4]; + float *V_mean_base_ptr = V_mean + batch_id * (num_qo_heads / num_kv_groups) * head_dim + (head_id / num_kv_groups) * head_dim + (lane_id % 4 ) * 2; +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { + ((float2*)v_mean)[0] = *((float2*)(V_mean_base_ptr + fv * 16)); + ((float2*)v_mean)[1] = *((float2*)(V_mean_base_ptr + fv * 16 + 8)); +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + RO[fq][fv][0] += v_mean[0]; + RO[fq][fv][1] += v_mean[1]; + RO[fq][fv][2] += v_mean[0]; + RO[fq][fv][3] += v_mean[1]; + RO[fq][fv][4] += v_mean[2]; + RO[fq][fv][5] += v_mean[3]; + RO[fq][fv][6] += v_mean[2]; + RO[fq][fv][7] += v_mean[3]; + } + } + } + + // save the result to shared memory + uint32_t smem_O_row_base = get_warp_idx_q() * WARP_Q + lane_id / 4; +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { + uint32_t offset_O = smem_O.get_permuted_offset(smem_O_row_base + fq * MMA_QK_M, fv * (MMA_SV_N / PACK_SIZE_O)); + + if constexpr (std::is_same::value) + { + // convert RO to half + uint32_t RO_f16[4]; +#pragma unroll + for (uint32_t k = 0; k < 4; k++) + { + if constexpr (std::is_same::value) + { + ((half2*)RO_f16)[k] = __float22half2_rn(((float2*)RO[fq][fv])[k]); + } + else + { + ((nv_bfloat162*)RO_f16)[k] = __float22bfloat162_rn(((float2*)RO[fq][fv])[k]); + } + } + + ((uint32_t*)(smem_O.base + offset_O))[lane_id % 4] = RO_f16[0]; + ((uint32_t*)(smem_O.base + offset_O + 8 * (O_SMEM_STRIDE / PACK_SIZE_O)))[lane_id % 4] = RO_f16[1]; + + offset_O = smem_O.get_permuted_offset(smem_O_row_base + fq * MMA_QK_M, fv * (MMA_SV_N / PACK_SIZE_O) + 1); + ((uint32_t*)(smem_O.base + offset_O))[lane_id % 4] = RO_f16[2]; + ((uint32_t*)(smem_O.base + offset_O + 8 * (O_SMEM_STRIDE / PACK_SIZE_O)))[lane_id % 4] = RO_f16[3]; + } + else if constexpr (std::is_same::value) + { + // TODO: not implement + } + } + } + + // ! do we need to sync here? + __syncwarp(); + + // shared memory to global memory + DTypeOut *O_lane_ptr = O + batch_id * stride_bz_o + head_id * stride_h_o + (bx * CTA_Q + WARP_Q * get_warp_idx_q() + lane_id / global_to_shared_line_lanes_O) * stride_seq_o + lane_id % global_to_shared_line_lanes_O * PACK_SIZE_O; + uint32_t offset_O = smem_O.get_permuted_offset(get_warp_idx_q() * WARP_Q + lane_id / global_to_shared_line_lanes_O, lane_id % global_to_shared_line_lanes_O); + uint32_t O_load_idx_lane_base = bx * CTA_Q + CTA_Q / num_warps * warp_id + lane_id / global_to_shared_line_lanes_O; + +#pragma unroll + for (uint32_t i = 0; i < O_smem_iters_col; i++) + { +#pragma unroll + for (uint32_t j = 0; j < O_smem_iters_row; j++) + { + if (O_load_idx_lane_base < qo_len) + { + smem_O.store_128b(offset_O, O_lane_ptr); + } + O_lane_ptr += (global_to_shared_line_lanes_O * PACK_SIZE_O); + offset_O = smem_O.advance_offset_by_column(offset_O); + } + + offset_O = smem_O.advance_offset_by_row(offset_O - (O_smem_iters_row * global_to_shared_line_lanes_O)); + O_lane_ptr += ((global_to_shared_copy_lines_per_warp_O * stride_seq_o) - (O_smem_iters_row * global_to_shared_line_lanes_O * PACK_SIZE_O)); + O_load_idx_lane_base += global_to_shared_copy_lines_per_warp_O; + } + + if constexpr (return_lse) + { + // ! this only works for num_tiles_q = 2 + uint32_t lse_idx = bx * CTA_Q + lane_id / 4 + 8 * (lane_id % 4) + WARP_Q * get_warp_idx_q(); + float *lse_lane_ptr = Lse + batch_id * (qo_len * num_qo_heads) + head_id * qo_len + lse_idx; + uint32_t fq = (lane_id % 4) / 2; + uint32_t k = (lane_id % 4) % 2; + + if (lse_idx < qo_len) + { + lse_lane_ptr[0] = (math::ptx_log2(d[fq][k]) + m[fq][k]); + } + } +} + + + + + + diff --git a/kernels/attention/sage_attn/qattn/qk_int_sv_f8_cuda_sm90.cu b/kernels/attention/sage_attn/qattn/qk_int_sv_f8_cuda_sm90.cu new file mode 100644 index 0000000000..9a366b058b --- /dev/null +++ b/kernels/attention/sage_attn/qattn/qk_int_sv_f8_cuda_sm90.cu @@ -0,0 +1,916 @@ +/* + * Copyright (c) 2024 by SageAttention team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../utils.cuh" +#include +#include +#include +#include + +#include "../wgmma.cuh" +#include "../math.cuh" +#include "../dispatch_utils.h" + +#include "attn_utils.cuh" + +template +CUtensorMap create_tensor_map_4D(T* gmem_ptr, int d1, int d2, int d3, int d4, int stride1, int stride2, int stride3) { + constexpr int smem_stride = BlockMinorSize * sizeof(T); + static_assert(sizeof(T) == 2 || sizeof(T) == 1); + static_assert(smem_stride == 32 || smem_stride == 64 || smem_stride == 128); + + CUtensorMap tma_map; + void* gmem_address = (void*)gmem_ptr; + uint64_t gmem_prob_shape[5] = {(uint64_t)d4, (uint64_t)d3, (uint64_t)d2, (uint64_t)d1, 1}; + uint64_t gmem_prob_stride[5] = {(uint64_t)stride3 * sizeof(T), (uint64_t)stride2 * sizeof(T), (uint64_t)stride1 * sizeof(T), 0, 0}; + uint32_t smem_box_shape[5] = {uint32_t(BlockMinorSize), uint32_t(BlockMajorSize), 1, 1, 1}; + uint32_t smem_box_stride[5] = {1, 1, 1, 1, 1}; + + CUresult result = cuTensorMapEncodeTiled( + &tma_map, (sizeof(T) == 2) ? CU_TENSOR_MAP_DATA_TYPE_BFLOAT16 : CU_TENSOR_MAP_DATA_TYPE_UINT8, 4, gmem_address, gmem_prob_shape, + gmem_prob_stride, smem_box_shape, smem_box_stride, CU_TENSOR_MAP_INTERLEAVE_NONE, + (swizzle == false) ? CU_TENSOR_MAP_SWIZZLE_NONE : (smem_stride == 128) ? CU_TENSOR_MAP_SWIZZLE_128B : (smem_stride == 64) ? CU_TENSOR_MAP_SWIZZLE_64B : CU_TENSOR_MAP_SWIZZLE_32B, + promotion_mode, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); + + assert(result == CUDA_SUCCESS); + + return tma_map; +} + +__device__ __forceinline__ void init_barrier(uint64_t* bar, int thread_count) { + uint32_t bar_ptr = static_cast(__cvta_generic_to_shared(bar)); + asm volatile ( + "mbarrier.init.shared::cta.b64 [%0], %1;\n" + :: "r"(bar_ptr), "r"(thread_count) + ); +} + +template +__device__ __forceinline__ void expect_bytes(uint64_t* bar) { + uint32_t bar_ptr = static_cast(__cvta_generic_to_shared(bar)); + asm volatile ("mbarrier.arrive.expect_tx.shared::cta.b64 _, [%0], %1;\n" + :: "r"(bar_ptr), "n"(bytes)); +} + +template +__device__ __forceinline__ void load_async_4D(T *dst, void const* const src_tma_map, uint64_t* bar, int s0, int s1, int s2, int s3) { + uint64_t tma_ptr = reinterpret_cast(src_tma_map); + uint32_t mbar_ptr = static_cast(__cvta_generic_to_shared(bar)); + uint32_t dst_ptr = static_cast(__cvta_generic_to_shared(dst)); + + asm volatile ( + "cp.async.bulk.tensor.4d.shared::cluster.global.tile.mbarrier::complete_tx::bytes" + " [%0], [%1, {%3, %4, %5, %6}], [%2];" + : + : "r"(dst_ptr), "l"(tma_ptr), "r"(mbar_ptr), + "r"(s0), "r"(s1), "r"(s2), "r"(s3) + : "memory" + ); +} + +template +__device__ __forceinline__ void store_async_4D(void const* dst_tma_map, T *src, int global_token_idx, int global_head_idx, int global_batch_idx) { + uint64_t tma_ptr = reinterpret_cast(dst_tma_map); + uint32_t src_ptr = static_cast(__cvta_generic_to_shared(src)); + + asm volatile ( + "cp.async.bulk.tensor.4d.global.shared::cta.tile.bulk_group" + " [%0, {%2, %3, %4, %5}], [%1];" + : + : "l"(tma_ptr), "r"(src_ptr), + "n"(0), "r"(global_token_idx), "r"(global_head_idx), "r"(global_batch_idx) + : "memory" + ); +} + +__device__ __forceinline__ void wait(uint64_t* bar, int kPhaseBit) { + uint32_t mbar_ptr = static_cast(__cvta_generic_to_shared(bar)); + asm volatile ( + "{\n" + ".reg .pred P1;\n" + "LAB_WAIT:\n" + "mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1;\n" + "@P1 bra.uni DONE;\n" + "bra.uni LAB_WAIT;\n" + "DONE:\n" + "}\n" + :: "r"(mbar_ptr), + "r"(kPhaseBit) + ); +} + +template +__device__ __forceinline__ void arrive(uint64_t* bar) { + uint32_t mbar_ptr = static_cast(__cvta_generic_to_shared(bar)); + asm volatile ( + "mbarrier.arrive.release.cta.shared::cta.b64 _, [%0], %1;\n" + : + : "r"(mbar_ptr), "n"(count) + : "memory" + ); +} + +template +__global__ void qk_int8_sv_f8_attn_kernel(const __grid_constant__ CUtensorMap tensorMapQ, + const __grid_constant__ CUtensorMap tensorMapK, + const __grid_constant__ CUtensorMap tensorMapV, + float *__restrict__ Q_scale, float *__restrict__ K_scale, float *__restrict__ V_scale, + DTypeOut* O, float *__restrict__ Lse, uint32_t stride_bz_o, uint32_t stride_h_o, uint32_t stride_seq_o, + const uint32_t qo_len, const uint32_t kv_len, const uint32_t num_kv_groups, + float sm_scale) +{ + static_assert(NUM_THREADS == 128); + static_assert(CTA_Q <= CTA_K); + + const uint32_t warp_idx = (threadIdx.x % 128) / 32; + const uint32_t lane_id = threadIdx.x % 32; + + constexpr uint32_t num_tiles_q = CTA_Q / 64; + constexpr uint32_t num_tiles_k = CTA_K / 16; + constexpr uint32_t num_tiles_qk_inner = head_dim / 32; + constexpr uint32_t num_tiles_v = head_dim / 16; + constexpr uint32_t num_tiles_pv_inner = CTA_K / 32; + + const uint32_t batch_id = blockIdx.z; + const uint32_t bx = blockIdx.x; + const uint32_t head_id = blockIdx.y; + const uint32_t num_qo_heads = gridDim.y; + const uint32_t kv_head_id = head_id / num_kv_groups; + + sm_scale *= math::log2e; + + extern __shared__ __align__(128) int8_t smem_[]; + + int8_t *sQ = (int8_t*)smem_; + int8_t *sK = (int8_t*)(smem_ + CTA_Q * head_dim * sizeof(int8_t)); + int8_t *sV = (int8_t*)(smem_ + CTA_Q * head_dim * sizeof(int8_t) + CTA_K * head_dim * sizeof(int8_t)); + half *sO = (half*)smem_; + + int32_t RS[num_tiles_q][num_tiles_k][8]; + float RO[num_tiles_q][num_tiles_v][8]; + float m[num_tiles_q][2]; + float d[num_tiles_q][2]; + + uint32_t q_scale_idx, k_scale_idx; + + if constexpr (Q_GRAN == QuantGranularity::kPerBlock) + { + const uint32_t num_block_q = gridDim.x; + q_scale_idx = batch_id * num_qo_heads * num_block_q + head_id * num_block_q + bx; + } + else if constexpr (Q_GRAN == QuantGranularity::kPerWarp) + { + const uint32_t num_warp_block_q = gridDim.x * 4; + q_scale_idx = batch_id * num_qo_heads * num_warp_block_q + head_id * num_warp_block_q + bx * 4 + warp_idx; + } + else if constexpr (Q_GRAN == QuantGranularity::kPerThread) + { + const uint32_t num_warp_block_q = gridDim.x * 4; + q_scale_idx = batch_id * num_qo_heads * (num_warp_block_q * 8) + head_id * (num_warp_block_q * 8) + bx * (4 * 8) + warp_idx * 8 + lane_id / 4; + } + + if constexpr (K_GRAN == QuantGranularity::kPerBlock || K_GRAN == QuantGranularity::kPerWarp) + { + const uint32_t num_block_k = div_ceil(kv_len, CTA_K); + k_scale_idx = batch_id * (num_qo_heads / num_kv_groups) * num_block_k + (head_id / num_kv_groups) * num_block_k; + } + else if constexpr (K_GRAN == QuantGranularity::kPerThread) + { + const uint32_t num_block_k = div_ceil(kv_len, CTA_K); + k_scale_idx = batch_id * (num_qo_heads / num_kv_groups) * (num_block_k * 4) + (head_id / num_kv_groups) * (num_block_k * 4) + lane_id % 4; + } + + constexpr uint32_t k_scale_advance_offset = (K_GRAN == QuantGranularity::kPerBlock || K_GRAN == QuantGranularity::kPerWarp) ? 1 : 4; + + uint32_t Q_idx_lane_base = bx * CTA_Q + warp_idx * 16 + lane_id / 4; + +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + m[fq][0] = -5000000.0f; + m[fq][1] = -5000000.0f; + d[fq][0] = 1.0f; + d[fq][1] = 1.0f; + } + +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { +#pragma unroll + for (uint32_t k = 0; k < 8; k++) + { + RO[fq][fv][k] = 0.0f; + } + } + } + + __shared__ __align__(8) uint64_t barrier_Q; + __shared__ __align__(8) uint64_t barrier_K; + __shared__ __align__(8) uint64_t barrier_V; + + if (threadIdx.x == 0) { + init_barrier(&barrier_Q, 1); + init_barrier(&barrier_K, 1); + init_barrier(&barrier_V, 1); + } + + __syncthreads(); + + // load Q, K, V + if (threadIdx.x == 0) + { + expect_bytes<(CTA_Q * head_dim) * sizeof(int8_t)>(&barrier_Q); + expect_bytes<(CTA_K * head_dim) * sizeof(int8_t)>(&barrier_K); + expect_bytes<(CTA_K * head_dim) * sizeof(int8_t)>(&barrier_V); + load_async_4D(sQ, &tensorMapQ, &barrier_Q, 0, bx * CTA_Q, head_id, batch_id); + load_async_4D(sK, &tensorMapK, &barrier_K, 0, 0, kv_head_id, batch_id); + load_async_4D(sV, &tensorMapV, &barrier_V, 0, 0, kv_head_id, batch_id); + } + + float q_scale = Q_scale[q_scale_idx]; + float original_sm_scale = sm_scale; + + // wait for Q + wait(&barrier_Q, 0); + + const uint32_t num_iterations = div_ceil( + mask_mode == MaskMode::kCausal + ? min(kv_len, (bx + 1) * CTA_Q) + : kv_len, + CTA_K); + + int p = 1; + for (uint32_t iter = 1; iter < num_iterations; iter++) + { + p ^= 1; + + float dequant_scale = q_scale * K_scale[k_scale_idx + (iter - 1) * k_scale_advance_offset]; + sm_scale = original_sm_scale * dequant_scale; + + // wait for K + wait(&barrier_K, p); + + // compute QK^T + wgmma::warpgroup_arrive(); +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + int8_t *sQ_local = sQ + fq * 64 * head_dim; + wgmma::wgmma_s8s8s32(RS[fq], sQ_local, sK); +#pragma unroll + for (int k_it = 1; k_it < num_tiles_qk_inner; k_it++) + { + wgmma::wgmma_s8s8s32(RS[fq], &sQ_local[k_it*32], &sK[k_it*32]); + } + } + wgmma::warpgroup_commit_batch(); + wgmma::warpgroup_wait<0>(); + + // load K + if (threadIdx.x == 0) + { + expect_bytes<(CTA_K * head_dim) * sizeof(int8_t)>(&barrier_K); + load_async_4D(sK, &tensorMapK, &barrier_K, 0, iter * CTA_K, kv_head_id, batch_id); + } + + // convert RS to float + float RS_f32[num_tiles_q][num_tiles_k][8]; +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { +#pragma unroll + for (uint32_t k = 0; k < 8; k++) + { + RS_f32[fq][fk][k] = __int2float_rz(RS[fq][fk][k]); + } + } + } + + update_mdo(RS_f32, RO, m, d, sm_scale); + + // accumulate d on thread basis +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unrol + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { + d[fq][0] += (RS_f32[fq][fk][0] + RS_f32[fq][fk][1] + RS_f32[fq][fk][4] + RS_f32[fq][fk][5]); + d[fq][1] += (RS_f32[fq][fk][2] + RS_f32[fq][fk][3] + RS_f32[fq][fk][6] + RS_f32[fq][fk][7]); + } + } + + uint32_t RS_f8[num_tiles_q][num_tiles_pv_inner][4]; + RS_32_to_8(RS_f32, RS_f8); + + // wait for V + wait(&barrier_V, p); + + float RO_temp[num_tiles_q][num_tiles_v][8]; + wgmma::warpgroup_arrive(); +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + wgmma::wgmma_f8f8f32(RO_temp[fq], RS_f8[fq][0], &sV[0]); +#pragma unroll + for (uint32_t v_it = 1; v_it < num_tiles_pv_inner; v_it++) + { + wgmma::wgmma_f8f8f32(RO_temp[fq], RS_f8[fq][v_it], &sV[v_it * 32]); + } + } + + wgmma::warpgroup_commit_batch(); + wgmma::warpgroup_wait<0>(); + +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { +#pragma unroll + for (uint32_t k = 0; k < 8; k++) + { + RO[fq][fv][k] += RO_temp[fq][fv][k]; + } + } + } + + // load V + if (threadIdx.x == 0) + { + expect_bytes<(CTA_K * head_dim) * sizeof(int8_t)>(&barrier_V); + load_async_4D(sV, &tensorMapV, &barrier_V, iter * CTA_K, 0, kv_head_id, batch_id); + } + } + + { + p ^= 1; + + float dequant_scale = q_scale * K_scale[k_scale_idx + (num_iterations - 1) * k_scale_advance_offset]; + sm_scale = original_sm_scale; + + // wait for K + wait(&barrier_K, p); + + // compute QK^T + wgmma::warpgroup_arrive(); +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + int8_t *sQ_local = sQ + fq * 64 * head_dim; + wgmma::wgmma_s8s8s32(RS[fq], sQ_local, sK); +#pragma unroll + for (int k_it = 1; k_it < num_tiles_qk_inner; k_it++) + { + wgmma::wgmma_s8s8s32(RS[fq], &sQ_local[k_it*32], &sK[k_it*32]); + } + } + wgmma::warpgroup_commit_batch(); + wgmma::warpgroup_wait<0>(); + + // convert RS to float + float RS_f32[num_tiles_q][num_tiles_k][8]; +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { +#pragma unroll + for (uint32_t k = 0; k < 8; k++) + { + RS_f32[fq][fk][k] = __int2float_rz(RS[fq][fk][k]) * dequant_scale; + } + } + } + + // masking +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { +#pragma unroll + for (uint32_t k = 0; k < 8; k++) + { + const uint32_t q_idx = Q_idx_lane_base + fq * 64 + 8 * ((k % 4) / 2); + const uint32_t k_idx = (num_iterations - 1) * CTA_K + fk * 16 + 2 * (lane_id % 4) + 8 * (k / 4) + k % 2; + + bool is_out_of_bounds; + + if constexpr (mask_mode == MaskMode::kCausal) + { + is_out_of_bounds = (k_idx > q_idx) || (k_idx >= kv_len); + } + else + { + is_out_of_bounds = (k_idx >= kv_len); + } + + if (is_out_of_bounds) + { + RS_f32[fq][fk][k] = -5000000.0f; + } + } + } + } + + update_mdo(RS_f32, RO, m, d, sm_scale); + + // accumulate d on thread basis +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unrol + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { + d[fq][0] += (RS_f32[fq][fk][0] + RS_f32[fq][fk][1] + RS_f32[fq][fk][4] + RS_f32[fq][fk][5]); + d[fq][1] += (RS_f32[fq][fk][2] + RS_f32[fq][fk][3] + RS_f32[fq][fk][6] + RS_f32[fq][fk][7]); + } + } + + uint32_t RS_f8[num_tiles_q][num_tiles_pv_inner][4]; + RS_32_to_8(RS_f32, RS_f8); + + // wait for V + wait(&barrier_V, p); + + float RO_temp[num_tiles_q][num_tiles_v][8]; + wgmma::warpgroup_arrive(); +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + wgmma::wgmma_f8f8f32(RO_temp[fq], RS_f8[fq][0], &sV[0]); +#pragma unroll + for (uint32_t v_it = 1; v_it < num_tiles_pv_inner; v_it++) + { + wgmma::wgmma_f8f8f32(RO_temp[fq], RS_f8[fq][v_it], &sV[v_it * 32]); + } + } + + wgmma::warpgroup_commit_batch(); + wgmma::warpgroup_wait<0>(); + +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { +#pragma unroll + for (uint32_t k = 0; k < 8; k++) + { + RO[fq][fv][k] += RO_temp[fq][fv][k]; + } + } + } + } + + normalize_d(RO, m, d); + + if constexpr (fuse_v_scale) + { + float v_scale[4]; + float *V_scale_base_ptr = V_scale + batch_id * (num_qo_heads / num_kv_groups) * head_dim + (head_id / num_kv_groups) * head_dim + (lane_id % 4 ) * 2; + #pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { + ((float2*)v_scale)[0] = *((float2*)(V_scale_base_ptr + fv * 16)); + ((float2*)v_scale)[1] = *((float2*)(V_scale_base_ptr + fv * 16 + 8)); + + #pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + RO[fq][fv][0] *= v_scale[0]; + RO[fq][fv][1] *= v_scale[1]; + RO[fq][fv][2] *= v_scale[0]; + RO[fq][fv][3] *= v_scale[1]; + RO[fq][fv][4] *= v_scale[2]; + RO[fq][fv][5] *= v_scale[3]; + RO[fq][fv][6] *= v_scale[2]; + RO[fq][fv][7] *= v_scale[3]; + } + } + } + + DTypeOut *O_lane_ptr = O + batch_id * stride_bz_o + head_id * stride_h_o + (bx * CTA_Q + warp_idx * 16 + (lane_id / 4)) * stride_seq_o + (lane_id % 4) * 2 ; +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fv = 0; fv < head_dim/16; fv++) + { + if (Q_idx_lane_base + fq * 64 < qo_len) + { + if constexpr (std::is_same::value) + { + ((half2*)(O_lane_ptr + fq * 64 * stride_seq_o + fv * 16))[0] = __float22half2_rn(((float2*)(RO[fq][fv]))[0]); + ((half2*)(O_lane_ptr + fq * 64 * stride_seq_o + fv * 16 + 8))[0] = __float22half2_rn(((float2*)(RO[fq][fv]))[2]); + } + else + { + ((nv_bfloat162*)(O_lane_ptr + fq * 64 * stride_seq_o + fv * 16))[0] = __float22bfloat162_rn(((float2*)(RO[fq][fv]))[0]); + ((nv_bfloat162*)(O_lane_ptr + fq * 64 * stride_seq_o + fv * 16 + 8))[0] = __float22bfloat162_rn(((float2*)(RO[fq][fv]))[2]); + } + } + + if (Q_idx_lane_base + fq * 64 + 8 < qo_len) + { + if constexpr (std::is_same::value) + { + ((half2*)(O_lane_ptr + fq * 64 * stride_seq_o + fv * 16 + 8 * stride_seq_o))[0] = __float22half2_rn(((float2*)(RO[fq][fv]))[1]); + ((half2*)(O_lane_ptr + fq * 64 * stride_seq_o + fv * 16 + 8 + 8 * stride_seq_o))[0] = __float22half2_rn(((float2*)(RO[fq][fv]))[3]); + } + else + { + ((nv_bfloat162*)(O_lane_ptr + fq * 64 * stride_seq_o + fv * 16 + 8 * stride_seq_o))[0] = __float22bfloat162_rn(((float2*)(RO[fq][fv]))[1]); + ((nv_bfloat162*)(O_lane_ptr + fq * 64 * stride_seq_o + fv * 16 + 8 + 8 * stride_seq_o))[0] = __float22bfloat162_rn(((float2*)(RO[fq][fv]))[3]); + } + } + } + + if constexpr (return_lse) + { + // only works for CTA_Q = 64 + uint32_t lse_idx = bx * CTA_Q + lane_id / 4 + 8 * (lane_id % 4) + 16 * warp_idx; + float *lse_lane_ptr = Lse + batch_id * (qo_len * num_qo_heads) + head_id * qo_len + lse_idx; + uint32_t fq = (lane_id % 4) / 2; + uint32_t k = (lane_id % 4) % 2; + + if (lse_idx < qo_len && (lane_id % 4) < 2) + { + lse_lane_ptr[0] = (math::ptx_log2(d[fq][k]) + m[fq][k]); + } + } + } +} + +torch::Tensor qk_int8_sv_f8_accum_f32_attn_inst_buf( + torch::Tensor query, + torch::Tensor key, + torch::Tensor value, + torch::Tensor output, + torch::Tensor query_scale, + torch::Tensor key_scale, + int64_t tensor_layout, + int64_t is_causal, + int64_t qk_quant_gran, + double sm_scale, + int64_t return_lse) +{ + CHECK_CUDA(query); + CHECK_CUDA(key); + CHECK_CUDA(value); + CHECK_CUDA(output); + CHECK_CUDA(query_scale); + CHECK_CUDA(key_scale); + + CHECK_LASTDIM_CONTIGUOUS(query); + CHECK_LASTDIM_CONTIGUOUS(key); + CHECK_LASTDIM_CONTIGUOUS(value); + CHECK_LASTDIM_CONTIGUOUS(output); + CHECK_CONTIGUOUS(query_scale); + CHECK_CONTIGUOUS(key_scale); + + CHECK_DTYPE(query, torch::kInt8); + CHECK_DTYPE(key, torch::kInt8); + CHECK_DTYPE(value, at::ScalarType::Float8_e4m3fn); + CHECK_DTYPE(query_scale, torch::kFloat32); + CHECK_DTYPE(key_scale, torch::kFloat32); + + CHECK_DIMS(query, 4); + CHECK_DIMS(key, 4); + CHECK_DIMS(value, 4); + CHECK_DIMS(output, 4); + CHECK_DIMS(query_scale, 3); + CHECK_DIMS(key_scale, 3); + + const int batch_size = query.size(0); + const int head_dim = query.size(3); + + int stride_bz_q = query.stride(0); + int stride_bz_k = key.stride(0); + int stride_bz_v = value.stride(0); + int stride_bz_o = output.stride(0); + + int qo_len, kv_len, padded_kv_len, num_qo_heads, num_kv_heads; + int stride_seq_q, stride_h_q, stride_seq_k, stride_h_k, stride_h_v, stride_d_v, stride_seq_o, stride_h_o; + + assert(value.size(0) == batch_size); + + if (tensor_layout == 0) + { + qo_len = query.size(1); + kv_len = key.size(1); + num_qo_heads = query.size(2); + num_kv_heads = key.size(2); + + stride_seq_q = query.stride(1); + stride_h_q = query.stride(2); + stride_seq_k = key.stride(1); + stride_h_k = key.stride(2); + stride_h_v = value.stride(2); + stride_d_v = value.stride(1); + stride_seq_o = output.stride(1); + stride_h_o = output.stride(2); + + CHECK_SHAPE(key, batch_size, kv_len, num_kv_heads, head_dim); + CHECK_SHAPE(output, batch_size, qo_len, num_qo_heads, head_dim); + assert(value.size(1) == head_dim); + assert(value.size(2) == num_kv_heads); + } + else + { + qo_len = query.size(2); + kv_len = key.size(2); + num_qo_heads = query.size(1); + num_kv_heads = key.size(1); + + stride_seq_q = query.stride(2); + stride_h_q = query.stride(1); + stride_seq_k = key.stride(2); + stride_h_k = key.stride(1); + stride_h_v = value.stride(1); + stride_d_v = value.stride(2); + stride_seq_o = output.stride(2); + stride_h_o = output.stride(1); + + CHECK_SHAPE(key, batch_size, num_kv_heads, kv_len, head_dim); + CHECK_SHAPE(output, batch_size, num_qo_heads, qo_len, head_dim); + assert(value.size(2) == head_dim); + assert(value.size(1) == num_kv_heads); + } + + if (num_qo_heads % num_kv_heads != 0) { + std::ostringstream err_msg; + err_msg << "num_qo_heads (" << num_qo_heads << ") must be divisible by num_kv_heads (" << num_kv_heads << ")"; + throw std::invalid_argument(err_msg.str()); + } + + torch::Tensor lse = torch::empty({0}); + if (return_lse) + { + lse = torch::empty({batch_size, num_qo_heads, qo_len}, query.options().dtype(torch::kFloat32)); + } + + const int num_kv_groups = num_qo_heads / num_kv_heads; + + auto output_type = output.scalar_type(); + + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + DISPATCH_CAUSAL(is_causal, IS_CAUSAL, { + DISPATCH_QK_QUANT_GRAN(qk_quant_gran, QK_QUANT_GRAN, { + DISPATCH_RETURN_LSE(return_lse, RETURN_LSE, { + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(output_type, DTypeOut, { + constexpr int CTA_Q = 64; + constexpr int CTA_K = 128; + constexpr int NUM_THREADS = 128; + + constexpr MaskMode mask_mode = IS_CAUSAL ? MaskMode::kCausal : MaskMode::kNone; + + assert(value.size(3) >= div_ceil(kv_len, CTA_K) * CTA_K); + + if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp)) + { + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, static_cast(div_ceil(qo_len, CTA_Q) * (NUM_THREADS / 32))); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, static_cast(div_ceil(kv_len, CTA_K))); + } + else if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread)) + { + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, static_cast(div_ceil(qo_len, CTA_Q) * (NUM_THREADS / 32) * 8)); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, static_cast(div_ceil(kv_len, CTA_K) * 4)); + } + else + { + static_assert(QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp) || QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread), "Unsupported quantization granularity"); + } + + CUtensorMap tma_map_Q = create_tensor_map_4D(reinterpret_cast(query.data_ptr()), batch_size, num_qo_heads, qo_len, HEAD_DIM, stride_bz_q, stride_h_q, stride_seq_q); + CUtensorMap tma_map_K = create_tensor_map_4D(reinterpret_cast(key.data_ptr()), batch_size, num_kv_heads, kv_len, HEAD_DIM, stride_bz_k, stride_h_k, stride_seq_k); + CUtensorMap tma_map_V = create_tensor_map_4D(reinterpret_cast(value.data_ptr()), batch_size, num_kv_heads, HEAD_DIM, value.size(3), stride_bz_v, stride_h_v, stride_d_v); + + auto* kernel = qk_int8_sv_f8_attn_kernel(QK_QUANT_GRAN), static_cast(QK_QUANT_GRAN), DTypeOut, mask_mode, RETURN_LSE, false>; + size_t sMemSize = CTA_Q * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(int8_t); + cudaFuncSetAttribute( + kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, sMemSize); + + dim3 grid(div_ceil(qo_len, CTA_Q), num_qo_heads, batch_size); + kernel<<>>( + tma_map_Q, + tma_map_K, + tma_map_V, + reinterpret_cast(query_scale.data_ptr()), + reinterpret_cast(key_scale.data_ptr()), + nullptr, + reinterpret_cast(output.data_ptr()), + (RETURN_LSE) ? reinterpret_cast(lse.data_ptr()) : nullptr, + stride_bz_o, stride_h_o, stride_seq_o, + qo_len, kv_len, num_kv_groups, sm_scale); + }); + }); + }); + }); + }); + + return lse; +} + +torch::Tensor qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf( + torch::Tensor query, + torch::Tensor key, + torch::Tensor value, + torch::Tensor output, + torch::Tensor query_scale, + torch::Tensor key_scale, + torch::Tensor value_scale, + int64_t tensor_layout, + int64_t is_causal, + int64_t qk_quant_gran, + double sm_scale, + int64_t return_lse) +{ + CHECK_CUDA(query); + CHECK_CUDA(key); + CHECK_CUDA(value); + CHECK_CUDA(output); + CHECK_CUDA(query_scale); + CHECK_CUDA(key_scale); + CHECK_CUDA(value_scale); + + CHECK_LASTDIM_CONTIGUOUS(query); + CHECK_LASTDIM_CONTIGUOUS(key); + CHECK_LASTDIM_CONTIGUOUS(value); + CHECK_LASTDIM_CONTIGUOUS(output); + CHECK_CONTIGUOUS(query_scale); + CHECK_CONTIGUOUS(key_scale); + CHECK_CONTIGUOUS(value_scale); + + CHECK_DTYPE(query, torch::kInt8); + CHECK_DTYPE(key, torch::kInt8); + CHECK_DTYPE(value, at::ScalarType::Float8_e4m3fn); + CHECK_DTYPE(query_scale, torch::kFloat32); + CHECK_DTYPE(key_scale, torch::kFloat32); + CHECK_DTYPE(value_scale, torch::kFloat32); + + CHECK_DIMS(query, 4); + CHECK_DIMS(key, 4); + CHECK_DIMS(value, 4); + CHECK_DIMS(output, 4); + CHECK_DIMS(query_scale, 3); + CHECK_DIMS(key_scale, 3); + CHECK_DIMS(value_scale, 3); + + const int batch_size = query.size(0); + const int head_dim = query.size(3); + + int stride_bz_q = query.stride(0); + int stride_bz_k = key.stride(0); + int stride_bz_v = value.stride(0); + int stride_bz_o = output.stride(0); + + int qo_len, kv_len, padded_kv_len, num_qo_heads, num_kv_heads; + int stride_seq_q, stride_h_q, stride_seq_k, stride_h_k, stride_h_v, stride_d_v, stride_seq_o, stride_h_o; + + assert(value.size(0) == batch_size); + + if (tensor_layout == 0) + { + qo_len = query.size(1); + kv_len = key.size(1); + num_qo_heads = query.size(2); + num_kv_heads = key.size(2); + + stride_seq_q = query.stride(1); + stride_h_q = query.stride(2); + stride_seq_k = key.stride(1); + stride_h_k = key.stride(2); + stride_h_v = value.stride(2); + stride_d_v = value.stride(1); + stride_seq_o = output.stride(1); + stride_h_o = output.stride(2); + + CHECK_SHAPE(key, batch_size, kv_len, num_kv_heads, head_dim); + CHECK_SHAPE(output, batch_size, qo_len, num_qo_heads, head_dim); + assert(value.size(1) == head_dim); + assert(value.size(2) == num_kv_heads); + } + else + { + qo_len = query.size(2); + kv_len = key.size(2); + num_qo_heads = query.size(1); + num_kv_heads = key.size(1); + + stride_seq_q = query.stride(2); + stride_h_q = query.stride(1); + stride_seq_k = key.stride(2); + stride_h_k = key.stride(1); + stride_h_v = value.stride(1); + stride_d_v = value.stride(2); + stride_seq_o = output.stride(2); + stride_h_o = output.stride(1); + + CHECK_SHAPE(key, batch_size, num_kv_heads, kv_len, head_dim); + CHECK_SHAPE(output, batch_size, num_qo_heads, qo_len, head_dim); + assert(value.size(2) == head_dim); + assert(value.size(1) == num_kv_heads); + } + + if (num_qo_heads % num_kv_heads != 0) { + std::ostringstream err_msg; + err_msg << "num_qo_heads (" << num_qo_heads << ") must be divisible by num_kv_heads (" << num_kv_heads << ")"; + throw std::invalid_argument(err_msg.str()); + } + + torch::Tensor lse = torch::empty({0}); + if (return_lse) + { + lse = torch::empty({batch_size, num_qo_heads, qo_len}, query.options().dtype(torch::kFloat32)); + } + + const int num_kv_groups = num_qo_heads / num_kv_heads; + + auto output_dtype = output.scalar_type(); + + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + DISPATCH_CAUSAL(is_causal, IS_CAUSAL, { + DISPATCH_QK_QUANT_GRAN(qk_quant_gran, QK_QUANT_GRAN, { + DISPATCH_RETURN_LSE(return_lse, RETURN_LSE, { + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(output_dtype, DTypeOut, { + constexpr int CTA_Q = 64; + constexpr int CTA_K = 128; + constexpr int NUM_THREADS = 128; + + constexpr MaskMode mask_mode = IS_CAUSAL ? MaskMode::kCausal : MaskMode::kNone; + + assert(value.size(3) >= div_ceil(kv_len, CTA_K) * CTA_K); + + if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp)) + { + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, static_cast(div_ceil(qo_len, CTA_Q) * (NUM_THREADS / 32))); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, static_cast(div_ceil(kv_len, CTA_K))); + } + else if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread)) + { + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, static_cast(div_ceil(qo_len, CTA_Q) * (NUM_THREADS / 32) * 8)); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, static_cast(div_ceil(kv_len, CTA_K) * 4)); + } + else + { + static_assert(QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp) || QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread), "Unsupported quantization granularity"); + } + + CHECK_SHAPE(value_scale, batch_size, num_kv_heads, head_dim); + + CUtensorMap tma_map_Q = create_tensor_map_4D(reinterpret_cast(query.data_ptr()), batch_size, num_qo_heads, qo_len, HEAD_DIM, stride_bz_q, stride_h_q, stride_seq_q); + CUtensorMap tma_map_K = create_tensor_map_4D(reinterpret_cast(key.data_ptr()), batch_size, num_kv_heads, kv_len, HEAD_DIM, stride_bz_k, stride_h_k, stride_seq_k); + CUtensorMap tma_map_V = create_tensor_map_4D(reinterpret_cast(value.data_ptr()), batch_size, num_kv_heads, HEAD_DIM, value.size(3), stride_bz_v, stride_h_v, stride_d_v); + + auto* kernel = qk_int8_sv_f8_attn_kernel(QK_QUANT_GRAN), static_cast(QK_QUANT_GRAN), DTypeOut, mask_mode, RETURN_LSE, true>; + size_t sMemSize = CTA_Q * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(int8_t); + cudaFuncSetAttribute( + kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, sMemSize); + + dim3 grid(div_ceil(qo_len, CTA_Q), num_qo_heads, batch_size); + kernel<<>>( + tma_map_Q, + tma_map_K, + tma_map_V, + reinterpret_cast(query_scale.data_ptr()), + reinterpret_cast(key_scale.data_ptr()), + reinterpret_cast(value_scale.data_ptr()), + reinterpret_cast(output.data_ptr()), + (RETURN_LSE) ? reinterpret_cast(lse.data_ptr()) : nullptr, + stride_bz_o, stride_h_o, stride_seq_o, + qo_len, kv_len, num_kv_groups, sm_scale); + }); + }); + }); + }); + }); + + return lse; +} \ No newline at end of file diff --git a/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f16_attn_inst_buf.cu b/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f16_attn_inst_buf.cu new file mode 100644 index 0000000000..7fb573706c --- /dev/null +++ b/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f16_attn_inst_buf.cu @@ -0,0 +1,180 @@ +#include "attn_cuda_sm89.h" +#include "qk_int_sv_f8_cuda_sm89.cuh" + +torch::Tensor qk_int8_sv_f8_accum_f16_attn_inst_buf(torch::Tensor query, + torch::Tensor key, + torch::Tensor value, + torch::Tensor output, + torch::Tensor query_scale, + torch::Tensor key_scale, + int64_t tensor_layout, + int64_t is_causal, + int64_t qk_quant_gran, + double sm_scale, + int64_t return_lse) +{ + CHECK_CUDA(query); + CHECK_CUDA(key); + CHECK_CUDA(value); + CHECK_CUDA(output); + CHECK_CUDA(query_scale); + CHECK_CUDA(key_scale); + + CHECK_LASTDIM_CONTIGUOUS(query); + CHECK_LASTDIM_CONTIGUOUS(key); + CHECK_CONTIGUOUS(value); // ensure value is contiguous to prevent troubles in the kernel + CHECK_LASTDIM_CONTIGUOUS(output); + CHECK_CONTIGUOUS(query_scale); + CHECK_CONTIGUOUS(key_scale); + + CHECK_DTYPE(query, torch::kInt8); + CHECK_DTYPE(key, torch::kInt8); + // TODO: how to check fp8 data type? + // CHECK_DTYPE(value, torch::kHalf); + CHECK_DTYPE(query_scale, torch::kFloat32); + CHECK_DTYPE(key_scale, torch::kFloat32); + + CHECK_DIMS(query, 4); + CHECK_DIMS(key, 4); + CHECK_DIMS(value, 4); + CHECK_DIMS(output, 4); + CHECK_DIMS(query_scale, 3); + CHECK_DIMS(key_scale, 3); + + const int batch_size = query.size(0); + const int head_dim = query.size(3); + + int stride_bz_q = query.stride(0); + int stride_bz_k = key.stride(0); + int stride_bz_v = value.stride(0); + int stride_bz_o = output.stride(0); + + int qo_len, kv_len, num_qo_heads, num_kv_heads; + int stride_seq_q, stride_h_q, stride_seq_k, stride_h_k, stride_h_v, stride_d_v, stride_seq_o, stride_h_o; + + if (tensor_layout == 0) + { + qo_len = query.size(1); + kv_len = key.size(1); + num_qo_heads = query.size(2); + num_kv_heads = key.size(2); + + stride_seq_q = query.stride(1); + stride_h_q = query.stride(2); + stride_seq_k = key.stride(1); + stride_h_k = key.stride(2); + stride_h_v = value.stride(2); + stride_d_v = value.stride(1); + stride_seq_o = output.stride(1); + stride_h_o = output.stride(2); + + CHECK_SHAPE(key, batch_size, kv_len, num_kv_heads, head_dim); + CHECK_SHAPE(output, batch_size, qo_len, num_qo_heads, head_dim); + assert(value.size(1) == head_dim); + assert(value.size(2) == num_kv_heads); + } + else + { + qo_len = query.size(2); + kv_len = key.size(2); + num_qo_heads = query.size(1); + num_kv_heads = key.size(1); + + stride_seq_q = query.stride(2); + stride_h_q = query.stride(1); + stride_seq_k = key.stride(2); + stride_h_k = key.stride(1); + stride_h_v = value.stride(1); + stride_d_v = value.stride(2); + stride_seq_o = output.stride(2); + stride_h_o = output.stride(1); + + CHECK_SHAPE(key, batch_size, num_kv_heads, kv_len, head_dim); + CHECK_SHAPE(output, batch_size, num_qo_heads, qo_len, head_dim); + assert(value.size(2) == head_dim); + assert(value.size(1) == num_kv_heads); + } + + if (num_qo_heads % num_kv_heads != 0) { + std::ostringstream err_msg; + err_msg << "num_qo_heads (" << num_qo_heads << ") must be divisible by num_kv_heads (" << num_kv_heads << ")"; + throw std::invalid_argument(err_msg.str()); + } + + torch::Tensor lse = torch::empty({0}); + if (return_lse) + { + lse = torch::empty({batch_size, num_qo_heads, qo_len}, query.options().dtype(torch::kFloat32)); + } + + const int num_kv_groups = num_qo_heads / num_kv_heads; + + auto output_dtype = output.scalar_type(); + + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + DISPATCH_CAUSAL(is_causal, IS_CAUSAL, { + DISPATCH_QK_QUANT_GRAN(qk_quant_gran, QK_QUANT_GRAN, { + DISPATCH_RETURN_LSE(return_lse, RETURN_LSE, { + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(output_dtype, DTypeOut, { + constexpr int CTA_Q = 128; + constexpr int CTA_K = 64; + constexpr int WARP_Q = 32; + constexpr int WARP_K = 64; + + assert(value.size(0) == batch_size); + assert(value.size(3) >= div_ceil(kv_len, CTA_K) * CTA_K); + + constexpr MaskMode mask_mode = IS_CAUSAL ? MaskMode::kCausal : MaskMode::kNone; + + if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp)) + { + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, div_ceil(qo_len, CTA_Q) * (CTA_Q / WARP_Q)); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K)); + } + else if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread)) + { + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, div_ceil(qo_len, CTA_Q) * (CTA_Q / WARP_Q) * 8); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K) * 4); + } + else + { + static_assert(QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp) || QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread), "Unsupported quantization granularity"); + } + + // smem_Q smem_K smem_V smem_O + size_t smem_max = std::max(CTA_Q * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(int8_t), CTA_Q * HEAD_DIM * sizeof(half)); + + auto kernel_func = qk_int_sv_f8_attn_kernel(QK_QUANT_GRAN), static_cast(QK_QUANT_GRAN), + float, true, DTypeOut, ComputeUnit::kCudaCore, mask_mode, RETURN_LSE, false, false, true>; + + cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_max); + + dim3 grid(div_ceil(qo_len, CTA_Q), num_qo_heads, batch_size); + dim3 block(32, (CTA_Q / WARP_Q) * (CTA_K / WARP_K)); + + kernel_func<<>>( + query.data_ptr(), + key.data_ptr(), + reinterpret_cast(value.data_ptr()), + reinterpret_cast(output.data_ptr()), + (RETURN_LSE) ? reinterpret_cast(lse.data_ptr()) : nullptr, + reinterpret_cast(query_scale.data_ptr()), + reinterpret_cast(key_scale.data_ptr()), + nullptr, + nullptr, + qo_len, + kv_len, + num_kv_groups, + stride_bz_q, stride_seq_q, stride_h_q, + stride_bz_k, stride_seq_k, stride_h_k, + stride_bz_v, stride_h_v, stride_d_v, + stride_bz_o, stride_seq_o, stride_h_o, + sm_scale); + }); + }); + }); + }); + }); + + return lse; +} \ No newline at end of file diff --git a/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf.cu b/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf.cu new file mode 100644 index 0000000000..212e4153e2 --- /dev/null +++ b/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf.cu @@ -0,0 +1,187 @@ +#include "attn_cuda_sm89.h" +#include "qk_int_sv_f8_cuda_sm89.cuh" +torch::Tensor qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf(torch::Tensor query, + torch::Tensor key, + torch::Tensor value, + torch::Tensor output, + torch::Tensor query_scale, + torch::Tensor key_scale, + torch::Tensor value_scale, + int64_t tensor_layout, + int64_t is_causal, + int64_t qk_quant_gran, + double sm_scale, + int64_t return_lse) +{ + CHECK_CUDA(query); + CHECK_CUDA(key); + CHECK_CUDA(value); + CHECK_CUDA(output); + CHECK_CUDA(query_scale); + CHECK_CUDA(key_scale); + CHECK_CUDA(value_scale); + + CHECK_LASTDIM_CONTIGUOUS(query); + CHECK_LASTDIM_CONTIGUOUS(key); + CHECK_CONTIGUOUS(value); // ensure value is contiguous to prevent troubles in the kernel + CHECK_LASTDIM_CONTIGUOUS(output); + CHECK_CONTIGUOUS(query_scale); + CHECK_CONTIGUOUS(key_scale); + CHECK_CONTIGUOUS(value_scale); + + CHECK_DTYPE(query, torch::kInt8); + CHECK_DTYPE(key, torch::kInt8); + // TODO: how to check fp8 data type? + // CHECK_DTYPE(value, torch::kHalf); + CHECK_DTYPE(query_scale, torch::kFloat32); + CHECK_DTYPE(key_scale, torch::kFloat32); + CHECK_DTYPE(value_scale, torch::kFloat32); + + CHECK_DIMS(query, 4); + CHECK_DIMS(key, 4); + CHECK_DIMS(value, 4); + CHECK_DIMS(output, 4); + CHECK_DIMS(query_scale, 3); + CHECK_DIMS(key_scale, 3); + CHECK_DIMS(value_scale, 3); + + const int batch_size = query.size(0); + const int head_dim = query.size(3); + + int stride_bz_q = query.stride(0); + int stride_bz_k = key.stride(0); + int stride_bz_v = value.stride(0); + int stride_bz_o = output.stride(0); + + int qo_len, kv_len, num_qo_heads, num_kv_heads; + int stride_seq_q, stride_h_q, stride_seq_k, stride_h_k, stride_h_v, stride_d_v, stride_seq_o, stride_h_o; + + if (tensor_layout == 0) + { + qo_len = query.size(1); + kv_len = key.size(1); + num_qo_heads = query.size(2); + num_kv_heads = key.size(2); + + stride_seq_q = query.stride(1); + stride_h_q = query.stride(2); + stride_seq_k = key.stride(1); + stride_h_k = key.stride(2); + stride_h_v = value.stride(2); + stride_d_v = value.stride(1); + stride_seq_o = output.stride(1); + stride_h_o = output.stride(2); + + CHECK_SHAPE(key, batch_size, kv_len, num_kv_heads, head_dim); + CHECK_SHAPE(output, batch_size, qo_len, num_qo_heads, head_dim); + assert(value.size(1) == head_dim); + assert(value.size(2) == num_kv_heads); + } + else + { + qo_len = query.size(2); + kv_len = key.size(2); + num_qo_heads = query.size(1); + num_kv_heads = key.size(1); + + stride_seq_q = query.stride(2); + stride_h_q = query.stride(1); + stride_seq_k = key.stride(2); + stride_h_k = key.stride(1); + stride_h_v = value.stride(1); + stride_d_v = value.stride(2); + stride_seq_o = output.stride(2); + stride_h_o = output.stride(1); + + CHECK_SHAPE(key, batch_size, num_kv_heads, kv_len, head_dim); + CHECK_SHAPE(output, batch_size, num_qo_heads, qo_len, head_dim); + assert(value.size(2) == head_dim); + assert(value.size(1) == num_kv_heads); + } + + if (num_qo_heads % num_kv_heads != 0) { + std::ostringstream err_msg; + err_msg << "num_qo_heads (" << num_qo_heads << ") must be divisible by num_kv_heads (" << num_kv_heads << ")"; + throw std::invalid_argument(err_msg.str()); + } + + torch::Tensor lse = torch::empty({0}); + if (return_lse) + { + lse = torch::empty({batch_size, num_qo_heads, qo_len}, query.options().dtype(torch::kFloat32)); + } + + const int num_kv_groups = num_qo_heads / num_kv_heads; + + auto output_dtype = output.scalar_type(); + + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + DISPATCH_CAUSAL(is_causal, IS_CAUSAL, { + DISPATCH_QK_QUANT_GRAN(qk_quant_gran, QK_QUANT_GRAN, { + DISPATCH_RETURN_LSE(return_lse, RETURN_LSE, { + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(output_dtype, DTypeOut, { + + constexpr int CTA_Q = 128; + constexpr int CTA_K = 64; + constexpr int WARP_Q = 32; + constexpr int WARP_K = 64; + + assert(value.size(0) == batch_size); + assert(value.size(3) >= div_ceil(kv_len, CTA_K) * CTA_K); + + constexpr MaskMode mask_mode = IS_CAUSAL ? MaskMode::kCausal : MaskMode::kNone; + + if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp)) + { + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, div_ceil(qo_len, CTA_Q) * (CTA_Q / WARP_Q)); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K)); + } + else if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread)) + { + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, div_ceil(qo_len, CTA_Q) * (CTA_Q / WARP_Q) * 8); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K) * 4); + } + else + { + static_assert(QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp) || QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread), "Unsupported quantization granularity"); + } + + CHECK_SHAPE(value_scale, batch_size, num_kv_heads, head_dim); + + // smem_Q smem_K smem_V smem_O + size_t smem_max = std::max(CTA_Q * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(int8_t), CTA_Q * HEAD_DIM * sizeof(half)); + + auto kernel_func = qk_int_sv_f8_attn_kernel(QK_QUANT_GRAN), static_cast(QK_QUANT_GRAN), + float, true, DTypeOut, ComputeUnit::kCudaCore, mask_mode, RETURN_LSE, true, false, true>; + + cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_max); + + dim3 grid(div_ceil(qo_len, CTA_Q), num_qo_heads, batch_size); + dim3 block(32, (CTA_Q / WARP_Q) * (CTA_K / WARP_K)); + + kernel_func<<>>( + query.data_ptr(), + key.data_ptr(), + reinterpret_cast(value.data_ptr()), + reinterpret_cast(output.data_ptr()), + (RETURN_LSE) ? reinterpret_cast(lse.data_ptr()) : nullptr, + reinterpret_cast(query_scale.data_ptr()), + reinterpret_cast(key_scale.data_ptr()), + reinterpret_cast(value_scale.data_ptr()), + nullptr, + qo_len, + kv_len, + num_kv_groups, + stride_bz_q, stride_seq_q, stride_h_q, + stride_bz_k, stride_seq_k, stride_h_k, + stride_bz_v, stride_h_v, stride_d_v, + stride_bz_o, stride_seq_o, stride_h_o, + sm_scale); + }); + }); + }); + }); + }); + + return lse; +} \ No newline at end of file diff --git a/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_attn.cu b/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_attn.cu new file mode 100644 index 0000000000..96ce7334b1 --- /dev/null +++ b/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_attn.cu @@ -0,0 +1,180 @@ +#include "attn_cuda_sm89.h" +#include "qk_int_sv_f8_cuda_sm89.cuh" + +torch::Tensor qk_int8_sv_f8_accum_f32_attn(torch::Tensor query, + torch::Tensor key, + torch::Tensor value, + torch::Tensor output, + torch::Tensor query_scale, + torch::Tensor key_scale, + int64_t tensor_layout, + int64_t is_causal, + int64_t qk_quant_gran, + double sm_scale, + int64_t return_lse) +{ + CHECK_CUDA(query); + CHECK_CUDA(key); + CHECK_CUDA(value); + CHECK_CUDA(output); + CHECK_CUDA(query_scale); + CHECK_CUDA(key_scale); + + CHECK_LASTDIM_CONTIGUOUS(query); + CHECK_LASTDIM_CONTIGUOUS(key); + CHECK_CONTIGUOUS(value); // ensure value is contiguous to prevent troubles in the kernel + CHECK_LASTDIM_CONTIGUOUS(output); + CHECK_CONTIGUOUS(query_scale); + CHECK_CONTIGUOUS(key_scale); + + CHECK_DTYPE(query, torch::kInt8); + CHECK_DTYPE(key, torch::kInt8); + // TODO: how to check fp8 data type? + // CHECK_DTYPE(value, torch::kHalf); + CHECK_DTYPE(query_scale, torch::kFloat32); + CHECK_DTYPE(key_scale, torch::kFloat32); + + CHECK_DIMS(query, 4); + CHECK_DIMS(key, 4); + CHECK_DIMS(value, 4); + CHECK_DIMS(output, 4); + CHECK_DIMS(query_scale, 3); + CHECK_DIMS(key_scale, 3); + + const int batch_size = query.size(0); + const int head_dim = query.size(3); + + int stride_bz_q = query.stride(0); + int stride_bz_k = key.stride(0); + int stride_bz_v = value.stride(0); + int stride_bz_o = output.stride(0); + + int qo_len, kv_len, num_qo_heads, num_kv_heads; + int stride_seq_q, stride_h_q, stride_seq_k, stride_h_k, stride_h_v, stride_d_v, stride_seq_o, stride_h_o; + + if (tensor_layout == 0) + { + qo_len = query.size(1); + kv_len = key.size(1); + num_qo_heads = query.size(2); + num_kv_heads = key.size(2); + + stride_seq_q = query.stride(1); + stride_h_q = query.stride(2); + stride_seq_k = key.stride(1); + stride_h_k = key.stride(2); + stride_h_v = value.stride(2); + stride_d_v = value.stride(1); + stride_seq_o = output.stride(1); + stride_h_o = output.stride(2); + + CHECK_SHAPE(key, batch_size, kv_len, num_kv_heads, head_dim); + CHECK_SHAPE(output, batch_size, qo_len, num_qo_heads, head_dim); + assert(value.size(1) == head_dim); + assert(value.size(2) == num_kv_heads); + } + else + { + qo_len = query.size(2); + kv_len = key.size(2); + num_qo_heads = query.size(1); + num_kv_heads = key.size(1); + + stride_seq_q = query.stride(2); + stride_h_q = query.stride(1); + stride_seq_k = key.stride(2); + stride_h_k = key.stride(1); + stride_h_v = value.stride(1); + stride_d_v = value.stride(2); + stride_seq_o = output.stride(2); + stride_h_o = output.stride(1); + + CHECK_SHAPE(key, batch_size, num_kv_heads, kv_len, head_dim); + CHECK_SHAPE(output, batch_size, num_qo_heads, qo_len, head_dim); + assert(value.size(2) == head_dim); + assert(value.size(1) == num_kv_heads); + } + + if (num_qo_heads % num_kv_heads != 0) { + std::ostringstream err_msg; + err_msg << "num_qo_heads (" << num_qo_heads << ") must be divisible by num_kv_heads (" << num_kv_heads << ")"; + throw std::invalid_argument(err_msg.str()); + } + + torch::Tensor lse = torch::empty({0}); + if (return_lse) + { + lse = torch::empty({batch_size, num_qo_heads, qo_len}, query.options().dtype(torch::kFloat32)); + } + + const int num_kv_groups = num_qo_heads / num_kv_heads; + + auto output_dtype = output.scalar_type(); + + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + DISPATCH_CAUSAL(is_causal, IS_CAUSAL, { + DISPATCH_QK_QUANT_GRAN(qk_quant_gran, QK_QUANT_GRAN, { + DISPATCH_RETURN_LSE(return_lse, RETURN_LSE, { + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(output_dtype, DTypeOut, { + constexpr int CTA_Q = 128; + constexpr int CTA_K = 64; + constexpr int WARP_Q = 32; + constexpr int WARP_K = 64; + + assert(value.size(0) == batch_size); + assert(value.size(3) >= div_ceil(kv_len, CTA_K) * CTA_K); + + constexpr MaskMode mask_mode = IS_CAUSAL ? MaskMode::kCausal : MaskMode::kNone; + + if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp)) + { + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, div_ceil(qo_len, CTA_Q) * (CTA_Q / WARP_Q)); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K)); + } + else if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread)) + { + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, div_ceil(qo_len, CTA_Q) * (CTA_Q / WARP_Q) * 8); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K) * 4); + } + else + { + static_assert(QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp) || QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread), "Unsupported quantization granularity"); + } + + // smem_Q smem_K smem_V smem_O + size_t smem_max = std::max(CTA_Q * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(int8_t), CTA_Q * HEAD_DIM * sizeof(half)); + + auto kernel_func = qk_int_sv_f8_attn_kernel(QK_QUANT_GRAN), static_cast(QK_QUANT_GRAN), + float, false, DTypeOut, ComputeUnit::kCudaCore, mask_mode, RETURN_LSE, false, false, false>; + + cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_max); + + dim3 grid(div_ceil(qo_len, CTA_Q), num_qo_heads, batch_size); + dim3 block(32, (CTA_Q / WARP_Q) * (CTA_K / WARP_K)); + + kernel_func<<>>( + query.data_ptr(), + key.data_ptr(), + reinterpret_cast(value.data_ptr()), + reinterpret_cast(output.data_ptr()), + (RETURN_LSE) ? reinterpret_cast(lse.data_ptr()) : nullptr, + reinterpret_cast(query_scale.data_ptr()), + reinterpret_cast(key_scale.data_ptr()), + nullptr, + nullptr, + qo_len, + kv_len, + num_kv_groups, + stride_bz_q, stride_seq_q, stride_h_q, + stride_bz_k, stride_seq_k, stride_h_k, + stride_bz_v, stride_h_v, stride_d_v, + stride_bz_o, stride_seq_o, stride_h_o, + sm_scale); + }); + }); + }); + }); + }); + + return lse; +} \ No newline at end of file diff --git a/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_attn_inst_buf.cu b/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_attn_inst_buf.cu new file mode 100644 index 0000000000..b19b07b605 --- /dev/null +++ b/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_attn_inst_buf.cu @@ -0,0 +1,179 @@ +#include "attn_cuda_sm89.h" +#include "qk_int_sv_f8_cuda_sm89.cuh" +torch::Tensor qk_int8_sv_f8_accum_f32_attn_inst_buf(torch::Tensor query, + torch::Tensor key, + torch::Tensor value, + torch::Tensor output, + torch::Tensor query_scale, + torch::Tensor key_scale, + int64_t tensor_layout, + int64_t is_causal, + int64_t qk_quant_gran, + double sm_scale, + int64_t return_lse) +{ + CHECK_CUDA(query); + CHECK_CUDA(key); + CHECK_CUDA(value); + CHECK_CUDA(output); + CHECK_CUDA(query_scale); + CHECK_CUDA(key_scale); + + CHECK_LASTDIM_CONTIGUOUS(query); + CHECK_LASTDIM_CONTIGUOUS(key); + CHECK_CONTIGUOUS(value); // ensure value is contiguous to prevent troubles in the kernel + CHECK_LASTDIM_CONTIGUOUS(output); + CHECK_CONTIGUOUS(query_scale); + CHECK_CONTIGUOUS(key_scale); + + CHECK_DTYPE(query, torch::kInt8); + CHECK_DTYPE(key, torch::kInt8); + // TODO: how to check fp8 data type? + // CHECK_DTYPE(value, torch::kHalf); + CHECK_DTYPE(query_scale, torch::kFloat32); + CHECK_DTYPE(key_scale, torch::kFloat32); + + CHECK_DIMS(query, 4); + CHECK_DIMS(key, 4); + CHECK_DIMS(value, 4); + CHECK_DIMS(output, 4); + CHECK_DIMS(query_scale, 3); + CHECK_DIMS(key_scale, 3); + + const int batch_size = query.size(0); + const int head_dim = query.size(3); + + int stride_bz_q = query.stride(0); + int stride_bz_k = key.stride(0); + int stride_bz_v = value.stride(0); + int stride_bz_o = output.stride(0); + + int qo_len, kv_len, num_qo_heads, num_kv_heads; + int stride_seq_q, stride_h_q, stride_seq_k, stride_h_k, stride_h_v, stride_d_v, stride_seq_o, stride_h_o; + + if (tensor_layout == 0) + { + qo_len = query.size(1); + kv_len = key.size(1); + num_qo_heads = query.size(2); + num_kv_heads = key.size(2); + + stride_seq_q = query.stride(1); + stride_h_q = query.stride(2); + stride_seq_k = key.stride(1); + stride_h_k = key.stride(2); + stride_h_v = value.stride(2); + stride_d_v = value.stride(1); + stride_seq_o = output.stride(1); + stride_h_o = output.stride(2); + + CHECK_SHAPE(key, batch_size, kv_len, num_kv_heads, head_dim); + CHECK_SHAPE(output, batch_size, qo_len, num_qo_heads, head_dim); + assert(value.size(1) == head_dim); + assert(value.size(2) == num_kv_heads); + } + else + { + qo_len = query.size(2); + kv_len = key.size(2); + num_qo_heads = query.size(1); + num_kv_heads = key.size(1); + + stride_seq_q = query.stride(2); + stride_h_q = query.stride(1); + stride_seq_k = key.stride(2); + stride_h_k = key.stride(1); + stride_h_v = value.stride(1); + stride_d_v = value.stride(2); + stride_seq_o = output.stride(2); + stride_h_o = output.stride(1); + + CHECK_SHAPE(key, batch_size, num_kv_heads, kv_len, head_dim); + CHECK_SHAPE(output, batch_size, num_qo_heads, qo_len, head_dim); + assert(value.size(2) == head_dim); + assert(value.size(1) == num_kv_heads); + } + + if (num_qo_heads % num_kv_heads != 0) { + std::ostringstream err_msg; + err_msg << "num_qo_heads (" << num_qo_heads << ") must be divisible by num_kv_heads (" << num_kv_heads << ")"; + throw std::invalid_argument(err_msg.str()); + } + + torch::Tensor lse = torch::empty({0}); + if (return_lse) + { + lse = torch::empty({batch_size, num_qo_heads, qo_len}, query.options().dtype(torch::kFloat32)); + } + + const int num_kv_groups = num_qo_heads / num_kv_heads; + + auto output_dtype = output.scalar_type(); + + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + DISPATCH_CAUSAL(is_causal, IS_CAUSAL, { + DISPATCH_QK_QUANT_GRAN(qk_quant_gran, QK_QUANT_GRAN, { + DISPATCH_RETURN_LSE(return_lse, RETURN_LSE, { + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(output_dtype, DTypeOut, { + constexpr int CTA_Q = 128; + constexpr int CTA_K = 64; + constexpr int WARP_Q = 32; + constexpr int WARP_K = 64; + + assert(value.size(0) == batch_size); + assert(value.size(3) >= div_ceil(kv_len, CTA_K) * CTA_K); + + constexpr MaskMode mask_mode = IS_CAUSAL ? MaskMode::kCausal : MaskMode::kNone; + + if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp)) + { + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, div_ceil(qo_len, CTA_Q) * (CTA_Q / WARP_Q)); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K)); + } + else if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread)) + { + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, div_ceil(qo_len, CTA_Q) * (CTA_Q / WARP_Q) * 8); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K) * 4); + } + else + { + static_assert(QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp) || QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread), "Unsupported quantization granularity"); + } + + // smem_Q smem_K smem_V smem_O + size_t smem_max = std::max(CTA_Q * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(int8_t), CTA_Q * HEAD_DIM * sizeof(half)); + + auto kernel_func = qk_int_sv_f8_attn_kernel(QK_QUANT_GRAN), static_cast(QK_QUANT_GRAN), + float, true, DTypeOut, ComputeUnit::kCudaCore, mask_mode, RETURN_LSE, false, false, false>; + + cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_max); + + dim3 grid(div_ceil(qo_len, CTA_Q), num_qo_heads, batch_size); + dim3 block(32, (CTA_Q / WARP_Q) * (CTA_K / WARP_K)); + + kernel_func<<>>( + query.data_ptr(), + key.data_ptr(), + reinterpret_cast(value.data_ptr()), + reinterpret_cast(output.data_ptr()), + (RETURN_LSE) ? reinterpret_cast(lse.data_ptr()) : nullptr, + reinterpret_cast(query_scale.data_ptr()), + reinterpret_cast(key_scale.data_ptr()), + nullptr, + nullptr, + qo_len, + kv_len, + num_kv_groups, + stride_bz_q, stride_seq_q, stride_h_q, + stride_bz_k, stride_seq_k, stride_h_k, + stride_bz_v, stride_h_v, stride_d_v, + stride_bz_o, stride_seq_o, stride_h_o, + sm_scale); + }); + }); + }); + }); + }); + + return lse; +} \ No newline at end of file diff --git a/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn.cu b/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn.cu new file mode 100644 index 0000000000..20e72af6fb --- /dev/null +++ b/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn.cu @@ -0,0 +1,187 @@ +#include "attn_cuda_sm89.h" +#include "qk_int_sv_f8_cuda_sm89.cuh" +torch::Tensor qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(torch::Tensor query, + torch::Tensor key, + torch::Tensor value, + torch::Tensor output, + torch::Tensor query_scale, + torch::Tensor key_scale, + torch::Tensor value_scale, + int64_t tensor_layout, + int64_t is_causal, + int64_t qk_quant_gran, + double sm_scale, + int64_t return_lse) +{ + CHECK_CUDA(query); + CHECK_CUDA(key); + CHECK_CUDA(value); + CHECK_CUDA(output); + CHECK_CUDA(query_scale); + CHECK_CUDA(key_scale); + CHECK_CUDA(value_scale); + + CHECK_LASTDIM_CONTIGUOUS(query); + CHECK_LASTDIM_CONTIGUOUS(key); + CHECK_CONTIGUOUS(value); // ensure value is contiguous to prevent troubles in the kernel + CHECK_LASTDIM_CONTIGUOUS(output); + CHECK_CONTIGUOUS(query_scale); + CHECK_CONTIGUOUS(key_scale); + CHECK_CONTIGUOUS(value_scale); + + CHECK_DTYPE(query, torch::kInt8); + CHECK_DTYPE(key, torch::kInt8); + // TODO: how to check fp8 data type? + // CHECK_DTYPE(value, torch::kHalf); + CHECK_DTYPE(query_scale, torch::kFloat32); + CHECK_DTYPE(key_scale, torch::kFloat32); + CHECK_DTYPE(value_scale, torch::kFloat32); + + CHECK_DIMS(query, 4); + CHECK_DIMS(key, 4); + CHECK_DIMS(value, 4); + CHECK_DIMS(output, 4); + CHECK_DIMS(query_scale, 3); + CHECK_DIMS(key_scale, 3); + CHECK_DIMS(value_scale, 3); + + const int batch_size = query.size(0); + const int head_dim = query.size(3); + + int stride_bz_q = query.stride(0); + int stride_bz_k = key.stride(0); + int stride_bz_v = value.stride(0); + int stride_bz_o = output.stride(0); + + int qo_len, kv_len, num_qo_heads, num_kv_heads; + int stride_seq_q, stride_h_q, stride_seq_k, stride_h_k, stride_h_v, stride_d_v, stride_seq_o, stride_h_o; + + if (tensor_layout == 0) + { + qo_len = query.size(1); + kv_len = key.size(1); + num_qo_heads = query.size(2); + num_kv_heads = key.size(2); + + stride_seq_q = query.stride(1); + stride_h_q = query.stride(2); + stride_seq_k = key.stride(1); + stride_h_k = key.stride(2); + stride_h_v = value.stride(2); + stride_d_v = value.stride(1); + stride_seq_o = output.stride(1); + stride_h_o = output.stride(2); + + CHECK_SHAPE(key, batch_size, kv_len, num_kv_heads, head_dim); + CHECK_SHAPE(output, batch_size, qo_len, num_qo_heads, head_dim); + assert(value.size(1) == head_dim); + assert(value.size(2) == num_kv_heads); + } + else + { + qo_len = query.size(2); + kv_len = key.size(2); + num_qo_heads = query.size(1); + num_kv_heads = key.size(1); + + stride_seq_q = query.stride(2); + stride_h_q = query.stride(1); + stride_seq_k = key.stride(2); + stride_h_k = key.stride(1); + stride_h_v = value.stride(1); + stride_d_v = value.stride(2); + stride_seq_o = output.stride(2); + stride_h_o = output.stride(1); + + CHECK_SHAPE(key, batch_size, num_kv_heads, kv_len, head_dim); + CHECK_SHAPE(output, batch_size, num_qo_heads, qo_len, head_dim); + assert(value.size(2) == head_dim); + assert(value.size(1) == num_kv_heads); + } + + if (num_qo_heads % num_kv_heads != 0) { + std::ostringstream err_msg; + err_msg << "num_qo_heads (" << num_qo_heads << ") must be divisible by num_kv_heads (" << num_kv_heads << ")"; + throw std::invalid_argument(err_msg.str()); + } + + torch::Tensor lse = torch::empty({0}); + if (return_lse) + { + lse = torch::empty({batch_size, num_qo_heads, qo_len}, query.options().dtype(torch::kFloat32)); + } + + const int num_kv_groups = num_qo_heads / num_kv_heads; + + auto output_dtype = output.scalar_type(); + + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + DISPATCH_CAUSAL(is_causal, IS_CAUSAL, { + DISPATCH_QK_QUANT_GRAN(qk_quant_gran, QK_QUANT_GRAN, { + DISPATCH_RETURN_LSE(return_lse, RETURN_LSE, { + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(output_dtype, DTypeOut, { + + constexpr int CTA_Q = 128; + constexpr int CTA_K = 64; + constexpr int WARP_Q = 32; + constexpr int WARP_K = 64; + + assert(value.size(0) == batch_size); + assert(value.size(3) >= div_ceil(kv_len, CTA_K) * CTA_K); + + constexpr MaskMode mask_mode = IS_CAUSAL ? MaskMode::kCausal : MaskMode::kNone; + + if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp)) + { + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, div_ceil(qo_len, CTA_Q) * (CTA_Q / WARP_Q)); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K)); + } + else if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread)) + { + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, div_ceil(qo_len, CTA_Q) * (CTA_Q / WARP_Q) * 8); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K) * 4); + } + else + { + static_assert(QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp) || QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread), "Unsupported quantization granularity"); + } + + CHECK_SHAPE(value_scale, batch_size, num_kv_heads, head_dim); + + // smem_Q smem_K smem_V smem_O + size_t smem_max = std::max(CTA_Q * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(int8_t), CTA_Q * HEAD_DIM * sizeof(half)); + + auto kernel_func = qk_int_sv_f8_attn_kernel(QK_QUANT_GRAN), static_cast(QK_QUANT_GRAN), + float, false, DTypeOut, ComputeUnit::kCudaCore, mask_mode, RETURN_LSE, true, false, false>; + + cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_max); + + dim3 grid(div_ceil(qo_len, CTA_Q), num_qo_heads, batch_size); + dim3 block(32, (CTA_Q / WARP_Q) * (CTA_K / WARP_K)); + + kernel_func<<>>( + query.data_ptr(), + key.data_ptr(), + reinterpret_cast(value.data_ptr()), + reinterpret_cast(output.data_ptr()), + (RETURN_LSE) ? reinterpret_cast(lse.data_ptr()) : nullptr, + reinterpret_cast(query_scale.data_ptr()), + reinterpret_cast(key_scale.data_ptr()), + reinterpret_cast(value_scale.data_ptr()), + nullptr, + qo_len, + kv_len, + num_kv_groups, + stride_bz_q, stride_seq_q, stride_h_q, + stride_bz_k, stride_seq_k, stride_h_k, + stride_bz_v, stride_h_v, stride_d_v, + stride_bz_o, stride_seq_o, stride_h_o, + sm_scale); + }); + }); + }); + }); + }); + + return lse; +} \ No newline at end of file diff --git a/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf.cu b/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf.cu new file mode 100644 index 0000000000..83c065ad7e --- /dev/null +++ b/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf.cu @@ -0,0 +1,187 @@ +#include "attn_cuda_sm89.h" +#include "qk_int_sv_f8_cuda_sm89.cuh" +torch::Tensor qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(torch::Tensor query, + torch::Tensor key, + torch::Tensor value, + torch::Tensor output, + torch::Tensor query_scale, + torch::Tensor key_scale, + torch::Tensor value_scale, + int64_t tensor_layout, + int64_t is_causal, + int64_t qk_quant_gran, + double sm_scale, + int64_t return_lse) +{ + CHECK_CUDA(query); + CHECK_CUDA(key); + CHECK_CUDA(value); + CHECK_CUDA(output); + CHECK_CUDA(query_scale); + CHECK_CUDA(key_scale); + CHECK_CUDA(value_scale); + + CHECK_LASTDIM_CONTIGUOUS(query); + CHECK_LASTDIM_CONTIGUOUS(key); + CHECK_CONTIGUOUS(value); // ensure value is contiguous to prevent troubles in the kernel + CHECK_LASTDIM_CONTIGUOUS(output); + CHECK_CONTIGUOUS(query_scale); + CHECK_CONTIGUOUS(key_scale); + CHECK_CONTIGUOUS(value_scale); + + CHECK_DTYPE(query, torch::kInt8); + CHECK_DTYPE(key, torch::kInt8); + // TODO: how to check fp8 data type? + // CHECK_DTYPE(value, torch::kHalf); + CHECK_DTYPE(query_scale, torch::kFloat32); + CHECK_DTYPE(key_scale, torch::kFloat32); + CHECK_DTYPE(value_scale, torch::kFloat32); + + CHECK_DIMS(query, 4); + CHECK_DIMS(key, 4); + CHECK_DIMS(value, 4); + CHECK_DIMS(output, 4); + CHECK_DIMS(query_scale, 3); + CHECK_DIMS(key_scale, 3); + CHECK_DIMS(value_scale, 3); + + const int batch_size = query.size(0); + const int head_dim = query.size(3); + + int stride_bz_q = query.stride(0); + int stride_bz_k = key.stride(0); + int stride_bz_v = value.stride(0); + int stride_bz_o = output.stride(0); + + int qo_len, kv_len, num_qo_heads, num_kv_heads; + int stride_seq_q, stride_h_q, stride_seq_k, stride_h_k, stride_h_v, stride_d_v, stride_seq_o, stride_h_o; + + if (tensor_layout == 0) + { + qo_len = query.size(1); + kv_len = key.size(1); + num_qo_heads = query.size(2); + num_kv_heads = key.size(2); + + stride_seq_q = query.stride(1); + stride_h_q = query.stride(2); + stride_seq_k = key.stride(1); + stride_h_k = key.stride(2); + stride_h_v = value.stride(2); + stride_d_v = value.stride(1); + stride_seq_o = output.stride(1); + stride_h_o = output.stride(2); + + CHECK_SHAPE(key, batch_size, kv_len, num_kv_heads, head_dim); + CHECK_SHAPE(output, batch_size, qo_len, num_qo_heads, head_dim); + assert(value.size(1) == head_dim); + assert(value.size(2) == num_kv_heads); + } + else + { + qo_len = query.size(2); + kv_len = key.size(2); + num_qo_heads = query.size(1); + num_kv_heads = key.size(1); + + stride_seq_q = query.stride(2); + stride_h_q = query.stride(1); + stride_seq_k = key.stride(2); + stride_h_k = key.stride(1); + stride_h_v = value.stride(1); + stride_d_v = value.stride(2); + stride_seq_o = output.stride(2); + stride_h_o = output.stride(1); + + CHECK_SHAPE(key, batch_size, num_kv_heads, kv_len, head_dim); + CHECK_SHAPE(output, batch_size, num_qo_heads, qo_len, head_dim); + assert(value.size(2) == head_dim); + assert(value.size(1) == num_kv_heads); + } + + if (num_qo_heads % num_kv_heads != 0) { + std::ostringstream err_msg; + err_msg << "num_qo_heads (" << num_qo_heads << ") must be divisible by num_kv_heads (" << num_kv_heads << ")"; + throw std::invalid_argument(err_msg.str()); + } + + torch::Tensor lse = torch::empty({0}); + if (return_lse) + { + lse = torch::empty({batch_size, num_qo_heads, qo_len}, query.options().dtype(torch::kFloat32)); + } + + const int num_kv_groups = num_qo_heads / num_kv_heads; + + auto output_dtype = output.scalar_type(); + + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + DISPATCH_CAUSAL(is_causal, IS_CAUSAL, { + DISPATCH_QK_QUANT_GRAN(qk_quant_gran, QK_QUANT_GRAN, { + DISPATCH_RETURN_LSE(return_lse, RETURN_LSE, { + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(output_dtype, DTypeOut, { + + constexpr int CTA_Q = 128; + constexpr int CTA_K = 64; + constexpr int WARP_Q = 32; + constexpr int WARP_K = 64; + + assert(value.size(0) == batch_size); + assert(value.size(3) >= div_ceil(kv_len, CTA_K) * CTA_K); + + constexpr MaskMode mask_mode = IS_CAUSAL ? MaskMode::kCausal : MaskMode::kNone; + + if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp)) + { + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, div_ceil(qo_len, CTA_Q) * (CTA_Q / WARP_Q)); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K)); + } + else if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread)) + { + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, div_ceil(qo_len, CTA_Q) * (CTA_Q / WARP_Q) * 8); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K) * 4); + } + else + { + static_assert(QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp) || QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread), "Unsupported quantization granularity"); + } + + CHECK_SHAPE(value_scale, batch_size, num_kv_heads, head_dim); + + // smem_Q smem_K smem_V smem_O + size_t smem_max = std::max(CTA_Q * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(int8_t), CTA_Q * HEAD_DIM * sizeof(half)); + + auto kernel_func = qk_int_sv_f8_attn_kernel(QK_QUANT_GRAN), static_cast(QK_QUANT_GRAN), + float, true, DTypeOut, ComputeUnit::kCudaCore, mask_mode, RETURN_LSE, true, false, false>; + + cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_max); + + dim3 grid(div_ceil(qo_len, CTA_Q), num_qo_heads, batch_size); + dim3 block(32, (CTA_Q / WARP_Q) * (CTA_K / WARP_K)); + + kernel_func<<>>( + query.data_ptr(), + key.data_ptr(), + reinterpret_cast(value.data_ptr()), + reinterpret_cast(output.data_ptr()), + (RETURN_LSE) ? reinterpret_cast(lse.data_ptr()) : nullptr, + reinterpret_cast(query_scale.data_ptr()), + reinterpret_cast(key_scale.data_ptr()), + reinterpret_cast(value_scale.data_ptr()), + nullptr, + qo_len, + kv_len, + num_kv_groups, + stride_bz_q, stride_seq_q, stride_h_q, + stride_bz_k, stride_seq_k, stride_h_k, + stride_bz_v, stride_h_v, stride_d_v, + stride_bz_o, stride_seq_o, stride_h_o, + sm_scale); + }); + }); + }); + }); + }); + + return lse; +} \ No newline at end of file diff --git a/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn.cu b/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn.cu new file mode 100644 index 0000000000..9b2124d598 --- /dev/null +++ b/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn.cu @@ -0,0 +1,192 @@ +#include "attn_cuda_sm89.h" +#include "qk_int_sv_f8_cuda_sm89.cuh" +torch::Tensor qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn(torch::Tensor query, + torch::Tensor key, + torch::Tensor value, + torch::Tensor output, + torch::Tensor query_scale, + torch::Tensor key_scale, + torch::Tensor value_scale, + torch::Tensor value_mean, + int64_t tensor_layout, + int64_t is_causal, + int64_t qk_quant_gran, + double sm_scale, + int64_t return_lse) +{ + CHECK_CUDA(query); + CHECK_CUDA(key); + CHECK_CUDA(value); + CHECK_CUDA(output); + CHECK_CUDA(query_scale); + CHECK_CUDA(key_scale); + CHECK_CUDA(value_scale); + CHECK_CUDA(value_mean); + + CHECK_LASTDIM_CONTIGUOUS(query); + CHECK_LASTDIM_CONTIGUOUS(key); + CHECK_CONTIGUOUS(value); // ensure value is contiguous to prevent troubles in the kernel + CHECK_LASTDIM_CONTIGUOUS(output); + CHECK_CONTIGUOUS(query_scale); + CHECK_CONTIGUOUS(key_scale); + CHECK_CONTIGUOUS(value_scale); + CHECK_CONTIGUOUS(value_mean); + + CHECK_DTYPE(query, torch::kInt8); + CHECK_DTYPE(key, torch::kInt8); + // TODO: how to check fp8 data type? + // CHECK_DTYPE(value, torch::kHalf); + CHECK_DTYPE(query_scale, torch::kFloat32); + CHECK_DTYPE(key_scale, torch::kFloat32); + CHECK_DTYPE(value_scale, torch::kFloat32); + CHECK_DTYPE(value_mean, torch::kFloat32); + + CHECK_DIMS(query, 4); + CHECK_DIMS(key, 4); + CHECK_DIMS(value, 4); + CHECK_DIMS(output, 4); + CHECK_DIMS(query_scale, 3); + CHECK_DIMS(key_scale, 3); + CHECK_DIMS(value_scale, 3); + CHECK_DIMS(value_mean, 3); + + const int batch_size = query.size(0); + const int head_dim = query.size(3); + + int stride_bz_q = query.stride(0); + int stride_bz_k = key.stride(0); + int stride_bz_v = value.stride(0); + int stride_bz_o = output.stride(0); + + int qo_len, kv_len, num_qo_heads, num_kv_heads; + int stride_seq_q, stride_h_q, stride_seq_k, stride_h_k, stride_h_v, stride_d_v, stride_seq_o, stride_h_o; + + if (tensor_layout == 0) + { + qo_len = query.size(1); + kv_len = key.size(1); + num_qo_heads = query.size(2); + num_kv_heads = key.size(2); + + stride_seq_q = query.stride(1); + stride_h_q = query.stride(2); + stride_seq_k = key.stride(1); + stride_h_k = key.stride(2); + stride_h_v = value.stride(2); + stride_d_v = value.stride(1); + stride_seq_o = output.stride(1); + stride_h_o = output.stride(2); + + CHECK_SHAPE(key, batch_size, kv_len, num_kv_heads, head_dim); + CHECK_SHAPE(output, batch_size, qo_len, num_qo_heads, head_dim); + assert(value.size(1) == head_dim); + assert(value.size(2) == num_kv_heads); + } + else + { + qo_len = query.size(2); + kv_len = key.size(2); + num_qo_heads = query.size(1); + num_kv_heads = key.size(1); + + stride_seq_q = query.stride(2); + stride_h_q = query.stride(1); + stride_seq_k = key.stride(2); + stride_h_k = key.stride(1); + stride_h_v = value.stride(1); + stride_d_v = value.stride(2); + stride_seq_o = output.stride(2); + stride_h_o = output.stride(1); + + CHECK_SHAPE(key, batch_size, num_kv_heads, kv_len, head_dim); + CHECK_SHAPE(output, batch_size, num_qo_heads, qo_len, head_dim); + assert(value.size(2) == head_dim); + assert(value.size(1) == num_kv_heads); + } + + if (num_qo_heads % num_kv_heads != 0) { + std::ostringstream err_msg; + err_msg << "num_qo_heads (" << num_qo_heads << ") must be divisible by num_kv_heads (" << num_kv_heads << ")"; + throw std::invalid_argument(err_msg.str()); + } + + torch::Tensor lse = torch::empty({0}); + if (return_lse) + { + lse = torch::empty({batch_size, num_qo_heads, qo_len}, query.options().dtype(torch::kFloat32)); + } + + const int num_kv_groups = num_qo_heads / num_kv_heads; + + auto output_dtype = output.scalar_type(); + + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + DISPATCH_CAUSAL(is_causal, IS_CAUSAL, { + DISPATCH_QK_QUANT_GRAN(qk_quant_gran, QK_QUANT_GRAN, { + DISPATCH_RETURN_LSE(return_lse, RETURN_LSE, { + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(output_dtype, DTypeOut, { + constexpr int CTA_Q = 128; + constexpr int CTA_K = 64; + constexpr int WARP_Q = 32; + constexpr int WARP_K = 64; + + assert(value.size(0) == batch_size); + assert(value.size(3) >= div_ceil(kv_len, CTA_K) * CTA_K); + + constexpr MaskMode mask_mode = IS_CAUSAL ? MaskMode::kCausal : MaskMode::kNone; + + if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp)) + { + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, div_ceil(qo_len, CTA_Q) * (CTA_Q / WARP_Q)); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K)); + } + else if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread)) + { + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, div_ceil(qo_len, CTA_Q) * (CTA_Q / WARP_Q) * 8); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K) * 4); + } + else + { + static_assert(QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp) || QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread), "Unsupported quantization granularity"); + } + + CHECK_SHAPE(value_scale, batch_size, num_kv_heads, head_dim); + CHECK_SHAPE(value_mean, batch_size, num_kv_heads, head_dim); + + // smem_Q smem_K smem_V smem_O + size_t smem_max = std::max(CTA_Q * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(int8_t) + CTA_K * HEAD_DIM * sizeof(int8_t), CTA_Q * HEAD_DIM * sizeof(half)); + + auto kernel_func = qk_int_sv_f8_attn_kernel(QK_QUANT_GRAN), static_cast(QK_QUANT_GRAN), + float, false, DTypeOut, ComputeUnit::kCudaCore, mask_mode, RETURN_LSE, true, true, false>; + + cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_max); + + dim3 grid(div_ceil(qo_len, CTA_Q), num_qo_heads, batch_size); + dim3 block(32, (CTA_Q / WARP_Q) * (CTA_K / WARP_K)); + + kernel_func<<>>( + query.data_ptr(), + key.data_ptr(), + reinterpret_cast(value.data_ptr()), + reinterpret_cast(output.data_ptr()), + (RETURN_LSE) ? reinterpret_cast(lse.data_ptr()) : nullptr, + reinterpret_cast(query_scale.data_ptr()), + reinterpret_cast(key_scale.data_ptr()), + reinterpret_cast(value_scale.data_ptr()), + reinterpret_cast(value_mean.data_ptr()), + qo_len, + kv_len, + num_kv_groups, + stride_bz_q, stride_seq_q, stride_h_q, + stride_bz_k, stride_seq_k, stride_h_k, + stride_bz_v, stride_h_v, stride_d_v, + stride_bz_o, stride_seq_o, stride_h_o, + sm_scale); + }); + }); + }); + }); + }); + + return lse; +} \ No newline at end of file diff --git a/kernels/attention/sage_attn/reduction_utils.cuh b/kernels/attention/sage_attn/reduction_utils.cuh new file mode 100644 index 0000000000..5ba6ef8a61 --- /dev/null +++ b/kernels/attention/sage_attn/reduction_utils.cuh @@ -0,0 +1,194 @@ +/* + * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/reduce_kernel_utils.cuh + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Modifications copyright (c) 2024 by SageAttention team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#define FINAL_MASK 0xffffffff + +namespace aphrodite { + +template +__inline__ __device__ T warpReduceSum(T val) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val += __shfl_xor_sync(0xffffffff, val, mask, 32); + return val; +} + +template +__inline__ __device__ T warpReduceSumV2(T* val) +{ +#pragma unroll + for (int i = 0; i < NUM; i++) + { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val[i] += __shfl_xor_sync(FINAL_MASK, val[i], mask, 32); + } + return (T) (0.0f); +} + +/* Calculate the sum of all elements in a block */ +template +__inline__ __device__ T blockReduceSum(T val) { + static __shared__ T shared[32]; + int lane = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + val = warpReduceSum(val); + + if (lane == 0) + shared[wid] = val; + + __syncthreads(); + + // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent + // blockDim.x is not divided by 32 + val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f); + val = warpReduceSum(val); + return val; +} + +/* Calculate the sum of all elements in a block */ +template +__inline__ __device__ T blockAllReduceSum(T val) { + static __shared__ T shared[32]; + int lane = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + val = warpReduceSum(val); + + if (lane == 0) + shared[wid] = val; + + __syncthreads(); + + // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent + // blockDim.x is not divided by 32 + val = (lane < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f); + val = warpReduceSum(val); + return val; +} + +template +__inline__ __device__ T blockReduceSumV2(T* val) +{ + static __shared__ T shared[NUM][33]; + int lane = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + warpReduceSumV2(val); + + if (lane == 0) + { +#pragma unroll + for (int i = 0; i < NUM; i++) + { + shared[i][wid] = val[i]; + } + } + + __syncthreads(); + + bool is_mask = threadIdx.x < (blockDim.x / 32.f); +#pragma unroll + for (int i = 0; i < NUM; i++) + { + val[i] = is_mask ? shared[i][lane] : (T) (0.0f); + } + warpReduceSumV2(val); + return (T) 0.0f; +} + +template +__inline__ __device__ T warpReduceMax(T val) +{ +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val = max(val, __shfl_xor_sync(0xffffffff, val, mask, 32)); + return val; +} +/* Calculate the maximum of all elements in a block */ +template +__inline__ __device__ T blockReduceMax(T val) +{ + static __shared__ T shared[32]; + int lane = threadIdx.x & 0x1f; // in-warp idx + int wid = threadIdx.x >> 5; // warp idx + val = warpReduceMax(val); // get maxx in each warp + if (lane == 0) // record in-warp maxx by warp Idx + shared[wid] = val; + __syncthreads(); + // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent + // blockDim.x is not divided by 32 + val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : -1e20f; + val = warpReduceMax(val); + return val; +} + +/* Calculate the maximum of all elements in a block */ +template +__inline__ __device__ T blockAllReduceMax(T val) +{ + static __shared__ T shared[32]; + int lane = threadIdx.x & 0x1f; // in-warp idx + int wid = threadIdx.x >> 5; // warp idx + + val = warpReduceMax(val); // get maxx in each warp + + if (lane == 0) // record in-warp maxx by warp Idx + shared[wid] = val; + + __syncthreads(); + + // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent + // blockDim.x is not divided by 32 + val = (lane < (blockDim.x / 32.f)) ? shared[lane] : -1e20f; + val = warpReduceMax(val); + + return val; +} + +template +__inline__ __device__ T warpReduceMin(T val) +{ +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val = min(val, __shfl_xor_sync(0xffffffff, val, mask, 32)); + return val; +} +/* Calculate the minimum of all elements in a block */ +template +__inline__ __device__ T blockReduceMin(T val) +{ + static __shared__ T shared[32]; + int lane = threadIdx.x & 0x1f; // in-warp idx + int wid = threadIdx.x >> 5; // warp idx + val = warpReduceMin(val); // get minx in each warp + if (lane == 0) // record in-warp minx by warp Idx + shared[wid] = val; + __syncthreads(); + // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent + // blockDim.x is not divided by 32 + val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : 1e20f; + val = warpReduceMin(val); + return val; +} + +} // namespace aphrodite diff --git a/kernels/attention/sage_attn/utils.cuh b/kernels/attention/sage_attn/utils.cuh new file mode 100644 index 0000000000..d051cc1f44 --- /dev/null +++ b/kernels/attention/sage_attn/utils.cuh @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2024 by SageAttention team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include + +#define CHECK_CUDA(x) \ + TORCH_CHECK(x.is_cuda(), "Tensor " #x " must be on CUDA") +#define CHECK_DTYPE(x, true_dtype) \ + TORCH_CHECK(x.dtype() == true_dtype, \ + "Tensor " #x " must have dtype (" #true_dtype ")") +#define CHECK_DIMS(x, true_dim) \ + TORCH_CHECK(x.dim() == true_dim, \ + "Tensor " #x " must have dimension number (" #true_dim ")") +#define CHECK_NUMEL(x, minimum) \ + TORCH_CHECK(x.numel() >= minimum, \ + "Tensor " #x " must have at last " #minimum " elements") +#define CHECK_SHAPE(x, ...) \ + TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), \ + "Tensor " #x " must have shape (" #__VA_ARGS__ ")") +#define CHECK_CONTIGUOUS(x) \ + TORCH_CHECK(x.is_contiguous(), "Tensor " #x " must be contiguous") +#define CHECK_LASTDIM_CONTIGUOUS(x) \ + TORCH_CHECK(x.stride(-1) == 1, \ + "Tensor " #x " must be contiguous at the last dimension") \ No newline at end of file diff --git a/kernels/attention/sage_attn/wgmma.cuh b/kernels/attention/sage_attn/wgmma.cuh new file mode 100644 index 0000000000..f25d6c8755 --- /dev/null +++ b/kernels/attention/sage_attn/wgmma.cuh @@ -0,0 +1,300 @@ +/* + * Copyright (c) 2024 by SageAttention team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include + +namespace wgmma{ +__device__ __forceinline__ uint64_t matrix_descriptor_encode(uint64_t x) { return (((x) & 0x3FFFF) >> 0x4); } + +template +__device__ uint64_t make_smem_desc(T* ptr) { + static_assert(stride == 32 || stride == 64 || stride == 128); + uint32_t addr = static_cast(__cvta_generic_to_shared(ptr)); + uint64_t desc = 0x0000000000000000; + desc |= matrix_descriptor_encode(addr); + desc |= matrix_descriptor_encode((uint64_t)16) << 16; + desc |= matrix_descriptor_encode((uint64_t)(8 * stride)) << 32; + desc |= ((stride == 128) ? 1llu : (stride == 64) ? 2llu : 3llu) << 62; + return desc; +} + +__device__ __forceinline__ void warpgroup_arrive() { + asm volatile("wgmma.fence.sync.aligned;\n" ::: "memory"); +} + +__device__ __forceinline__ void warpgroup_commit_batch() { + asm volatile("wgmma.commit_group.sync.aligned;\n" ::: "memory"); +} + +template +__device__ __forceinline__ void warpgroup_wait() { + static_assert(N >= 0 && N <= 7, "WGMMA wait: N must be in range [0, 7]"); + asm volatile("wgmma.wait_group.sync.aligned %0;\n" ::"n"(N) : "memory"); +} + +template +__device__ __forceinline__ void wgmma_m64n128k16_f16f16f32(float d[][8], T* sA, T* sB) { + uint64_t desc_a = make_smem_desc(&sA[0]); + uint64_t desc_b = make_smem_desc(&sB[0]); + asm volatile( + "{\n" + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}, " + " %64," + " %65," + " %66, %67, %68, %69, %70;\n" + "}\n" + : "+f"(d[0][0]), "+f"(d[0][1]), "+f"(d[0][2]), "+f"(d[0][3]), "+f"(d[0][4]), "+f"(d[0][5]), "+f"(d[0][6]), "+f"(d[0][7]), + "+f"(d[1][0]), "+f"(d[1][1]), "+f"(d[1][2]), "+f"(d[1][3]), "+f"(d[1][4]), "+f"(d[1][5]), "+f"(d[1][6]), "+f"(d[1][7]), + "+f"(d[2][0]), "+f"(d[2][1]), "+f"(d[2][2]), "+f"(d[2][3]), "+f"(d[2][4]), "+f"(d[2][5]), "+f"(d[2][6]), "+f"(d[2][7]), + "+f"(d[3][0]), "+f"(d[3][1]), "+f"(d[3][2]), "+f"(d[3][3]), "+f"(d[3][4]), "+f"(d[3][5]), "+f"(d[3][6]), "+f"(d[3][7]), + "+f"(d[4][0]), "+f"(d[4][1]), "+f"(d[4][2]), "+f"(d[4][3]), "+f"(d[4][4]), "+f"(d[4][5]), "+f"(d[4][6]), "+f"(d[4][7]), + "+f"(d[5][0]), "+f"(d[5][1]), "+f"(d[5][2]), "+f"(d[5][3]), "+f"(d[5][4]), "+f"(d[5][5]), "+f"(d[5][6]), "+f"(d[5][7]), + "+f"(d[6][0]), "+f"(d[6][1]), "+f"(d[6][2]), "+f"(d[6][3]), "+f"(d[6][4]), "+f"(d[6][5]), "+f"(d[6][6]), "+f"(d[6][7]), + "+f"(d[7][0]), "+f"(d[7][1]), "+f"(d[7][2]), "+f"(d[7][3]), "+f"(d[7][4]), "+f"(d[7][5]), "+f"(d[7][6]), "+f"(d[7][7]) + : "l"(desc_a), "l"(desc_b), "n"(int32_t(ScaleD)), "n"(int32_t(ScaleA)), + "n"(int32_t(ScaleB)), "n"(int32_t(TransA)), "n"(int32_t(TransB))); +} + +template +__device__ __forceinline__ void wgmma_m64n64k16_f16f16f32(float d[][8], T* sA, T* sB) { + uint64_t desc_a = make_smem_desc(&sA[0]); + uint64_t desc_b = make_smem_desc(&sB[0]); + asm volatile( + "{\n" + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}, " + " %32," + " %33," + " %34, %35, %36, %37, %38;\n" + "}\n" + : "+f"(d[0][0]), "+f"(d[0][1]), "+f"(d[0][2]), "+f"(d[0][3]), "+f"(d[0][4]), "+f"(d[0][5]), "+f"(d[0][6]), "+f"(d[0][7]), + "+f"(d[1][0]), "+f"(d[1][1]), "+f"(d[1][2]), "+f"(d[1][3]), "+f"(d[1][4]), "+f"(d[1][5]), "+f"(d[1][6]), "+f"(d[1][7]), + "+f"(d[2][0]), "+f"(d[2][1]), "+f"(d[2][2]), "+f"(d[2][3]), "+f"(d[2][4]), "+f"(d[2][5]), "+f"(d[2][6]), "+f"(d[2][7]), + "+f"(d[3][0]), "+f"(d[3][1]), "+f"(d[3][2]), "+f"(d[3][3]), "+f"(d[3][4]), "+f"(d[3][5]), "+f"(d[3][6]), "+f"(d[3][7]) + : "l"(desc_a), "l"(desc_b), "n"(int32_t(ScaleD)), "n"(int32_t(ScaleA)), + "n"(int32_t(ScaleB)), "n"(int32_t(TransA)), "n"(int32_t(TransB))); +} + +template +__device__ __forceinline__ void wgmma_m64n128k16_f16f16f32(float d[][8], uint32_t RA[], T* sB) { + uint64_t desc_b = make_smem_desc(&sB[0]); + asm volatile( + "{\n" + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}, " + "{%64, %65, %66, %67}, " + " %68," + " %69, %70, %71, %72;\n" + "}\n" + : "+f"(d[0][0]), "+f"(d[0][1]), "+f"(d[0][2]), "+f"(d[0][3]), "+f"(d[0][4]), "+f"(d[0][5]), "+f"(d[0][6]), "+f"(d[0][7]), + "+f"(d[1][0]), "+f"(d[1][1]), "+f"(d[1][2]), "+f"(d[1][3]), "+f"(d[1][4]), "+f"(d[1][5]), "+f"(d[1][6]), "+f"(d[1][7]), + "+f"(d[2][0]), "+f"(d[2][1]), "+f"(d[2][2]), "+f"(d[2][3]), "+f"(d[2][4]), "+f"(d[2][5]), "+f"(d[2][6]), "+f"(d[2][7]), + "+f"(d[3][0]), "+f"(d[3][1]), "+f"(d[3][2]), "+f"(d[3][3]), "+f"(d[3][4]), "+f"(d[3][5]), "+f"(d[3][6]), "+f"(d[3][7]), + "+f"(d[4][0]), "+f"(d[4][1]), "+f"(d[4][2]), "+f"(d[4][3]), "+f"(d[4][4]), "+f"(d[4][5]), "+f"(d[4][6]), "+f"(d[4][7]), + "+f"(d[5][0]), "+f"(d[5][1]), "+f"(d[5][2]), "+f"(d[5][3]), "+f"(d[5][4]), "+f"(d[5][5]), "+f"(d[5][6]), "+f"(d[5][7]), + "+f"(d[6][0]), "+f"(d[6][1]), "+f"(d[6][2]), "+f"(d[6][3]), "+f"(d[6][4]), "+f"(d[6][5]), "+f"(d[6][6]), "+f"(d[6][7]), + "+f"(d[7][0]), "+f"(d[7][1]), "+f"(d[7][2]), "+f"(d[7][3]), "+f"(d[7][4]), "+f"(d[7][5]), "+f"(d[7][6]), "+f"(d[7][7]) + : "r"(RA[0]), "r"(RA[1]), "r"(RA[2]), "r"(RA[3]), + "l"(desc_b), "n"(int32_t(ScaleD)), "n"(int32_t(ScaleA)), + "n"(int32_t(ScaleB)), "n"(int32_t(TransB))); +} + +template +__device__ __forceinline__ void wgmma_m64n64k16_f16f16f32(float d[][8], uint32_t RA[], T* sB) { + uint64_t desc_b = make_smem_desc(&sB[0]); + asm volatile( + "{\n" + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}, " + "{%32, %33, %34, %35}, " + " %36," + " %37, %38, %39, %40;\n" + "}\n" + : "+f"(d[0][0]), "+f"(d[0][1]), "+f"(d[0][2]), "+f"(d[0][3]), "+f"(d[0][4]), "+f"(d[0][5]), "+f"(d[0][6]), "+f"(d[0][7]), + "+f"(d[1][0]), "+f"(d[1][1]), "+f"(d[1][2]), "+f"(d[1][3]), "+f"(d[1][4]), "+f"(d[1][5]), "+f"(d[1][6]), "+f"(d[1][7]), + "+f"(d[2][0]), "+f"(d[2][1]), "+f"(d[2][2]), "+f"(d[2][3]), "+f"(d[2][4]), "+f"(d[2][5]), "+f"(d[2][6]), "+f"(d[2][7]), + "+f"(d[3][0]), "+f"(d[3][1]), "+f"(d[3][2]), "+f"(d[3][3]), "+f"(d[3][4]), "+f"(d[3][5]), "+f"(d[3][6]), "+f"(d[3][7]) + : "r"(RA[0]), "r"(RA[1]), "r"(RA[2]), "r"(RA[3]), + "l"(desc_b), "n"(int32_t(ScaleD)), "n"(int32_t(ScaleA)), + "n"(int32_t(ScaleB)), "n"(int32_t(TransB))); +} + +template +__device__ __forceinline__ void wgmma_m64n64k32_f8f8f32(float d[][8], uint32_t RA[], T* sB) { + uint64_t desc_b = make_smem_desc(&sB[0]); + asm volatile( + "{\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}, " + "{%32, %33, %34, %35}, " + " %36," + " %37," + " %38, %39;\n" + "}\n" + : "+f"(d[0][0]), "+f"(d[0][1]), "+f"(d[0][2]), "+f"(d[0][3]), "+f"(d[0][4]), "+f"(d[0][5]), "+f"(d[0][6]), "+f"(d[0][7]), + "+f"(d[1][0]), "+f"(d[1][1]), "+f"(d[1][2]), "+f"(d[1][3]), "+f"(d[1][4]), "+f"(d[1][5]), "+f"(d[1][6]), "+f"(d[1][7]), + "+f"(d[2][0]), "+f"(d[2][1]), "+f"(d[2][2]), "+f"(d[2][3]), "+f"(d[2][4]), "+f"(d[2][5]), "+f"(d[2][6]), "+f"(d[2][7]), + "+f"(d[3][0]), "+f"(d[3][1]), "+f"(d[3][2]), "+f"(d[3][3]), "+f"(d[3][4]), "+f"(d[3][5]), "+f"(d[3][6]), "+f"(d[3][7]) + : "r"(RA[0]), "r"(RA[1]), "r"(RA[2]), "r"(RA[3]), + "l"(desc_b), "n"(int32_t(ScaleD)), + "n"(1), "n"(1)); +} + +template +__device__ __forceinline__ void wgmma_m64n128k32_f8f8f32(float d[][8], uint32_t RA[], T* sB) { + uint64_t desc_b = make_smem_desc(&sB[0]); + asm volatile( + "{\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}, " + "{%64, %65, %66, %67}, " + " %68," + " %69," + " %70, %71;\n" + "}\n" + : "+f"(d[0][0]), "+f"(d[0][1]), "+f"(d[0][2]), "+f"(d[0][3]), "+f"(d[0][4]), "+f"(d[0][5]), "+f"(d[0][6]), "+f"(d[0][7]), + "+f"(d[1][0]), "+f"(d[1][1]), "+f"(d[1][2]), "+f"(d[1][3]), "+f"(d[1][4]), "+f"(d[1][5]), "+f"(d[1][6]), "+f"(d[1][7]), + "+f"(d[2][0]), "+f"(d[2][1]), "+f"(d[2][2]), "+f"(d[2][3]), "+f"(d[2][4]), "+f"(d[2][5]), "+f"(d[2][6]), "+f"(d[2][7]), + "+f"(d[3][0]), "+f"(d[3][1]), "+f"(d[3][2]), "+f"(d[3][3]), "+f"(d[3][4]), "+f"(d[3][5]), "+f"(d[3][6]), "+f"(d[3][7]), + "+f"(d[4][0]), "+f"(d[4][1]), "+f"(d[4][2]), "+f"(d[4][3]), "+f"(d[4][4]), "+f"(d[4][5]), "+f"(d[4][6]), "+f"(d[4][7]), + "+f"(d[5][0]), "+f"(d[5][1]), "+f"(d[5][2]), "+f"(d[5][3]), "+f"(d[5][4]), "+f"(d[5][5]), "+f"(d[5][6]), "+f"(d[5][7]), + "+f"(d[6][0]), "+f"(d[6][1]), "+f"(d[6][2]), "+f"(d[6][3]), "+f"(d[6][4]), "+f"(d[6][5]), "+f"(d[6][6]), "+f"(d[6][7]), + "+f"(d[7][0]), "+f"(d[7][1]), "+f"(d[7][2]), "+f"(d[7][3]), "+f"(d[7][4]), "+f"(d[7][5]), "+f"(d[7][6]), "+f"(d[7][7]) + : "r"(RA[0]), "r"(RA[1]), "r"(RA[2]), "r"(RA[3]), + "l"(desc_b), "n"(int32_t(ScaleD)), + "n"(1), "n"(1)); +} + +template +__device__ void wgmma_m64n128k32_s8s8s32(int32_t d[][8], T* sA, T* sB) { + uint64_t desc_a = make_smem_desc(&sA[0]); + uint64_t desc_b = make_smem_desc(&sB[0]); + asm volatile( + "{\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}, " + " %64," + " %65," + " %66;\n" + "}\n" + : "+r"(d[0][0]), "+r"(d[0][1]), "+r"(d[0][2]), "+r"(d[0][3]), "+r"(d[0][4]), "+r"(d[0][5]), "+r"(d[0][6]), "+r"(d[0][7]), + "+r"(d[1][0]), "+r"(d[1][1]), "+r"(d[1][2]), "+r"(d[1][3]), "+r"(d[1][4]), "+r"(d[1][5]), "+r"(d[1][6]), "+r"(d[1][7]), + "+r"(d[2][0]), "+r"(d[2][1]), "+r"(d[2][2]), "+r"(d[2][3]), "+r"(d[2][4]), "+r"(d[2][5]), "+r"(d[2][6]), "+r"(d[2][7]), + "+r"(d[3][0]), "+r"(d[3][1]), "+r"(d[3][2]), "+r"(d[3][3]), "+r"(d[3][4]), "+r"(d[3][5]), "+r"(d[3][6]), "+r"(d[3][7]), + "+r"(d[4][0]), "+r"(d[4][1]), "+r"(d[4][2]), "+r"(d[4][3]), "+r"(d[4][4]), "+r"(d[4][5]), "+r"(d[4][6]), "+r"(d[4][7]), + "+r"(d[5][0]), "+r"(d[5][1]), "+r"(d[5][2]), "+r"(d[5][3]), "+r"(d[5][4]), "+r"(d[5][5]), "+r"(d[5][6]), "+r"(d[5][7]), + "+r"(d[6][0]), "+r"(d[6][1]), "+r"(d[6][2]), "+r"(d[6][3]), "+r"(d[6][4]), "+r"(d[6][5]), "+r"(d[6][6]), "+r"(d[6][7]), + "+r"(d[7][0]), "+r"(d[7][1]), "+r"(d[7][2]), "+r"(d[7][3]), "+r"(d[7][4]), "+r"(d[7][5]), "+r"(d[7][6]), "+r"(d[7][7]) + : "l"(desc_a), "l"(desc_b), "n"(int32_t(ScaleD))); +} + +template +__device__ void wgmma_m64n64k32_s8s8s32(int32_t d[][8], T* sA, T* sB) { + uint64_t desc_a = make_smem_desc(&sA[0]); + uint64_t desc_b = make_smem_desc(&sB[0]); + asm volatile( + "{\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34;\n" + "}\n" + : "+r"(d[0][0]), "+r"(d[0][1]), "+r"(d[0][2]), "+r"(d[0][3]), "+r"(d[0][4]), "+r"(d[0][5]), "+r"(d[0][6]), "+r"(d[0][7]), + "+r"(d[1][0]), "+r"(d[1][1]), "+r"(d[1][2]), "+r"(d[1][3]), "+r"(d[1][4]), "+r"(d[1][5]), "+r"(d[1][6]), "+r"(d[1][7]), + "+r"(d[2][0]), "+r"(d[2][1]), "+r"(d[2][2]), "+r"(d[2][3]), "+r"(d[2][4]), "+r"(d[2][5]), "+r"(d[2][6]), "+r"(d[2][7]), + "+r"(d[3][0]), "+r"(d[3][1]), "+r"(d[3][2]), "+r"(d[3][3]), "+r"(d[3][4]), "+r"(d[3][5]), "+r"(d[3][6]), "+r"(d[3][7]) + : "l"(desc_a), "l"(desc_b), "n"(int32_t(ScaleD))); +} + +template +__device__ __forceinline__ void wgmma_f16f16f32(float d[WGMMA_N/16][8], T* sA, T* sB) { + static_assert(std::is_same::value); + + static_assert(WGMMA_N == 128 || WGMMA_N == 64); + if constexpr (WGMMA_N == 128) { + wgmma_m64n128k16_f16f16f32(d, sA, sB); + } + else if constexpr (WGMMA_N == 64) { + wgmma_m64n64k16_f16f16f32(d, sA, sB); + } +} + +template +__device__ __forceinline__ void wgmma_s8s8s32(int32_t d[WGMMA_N/16][8], T* sA, T* sB) { + static_assert(WGMMA_N == 128 || WGMMA_N == 64); + if constexpr (WGMMA_N == 128) { + wgmma_m64n128k32_s8s8s32(d, sA, sB); + } + else if constexpr (WGMMA_N == 64) { + wgmma_m64n64k32_s8s8s32(d, sA, sB); + } +} + +template +__device__ __forceinline__ void wgmma_f8f8f32(float d[][8], uint32_t* RA, T* sB) { + static_assert(WGMMA_N == 128 || WGMMA_N == 64); + if constexpr (WGMMA_N == 128) { + wgmma_m64n128k32_f8f8f32(d, RA, sB); + } + else if constexpr (WGMMA_N == 64) { + wgmma_m64n64k32_f8f8f32(d, RA, sB); + } +} + +} // namespace wgmma \ No newline at end of file diff --git a/kernels/ops.h b/kernels/ops.h index 4cb90f5594..d2ef9ad913 100644 --- a/kernels/ops.h +++ b/kernels/ops.h @@ -432,4 +432,233 @@ void qr_open_handles(fptr_t _fa, const std::vector& handles); void qr_all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, int64_t quant_level, bool cast_bf2half = false); int64_t qr_max_size(); -#endif \ No newline at end of file +#endif + +// Sage Attention +#ifndef USE_ROCM + +void quant_per_block_int8_cuda( + torch::Tensor input, + torch::Tensor output, + torch::Tensor scale, + double sm_scale, + int64_t block_size, + int64_t tensor_layout); + +void quant_per_block_int8_cuda( + torch::Tensor input, + torch::Tensor output, + torch::Tensor scale, + int64_t block_size, + int64_t tensor_layout); + +void quant_per_block_int8_fuse_sub_mean_cuda( + torch::Tensor input, + torch::Tensor mean, + torch::Tensor output, + torch::Tensor scale, + int64_t block_size, + int64_t tensor_layout); + +void quant_per_warp_int8_cuda( + torch::Tensor input, + torch::Tensor output, + torch::Tensor scale, + int64_t block_size, + int64_t warp_block_size, + int64_t tensor_layout); + +void sub_mean_cuda( + torch::Tensor input, + torch::Tensor mean, + torch::Tensor output, + int64_t tensor_layout); + +void transpose_pad_permute_cuda( + torch::Tensor input, + torch::Tensor output, + int64_t tensor_layout); + +void scale_fuse_quant_cuda( + torch::Tensor input, + torch::Tensor output, + torch::Tensor scale, + int64_t num_tokens, + double scale_max, + int64_t tensor_layout); + +void mean_scale_fuse_quant_cuda( + torch::Tensor input, + torch::Tensor output, + torch::Tensor mean, + torch::Tensor scale, + int64_t num_tokens, + double scale_max, + int64_t tensor_layout); + +torch::Tensor qk_int8_sv_f16_accum_f32_attn(torch::Tensor query, + torch::Tensor key, + torch::Tensor value, + torch::Tensor output, + torch::Tensor query_scale, + torch::Tensor key_scale, + int64_t tensor_layout, + int64_t is_causal, + int64_t qk_quant_gran, + double sm_scale, + int64_t return_lse); + +torch::Tensor qk_int8_sv_f16_accum_f16_attn(torch::Tensor query, + torch::Tensor key, + torch::Tensor value, + torch::Tensor output, + torch::Tensor query_scale, + torch::Tensor key_scale, + int64_t tensor_layout, + int64_t is_causal, + int64_t qk_quant_gran, + double sm_scale, + int64_t return_lse); + +torch::Tensor qk_int8_sv_f16_accum_f16_attn_inst_buf(torch::Tensor query, + torch::Tensor key, + torch::Tensor value, + torch::Tensor output, + torch::Tensor query_scale, + torch::Tensor key_scale, + int64_t tensor_layout, + int64_t is_causal, + int64_t qk_quant_gran, + double sm_scale, + int64_t return_lse); + +torch::Tensor qk_int8_sv_f16_accum_f16_fuse_v_mean_attn(torch::Tensor query, + torch::Tensor key, + torch::Tensor value, + torch::Tensor output, + torch::Tensor query_scale, + torch::Tensor key_scale, + torch::Tensor value_mean, + int64_t tensor_layout, + int64_t is_causal, + int64_t qk_quant_gran, + double sm_scale, + int64_t return_lse); + +torch::Tensor qk_int8_sv_f8_accum_f32_attn(torch::Tensor query, + torch::Tensor key, + torch::Tensor value, + torch::Tensor output, + torch::Tensor query_scale, + torch::Tensor key_scale, + int64_t tensor_layout, + int64_t is_causal, + int64_t qk_quant_gran, + double sm_scale, + int64_t return_lse); + +torch::Tensor qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(torch::Tensor query, + torch::Tensor key, + torch::Tensor value, + torch::Tensor output, + torch::Tensor query_scale, + torch::Tensor key_scale, + torch::Tensor value_scale, + int64_t tensor_layout, + int64_t is_causal, + int64_t qk_quant_gran, + double sm_scale, + int64_t return_lse); + +torch::Tensor qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn(torch::Tensor query, + torch::Tensor key, + torch::Tensor value, + torch::Tensor output, + torch::Tensor query_scale, + torch::Tensor key_scale, + torch::Tensor value_scale, + torch::Tensor value_mean, + int64_t tensor_layout, + int64_t is_causal, + int64_t qk_quant_gran, + double sm_scale, + int64_t return_lse); + +torch::Tensor qk_int8_sv_f8_accum_f32_attn_inst_buf(torch::Tensor query, + torch::Tensor key, + torch::Tensor value, + torch::Tensor output, + torch::Tensor query_scale, + torch::Tensor key_scale, + int64_t tensor_layout, + int64_t is_causal, + int64_t qk_quant_gran, + double sm_scale, + int64_t return_lse); + +torch::Tensor qk_int8_sv_f8_accum_f16_attn_inst_buf(torch::Tensor query, + torch::Tensor key, + torch::Tensor value, + torch::Tensor output, + torch::Tensor query_scale, + torch::Tensor key_scale, + int64_t tensor_layout, + int64_t is_causal, + int64_t qk_quant_gran, + double sm_scale, + int64_t return_lse); + +torch::Tensor qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(torch::Tensor query, + torch::Tensor key, + torch::Tensor value, + torch::Tensor output, + torch::Tensor query_scale, + torch::Tensor key_scale, + torch::Tensor value_scale, + int64_t tensor_layout, + int64_t is_causal, + int64_t qk_quant_gran, + double sm_scale, + int64_t return_lse); + +torch::Tensor qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf(torch::Tensor query, + torch::Tensor key, + torch::Tensor value, + torch::Tensor output, + torch::Tensor query_scale, + torch::Tensor key_scale, + torch::Tensor value_scale, + int64_t tensor_layout, + int64_t is_causal, + int64_t qk_quant_gran, + double sm_scale, + int64_t return_lse); + +torch::Tensor qk_int8_sv_f8_accum_f32_attn_inst_buf( + torch::Tensor query, + torch::Tensor key, + torch::Tensor value, + torch::Tensor output, + torch::Tensor query_scale, + torch::Tensor key_scale, + int64_t tensor_layout, + int64_t is_causal, + int64_t qk_quant_gran, + double sm_scale, + int64_t return_lse); + +torch::Tensor qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf( + torch::Tensor query, + torch::Tensor key, + torch::Tensor value, + torch::Tensor output, + torch::Tensor query_scale, + torch::Tensor key_scale, + torch::Tensor value_scale, + int64_t tensor_layout, + int64_t is_causal, + int64_t qk_quant_gran, + double sm_scale, + int64_t return_lse); + +#endif diff --git a/kernels/torch_bindings.cpp b/kernels/torch_bindings.cpp index 2536008820..d9125f90a4 100644 --- a/kernels/torch_bindings.cpp +++ b/kernels/torch_bindings.cpp @@ -788,6 +788,277 @@ ops.def("cutlass_encode_and_reorder_int4b(Tensor B) -> Tensor"); &top_k_mask_logits); ops.impl("top_k_mask_logits", torch::kCUDA, &top_k_mask_logits); + // Sage Attention + ops.def( + "quant_per_block_int8_cuda(" + " Tensor input," + " Tensor output," + " Tensor scale," + " float sm_scale," + " int block_size," + " int tensor_layout) -> ()"); + ops.impl("quant_per_block_int8_cuda", torch::kCUDA, + static_cast(&quant_per_block_int8_cuda)); + + ops.def( + "quant_per_block_int8_cuda(" + " Tensor input," + " Tensor output," + " Tensor scale," + " int block_size," + " int tensor_layout) -> ()"); + ops.impl("quant_per_block_int8_cuda", torch::kCUDA, + static_cast(&quant_per_block_int8_cuda)); + + ops.def( + "quant_per_block_int8_fuse_sub_mean_cuda(" + " Tensor input," + " Tensor mean," + " Tensor output," + " Tensor scale," + " int block_size," + " int tensor_layout) -> ()"); + ops.impl("quant_per_block_int8_fuse_sub_mean_cuda", torch::kCUDA, &quant_per_block_int8_fuse_sub_mean_cuda); + + ops.def( + "quant_per_warp_int8_cuda(" + " Tensor input," + " Tensor output," + " Tensor scale," + " int block_size," + " int warp_block_size," + " int tensor_layout) -> ()"); + ops.impl("quant_per_warp_int8_cuda", torch::kCUDA, &quant_per_warp_int8_cuda); + + ops.def("sub_mean_cuda(Tensor input, Tensor mean, Tensor output, int tensor_layout) -> ()"); + ops.impl("sub_mean_cuda", torch::kCUDA, &sub_mean_cuda); + + ops.def("transpose_pad_permute_cuda(Tensor input, Tensor output, int tensor_layout) -> ()"); + ops.impl("transpose_pad_permute_cuda", torch::kCUDA, &transpose_pad_permute_cuda); + + ops.def( + "scale_fuse_quant_cuda(" + " Tensor input," + " Tensor output," + " Tensor scale," + " int num_tokens," + " float scale_max," + " int tensor_layout) -> ()"); + ops.impl("scale_fuse_quant_cuda", torch::kCUDA, &scale_fuse_quant_cuda); + + ops.def( + "mean_scale_fuse_quant_cuda(" + " Tensor input," + " Tensor output," + " Tensor mean," + " Tensor scale," + " int num_tokens," + " float scale_max," + " int tensor_layout) -> ()"); + ops.impl("mean_scale_fuse_quant_cuda", torch::kCUDA, &mean_scale_fuse_quant_cuda); + + ops.def( + "qk_int8_sv_f16_accum_f32_attn(" + " Tensor query," + " Tensor key," + " Tensor value," + " Tensor output," + " Tensor query_scale," + " Tensor key_scale," + " int tensor_layout," + " int is_causal," + " int qk_quant_gran," + " float sm_scale," + " int return_lse) -> Tensor"); + ops.impl("qk_int8_sv_f16_accum_f32_attn", torch::kCUDA, &qk_int8_sv_f16_accum_f32_attn); + + ops.def( + "qk_int8_sv_f16_accum_f16_attn(" + " Tensor query," + " Tensor key," + " Tensor value," + " Tensor output," + " Tensor query_scale," + " Tensor key_scale," + " int tensor_layout," + " int is_causal," + " int qk_quant_gran," + " float sm_scale," + " int return_lse) -> Tensor"); + ops.impl("qk_int8_sv_f16_accum_f16_attn", torch::kCUDA, &qk_int8_sv_f16_accum_f16_attn); + + ops.def( + "qk_int8_sv_f16_accum_f16_attn_inst_buf(" + " Tensor query," + " Tensor key," + " Tensor value," + " Tensor output," + " Tensor query_scale," + " Tensor key_scale," + " int tensor_layout," + " int is_causal," + " int qk_quant_gran," + " float sm_scale," + " int return_lse) -> Tensor"); + ops.impl("qk_int8_sv_f16_accum_f16_attn_inst_buf", torch::kCUDA, &qk_int8_sv_f16_accum_f16_attn_inst_buf); + + ops.def( + "qk_int8_sv_f16_accum_f16_fuse_v_mean_attn(" + " Tensor query," + " Tensor key," + " Tensor value," + " Tensor output," + " Tensor query_scale," + " Tensor key_scale," + " Tensor value_mean," + " int tensor_layout," + " int is_causal," + " int qk_quant_gran," + " float sm_scale," + " int return_lse) -> Tensor"); + ops.impl("qk_int8_sv_f16_accum_f16_fuse_v_mean_attn", torch::kCUDA, &qk_int8_sv_f16_accum_f16_fuse_v_mean_attn); + + ops.def( + "qk_int8_sv_f8_accum_f32_attn(" + " Tensor query," + " Tensor key," + " Tensor value," + " Tensor output," + " Tensor query_scale," + " Tensor key_scale," + " int tensor_layout," + " int is_causal," + " int qk_quant_gran," + " float sm_scale," + " int return_lse) -> Tensor"); + ops.impl("qk_int8_sv_f8_accum_f32_attn", torch::kCUDA, &qk_int8_sv_f8_accum_f32_attn); + + ops.def( + "qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(" + " Tensor query," + " Tensor key," + " Tensor value," + " Tensor output," + " Tensor query_scale," + " Tensor key_scale," + " Tensor value_scale," + " int tensor_layout," + " int is_causal," + " int qk_quant_gran," + " float sm_scale," + " int return_lse) -> Tensor"); + ops.impl("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn", torch::kCUDA, &qk_int8_sv_f8_accum_f32_fuse_v_scale_attn); + + ops.def( + "qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn(" + " Tensor query," + " Tensor key," + " Tensor value," + " Tensor output," + " Tensor query_scale," + " Tensor key_scale," + " Tensor value_scale," + " Tensor value_mean," + " int tensor_layout," + " int is_causal," + " int qk_quant_gran," + " float sm_scale," + " int return_lse) -> Tensor"); + ops.impl("qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn", torch::kCUDA, &qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn); + + ops.def( + "qk_int8_sv_f8_accum_f32_attn_inst_buf(" + " Tensor query," + " Tensor key," + " Tensor value," + " Tensor output," + " Tensor query_scale," + " Tensor key_scale," + " int tensor_layout," + " int is_causal," + " int qk_quant_gran," + " float sm_scale," + " int return_lse) -> Tensor"); + ops.impl("qk_int8_sv_f8_accum_f32_attn_inst_buf", torch::kCUDA, &qk_int8_sv_f8_accum_f32_attn_inst_buf); + + ops.def( + "qk_int8_sv_f8_accum_f16_attn_inst_buf(" + " Tensor query," + " Tensor key," + " Tensor value," + " Tensor output," + " Tensor query_scale," + " Tensor key_scale," + " int tensor_layout," + " int is_causal," + " int qk_quant_gran," + " float sm_scale," + " int return_lse) -> Tensor"); + ops.impl("qk_int8_sv_f8_accum_f16_attn_inst_buf", torch::kCUDA, &qk_int8_sv_f8_accum_f16_attn_inst_buf); + + ops.def( + "qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(" + " Tensor query," + " Tensor key," + " Tensor value," + " Tensor output," + " Tensor query_scale," + " Tensor key_scale," + " Tensor value_scale," + " int tensor_layout," + " int is_causal," + " int qk_quant_gran," + " float sm_scale," + " int return_lse) -> Tensor"); + ops.impl("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf", torch::kCUDA, &qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf); + + ops.def( + "qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf(" + " Tensor query," + " Tensor key," + " Tensor value," + " Tensor output," + " Tensor query_scale," + " Tensor key_scale," + " Tensor value_scale," + " int tensor_layout," + " int is_causal," + " int qk_quant_gran," + " float sm_scale," + " int return_lse) -> Tensor"); + ops.impl("qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf", torch::kCUDA, &qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf); + + ops.def( + "qk_int8_sv_f8_accum_f32_attn_inst_buf(" + " Tensor query," + " Tensor key," + " Tensor value," + " Tensor output," + " Tensor query_scale," + " Tensor key_scale," + " int tensor_layout," + " int is_causal," + " int qk_quant_gran," + " float sm_scale," + " int return_lse) -> Tensor"); + ops.impl("qk_int8_sv_f8_accum_f32_attn_inst_buf", torch::kCUDA, &qk_int8_sv_f8_accum_f32_attn_inst_buf); + + ops.def( + "qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(" + " Tensor query," + " Tensor key," + " Tensor value," + " Tensor output," + " Tensor query_scale," + " Tensor key_scale," + " Tensor value_scale," + " int tensor_layout," + " int is_causal," + " int qk_quant_gran," + " float sm_scale," + " int return_lse) -> Tensor"); + ops.impl("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf", torch::kCUDA, &qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf); + #endif } diff --git a/tools/generate_torch_registration.py b/tools/generate_torch_registration.py new file mode 100755 index 0000000000..99bbeba316 --- /dev/null +++ b/tools/generate_torch_registration.py @@ -0,0 +1,407 @@ +#!/usr/bin/env python3 +""" +Script to generate PyTorch operation registration code from C++ function +signatures. + +Usage: + python generate_torch_registration.py input_file.h [--namespace my_ops] \ + [--backend torch::kCUDA] +""" + +import re +import argparse +import sys +from typing import List, Tuple, Optional, Dict + + +class CPPToTorchMapper: + """Maps C++ types to PyTorch schema types.""" + + def __init__(self): + # Mapping from C++ types to PyTorch schema types + self.type_mapping = { + # Tensor types + 'torch::Tensor&': 'Tensor!', + 'const torch::Tensor&': 'Tensor', + 'torch::Tensor': 'Tensor', + 'std::optional&': 'Tensor!?', + 'const std::optional&': 'Tensor?', + 'std::optional': 'Tensor?', + 'std::vector&': 'Tensor[]', # Note: aliasing needs manual handling # noqa: E501 + 'const std::vector&': 'Tensor[]', + 'std::vector': 'Tensor[]', + + # Scalar types + 'int64_t': 'int', + 'const int64_t': 'int', + 'int64_t&': 'int', # Unusual but possible + 'double': 'float', + 'const double': 'float', + 'float': 'float', + 'const float': 'float', + 'bool': 'bool', + 'const bool': 'bool', + + # String types + 'std::string': 'str', + 'const std::string&': 'str', + 'std::string&': 'str', + + # Vector types + 'std::vector': 'int[]', + 'const std::vector&': 'int[]', + 'std::vector': 'str[]', + 'const std::vector&': 'str[]', + + # PyTorch specific types + 'at::ScalarType': 'ScalarType', + 'const at::ScalarType': 'ScalarType', + 'std::optional': 'ScalarType?', + 'const std::optional&': 'ScalarType?', + } + + # Common type aliases that might appear in headers + self.type_aliases = { + 'c10::optional': 'std::optional', + } + + def normalize_type(self, cpp_type: str) -> str: + """Normalize C++ type string by removing extra spaces and applying + aliases.""" + # Remove extra whitespace + cpp_type = re.sub(r'\s+', ' ', cpp_type.strip()) + + # Apply type aliases + for alias, replacement in self.type_aliases.items(): + cpp_type = cpp_type.replace(alias, replacement) + + return cpp_type + + def cpp_to_schema_type(self, cpp_type: str) -> str: + """Convert C++ type to PyTorch schema type.""" + normalized = self.normalize_type(cpp_type) + + if normalized in self.type_mapping: + return self.type_mapping[normalized] + + # Handle some common variations + if 'torch::Tensor' in normalized and 'optional' in normalized: + if '&' in normalized and 'const' not in normalized: + return 'Tensor!?' + else: + return 'Tensor?' + elif 'torch::Tensor' in normalized and '&' in normalized and \ + 'const' not in normalized: + return 'Tensor!' + elif 'torch::Tensor' in normalized and 'const' in normalized or \ + 'torch::Tensor' in normalized: + return 'Tensor' + + # Fallback - silently use a placeholder + return 'Unknown' + + +class FunctionSignatureParser: + """Parses C++ function signatures into structured data.""" + + def __init__(self): + self.mapper = CPPToTorchMapper() + + def parse_signature(self, signature: str) -> Optional[Dict]: + """Parse a C++ function signature into components.""" + # Remove comments and normalize whitespace + signature = re.sub(r'//.*$', '', signature, flags=re.MULTILINE) + signature = re.sub(r'/\*.*?\*/', '', signature, flags=re.DOTALL) + signature = re.sub(r'\s+', ' ', signature.strip()) + + # Match function signature pattern + # Return type, function name, parameters + pattern = r'(\w+(?:\s*::\s*\w+)*(?:\s*<[^>]*>)?(?:\s*&)?)\s+(\w+)\s*\((.*?)\)\s*;?' # noqa: E501 + match = re.match(pattern, signature, re.DOTALL) + + if not match: + return None + + return_type, func_name, params_str = match.groups() + + # Parse parameters + parameters = self.parse_parameters(params_str) + + return { + 'return_type': return_type.strip(), + 'function_name': func_name.strip(), + 'parameters': parameters, + 'raw_signature': signature + } + + def parse_parameters(self, params_str: str) -> List[Tuple[str, str]]: + """Parse parameter list into (type, name) tuples.""" + if not params_str.strip(): + return [] + + parameters = [] + # Split by comma, but be careful of nested templates and function + # pointers + param_parts = self.split_parameters(params_str) + + for param in param_parts: + param = param.strip() + if not param: + continue + + # Extract type and name + # Handle cases like "const torch::Tensor& input", "int64_t size", + # etc. + # Look for the last identifier as the parameter name + tokens = param.split() + if len(tokens) >= 2: + param_name = tokens[-1] + param_type = ' '.join(tokens[:-1]) + else: + # No parameter name provided, generate one + param_type = param + param_name = f"param_{len(parameters)}" + + parameters.append((param_type, param_name)) + + return parameters + + def split_parameters(self, params_str: str) -> List[str]: + """Split parameter string by commas, respecting nested templates.""" + parameters = [] + current_param = "" + paren_depth = 0 + angle_depth = 0 + + for char in params_str: + if char == '<': + angle_depth += 1 + elif char == '>': + angle_depth -= 1 + elif char == '(': + paren_depth += 1 + elif char == ')': + paren_depth -= 1 + elif char == ',' and paren_depth == 0 and angle_depth == 0: + parameters.append(current_param.strip()) + current_param = "" + continue + + current_param += char + + if current_param.strip(): + parameters.append(current_param.strip()) + + return parameters + +class TorchRegistrationGenerator: + """Generates PyTorch registration code from parsed function signatures.""" + + def __init__( + self, namespace: str = "my_ops", + backend: str = "torch::kCUDA" + ): + self.namespace = namespace + self.backend = backend + self.mapper = CPPToTorchMapper() + + def generate_schema(self, func_info: Dict) -> str: + """Generate PyTorch schema string from function info.""" + func_name = func_info['function_name'] + parameters = func_info['parameters'] + return_type = func_info['return_type'] + + # Convert parameters to schema format + schema_params = [] + for param_type, param_name in parameters: + schema_type = self.mapper.cpp_to_schema_type(param_type) + schema_params.append(f"{schema_type} {param_name}") + + # Determine return type + if return_type == 'void': + schema_return = '()' + elif 'torch::Tensor' in return_type and 'vector' not in return_type.lower(): + schema_return = 'Tensor' + elif 'std::vector' in return_type: + schema_return = 'Tensor[]' + elif 'int64_t' in return_type: + schema_return = 'int' + elif 'bool' in return_type: + schema_return = 'bool' + elif 'string' in return_type.lower(): + schema_return = 'str' + else: + schema_return = f'Unknown /* {return_type} */' + + # Build schema string + params_str = ', '.join(schema_params) + schema = f"{func_name}({params_str}) -> {schema_return}" + + return schema + + def generate_registration(self, func_info: Dict) -> str: + """Generate complete registration code for a function.""" + func_name = func_info['function_name'] + schema = self.generate_schema(func_info) + + # Split long schemas for readability + if len(schema) > 80: + def_code = self.format_multiline_def(func_name, schema) + else: + def_code = f' ops.def("{schema}");' + + impl_code = f' ops.impl("{func_name}", {self.backend}, &{func_name});' + + return f"{def_code}\n{impl_code}" + + def format_multiline_def(self, func_name: str, schema: str) -> str: + """Format a long schema across multiple lines for ops.def().""" + # Find the parameter list + if '(' in schema and ')' in schema: + func_part = schema[:schema.find('(') + 1] + params_part = schema[schema.find('(') + 1:schema.rfind(')')] + return_part = schema[schema.rfind(')'):] + + # Split parameters + params = [p.strip() for p in params_part.split(',') if p.strip()] + + if len(params) <= 3: + return f' ops.def("{schema}");' + + # Format as proper C++ multiline string literal + lines = [] + lines.append(' ops.def(') + lines.append(f' "{func_part}"') + + for i, param in enumerate(params): + if i == len(params) - 1: + # Last parameter - add return part + lines.append(f' " {param}{return_part}");') + else: + lines.append(f' " {param},"') + + return '\n'.join(lines) + + return f' ops.def("{schema}");' + + def generate_library_block(self, func_infos: List[Dict]) -> str: + """Generate complete TORCH_LIBRARY block.""" + header = f"TORCH_LIBRARY({self.namespace}, ops) {{\n" + + registrations = [] + for func_info in func_infos: + registration = self.generate_registration(func_info) + registrations.append( + f" // {func_info['function_name']}\n{registration}") + + body = '\n\n'.join(registrations) + footer = "\n}" + + return header + body + footer + + def generate_raw_registrations(self, func_infos: List[Dict]) -> str: + """Generate just the registration statements without wrapper.""" + registrations = [] + for func_info in func_infos: + registration = self.generate_registration(func_info) + registrations.append(registration) + + return '\n\n'.join(registrations) + +def parse_header_file(filepath: str) -> List[str]: + """Extract function signatures from a header file.""" + try: + with open(filepath, 'r') as f: + content = f.read() + except FileNotFoundError: + print(f"Error: File '{filepath}' not found.", file=sys.stderr) + sys.exit(1) + + signatures = [] + lines = content.split('\n') + current_signature = "" + in_function = False + + for line in lines: + line = line.strip() + + # Skip preprocessor directives, comments, and empty lines + if (line.startswith('#') or line.startswith('//') or + line.startswith('/*') or not line): + continue + + # Skip common non-function declarations + if (line.startswith('class ') or line.startswith('struct ') or + line.startswith('namespace ') or line.startswith('using ') or + line.startswith('typedef ') or line.startswith('template')): + continue + + # Accumulate multi-line function signatures + if ('(' in line and not line.endswith(';')) or in_function: + in_function = True + current_signature += " " + line + if ';' in line: + signatures.append(current_signature.strip()) + current_signature = "" + in_function = False + elif '(' in line and ';' in line: + # Single-line function signature + signatures.append(line) + + return signatures + +def main(): + parser = argparse.ArgumentParser( + description="Generate PyTorch registration code from C++ function" + " signatures." + ) + parser.add_argument( + "input_file", help="Input header file with C++ function signatures") + parser.add_argument( + "--namespace", default="my_ops", help="PyTorch library namespace") + parser.add_argument( + "--backend", default="torch::kCUDA", help="Backend for ops.impl") + parser.add_argument( + "--output", "-o", help="Output file (default: stdout)") + parser.add_argument( + "--with-library", + action="store_true", + help="Include TORCH_LIBRARY wrapper (default: raw statements only)") + + args = parser.parse_args() + + # Parse input file + signatures = parse_header_file(args.input_file) + + if not signatures: + sys.exit(1) + + # Parse signatures + parser = FunctionSignatureParser() + func_infos = [] + + for sig in signatures: + func_info = parser.parse_signature(sig) + if func_info: + func_infos.append(func_info) + + if not func_infos: + sys.exit(1) + + # Generate registration code + generator = TorchRegistrationGenerator(args.namespace, args.backend) + if args.with_library: + registration_code = generator.generate_library_block(func_infos) + else: + registration_code = generator.generate_raw_registrations(func_infos) + + # Output + if args.output: + with open(args.output, 'w') as f: + f.write(registration_code) + else: + print(registration_code) + + +if __name__ == "__main__": + main() From fc2d021a6296a6362f0d3f358f586546a0c1a204 Mon Sep 17 00:00:00 2001 From: AlpinDale Date: Fri, 12 Sep 2025 05:40:33 +0000 Subject: [PATCH 2/3] port sageattn triton kernels --- CMakeLists.txt | 83 +- aphrodite/_custom_ops.py | 317 ++++ .../attention/ops/sage_attention/__init__.py | 14 + .../attn_qk_int8_block_varlen.py | 196 +++ .../sage_attention/attn_qk_int8_per_block.py | 358 +++++ .../attn_qk_int8_per_block_causal.py | 330 ++++ .../attn_qk_int8_per_block_causal_varlen.py | 205 +++ .../attention/ops/sage_attention/core.py | 1367 +++++++++++++++++ .../attention/ops/sage_attention/quant.py | 348 +++++ .../ops/sage_attention/quant_per_block.py | 186 +++ .../sage_attention/quant_per_block_varlen.py | 165 ++ .../ops/sage_attention/quant_per_thread.py | 379 +++++ aphrodite/v1/attention/backends/sage_attn.py | 2 +- kernels/attention/sage_attn/fused/fused.cu | 16 + .../qattn/qk_int_sv_f16_cuda_sm80.cu | 7 + .../sage_attn/qattn/qk_int_sv_f8_cuda_sm90.cu | 4 + ...9_qk_int8_sv_f8_accum_f16_attn_inst_buf.cu | 2 + ...f8_accum_f16_fuse_v_scale_attn_inst_buf.cu | 2 + .../sm89_qk_int8_sv_f8_accum_f32_attn.cu | 2 + ...9_qk_int8_sv_f8_accum_f32_attn_inst_buf.cu | 2 + ..._int8_sv_f8_accum_f32_fuse_v_scale_attn.cu | 2 + ...f8_accum_f32_fuse_v_scale_attn_inst_buf.cu | 2 + ...accum_f32_fuse_v_scale_fuse_v_mean_attn.cu | 2 + kernels/attention/sage_attn/utils.cuh | 2 +- 24 files changed, 3919 insertions(+), 74 deletions(-) create mode 100644 aphrodite/attention/ops/sage_attention/__init__.py create mode 100644 aphrodite/attention/ops/sage_attention/attn_qk_int8_block_varlen.py create mode 100644 aphrodite/attention/ops/sage_attention/attn_qk_int8_per_block.py create mode 100644 aphrodite/attention/ops/sage_attention/attn_qk_int8_per_block_causal.py create mode 100644 aphrodite/attention/ops/sage_attention/attn_qk_int8_per_block_causal_varlen.py create mode 100644 aphrodite/attention/ops/sage_attention/core.py create mode 100644 aphrodite/attention/ops/sage_attention/quant.py create mode 100644 aphrodite/attention/ops/sage_attention/quant_per_block.py create mode 100644 aphrodite/attention/ops/sage_attention/quant_per_block_varlen.py create mode 100644 aphrodite/attention/ops/sage_attention/quant_per_thread.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 9a0dd3dd50..09662080d8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -307,7 +307,17 @@ set(APHRODITE_EXT_SRC "kernels/sparse/cutlass/sparse_scaled_mm_entry.cu" "kernels/cutlass_extensions/common.cpp" "kernels/attention/mla/cutlass_mla_entry.cu" - "kernels/quantization/fp8/per_token_group_quant.cu") + "kernels/quantization/fp8/per_token_group_quant.cu" + "kernels/attention/sage_attn/fused/fused.cu" + "kernels/attention/sage_attn/qattn/qk_int_sv_f8_cuda_sm90.cu" + "kernels/attention/sage_attn/qattn/qk_int_sv_f16_cuda_sm80.cu" + "kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f16_attn_inst_buf.cu" + "kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf.cu" + "kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_attn_inst_buf.cu" + "kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_attn.cu" + "kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf.cu" + "kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn.cu" + "kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn.cu") set_gencode_flags_for_srcs( SRCS "${APHRODITE_EXT_SRC}" @@ -790,77 +800,6 @@ set(APHRODITE_EXT_SRC endif() endif() - # - # SageAttention kernels - # - - # Only build SageAttention kernels if we are building for at least SM 8.0 compatible archs - cuda_archs_loose_intersection(SAGE_ATTN_ARCHS "8.0;8.6;8.7;8.9;9.0+PTX" "${CUDA_ARCHS}") - if (SAGE_ATTN_ARCHS) - - # Base SageAttention sources (always included) - set(SAGE_ATTN_BASE_SRCS - "kernels/attention/sage_attn/fused/fused.cu") - - # SM 8.0 specific kernels - cuda_archs_loose_intersection(SAGE_ATTN_SM80_ARCHS "8.0;8.6;8.7" "${CUDA_ARCHS}") - set(SAGE_ATTN_SM80_SRCS) - if (SAGE_ATTN_SM80_ARCHS) - list(APPEND SAGE_ATTN_SM80_SRCS - "kernels/attention/sage_attn/qattn/qk_int_sv_f16_cuda_sm80.cu") - set_gencode_flags_for_srcs( - SRCS "${SAGE_ATTN_SM80_SRCS}" - CUDA_ARCHS "${SAGE_ATTN_SM80_ARCHS}") - message(STATUS "Building SageAttention SM80 kernels for archs: ${SAGE_ATTN_SM80_ARCHS}") - endif() - - # SM 8.9 specific kernels - cuda_archs_loose_intersection(SAGE_ATTN_SM89_ARCHS "8.9" "${CUDA_ARCHS}") - set(SAGE_ATTN_SM89_SRCS) - if (SAGE_ATTN_SM89_ARCHS) - list(APPEND SAGE_ATTN_SM89_SRCS - "kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f16_attn_inst_buf.cu" - "kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf.cu" - "kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_attn_inst_buf.cu" - "kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_attn.cu" - "kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf.cu" - "kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn.cu" - "kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn.cu") - set_gencode_flags_for_srcs( - SRCS "${SAGE_ATTN_SM89_SRCS}" - CUDA_ARCHS "${SAGE_ATTN_SM89_ARCHS}") - message(STATUS "Building SageAttention SM89 kernels for archs: ${SAGE_ATTN_SM89_ARCHS}") - endif() - - # SM 9.0 specific kernels - cuda_archs_loose_intersection(SAGE_ATTN_SM90_ARCHS "9.0+PTX" "${CUDA_ARCHS}") - set(SAGE_ATTN_SM90_SRCS) - if (SAGE_ATTN_SM90_ARCHS) - list(APPEND SAGE_ATTN_SM90_SRCS - "kernels/attention/sage_attn/qattn/qk_int_sv_f8_cuda_sm90.cu") - set_gencode_flags_for_srcs( - SRCS "${SAGE_ATTN_SM90_SRCS}" - CUDA_ARCHS "${SAGE_ATTN_SM90_ARCHS}") - message(STATUS "Building SageAttention SM90 kernels for archs: ${SAGE_ATTN_SM90_ARCHS}") - endif() - - set(SAGE_ATTN_SRCS ${SAGE_ATTN_BASE_SRCS}) - list(APPEND SAGE_ATTN_SRCS ${SAGE_ATTN_SM80_SRCS}) - list(APPEND SAGE_ATTN_SRCS ${SAGE_ATTN_SM89_SRCS}) - list(APPEND SAGE_ATTN_SRCS ${SAGE_ATTN_SM90_SRCS}) - - set_gencode_flags_for_srcs( - SRCS "${SAGE_ATTN_BASE_SRCS}" - CUDA_ARCHS "${SAGE_ATTN_ARCHS}") - - list(APPEND APHRODITE_EXT_SRC "${SAGE_ATTN_SRCS}") - - message(STATUS "Building SageAttention kernels for archs: ${SAGE_ATTN_ARCHS}") - else() - message(STATUS "Not building SageAttention kernels as no compatible archs found" - " in CUDA target architectures (requires SM 8.0 or above)") - endif() - # if CUDA endif endif() diff --git a/aphrodite/_custom_ops.py b/aphrodite/_custom_ops.py index f37c409cb6..566d56ca86 100644 --- a/aphrodite/_custom_ops.py +++ b/aphrodite/_custom_ops.py @@ -173,6 +173,323 @@ def merge_attn_states(output: torch.Tensor, prefix_lse, suffix_output, suffix_lse) +# sage attention ops +def quant_per_block_int8( + input: torch.Tensor, + output: torch.Tensor, + scale: torch.Tensor, + sm_scale: Optional[float] = None, + block_size: int = 16, + tensor_layout: int = 0, +) -> None: + """Quantize per block to int8. + + Args: + input: Input tensor to quantize + output: Output quantized tensor + scale: Output scale tensor + sm_scale: Optional softmax scale + block_size: Block size for quantization + tensor_layout: Tensor layout format + """ + if sm_scale is not None: + torch.ops._C.quant_per_block_int8_cuda( + input, output, scale, sm_scale, block_size, tensor_layout) + else: + torch.ops._C.quant_per_block_int8_cuda( + input, output, scale, block_size, tensor_layout) + + +def quant_per_block_int8_fuse_sub_mean( + input: torch.Tensor, + mean: torch.Tensor, + output: torch.Tensor, + scale: torch.Tensor, + block_size: int = 16, + tensor_layout: int = 0, +) -> None: + """Quantize per block to int8 with fused mean subtraction.""" + torch.ops._C.quant_per_block_int8_fuse_sub_mean_cuda( + input, mean, output, scale, block_size, tensor_layout) + + +def quant_per_warp_int8( + input: torch.Tensor, + output: torch.Tensor, + scale: torch.Tensor, + block_size: int = 16, + warp_block_size: int = 32, + tensor_layout: int = 0, +) -> None: + """Quantize per warp to int8.""" + torch.ops._C.quant_per_warp_int8_cuda( + input, output, scale, block_size, warp_block_size, tensor_layout) + + +def sub_mean( + input: torch.Tensor, + mean: torch.Tensor, + output: torch.Tensor, + tensor_layout: int = 0, +) -> None: + """Subtract mean from input.""" + torch.ops._C.sub_mean_cuda(input, mean, output, tensor_layout) + + +def transpose_pad_permute( + input: torch.Tensor, + output: torch.Tensor, + tensor_layout: int = 0, +) -> None: + """Transpose, pad and permute tensor.""" + torch.ops._C.transpose_pad_permute_cuda(input, output, tensor_layout) + + +def scale_fuse_quant( + input: torch.Tensor, + output: torch.Tensor, + scale: torch.Tensor, + num_tokens: int, + scale_max: float, + tensor_layout: int = 0, +) -> None: + """Scale and fuse quantization.""" + torch.ops._C.scale_fuse_quant_cuda( + input, output, scale, num_tokens, scale_max, tensor_layout) + + +def mean_scale_fuse_quant( + input: torch.Tensor, + output: torch.Tensor, + mean: torch.Tensor, + scale: torch.Tensor, + num_tokens: int, + scale_max: float, + tensor_layout: int = 0, +) -> None: + """Mean, scale and fuse quantization.""" + torch.ops._C.mean_scale_fuse_quant_cuda( + input, output, mean, scale, num_tokens, scale_max, tensor_layout) + + +def qk_int8_sv_f16_accum_f32_attn( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + query_scale: torch.Tensor, + key_scale: torch.Tensor, + tensor_layout: int = 0, + is_causal: int = 0, + qk_quant_gran: int = 16, + sm_scale: float = 1.0, + return_lse: int = 0, +) -> torch.Tensor: + """QK int8 attention with SV f16 accumulation in f32.""" + return torch.ops._C.qk_int8_sv_f16_accum_f32_attn( + query, key, value, output, query_scale, key_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse) + + +def qk_int8_sv_f16_accum_f16_attn( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + query_scale: torch.Tensor, + key_scale: torch.Tensor, + tensor_layout: int = 0, + is_causal: int = 0, + qk_quant_gran: int = 16, + sm_scale: float = 1.0, + return_lse: int = 0, +) -> torch.Tensor: + """QK int8 attention with SV f16 accumulation in f16.""" + return torch.ops._C.qk_int8_sv_f16_accum_f16_attn( + query, key, value, output, query_scale, key_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse) + + +def qk_int8_sv_f16_accum_f16_attn_inst_buf( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + query_scale: torch.Tensor, + key_scale: torch.Tensor, + tensor_layout: int = 0, + is_causal: int = 0, + qk_quant_gran: int = 16, + sm_scale: float = 1.0, + return_lse: int = 0, +) -> torch.Tensor: + """QK int8 attention with SV f16 accumulation in f16 with instant buffer.""" + return torch.ops._C.qk_int8_sv_f16_accum_f16_attn_inst_buf( + query, key, value, output, query_scale, key_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse) + + +def qk_int8_sv_f16_accum_f16_fuse_v_mean_attn( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + query_scale: torch.Tensor, + key_scale: torch.Tensor, + value_mean: torch.Tensor, + tensor_layout: int = 0, + is_causal: int = 0, + qk_quant_gran: int = 16, + sm_scale: float = 1.0, + return_lse: int = 0, +) -> torch.Tensor: + """QK int8 attention with SV f16 accumulation and fused value mean.""" + return torch.ops._C.qk_int8_sv_f16_accum_f16_fuse_v_mean_attn( + query, key, value, output, query_scale, key_scale, value_mean, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse) + + +def qk_int8_sv_f8_accum_f32_attn( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + query_scale: torch.Tensor, + key_scale: torch.Tensor, + tensor_layout: int = 0, + is_causal: int = 0, + qk_quant_gran: int = 16, + sm_scale: float = 1.0, + return_lse: int = 0, +) -> torch.Tensor: + """QK int8 attention with SV f8 accumulation in f32.""" + return torch.ops._C.qk_int8_sv_f8_accum_f32_attn( + query, key, value, output, query_scale, key_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse) + + +def qk_int8_sv_f8_accum_f32_fuse_v_scale_attn( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + query_scale: torch.Tensor, + key_scale: torch.Tensor, + value_scale: torch.Tensor, + tensor_layout: int = 0, + is_causal: int = 0, + qk_quant_gran: int = 16, + sm_scale: float = 1.0, + return_lse: int = 0, +) -> torch.Tensor: + """QK int8 attention with SV f8 accumulation and fused value scale.""" + return torch.ops._C.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse) + + +def qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + query_scale: torch.Tensor, + key_scale: torch.Tensor, + value_scale: torch.Tensor, + value_mean: torch.Tensor, + tensor_layout: int = 0, + is_causal: int = 0, + qk_quant_gran: int = 16, + sm_scale: float = 1.0, + return_lse: int = 0, +) -> torch.Tensor: + """QK int8 attention with SV f8 accumulation and fused value scale/mean.""" + return torch.ops._C.qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn( + query, key, value, output, query_scale, key_scale, value_scale, + value_mean, tensor_layout, is_causal, qk_quant_gran, sm_scale, + return_lse) + + +def qk_int8_sv_f8_accum_f32_attn_inst_buf( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + query_scale: torch.Tensor, + key_scale: torch.Tensor, + tensor_layout: int = 0, + is_causal: int = 0, + qk_quant_gran: int = 16, + sm_scale: float = 1.0, + return_lse: int = 0, +) -> torch.Tensor: + """QK int8 attention with SV f8 accumulation in f32 with instant buffer.""" + return torch.ops._C.qk_int8_sv_f8_accum_f32_attn_inst_buf( + query, key, value, output, query_scale, key_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse) + + +def qk_int8_sv_f8_accum_f16_attn_inst_buf( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + query_scale: torch.Tensor, + key_scale: torch.Tensor, + tensor_layout: int = 0, + is_causal: int = 0, + qk_quant_gran: int = 16, + sm_scale: float = 1.0, + return_lse: int = 0, +) -> torch.Tensor: + """QK int8 attention with SV f8 accumulation in f16 with instant buffer.""" + return torch.ops._C.qk_int8_sv_f8_accum_f16_attn_inst_buf( + query, key, value, output, query_scale, key_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse) + + +def qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + query_scale: torch.Tensor, + key_scale: torch.Tensor, + value_scale: torch.Tensor, + tensor_layout: int = 0, + is_causal: int = 0, + qk_quant_gran: int = 16, + sm_scale: float = 1.0, + return_lse: int = 0, +) -> torch.Tensor: + """QK int8 attention with SV f8 accumulation with fused scale and instant + buffer.""" + return torch.ops._C.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse) + + +def qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + query_scale: torch.Tensor, + key_scale: torch.Tensor, + value_scale: torch.Tensor, + tensor_layout: int = 0, + is_causal: int = 0, + qk_quant_gran: int = 16, + sm_scale: float = 1.0, + return_lse: int = 0, +) -> torch.Tensor: + """QK int8 attention with SV f8 accumulation with fused scale and instant + buffer.""" + return torch.ops._C.qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse) + + def convert_vertical_slash_indexes( q_seqlens: torch.Tensor, # [BATCH, ] kv_seqlens: torch.Tensor, # [BATCH, ] diff --git a/aphrodite/attention/ops/sage_attention/__init__.py b/aphrodite/attention/ops/sage_attention/__init__.py new file mode 100644 index 0000000000..ed112946eb --- /dev/null +++ b/aphrodite/attention/ops/sage_attention/__init__.py @@ -0,0 +1,14 @@ +from .core import (sageattn, sageattn_qk_int8_pv_fp8_cuda, + sageattn_qk_int8_pv_fp8_cuda_sm90, + sageattn_qk_int8_pv_fp16_cuda, + sageattn_qk_int8_pv_fp16_triton, sageattn_varlen) + + +__all__ = [ + "sageattn", + "sageattn_qk_int8_pv_fp8_cuda", + "sageattn_qk_int8_pv_fp8_cuda_sm90", + "sageattn_qk_int8_pv_fp16_cuda", + "sageattn_qk_int8_pv_fp16_triton", + "sageattn_varlen" +] diff --git a/aphrodite/attention/ops/sage_attention/attn_qk_int8_block_varlen.py b/aphrodite/attention/ops/sage_attention/attn_qk_int8_block_varlen.py new file mode 100644 index 0000000000..d4f6665ec8 --- /dev/null +++ b/aphrodite/attention/ops/sage_attention/attn_qk_int8_block_varlen.py @@ -0,0 +1,196 @@ +""" +Copyright (c) 2024 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch + +from aphrodite.triton_utils import tl, triton + + +@triton.jit +def _attn_fwd_inner( + acc, + l_i, + m_i, + q, + q_scale, + kv_len, + K_ptrs, + K_scale_ptr, + V_ptrs, + stride_kn, + stride_vn, + start_m, + H: tl.constexpr, + BLOCK_M: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_N: tl.constexpr, + STAGE: tl.constexpr, + offs_m: tl.constexpr, + offs_n: tl.constexpr, +): + lo, hi = 0, kv_len + for start_n in range(lo, hi, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + k_mask = offs_n[None, :] < (kv_len - start_n) + k = tl.load(K_ptrs, mask=k_mask) + k_scale = tl.load(K_scale_ptr) + qk = tl.dot(q, k).to(tl.float32) * (q_scale * k_scale) + + qk += tl.where(k_mask, 0, float('-inf')) + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk = qk - m_ij[:, None] + + p = tl.math.exp2(qk) + l_ij = tl.sum(p, 1) + + alpha = tl.math.exp2(m_i - m_ij) + l_i = l_i * alpha + l_ij + + acc = acc * alpha[:, None] + + v = tl.load(V_ptrs, mask=offs_n[:, None] < (kv_len - start_n)) + p = p.to(tl.float16) + + acc += tl.dot(p, v, out_dtype=tl.float16) + m_i = m_ij + K_ptrs += BLOCK_N * stride_kn + K_scale_ptr += H + V_ptrs += BLOCK_N * stride_vn + return acc, l_i + + +@triton.jit +def _attn_fwd( + Q, + K, + V, + cu_seqlens_q, + cu_seqlens_k, + Q_scale, + K_scale, + cu_seqlens_q_scale, + cu_seqlens_k_scale, + Out, + stride_qh, stride_qn, + stride_kh, stride_kn, + stride_vh, stride_vn, + stride_oh, stride_on, + H: tl.constexpr, + num_kv_groups: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + STAGE: tl.constexpr +): + start_m = tl.program_id(0) + + off_z = tl.program_id(2).to(tl.int64) + off_h = tl.program_id(1).to(tl.int64) + + cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) + cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) + + qo_len = cu_seqlens_q_end - cu_seqlens_q_start + + if (start_m * BLOCK_M) >= qo_len: + return + + cu_seq_lens_q_scale_start = tl.load(cu_seqlens_q_scale + off_z) + cu_seq_lens_k_scale_start = tl.load(cu_seqlens_k_scale + off_z) + + q_scale_offset = cu_seq_lens_q_scale_start * H + off_h + start_m * H + k_scale_offset = cu_seq_lens_k_scale_start * ( + H // num_kv_groups) + off_h // num_kv_groups + + cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z) + cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1) + + kv_len = cu_seqlens_k_end - cu_seqlens_k_start + + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, HEAD_DIM) + Q_ptrs = Q + (cu_seqlens_q_start * stride_qn + off_h * stride_qh) + \ + offs_m[:, None] * stride_qn + offs_k[None, :] + Q_scale_ptr = Q_scale + q_scale_offset + K_ptrs = K + (cu_seqlens_k_start * stride_kn + (off_h // num_kv_groups) * + stride_kh) + offs_n[None, :] * stride_kn + offs_k[:, None] + K_scale_ptr = K_scale + k_scale_offset + V_ptrs = V + (cu_seqlens_k_start * stride_vn + (off_h // num_kv_groups) * + stride_vh) + offs_n[:, None] * stride_vn + offs_k[None, :] + O_block_ptr = (Out + (cu_seqlens_q_start * stride_on + off_h * stride_oh) + + offs_m[:, None] * stride_on + offs_k[None, :]) + + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 + acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + + q = tl.load(Q_ptrs, mask=offs_m[:, None] < qo_len) + q_scale = tl.load(Q_scale_ptr) + acc, l_i = _attn_fwd_inner( + acc, l_i, m_i, q, q_scale, kv_len, K_ptrs, + K_scale_ptr, V_ptrs, stride_kn, stride_vn, + start_m, + H // num_kv_groups, + BLOCK_M, HEAD_DIM, BLOCK_N, + 4 - STAGE, offs_m, offs_n + ) + acc = acc / l_i[:, None] + tl.store(O_block_ptr, acc.to(Out.type.element_ty), + mask=(offs_m[:, None] < qo_len)) + + +def forward( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + q_scale, + k_scale, + cu_seqlens_q_scale, + cu_seqlens_k_scale, + output_dtype=torch.float16, +): + BLOCK_M = 128 + BLOCK_N = 64 + stage = 1 + + o = torch.empty(q.shape, dtype=output_dtype, device=q.device) + + b = cu_seqlens_q.shape[0] - 1 + _, h_qo, head_dim = q.shape + _, h_kv, _ = k.shape + + HEAD_DIM_K = head_dim + num_kv_groups = h_qo // h_kv + + grid = (triton.cdiv(max_seqlen_q, BLOCK_M), h_qo, b) + _attn_fwd[grid]( + q, k, v, cu_seqlens_q, cu_seqlens_k, + q_scale, k_scale, cu_seqlens_q_scale, cu_seqlens_k_scale, + o, + q.stride(1), q.stride(0), + k.stride(1), k.stride(0), + v.stride(1), v.stride(0), + o.stride(1), o.stride(0), + h_qo, num_kv_groups, + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, HEAD_DIM=HEAD_DIM_K, + STAGE=stage, + num_warps=4 if head_dim == 64 else 8, + num_stages=3 if head_dim == 64 else 4) + return o diff --git a/aphrodite/attention/ops/sage_attention/attn_qk_int8_per_block.py b/aphrodite/attention/ops/sage_attention/attn_qk_int8_per_block.py new file mode 100644 index 0000000000..6e40146cba --- /dev/null +++ b/aphrodite/attention/ops/sage_attention/attn_qk_int8_per_block.py @@ -0,0 +1,358 @@ +""" +Copyright (c) 2024 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch + +from aphrodite.triton_utils import tl, triton + + +@triton.jit +def _attn_fwd_inner( + acc, + l_i, + m_i, + q, + q_scale, + qo_len, + kv_len, + K_ptrs, + K_scale_ptr, + V_ptrs, + stride_kn, + stride_vn, + start_m, + mask_ptrs, + stride_maskn, + BLOCK_M: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_N: tl.constexpr, + STAGE: tl.constexpr, + offs_m: tl.constexpr, + offs_n: tl.constexpr, +): + lo, hi = 0, kv_len + for start_n in range(lo, hi, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + mask_block = None + skip = False + if mask_ptrs is not None: + if mask_ptrs.dtype.element_ty == tl.int1: + mask_block = tl.load( + mask_ptrs + start_n * stride_maskn, + mask=(offs_m[:, None] < qo_len) + & (offs_n[None, :] < kv_len - start_n), + other=False, + ) + if tl.max(mask_block) == 0: + skip = True + else: + mask_block = tl.load( + mask_ptrs + start_n * stride_maskn, + mask=(offs_m[:, None] < qo_len) + & (offs_n[None, :] < kv_len - start_n), + other=-1.0e6, + ) + if not skip: + k_mask = offs_n[None, :] < (kv_len - start_n) + k = tl.load(K_ptrs, mask=k_mask) + k_scale = tl.load(K_scale_ptr) + + qk = tl.dot(q, k).to(tl.float32) * (q_scale * k_scale) + + if mask_block is not None: + if mask_block.dtype == tl.int1: + qk = qk + tl.where(mask_block, 0, -1.0e6) + else: + qk = qk + mask_block + else: + qk += tl.where(k_mask, 0, -1.0e6) + + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk = qk - m_ij[:, None] + p = tl.math.exp2(qk) + l_ij = tl.sum(p, 1) + + alpha = tl.math.exp2(m_i - m_ij) + l_i = l_i * alpha + l_ij + + acc = acc * alpha[:, None] + + v = tl.load(V_ptrs, mask=offs_n[:, None] < (kv_len - start_n)) + p = p.to(tl.float16) + + acc += tl.dot(p, v, out_dtype=tl.float16) + m_i = m_ij + K_ptrs += BLOCK_N * stride_kn + K_scale_ptr += 1 + V_ptrs += BLOCK_N * stride_vn + return acc, l_i, m_i + + +@triton.jit +def _attn_fwd( + Q, + K, + V, + Q_scale, + K_scale, + Out, + mask, + Lse, + stride_qz, + stride_qh, + stride_qn, + stride_kz, + stride_kh, + stride_kn, + stride_vz, + stride_vh, + stride_vn, + stride_oz, + stride_oh, + stride_on, + stride_maskz, + stride_maskh, + stride_maskm, + stride_maskn, + qo_len, + kv_len, + H: tl.constexpr, + num_kv_groups: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + STAGE: tl.constexpr, + RETURN_LSE: tl.constexpr, +): + start_m = tl.program_id(0) + + off_z = tl.program_id(2).to(tl.int64) + off_h = tl.program_id(1).to(tl.int64) + + q_scale_offset = (off_z * H + off_h) * tl.cdiv(qo_len, BLOCK_M) + k_scale_offset = ( + off_z * (H // num_kv_groups) + off_h // num_kv_groups + ) * tl.cdiv(kv_len, BLOCK_N) + + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, HEAD_DIM) + Q_ptrs = ( + Q + + (off_z * stride_qz + off_h * stride_qh) + + offs_m[:, None] * stride_qn + + offs_k[None, :] + ) + Q_scale_ptr = Q_scale + q_scale_offset + start_m + K_ptrs = ( + K + + (off_z * stride_kz + (off_h // num_kv_groups) * stride_kh) + + offs_n[None, :] * stride_kn + + offs_k[:, None] + ) + K_scale_ptr = K_scale + k_scale_offset + V_ptrs = ( + V + + (off_z * stride_vz + (off_h // num_kv_groups) * stride_vh) + + offs_n[:, None] * stride_vn + + offs_k[None, :] + ) + O_block_ptr = ( + Out + + (off_z * stride_oz + off_h * stride_oh) + + offs_m[:, None] * stride_on + + offs_k[None, :] + ) + if mask is None: + mask_ptrs = None + else: + mask_ptrs = ( + mask + + (off_z * stride_maskz + off_h * stride_maskh) + + offs_m[:, None] * stride_maskm + + offs_n[None, :] * stride_maskn + ) + + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 + acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + + q = tl.load(Q_ptrs, mask=offs_m[:, None] < qo_len) + q_scale = tl.load(Q_scale_ptr) + acc, l_i, m_i = _attn_fwd_inner( + acc, + l_i, + m_i, + q, + q_scale, + qo_len, + kv_len, + K_ptrs, + K_scale_ptr, + V_ptrs, + stride_kn, + stride_vn, + start_m, + mask_ptrs, + stride_maskn, + BLOCK_M, + HEAD_DIM, + BLOCK_N, + 4 - STAGE, + offs_m, + offs_n, + ) + acc = acc / l_i[:, None] + tl.store( + O_block_ptr, + acc.to(Out.type.element_ty), + mask=(offs_m[:, None] < qo_len), + ) + + if RETURN_LSE: + lse_ptrs = Lse + (off_z * qo_len * H + off_h * qo_len) + offs_m + l_i = tl.log2(l_i) + m_i + tl.store(lse_ptrs, l_i, mask=(offs_m < qo_len)) + + +def forward( + q, + k, + v, + q_scale, + k_scale, + tensor_layout="HND", + attn_mask=None, + output_dtype=torch.float16, + return_lse=False, +): + BLOCK_M = 128 + BLOCK_N = 64 + stage = 1 + + o = torch.empty(q.shape, dtype=output_dtype, device=q.device) + + if tensor_layout == "HND": + b, h_qo, qo_len, head_dim = q.shape + _, h_kv, kv_len, _ = k.shape + + stride_bz_q, stride_h_q, stride_seq_q = ( + q.stride(0), + q.stride(1), + q.stride(2), + ) + stride_bz_k, stride_h_k, stride_seq_k = ( + k.stride(0), + k.stride(1), + k.stride(2), + ) + stride_bz_v, stride_h_v, stride_seq_v = ( + v.stride(0), + v.stride(1), + v.stride(2), + ) + stride_bz_o, stride_h_o, stride_seq_o = ( + o.stride(0), + o.stride(1), + o.stride(2), + ) + elif tensor_layout == "NHD": + b, qo_len, h_qo, head_dim = q.shape + _, kv_len, h_kv, _ = k.shape + + stride_bz_q, stride_h_q, stride_seq_q = ( + q.stride(0), + q.stride(2), + q.stride(1), + ) + stride_bz_k, stride_h_k, stride_seq_k = ( + k.stride(0), + k.stride(2), + k.stride(1), + ) + stride_bz_v, stride_h_v, stride_seq_v = ( + v.stride(0), + v.stride(2), + v.stride(1), + ) + stride_bz_o, stride_h_o, stride_seq_o = ( + o.stride(0), + o.stride(2), + o.stride(1), + ) + else: + raise ValueError(f"tensor_layout {tensor_layout} not supported") + + if attn_mask is not None: + stride_bz_mask, stride_h_mask, stride_m_mask, stride_n_mask = ( + attn_mask.stride(0), + attn_mask.stride(1), + attn_mask.stride(2), + attn_mask.stride(3), + ) + else: + stride_bz_mask, stride_h_mask, stride_m_mask, stride_n_mask = 0, 0, 0, 0 + + HEAD_DIM_K = head_dim + num_kv_groups = h_qo // h_kv + + if return_lse: + lse = torch.empty( + [b, h_qo, qo_len], dtype=torch.float32, device=q.device + ) + else: + lse = torch.empty([0], dtype=torch.float32, device="cpu") + + grid = (triton.cdiv(qo_len, BLOCK_M), h_qo, b) + _attn_fwd[grid]( + q, + k, + v, + q_scale, + k_scale, + o, + attn_mask, + lse, + stride_bz_q, + stride_h_q, + stride_seq_q, + stride_bz_k, + stride_h_k, + stride_seq_k, + stride_bz_v, + stride_h_v, + stride_seq_v, + stride_bz_o, + stride_h_o, + stride_seq_o, + stride_bz_mask, + stride_h_mask, + stride_m_mask, + stride_n_mask, + qo_len, + kv_len, + h_qo, + num_kv_groups, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + HEAD_DIM=HEAD_DIM_K, + STAGE=stage, + RETURN_LSE=return_lse, + num_warps=4 if head_dim == 64 else 8, + num_stages=3 if head_dim == 64 else 4, + ) + + return o, lse diff --git a/aphrodite/attention/ops/sage_attention/attn_qk_int8_per_block_causal.py b/aphrodite/attention/ops/sage_attention/attn_qk_int8_per_block_causal.py new file mode 100644 index 0000000000..4ee1e8bbd1 --- /dev/null +++ b/aphrodite/attention/ops/sage_attention/attn_qk_int8_per_block_causal.py @@ -0,0 +1,330 @@ +""" +Copyright (c) 2024 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch + +from aphrodite.triton_utils import tl, triton + + +@triton.jit +def _attn_fwd_inner( + acc, + l_i, + m_i, + q, + q_scale, + kv_len, + K_ptrs, + K_scale_ptr, + V_ptrs, + stride_kn, + stride_vn, + start_m, + BLOCK_M: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_N: tl.constexpr, + STAGE: tl.constexpr, + offs_m: tl.constexpr, + offs_n: tl.constexpr, +): + if STAGE == 1: + lo, hi = 0, start_m * BLOCK_M + elif STAGE == 2: + lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M + lo = tl.multiple_of(lo, BLOCK_M) + K_scale_ptr += lo // BLOCK_N + K_ptrs += stride_kn * lo + V_ptrs += stride_vn * lo + for start_n in range(lo, hi, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + k_mask = offs_n[None, :] < (kv_len - start_n) + k = tl.load(K_ptrs, mask=k_mask) + k_scale = tl.load(K_scale_ptr) + qk = tl.dot(q, k).to(tl.float32) * (q_scale * k_scale) + + mask = k_mask + if STAGE == 2: + mask &= offs_m[:, None] >= (start_n + offs_n[None, :]) + qk += tl.where(mask, 0, float("-inf")) + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk -= m_ij[:, None] + + p = tl.math.exp2(qk) + l_ij = tl.sum(p, 1) + + alpha = tl.math.exp2(m_i - m_ij) + l_i = l_i * alpha + l_ij + + acc = acc * alpha[:, None] + + v = tl.load(V_ptrs, mask=offs_n[:, None] < (kv_len - start_n)) + p = p.to(tl.float16) + + acc += tl.dot(p, v, out_dtype=tl.float16) + m_i = m_ij + K_ptrs += BLOCK_N * stride_kn + K_scale_ptr += 1 + V_ptrs += BLOCK_N * stride_vn + return acc, l_i, m_i + + +@triton.jit +def _attn_fwd( + Q, + K, + V, + Q_scale, + K_scale, + Out, + Lse, + stride_qz, + stride_qh, + stride_qn, + stride_kz, + stride_kh, + stride_kn, + stride_vz, + stride_vh, + stride_vn, + stride_oz, + stride_oh, + stride_on, + qo_len, + kv_len, + H: tl.constexpr, + num_kv_groups: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + STAGE: tl.constexpr, + RETURN_LSE: tl.constexpr, +): + start_m = tl.program_id(0) + + off_z = tl.program_id(2).to(tl.int64) + off_h = tl.program_id(1).to(tl.int64) + + q_scale_offset = (off_z * H + off_h) * tl.cdiv(qo_len, BLOCK_M) + k_scale_offset = ( + off_z * (H // num_kv_groups) + off_h // num_kv_groups + ) * tl.cdiv(kv_len, BLOCK_N) + + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, HEAD_DIM) + Q_ptrs = ( + Q + + (off_z * stride_qz + off_h * stride_qh) + + offs_m[:, None] * stride_qn + + offs_k[None, :] + ) + Q_scale_ptr = Q_scale + q_scale_offset + start_m + K_ptrs = ( + K + + (off_z * stride_kz + (off_h // num_kv_groups) * stride_kh) + + offs_n[None, :] * stride_kn + + offs_k[:, None] + ) + K_scale_ptr = K_scale + k_scale_offset + V_ptrs = ( + V + + (off_z * stride_vz + (off_h // num_kv_groups) * stride_vh) + + offs_n[:, None] * stride_vn + + offs_k[None, :] + ) + O_block_ptr = ( + Out + + (off_z * stride_oz + off_h * stride_oh) + + offs_m[:, None] * stride_on + + offs_k[None, :] + ) + + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 + acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + + q = tl.load(Q_ptrs, mask=offs_m[:, None] < qo_len) + q_scale = tl.load(Q_scale_ptr) + acc, l_i, m_i = _attn_fwd_inner( + acc, + l_i, + m_i, + q, + q_scale, + kv_len, + K_ptrs, + K_scale_ptr, + V_ptrs, + stride_kn, + stride_vn, + start_m, + BLOCK_M, + HEAD_DIM, + BLOCK_N, + 4 - STAGE, + offs_m, + offs_n, + ) + + acc, l_i, m_i = _attn_fwd_inner( + acc, + l_i, + m_i, + q, + q_scale, + kv_len, + K_ptrs, + K_scale_ptr, + V_ptrs, + stride_kn, + stride_vn, + start_m, + BLOCK_M, + HEAD_DIM, + BLOCK_N, + 2, + offs_m, + offs_n, + ) + acc = acc / l_i[:, None] + tl.store( + O_block_ptr, + acc.to(Out.type.element_ty), + mask=(offs_m[:, None] < qo_len), + ) + + if RETURN_LSE: + lse_ptrs = Lse + (off_z * qo_len * H + off_h * qo_len) + offs_m + l_i = tl.log2(l_i) + m_i + tl.store(lse_ptrs, l_i, mask=(offs_m < qo_len)) + + +def forward( + q, + k, + v, + q_scale, + k_scale, + tensor_layout="HND", + output_dtype=torch.float16, + return_lse=False, +): + BLOCK_M = 128 + BLOCK_N = 64 + stage = 3 + + o = torch.empty(q.shape, dtype=output_dtype, device=q.device) + + if tensor_layout == "HND": + b, h_qo, qo_len, head_dim = q.shape + _, h_kv, kv_len, _ = k.shape + + stride_bz_q, stride_h_q, stride_seq_q = ( + q.stride(0), + q.stride(1), + q.stride(2), + ) + stride_bz_k, stride_h_k, stride_seq_k = ( + k.stride(0), + k.stride(1), + k.stride(2), + ) + stride_bz_v, stride_h_v, stride_seq_v = ( + v.stride(0), + v.stride(1), + v.stride(2), + ) + stride_bz_o, stride_h_o, stride_seq_o = ( + o.stride(0), + o.stride(1), + o.stride(2), + ) + elif tensor_layout == "NHD": + b, qo_len, h_qo, head_dim = q.shape + _, kv_len, h_kv, _ = k.shape + + stride_bz_q, stride_h_q, stride_seq_q = ( + q.stride(0), + q.stride(2), + q.stride(1), + ) + stride_bz_k, stride_h_k, stride_seq_k = ( + k.stride(0), + k.stride(2), + k.stride(1), + ) + stride_bz_v, stride_h_v, stride_seq_v = ( + v.stride(0), + v.stride(2), + v.stride(1), + ) + stride_bz_o, stride_h_o, stride_seq_o = ( + o.stride(0), + o.stride(2), + o.stride(1), + ) + else: + raise ValueError(f"tensor_layout {tensor_layout} not supported") + + assert qo_len == kv_len, ( + "qo_len and kv_len must be equal for causal attention" + ) + + HEAD_DIM_K = head_dim + num_kv_groups = h_qo // h_kv + + if return_lse: + lse = torch.empty( + [b, h_qo, qo_len], dtype=torch.float32, device=q.device + ) + else: + lse = torch.empty([0], dtype=torch.float32, device="cpu") + + grid = (triton.cdiv(qo_len, BLOCK_M), h_qo, b) + _attn_fwd[grid]( + q, + k, + v, + q_scale, + k_scale, + o, + lse, + stride_bz_q, + stride_h_q, + stride_seq_q, + stride_bz_k, + stride_h_k, + stride_seq_k, + stride_bz_v, + stride_h_v, + stride_seq_v, + stride_bz_o, + stride_h_o, + stride_seq_o, + qo_len, + kv_len, + h_qo, + num_kv_groups, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + HEAD_DIM=HEAD_DIM_K, + STAGE=stage, + RETURN_LSE=return_lse, + num_warps=4 if head_dim == 64 else 8, + num_stages=4, + ) + + return o, lse diff --git a/aphrodite/attention/ops/sage_attention/attn_qk_int8_per_block_causal_varlen.py b/aphrodite/attention/ops/sage_attention/attn_qk_int8_per_block_causal_varlen.py new file mode 100644 index 0000000000..c70171bdf8 --- /dev/null +++ b/aphrodite/attention/ops/sage_attention/attn_qk_int8_per_block_causal_varlen.py @@ -0,0 +1,205 @@ +""" +Copyright (c) 2024 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch + +from aphrodite.triton_utils import tl, triton + + +@triton.jit +def _attn_fwd_inner( + acc, + l_i, + m_i, + q, + q_scale, + kv_len, + K_ptrs, + K_scale_ptr, + V_ptrs, + stride_kn, + stride_vn, + start_m, + H: tl.constexpr, + BLOCK_M: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_N: tl.constexpr, + STAGE: tl.constexpr, + offs_m: tl.constexpr, + offs_n: tl.constexpr, +): + if STAGE == 1: + lo, hi = 0, start_m * BLOCK_M + elif STAGE == 2: + lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M + lo = tl.multiple_of(lo, BLOCK_M) + K_scale_ptr += (lo // BLOCK_N) * H + K_ptrs += stride_kn * lo + V_ptrs += stride_vn * lo + for start_n in range(lo, hi, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + k_mask = offs_n[None, :] < (kv_len - start_n) + k = tl.load(K_ptrs, mask=k_mask) + k_scale = tl.load(K_scale_ptr) + qk = tl.dot(q, k).to(tl.float32) * (q_scale * k_scale) + + mask = k_mask + if STAGE == 2: + mask &= offs_m[:, None] >= (start_n + offs_n[None, :]) + qk += tl.where(mask, 0, float('-inf')) + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk -= m_ij[:, None] + + p = tl.math.exp2(qk) + l_ij = tl.sum(p, 1) + + alpha = tl.math.exp2(m_i - m_ij) + l_i = l_i * alpha + l_ij + + acc = acc * alpha[:, None] + + v = tl.load(V_ptrs, mask=offs_n[:, None] < (kv_len - start_n)) + p = p.to(tl.float16) + + acc += tl.dot(p, v, out_dtype=tl.float16) + m_i = m_ij + K_ptrs += BLOCK_N * stride_kn + K_scale_ptr += H + V_ptrs += BLOCK_N * stride_vn + return acc, l_i, m_i + + +@triton.jit +def _attn_fwd(Q, K, V, + cu_seqlens_q, cu_seqlens_k, + Q_scale, K_scale, cu_seqlens_q_scale, cu_seqlens_k_scale, + Out, + stride_qh, stride_qn, + stride_kh, stride_kn, + stride_vh, stride_vn, + stride_oh, stride_on, + H: tl.constexpr, num_kv_groups: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + STAGE: tl.constexpr + ): + start_m = tl.program_id(0) + + off_z = tl.program_id(2).to(tl.int64) + off_h = tl.program_id(1).to(tl.int64) + + cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) + cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) + + qo_len = cu_seqlens_q_end - cu_seqlens_q_start + + if (start_m * BLOCK_M) >= qo_len: + return + + cu_seq_lens_q_scale_start = tl.load(cu_seqlens_q_scale + off_z) + cu_seq_lens_k_scale_start = tl.load(cu_seqlens_k_scale + off_z) + + q_scale_offset = cu_seq_lens_q_scale_start * H + off_h + start_m * H + k_scale_offset = cu_seq_lens_k_scale_start * ( + H // num_kv_groups) + off_h // num_kv_groups + + cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z) + cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1) + + kv_len = cu_seqlens_k_end - cu_seqlens_k_start + + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, HEAD_DIM) + Q_ptrs = Q + ((cu_seqlens_q_start * stride_qn + off_h * stride_qh) + + offs_m[:, None] * stride_qn + offs_k[None, :]) + Q_scale_ptr = Q_scale + q_scale_offset + K_ptrs = K + (cu_seqlens_k_start * stride_kn + (off_h // num_kv_groups) * + stride_kh) + offs_n[None, :] * stride_kn + offs_k[:, None] + K_scale_ptr = K_scale + k_scale_offset + V_ptrs = V + (cu_seqlens_k_start * stride_vn + (off_h // num_kv_groups) * + stride_vh) + offs_n[:, None] * stride_vn + offs_k[None, :] + O_block_ptr = Out + ((cu_seqlens_q_start * stride_on + off_h * stride_oh) + + offs_m[:, None] * stride_on + offs_k[None, :]) + + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 + acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + + q = tl.load(Q_ptrs, mask=offs_m[:, None] < qo_len) + q_scale = tl.load(Q_scale_ptr) + acc, l_i, m_i = _attn_fwd_inner( + acc, l_i, m_i, q, q_scale, kv_len, K_ptrs, + K_scale_ptr, V_ptrs, stride_kn, stride_vn, + start_m, H // num_kv_groups, + BLOCK_M, HEAD_DIM, BLOCK_N, + 4 - STAGE, offs_m, offs_n + ) + + acc, l_i, _ = _attn_fwd_inner( + acc, l_i, m_i, q, q_scale, kv_len, K_ptrs, + K_scale_ptr, V_ptrs, stride_kn, stride_vn, + start_m, H // num_kv_groups, + BLOCK_M, HEAD_DIM, BLOCK_N, + 2, offs_m, offs_n + ) + acc = acc / l_i[:, None] + tl.store(O_block_ptr, acc.to(Out.type.element_ty), + mask=(offs_m[:, None] < qo_len)) + + +def forward( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + q_scale, + k_scale, + cu_seqlens_q_scale, + cu_seqlens_k_scale, + output_dtype=torch.float16, +): + BLOCK_M = 128 + BLOCK_N = 64 + stage = 3 + + o = torch.empty(q.shape, dtype=output_dtype, device=q.device) + + b = cu_seqlens_q.shape[0] - 1 + _, h_qo, head_dim = q.shape + _, h_kv, _ = k.shape + + HEAD_DIM_K = head_dim + num_kv_groups = h_qo // h_kv + + grid = (triton.cdiv(max_seqlen_q, BLOCK_M), h_qo, b) + _attn_fwd[grid]( + q, k, v, cu_seqlens_q, cu_seqlens_k, + q_scale, k_scale, cu_seqlens_q_scale, cu_seqlens_k_scale, + o, + q.stride(1), q.stride(0), + k.stride(1), k.stride(0), + v.stride(1), v.stride(0), + o.stride(1), o.stride(0), + h_qo, num_kv_groups, + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, HEAD_DIM=HEAD_DIM_K, + STAGE=stage, + num_warps=4 if head_dim == 64 else 8, + num_stages=4) + return o diff --git a/aphrodite/attention/ops/sage_attention/core.py b/aphrodite/attention/ops/sage_attention/core.py new file mode 100644 index 0000000000..d885449692 --- /dev/null +++ b/aphrodite/attention/ops/sage_attention/core.py @@ -0,0 +1,1367 @@ +""" +Copyright (c) 2024 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from loguru import logger +from typing import Any, Optional + +import torch + +import aphrodite._custom_ops as ops +from aphrodite.platforms import current_platform + +from .attn_qk_int8_block_varlen import forward as attn_false_varlen +from .attn_qk_int8_per_block import forward as attn_false +from .attn_qk_int8_per_block_causal import forward as attn_true +from .attn_qk_int8_per_block_causal_varlen import forward as attn_true_varlen +from .quant import per_block_int8 as per_block_int8_cuda +from .quant import per_channel_fp8 +from .quant import per_warp_int8 as per_warp_int8_cuda +from .quant import sub_mean +from .quant_per_block import per_block_int8 as per_block_int8_triton +from .quant_per_block_varlen import ( + per_block_int8 as per_block_int8_varlen_triton) +from .quant_per_thread import per_thread_int8 as per_thread_int8_triton + + +def sageattn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tensor_layout: str = "HND", + is_causal: bool = False, + sm_scale: Optional[float] = None, + return_lse: bool = False, + **kwargs: Any, +): + """ + Automatically selects the appropriate implementation of the SageAttention kernel based on the GPU compute capability. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. + Default: False. + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + return_lse : bool + Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. + Default: False. + + Returns + ------- + torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + torch.Tensor + The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). + Shape: ``[batch_size, num_qo_heads, qo_len]``. + Only returned if `return_lse` is True. + + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - All tensors must be on the same cuda device. + """ # noqa: E501 + + arch = current_platform.get_device_capability().as_version_str() + if arch == "8.0": + return sageattn_qk_int8_pv_fp16_cuda( + q, + k, + v, + tensor_layout=tensor_layout, + is_causal=is_causal, + sm_scale=sm_scale, + return_lse=return_lse, + pv_accum_dtype="fp32", + ) + elif arch == "8.6": + return sageattn_qk_int8_pv_fp16_triton( + q, + k, + v, + tensor_layout=tensor_layout, + is_causal=is_causal, + sm_scale=sm_scale, + return_lse=return_lse, + ) + elif arch == "8.9": + return sageattn_qk_int8_pv_fp8_cuda( + q, + k, + v, + tensor_layout=tensor_layout, + is_causal=is_causal, + sm_scale=sm_scale, + return_lse=return_lse, + pv_accum_dtype="fp32+fp16", + ) + elif arch == "9.0": + return sageattn_qk_int8_pv_fp8_cuda_sm90( + q, + k, + v, + tensor_layout=tensor_layout, + is_causal=is_causal, + sm_scale=sm_scale, + return_lse=return_lse, + pv_accum_dtype="fp32+fp32", + ) + elif arch == "12.0": + return sageattn_qk_int8_pv_fp8_cuda( + q, + k, + v, + tensor_layout=tensor_layout, + is_causal=is_causal, + qk_quant_gran="per_warp", + sm_scale=sm_scale, + return_lse=return_lse, + pv_accum_dtype="fp32+fp16", + ) # sm120 has accurate fp32 accumulator for fp8 mma and triton kernel + # is currently not usable on sm120. + else: + raise ValueError(f"Unsupported CUDA architecture: {arch}") + + +@torch.compiler.disable +def sageattn_qk_int8_pv_fp16_triton( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tensor_layout: str = "HND", + quantization_backend: str = "triton", + is_causal: bool = False, + attn_mask: Optional[torch.Tensor] = None, + sm_scale: Optional[float] = None, + smooth_k: bool = True, + return_lse: bool = False, + **kwargs: Any, +) -> torch.Tensor: + """ + SageAttention with per-block INT8 quantization for Q and K, FP16 PV with FP16 accumulation, implemented using Triton. + The FP16 accumulator is added to a FP32 buffer immediately after each iteration. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + quantization_backend : str + The quantization backend, either "triton" or "cuda". + "cuda" backend offers better performance due to kernel fusion. + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. + Default: False. + + attn_mask : Optional[torch.Tensor] + The attention mask tensor, of dtype bool or float32. + Should be able to broadcast to the shape of the matrix qk^T. + Default: None. + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + smooth_k : bool + Whether to smooth the key tensor by subtracting the mean along the sequence dimension. + Default: True. + + return_lse : bool + Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. + Default: False. + + Returns + ------- + torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + torch.Tensor + The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). + Shape: ``[batch_size, num_qo_heads, qo_len]``. + Only returned if `return_lse` is True. + + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16``, ``torch.bfloat16`` or ``torch.float32``. + - All tensors must be on the same cuda device. + - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances. + """ # noqa: E501 + + dtype = q.dtype + assert q.is_cuda, "Input tensors must be on cuda." + assert dtype in [torch.float16, torch.bfloat16], ( + "Input tensors must be in dtype of torch.float16 or torch.bfloat16" + ) + assert q.device == k.device == v.device, ( + "All tensors must be on the same device." + ) + assert q.dtype == k.dtype == v.dtype, ( + "All tensors must have the same dtype." + ) + + if attn_mask is not None: + assert attn_mask.dtype == torch.bool or attn_mask.dtype == q.dtype, ( + "attn_mask must be of dtype bool or the same dtype as q." + ) + assert attn_mask.device == q.device, ( + "All tensors must be on the same device." + ) + + # FIXME(DefTruth): make sage attention work compatible with distributed + # env, for example, xDiT which launch by torchrun. Without this workaround, + # sage attention will run into illegal memory access error after first + # inference step in distributed env for multi gpus inference. This small + # workaround also make sage attention work compatible with torch.compile + # through non-fullgraph compile mode. + torch.cuda.set_device(v.device) + + head_dim_og = q.size(-1) + + if head_dim_og < 64: + q = torch.nn.functional.pad(q, (0, 64 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 64 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 64 - head_dim_og)) + elif head_dim_og > 64 and head_dim_og < 128: + q = torch.nn.functional.pad(q, (0, 128 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 128 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 128 - head_dim_og)) + elif head_dim_og > 128: + raise ValueError(f"Unsupported head_dim: {head_dim_og}") + + # assert last dim is contiguous + assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, ( + "Last dim of qkv must be contiguous." + ) + + seq_dim = 1 if tensor_layout == "NHD" else 2 + + if smooth_k: + km = k.mean(dim=seq_dim, keepdim=True) + if return_lse: + if tensor_layout == "NHD": + lse_correction = ( + torch.matmul( + q.transpose(1, 2), km.transpose(1, 2).transpose(2, 3) + ) + .squeeze(-1) + .to(torch.float32) + ) + else: + lse_correction = ( + torch.matmul(q, km.transpose(2, 3)) + .squeeze(-1) + .to(torch.float32) + ) + else: + km = None + + if dtype == torch.bfloat16 or dtype == torch.float32: + v = v.to(torch.float16) + + if sm_scale is None: + sm_scale = 1.0 / (head_dim_og**0.5) + + if quantization_backend == "triton": + q_int8, q_scale, k_int8, k_scale = per_block_int8_triton( + q, k, km=km, sm_scale=sm_scale, tensor_layout=tensor_layout + ) + elif quantization_backend == "cuda": + q_int8, q_scale, k_int8, k_scale = per_block_int8_cuda( + q, k, km=km, sm_scale=sm_scale, tensor_layout=tensor_layout + ) + else: + raise ValueError( + f"Unsupported quantization backend: {quantization_backend}" + ) + if is_causal: + assert attn_mask is None, "Mask should be None for causal attention." + o, lse = attn_true( + q_int8, + k_int8, + v, + q_scale, + k_scale, + tensor_layout=tensor_layout, + output_dtype=dtype, + return_lse=return_lse, + ) + else: + if attn_mask is not None: + if tensor_layout == "HND": + target_shape = (q.shape[0], q.shape[1], q.shape[2], k.shape[2]) + elif tensor_layout == "NHD": + target_shape = (q.shape[0], q.shape[2], q.shape[1], k.shape[1]) + else: + raise ValueError( + f"tensor_layout {tensor_layout} not supported") + try: + attn_mask = attn_mask.expand(target_shape) + except Exception as e: + raise AssertionError( + f"attn_mask shape {attn_mask.shape} cannot be broadcast " + f"to {target_shape}" + ) from e + o, lse = attn_false( + q_int8, + k_int8, + v, + q_scale, + k_scale, + tensor_layout=tensor_layout, + output_dtype=dtype, + attn_mask=attn_mask, + return_lse=return_lse, + ) + + o = o[..., :head_dim_og] + + if return_lse: + return ( + o, + lse / 1.44269504 + lse_correction * sm_scale + if smooth_k + else lse / 1.44269504, + ) + else: + return o + + +@torch.compiler.disable +def sageattn_varlen( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + is_causal: bool = False, + sm_scale: Optional[float] = None, + smooth_k: bool = True, + **kwargs: Any, +) -> torch.Tensor: + """ + + Parameters + ---------- + q : torch.Tensor + The query tensor, shape: ``[cu_seqlens_q[-1], num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor, shape: ``[cu_seqlens_k[-1], num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor, shape: ``[cu_seqlens_k[-1], num_kv_heads, head_dim]``. + + cu_seqlens_q : torch.Tensor + The cumulative sequence lengths for the query sequences in the batch, used to index into `q`. + Shape: ``[batch_size + 1]``, where each entry represents the cumulative length of sequences up to that batch index. + + cu_seqlens_k : torch.Tensor + The cumulative sequence lengths for the key and value sequences in the batch, used to index into `k` and `v`. + Shape: ``[batch_size + 1]``, where each entry represents the cumulative length of sequences up to that batch index. + + max_seqlen_q : int + The maximum sequence length for the query tensor in the batch. + + max_seqlen_k : int + The maximum sequence length for the key and value tensors in the batch. + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len for each sequence. + Default: False. + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + smooth_k : bool + Whether to smooth the key tensor by subtracting the mean along the sequence dimension. + Default: True. + + Returns + ------- + torch.Tensor + The output tensor, shape: ``[cu_seqlens_q[-1], num_qo_heads, head_dim]``. + + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16``, ``torch.bfloat16`` or ``torch.float32``. + - The tensors `cu_seqlens_q` and `cu_seqlens_k` must have the dtype ``torch.int32`` or ``torch.int64``. + - All tensors must be on the same cuda device. + - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances. + """ # noqa: E501 + + dtype = q.dtype + assert q.is_cuda, "Input tensors must be on cuda." + assert dtype in [torch.float16, torch.bfloat16], ( + "Input tensors must be in dtype of torch.float16 or torch.bfloat16" + ) + assert q.device == k.device == v.device, ( + "All tensors must be on the same device." + ) + assert q.dtype == k.dtype == v.dtype, ( + "All tensors must have the same dtype." + ) + + # FIXME(DefTruth): make sage attention work compatible with distributed + # env, for example, xDiT which launch by torchrun. Without this workaround, + # sage attention will run into illegal memory access error after first + # inference step in distributed env for multi gpus inference. This small + # workaround also make sage attention work compatible with torch.compile + # through non-fullgraph compile mode. + torch.cuda.set_device(v.device) + + head_dim_og = q.size(-1) + + if head_dim_og < 64: + q = torch.nn.functional.pad(q, (0, 64 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 64 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 64 - head_dim_og)) + elif head_dim_og > 64 and head_dim_og < 128: + q = torch.nn.functional.pad(q, (0, 128 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 128 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 128 - head_dim_og)) + elif head_dim_og > 128: + raise ValueError(f"Unsupported head_dim: {head_dim_og}") + + assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, ( + "Last dim of qkv must be contiguous." + ) + assert cu_seqlens_q.is_contiguous() and cu_seqlens_k.is_contiguous(), ( + "cu_seqlens_q and cu_seqlens_k must be contiguous." + ) + + if dtype == torch.bfloat16 or dtype == torch.float32: + v = v.to(torch.float16) + + if smooth_k: + km = k.mean( + dim=0, keepdim=True + ) # ! km is calculated on the all the batches. Calculate over each + # individual sequence requires dedicated kernel. + k = k - km + + if sm_scale is None: + sm_scale = 1.0 / (head_dim_og**0.5) + + q_int8, q_scale, k_int8, k_scale, cu_seqlens_q_scale, cu_seqlens_k_scale = ( # noqa: E501 + per_block_int8_varlen_triton( + q, + k, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + sm_scale=sm_scale, + ) + ) + + if is_causal: + o = attn_true_varlen( + q_int8, + k_int8, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + q_scale, + k_scale, + cu_seqlens_q_scale, + cu_seqlens_k_scale, + output_dtype=dtype, + ) + else: + o = attn_false_varlen( + q_int8, + k_int8, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + q_scale, + k_scale, + cu_seqlens_q_scale, + cu_seqlens_k_scale, + output_dtype=dtype, + ) + + o = o[..., :head_dim_og] + + return o + + +@torch.compiler.disable +def sageattn_qk_int8_pv_fp16_cuda( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tensor_layout: str = "HND", + is_causal: bool = False, + qk_quant_gran: str = "per_thread", + sm_scale: Optional[float] = None, + pv_accum_dtype: str = "fp32", + smooth_k: bool = True, + smooth_v: bool = False, + return_lse: bool = False, + **kwargs: Any, +) -> torch.Tensor: + """ + SageAttention with INT8 quantization for Q and K, FP16 PV with FP16/FP32 accumulation, implemented using CUDA. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. + Default: False. + + qk_quant_gran : str + The granularity of quantization for Q and K, either "per_warp" or "per_thread". + Default: "per_thread". + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + pv_accum_dtype : str + The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp16", "fp16+fp32" or "fp32". + - "fp16": PV accumulation is done in fully in FP16. This is the fastest option but may lead to numerical instability. `smooth_v` option will increase the accuracy in cases when the value tensor has a large bias (like in CogVideoX-2b). + - "fp32": PV accumulation is done in FP32. This is the most accurate option but may be slower than "fp16" due to CUDA core overhead. + - "fp16+fp32": PV accumulation is done in FP16, but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy. + Default: "fp32". + + smooth_k : bool + Whether to smooth the key tensor by subtracting the mean along the sequence dimension. + Default: True. + + smooth_v : bool + Whether to smooth the value tensor by subtracting the mean along the sequence dimension. + smooth_v will be ignored if pv_accum_dtype is "fp32" or "fp16+fp32". + Default: False. + + return_lse : bool + Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. + Default: False. + + Returns + ------- + torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + torch.Tensor + The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). + Shape: ``[batch_size, num_qo_heads, qo_len]``. + Only returned if `return_lse` is True. + + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - All tensors must be on the same cuda device. + - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances. + """ # noqa: E501 + + dtype = q.dtype + assert current_platform.get_device_capability().as_version_str() == "8.0", ( # noqa: E501 + "SM80 kernel is not available. make sure you GPUs with compute" + " capability 8.0 or higher." + ) + assert q.is_cuda, "Input tensors must be on cuda." + assert dtype in [torch.float16, torch.bfloat16], ( + "Input tensors must be in dtype of torch.float16 or torch.bfloat16" + ) + assert qk_quant_gran in ["per_warp", "per_thread"], ( + "qk_quant_gran must be either 'per_warp' or 'per_thread'." + ) + assert q.device == k.device == v.device, ( + "All tensors must be on the same device." + ) + assert q.dtype == k.dtype == v.dtype, ( + "All tensors must have the same dtype." + ) + + # FIXME(DefTruth): make sage attention work compatible with distributed + # env, for example, xDiT which launch by torchrun. Without this workaround, + # sage attention will run into illegal memory access error after first + # inference step in distributed env for multi gpus inference. This small + # workaround also make sage attention work compatible with torch.compile + # through non-fullgraph compile mode. + torch.cuda.set_device(v.device) + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + _is_caual = 1 if is_causal else 0 + _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2 + _return_lse = 1 if return_lse else 0 + + head_dim_og = q.size(-1) + + if head_dim_og < 64: + q = torch.nn.functional.pad(q, (0, 64 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 64 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 64 - head_dim_og)) + elif head_dim_og > 64 and head_dim_og < 128: + q = torch.nn.functional.pad(q, (0, 128 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 128 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 128 - head_dim_og)) + elif head_dim_og > 128: + raise ValueError(f"Unsupported head_dim: {head_dim_og}") + + # assert last dim is contiguous + assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, ( + "Last dim of qkv must be contiguous." + ) + + if sm_scale is None: + sm_scale = head_dim_og**-0.5 + + seq_dim = 1 if _tensor_layout == 0 else 2 + + if smooth_k: + km = k.mean(dim=seq_dim, keepdim=True) + if return_lse: + if tensor_layout == "NHD": + lse_correction = ( + torch.matmul( + q.transpose(1, 2), km.transpose(1, 2).transpose(2, 3) + ) + .squeeze(-1) + .to(torch.float32) + ) + else: + lse_correction = ( + torch.matmul(q, km.transpose(2, 3)) + .squeeze(-1) + .to(torch.float32) + ) + else: + km = None + + if qk_quant_gran == "per_warp": + q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda( + q, + k, + km, + tensor_layout=tensor_layout, + BLKQ=128, + WARPQ=( + 16 + if (q.size(-1) == 128 and pv_accum_dtype == "fp16+fp32") + else 32 + ), + BLKK=64, + ) + elif qk_quant_gran == "per_thread": + q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton( + q, + k, + km, + tensor_layout=tensor_layout, + BLKQ=128, + WARPQ=( + 16 + if (q.size(-1) == 128 and pv_accum_dtype == "fp16+fp32") + else 32 + ), + BLKK=64, + WARPK=64, + ) + + o = torch.empty(q.size(), dtype=dtype, device=q.device) + + if pv_accum_dtype in ["fp32", "fp16+fp32"] and smooth_v: + logger.warning( + f"pv_accum_dtype is {pv_accum_dtype}, smooth_v will be ignored." + ) + smooth_v = False + + if pv_accum_dtype == "fp32": + v = v.to(torch.float16) + lse = ops.qk_int8_sv_f16_accum_f32_attn( + q_int8, + k_int8, + v, + o, + q_scale, + k_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + elif pv_accum_dtype == "fp16": + if smooth_v: + smoothed_v, vm = sub_mean(v, tensor_layout=tensor_layout) + lse = ops.qk_int8_sv_f16_accum_f16_fuse_v_mean_attn( + q_int8, + k_int8, + smoothed_v, + o, + q_scale, + k_scale, + vm, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + else: + v = v.to(torch.float16) + lse = ops.qk_int8_sv_f16_accum_f16_attn( + q_int8, + k_int8, + v, + o, + q_scale, + k_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + elif pv_accum_dtype == "fp16+fp32": + v = v.to(torch.float16) + lse = ops.qk_int8_sv_f16_accum_f16_attn_inst_buf( + q_int8, + k_int8, + v, + o, + q_scale, + k_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + else: + raise ValueError(f"Unsupported pv_accum_dtype: {pv_accum_dtype}") + + o = o[..., :head_dim_og] + + if return_lse: + return ( + o, + lse / 1.44269504 + lse_correction * sm_scale + if smooth_k + else lse / 1.44269504, + ) + else: + return o + + +@torch.compiler.disable +def sageattn_qk_int8_pv_fp8_cuda( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tensor_layout: str = "HND", + is_causal: bool = False, + qk_quant_gran: str = "per_thread", + sm_scale: Optional[float] = None, + pv_accum_dtype: str = "fp32+fp16", + smooth_k: bool = True, + smooth_v: bool = False, + return_lse: bool = False, + **kwargs: Any, +) -> torch.Tensor: + """ + SageAttention with INT8 quantization for Q and K, FP8 PV with FP32 accumulation, implemented using CUDA. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. + Default: False. + + qk_quant_gran : str + The granularity of quantization for Q and K, either "per_warp" or "per_thread". + Default: "per_thread". + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + pv_accum_dtype : str + The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp32" or "fp32+fp32". + - "fp32": PV accumulation is done in fully in FP32. However, due to the hardware issue, there are only 22 valid bits in the FP32 accumulator. + - "fp32+fp32": PV accumulation is done in FP32 (actually FP22), but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy. + Default: "fp32+fp32". + + smooth_k : bool + Whether to smooth the key tensor by subtracting the mean along the sequence dimension. + Default: True. + + smooth_v : bool + Whether to smooth the value tensor by subtracting the mean along the sequence dimension. + smooth_v will be ignored if pv_accum_dtype is "fp32+fp32". + Default: False. + + return_lse : bool + Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. + Default: False. + + Returns + ------- + torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + torch.Tensor + The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). + Shape: ``[batch_size, num_qo_heads, qo_len]``. + Only returned if `return_lse` is True. + + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - All tensors must be on the same cuda device. + - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances. + """ # noqa: E501 + + dtype = q.dtype + assert current_platform.get_device_capability().as_version_str() == "8.9", ( # noqa: E501 + "SM89 kernel is not available. Make sure you GPUs with compute" + " capability 8.9." + ) + assert q.is_cuda, "Input tensors must be on cuda." + assert dtype in [torch.float16, torch.bfloat16], ( + "Input tensors must be in dtype of torch.float16 or torch.bfloat16" + ) + assert qk_quant_gran in ["per_warp", "per_thread"], ( + "qk_quant_gran must be either 'per_warp' or 'per_thread'." + ) + assert q.device == k.device == v.device, ( + "All tensors must be on the same device." + ) + assert q.dtype == k.dtype == v.dtype, ( + "All tensors must have the same dtype." + ) + + # FIXME(DefTruth): make sage attention work compatible with distributed + # env, for example, xDiT which launch by torchrun. Without this workaround, + # sage attention will run into illegal memory access error after first + # inference step in distributed env for multi gpus inference. This small + # workaround also make sage attention work compatible with torch.compile + # through non-fullgraph compile mode. + torch.cuda.set_device(v.device) + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + _is_caual = 1 if is_causal else 0 + _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2 + _return_lse = 1 if return_lse else 0 + + head_dim_og = q.size(-1) + + if head_dim_og < 64: + q = torch.nn.functional.pad(q, (0, 64 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 64 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 64 - head_dim_og)) + elif head_dim_og > 64 and head_dim_og < 128: + q = torch.nn.functional.pad(q, (0, 128 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 128 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 128 - head_dim_og)) + elif head_dim_og > 128: + raise ValueError(f"Unsupported head_dim: {head_dim_og}") + + # assert last dim is contiguous + assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, ( + "Last dim of qkv must be contiguous." + ) + + if sm_scale is None: + sm_scale = head_dim_og**-0.5 + + seq_dim = 1 if _tensor_layout == 0 else 2 + + if smooth_k: + km = k.mean(dim=seq_dim, keepdim=True) + if return_lse: + if tensor_layout == "NHD": + lse_correction = ( + torch.matmul( + q.transpose(1, 2), km.transpose(1, 2).transpose(2, 3) + ) + .squeeze(-1) + .to(torch.float32) + ) + else: + lse_correction = ( + torch.matmul(q, km.transpose(2, 3)) + .squeeze(-1) + .to(torch.float32) + ) + else: + km = None + + if qk_quant_gran == "per_warp": + q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda( + q, k, km, tensor_layout=tensor_layout, BLKQ=128, WARPQ=32, BLKK=64 + ) + elif qk_quant_gran == "per_thread": + q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton( + q, + k, + km, + tensor_layout=tensor_layout, + BLKQ=128, + WARPQ=32, + BLKK=64, + WARPK=64, + ) + + o = torch.empty(q.size(), dtype=dtype, device=q.device) + + if pv_accum_dtype == "fp32+fp32" and smooth_v: + logger.warning( + "pv_accum_dtype is 'fp32+fp32', smooth_v will be ignored." + ) + smooth_v = False + + if pv_accum_dtype == "fp32+fp16" and smooth_v: + logger.warning( + "pv_accum_dtype is 'fp32+fp16', smooth_v will be ignored." + ) + smooth_v = False + + quant_v_scale_max = 448.0 + if pv_accum_dtype == "fp32+fp16": + quant_v_scale_max = 2.25 + + v_fp8, v_scale, vm = per_channel_fp8( + v, + tensor_layout=tensor_layout, + scale_max=quant_v_scale_max, + smooth_v=smooth_v, + ) + + if pv_accum_dtype == "fp32": + if smooth_v: + lse = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + vm, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + else: + lse = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + elif pv_accum_dtype == "fp32+fp32": + lse = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + elif pv_accum_dtype == "fp32+fp16": + lse = ops.qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + + o = o[..., :head_dim_og] + + if return_lse: + return ( + o, + lse / 1.44269504 + lse_correction * sm_scale + if smooth_k + else lse / 1.44269504, + ) + else: + return o + + +@torch.compiler.disable +def sageattn_qk_int8_pv_fp8_cuda_sm90( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tensor_layout: str = "HND", + is_causal: bool = False, + qk_quant_gran: str = "per_thread", + sm_scale: Optional[float] = None, + pv_accum_dtype: str = "fp32+fp32", + smooth_k: bool = True, + return_lse: bool = False, + **kwargs: Any, +) -> torch.Tensor: + """ + SageAttention with INT8 quantization for Q and K, FP8 PV with FP32 accumulation, implemented using CUDA. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. + Default: False. + + qk_quant_gran : str + The granularity of quantization for Q and K, either "per_warp" or "per_thread". + Default: "per_thread". + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + pv_accum_dtype : str + The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp32" or "fp32+fp32". + - "fp32": PV accumulation is done in fully in FP32. However, due to the hardware issue, there are only 22 valid bits in the FP32 accumulator. + - "fp32+fp32": PV accumulation is done in FP32 (actually FP22), but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy. + Default: "fp32+fp32". + + smooth_k : bool + Whether to smooth the key tensor by subtracting the mean along the sequence dimension. + Default: True. + + return_lse : bool + Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. + Default: False. + + Returns + ------- + torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + torch.Tensor + The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). + Shape: ``[batch_size, num_qo_heads, qo_len]``. + Only returned if `return_lse` is True. + + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - All tensors must be on the same cuda device. + - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances. + """ # noqa: E501 + + dtype = q.dtype + assert current_platform.get_device_capability().as_version_str() == "9.0", ( # noqa: E501 + "SM90 kernel is not available. Make sure you GPUs with compute" + " capability 9.0." + ) + assert q.is_cuda, "Input tensors must be on cuda." + assert dtype in [torch.float16, torch.bfloat16], ( + "Input tensors must be in dtype of torch.float16 or torch.bfloat16" + ) + assert qk_quant_gran in ["per_warp", "per_thread"], ( + "qk_quant_gran must be either 'per_warp' or 'per_thread'." + ) + assert q.device == k.device == v.device, ( + "All tensors must be on the same device." + ) + assert q.dtype == k.dtype == v.dtype, ( + "All tensors must have the same dtype." + ) + + torch.cuda.set_device(v.device) + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + _is_caual = 1 if is_causal else 0 + _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2 + _return_lse = 1 if return_lse else 0 + + head_dim_og = q.size(-1) + + if head_dim_og < 64: + q = torch.nn.functional.pad(q, (0, 64 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 64 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 64 - head_dim_og)) + elif head_dim_og > 64 and head_dim_og < 128: + q = torch.nn.functional.pad(q, (0, 128 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 128 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 128 - head_dim_og)) + elif head_dim_og > 128: + raise ValueError(f"Unsupported head_dim: {head_dim_og}") + + # assert last dim is contiguous + assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, ( + "Last dim of qkv must be contiguous." + ) + + if sm_scale is None: + sm_scale = head_dim_og**-0.5 + + seq_dim = 1 if _tensor_layout == 0 else 2 + + if smooth_k: + km = k.mean(dim=seq_dim, keepdim=True) + if return_lse: + if tensor_layout == "NHD": + lse_correction = ( + torch.matmul( + q.transpose(1, 2), km.transpose(1, 2).transpose(2, 3) + ) + .squeeze(-1) + .to(torch.float32) + ) + else: + lse_correction = ( + torch.matmul(q, km.transpose(2, 3)) + .squeeze(-1) + .to(torch.float32) + ) + else: + km = None + + if qk_quant_gran == "per_warp": + q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda( + q, k, km, tensor_layout=tensor_layout, BLKQ=64, WARPQ=16, BLKK=128 + ) + elif qk_quant_gran == "per_thread": + q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton( + q, + k, + km, + tensor_layout=tensor_layout, + BLKQ=64, + WARPQ=16, + BLKK=128, + WARPK=128, + ) + + o = torch.empty(q.size(), dtype=dtype, device=q.device) + + # pad v to multiple of 128 + # TODO: modify per_channel_fp8 kernel to handle this + kv_len = k.size(seq_dim) + v_pad_len = 128 - (kv_len % 128) if kv_len % 128 != 0 else 0 + if v_pad_len > 0: + if tensor_layout == "HND": + v = torch.cat( + [ + v, + torch.zeros( + v.size(0), + v.size(1), + v_pad_len, + v.size(3), + dtype=v.dtype, + device=v.device, + ), + ], + dim=2, + ) + else: + v = torch.cat( + [ + v, + torch.zeros( + v.size(0), + v_pad_len, + v.size(2), + v.size(3), + dtype=v.dtype, + device=v.device, + ), + ], + dim=1, + ) + + v_fp8, v_scale, _ = per_channel_fp8( + v, tensor_layout=tensor_layout, smooth_v=False + ) + + if pv_accum_dtype == "fp32": + raise NotImplementedError( + "Please use pv_accum_dtype='fp32+fp32' for sm90." + ) + lse = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + elif pv_accum_dtype == "fp32+fp32": + lse = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + + o = o[..., :head_dim_og] + + if return_lse: + return ( + o, + lse / 1.44269504 + lse_correction * sm_scale + if smooth_k + else lse / 1.44269504, + ) + else: + return o diff --git a/aphrodite/attention/ops/sage_attention/quant.py b/aphrodite/attention/ops/sage_attention/quant.py new file mode 100644 index 0000000000..7b1c4e38a9 --- /dev/null +++ b/aphrodite/attention/ops/sage_attention/quant.py @@ -0,0 +1,348 @@ +""" +Copyright (c) 2024 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from typing import Optional + +import torch + +import aphrodite._custom_ops as ops + + +def per_block_int8( + q: torch.Tensor, + k: torch.Tensor, + km: Optional[torch.Tensor] = None, + BLKQ: int = 128, + BLKK: int = 64, + sm_scale: Optional[float] = None, + tensor_layout: str = "HND", +): + """ + Quantize the query tensor `q` and the key tensor `k` with per block + quantization. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + km : Optional[torch.Tensor] + The mean tensor of `k` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]``. + Should be of the same dtype as `k` if provided. Default is None. + + sm_scale : Optional[float] + The scale factor for the softmax operation. Default is ``head_dim**-0.5``. + It will be multiplied by ``1.44269504`` to work together with the triton attention kernel. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] + A tuple containing: + - The quantized query tensor. Shape: Same as `q` but with `int8` dtype. + - The scale tensor of the query tensor. Shape: ``[batch_size, num_qo_heads, (qo_len + BLKQ - 1) // BLKQ]`` with `float32` dtype. + - The quantized key tensor. Shape: Same as `k` but with `int8` dtype. + - The scale tensor of the key tensor. Shape: ``[batch_size, num_kv_heads, (kv_len + BLKK - 1) // BLKK]`` with `float32` dtype. + + Note + ---- + - The tensors `q` and `k` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + """ # noqa: E501 + + q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device) + k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device) + + if tensor_layout == "HND": + b, h_qo, qo_len, head_dim = q.shape + _, h_kv, kv_len, _ = k.shape + + elif tensor_layout == "NHD": + b, qo_len, h_qo, head_dim = q.shape + _, kv_len, h_kv, _ = k.shape + + else: + raise ValueError(f"Unknown tensor layout: {tensor_layout}") + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + + q_scale = torch.empty( + (b, h_qo, (qo_len + BLKQ - 1) // BLKQ), + device=q.device, + dtype=torch.float32, + ) + k_scale = torch.empty( + (b, h_kv, (kv_len + BLKK - 1) // BLKK), + device=q.device, + dtype=torch.float32, + ) + + if sm_scale is None: + sm_scale = head_dim**-0.5 + + sm_scale *= 1.44269504 + + _fused.quant_per_block_int8_cuda( + q, q_int8, q_scale, sm_scale, BLKQ, _tensor_layout + ) + if km is not None: + km = km.squeeze(1) if _tensor_layout == 0 else km.squeeze(2) + _fused.quant_per_block_int8_fuse_sub_mean_cuda( + k, km, k_int8, k_scale, BLKK, _tensor_layout + ) + else: + _fused.quant_per_block_int8_cuda( + k, k_int8, k_scale, BLKK, _tensor_layout + ) + + return q_int8, q_scale, k_int8, k_scale + + +def per_warp_int8( + q: torch.Tensor, + k: torch.Tensor, + km: Optional[torch.Tensor] = None, + BLKQ: int = 128, + WARPQ: int = 32, + BLKK: int = 64, + tensor_layout: str = "HND", +): + """ + Quantize the query tensor `q` with per warp quantization and the key tensor `k` with per block quantization. + Warp size of quantizing `q` is 16 or 32, with a block size of 64 or 128. + Block size of quantizing `k` is 64 or 128. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + km : Optional[torch.Tensor] + The mean tensor of `k` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]``. + Should be of the same dtype as `k` if provided. Default is None. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] + A tuple containing: + - The quantized query tensor. Shape: Same as `q` but with `int8` dtype. + - The scale tensor of the query tensor. Shape: ``[batch_size, num_qo_heads, (qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ)]`` with `float32` dtype. + - The quantized key tensor. Shape: Same as `k` but with `int8` dtype. + - The scale tensor of the key tensor. Shape: ``[batch_size, num_kv_heads, (kv_len + BLKK - 1) // BLKK]`` with `float32` dtype. + + Note + ---- + - The tensors `q` and `k` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + """ # noqa: E501 + + q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device) + k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device) + + if tensor_layout == "HND": + b, h_qo, qo_len, head_dim = q.shape + _, h_kv, kv_len, _ = k.shape + + elif tensor_layout == "NHD": + b, qo_len, h_qo, head_dim = q.shape + _, kv_len, h_kv, _ = k.shape + + else: + raise ValueError(f"Unknown tensor layout: {tensor_layout}") + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + + q_scale = torch.empty( + (b, h_qo, ((qo_len + BLKQ - 1) // BLKQ) * (BLKQ // WARPQ)), + device=q.device, + dtype=torch.float32, + ) + k_scale = torch.empty( + (b, h_kv, (kv_len + BLKK - 1) // BLKK), + device=q.device, + dtype=torch.float32, + ) + + _fused.quant_per_warp_int8_cuda( + q, q_int8, q_scale, BLKQ, WARPQ, _tensor_layout + ) + + if km is not None: + km = km.squeeze(1) if _tensor_layout == 0 else km.squeeze(2) + _fused.quant_per_block_int8_fuse_sub_mean_cuda( + k, km, k_int8, k_scale, BLKK, _tensor_layout + ) + else: + _fused.quant_per_block_int8_cuda( + k, k_int8, k_scale, BLKK, _tensor_layout + ) + + return q_int8, q_scale, k_int8, k_scale + + +def sub_mean(v: torch.Tensor, tensor_layout: str = "HND"): + """ + Calculate the mean of the tensor `v` along the sequence length dimension and subtract it from `v`. Result is stored as fp16. + + Parameters + ---------- + v : torch.Tensor + The input tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor] + A tuple containing: + - The tensor `v_smoothed` with the mean subtracted and stored as fp16. Shape: Same as `v` with `float16` dtype. + - The mean tensor of `v` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]`` with dtype same as `v`. + + Note + ---- + - The tensors `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - The returned tensor `v_smoothed` will have dtype ``torch.float16`` regardless of the input dtype. + - The returned mean tensor will have the same dtype as the input tensor. + """ # noqa: E501 + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + vm = v.mean(dim=1 if _tensor_layout == 0 else 2) + + v_smoothed = torch.empty(v.shape, dtype=torch.float16, device=v.device) + + # subtract mean and store the result as fp16 + _fused.sub_mean_cuda(v, vm, v_smoothed, _tensor_layout) + + return v_smoothed, vm + + +def per_channel_fp8( + v: torch.Tensor, + tensor_layout: str = "HND", + scale_max: float = 448.0, + smooth_v: bool = True, +): + """ + Transpose, pad and permute the tensor `v` and quantize it to fp8 with per channel quantization. + `v` is first transposed along the head dimension and the sequence length dimension, then padded to a multiple of 64. + After that, the tensor is permuted along the sequence length dimension by ``[0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15]``. + The quantization is done per channel, with the scale value and smooth factor calculated per channel. + + Parameters + ---------- + v : torch.Tensor + The input tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + scale_max : float + The maximum scale value for the quantization. Default is 448.0 (upper bound of E4M3 data format). + + smooth_v : bool + Whether to smooth the quantized tensor. Default is True. + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] + A tuple containing: + - The quantized tensor `v_fp8`. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, head_dim, (kv_len + 63) // 64 * 64]``, with `float8_e4m3fn` dtype. + - If `tensor_layout` is "NHD": ``[batch_size, head_dim, num_kv_heads, (kv_len + 63) // 64 * 64]``, with `float8_e4m3fn` dtype. + - The scale tensor of `v`. Shape: ``[batch_size, num_kv_heads, head_dim]`` with `float32` dtype. + - The mean tensor of `v` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]`` with `float32` dtype. + + Note + ---- + - The tensors `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - The returned mean tensor will be None if `smooth_v` is False. Otherwise it will have dtype ``torch.float32``. + """ # noqa: E501 + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + + if tensor_layout == "HND": + b, h_kv, kv_len, head_dim = v.shape + padded_len = (kv_len + 63) // 64 * 64 + v_transposed_permutted = torch.empty( + (b, h_kv, head_dim, padded_len), dtype=v.dtype, device=v.device + ) + + elif tensor_layout == "NHD": + b, kv_len, h_kv, head_dim = v.shape + padded_len = (kv_len + 63) // 64 * 64 + v_transposed_permutted = torch.empty( + (b, head_dim, h_kv, padded_len), dtype=v.dtype, device=v.device + ) + + ops.transpose_pad_permute(v, v_transposed_permutted, _tensor_layout) + + v_fp8 = torch.empty( + v_transposed_permutted.shape, dtype=torch.float8_e4m3fn, + device=v.device + ) + + v_scale = torch.empty( + (b, h_kv, head_dim), dtype=torch.float32, device=v.device + ) + vm = torch.empty((b, h_kv, head_dim), dtype=torch.float32, device=v.device) + + if smooth_v: + ops.mean_scale_fuse_quant( + v_transposed_permutted, + v_fp8, + vm, + v_scale, + kv_len, + scale_max, + _tensor_layout, + ) + return v_fp8, v_scale, vm + else: + ops.scale_fuse_quant( + v_transposed_permutted, + v_fp8, + v_scale, + kv_len, + scale_max, + _tensor_layout, + ) + return v_fp8, v_scale, None diff --git a/aphrodite/attention/ops/sage_attention/quant_per_block.py b/aphrodite/attention/ops/sage_attention/quant_per_block.py new file mode 100644 index 0000000000..8c2d7497ac --- /dev/null +++ b/aphrodite/attention/ops/sage_attention/quant_per_block.py @@ -0,0 +1,186 @@ +""" +Copyright (c) 2024 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch + +from aphrodite.triton_utils import tl, triton + + +@triton.jit +def quant_per_block_int8_kernel( + Input, + Output, + Scale, + L, + stride_iz, + stride_ih, + stride_in, + stride_oz, + stride_oh, + stride_on, + stride_sz, + stride_sh, + sm_scale, + C: tl.constexpr, + BLK: tl.constexpr, +): + off_blk = tl.program_id(0) + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + offs_n = off_blk * BLK + tl.arange(0, BLK) + offs_k = tl.arange(0, C) + + input_ptrs = ( + Input + + off_b * stride_iz + + off_h * stride_ih + + offs_n[:, None] * stride_in + + offs_k[None, :] + ) + output_ptrs = ( + Output + + off_b * stride_oz + + off_h * stride_oh + + offs_n[:, None] * stride_on + + offs_k[None, :] + ) + scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk + + x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + x = x.to(tl.float32) + x *= sm_scale + scale = tl.max(tl.abs(x)) / 127.0 + x_int8 = x / scale + x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + x_int8 = x_int8.to(tl.int8) + tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + tl.store(scale_ptrs, scale) + + +def per_block_int8( + q, k, km=None, BLKQ=128, BLKK=64, sm_scale=None, tensor_layout="HND" +): + q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device) + k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device) + + if km is not None: + k = k - km + + if tensor_layout == "HND": + b, h_qo, qo_len, head_dim = q.shape + _, h_kv, kv_len, _ = k.shape + + stride_bz_q, stride_h_q, stride_seq_q = ( + q.stride(0), + q.stride(1), + q.stride(2), + ) + stride_bz_qo, stride_h_qo, stride_seq_qo = ( + q_int8.stride(0), + q_int8.stride(1), + q_int8.stride(2), + ) + stride_bz_k, stride_h_k, stride_seq_k = ( + k.stride(0), + k.stride(1), + k.stride(2), + ) + stride_bz_ko, stride_h_ko, stride_seq_ko = ( + k_int8.stride(0), + k_int8.stride(1), + k_int8.stride(2), + ) + elif tensor_layout == "NHD": + b, qo_len, h_qo, head_dim = q.shape + _, kv_len, h_kv, _ = k.shape + + stride_bz_q, stride_h_q, stride_seq_q = ( + q.stride(0), + q.stride(2), + q.stride(1), + ) + stride_bz_qo, stride_h_qo, stride_seq_qo = ( + q_int8.stride(0), + q_int8.stride(2), + q_int8.stride(1), + ) + stride_bz_k, stride_h_k, stride_seq_k = ( + k.stride(0), + k.stride(2), + k.stride(1), + ) + stride_bz_ko, stride_h_ko, stride_seq_ko = ( + k_int8.stride(0), + k_int8.stride(2), + k_int8.stride(1), + ) + else: + raise ValueError(f"Unknown tensor layout: {tensor_layout}") + + q_scale = torch.empty( + (b, h_qo, (qo_len + BLKQ - 1) // BLKQ), + device=q.device, + dtype=torch.float32, + ) + k_scale = torch.empty( + (b, h_kv, (kv_len + BLKK - 1) // BLKK), + device=q.device, + dtype=torch.float32, + ) + + if sm_scale is None: + sm_scale = head_dim**-0.5 + + grid = ((qo_len + BLKQ - 1) // BLKQ, h_qo, b) + quant_per_block_int8_kernel[grid]( + q, + q_int8, + q_scale, + qo_len, + stride_bz_q, + stride_h_q, + stride_seq_q, + stride_bz_qo, + stride_h_qo, + stride_seq_qo, + q_scale.stride(0), + q_scale.stride(1), + sm_scale=(sm_scale * 1.44269504), + C=head_dim, + BLK=BLKQ, + ) + + grid = ((kv_len + BLKK - 1) // BLKK, h_kv, b) + quant_per_block_int8_kernel[grid]( + k, + k_int8, + k_scale, + kv_len, + stride_bz_k, + stride_h_k, + stride_seq_k, + stride_bz_ko, + stride_h_ko, + stride_seq_ko, + k_scale.stride(0), + k_scale.stride(1), + sm_scale=1.0, + C=head_dim, + BLK=BLKK, + ) + + return q_int8, q_scale, k_int8, k_scale diff --git a/aphrodite/attention/ops/sage_attention/quant_per_block_varlen.py b/aphrodite/attention/ops/sage_attention/quant_per_block_varlen.py new file mode 100644 index 0000000000..acf5ac96da --- /dev/null +++ b/aphrodite/attention/ops/sage_attention/quant_per_block_varlen.py @@ -0,0 +1,165 @@ +""" +Copyright (c) 2024 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch + +from aphrodite.triton_utils import tl, triton + + +@triton.jit +def quant_per_block_int8_kernel( + Input, + Output, + Scale, + cu_seqlens_input, + cu_seqlens_scale, + stride_ih, + stride_in, + stride_oh, + stride_on, + sm_scale, + H: tl.constexpr, + C: tl.constexpr, + BLK: tl.constexpr, +): + off_blk = tl.program_id(0) + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + cu_seqlens_input_start = tl.load(cu_seqlens_input + off_b) + cu_seqlens_input_end = tl.load(cu_seqlens_input + off_b + 1) + + L = cu_seqlens_input_end - cu_seqlens_input_start + + if (off_blk * BLK) >= L: + return + + cu_seqlens_scale_start = tl.load(cu_seqlens_scale + off_b) + + offs_n = off_blk * BLK + tl.arange(0, BLK) + offs_k = tl.arange(0, C) + + input_ptrs = ( + Input + + cu_seqlens_input_start * stride_in + + off_h * stride_ih + + offs_n[:, None] * stride_in + + offs_k[None, :] + ) + output_ptrs = ( + Output + + cu_seqlens_input_start * stride_on + + off_h * stride_oh + + offs_n[:, None] * stride_on + + offs_k[None, :] + ) + scale_ptrs = Scale + cu_seqlens_scale_start * H + off_h + off_blk * H + + x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + x = x.to(tl.float32) + x *= sm_scale + scale = tl.max(tl.abs(x)) / 127.0 + x_int8 = x / scale + x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + x_int8 = x_int8.to(tl.int8) + tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + tl.store(scale_ptrs, scale) + + +def per_block_int8( + q, + k, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + BLKQ=128, + BLKK=64, + sm_scale=None, +): + q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device) + k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device) + + h_qo = q.shape[1] + h_kv = k.shape[1] + head_dim = q.shape[-1] + + b = cu_seqlens_q.shape[0] - 1 + q_batch_len = cu_seqlens_q[1:] - cu_seqlens_q[:-1] + k_batch_len = cu_seqlens_k[1:] - cu_seqlens_k[:-1] + + q_scale_len = (q_batch_len + BLKQ - 1) // BLKQ + k_scale_len = (k_batch_len + BLKK - 1) // BLKK + + cu_seqlens_q_scale = torch.nn.functional.pad( + torch.cumsum(q_scale_len, dim=0), (1, 0), value=0 + ) + cu_seqlens_k_scale = torch.nn.functional.pad( + torch.cumsum(k_scale_len, dim=0), (1, 0), value=0 + ) + + q_scale = torch.empty( + (cu_seqlens_q_scale[-1], h_qo), device=q.device, dtype=torch.float32 + ) + k_scale = torch.empty( + (cu_seqlens_k_scale[-1], h_kv), device=k.device, dtype=torch.float32 + ) + + if sm_scale is None: + sm_scale = head_dim**-0.5 + + grid = ((max_seqlen_q + BLKQ - 1) // BLKQ, h_qo, b) + quant_per_block_int8_kernel[grid]( + q, + q_int8, + q_scale, + cu_seqlens_q, + cu_seqlens_q_scale, + q.stride(1), + q.stride(0), + q_int8.stride(1), + q_int8.stride(0), + sm_scale=(sm_scale * 1.44269504), + H=h_qo, + C=head_dim, + BLK=BLKQ, + ) + + grid = ((max_seqlen_k + BLKK - 1) // BLKK, h_kv, b) + quant_per_block_int8_kernel[grid]( + k, + k_int8, + k_scale, + cu_seqlens_k, + cu_seqlens_k_scale, + k.stride(1), + k.stride(0), + k_int8.stride(1), + k_int8.stride(0), + sm_scale=1.0, + H=h_kv, + C=head_dim, + BLK=BLKK, + ) + + return ( + q_int8, + q_scale, + k_int8, + k_scale, + cu_seqlens_q_scale, + cu_seqlens_k_scale, + ) diff --git a/aphrodite/attention/ops/sage_attention/quant_per_thread.py b/aphrodite/attention/ops/sage_attention/quant_per_thread.py new file mode 100644 index 0000000000..37fde70cbe --- /dev/null +++ b/aphrodite/attention/ops/sage_attention/quant_per_thread.py @@ -0,0 +1,379 @@ +""" +Copyright (c) 2024 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch + +from aphrodite.triton_utils import tl, triton + + +@triton.jit +def quant_query_per_thread_int8_kernel( + Input, + Output, + Scale, + L, + stride_iz, + stride_ih, + stride_in, + stride_oz, + stride_oh, + stride_on, + stride_sz, + stride_sh, + C: tl.constexpr, + BLK: tl.constexpr, +): + off_blk = tl.program_id(0) // 8 + off_tld = tl.program_id(0) % 8 + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + offs_n = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld + offs_k = tl.arange(0, C) + + input_ptrs = ( + Input + + off_b * stride_iz + + off_h * stride_ih + + offs_n[:, None] * stride_in + + offs_k[None, :] + ) + output_ptrs = ( + Output + + off_b * stride_oz + + off_h * stride_oh + + offs_n[:, None] * stride_on + + offs_k[None, :] + ) + scale_ptrs = ( + Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 8 + off_tld + ) + + x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + x = x.to(tl.float32) + scale = tl.max(tl.abs(x)) / 127.0 + 0.0000001 + x_int8 = x / scale + x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + x_int8 = x_int8.to(tl.int8) + tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + tl.store(scale_ptrs, scale) + + +@triton.jit +def quant_key_per_thread_int8_kernel( + Input, + Output, + Scale, + L, + stride_iz, + stride_ih, + stride_in, + stride_oz, + stride_oh, + stride_on, + stride_sz, + stride_sh, + C: tl.constexpr, + BLK: tl.constexpr, +): + off_blk = tl.program_id(0) // 4 + off_tld = tl.program_id(0) % 4 + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + offs_n0 = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld * 2 + offs_n1 = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld * 2 + 1 + offs_k = tl.arange(0, C) + + input_ptrs0 = ( + Input + + off_b * stride_iz + + off_h * stride_ih + + offs_n0[:, None] * stride_in + + offs_k[None, :] + ) + input_ptrs1 = ( + Input + + off_b * stride_iz + + off_h * stride_ih + + offs_n1[:, None] * stride_in + + offs_k[None, :] + ) + output_ptrs0 = ( + Output + + off_b * stride_oz + + off_h * stride_oh + + offs_n0[:, None] * stride_on + + offs_k[None, :] + ) + output_ptrs1 = ( + Output + + off_b * stride_oz + + off_h * stride_oh + + offs_n1[:, None] * stride_on + + offs_k[None, :] + ) + scale_ptrs = ( + Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld + ) + + x0 = tl.load(input_ptrs0, mask=offs_n0[:, None] < L) + x1 = tl.load(input_ptrs1, mask=offs_n1[:, None] < L) + x0 = x0.to(tl.float32) + x1 = x1.to(tl.float32) + scale = max(tl.max(tl.abs(x0)), tl.max(tl.abs(x1))) / 127.0 + 0.0000001 + x0_int8 = x0 / scale + x1_int8 = x1 / scale + x0_int8 += 0.5 * tl.where(x0_int8 >= 0, 1, -1) + x1_int8 += 0.5 * tl.where(x1_int8 >= 0, 1, -1) + x0_int8 = x0_int8.to(tl.int8) + x1_int8 = x1_int8.to(tl.int8) + tl.store(output_ptrs0, x0_int8, mask=offs_n0[:, None] < L) + tl.store(output_ptrs1, x1_int8, mask=offs_n1[:, None] < L) + tl.store(scale_ptrs, scale) + + +@triton.jit +def quant_query_per_thread_int4_kernel( + Input, + Output, + Scale, + L, + stride_iz, + stride_ih, + stride_in, + stride_oz, + stride_oh, + stride_on, + stride_sz, + stride_sh, + C: tl.constexpr, + BLK: tl.constexpr, +): + off_blk = tl.program_id(0) // 8 + off_tld = tl.program_id(0) % 8 + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + offs_n = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld + offs_k = tl.arange(0, C) + + input_ptrs = ( + Input + + off_b * stride_iz + + off_h * stride_ih + + offs_n[:, None] * stride_in + + offs_k[None, :] + ) + output_ptrs = ( + Output + + off_b * stride_oz + + off_h * stride_oh + + offs_n[:, None] * stride_on + + offs_k[None, :] + ) + scale_ptrs = ( + Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 8 + off_tld + ) + + x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + x = x.to(tl.float32) + scale = tl.max(tl.abs(x)) / 7.0 + 0.0000001 + x_int8 = x / scale + x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + x_int8 = x_int8.to(tl.int8) + tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + tl.store(scale_ptrs, scale) + + +@triton.jit +def quant_key_per_thread_int4_kernel( + Input, + Output, + Scale, + L, + stride_iz, + stride_ih, + stride_in, + stride_oz, + stride_oh, + stride_on, + stride_sz, + stride_sh, + C: tl.constexpr, + BLK: tl.constexpr, +): + off_blk = tl.program_id(0) // 4 + off_tld = tl.program_id(0) % 4 + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + offs_n = ( + off_blk * BLK + + tl.cat( + tl.arange(0, BLK // 8) * 8, tl.arange(0, BLK // 8) * 8 + 1, True + ) + + off_tld * 2 + ) + offs_k = tl.arange(0, C) + + input_ptrs = ( + Input + + off_b * stride_iz + + off_h * stride_ih + + offs_n[:, None] * stride_in + + offs_k[None, :] + ) + output_ptrs = ( + Output + + off_b * stride_oz + + off_h * stride_oh + + offs_n[:, None] * stride_on + + offs_k[None, :] + ) + scale_ptrs = ( + Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld + ) + + x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + x = x.to(tl.float32) + scale = tl.max(tl.abs(x)) / 7.0 + 0.0000001 + x_int8 = x / scale + x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + x_int8 = x_int8.to(tl.int8) + tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + tl.store(scale_ptrs, scale) + + +def per_thread_int8( + q, + k, + km=None, + BLKQ=128, + WARPQ=32, + BLKK=64, + WARPK=64, + sm_scale=None, + tensor_layout="HND", +): + q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device) + k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device) + + if km is not None: + k = k - km + + if tensor_layout == "HND": + b, h_qo, qo_len, head_dim = q.shape + _, h_kv, kv_len, _ = k.shape + + stride_bz_q, stride_h_q, stride_seq_q = ( + q.stride(0), + q.stride(1), + q.stride(2), + ) + stride_bz_qo, stride_h_qo, stride_seq_qo = ( + q_int8.stride(0), + q_int8.stride(1), + q_int8.stride(2), + ) + stride_bz_k, stride_h_k, stride_seq_k = ( + k.stride(0), + k.stride(1), + k.stride(2), + ) + stride_bz_ko, stride_h_ko, stride_seq_ko = ( + k_int8.stride(0), + k_int8.stride(1), + k_int8.stride(2), + ) + elif tensor_layout == "NHD": + b, qo_len, h_qo, head_dim = q.shape + _, kv_len, h_kv, _ = k.shape + + stride_bz_q, stride_h_q, stride_seq_q = ( + q.stride(0), + q.stride(2), + q.stride(1), + ) + stride_bz_qo, stride_h_qo, stride_seq_qo = ( + q_int8.stride(0), + q_int8.stride(2), + q_int8.stride(1), + ) + stride_bz_k, stride_h_k, stride_seq_k = ( + k.stride(0), + k.stride(2), + k.stride(1), + ) + stride_bz_ko, stride_h_ko, stride_seq_ko = ( + k_int8.stride(0), + k_int8.stride(2), + k_int8.stride(1), + ) + else: + raise ValueError(f"Unknown tensor layout: {tensor_layout}") + + q_scale = torch.empty( + (b, h_qo, (qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ) * 8), + device=q.device, + dtype=torch.float32, + ) + k_scale = torch.empty( + (b, h_kv, (kv_len + BLKK - 1) // BLKK * (BLKK // WARPK) * 4), + device=q.device, + dtype=torch.float32, + ) + + if sm_scale is None: + sm_scale = head_dim**-0.5 + + grid = ((qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ) * 8, h_qo, b) + quant_query_per_thread_int8_kernel[grid]( + q, + q_int8, + q_scale, + qo_len, + stride_bz_q, + stride_h_q, + stride_seq_q, + stride_bz_qo, + stride_h_qo, + stride_seq_qo, + q_scale.stride(0), + q_scale.stride(1), + C=head_dim, + BLK=WARPQ, + ) + + grid = ((kv_len + BLKK - 1) // BLKK * (BLKK // WARPK) * 4, h_kv, b) + quant_key_per_thread_int8_kernel[grid]( + k, + k_int8, + k_scale, + kv_len, + stride_bz_k, + stride_h_k, + stride_seq_k, + stride_bz_ko, + stride_h_ko, + stride_seq_ko, + k_scale.stride(0), + k_scale.stride(1), + C=head_dim, + BLK=WARPK, + ) + + return q_int8, q_scale, k_int8, k_scale diff --git a/aphrodite/v1/attention/backends/sage_attn.py b/aphrodite/v1/attention/backends/sage_attn.py index 4865fd70b4..614ada2aed 100644 --- a/aphrodite/v1/attention/backends/sage_attn.py +++ b/aphrodite/v1/attention/backends/sage_attn.py @@ -3,11 +3,11 @@ from typing import Optional import torch -from sageattention import sageattn from aphrodite.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionType) +from aphrodite.attention.ops.sage_attention import sageattn from aphrodite.config import AphroditeConfig from aphrodite.v1.attention.backends.utils import (AttentionMetadataBuilder, CommonAttentionMetadata) diff --git a/kernels/attention/sage_attn/fused/fused.cu b/kernels/attention/sage_attn/fused/fused.cu index d9265e769a..a9d02784a4 100644 --- a/kernels/attention/sage_attn/fused/fused.cu +++ b/kernels/attention/sage_attn/fused/fused.cu @@ -434,6 +434,7 @@ void quant_per_block_int8_cuda( int64_t block_size, int64_t tensor_layout) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 CHECK_CUDA(input); CHECK_CUDA(output); CHECK_CUDA(scale); @@ -507,6 +508,7 @@ void quant_per_block_int8_cuda( }); }); }); +#endif } void quant_per_block_int8_cuda( @@ -516,6 +518,7 @@ void quant_per_block_int8_cuda( int64_t block_size, int64_t tensor_layout) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 CHECK_CUDA(input); CHECK_CUDA(output); CHECK_CUDA(scale); @@ -589,6 +592,7 @@ void quant_per_block_int8_cuda( }); }); }); +#endif } void quant_per_block_int8_fuse_sub_mean_cuda( @@ -599,6 +603,7 @@ void quant_per_block_int8_fuse_sub_mean_cuda( int64_t block_size, int64_t tensor_layout) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 CHECK_CUDA(input); CHECK_CUDA(mean); CHECK_CUDA(output); @@ -679,6 +684,7 @@ void quant_per_block_int8_fuse_sub_mean_cuda( }); }); }); +#endif } // use block size 128 and warp_block size 32 @@ -690,6 +696,7 @@ void quant_per_warp_int8_cuda( int64_t warp_block_size, int64_t tensor_layout) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 CHECK_CUDA(input); CHECK_CUDA(output); CHECK_CUDA(scale); @@ -765,6 +772,7 @@ void quant_per_warp_int8_cuda( }); }); }); +#endif } void sub_mean_cuda( @@ -773,6 +781,7 @@ void sub_mean_cuda( torch::Tensor output, int64_t tensor_layout) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 CHECK_CUDA(input); CHECK_CUDA(mean); CHECK_CUDA(output); @@ -845,6 +854,7 @@ void sub_mean_cuda( ); }); }); +#endif } void transpose_pad_permute_cuda( @@ -852,6 +862,7 @@ void transpose_pad_permute_cuda( torch::Tensor output, int64_t tensor_layout) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 CHECK_CUDA(input); CHECK_CUDA(output); @@ -920,6 +931,7 @@ void transpose_pad_permute_cuda( ); }); }); +#endif } void scale_fuse_quant_cuda( @@ -930,6 +942,7 @@ void scale_fuse_quant_cuda( double scale_max, int64_t tensor_layout) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 CHECK_CUDA(input); CHECK_CUDA(output); CHECK_CUDA(scale); @@ -997,6 +1010,7 @@ void scale_fuse_quant_cuda( scale.stride(0), scale.stride(1) ); }); +#endif } void mean_scale_fuse_quant_cuda( @@ -1008,6 +1022,7 @@ void mean_scale_fuse_quant_cuda( double scale_max, int64_t tensor_layout) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 CHECK_CUDA(input); CHECK_CUDA(output); CHECK_CUDA(mean); @@ -1080,4 +1095,5 @@ void mean_scale_fuse_quant_cuda( scale.stride(0), scale.stride(1) ); }); +#endif } \ No newline at end of file diff --git a/kernels/attention/sage_attn/qattn/qk_int_sv_f16_cuda_sm80.cu b/kernels/attention/sage_attn/qattn/qk_int_sv_f16_cuda_sm80.cu index f3249f4ed9..693d466e1a 100644 --- a/kernels/attention/sage_attn/qattn/qk_int_sv_f16_cuda_sm80.cu +++ b/kernels/attention/sage_attn/qattn/qk_int_sv_f16_cuda_sm80.cu @@ -668,6 +668,7 @@ __global__ void qk_int_sv_f16_attn_kernel(int8_t *__restrict__ Q, int8_t *__rest } // } +#endif } // tensor_layout 0 for [B, N, H, D], 1 for [B, H, N, D] @@ -683,6 +684,7 @@ torch::Tensor qk_int8_sv_f16_accum_f32_attn(torch::Tensor query, double sm_scale, int64_t return_lse) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 CHECK_CUDA(query); CHECK_CUDA(key); CHECK_CUDA(value); @@ -843,6 +845,7 @@ torch::Tensor qk_int8_sv_f16_accum_f32_attn(torch::Tensor query, }); return lse; +#endif } torch::Tensor qk_int8_sv_f16_accum_f16_attn(torch::Tensor query, @@ -857,6 +860,7 @@ torch::Tensor qk_int8_sv_f16_accum_f16_attn(torch::Tensor query, double sm_scale, int64_t return_lse) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 CHECK_CUDA(query); CHECK_CUDA(key); CHECK_CUDA(value); @@ -1018,6 +1022,7 @@ torch::Tensor qk_int8_sv_f16_accum_f16_attn(torch::Tensor query, }); return lse; +#endif } torch::Tensor qk_int8_sv_f16_accum_f16_attn_inst_buf(torch::Tensor query, @@ -1032,6 +1037,7 @@ torch::Tensor qk_int8_sv_f16_accum_f16_attn_inst_buf(torch::Tensor query, double sm_scale, int64_t return_lse) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 CHECK_CUDA(query); CHECK_CUDA(key); CHECK_CUDA(value); @@ -1193,6 +1199,7 @@ torch::Tensor qk_int8_sv_f16_accum_f16_attn_inst_buf(torch::Tensor query, }); return lse; +#endif } torch::Tensor qk_int8_sv_f16_accum_f16_fuse_v_mean_attn(torch::Tensor query, diff --git a/kernels/attention/sage_attn/qattn/qk_int_sv_f8_cuda_sm90.cu b/kernels/attention/sage_attn/qattn/qk_int_sv_f8_cuda_sm90.cu index 9a366b058b..1b6b88fcf5 100644 --- a/kernels/attention/sage_attn/qattn/qk_int_sv_f8_cuda_sm90.cu +++ b/kernels/attention/sage_attn/qattn/qk_int_sv_f8_cuda_sm90.cu @@ -579,6 +579,7 @@ torch::Tensor qk_int8_sv_f8_accum_f32_attn_inst_buf( double sm_scale, int64_t return_lse) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 CHECK_CUDA(query); CHECK_CUDA(key); CHECK_CUDA(value); @@ -735,6 +736,7 @@ torch::Tensor qk_int8_sv_f8_accum_f32_attn_inst_buf( }); return lse; +#endif } torch::Tensor qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf( @@ -751,6 +753,7 @@ torch::Tensor qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf( double sm_scale, int64_t return_lse) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 CHECK_CUDA(query); CHECK_CUDA(key); CHECK_CUDA(value); @@ -913,4 +916,5 @@ torch::Tensor qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf( }); return lse; +#endif } \ No newline at end of file diff --git a/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f16_attn_inst_buf.cu b/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f16_attn_inst_buf.cu index 7fb573706c..ccf81363a3 100644 --- a/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f16_attn_inst_buf.cu +++ b/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f16_attn_inst_buf.cu @@ -13,6 +13,7 @@ torch::Tensor qk_int8_sv_f8_accum_f16_attn_inst_buf(torch::Tensor query, double sm_scale, int64_t return_lse) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 890 CHECK_CUDA(query); CHECK_CUDA(key); CHECK_CUDA(value); @@ -177,4 +178,5 @@ torch::Tensor qk_int8_sv_f8_accum_f16_attn_inst_buf(torch::Tensor query, }); return lse; +#endif } \ No newline at end of file diff --git a/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf.cu b/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf.cu index 212e4153e2..663bdea2fb 100644 --- a/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf.cu +++ b/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf.cu @@ -13,6 +13,7 @@ torch::Tensor qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf(torch::Tensor q double sm_scale, int64_t return_lse) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 890 CHECK_CUDA(query); CHECK_CUDA(key); CHECK_CUDA(value); @@ -184,4 +185,5 @@ torch::Tensor qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf(torch::Tensor q }); return lse; +#endif } \ No newline at end of file diff --git a/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_attn.cu b/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_attn.cu index 96ce7334b1..1170bee362 100644 --- a/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_attn.cu +++ b/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_attn.cu @@ -13,6 +13,7 @@ torch::Tensor qk_int8_sv_f8_accum_f32_attn(torch::Tensor query, double sm_scale, int64_t return_lse) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 890 CHECK_CUDA(query); CHECK_CUDA(key); CHECK_CUDA(value); @@ -177,4 +178,5 @@ torch::Tensor qk_int8_sv_f8_accum_f32_attn(torch::Tensor query, }); return lse; +#endif } \ No newline at end of file diff --git a/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_attn_inst_buf.cu b/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_attn_inst_buf.cu index b19b07b605..3eafe0454d 100644 --- a/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_attn_inst_buf.cu +++ b/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_attn_inst_buf.cu @@ -12,6 +12,7 @@ torch::Tensor qk_int8_sv_f8_accum_f32_attn_inst_buf(torch::Tensor query, double sm_scale, int64_t return_lse) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 890 CHECK_CUDA(query); CHECK_CUDA(key); CHECK_CUDA(value); @@ -176,4 +177,5 @@ torch::Tensor qk_int8_sv_f8_accum_f32_attn_inst_buf(torch::Tensor query, }); return lse; +#endif } \ No newline at end of file diff --git a/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn.cu b/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn.cu index 20e72af6fb..f51feec1ef 100644 --- a/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn.cu +++ b/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn.cu @@ -13,6 +13,7 @@ torch::Tensor qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(torch::Tensor query, double sm_scale, int64_t return_lse) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 890 CHECK_CUDA(query); CHECK_CUDA(key); CHECK_CUDA(value); @@ -184,4 +185,5 @@ torch::Tensor qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(torch::Tensor query, }); return lse; +#endif } \ No newline at end of file diff --git a/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf.cu b/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf.cu index 83c065ad7e..c48b02d9fc 100644 --- a/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf.cu +++ b/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf.cu @@ -13,6 +13,7 @@ torch::Tensor qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(torch::Tensor q double sm_scale, int64_t return_lse) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 890 CHECK_CUDA(query); CHECK_CUDA(key); CHECK_CUDA(value); @@ -184,4 +185,5 @@ torch::Tensor qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(torch::Tensor q }); return lse; +#endif } \ No newline at end of file diff --git a/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn.cu b/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn.cu index 9b2124d598..8a575a939e 100644 --- a/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn.cu +++ b/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn.cu @@ -14,6 +14,7 @@ torch::Tensor qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn(torch::Tenso double sm_scale, int64_t return_lse) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 890 CHECK_CUDA(query); CHECK_CUDA(key); CHECK_CUDA(value); @@ -189,4 +190,5 @@ torch::Tensor qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn(torch::Tenso }); return lse; +#endif } \ No newline at end of file diff --git a/kernels/attention/sage_attn/utils.cuh b/kernels/attention/sage_attn/utils.cuh index d051cc1f44..f94d29fd51 100644 --- a/kernels/attention/sage_attn/utils.cuh +++ b/kernels/attention/sage_attn/utils.cuh @@ -35,4 +35,4 @@ TORCH_CHECK(x.is_contiguous(), "Tensor " #x " must be contiguous") #define CHECK_LASTDIM_CONTIGUOUS(x) \ TORCH_CHECK(x.stride(-1) == 1, \ - "Tensor " #x " must be contiguous at the last dimension") \ No newline at end of file + "Tensor " #x " must be contiguous at the last dimension") From b60bf1565089f6dafed28d376fa5f83c4e067727 Mon Sep 17 00:00:00 2001 From: AlpinDale Date: Fri, 12 Sep 2025 19:08:29 +0000 Subject: [PATCH 3/3] more fixes --- CMakeLists.txt | 92 ++++++++++++++++--- .../qattn/qk_int_sv_f16_cuda_sm80.cu | 1 - ...9_qk_int8_sv_f8_accum_f16_attn_inst_buf.cu | 1 - ...f8_accum_f16_fuse_v_scale_attn_inst_buf.cu | 1 - .../sm89_qk_int8_sv_f8_accum_f32_attn.cu | 1 - ...9_qk_int8_sv_f8_accum_f32_attn_inst_buf.cu | 1 - ..._int8_sv_f8_accum_f32_fuse_v_scale_attn.cu | 1 - ...f8_accum_f32_fuse_v_scale_attn_inst_buf.cu | 1 - ...accum_f32_fuse_v_scale_fuse_v_mean_attn.cu | 1 - kernels/ops.h | 5 + kernels/torch_bindings.cpp | 4 + 11 files changed, 90 insertions(+), 19 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 09662080d8..3d5c68f615 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -307,17 +307,7 @@ set(APHRODITE_EXT_SRC "kernels/sparse/cutlass/sparse_scaled_mm_entry.cu" "kernels/cutlass_extensions/common.cpp" "kernels/attention/mla/cutlass_mla_entry.cu" - "kernels/quantization/fp8/per_token_group_quant.cu" - "kernels/attention/sage_attn/fused/fused.cu" - "kernels/attention/sage_attn/qattn/qk_int_sv_f8_cuda_sm90.cu" - "kernels/attention/sage_attn/qattn/qk_int_sv_f16_cuda_sm80.cu" - "kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f16_attn_inst_buf.cu" - "kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf.cu" - "kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_attn_inst_buf.cu" - "kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_attn.cu" - "kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf.cu" - "kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn.cu" - "kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn.cu") + "kernels/quantization/fp8/per_token_group_quant.cu") set_gencode_flags_for_srcs( SRCS "${APHRODITE_EXT_SRC}" @@ -800,6 +790,86 @@ set(APHRODITE_EXT_SRC endif() endif() + # + # SageAttention kernels + # + + # Only build SageAttention kernels if we are building for at least SM 8.0 compatible archs + cuda_archs_loose_intersection(SAGE_ATTN_ARCHS "8.0;8.6;8.7;8.9;9.0+PTX" "${CUDA_ARCHS}") + if (SAGE_ATTN_ARCHS) + + # Base SageAttention sources (always included) + set(SAGE_ATTN_BASE_SRCS + "kernels/attention/sage_attn/fused/fused.cu") + + # SM 8.0 specific kernels + cuda_archs_loose_intersection(SAGE_ATTN_SM80_ARCHS "8.0;8.6;8.7" "${CUDA_ARCHS}") + set(SAGE_ATTN_SM80_SRCS) + if (SAGE_ATTN_SM80_ARCHS) + list(APPEND SAGE_ATTN_SM80_SRCS + "kernels/attention/sage_attn/qattn/qk_int_sv_f16_cuda_sm80.cu") + set_gencode_flags_for_srcs( + SRCS "${SAGE_ATTN_SM80_SRCS}" + CUDA_ARCHS "${SAGE_ATTN_SM80_ARCHS}") + message(STATUS "Building SageAttention SM80 kernels for archs: ${SAGE_ATTN_SM80_ARCHS}") + endif() + + # SM 8.9 specific kernels + cuda_archs_loose_intersection(SAGE_ATTN_SM89_ARCHS "8.9" "${CUDA_ARCHS}") + set(SAGE_ATTN_SM89_SRCS) + if (SAGE_ATTN_SM89_ARCHS) + list(APPEND SAGE_ATTN_SM89_SRCS + "kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f16_attn_inst_buf.cu" + "kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf.cu" + "kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_attn_inst_buf.cu" + "kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_attn.cu" + "kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf.cu" + "kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn.cu" + "kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn.cu") + set_gencode_flags_for_srcs( + SRCS "${SAGE_ATTN_SM89_SRCS}" + CUDA_ARCHS "${SAGE_ATTN_SM89_ARCHS}") + # Define a compile-time macro to indicate SM89 kernels are available + list(APPEND APHRODITE_GPU_FLAGS "-DSAGE_ATTN_HAS_SM89=1") + message(STATUS "Building SageAttention SM89 kernels for archs: ${SAGE_ATTN_SM89_ARCHS}") + else() + list(APPEND APHRODITE_GPU_FLAGS "-DSAGE_ATTN_HAS_SM89=0") + endif() + + # SM 9.0 specific kernels + cuda_archs_loose_intersection(SAGE_ATTN_SM90_ARCHS "9.0+PTX" "${CUDA_ARCHS}") + set(SAGE_ATTN_SM90_SRCS) + if (SAGE_ATTN_SM90_ARCHS) + list(APPEND SAGE_ATTN_SM90_SRCS + "kernels/attention/sage_attn/qattn/qk_int_sv_f8_cuda_sm90.cu") + set_gencode_flags_for_srcs( + SRCS "${SAGE_ATTN_SM90_SRCS}" + CUDA_ARCHS "${SAGE_ATTN_SM90_ARCHS}") + # Define a compile-time macro to indicate SM90 kernels are available + list(APPEND APHRODITE_GPU_FLAGS "-DSAGE_ATTN_HAS_SM90=1") + message(STATUS "Building SageAttention SM90 kernels for archs: ${SAGE_ATTN_SM90_ARCHS}") + else() + list(APPEND APHRODITE_GPU_FLAGS "-DSAGE_ATTN_HAS_SM90=0") + endif() + + set(SAGE_ATTN_SRCS ${SAGE_ATTN_BASE_SRCS}) + list(APPEND SAGE_ATTN_SRCS ${SAGE_ATTN_SM80_SRCS}) + list(APPEND SAGE_ATTN_SRCS ${SAGE_ATTN_SM89_SRCS}) + list(APPEND SAGE_ATTN_SRCS ${SAGE_ATTN_SM90_SRCS}) + + + set_gencode_flags_for_srcs( + SRCS "${SAGE_ATTN_BASE_SRCS}" + CUDA_ARCHS "${SAGE_ATTN_ARCHS}") + + list(APPEND APHRODITE_EXT_SRC "${SAGE_ATTN_SRCS}") + + message(STATUS "Building SageAttention kernels for archs: ${SAGE_ATTN_ARCHS}") + else() + message(STATUS "Not building SageAttention kernels as no compatible archs found" + " in CUDA target architectures (requires SM 8.0 or above)") + endif() + # if CUDA endif endif() diff --git a/kernels/attention/sage_attn/qattn/qk_int_sv_f16_cuda_sm80.cu b/kernels/attention/sage_attn/qattn/qk_int_sv_f16_cuda_sm80.cu index 693d466e1a..ae7eddcee8 100644 --- a/kernels/attention/sage_attn/qattn/qk_int_sv_f16_cuda_sm80.cu +++ b/kernels/attention/sage_attn/qattn/qk_int_sv_f16_cuda_sm80.cu @@ -668,7 +668,6 @@ __global__ void qk_int_sv_f16_attn_kernel(int8_t *__restrict__ Q, int8_t *__rest } // } -#endif } // tensor_layout 0 for [B, N, H, D], 1 for [B, H, N, D] diff --git a/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f16_attn_inst_buf.cu b/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f16_attn_inst_buf.cu index ccf81363a3..f696623f1c 100644 --- a/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f16_attn_inst_buf.cu +++ b/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f16_attn_inst_buf.cu @@ -1,4 +1,3 @@ -#include "attn_cuda_sm89.h" #include "qk_int_sv_f8_cuda_sm89.cuh" torch::Tensor qk_int8_sv_f8_accum_f16_attn_inst_buf(torch::Tensor query, diff --git a/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf.cu b/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf.cu index 663bdea2fb..e3252fa5ad 100644 --- a/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf.cu +++ b/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf.cu @@ -1,4 +1,3 @@ -#include "attn_cuda_sm89.h" #include "qk_int_sv_f8_cuda_sm89.cuh" torch::Tensor qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf(torch::Tensor query, torch::Tensor key, diff --git a/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_attn.cu b/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_attn.cu index 1170bee362..6e8c7bd57b 100644 --- a/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_attn.cu +++ b/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_attn.cu @@ -1,4 +1,3 @@ -#include "attn_cuda_sm89.h" #include "qk_int_sv_f8_cuda_sm89.cuh" torch::Tensor qk_int8_sv_f8_accum_f32_attn(torch::Tensor query, diff --git a/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_attn_inst_buf.cu b/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_attn_inst_buf.cu index 3eafe0454d..fbdb83aed3 100644 --- a/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_attn_inst_buf.cu +++ b/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_attn_inst_buf.cu @@ -1,4 +1,3 @@ -#include "attn_cuda_sm89.h" #include "qk_int_sv_f8_cuda_sm89.cuh" torch::Tensor qk_int8_sv_f8_accum_f32_attn_inst_buf(torch::Tensor query, torch::Tensor key, diff --git a/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn.cu b/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn.cu index f51feec1ef..e8d1a4656a 100644 --- a/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn.cu +++ b/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn.cu @@ -1,4 +1,3 @@ -#include "attn_cuda_sm89.h" #include "qk_int_sv_f8_cuda_sm89.cuh" torch::Tensor qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(torch::Tensor query, torch::Tensor key, diff --git a/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf.cu b/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf.cu index c48b02d9fc..b089b946f8 100644 --- a/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf.cu +++ b/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf.cu @@ -1,4 +1,3 @@ -#include "attn_cuda_sm89.h" #include "qk_int_sv_f8_cuda_sm89.cuh" torch::Tensor qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(torch::Tensor query, torch::Tensor key, diff --git a/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn.cu b/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn.cu index 8a575a939e..28e8562b45 100644 --- a/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn.cu +++ b/kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn.cu @@ -1,4 +1,3 @@ -#include "attn_cuda_sm89.h" #include "qk_int_sv_f8_cuda_sm89.cuh" torch::Tensor qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn(torch::Tensor query, torch::Tensor key, diff --git a/kernels/ops.h b/kernels/ops.h index d2ef9ad913..12c29a8aa0 100644 --- a/kernels/ops.h +++ b/kernels/ops.h @@ -545,6 +545,10 @@ torch::Tensor qk_int8_sv_f16_accum_f16_fuse_v_mean_attn(torch::Tensor query, double sm_scale, int64_t return_lse); +// SageAttention SM89+ specific functions (FP8 kernels) +// Only declare these if SM89 kernels are compiled +#if defined(SAGE_ATTN_HAS_SM89) && SAGE_ATTN_HAS_SM89 + torch::Tensor qk_int8_sv_f8_accum_f32_attn(torch::Tensor query, torch::Tensor key, torch::Tensor value, @@ -662,3 +666,4 @@ torch::Tensor qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf( int64_t return_lse); #endif +#endif diff --git a/kernels/torch_bindings.cpp b/kernels/torch_bindings.cpp index d9125f90a4..fb723d42d8 100644 --- a/kernels/torch_bindings.cpp +++ b/kernels/torch_bindings.cpp @@ -918,6 +918,9 @@ ops.def("cutlass_encode_and_reorder_int4b(Tensor B) -> Tensor"); " int return_lse) -> Tensor"); ops.impl("qk_int8_sv_f16_accum_f16_fuse_v_mean_attn", torch::kCUDA, &qk_int8_sv_f16_accum_f16_fuse_v_mean_attn); + // SageAttention SM89+ specific functions (FP8 kernels) + // Only register these if SM89 kernels are compiled +#if defined(SAGE_ATTN_HAS_SM89) && SAGE_ATTN_HAS_SM89 ops.def( "qk_int8_sv_f8_accum_f32_attn(" " Tensor query," @@ -1059,6 +1062,7 @@ ops.def("cutlass_encode_and_reorder_int4b(Tensor B) -> Tensor"); " int return_lse) -> Tensor"); ops.impl("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf", torch::kCUDA, &qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf); +#endif #endif }