Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 83 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -774,7 +774,7 @@ set(APHRODITE_EXT_SRC
SRCS "${SRCS}"
CUDA_ARCHS "${W4A8_ARCHS}")

list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND APHRODITE_EXT_SRC "${SRCS}")

message(STATUS "Building W4A8 kernels for archs: ${W4A8_ARCHS}")
else()
Expand All @@ -790,6 +790,86 @@ set(APHRODITE_EXT_SRC
endif()
endif()

#
# SageAttention kernels
#

# Only build SageAttention kernels if we are building for at least SM 8.0 compatible archs
cuda_archs_loose_intersection(SAGE_ATTN_ARCHS "8.0;8.6;8.7;8.9;9.0+PTX" "${CUDA_ARCHS}")
if (SAGE_ATTN_ARCHS)

# Base SageAttention sources (always included)
set(SAGE_ATTN_BASE_SRCS
"kernels/attention/sage_attn/fused/fused.cu")

# SM 8.0 specific kernels
cuda_archs_loose_intersection(SAGE_ATTN_SM80_ARCHS "8.0;8.6;8.7" "${CUDA_ARCHS}")
set(SAGE_ATTN_SM80_SRCS)
if (SAGE_ATTN_SM80_ARCHS)
list(APPEND SAGE_ATTN_SM80_SRCS
"kernels/attention/sage_attn/qattn/qk_int_sv_f16_cuda_sm80.cu")
set_gencode_flags_for_srcs(
SRCS "${SAGE_ATTN_SM80_SRCS}"
CUDA_ARCHS "${SAGE_ATTN_SM80_ARCHS}")
message(STATUS "Building SageAttention SM80 kernels for archs: ${SAGE_ATTN_SM80_ARCHS}")
endif()

# SM 8.9 specific kernels
cuda_archs_loose_intersection(SAGE_ATTN_SM89_ARCHS "8.9" "${CUDA_ARCHS}")
set(SAGE_ATTN_SM89_SRCS)
if (SAGE_ATTN_SM89_ARCHS)
list(APPEND SAGE_ATTN_SM89_SRCS
"kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f16_attn_inst_buf.cu"
"kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf.cu"
"kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_attn_inst_buf.cu"
"kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_attn.cu"
"kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf.cu"
"kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn.cu"
"kernels/attention/sage_attn/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn.cu")
set_gencode_flags_for_srcs(
SRCS "${SAGE_ATTN_SM89_SRCS}"
CUDA_ARCHS "${SAGE_ATTN_SM89_ARCHS}")
# Define a compile-time macro to indicate SM89 kernels are available
list(APPEND APHRODITE_GPU_FLAGS "-DSAGE_ATTN_HAS_SM89=1")
message(STATUS "Building SageAttention SM89 kernels for archs: ${SAGE_ATTN_SM89_ARCHS}")
else()
list(APPEND APHRODITE_GPU_FLAGS "-DSAGE_ATTN_HAS_SM89=0")
endif()

# SM 9.0 specific kernels
cuda_archs_loose_intersection(SAGE_ATTN_SM90_ARCHS "9.0+PTX" "${CUDA_ARCHS}")
set(SAGE_ATTN_SM90_SRCS)
if (SAGE_ATTN_SM90_ARCHS)
list(APPEND SAGE_ATTN_SM90_SRCS
"kernels/attention/sage_attn/qattn/qk_int_sv_f8_cuda_sm90.cu")
set_gencode_flags_for_srcs(
SRCS "${SAGE_ATTN_SM90_SRCS}"
CUDA_ARCHS "${SAGE_ATTN_SM90_ARCHS}")
# Define a compile-time macro to indicate SM90 kernels are available
list(APPEND APHRODITE_GPU_FLAGS "-DSAGE_ATTN_HAS_SM90=1")
message(STATUS "Building SageAttention SM90 kernels for archs: ${SAGE_ATTN_SM90_ARCHS}")
else()
list(APPEND APHRODITE_GPU_FLAGS "-DSAGE_ATTN_HAS_SM90=0")
endif()

set(SAGE_ATTN_SRCS ${SAGE_ATTN_BASE_SRCS})
list(APPEND SAGE_ATTN_SRCS ${SAGE_ATTN_SM80_SRCS})
list(APPEND SAGE_ATTN_SRCS ${SAGE_ATTN_SM89_SRCS})
list(APPEND SAGE_ATTN_SRCS ${SAGE_ATTN_SM90_SRCS})


set_gencode_flags_for_srcs(
SRCS "${SAGE_ATTN_BASE_SRCS}"
CUDA_ARCHS "${SAGE_ATTN_ARCHS}")

list(APPEND APHRODITE_EXT_SRC "${SAGE_ATTN_SRCS}")

message(STATUS "Building SageAttention kernels for archs: ${SAGE_ATTN_ARCHS}")
else()
message(STATUS "Not building SageAttention kernels as no compatible archs found"
" in CUDA target architectures (requires SM 8.0 or above)")
endif()

# if CUDA endif
endif()

Expand Down Expand Up @@ -953,10 +1033,10 @@ if (APHRODITE_GPU_LANG STREQUAL "CUDA")
include(cmake/external_project/flashmla.cmake)

# Only build flash attention if not disabled
if (NOT DEFINED ENV{APHRODITE_DISABLE_FLASH_ATTN} OR NOT $ENV{APHRODITE_DISABLE_FLASH_ATTN})
if (NOT DEFINED ENV{APHRODITE_DISABLE_FLASH_ATTN_COMPILE} OR NOT $ENV{APHRODITE_DISABLE_FLASH_ATTN_COMPILE})
# vllm-flash-attn should be last as it overwrites some CMake functions
include(cmake/external_project/vllm_flash_attn.cmake)
else()
message(STATUS "Flash attention compilation disabled by APHRODITE_DISABLE_FLASH_ATTN")
message(STATUS "Flash attention compilation disabled by APHRODITE_DISABLE_FLASH_ATTN_COMPILE")
endif()
endif()
Loading
Loading