diff --git a/CMakeLists.txt b/CMakeLists.txt index f7cabf6198..43c65fa9e9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -246,7 +246,8 @@ set(APHRODITE_EXT_SRC "kernels/activation_kernels.cu" "kernels/layernorm_kernels.cu" "kernels/layernorm_quant_kernels.cu" - "kernels/sampler.cu" + "kernels/sampling/repetition_penalty.cu" + "kernels/sampling/topk_topp.cu" "kernels/cuda_view.cu" "kernels/quantization/squeezellm/quant_cuda_kernel.cu" "kernels/quantization/gptq/q_gemm.cu" @@ -300,7 +301,6 @@ set(APHRODITE_EXT_SRC "kernels/quantization/awq/gemm_kernels.cu" "kernels/quantization/quip/origin_order.cu" "kernels/permute_cols.cu" - "kernels/sampling/sampling.cu" "kernels/quantization/cutlass_w8a8/scaled_mm_entry.cu" "kernels/quantization/fp4/nvfp4_quant_entry.cu" "kernels/quantization/fp4/nvfp4_scaled_mm_entry.cu" diff --git a/aphrodite/_custom_ops.py b/aphrodite/_custom_ops.py index 0d13b895a2..212e249d64 100644 --- a/aphrodite/_custom_ops.py +++ b/aphrodite/_custom_ops.py @@ -349,6 +349,30 @@ def apply_repetition_penalties(logits: torch.Tensor, prompt_mask: torch.Tensor, repetition_penalties) +def apply_top_k_top_p_cuda( + logits: torch.Tensor, + output_ids: torch.Tensor, + top_k_values: torch.Tensor, + top_p_values: Optional[torch.Tensor] = None, + curand_states: Optional[torch.Tensor] = None, + output_logprobs: Optional[torch.Tensor] = None, + normalize_logprobs: bool = False, +) -> None: + """Apply top-k and top-p sampling using CUDA kernel. + Args: + logits: The logits tensor of shape [num_seqs, vocab_size]. + output_ids: Output tensor for sampled token ids [num_seqs]. + top_k_values: Top-k values per sequence [num_seqs]. + top_p_values: Optional top-p values per sequence [num_seqs]. + curand_states: Optional CUDA random states for sampling [num_seqs]. + output_logprobs: Optional output for log probabilities [num_seqs]. + normalize_logprobs: Whether to normalize log probabilities. + """ + torch.ops._C.topk_topp_sampling( + logits, output_ids, top_k_values, top_p_values, curand_states, + output_logprobs, normalize_logprobs) + + def advance_step_flashattn(num_seqs: int, num_queries: int, block_size: int, input_tokens: torch.Tensor, sampled_token_ids: torch.Tensor, @@ -2020,129 +2044,3 @@ def int8_scaled_mm_with_quant_fake( M = mat1.size(0) N = mat2.size(0) return torch.empty((M, N), dtype=out_dtype) - - -# Sampling Kernels -def sampling_from_probs(probs: torch.Tensor, - uniform_samplers: torch.Tensor, - deterministic: bool = True, - check_nan: bool = False) -> torch.Tensor: - if check_nan and torch.any(torch.isnan(probs)): - raise ValueError("NaN detected in probs") - return torch.ops._C.sampling_from_probs(probs, uniform_samplers, - deterministic) - - -def _to_tensor_scalar_tuple(x): - if isinstance(x, torch.Tensor): - return (x, 0) - else: - return (None, x) - - -def top_p_sampling_from_probs( - probs: torch.Tensor, - uniform_samples: torch.Tensor, - top_p: Union[torch.Tensor, float], - deterministic: bool = True, - check_nan: bool = False) -> tuple[torch.Tensor, torch.Tensor]: - if check_nan and torch.any(torch.isnan(probs)): - raise ValueError("NaN detected in probs") - return torch.ops._C.top_p_sampling_from_probs( - probs, uniform_samples, *_to_tensor_scalar_tuple(top_p), deterministic) - - -def top_k_sampling_from_probs( - probs: torch.Tensor, - uniform_samples: torch.Tensor, - top_k: Union[torch.Tensor, int], - deterministic: bool = True, - check_nan: bool = False) -> tuple[torch.Tensor, torch.Tensor]: - if check_nan and torch.any(torch.isnan(probs)): - raise ValueError("NaN detected in probs") - return torch.ops._C.top_k_sampling_from_probs( - probs, uniform_samples, *_to_tensor_scalar_tuple(top_k), deterministic) - - -def min_p_sampling_from_probs( - probs: torch.Tensor, - uniform_samples: torch.Tensor, - min_p: Union[torch.Tensor, float], - deterministic: bool = True, - check_nan: bool = False) -> tuple[torch.Tensor, torch.Tensor]: - if check_nan and torch.any(torch.isnan(probs)): - raise ValueError("NaN detected in probs") - return torch.ops._C.min_p_sampling_from_probs( - probs, uniform_samples, *_to_tensor_scalar_tuple(min_p), deterministic) - - -def top_k_mask_logits( - logits: torch.Tensor, - top_k: Union[torch.Tensor, int], -) -> torch.Tensor: - return torch.ops._C.top_k_mask_logits(logits, - *_to_tensor_scalar_tuple(top_k)) - - -def top_p_renorm_prob( - probs: torch.Tensor, - top_p: Union[torch.Tensor, float], -) -> torch.Tensor: - return torch.ops._C.top_p_renorm_prob(probs, - *_to_tensor_scalar_tuple(top_p)) - - -def top_k_renorm_prob( - probs: torch.Tensor, - top_k: Union[torch.Tensor, int], -) -> torch.Tensor: - return torch.ops._C.top_k_renorm_prob(probs, - *_to_tensor_scalar_tuple(top_k)) - - -def top_k_top_p_sampling_from_logits( - probs: torch.Tensor, - uniform_samples: torch.Tensor, - top_k: Union[torch.Tensor, int], - top_p: Union[torch.Tensor, float], - filter_apply_order: str = "top_k_first", - deterministic: bool = True, - check_nan: bool = False, -) -> tuple[torch.Tensor, torch.Tensor]: - if filter_apply_order == "top_k_first": - masked_logits = top_k_mask_logits(probs, top_k) - probs = torch.softmax(masked_logits, dim=-1) - return top_p_sampling_from_probs(probs, uniform_samples, top_p, - deterministic, check_nan) - elif filter_apply_order == "joint": - probs = torch.softmax(probs, dim=-1) - if check_nan and torch.any(torch.isnan(probs)): - raise ValueError("NaN detected in probs") - return torch.ops._C.top_k_top_p_sampling_from_logits( - probs, uniform_samples, *_to_tensor_scalar_tuple(top_k), - *_to_tensor_scalar_tuple(top_p), deterministic) - else: - raise ValueError(f"Invalid filter_apply_order: {filter_apply_order}") - - -def top_k_top_p_sampling_from_probs( - probs: torch.Tensor, - uniform_samples: torch.Tensor, - top_k: Union[torch.Tensor, int], - top_p: Union[torch.Tensor, float], - filter_apply_order: str = "top_k_first", - deterministic: bool = True, - check_nan: bool = False, -) -> tuple[torch.Tensor, torch.Tensor]: - if filter_apply_order == "top_k_first": - renorm_probs = top_k_renorm_prob(probs, top_k) - return top_p_sampling_from_probs(renorm_probs, uniform_samples, top_p, - deterministic, check_nan) - elif filter_apply_order == "joint": - if check_nan and torch.any(torch.isnan(probs)): - raise ValueError("NaN detected in probs") - return torch.ops._C.top_k_top_p_sampling_from_probs( - probs, uniform_samples, *_to_tensor_scalar_tuple(top_k), - *_to_tensor_scalar_tuple(top_p), deterministic) - else: - raise ValueError(f"Invalid filter_apply_order: {filter_apply_order}") diff --git a/aphrodite/common/envs.py b/aphrodite/common/envs.py index 7484e08187..83563b8f1c 100755 --- a/aphrodite/common/envs.py +++ b/aphrodite/common/envs.py @@ -160,7 +160,7 @@ APHRODITE_REQUEST_LEVEL_METRICS: bool = False APHRODITE_USE_SAMPLING_KERNELS: bool = False APHRODITE_NO_DEPRECATION_WARNING: bool = False - APHRODITE_DISABLE_FLASH_ATTN: bool = False + APHRODITE_DISABLE_FLASH_ATTN_COMPILE: bool = False APHRODITE_DYNAMIC_ROPE_SCALING: bool = False APHRODITE_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False APHRODITE_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False diff --git a/aphrodite/v1/sample/ops/temperatures.py b/aphrodite/v1/sample/ops/temperatures.py index 90631695df..bf1a786e8b 100644 --- a/aphrodite/v1/sample/ops/temperatures.py +++ b/aphrodite/v1/sample/ops/temperatures.py @@ -2,6 +2,7 @@ _SAMPLING_EPS = 1e-5 + def _tensor_or_zeros(tens, like_tensor): return tens if tens is not None else torch.zeros_like(like_tensor) diff --git a/aphrodite/v1/sample/ops/topk_topp_sampler.py b/aphrodite/v1/sample/ops/topk_topp_sampler.py index 49f2c8d739..3b30e35eda 100644 --- a/aphrodite/v1/sample/ops/topk_topp_sampler.py +++ b/aphrodite/v1/sample/ops/topk_topp_sampler.py @@ -8,6 +8,7 @@ from aphrodite.common import envs from aphrodite.common.logger import log_once from aphrodite.platforms import current_platform +from aphrodite._custom_ops import apply_top_k_top_p_cuda try: import flashinfer.sampling @@ -18,8 +19,9 @@ try: from aphrodite.distributed.parallel_state import ( get_tensor_model_parallel_rank) + rank = get_tensor_model_parallel_rank() except Exception: - get_tensor_model_parallel_rank = lambda: 0 + rank = 0 class TopKTopPSampler(nn.Module): @@ -45,28 +47,34 @@ def __init__(self): # earlier design. # https://github.com/flashinfer-ai/flashinfer/releases/ # tag/v0.2.3 - if get_tensor_model_parallel_rank() == 0: + if rank == 0: logger.info( "FlashInfer version >= 0.2.3 required. " "Falling back to default sampling implementation.") self.forward = self.forward_native - elif envs.APHRODITE_USE_SAMPLING_KERNELS is not False: + elif envs.APHRODITE_USE_SAMPLING_KERNELS is True: + # Use custom CUDA kernel for top-k/top-p sampling + if rank == 0: + logger.info("Using custom CUDA kernel for top-p & " + "top-k sampling.") + self.forward = self.forward_cuda_kernel + elif envs.APHRODITE_USE_FLASHINFER_SAMPLER is not None: # NOTE: The V0 sampler doesn't use FlashInfer for - # sampling unless APHRODITE_USE_SAMPLING_KERNELS=1 (i.e., by + # sampling unless APHRODITE_USE_FLASHINFER_SAMPLER=1 (i.e., by # default it is unused). For backward compatibility, we set - # `APHRODITE_USE_SAMPLING_KERNELS` as None by default and + # `APHRODITE_USE_FLASHINFER_SAMPLER` as None by default and # interpret it differently in V0 and V1 samplers: In V0, # None means False, while in V1, None means True. This is # why we use the condition - # `envs.APHRODITE_USE_SAMPLING_KERNELS is not False` here. + # `envs.APHRODITE_USE_FLASHINFER_SAMPLER is not None` here. logger.info("Using FlashInfer for top-p & top-k sampling.") self.forward = self.forward_cuda else: if get_tensor_model_parallel_rank() == 0: logger.warning( "FlashInfer is available, but it is not enabled. " - "Falling back to the PyTorch-native implementation " - "of top-p & top-k sampling. For the best " + "Falling back to the PyTorch-native implementation" + " of top-p & top-k sampling. For the best " "performance, please set " "APHRODITE_USE_SAMPLING_KERNELS=1.") self.forward = self.forward_native @@ -136,6 +144,67 @@ def forward_tpu( probs = logits.softmax(dim=-1, dtype=torch.float32) return random_sample(probs, generators) + def forward_cuda_kernel( + self, + logits: torch.Tensor, + generators: dict[int, torch.Generator], + k: Optional[torch.Tensor], + p: Optional[torch.Tensor], + ) -> torch.Tensor: + """Use custom CUDA kernel for top-k and top-p sampling.""" + if k is None and p is None: + # No filtering needed, use regular sampling + probs = logits.softmax(dim=-1, dtype=torch.float32) + return random_sample(probs, generators) + + if generators: + log_once( + "WARNING", + "Custom CUDA kernel does not support per-request generators. " + "Falling back to PyTorch-native implementation.") + return self.forward_native(logits, generators, k, p) + + num_seqs = logits.size(0) + vocab_size = logits.size(1) + + # Prepare output tensor for the CUDA kernel + output_ids = torch.empty(num_seqs, dtype=torch.int64, + device=logits.device) + + # Prepare top-k and top-p values + # Convert to the format expected by CUDA kernel + if k is not None: + top_k_values = k.to(dtype=torch.int32, device=logits.device) + else: + top_k_values = torch.full((num_seqs,), vocab_size, + dtype=torch.int32, device=logits.device) + + if p is not None: + top_p_values = p.to(dtype=torch.float32, device=logits.device) + else: + top_p_values = None + + # Call the CUDA kernel + # Note: We don't use curand_states for now, relying on the + # kernel's internal randomness + try: + apply_top_k_top_p_cuda( + logits=logits, + output_ids=output_ids, + top_k_values=top_k_values, + top_p_values=top_p_values, + curand_states=None, # Not using CUDA random states for now + output_logprobs=None, # Not requesting log probabilities + normalize_logprobs=False + ) + return output_ids + except Exception as e: + log_once( + "WARNING", + f"Custom CUDA kernel failed: {e}. Falling back to " + "PyTorch-native implementation.") + return self.forward_native(logits, generators, k, p) + def apply_top_k_top_p_tpu( logits: torch.Tensor, diff --git a/kernels/ops.h b/kernels/ops.h index 69117e3484..0d662cbc93 100644 --- a/kernels/ops.h +++ b/kernels/ops.h @@ -97,6 +97,13 @@ void apply_repetition_penalties_(torch::Tensor& logits, const torch::Tensor& output_mask, const torch::Tensor& repetition_penalties); +void topk_topp_sampling(torch::Tensor& logits, torch::Tensor& output_ids, + const torch::Tensor& top_k_values, + const std::optional& top_p_values, + const std::optional& curand_states, + std::optional& output_logprobs, + bool normalize_logprobs = false); + void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, torch::Tensor& scale, double epsilon); @@ -205,39 +212,6 @@ void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta, torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm); -// Sampling kernels -#ifndef USE_ROCM -torch::Tensor sampling_from_probs(torch::Tensor probs, - torch::Tensor uniform_samples, - bool deterministic); -std::vector top_p_sampling_from_probs( - torch::Tensor probs, torch::Tensor uniform_samples, - std::optional maybe_top_p_arr, double top_p_val, - bool deterministic); -std::vector top_k_sampling_from_probs( - torch::Tensor probs, torch::Tensor uniform_samples, - std::optional maybe_top_k_arr, int64_t top_k_val, - bool deterministic); -std::vector min_p_sampling_from_probs( - torch::Tensor probs, torch::Tensor uniform_samples, - std::optional maybe_min_p_arr, double min_p_val, - bool deterministic); -std::vector top_k_top_p_sampling_from_probs( - torch::Tensor probs, torch::Tensor uniform_samples, - std::optional maybe_top_k_arr, double top_k_val, - std::optional maybe_top_p_arr, double top_p_val, - bool deterministic); -torch::Tensor top_p_renorm_prob(torch::Tensor probs, - std::optional maybe_top_p_arr, - double top_p_val); -torch::Tensor top_k_renorm_prob(torch::Tensor probs, - std::optional maybe_top_k_arr, - int64_t top_k_val); -torch::Tensor top_k_mask_logits(torch::Tensor logits, - std::optional maybe_top_k_arr, - int64_t top_k_val); - -#endif // Quantization kernels #ifndef USE_ROCM diff --git a/kernels/sampling/math.cuh b/kernels/sampling/math.cuh deleted file mode 100644 index 6117280974..0000000000 --- a/kernels/sampling/math.cuh +++ /dev/null @@ -1,159 +0,0 @@ -/* - * Copyright (c) 2024 by PygmalionAI team. - * Copyright (c) 2023 by FlashInfer team. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef APHRODITE_MATH_CUH_ -#define APHRODITE_MATH_CUH_ - -#include -#include - -namespace aphrodite { -namespace math { - -// log2(e) -constexpr float log2e = 1.44269504088896340736f; - -__forceinline__ __device__ half2 uint32_as_half2(uint32_t x) { - return *(half2*)&x; -} - -__forceinline__ __device__ uint32_t half2_as_uint32(half2 x) { - return *(uint32_t*)&x; -} - -/*! - * \brief Wrapper of PTX ex2.approx instruction, which computes 2^x - * \param x input - */ -__forceinline__ __device__ float ptx_exp2(float x) { - float y; - asm volatile("ex2.approx.ftz.f32 %0, %1;" : "=f"(y) : "f"(x)); - return y; -} - -/*! - * \brief Wrapper of PTX lg2.approx instruction, which computes log2(x) - * \param x input - */ -__forceinline__ __device__ float ptx_log2(float x) { - float y; - asm volatile("lg2.approx.ftz.f32 %0, %1;" : "=f"(y) : "f"(x)); - return y; -} - -/*! - * \brief Wrapper of PTX ex2.approx.f16x2 instruction, which computes 2^x - * \param x input - */ -__forceinline__ __device__ half2 ptx_exp2(half2 x) { - uint32_t y_u32; - uint32_t x_u32 = half2_as_uint32(x); - asm volatile("ex2.approx.f16x2 %0, %1;" : "=r"(y_u32) : "r"(x_u32)); - return uint32_as_half2(y_u32); -} - -/*! - * \brief Wrapper of PTX ex2.approx.f16 instruction, which computes 2^x - * \param x input - */ -__forceinline__ __device__ half ptx_exp2(half x) { - ushort y_u16; - asm volatile("ex2.approx.f16 %0, %1;" - : "=h"(y_u16) - : "h"(__half_as_ushort(x))); - return __ushort_as_half(y_u16); -} - -/*! - * \brief Wrapper of PTX rcp.approx instruction, which computes 1/x - * \param x input - */ -__forceinline__ __device__ float ptx_rcp(float x) { - float y; - asm volatile("rcp.approx.ftz.f32 %0, %1;" : "=f"(y) : "f"(x)); - return y; -} - -/*! - * \brief Wrapper of PTX shfl.sync.bfly instruction, which performs a butterfly - * shuffle between threads in a warp. \param x The value in the source lane - * \param lane_mask The mask to perform thread index xor with: y[i] <- x[i ^ - * delta] - */ -__forceinline__ __device__ float shfl_xor_sync(float x, int lane_mask) { - float y; - asm volatile("shfl.sync.bfly.b32 %0, %1, %2, 0x1f, 0xffffffff;" - : "=f"(y) - : "f"(x), "r"(lane_mask)); - return y; -} - -/*! - * \brief Wrapper of PTX shfl.sync.bfly instruction on half2, which performs a - * butterfly shuffle between threads in a warp. \param x The value in the source - * lane \param lane_mask The mask to perform thread index xor with: y[i] <- x[i - * ^ lane_mask] - */ -__forceinline__ __device__ half2 shfl_xor_sync(half2 x, int lane_mask) { - return __shfl_xor_sync(0xffffffff, x, lane_mask); -} - -/*! - * \brief Wrapper of PTX rsqrt approximation instruction, which computes - * 1/sqrt(x) \param x input - */ -__forceinline__ __device__ float rsqrt(float x) { - float y; - asm volatile("rsqrt.approx.ftz.f32 %0, %1;" : "=f"(y) : "f"(x)); - return y; -} - -/*! - * \brief Wrapper of PTX tanh.approx.f32 instruction, which computes tanh(x) - * \param x input - */ -__forceinline__ __device__ float tanh(float x) { - float y; - asm volatile("tanh.approx.f32 %0, %1;" : "=f"(y) : "f"(x)); - return y; -} - -/*! - * \brief Wrapper of PTX tanh.approx.f16x2 instruction, which computes tanh(x) - * \param x input - */ -__forceinline__ __device__ half2 tanh(half2 x) { - uint32_t y_u32; - uint32_t x_u32 = half2_as_uint32(x); - asm volatile("tanh.approx.f16x2 %0, %1;" : "=r"(y_u32) : "r"(x_u32)); - return uint32_as_half2(y_u32); -} - -/*! - * \brief Wrapper of PTX tanh.approx.f16 instruction, which computes tanh(x) - * \param x input - */ -__forceinline__ __device__ half tanh(half x) { - ushort y_u16; - asm volatile("tanh.approx.f16 %0, %1;" - : "=h"(y_u16) - : "h"(__half_as_ushort(x))); - return __ushort_as_half(y_u16); -} - -} // namespace math -} // namespace aphrodite -#endif // APHRODITE_MATH_CUH_ \ No newline at end of file diff --git a/kernels/sampler.cu b/kernels/sampling/repetition_penalty.cu similarity index 100% rename from kernels/sampler.cu rename to kernels/sampling/repetition_penalty.cu diff --git a/kernels/sampling/sampling.cu b/kernels/sampling/sampling.cu deleted file mode 100644 index 032a8b0c27..0000000000 --- a/kernels/sampling/sampling.cu +++ /dev/null @@ -1,391 +0,0 @@ -/* - * Copyright (c) 2024 by PygmalionAI team. - * Copyright (c) 2024 by FlashInfer team. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include - -#include "sampling.cuh" -#include "../ops.h" -#include "utils.cuh" - -// Check utils -#define CUDA_CHECK(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") - -#define CHECK_CONTIGUOUS(x) \ - TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") - -#define CHECK_INPUT(x) \ - CUDA_CHECK(x); \ - CHECK_CONTIGUOUS(x) - -#define CHECK_EQ(a, b) \ - TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b) - -#define CHECK_GE(a, b) \ - TORCH_CHECK((a) >= (b), "CHECK_GE(" #a ", " #b ") failed. ", a, " vs ", b) - -#define CHECK_DIM(d, x) \ - TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor") - -using namespace aphrodite; - -torch::Tensor sampling_from_probs(torch::Tensor probs, - torch::Tensor uniform_samples, - bool deterministic) { - CHECK_INPUT(probs); - CHECK_INPUT(uniform_samples); - auto device = probs.device(); - CHECK_EQ(uniform_samples.device(), device); - CHECK_DIM(2, probs); // probs: (batch_size, vocab_size) - CHECK_DIM(1, uniform_samples); // uniform_samples: (batch_size) - CHECK_EQ(probs.size(0), uniform_samples.size(0)); - unsigned int batch_size = probs.size(0); - unsigned int vocab_size = probs.size(1); - probs = probs.to(torch::kFloat32); - uniform_samples = uniform_samples.to(torch::kFloat32); - - cudaStream_t torch_current_stream = - c10::cuda::getCurrentCUDAStream(device.index()); - auto samples = - torch::empty({batch_size}, torch::dtype(torch::kInt32).device(device)); - - cudaError_t status = sampling::SamplingFromProb( - static_cast(probs.data_ptr()), - static_cast(uniform_samples.data_ptr()), - static_cast(samples.data_ptr()), batch_size, vocab_size, - deterministic, torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "SamplingFromProbs failed with error code " + - std::string(cudaGetErrorString(status))); - return samples; -} - -std::vector top_p_sampling_from_probs( - torch::Tensor probs, torch::Tensor uniform_samples, - std::optional maybe_top_p_arr, double top_p_val, - bool deterministic) { - CHECK_INPUT(probs); - CHECK_INPUT(uniform_samples); - auto device = probs.device(); - CHECK_EQ(uniform_samples.device(), device); - CHECK_DIM(2, probs); // probs: (batch_size, vocab_size) - CHECK_DIM( - 2, uniform_samples); // uniform_samples: (max_top_p_rounds, batch_size) - CHECK_EQ(probs.size(0), uniform_samples.size(1)); - unsigned int batch_size = probs.size(0); - unsigned int vocab_size = probs.size(1); - unsigned int max_top_p_rounds = uniform_samples.size(0); - bool has_top_p_arr = maybe_top_p_arr.has_value(); - auto top_p_arr = maybe_top_p_arr.value_or( - torch::empty({0}, torch::dtype(torch::kFloat32))); - if (has_top_p_arr) { - CHECK_INPUT(top_p_arr); - CHECK_DIM(1, top_p_arr); // top_p_arr: (batch_size,) - CHECK_EQ(top_p_arr.size(0), batch_size); - CHECK_EQ(top_p_arr.device(), device); - } - probs = probs.to(torch::kFloat32); - uniform_samples = uniform_samples.to(torch::kFloat32); - top_p_arr = top_p_arr.to(torch::kFloat32); - - cudaStream_t torch_current_stream = - c10::cuda::getCurrentCUDAStream(device.index()); - auto samples = - torch::empty({batch_size}, torch::dtype(torch::kInt32).device(device)); - auto success = - torch::empty({batch_size}, torch::dtype(torch::kBool).device(device)); - - cudaError_t status = sampling::TopPSamplingFromProb( - static_cast(probs.data_ptr()), - static_cast(uniform_samples.data_ptr()), - static_cast(samples.data_ptr()), - static_cast(success.data_ptr()), - has_top_p_arr ? static_cast(top_p_arr.data_ptr()) : nullptr, - batch_size, top_p_val, vocab_size, max_top_p_rounds, deterministic, - torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "TopPSamplingFromProbs failed with error code " + - std::string(cudaGetErrorString(status))); - - return {samples, success}; -} - -std::vector top_k_sampling_from_probs( - torch::Tensor probs, torch::Tensor uniform_samples, - std::optional maybe_top_k_arr, int64_t top_k_val, - bool deterministic) { - CHECK_INPUT(probs); - CHECK_INPUT(uniform_samples); - auto device = probs.device(); - CHECK_EQ(uniform_samples.device(), device); - CHECK_DIM(2, probs); // probs: (batch_size, vocab_size) - CHECK_DIM( - 2, uniform_samples); // uniform_samples: (max_top_k_rounds, batch_size) - CHECK_EQ(probs.size(0), uniform_samples.size(1)); - unsigned int batch_size = probs.size(0); - unsigned int vocab_size = probs.size(1); - unsigned int max_top_k_rounds = uniform_samples.size(0); - bool has_top_k_arr = maybe_top_k_arr.has_value(); - auto top_k_arr = - maybe_top_k_arr.value_or(torch::empty({0}, torch::dtype(torch::kInt32))); - if (has_top_k_arr) { - CHECK_INPUT(top_k_arr); - CHECK_DIM(1, top_k_arr); // top_k_arr: (batch_size,) - CHECK_EQ(top_k_arr.size(0), batch_size); - CHECK_EQ(top_k_arr.device(), device); - } - probs = probs.to(torch::kFloat32); - uniform_samples = uniform_samples.to(torch::kFloat32); - top_k_arr = top_k_arr.to(torch::kInt32); - - cudaStream_t torch_current_stream = - c10::cuda::getCurrentCUDAStream(device.index()); - auto samples = - torch::empty({batch_size}, torch::dtype(torch::kInt32).device(device)); - auto success = - torch::empty({batch_size}, torch::dtype(torch::kBool).device(device)); - - cudaError_t status = sampling::TopKSamplingFromProb( - static_cast(probs.data_ptr()), - static_cast(uniform_samples.data_ptr()), - static_cast(samples.data_ptr()), - static_cast(success.data_ptr()), - has_top_k_arr ? static_cast(top_k_arr.data_ptr()) : nullptr, - batch_size, top_k_val, vocab_size, max_top_k_rounds, deterministic, - torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "TopKSamplingFromProbs failed with error code " + - std::string(cudaGetErrorString(status))); - - return {samples, success}; -} - -std::vector min_p_sampling_from_probs( - torch::Tensor probs, torch::Tensor uniform_samples, - std::optional maybe_min_p_arr, double min_p_val, - bool deterministic) { - CHECK_INPUT(probs); - CHECK_INPUT(uniform_samples); - auto device = probs.device(); - CHECK_EQ(uniform_samples.device(), device); - CHECK_DIM(2, probs); // probs: (batch_size, vocab_size) - CHECK_DIM(2, uniform_samples); // uniform_samples: (max_rounds, batch_size) - unsigned int batch_size = probs.size(0); - unsigned int vocab_size = probs.size(1); - unsigned int max_rounds = uniform_samples.size(0); - CHECK_EQ(uniform_samples.size(1), batch_size); - bool has_min_p_arr = maybe_min_p_arr.has_value(); - auto min_p_arr = maybe_min_p_arr.value_or( - torch::empty({0}, torch::dtype(torch::kFloat32))); - if (has_min_p_arr) { - CHECK_INPUT(min_p_arr); - CHECK_DIM(1, min_p_arr); // min_p_arr: (batch_size,) - CHECK_EQ(min_p_arr.size(0), batch_size); - CHECK_EQ(min_p_arr.device(), device); - } - min_p_arr = min_p_arr.to(torch::kFloat32); - probs = probs.to(torch::kFloat32); - uniform_samples = uniform_samples.to(torch::kFloat32); - - cudaStream_t torch_current_stream = - c10::cuda::getCurrentCUDAStream(device.index()); - auto samples = - torch::empty({batch_size}, torch::dtype(torch::kInt32).device(device)); - auto success = - torch::empty({batch_size}, torch::dtype(torch::kBool).device(device)); - - cudaError_t status = sampling::MinPSamplingFromProb( - static_cast(probs.data_ptr()), - static_cast(uniform_samples.data_ptr()), - has_min_p_arr ? static_cast(min_p_arr.data_ptr()) : nullptr, - static_cast(samples.data_ptr()), - static_cast(success.data_ptr()), batch_size, min_p_val, vocab_size, - max_rounds, deterministic, torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "MinPSamplingFromProb failed with error code " + - std::string(cudaGetErrorString(status))); - - return {samples, success}; -} - -std::vector top_k_top_p_sampling_from_probs( - torch::Tensor probs, torch::Tensor uniform_samples, - std::optional maybe_top_k_arr, double top_k_val, - std::optional maybe_top_p_arr, double top_p_val, - bool deterministic) { - CHECK_INPUT(probs); - CHECK_INPUT(uniform_samples); - auto device = probs.device(); - CHECK_EQ(uniform_samples.device(), device); - CHECK_DIM(2, probs); // probs: (batch_size, vocab_size) - CHECK_DIM(2, uniform_samples); // uniform_samples: (max_rounds, batch_size) - unsigned int batch_size = probs.size(0); - unsigned int vocab_size = probs.size(1); - unsigned int max_rounds = uniform_samples.size(0); - CHECK_EQ(uniform_samples.size(1), batch_size); - bool has_top_k_arr = maybe_top_k_arr.has_value(); - auto top_k_arr = - maybe_top_k_arr.value_or(torch::empty({0}, torch::dtype(torch::kInt32))); - if (has_top_k_arr) { - CHECK_INPUT(top_k_arr); - CHECK_DIM(1, top_k_arr); // top_k_arr: (batch_size,) - CHECK_EQ(top_k_arr.size(0), batch_size); - CHECK_EQ(top_k_arr.device(), device); - } - top_k_arr = top_k_arr.to(torch::kInt32); - bool has_top_p_arr = maybe_top_p_arr.has_value(); - auto top_p_arr = maybe_top_p_arr.value_or( - torch::empty({0}, torch::dtype(torch::kFloat32))); - if (has_top_p_arr) { - CHECK_INPUT(top_p_arr); - CHECK_DIM(1, top_p_arr); // top_p_arr: (batch_size,) - CHECK_EQ(top_p_arr.size(0), batch_size); - CHECK_EQ(top_p_arr.device(), device); - } - top_p_arr = top_p_arr.to(torch::kFloat32); - probs = probs.to(torch::kFloat32); - uniform_samples = uniform_samples.to(torch::kFloat32); - - cudaStream_t torch_current_stream = - c10::cuda::getCurrentCUDAStream(device.index()); - auto samples = - torch::empty({batch_size}, torch::dtype(torch::kInt32).device(device)); - auto success = - torch::empty({batch_size}, torch::dtype(torch::kBool).device(device)); - - cudaError_t status = sampling::TopKTopPSamplingFromProb( - static_cast(probs.data_ptr()), - static_cast(uniform_samples.data_ptr()), - has_top_k_arr ? static_cast(top_k_arr.data_ptr()) : nullptr, - has_top_p_arr ? static_cast(top_p_arr.data_ptr()) : nullptr, - static_cast(samples.data_ptr()), - static_cast(success.data_ptr()), batch_size, top_k_val, top_p_val, - vocab_size, max_rounds, deterministic, torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "TopKTopPSamplingFromProbs failed with error code " + - std::string(cudaGetErrorString(status))); - - return {samples, success}; -} - -torch::Tensor top_p_renorm_prob(torch::Tensor probs, - std::optional maybe_top_p_arr, - double top_p_val) { - CHECK_INPUT(probs); - auto device = probs.device(); - CHECK_DIM(2, probs); // probs: (batch_size, vocab_size) - unsigned int batch_size = probs.size(0); - unsigned int vocab_size = probs.size(1); - bool has_top_p_arr = maybe_top_p_arr.has_value(); - auto top_p_arr = maybe_top_p_arr.value_or( - torch::empty({0}, torch::dtype(torch::kFloat32))); - if (has_top_p_arr) { - CHECK_INPUT(top_p_arr); - CHECK_DIM(1, top_p_arr); // top_p_arr: (batch_size,) - CHECK_EQ(top_p_arr.size(0), batch_size); - CHECK_EQ(top_p_arr.device(), device); - } - top_p_arr = top_p_arr.to(torch::kFloat32); - probs = probs.to(torch::kFloat32); - - cudaStream_t torch_current_stream = - c10::cuda::getCurrentCUDAStream(device.index()); - auto renorm_probs = torch::empty( - {batch_size, vocab_size}, torch::dtype(torch::kFloat32).device(device)); - - cudaError_t status = sampling::TopPRenormProb( - static_cast(probs.data_ptr()), - static_cast(renorm_probs.data_ptr()), - has_top_p_arr ? static_cast(top_p_arr.data_ptr()) : nullptr, - batch_size, top_p_val, vocab_size, torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "TopPRenormProb failed with error code " + - std::string(cudaGetErrorString(status))); - return renorm_probs; -} - -torch::Tensor top_k_renorm_prob(torch::Tensor probs, - std::optional maybe_top_k_arr, - int64_t top_k_val) { - CHECK_INPUT(probs); - auto device = probs.device(); - CHECK_DIM(2, probs); // probs: (batch_size, vocab_size) - unsigned int batch_size = probs.size(0); - unsigned int vocab_size = probs.size(1); - bool has_top_k_arr = maybe_top_k_arr.has_value(); - auto top_k_arr = - maybe_top_k_arr.value_or(torch::empty({0}, torch::dtype(torch::kInt32))); - if (has_top_k_arr) { - CHECK_INPUT(top_k_arr); - CHECK_DIM(1, top_k_arr); // top_k_arr: (batch_size,) - CHECK_EQ(top_k_arr.size(0), batch_size); - CHECK_EQ(top_k_arr.device(), device); - } - top_k_arr = top_k_arr.to(torch::kInt32); - probs = probs.to(torch::kFloat32); - - cudaStream_t torch_current_stream = - c10::cuda::getCurrentCUDAStream(device.index()); - auto renorm_probs = torch::empty( - {batch_size, vocab_size}, torch::dtype(torch::kFloat32).device(device)); - - cudaError_t status = sampling::TopKRenormProb( - static_cast(probs.data_ptr()), - static_cast(renorm_probs.data_ptr()), - has_top_k_arr ? static_cast(top_k_arr.data_ptr()) : nullptr, - batch_size, top_k_val, vocab_size, torch_current_stream); - - TORCH_CHECK(status == cudaSuccess, - "TopKRenormProb failed with error code " + - std::string(cudaGetErrorString(status))); - return renorm_probs; -} - -torch::Tensor top_k_mask_logits(torch::Tensor logits, - std::optional maybe_top_k_arr, - int64_t top_k_val) { - CHECK_INPUT(logits); - auto device = logits.device(); - CHECK_DIM(2, logits); // logits: (batch_size, vocab_size) - unsigned int batch_size = logits.size(0); - unsigned int vocab_size = logits.size(1); - bool has_top_k_arr = maybe_top_k_arr.has_value(); - auto top_k_arr = - maybe_top_k_arr.value_or(torch::empty({0}, torch::dtype(torch::kInt32))); - if (has_top_k_arr) { - CHECK_INPUT(top_k_arr); - CHECK_DIM(1, top_k_arr); // top_k_arr: (batch_size,) - CHECK_EQ(top_k_arr.size(0), batch_size); - CHECK_EQ(top_k_arr.device(), device); - } - top_k_arr = top_k_arr.to(torch::kInt32); - logits = logits.to(torch::kFloat32); - - cudaStream_t torch_current_stream = - c10::cuda::getCurrentCUDAStream(device.index()); - auto mask_logits = torch::empty({batch_size, vocab_size}, - torch::dtype(torch::kFloat32).device(device)); - - cudaError_t status = sampling::TopKMaskLogits( - static_cast(logits.data_ptr()), - static_cast(mask_logits.data_ptr()), - has_top_k_arr ? static_cast(top_k_arr.data_ptr()) : nullptr, - batch_size, top_k_val, vocab_size, torch_current_stream); - - TORCH_CHECK(status == cudaSuccess, - "TopKMaskLogits failed with error code " + - std::string(cudaGetErrorString(status))); - return mask_logits; -} \ No newline at end of file diff --git a/kernels/sampling/sampling.cuh b/kernels/sampling/sampling.cuh deleted file mode 100644 index 84deab2816..0000000000 --- a/kernels/sampling/sampling.cuh +++ /dev/null @@ -1,1398 +0,0 @@ -/* - * Copyright (c) 2024 by PygmalionAI team. - * Copyright (c) 2024 by FlashInfer team. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef APHRODITE_SAMPLING_CUH_ -#define APHRODITE_SAMPLING_CUH_ - -#include -#include -#include -#include - -#include "math.cuh" -#include "utils.cuh" -#include "vec_dtypes.cuh" - -namespace aphrodite { - -namespace sampling { - -using namespace cub; - -#define DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, ...) \ - if (deterministic) { \ - constexpr bool DETERMINISTIC = true; \ - __VA_ARGS__ \ - } else { \ - constexpr bool DETERMINISTIC = false; \ - __VA_ARGS__ \ - } - -constexpr BlockScanAlgorithm SCAN_ALGO = BLOCK_SCAN_WARP_SCANS; -constexpr BlockReduceAlgorithm REDUCE_ALGO = BLOCK_REDUCE_WARP_REDUCTIONS; - -#if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120100) - #define APHRODITE_CUB_SUBTRACTLEFT_DEFINED -#endif - -template -struct Pair { - T value; - int count; - - __device__ Pair operator+(const Pair& other) const { - return {value + other.value, count + other.count}; - } - __device__ Pair& operator+=(const Pair& other) { - value += other.value; - count += other.count; - return *this; - } -}; - -struct BoolDiffOp { - __device__ __forceinline__ bool operator()(const bool& lhs, - const bool& rhs) const { - return lhs != rhs; - } -}; - -template -struct SamplingTempStorage { - union { - T deterministic_scan[BLOCK_THREADS / 32]; - typename BlockScan::TempStorage scan; - typename BlockReduce::TempStorage - reduce; - typename BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage - reduce_pair; - typename BlockAdjacentDifference::TempStorage adj_diff; - } block_prim; - struct { - int32_t sampled_id; - union { - T value; - Pair pair; - T max_p; - } block_aggregate; - } data; -}; - -/*! - * \brief Deterministic inclusive scan implementation, use Belloch scan - * algorithm. \note This implementation is slower than the cub::BlockScan, but - * it is deterministic. - */ -template -__device__ __forceinline__ void DeterministicInclusiveSum( - const T* in_data, T* out_data, - SamplingTempStorage* - temp_storage) { - T* smem_prefix_sum = temp_storage->block_prim.deterministic_scan; - T thread_data[VEC_SIZE]; - T thread_sum = 0; -#pragma unroll - for (uint32_t i = 0; i < VEC_SIZE; ++i) { - thread_sum += in_data[i]; - thread_data[i] = thread_sum; - } - - T thread_exclusive_prefix_sum = thread_sum; - -#pragma unroll - for (uint32_t offset = 1; offset < 32; offset *= 2) { - T tmp = __shfl_up_sync(0xffffffff, thread_exclusive_prefix_sum, offset); - if ((threadIdx.x + 1) % (offset * 2) == 0) { - thread_exclusive_prefix_sum += tmp; - } - } - - T warp_sum = __shfl_sync(0xffffffff, thread_exclusive_prefix_sum, - threadIdx.x | 0xffffffff); - if (threadIdx.x % 32 == 31) { - thread_exclusive_prefix_sum = 0; - } - -#pragma unroll - for (uint32_t offset = 16; offset >= 1; offset /= 2) { - T tmp = __shfl_xor_sync(0xffffffff, thread_exclusive_prefix_sum, offset); - if ((threadIdx.x + 1) % (offset * 2) == 0) { - thread_exclusive_prefix_sum = tmp + thread_exclusive_prefix_sum; - } - if ((threadIdx.x + 1) % (offset * 2) == offset) { - thread_exclusive_prefix_sum = tmp; - } - } - - smem_prefix_sum[threadIdx.x / 32] = warp_sum; - __syncthreads(); - - if (threadIdx.x < 32) { - T warp_exclusive_prefix_sum = - (threadIdx.x < BLOCK_THREADS / 32) ? smem_prefix_sum[threadIdx.x] : 0; - -#pragma unroll - for (uint32_t offset = 1; offset < 32; offset *= 2) { - T tmp = __shfl_up_sync(0xffffffff, warp_exclusive_prefix_sum, offset); - if ((threadIdx.x + 1) % (offset * 2) == 0) { - warp_exclusive_prefix_sum += tmp; - } - } - - if (threadIdx.x % 32 == 31) { - warp_exclusive_prefix_sum = 0; - } - -#pragma unroll - for (uint32_t offset = 16; offset >= 1; offset /= 2) { - T tmp = __shfl_xor_sync(0xffffffff, warp_exclusive_prefix_sum, offset); - if ((threadIdx.x + 1) % (offset * 2) == 0) { - warp_exclusive_prefix_sum = tmp + warp_exclusive_prefix_sum; - } - if ((threadIdx.x + 1) % (offset * 2) == offset) { - warp_exclusive_prefix_sum = tmp; - } - } - if (threadIdx.x < BLOCK_THREADS / 32) { - smem_prefix_sum[threadIdx.x] = warp_exclusive_prefix_sum; - } - } - __syncthreads(); - -#pragma unroll - for (uint32_t i = 0; i < VEC_SIZE; ++i) { - out_data[i] = smem_prefix_sum[threadIdx.x / 32] + - thread_exclusive_prefix_sum + thread_data[i]; - } -} - -template -__device__ __forceinline__ void DeviceSamplingFromProb( - uint32_t i, uint32_t d, T threshold, T u, vec_t prob_vec, - T& aggregate, - SamplingTempStorage* - temp_storage) { - const uint32_t tx = threadIdx.x; - T prob_greater_than_threshold[VEC_SIZE]; - T inclusive_cdf[VEC_SIZE]; - bool greater_than_u[VEC_SIZE], valid[VEC_SIZE]; -#pragma unroll - for (uint32_t j = 0; j < VEC_SIZE; ++j) { - prob_greater_than_threshold[j] = - (prob_vec[j] > threshold) ? prob_vec[j] : T(0); - valid[j] = - prob_vec[j] > threshold && (i * BLOCK_THREADS + tx) * VEC_SIZE < d; - } - T aggregate_local = BlockReduce( - temp_storage->block_prim.reduce) - .Sum(prob_greater_than_threshold); - if (tx == 0) { - temp_storage->data.block_aggregate.value = aggregate_local; - } - __syncthreads(); - aggregate_local = temp_storage->data.block_aggregate.value; - - if (aggregate + aggregate_local > u) { - if constexpr (DETERMINISTIC) { - DeterministicInclusiveSum( - prob_greater_than_threshold, inclusive_cdf, temp_storage); - } else { - BlockScan(temp_storage->block_prim.scan) - .InclusiveSum(prob_greater_than_threshold, inclusive_cdf); - - __syncthreads(); - } - -#pragma unroll - for (uint32_t j = 0; j < VEC_SIZE; ++j) { - greater_than_u[j] = inclusive_cdf[j] + aggregate > u; - } - - bool greater_than_u_diff[VEC_SIZE]; -#ifdef APHRODITE_CUB_SUBTRACTLEFT_DEFINED - BlockAdjacentDifference( - temp_storage->block_prim.adj_diff) - .SubtractLeft(greater_than_u, greater_than_u_diff, - BoolDiffOp()); -#else - BlockAdjacentDifference( - temp_storage->block_prim.adj_diff) - .FlagHeads(greater_than_u_diff, greater_than_u, BoolDiffOp(), - 0); -#endif - __syncthreads(); - -#pragma unroll - for (uint32_t j = 0; j < VEC_SIZE; ++j) { - if (greater_than_u_diff[j] && valid[j]) { - if constexpr (DETERMINISTIC) { - temp_storage->data.sampled_id = - (i * BLOCK_THREADS + tx) * VEC_SIZE + j; - } else { - // cub's block scan result might not be monotonic, so we need to find - // the first element - atomicMin(&(temp_storage->data.sampled_id), - (i * BLOCK_THREADS + tx) * VEC_SIZE + j); - } - } - } - __syncthreads(); - } - aggregate += aggregate_local; -} - -template -__global__ void SamplingFromProbKernel(DType* probs, DType* uniform_samples, - IdType* output, IdType* row_indices, - uint32_t d) { - const uint32_t bx = blockIdx.x, tx = threadIdx.x; - const uint32_t row_idx = row_indices == nullptr ? bx : row_indices[bx]; - - extern __shared__ __align__( - alignof(SamplingTempStorage)) uint8_t smem_sampling[]; - auto& temp_storage = - reinterpret_cast&>(smem_sampling); - temp_storage.data.sampled_id = d - 1; - __syncthreads(); - - vec_t probs_vec; - DType aggregate(0); - float u = uniform_samples[bx]; - - for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { - probs_vec.fill(DType(0)); - if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + - tx * VEC_SIZE); - } - - DeviceSamplingFromProb( - i, d, DType(0), u, probs_vec, aggregate, &temp_storage); - if (float(aggregate) > u) { - break; - } - } - output[bx] = temp_storage.data.sampled_id; -} - -template -__global__ void TopKSamplingFromProbKernel(DType* probs, DType* uniform_samples, - IdType* output, bool* success, - IdType* top_k_arr, - uint32_t top_k_val, uint32_t d, - uint32_t max_top_k_rounds) { - const uint32_t batch_size = gridDim.x; - const uint32_t bx = blockIdx.x, tx = threadIdx.x; - uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[bx]; - - extern __shared__ __align__( - alignof(SamplingTempStorage)) uint8_t smem_sampling[]; - auto& temp_storage = - reinterpret_cast&>(smem_sampling); - - vec_t probs_vec; - DType aggregate; - DType q = DType(1); - DType pivot = DType(0); - IdType sampled_id; - for (uint32_t round = 0; round < max_top_k_rounds; ++round) { - temp_storage.data.sampled_id = d - 1; - __syncthreads(); - DType u = uniform_samples[round * batch_size + bx] * q; - aggregate = DType(0); - for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { - probs_vec.fill(DType(0)); - if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); - } - - DeviceSamplingFromProb( - i, d, pivot, u, probs_vec, aggregate, &temp_storage); - if (aggregate > u) { - break; - } - } - __syncthreads(); - sampled_id = temp_storage.data.sampled_id; - pivot = max(pivot, probs[bx * d + sampled_id]); - - Pair aggregate_gt_pivot{DType(0), 0}; - for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { - probs_vec.fill(DType(0)); - if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); - } - - Pair probs_gt_pivot[VEC_SIZE]; -#pragma unroll - for (uint32_t j = 0; j < VEC_SIZE; ++j) { - probs_gt_pivot[j] = {(probs_vec[j] > pivot) ? probs_vec[j] : DType(0), - (probs_vec[j] > pivot && - (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)}; - } - - aggregate_gt_pivot += - BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( - temp_storage.block_prim.reduce_pair) - .Sum(probs_gt_pivot); - if (tx == 0) { - temp_storage.data.block_aggregate.pair = aggregate_gt_pivot; - } - __syncthreads(); - } - q = temp_storage.data.block_aggregate.pair.value; - if (temp_storage.data.block_aggregate.pair.count < k) { - break; - } - } - __syncthreads(); - if (tx == 0) { - output[bx] = sampled_id; - if (temp_storage.data.block_aggregate.pair.count >= k) { - // failed to sample within MAX_TOP_P_ROUNDS - if (success != nullptr) { - success[bx] = false; - } - } else { - if (success != nullptr) { - success[bx] = true; - } - } - } -} - -template -__global__ void TopPSamplingFromProbKernel(DType* probs, DType* uniform_samples, - IdType* output, bool* success, - IdType* row_indices, - float* top_p_arr, float top_p_val, - uint32_t d, - uint32_t max_top_p_rounds) { - const uint32_t batch_size = gridDim.x; - const uint32_t bx = blockIdx.x, tx = threadIdx.x; - float top_p = (top_p_arr == nullptr) ? top_p_val : top_p_arr[bx]; - - const uint32_t row_idx = row_indices == nullptr ? bx : row_indices[bx]; - - extern __shared__ __align__( - alignof(SamplingTempStorage)) uint8_t smem_sampling[]; - auto& temp_storage = - reinterpret_cast&>(smem_sampling); - - vec_t probs_vec; - DType aggregate; - DType q = DType(1); - DType pivot = DType(0); - IdType sampled_id; - for (uint32_t round = 0; round < max_top_p_rounds; ++round) { - temp_storage.data.sampled_id = d - 1; - __syncthreads(); - DType u = uniform_samples[round * batch_size + bx] * q; - aggregate = DType(0); - for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { - probs_vec.fill(DType(0)); - if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.load(probs + row_idx * d + - (i * BLOCK_THREADS + tx) * VEC_SIZE); - } - - DeviceSamplingFromProb( - i, d, pivot, u, probs_vec, aggregate, &temp_storage); - if (aggregate > u) { - break; - } - } - __syncthreads(); - sampled_id = temp_storage.data.sampled_id; - pivot = max(pivot, probs[row_idx * d + sampled_id]); - - DType aggregate_gt_pivot = DType(0); - for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { - probs_vec.fill(DType(0)); - if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.load(probs + row_idx * d + - (i * BLOCK_THREADS + tx) * VEC_SIZE); - } - - DType probs_gt_pivot[VEC_SIZE]; -#pragma unroll - for (uint32_t j = 0; j < VEC_SIZE; ++j) { - probs_gt_pivot[j] = (probs_vec[j] > pivot) ? probs_vec[j] : DType(0); - } - - aggregate_gt_pivot += - BlockReduce(temp_storage.block_prim.reduce) - .Sum(probs_gt_pivot); - if (tx == 0) { - temp_storage.data.block_aggregate.value = aggregate_gt_pivot; - } - __syncthreads(); - } - q = temp_storage.data.block_aggregate.value; - if (float(q) < top_p) { - break; - } - } - __syncthreads(); - if (tx == 0) { - output[bx] = sampled_id; - if (float(q) >= top_p) { - // failed to sample within MAX_TOP_P_ROUNDS - if (success != nullptr) { - success[bx] = false; - } - } else { - if (success != nullptr) { - success[bx] = true; - } - } - } -} - -template -__global__ void MinPSamplingFromProbKernel(DType* probs, DType* uniform_samples, - DType* min_p_arr, IdType* output, - bool* success, float min_p_val, - uint32_t d, - uint32_t max_min_p_rounds) { - const uint32_t batch_size = gridDim.x; - const uint32_t bx = blockIdx.x, tx = threadIdx.x; - DType p = (min_p_arr == nullptr) ? min_p_val : min_p_arr[bx]; - - extern __shared__ __align__( - alignof(SamplingTempStorage)) uint8_t smem_sampling[]; - auto& temp_storage = - reinterpret_cast&>(smem_sampling); - - vec_t probs_vec; - DType aggregate; - DType q = DType(1); - DType pivot = DType(0); - - DType max_p = 0; - for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { - probs_vec.fill(DType(0)); - if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); - } - DType probs_[VEC_SIZE]; -#pragma unroll - for (uint32_t j = 0; j < VEC_SIZE; ++j) { - probs_[j] = probs_vec[j]; - } - max_p = max( - max_p, BlockReduce(temp_storage.block_prim.reduce) - .Reduce(probs_, cub::Max())); - __syncthreads(); - } - if (tx == 0) { - temp_storage.data.block_aggregate.max_p = max_p; - } - __syncthreads(); - DType scaled_p = temp_storage.data.block_aggregate.max_p * p; - - IdType sampled_id; - for (uint32_t round = 0; round < max_min_p_rounds; ++round) { - temp_storage.data.sampled_id = d - 1; - __syncthreads(); - DType u = uniform_samples[round * batch_size + bx] * q; - aggregate = DType(0); - for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { - probs_vec.fill(DType(0)); - if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); - } - - DeviceSamplingFromProb( - i, d, pivot, u, probs_vec, aggregate, &temp_storage); - if (aggregate > u) { - break; - } - } - __syncthreads(); - sampled_id = temp_storage.data.sampled_id; - pivot = max(pivot, probs[bx * d + sampled_id]); - if (pivot >= scaled_p) { - break; - } - - DType aggregate_gt_pivot = DType(0); - for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { - probs_vec.fill(DType(0)); - if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); - } - - DType probs_gt_pivot[VEC_SIZE]; -#pragma unroll - for (uint32_t j = 0; j < VEC_SIZE; ++j) { - probs_gt_pivot[j] = (probs_vec[j] > pivot) ? probs_vec[j] : DType(0); - } - - aggregate_gt_pivot += - BlockReduce(temp_storage.block_prim.reduce) - .Sum(probs_gt_pivot); - if (tx == 0) { - temp_storage.data.block_aggregate.value = aggregate_gt_pivot; - } - __syncthreads(); - } - q = temp_storage.data.block_aggregate.value; - } - __syncthreads(); - if (tx == 0) { - output[bx] = sampled_id; - if (pivot < scaled_p) { - // failed to sample within MAX_ROUNDS - if (success != nullptr) { - success[bx] = false; - } - } else { - if (success != nullptr) { - success[bx] = true; - } - } - } -} - -template -__global__ void TopKTopPSamplingFromProbKernel( - DType* probs, DType* uniform_samples, IdType* top_k_arr, DType* top_p_arr, - IdType* output, bool* success, IdType top_k_val, DType top_p_val, - uint32_t d, uint32_t max_rounds) { - const uint32_t batch_size = gridDim.x; - const uint32_t bx = blockIdx.x, tx = threadIdx.x; - IdType k = top_k_arr == nullptr ? top_k_val : top_k_arr[bx]; - DType p = top_p_arr == nullptr ? top_p_val : top_p_arr[bx]; - - extern __shared__ __align__( - alignof(SamplingTempStorage)) uint8_t smem_sampling[]; - auto& temp_storage = - reinterpret_cast&>(smem_sampling); - - vec_t probs_vec; - DType aggregate; - DType q = DType(1); - DType pivot = DType(0); - IdType sampled_id; - for (uint32_t round = 0; round < max_rounds; ++round) { - temp_storage.data.sampled_id = d - 1; - __syncthreads(); - DType u = uniform_samples[round * batch_size + bx] * q; - aggregate = DType(0); - for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { - probs_vec.fill(DType(0)); - if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); - } - - DeviceSamplingFromProb( - i, d, pivot, u, probs_vec, aggregate, &temp_storage); - if (aggregate > u) { - break; - } - } - __syncthreads(); - sampled_id = temp_storage.data.sampled_id; - pivot = max(pivot, probs[bx * d + sampled_id]); - - Pair aggregate_gt_pivot{DType(0), 0}; - for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { - probs_vec.fill(DType(0)); - if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); - } - - Pair probs_gt_pivot[VEC_SIZE]; -#pragma unroll - for (uint32_t j = 0; j < VEC_SIZE; ++j) { - probs_gt_pivot[j] = {(probs_vec[j] > pivot) ? probs_vec[j] : DType(0), - (probs_vec[j] > pivot && - (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)}; - } - - aggregate_gt_pivot += - BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( - temp_storage.block_prim.reduce_pair) - .Sum(probs_gt_pivot); - if (tx == 0) { - temp_storage.data.block_aggregate.pair = aggregate_gt_pivot; - } - __syncthreads(); - } - q = temp_storage.data.block_aggregate.pair.value; - if (temp_storage.data.block_aggregate.pair.count < k && float(q) < p) { - break; - } - } - __syncthreads(); - if (tx == 0) { - output[bx] = sampled_id; - if (temp_storage.data.block_aggregate.pair.count >= k || float(q) >= p) { - // failed to sample within MAX_TOP_P_ROUNDS - if (success != nullptr) { - success[bx] = false; - } - } else { - if (success != nullptr) { - success[bx] = true; - } - } - } -} - -template -cudaError_t SamplingFromProb(T* probs, T* uniform_samples, IdType* output, - uint32_t batch_size, uint32_t d, - bool deterministic, cudaStream_t stream = 0) { - constexpr uint32_t BLOCK_THREADS = 1024; - const uint32_t vec_size = std::gcd(16 / sizeof(T), d); - dim3 nblks(batch_size); - dim3 nthrs(BLOCK_THREADS); - IdType* row_indices_placeholder = nullptr; - void* args[] = {&probs, &uniform_samples, &output, &row_indices_placeholder, - &d}; - const uint32_t smem_size = - sizeof(SamplingTempStorage); - - DISPATCH_ALIGNED_VEC_SIZE( - vec_size, VEC_SIZE, - {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { - auto kernel = - SamplingFromProbKernel; - APHRODITE_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, - smem_size, stream)); - })}); - return cudaSuccess; -} - -template -cudaError_t ParallelSamplingFromProb(T* probs, T* uniform_samples, - IdType* output, IdType* row_indices, - uint32_t batch_size, uint32_t d, - bool deterministic, - cudaStream_t stream = 0) { - constexpr uint32_t BLOCK_THREADS = 1024; - const uint32_t vec_size = std::gcd(16 / sizeof(T), d); - dim3 nblks(batch_size); - dim3 nthrs(BLOCK_THREADS); - void* args[] = {&probs, &uniform_samples, &output, &row_indices, &d}; - const uint32_t smem_size = - sizeof(SamplingTempStorage); - - DISPATCH_ALIGNED_VEC_SIZE( - vec_size, VEC_SIZE, - {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { - auto kernel = - SamplingFromProbKernel; - APHRODITE_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, - smem_size, stream)); - })}); - return cudaSuccess; -} - -template -cudaError_t TopKSamplingFromProb(T* probs, T* uniform_samples, IdType* output, - bool* success, T* top_k_arr, - uint32_t batch_size, uint32_t top_k_val, - uint32_t d, uint32_t max_top_k_rounds, - bool deterministic, cudaStream_t stream = 0) { - constexpr uint32_t BLOCK_THREADS = 1024; - const uint32_t vec_size = std::gcd(16 / sizeof(T), d); - - const uint32_t smem_size = - sizeof(SamplingTempStorage); - dim3 nblks(batch_size); - dim3 nthrs(BLOCK_THREADS); - void* args[] = {&probs, &uniform_samples, &output, &success, - &top_k_arr, &top_k_val, &d, &max_top_k_rounds}; - - DISPATCH_ALIGNED_VEC_SIZE( - vec_size, VEC_SIZE, - {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { - auto kernel = - TopKSamplingFromProbKernel; - APHRODITE_CUDA_CALL(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - APHRODITE_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, - smem_size, stream)); - })}); - return cudaSuccess; -} - -template -cudaError_t TopPSamplingFromProb(T* probs, T* uniform_samples, IdType* output, - bool* success, T* top_p_arr, - uint32_t batch_size, T top_p_val, uint32_t d, - uint32_t max_top_p_rounds, bool deterministic, - cudaStream_t stream = 0) { - constexpr uint32_t BLOCK_THREADS = 1024; - const uint32_t vec_size = std::gcd(16 / sizeof(T), d); - - const uint32_t smem_size = - sizeof(SamplingTempStorage); - dim3 nblks(batch_size); - dim3 nthrs(BLOCK_THREADS); - IdType* row_indices_placeholder = nullptr; - void* args[] = {&probs, - &uniform_samples, - &output, - &success, - &row_indices_placeholder, - &top_p_arr, - &top_p_val, - &d, - &max_top_p_rounds}; - - DISPATCH_ALIGNED_VEC_SIZE( - vec_size, VEC_SIZE, - {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { - auto kernel = - TopPSamplingFromProbKernel; - APHRODITE_CUDA_CALL(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - APHRODITE_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, - smem_size, stream)); - })}); - return cudaSuccess; -} - -template -cudaError_t MinPSamplingFromProb(T* probs, T* uniform_samples, T* min_p_arr, - IdType* output, bool* success, - uint32_t batch_size, float min_p_val, - uint32_t d, uint32_t max_rounds, - bool deterministic, cudaStream_t stream = 0) { - constexpr uint32_t BLOCK_THREADS = 1024; - const uint32_t vec_size = std::gcd(16 / sizeof(T), d); - - const uint32_t smem_size = - sizeof(SamplingTempStorage); - dim3 nblks(batch_size); - dim3 nthrs(BLOCK_THREADS); - void* args[] = {&probs, &uniform_samples, &min_p_arr, &output, - &success, &min_p_val, &d, &max_rounds}; - - DISPATCH_ALIGNED_VEC_SIZE( - vec_size, VEC_SIZE, - {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { - auto kernel = - MinPSamplingFromProbKernel; - APHRODITE_CUDA_CALL(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - APHRODITE_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, - smem_size, stream)); - })}); - return cudaSuccess; -} - -template -cudaError_t TopKTopPSamplingFromProb(T* probs, T* uniform_samples, - IdType* top_k_arr, T* top_p_arr, - IdType* output, bool* success, - uint32_t batch_size, IdType top_k_val, - T top_p_val, uint32_t d, - uint32_t max_rounds, bool deterministic, - cudaStream_t stream = 0) { - constexpr uint32_t BLOCK_THREADS = 1024; - const uint32_t vec_size = std::gcd(16 / sizeof(T), d); - - const uint32_t smem_size = - sizeof(SamplingTempStorage); - dim3 nblks(batch_size); - dim3 nthrs(BLOCK_THREADS); - void* args[] = {&probs, &uniform_samples, &top_k_arr, &top_p_arr, - &output, &success, &top_k_val, &top_p_val, - &d, &max_rounds}; - - DISPATCH_ALIGNED_VEC_SIZE( - vec_size, VEC_SIZE, - {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { - auto kernel = TopKTopPSamplingFromProbKernel; - APHRODITE_CUDA_CALL(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - APHRODITE_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, - smem_size, stream)); - })}); - return cudaSuccess; -} - -template -struct RenormTempStorage { - union { - typename BlockReduce::TempStorage - reduce; - typename BlockReduce::TempStorage - reduce_int; - typename BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage - reduce_pair; - } block_prim; - struct { - T max_val; - T min_val; - union { - T value; - int count; - Pair pair; - } block_aggregate; - } data; -}; - -template -__global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, - DType* top_p_arr, float top_p_val, - uint32_t d) { - const uint32_t bx = blockIdx.x, tx = threadIdx.x; - const uint32_t row_idx = bx; - float p = top_p_arr == nullptr ? top_p_val : top_p_arr[bx]; - - extern __shared__ __align__( - alignof(RenormTempStorage)) - uint8_t smem_renorm[]; - auto& temp_storage = - reinterpret_cast&>( - smem_renorm); - temp_storage.data.max_val = DType(0); - vec_t probs_vec; - DType probs_greater_than_pivot[VEC_SIZE]; // pivot initialized to 0 - - DType threadlocal_max_val = DType(0); - for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { - probs_vec.fill(DType(0)); - if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + - tx * VEC_SIZE); - } -#pragma unroll - for (uint32_t j = 0; j < VEC_SIZE; ++j) { - probs_greater_than_pivot[j] = probs_vec[j]; - } - threadlocal_max_val = - max(threadlocal_max_val, - BlockReduce( - temp_storage.block_prim.reduce) - .Reduce(probs_greater_than_pivot, cub::Max())); - __syncthreads(); - } - if (tx == 0) { - temp_storage.data.max_val = threadlocal_max_val; - } - __syncthreads(); - threadlocal_max_val = temp_storage.data.max_val; - - float low = 0, high = threadlocal_max_val; - DType min_gt_low, max_le_high; - DType sum_low(1); - // f(x) = sum(probs[probs > x]), f(x) is non-increasing - // min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs | p - // <= high} loop invariant: - // - f(low) >= p, f(high) < p - // - f(low) > f(min_gt_low) >= f(max_le_high) == f(high) - // stopping condition - // - f(low) >= p, f(min_gt_low) == f(max_le_high) == f(high) < p - do { - DType threadlocal_sum(0); - float mid = (low + high) / 2; - min_gt_low = high; - max_le_high = low; - for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { - probs_vec.fill(DType(0)); - if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + - tx * VEC_SIZE); - } -#pragma unroll - for (uint32_t j = 0; j < VEC_SIZE; ++j) { - probs_greater_than_pivot[j] = - (probs_vec[j] > mid) ? probs_vec[j] : DType(0); - if (probs_vec[j] > low && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) { - min_gt_low = min(min_gt_low, probs_vec[j]); - } - if (probs_vec[j] <= high && - (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) { - max_le_high = max(max_le_high, probs_vec[j]); - } - } - threadlocal_sum += BlockReduce( - temp_storage.block_prim.reduce) - .Sum(probs_greater_than_pivot); - __syncthreads(); - } - min_gt_low = BlockReduce( - temp_storage.block_prim.reduce) - .Reduce(min_gt_low, cub::Min()); - __syncthreads(); - max_le_high = BlockReduce( - temp_storage.block_prim.reduce) - .Reduce(max_le_high, cub::Max()); - if (tx == 0) { - temp_storage.data.block_aggregate.value = threadlocal_sum; - temp_storage.data.min_val = min_gt_low; - temp_storage.data.max_val = max_le_high; - } - __syncthreads(); - threadlocal_sum = temp_storage.data.block_aggregate.value; - min_gt_low = temp_storage.data.min_val; - max_le_high = temp_storage.data.max_val; - if (threadlocal_sum >= p) { - low = mid; - sum_low = float(threadlocal_sum); - } else { - high = min(mid, max_le_high); - } - } while (min_gt_low != max_le_high); - - DType normalizer = math::ptx_rcp(max(sum_low, 1e-8)); - - // normalize - for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { - probs_vec.fill(DType(0)); - if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + - tx * VEC_SIZE); - } -#pragma unroll - for (uint32_t j = 0; j < VEC_SIZE; ++j) { - probs_vec[j] = - (probs_vec[j] > low) ? probs_vec[j] * normalizer : DType(0); - } - if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.store(renormed_prob + row_idx * d + - i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); - } - } -} - -template -__global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits, - IdType* top_k_arr, uint32_t top_k_val, - uint32_t d) { - const uint32_t bx = blockIdx.x, tx = threadIdx.x; - const uint32_t row_idx = bx; - uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[bx]; - float pivot = -std::numeric_limits::infinity(); - vec_t logits_vec; - if (k < d) { - extern __shared__ __align__( - alignof(RenormTempStorage)) - uint8_t smem_renorm[]; - auto& temp_storage = - reinterpret_cast&>( - smem_renorm); - DType logits_greater_than_pivot[VEC_SIZE]; // pivot initialized to 0 - - DType threadlocal_max_val = DType(-std::numeric_limits::infinity()), - threadlocal_min_val = DType(std::numeric_limits::infinity()); - for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { - logits_vec.fill(DType(0)); - if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - logits_vec.load(logits + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + - tx * VEC_SIZE); - } -#pragma unroll - for (uint32_t j = 0; j < VEC_SIZE; ++j) { - logits_greater_than_pivot[j] = logits_vec[j]; - } - threadlocal_max_val = - max(threadlocal_max_val, - BlockReduce( - temp_storage.block_prim.reduce) - .Reduce(logits_greater_than_pivot, cub::Max())); - __syncthreads(); - threadlocal_min_val = - min(threadlocal_min_val, - BlockReduce( - temp_storage.block_prim.reduce) - .Reduce(logits_greater_than_pivot, cub::Min())); - __syncthreads(); - } - if (tx == 0) { - temp_storage.data.max_val = threadlocal_max_val; - temp_storage.data.min_val = threadlocal_min_val; - } - __syncthreads(); - threadlocal_max_val = temp_storage.data.max_val; - threadlocal_min_val = temp_storage.data.min_val; - - float low = threadlocal_min_val - 1, high = threadlocal_max_val; - DType min_gt_low, max_le_high; - // f(x) = len(nonzero(probs > x)), f(x) is non-increasing - // min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs | - // p <= high} loop invariant: - // - f(low) >= k, f(high) < k - // - f(low) > f(min_gt_low) >= f(max_le_high) == f(high) - // stopping condition: min_gt_low == max_le_high - // - f(low) >= k, f(min_gt_low) == f(max_le_high) == f(high) < k - do { - int threadlocal_count_sum = 0; - int probs_greater_than_pivot_count[VEC_SIZE]; // pivot initialized to 0 - float mid = (low + high) / 2; - min_gt_low = high; - max_le_high = low; - for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { - logits_vec.fill(DType(0)); - if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - logits_vec.load(logits + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + - tx * VEC_SIZE); - } -#pragma unroll - for (uint32_t j = 0; j < VEC_SIZE; ++j) { - probs_greater_than_pivot_count[j] = - logits_vec[j] > mid && - (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d; - if (logits_vec[j] > low && - (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) { - min_gt_low = min(min_gt_low, logits_vec[j]); - } - if (logits_vec[j] <= high && - (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) { - max_le_high = max(max_le_high, logits_vec[j]); - } - } - threadlocal_count_sum += - BlockReduce( - temp_storage.block_prim.reduce_int) - .Sum(probs_greater_than_pivot_count); - __syncthreads(); - } - min_gt_low = BlockReduce( - temp_storage.block_prim.reduce) - .Reduce(min_gt_low, cub::Min()); - __syncthreads(); - max_le_high = BlockReduce( - temp_storage.block_prim.reduce) - .Reduce(max_le_high, cub::Max()); - if (tx == 0) { - temp_storage.data.block_aggregate.count = threadlocal_count_sum; - temp_storage.data.min_val = min_gt_low; - temp_storage.data.max_val = max_le_high; - } - __syncthreads(); - threadlocal_count_sum = temp_storage.data.block_aggregate.count; - min_gt_low = temp_storage.data.min_val; - max_le_high = temp_storage.data.max_val; - if (threadlocal_count_sum >= k) { - low = mid; - } else { - high = min(mid, max_le_high); - } - } while (min_gt_low != max_le_high); - pivot = low; - } - - // masking - for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { - logits_vec.fill(DType(0)); - if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - logits_vec.load(logits + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + - tx * VEC_SIZE); - } -#pragma unroll - for (uint32_t j = 0; j < VEC_SIZE; ++j) { - logits_vec[j] = (logits_vec[j] > pivot) - ? logits_vec[j] - : DType(-std::numeric_limits::infinity()); - } - if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - logits_vec.store(masked_logits + row_idx * d + - i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); - } - } -} - -template -__global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, - IdType* top_k_arr, uint32_t top_k_val, - uint32_t d) { - const uint32_t bx = blockIdx.x, tx = threadIdx.x; - const uint32_t row_idx = bx; - uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[bx]; - float pivot = -std::numeric_limits::infinity(), normalizer = 1; - vec_t probs_vec; - if (k < d) { - extern __shared__ __align__( - alignof(RenormTempStorage)) - uint8_t smem_renorm[]; - auto& temp_storage = - reinterpret_cast&>( - smem_renorm); - temp_storage.data.max_val = DType(0); - DType probs_greater_than_pivot[VEC_SIZE]; // pivot initialized to 0 - - DType threadlocal_max_val = DType(0); - for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { - probs_vec.fill(DType(0)); - if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + - tx * VEC_SIZE); - } -#pragma unroll - for (uint32_t j = 0; j < VEC_SIZE; ++j) { - probs_greater_than_pivot[j] = probs_vec[j]; - } - threadlocal_max_val = - max(threadlocal_max_val, - BlockReduce( - temp_storage.block_prim.reduce) - .Reduce(probs_greater_than_pivot, cub::Max())); - __syncthreads(); - } - if (tx == 0) { - temp_storage.data.max_val = threadlocal_max_val; - } - __syncthreads(); - threadlocal_max_val = temp_storage.data.max_val; - - float low = 0, high = threadlocal_max_val; - DType min_gt_low, max_le_high; - DType sum_low(1); - // f(x) = len(nonzero(probs > x)), f(x) is non-increasing - // min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs | - // p <= high} loop invariant: - // - f(low) >= k, f(high) < k - // - f(low) > f(min_gt_low) >= f(max_le_high) == f(high) - // stopping condition: min_gt_low == max_le_high - // - f(low) >= k, f(min_gt_low) == f(max_le_high) == f(high) < k - do { - Pair threadlocal_sum{DType(0), 0}; - Pair - probs_greater_than_pivot_pair[VEC_SIZE]; // pivot initialized to 0 - float mid = (low + high) / 2; - min_gt_low = high; - max_le_high = low; - for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { - probs_vec.fill(DType(0)); - if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + - tx * VEC_SIZE); - } -#pragma unroll - for (uint32_t j = 0; j < VEC_SIZE; ++j) { - probs_greater_than_pivot_pair[j] = { - (probs_vec[j] > mid) ? probs_vec[j] : DType(0), - (probs_vec[j] > mid && - (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)}; - if (probs_vec[j] > low && - (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) { - min_gt_low = min(min_gt_low, probs_vec[j]); - } - if (probs_vec[j] <= high && - (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) { - max_le_high = max(max_le_high, probs_vec[j]); - } - } - threadlocal_sum += - BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( - temp_storage.block_prim.reduce_pair) - .Sum(probs_greater_than_pivot_pair); - __syncthreads(); - } - min_gt_low = BlockReduce( - temp_storage.block_prim.reduce) - .Reduce(min_gt_low, cub::Min()); - __syncthreads(); - max_le_high = BlockReduce( - temp_storage.block_prim.reduce) - .Reduce(max_le_high, cub::Max()); - if (tx == 0) { - temp_storage.data.block_aggregate.pair = threadlocal_sum; - temp_storage.data.min_val = min_gt_low; - temp_storage.data.max_val = max_le_high; - } - __syncthreads(); - threadlocal_sum = temp_storage.data.block_aggregate.pair; - min_gt_low = temp_storage.data.min_val; - max_le_high = temp_storage.data.max_val; - if (threadlocal_sum.count >= k) { - low = mid; - sum_low = float(threadlocal_sum.value); - } else { - high = min(mid, max_le_high); - } - } while (min_gt_low != max_le_high); - - normalizer = math::ptx_rcp(max(sum_low, 1e-8)); - pivot = low; - } - - // normalize - for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { - probs_vec.fill(DType(0)); - if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + - tx * VEC_SIZE); - } -#pragma unroll - for (uint32_t j = 0; j < VEC_SIZE; ++j) { - probs_vec[j] = - (probs_vec[j] > pivot) ? probs_vec[j] * normalizer : DType(0); - } - if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.store(renormed_prob + row_idx * d + - i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); - } - } -} - -template -cudaError_t TopPRenormProb(DType* probs, DType* renormed_prob, DType* top_p_arr, - uint32_t batch_size, float top_p_val, uint32_t d, - cudaStream_t stream = 0) { - const uint32_t BLOCK_THREADS = 1024; - const uint32_t vec_size = std::gcd(16 / sizeof(DType), d); - - const uint32_t smem_size = - sizeof(RenormTempStorage); - dim3 nblks(batch_size); - dim3 nthrs(BLOCK_THREADS); - void* args[] = {&probs, &renormed_prob, &top_p_arr, &top_p_val, &d}; - DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { - auto kernel = - TopPRenormProbKernel; - APHRODITE_CUDA_CALL(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - APHRODITE_CUDA_CALL( - cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); - }); - return cudaSuccess; -} - -template -cudaError_t TopKRenormProb(DType* probs, DType* renormed_prob, - IdType* top_k_arr, uint32_t batch_size, - uint32_t top_k_val, uint32_t d, - cudaStream_t stream = 0) { - const uint32_t BLOCK_THREADS = 1024; - const uint32_t vec_size = std::gcd(16 / sizeof(DType), d); - - const uint32_t smem_size = - sizeof(RenormTempStorage); - dim3 nblks(batch_size); - dim3 nthrs(BLOCK_THREADS); - void* args[] = {&probs, &renormed_prob, &top_k_arr, &top_k_val, &d}; - DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { - auto kernel = TopKRenormProbKernel; - APHRODITE_CUDA_CALL(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - APHRODITE_CUDA_CALL( - cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); - }); - return cudaSuccess; -} - -template -cudaError_t TopKMaskLogits(DType* logits, DType* masked_logits, - IdType* top_k_arr, uint32_t batch_size, - uint32_t top_k_val, uint32_t d, - cudaStream_t stream = 0) { - const uint32_t BLOCK_THREADS = 1024; - const uint32_t vec_size = std::gcd(16 / sizeof(DType), d); - - const uint32_t smem_size = - sizeof(RenormTempStorage); - dim3 nblks(batch_size); - dim3 nthrs(BLOCK_THREADS); - void* args[] = {&logits, &masked_logits, &top_k_arr, &top_k_val, &d}; - DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { - auto kernel = TopKMaskLogitsKernel; - APHRODITE_CUDA_CALL(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - APHRODITE_CUDA_CALL( - cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); - }); - return cudaSuccess; -} - -template -cudaError_t ParallelTopPSamplingFromProb( - T* probs, T* uniform_samples, IdType* output, bool* success, - IdType* row_indices, T* top_p_arr, uint32_t batch_size, uint32_t d, - uint32_t max_top_p_rounds, bool deterministic, cudaStream_t stream = 0) { - constexpr uint32_t BLOCK_THREADS = 1024; - const uint32_t vec_size = std::gcd(16 / sizeof(T), d); - - const uint32_t smem_size = - sizeof(SamplingTempStorage); - dim3 nblks(batch_size); - dim3 nthrs(BLOCK_THREADS); - T top_p_placeholder = 0; - void* args[] = { - &probs, &uniform_samples, &output, &success, &row_indices, - &top_p_arr, &top_p_placeholder, &d, &max_top_p_rounds}; - - DISPATCH_ALIGNED_VEC_SIZE( - vec_size, VEC_SIZE, - {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { - auto kernel = - TopPSamplingFromProbKernel; - APHRODITE_CUDA_CALL(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - APHRODITE_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, - smem_size, stream)); - })}); - return cudaSuccess; -} - -} // namespace sampling - -} // namespace aphrodite - -#endif // APHRODITE_SAMPLING_CUH_ \ No newline at end of file diff --git a/kernels/sampling/topk_topp.cu b/kernels/sampling/topk_topp.cu new file mode 100644 index 0000000000..76648a8e81 --- /dev/null +++ b/kernels/sampling/topk_topp.cu @@ -0,0 +1,286 @@ +#include "dispatch_utils.h" + +#include +#include +#include + +#ifndef USE_ROCM + #include +#else + #include +#endif + +namespace aphrodite { + +constexpr int TOP_K_MAX = 256; + +template +struct TopK_2 { + int idx = -1; + T val; + + __device__ __forceinline__ void insert(T elem, int elem_idx) { + // Use float comparison to avoid operator ambiguity + if (static_cast(elem) > static_cast(val)) { + val = elem; + idx = elem_idx; + } + } + + __device__ __forceinline__ void init() { + // Initialize with appropriate minimum value based on type + if constexpr (std::is_same_v) { + val = -FLT_MAX; + } else { + val = static_cast(-65504.0f); // Half precision min + } + idx = -1; + } +}; + +template +__device__ __forceinline__ TopK_2 reduce_topk_op(const TopK_2& a, const TopK_2& b) { + // Use float comparison to avoid operator ambiguity + return static_cast(a.val) > static_cast(b.val) ? a : b; +} + +// Stage 1: Find top-k values per block +template +__global__ void topk_stage1_kernel( + const scalar_t* __restrict__ logits, // [num_seqs, vocab_size] + scalar_t* __restrict__ tmp_logits, // [num_seqs, vocab_size] + int* __restrict__ topk_tmp_id_buf, // [num_seqs, BLOCKS_PER_SEQ * k] + scalar_t* __restrict__ topk_tmp_val_buf, // [num_seqs, BLOCKS_PER_SEQ * k] + const int* __restrict__ topk_values, // [num_seqs] - k values per sequence + const int num_seqs, + const int vocab_size) { + + typedef cub::BlockReduce, BLOCK_SIZE> BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + const int tid = threadIdx.x; + const int seq_idx = blockIdx.x / BLOCKS_PER_SEQ; + const int block_lane = blockIdx.x % BLOCKS_PER_SEQ; + + if (seq_idx >= num_seqs) return; + + const int k = topk_values[seq_idx]; + if (k == 0) return; + + const int logits_offset = seq_idx * vocab_size; + const int tmp_buf_offset = seq_idx * BLOCKS_PER_SEQ * TOP_K_MAX + block_lane * k; + + // Copy logits to temporary buffer + for (int i = tid + block_lane * BLOCK_SIZE; i < vocab_size; i += BLOCK_SIZE * BLOCKS_PER_SEQ) { + tmp_logits[logits_offset + i] = logits[logits_offset + i]; + } + + // Find top-k values iteratively + for (int ite = 0; ite < k; ite++) { + TopK_2 partial; + partial.init(); + + // Each thread finds its maximum + for (int i = tid + block_lane * BLOCK_SIZE; i < vocab_size; i += BLOCK_SIZE * BLOCKS_PER_SEQ) { + int idx = logits_offset + i; + partial.insert(tmp_logits[idx], idx); + } + + // Reduce across block + TopK_2 total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op); + + if (tid == 0) { + topk_tmp_id_buf[tmp_buf_offset + ite] = total.idx; + topk_tmp_val_buf[tmp_buf_offset + ite] = total.val; + if (total.idx >= 0) { + tmp_logits[total.idx] = static_cast(-65504.0f); // Safe min for half precision + } + } + __syncthreads(); + } +} + +// Stage 2: Merge results and sample +template +__global__ void topk_stage2_sampling_kernel( + const int* __restrict__ topk_tmp_id_buf, // [num_seqs, BLOCKS_PER_SEQ * k] + scalar_t* __restrict__ topk_tmp_val_buf, // [num_seqs, BLOCKS_PER_SEQ * k] + int64_t* __restrict__ output_ids, // [num_seqs] + float* __restrict__ output_logprobs, // [num_seqs] optional + const int* __restrict__ topk_values, // [num_seqs] + const float* __restrict__ top_p_values, // [num_seqs] optional + curandState_t* __restrict__ curand_states, // [num_seqs] + const int num_seqs, + const int vocab_size, + const bool normalize_logprobs) { + + typedef cub::BlockReduce, BLOCK_SIZE> BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + extern __shared__ char shared_mem[]; + + const int tid = threadIdx.x; + const int seq_idx = blockIdx.x; + + if (seq_idx >= num_seqs) return; + + const int k = topk_values[seq_idx]; + if (k == 0) return; + + const float top_p = top_p_values ? top_p_values[seq_idx] : 1.0f; + const int stride = TOP_K_MAX * BLOCKS_PER_SEQ; + + // Shared memory arrays + int* s_id = reinterpret_cast(shared_mem); + float* s_val = reinterpret_cast(s_id + k); + + __shared__ float s_sum; + if (tid == 0) { + s_sum = 0.0f; + } + __syncthreads(); + + scalar_t* val_buf = topk_tmp_val_buf + seq_idx * stride; + + // Find top-k across all blocks + float max_logit = -FLT_MAX; + for (int ite = 0; ite < k; ite++) { + TopK_2 partial; + partial.init(); + + // Each thread searches in the merged buffer + for (int i = tid; i < k * BLOCKS_PER_SEQ; i += BLOCK_SIZE) { + partial.insert(static_cast(val_buf[i]), i); + } + + TopK_2 total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op); + + if (tid == 0) { + if (ite == 0) { + max_logit = total.val; + } + s_id[ite] = total.idx; + val_buf[total.idx] = static_cast(-65504.0f); // Safe min for half precision + + // Convert to probability + total.val = expf(total.val - max_logit); + s_val[ite] = total.val; + s_sum += total.val; + } + __syncthreads(); + } + + // Sample from top-k + if (tid == 0) { + float rand_num = curand_states ? + curand_uniform(&curand_states[seq_idx]) * top_p * s_sum : + top_p * s_sum; + + int selected_idx = k - 1; // Default to last element + for (int i = 0; i < k; i++) { + rand_num -= s_val[i]; + if (rand_num <= 0.0f) { + selected_idx = i; + break; + } + } + + // Get actual token id + int buffer_idx = s_id[selected_idx]; + int token_id = buffer_idx >= 0 ? + topk_tmp_id_buf[seq_idx * stride + buffer_idx] % vocab_size : + vocab_size - 1; + + output_ids[seq_idx] = token_id; + + // Optional: output log probability + if (output_logprobs) { + float log_prob = logf(s_val[selected_idx]); + if (normalize_logprobs) { + log_prob -= logf(s_sum); + } + output_logprobs[seq_idx] = log_prob; + } + } +} + +} // namespace aphrodite + +void topk_topp_sampling( + torch::Tensor& logits, // [num_seqs, vocab_size] + torch::Tensor& output_ids, // [num_seqs] + const torch::Tensor& top_k_values, // [num_seqs] + const std::optional& top_p_values, // [num_seqs] optional + const std::optional& curand_states, // [num_seqs] optional + std::optional& output_logprobs, // [num_seqs] optional + bool normalize_logprobs = false) { + + TORCH_CHECK(logits.is_contiguous()); + TORCH_CHECK(output_ids.is_contiguous()); + TORCH_CHECK(top_k_values.is_contiguous()); + + int num_seqs = logits.size(0); + int vocab_size = logits.size(1); + + if (num_seqs == 0) return; + + // Allocate temporary buffers + constexpr int BLOCKS_PER_SEQ = 8; + auto options = torch::TensorOptions() + .dtype(logits.dtype()) + .device(logits.device()); + + torch::Tensor tmp_logits = torch::empty_like(logits); + torch::Tensor topk_tmp_id_buf = torch::empty({num_seqs, BLOCKS_PER_SEQ * aphrodite::TOP_K_MAX}, + torch::kInt32).to(logits.device()); + torch::Tensor topk_tmp_val_buf = torch::empty({num_seqs, BLOCKS_PER_SEQ * aphrodite::TOP_K_MAX}, + options); + + const at::cuda::OptionalCUDAGuard device_guard(device_of(logits)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Determine grid and block sizes + constexpr int BLOCK_SIZE_STAGE1 = 256; + constexpr int BLOCK_SIZE_STAGE2 = 128; + + dim3 grid1(num_seqs * BLOCKS_PER_SEQ); + dim3 block1(BLOCK_SIZE_STAGE1); + + dim3 grid2(num_seqs); + dim3 block2(BLOCK_SIZE_STAGE2); + + // Calculate shared memory size for stage 2 + // Assuming max k = TOP_K_MAX + size_t shared_mem_size = aphrodite::TOP_K_MAX * sizeof(int) + + aphrodite::TOP_K_MAX * sizeof(float); + + APHRODITE_DISPATCH_FLOATING_TYPES( + logits.scalar_type(), "topk_sampling", [&] { + // Stage 1: Find top-k per block + aphrodite::topk_stage1_kernel + <<>>( + logits.data_ptr(), + tmp_logits.data_ptr(), + topk_tmp_id_buf.data_ptr(), + topk_tmp_val_buf.data_ptr(), + top_k_values.data_ptr(), + num_seqs, + vocab_size); + + // Stage 2: Merge and sample + aphrodite::topk_stage2_sampling_kernel + <<>>( + topk_tmp_id_buf.data_ptr(), + topk_tmp_val_buf.data_ptr(), + output_ids.data_ptr(), + output_logprobs.has_value() ? + output_logprobs.value().data_ptr() : nullptr, + top_k_values.data_ptr(), + top_p_values.has_value() ? + top_p_values.value().data_ptr() : nullptr, + curand_states.has_value() ? + static_cast(curand_states.value().data_ptr()) : nullptr, + num_seqs, + vocab_size, + normalize_logprobs); + }); +} diff --git a/kernels/sampling/utils.cuh b/kernels/sampling/utils.cuh deleted file mode 100644 index eaddffa4e2..0000000000 --- a/kernels/sampling/utils.cuh +++ /dev/null @@ -1,273 +0,0 @@ -/* - * Copyright (c) 2024 by PygmalionAI team. - * Copyright (c) 2023 by FlashInfer team. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef APHRODITE_UTILS_CUH_ -#define APHRODITE_UTILS_CUH_ -#include - -#include -#include -#include -#include -#include - -#define STR_HELPER(x) #x -#define STR(x) STR_HELPER(x) - -// macro to turn off fp16 qk reduction to reduce binary -#ifndef APHRODITE_ALWAYS_DISALLOW_FP16_QK_REDUCTION - #define APHRODITE_ALWAYS_DISALLOW_FP16_QK_REDUCTION 0 -#endif - -#ifndef NDEBUG - #define APHRODITE_CUDA_CALL(func, ...) \ - { \ - cudaError_t e = (func); \ - if (e != cudaSuccess) { \ - std::cerr << "CUDA Error: " << cudaGetErrorString(e) << " (" << e \ - << ") " << __FILE__ << ": line " << __LINE__ \ - << " at function " << STR(func) << std::endl; \ - return e; \ - } \ - } -#else - #define APHRODITE_CUDA_CALL(func, ...) \ - { \ - cudaError_t e = (func); \ - if (e != cudaSuccess) { \ - return e; \ - } \ - } -#endif - -#define DISPATCH_ALLOW_FP16_QK_REDUCTION(allow_fp16_qk_reduction, \ - ALLOW_FP16_QK_REDUCTION, ...) \ - if (allow_fp16_qk_reduction) { \ - throw std::runtime_error("FP16_QK_REDUCTION disabled at compile time"); \ - } else { \ - constexpr bool ALLOW_FP16_QK_REDUCTION = false; \ - __VA_ARGS__ \ - } - -#define DISPATCH_NUM_FRAGS_X(num_frags_x, NUM_FRAGS_X, ...) \ - if (num_frags_x == 1) { \ - constexpr size_t NUM_FRAGS_X = 1; \ - __VA_ARGS__ \ - } else if (num_frags_x == 2) { \ - constexpr size_t NUM_FRAGS_X = 2; \ - __VA_ARGS__ \ - } else { \ - std::ostringstream err_msg; \ - err_msg << "Unsupported num_frags_x: " << num_frags_x; \ - throw std::invalid_argument(err_msg.str()); \ - } - -#define DISPATCH_NUM_FRAGS_Z(max_frags_z, NUM_FRAGS_Z, ...) \ - if (max_frags_z >= 8) { \ - constexpr size_t NUM_FRAGS_Z = 8; \ - __VA_ARGS__ \ - } else if (max_frags_z >= 4) { \ - constexpr size_t NUM_FRAGS_Z = 4; \ - __VA_ARGS__ \ - } else if (max_frags_z >= 2) { \ - constexpr size_t NUM_FRAGS_Z = 2; \ - __VA_ARGS__ \ - } else if (max_frags_z >= 1) { \ - constexpr size_t NUM_FRAGS_Z = 1; \ - __VA_ARGS__ \ - } else { \ - std::ostringstream err_msg; \ - err_msg << "Unsupported max_frags_z: " << max_frags_z; \ - throw std::invalid_argument(err_msg.str()); \ - } - -#define DISPATCH_GQA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \ - if (group_size == 1) { \ - constexpr size_t GROUP_SIZE = 1; \ - __VA_ARGS__ \ - } else if (group_size == 2) { \ - constexpr size_t GROUP_SIZE = 2; \ - __VA_ARGS__ \ - } else if (group_size == 4) { \ - constexpr size_t GROUP_SIZE = 4; \ - __VA_ARGS__ \ - } else if (group_size == 8) { \ - constexpr size_t GROUP_SIZE = 8; \ - __VA_ARGS__ \ - } else { \ - std::ostringstream err_msg; \ - err_msg << "Unsupported group_size: " << group_size; \ - throw std::invalid_argument(err_msg.str()); \ - } - -#define DISPATCH_MASK_MODE(mask_mode, MASK_MODE, ...) \ - switch (mask_mode) { \ - case MaskMode::kNone: { \ - constexpr MaskMode MASK_MODE = MaskMode::kNone; \ - __VA_ARGS__ \ - break; \ - } \ - case MaskMode::kCausal: { \ - constexpr MaskMode MASK_MODE = MaskMode::kCausal; \ - __VA_ARGS__ \ - break; \ - } \ - case MaskMode::kCustom: { \ - constexpr MaskMode MASK_MODE = MaskMode::kCustom; \ - __VA_ARGS__ \ - break; \ - } \ - default: { \ - std::ostringstream err_msg; \ - err_msg << "Unsupported mask_mode: " << int(mask_mode); \ - throw std::invalid_argument(err_msg.str()); \ - } \ - } - -#define DISPATCH_LOGITS_POST_HOOK(logits_soft_cap, LOGITS_POST_HOOK, ...) \ - if (logits_soft_cap > 0.f) { \ - constexpr LogitsPostHook LOGITS_POST_HOOK = LogitsPostHook::kSoftCap; \ - __VA_ARGS__ \ - } else if (logits_soft_cap == 0.f) { \ - constexpr LogitsPostHook LOGITS_POST_HOOK = LogitsPostHook::kNone; \ - __VA_ARGS__ \ - } else { \ - std::ostringstream err_msg; \ - err_msg << "Invalid logits_soft_cap (should be >= 0): " \ - << logits_soft_cap; \ - throw std::invalid_argument(err_msg.str()); \ - } - -#define DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, ...) \ - switch (head_dim) { \ - case 64: { \ - constexpr size_t HEAD_DIM = 64; \ - __VA_ARGS__ \ - break; \ - } \ - case 128: { \ - constexpr size_t HEAD_DIM = 128; \ - __VA_ARGS__ \ - break; \ - } \ - case 256: { \ - constexpr size_t HEAD_DIM = 256; \ - __VA_ARGS__ \ - break; \ - } \ - default: { \ - std::ostringstream err_msg; \ - err_msg << "Unsupported head_dim: " << head_dim; \ - throw std::invalid_argument(err_msg.str()); \ - } \ - } - -#define DISPATCH_POS_ENCODING_MODE(pos_encoding_mode, POS_ENCODING_MODE, ...) \ - switch (pos_encoding_mode) { \ - case PosEncodingMode::kNone: { \ - constexpr PosEncodingMode POS_ENCODING_MODE = PosEncodingMode::kNone; \ - __VA_ARGS__ \ - break; \ - } \ - case PosEncodingMode::kRoPELlama: { \ - constexpr PosEncodingMode POS_ENCODING_MODE = \ - PosEncodingMode::kRoPELlama; \ - __VA_ARGS__ \ - break; \ - } \ - case PosEncodingMode::kALiBi: { \ - constexpr PosEncodingMode POS_ENCODING_MODE = PosEncodingMode::kALiBi; \ - __VA_ARGS__ \ - break; \ - } \ - default: { \ - std::ostringstream err_msg; \ - err_msg << "Unsupported pos_encoding_mode: " << int(pos_encoding_mode); \ - throw std::invalid_argument(err_msg.str()); \ - } \ - } - -#define DISPATCH_ALIGNED_VEC_SIZE(aligned_vec_size, ALIGNED_VEC_SIZE, ...) \ - switch (aligned_vec_size) { \ - case 16: { \ - constexpr size_t ALIGNED_VEC_SIZE = 16; \ - __VA_ARGS__ \ - break; \ - } \ - case 8: { \ - constexpr size_t ALIGNED_VEC_SIZE = 8; \ - __VA_ARGS__ \ - break; \ - } \ - case 4: { \ - constexpr size_t ALIGNED_VEC_SIZE = 4; \ - __VA_ARGS__ \ - break; \ - } \ - case 2: { \ - constexpr size_t ALIGNED_VEC_SIZE = 2; \ - __VA_ARGS__ \ - break; \ - } \ - case 1: { \ - constexpr size_t ALIGNED_VEC_SIZE = 1; \ - __VA_ARGS__ \ - break; \ - } \ - default: { \ - std::ostringstream err_msg; \ - err_msg << "Unsupported aligned_vec_size: " << aligned_vec_size; \ - throw std::invalid_argument(err_msg.str()); \ - } \ - } - -namespace aphrodite { - -template -__forceinline__ __device__ __host__ T1 ceil_div(const T1 x, const T2 y) { - return (x + y - 1) / y; -} - -template -inline void DebugPrintCUDAArray(T* device_ptr, size_t size, - std::string prefix = "") { - std::vector host_array(size); - std::cout << prefix; - cudaMemcpy(host_array.data(), device_ptr, size * sizeof(T), - cudaMemcpyDeviceToHost); - for (size_t i = 0; i < size; ++i) { - std::cout << host_array[i] << " "; - } - std::cout << std::endl; -} - -/*! - * \brief Return x - y if x > y, otherwise return 0. - */ -__device__ __forceinline__ uint32_t sub_if_greater_or_zero(uint32_t x, - uint32_t y) { - return (x > y) ? x - y : 0U; -} - -__device__ __forceinline__ void swap(uint32_t& a, uint32_t& b) { - uint32_t tmp = a; - a = b; - b = tmp; -} - -} // namespace aphrodite - -#endif // APHRODITE_UTILS_CUH_ \ No newline at end of file diff --git a/kernels/sampling/vec_dtypes.cuh b/kernels/sampling/vec_dtypes.cuh deleted file mode 100644 index acff9dfe03..0000000000 --- a/kernels/sampling/vec_dtypes.cuh +++ /dev/null @@ -1,1501 +0,0 @@ -/* - * Copyright (c) 2024 by PygmalionAI team. - * Copyright (c) 2023 by FlashInfer team. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef VEC_DTYPES_CUH_ -#define VEC_DTYPES_CUH_ - -#include -#include -#include -#include - -#include - -namespace aphrodite { - -#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 900)) - #define APHRODITE_HARDWARE_FP8_CONVERSION_ENABLED -#endif - -#define APHRODITE_INLINE inline __attribute__((always_inline)) __device__ - -/******************* vec_t type cast *******************/ - -template -struct vec_cast { - template - APHRODITE_INLINE static void cast(dst_t* dst, const src_t* src) { -#pragma unroll - for (size_t i = 0; i < vec_size; ++i) { - dst[i] = (dst_t)src[i]; - } - } -}; - -template <> -struct vec_cast { - template - APHRODITE_INLINE static void cast(float* dst, const half* src) { - if constexpr (vec_size == 1) { - dst[0] = (float)src[0]; - } else { -#pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - ((float2*)dst)[i] = __half22float2(((half2*)src)[i]); - } - } - } -}; - -template <> -struct vec_cast { - template - APHRODITE_INLINE static void cast(half* dst, const float* src) { - if constexpr (vec_size == 1) { - dst[0] = __float2half(src[0]); - } else { -#pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - ((half2*)dst)[i] = __float22half2_rn(((float2*)src)[i]); - } - } - } -}; - -template -constexpr APHRODITE_INLINE int get_exponent_bits() { - if constexpr (std::is_same::value) { - return 4; - } else if constexpr (std::is_same::value) { - return 5; - } else if constexpr (std::is_same::value) { - return 5; - } else if constexpr (std::is_same::value) { - return 8; - } -} - -template -constexpr APHRODITE_INLINE int get_mantissa_bits() { - if constexpr (std::is_same::value) { - return 3; - } else if constexpr (std::is_same::value) { - return 2; - } else if constexpr (std::is_same::value) { - return 11; - } else if constexpr (std::is_same::value) { - return 7; - } -} - -/*! - * \brief Fallback to software fast dequant implementation if hardware - * dequantization is not available. \note Inspired by Marlin's fast - * dequantization, but here we don't have to permute weights order. \ref - * https://github.com/vllm-project/vllm/blob/6dffa4b0a6120159ef2fe44d695a46817aff65bc/csrc/quantization/fp8/fp8_marlin.cu#L120 - */ -template -__device__ void fast_dequant_f8f16x4(uint32_t* input, uint2* output) { - uint32_t q = *input; - if constexpr (std::is_same::value && - std::is_same::value) { - output->x = __byte_perm(0U, q, 0x5140); - output->y = __byte_perm(0U, q, 0x7362); - } else { - constexpr int FP8_EXPONENT = get_exponent_bits(); - constexpr int FP8_MANTISSA = get_mantissa_bits(); - constexpr int FP16_EXPONENT = get_exponent_bits(); - - constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT; - // Calculate MASK for extracting mantissa and exponent - constexpr int MASK1 = 0x80000000; - constexpr int MASK2 = MASK1 >> (FP8_EXPONENT + FP8_MANTISSA); - constexpr int MASK3 = MASK2 & 0x7fffffff; - constexpr int MASK = MASK3 | (MASK3 >> 16); - q = __byte_perm(q, q, 0x1302); - - // Extract and shift FP8 values to FP16 format - uint32_t Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); - uint32_t Out2 = - ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT); - - constexpr int BIAS_OFFSET = - (1 << (FP16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1)); - // Construct and apply exponent bias - if constexpr (std::is_same::value) { - const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET)); - - // Convert to half2 and apply bias - *(half2*)&(output->x) = - __hmul2(*reinterpret_cast(&Out1), bias_reg); - *(half2*)&(output->y) = - __hmul2(*reinterpret_cast(&Out2), bias_reg); - } else { - constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23; - const nv_bfloat162 bias_reg = - __float2bfloat162_rn(*reinterpret_cast(&BIAS)); - // Convert to bfloat162 and apply bias - *(nv_bfloat162*)&(output->x) = - __hmul2(*reinterpret_cast(&Out1), bias_reg); - *(nv_bfloat162*)&(output->y) = - __hmul2(*reinterpret_cast(&Out2), bias_reg); - } - } -} - -template <> -struct vec_cast { - template - APHRODITE_INLINE static void cast(nv_bfloat16* dst, - const __nv_fp8_e4m3* src) { - if constexpr (vec_size == 1) { - dst[0] = nv_bfloat16(src[0]); - } else if constexpr (vec_size == 2) { - dst[0] = nv_bfloat16(src[0]); - dst[1] = nv_bfloat16(src[1]); - } else { - static_assert(vec_size % 4 == 0, "vec_size must be a multiple of 4"); -#pragma unroll - for (uint32_t i = 0; i < vec_size / 4; ++i) { - fast_dequant_f8f16x4<__nv_fp8_e4m3, nv_bfloat16>((uint32_t*)&src[i * 4], - (uint2*)&dst[i * 4]); - } - } - } -}; - -template <> -struct vec_cast { - template - APHRODITE_INLINE static void cast(nv_bfloat16* dst, - const __nv_fp8_e5m2* src) { - if constexpr (vec_size == 1) { - dst[0] = nv_bfloat16(src[0]); - } else if constexpr (vec_size == 2) { - dst[0] = nv_bfloat16(src[0]); - dst[1] = nv_bfloat16(src[1]); - } else { - static_assert(vec_size % 4 == 0, "vec_size must be a multiple of 4"); -#pragma unroll - for (uint32_t i = 0; i < vec_size / 4; ++i) { - fast_dequant_f8f16x4<__nv_fp8_e5m2, nv_bfloat16>((uint32_t*)&src[i * 4], - (uint2*)&dst[i * 4]); - } - } - } -}; - -template <> -struct vec_cast<__nv_fp8_e4m3, half> { - template - APHRODITE_INLINE static void cast(__nv_fp8_e4m3* dst, const half* src) { -#ifdef APHRODITE_HARDWARE_FP8_CONVERSION_ENABLED - if constexpr (vec_size == 1) { - dst[0] = __nv_fp8_e4m3(src[0]); - } else { - #pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - uint16_t y; - uint32_t x = *(uint32_t*)&src[i * 2]; - asm volatile("cvt.rn.satfinite.e4m3x2.f16x2 %0, %1;" - : "=h"(y) - : "r"(x)); - *(uint16_t*)&dst[i * 2] = y; - } - } -#else - #pragma unroll - for (size_t i = 0; i < vec_size; ++i) { - dst[i] = __nv_fp8_e4m3(src[i]); - } -#endif // APHRODITE_HARDWARE_FP8_CONVERSION_ENABLED - } -}; - -template <> -struct vec_cast<__nv_fp8_e5m2, half> { - template - APHRODITE_INLINE static void cast(__nv_fp8_e5m2* dst, const half* src) { -#ifdef APHRODITE_HARDWARE_FP8_CONVERSION_ENABLED - if constexpr (vec_size == 1) { - dst[0] = __nv_fp8_e5m2(src[0]); - } else { - #pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - uint16_t y; - uint32_t x = *(uint32_t*)&src[i * 2]; - asm volatile("cvt.rn.satfinite.e5m2x2.f16x2 %0, %1;" - : "=h"(y) - : "r"(x)); - *(uint16_t*)&dst[i * 2] = y; - } - } -#else - #pragma unroll - for (size_t i = 0; i < vec_size; ++i) { - dst[i] = __nv_fp8_e5m2(src[i]); - } -#endif // APHRODITE_HARDWARE_FP8_CONVERSION_ENABLED - } -}; - -template <> -struct vec_cast { - template - APHRODITE_INLINE static void cast(half* dst, const __nv_fp8_e4m3* src) { -#ifdef APHRODITE_HARDWARE_FP8_CONVERSION_ENABLED - if constexpr (vec_size == 1) { - dst[0] = half(src[0]); - } else { - #pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - uint32_t y; - uint16_t x = *(uint16_t*)&src[i * 2]; - asm volatile("cvt.rn.f16x2.e4m3x2 %0, %1;" : "=r"(y) : "h"(x)); - *(uint32_t*)&dst[i * 2] = y; - } - } -#else - if constexpr (vec_size == 1) { - dst[0] = half(src[0]); - } else if constexpr (vec_size == 2) { - dst[0] = half(src[0]); - dst[1] = half(src[1]); - } else { - static_assert(vec_size % 4 == 0, "vec_size must be a multiple of 4"); - #pragma unroll - for (uint32_t i = 0; i < vec_size / 4; ++i) { - fast_dequant_f8f16x4<__nv_fp8_e4m3, half>((uint32_t*)&src[i * 4], - (uint2*)&dst[i * 4]); - } - } -#endif // APHRODITE_HARDWARE_FP8_CONVERSION_ENABLED - } -}; - -template <> -struct vec_cast { - template - APHRODITE_INLINE static void cast(half* dst, const __nv_fp8_e5m2* src) { -#ifdef APHRODITE_HARDWARE_FP8_CONVERSION_ENABLED - if constexpr (vec_size == 1) { - dst[0] = half(src[0]); - } else { - #pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - uint32_t y; - uint16_t x = *(uint16_t*)&src[i * 2]; - asm volatile("cvt.rn.f16x2.e5m2x2 %0, %1;" : "=r"(y) : "h"(x)); - *(uint32_t*)&dst[i * 2] = y; - } - } -#else - if constexpr (vec_size == 1) { - dst[0] = half(src[0]); - } else if constexpr (vec_size == 2) { - dst[0] = half(src[0]); - dst[1] = half(src[1]); - } else { - static_assert(vec_size % 4 == 0, "vec_size must be a multiple of 4"); - #pragma unroll - for (uint32_t i = 0; i < vec_size / 4; ++i) { - fast_dequant_f8f16x4<__nv_fp8_e5m2, half>((uint32_t*)&src[i * 4], - (uint2*)&dst[i * 4]); - } - } -#endif // APHRODITE_HARDWARE_FP8_CONVERSION_ENABLED - } -}; - -template <> -struct vec_cast { - template - APHRODITE_INLINE static void cast(float* dst, const nv_bfloat16* src) { - if constexpr (vec_size == 1) { - dst[0] = (float)src[0]; - } else { -#pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - ((float2*)dst)[i] = __bfloat1622float2(((nv_bfloat162*)src)[i]); - } - } - } -}; - -template <> -struct vec_cast { - template - APHRODITE_INLINE static void cast(nv_bfloat16* dst, const float* src) { - if constexpr (vec_size == 1) { - dst[0] = nv_bfloat16(src[0]); - } else { -#pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - ((nv_bfloat162*)dst)[i] = __float22bfloat162_rn(((float2*)src)[i]); - } - } - } -}; - -template -struct vec_t { - APHRODITE_INLINE float_t& operator[](size_t i); - APHRODITE_INLINE const float_t& operator[](size_t i) const; - APHRODITE_INLINE void fill(float_t val); - APHRODITE_INLINE void load(const float_t* ptr); - APHRODITE_INLINE void store(float_t* ptr) const; - template - APHRODITE_INLINE void cast_from(const vec_t& src); - template - APHRODITE_INLINE void cast_load(const T* ptr); - template - APHRODITE_INLINE void cast_store(T* ptr) const; - APHRODITE_INLINE static void memcpy(float_t* dst, const float_t* src); - APHRODITE_INLINE float_t* ptr(); -}; - -template -APHRODITE_INLINE void cast_from_impl(vec_t& dst, - const vec_t& src) { - vec_cast::cast( - dst.ptr(), const_cast*>(&src)->ptr()); -} - -template -APHRODITE_INLINE void cast_load_impl(vec_t& dst, - const src_float_t* src_ptr) { - if constexpr (std::is_same::value) { - dst.load(src_ptr); - } else { - vec_t tmp; - tmp.load(src_ptr); - dst.cast_from(tmp); - } -} - -template -APHRODITE_INLINE void cast_store_impl(tgt_float_t* dst_ptr, - const vec_t& src) { - if constexpr (std::is_same::value) { - src.store(dst_ptr); - } else { - vec_t tmp; - tmp.cast_from(src); - tmp.store(dst_ptr); - } -} - -/******************* vec_t<__nv_fp8_e4m3> *******************/ - -// __nv_fp8_e4m3 x 1 -template <> -struct vec_t<__nv_fp8_e4m3, 1> { - __nv_fp8_e4m3 data; - - APHRODITE_INLINE __nv_fp8_e4m3& operator[](size_t i) { - return ((__nv_fp8_e4m3*)(&data))[i]; - } - APHRODITE_INLINE const __nv_fp8_e4m3& operator[](size_t i) const { - return ((const __nv_fp8_e4m3*)(&data))[i]; - } - APHRODITE_INLINE __nv_fp8_e4m3* ptr() { - return reinterpret_cast<__nv_fp8_e4m3*>(&data); - } - APHRODITE_INLINE void fill(__nv_fp8_e4m3 val); - APHRODITE_INLINE void load(const __nv_fp8_e4m3* ptr); - APHRODITE_INLINE void store(__nv_fp8_e4m3* ptr) const; - template - APHRODITE_INLINE void cast_from(const vec_t& src) { - cast_from_impl(*this, src); - } - template - APHRODITE_INLINE void cast_load(const T* ptr) { - cast_load_impl(*this, ptr); - } - template - APHRODITE_INLINE void cast_store(T* ptr) const { - cast_store_impl(ptr, *this); - } - - APHRODITE_INLINE static void memcpy(__nv_fp8_e4m3* dst, - const __nv_fp8_e4m3* src); -}; - -APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 1>::fill(__nv_fp8_e4m3 val) { - data = val; -} - -APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 1>::load(const __nv_fp8_e4m3* ptr) { - data = *ptr; -} - -APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 1>::store(__nv_fp8_e4m3* ptr) const { - *ptr = data; -} - -APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 1>::memcpy( - __nv_fp8_e4m3* dst, const __nv_fp8_e4m3* src) { - *dst = *src; -} - -// __nv_fp8_e4m3 x 2 -template <> -struct vec_t<__nv_fp8_e4m3, 2> { - __nv_fp8x2_e4m3 data; - - APHRODITE_INLINE __nv_fp8_e4m3& operator[](size_t i) { - return ((__nv_fp8_e4m3*)(&data))[i]; - } - APHRODITE_INLINE const __nv_fp8_e4m3& operator[](size_t i) const { - return ((const __nv_fp8_e4m3*)(&data))[i]; - } - APHRODITE_INLINE __nv_fp8_e4m3* ptr() { - return reinterpret_cast<__nv_fp8_e4m3*>(&data); - } - APHRODITE_INLINE void fill(__nv_fp8_e4m3 val); - APHRODITE_INLINE void load(const __nv_fp8_e4m3* ptr); - APHRODITE_INLINE void store(__nv_fp8_e4m3* ptr) const; - template - APHRODITE_INLINE void cast_from(const vec_t& src) { - cast_from_impl(*this, src); - } - template - APHRODITE_INLINE void cast_load(const T* ptr) { - cast_load_impl(*this, ptr); - } - template - APHRODITE_INLINE void cast_store(T* ptr) const { - cast_store_impl(ptr, *this); - } - APHRODITE_INLINE static void memcpy(__nv_fp8_e4m3* dst, - const __nv_fp8_e4m3* src); -}; - -APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 2>::fill(__nv_fp8_e4m3 val) { - data.__x = - (__nv_fp8x2_storage_t(val.__x) << 8) | __nv_fp8x2_storage_t(val.__x); -} - -APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 2>::load(const __nv_fp8_e4m3* ptr) { - data = *((__nv_fp8x2_e4m3*)ptr); -} - -APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 2>::store(__nv_fp8_e4m3* ptr) const { - *((__nv_fp8x2_e4m3*)ptr) = data; -} - -APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 2>::memcpy( - __nv_fp8_e4m3* dst, const __nv_fp8_e4m3* src) { - *((__nv_fp8x2_e4m3*)dst) = *((__nv_fp8x2_e4m3*)src); -} - -// __nv_fp8_e4m3 x 4 - -template <> -struct vec_t<__nv_fp8_e4m3, 4> { - __nv_fp8x4_e4m3 data; - - APHRODITE_INLINE __nv_fp8_e4m3& operator[](size_t i) { - return ((__nv_fp8_e4m3*)(&data))[i]; - } - APHRODITE_INLINE const __nv_fp8_e4m3& operator[](size_t i) const { - return ((const __nv_fp8_e4m3*)(&data))[i]; - } - APHRODITE_INLINE __nv_fp8_e4m3* ptr() { - return reinterpret_cast<__nv_fp8_e4m3*>(&data); - } - APHRODITE_INLINE void fill(__nv_fp8_e4m3 val); - APHRODITE_INLINE void load(const __nv_fp8_e4m3* ptr); - APHRODITE_INLINE void store(__nv_fp8_e4m3* ptr) const; - template - APHRODITE_INLINE void cast_from(const vec_t& src) { - cast_from_impl(*this, src); - } - template - APHRODITE_INLINE void cast_load(const T* ptr) { - cast_load_impl(*this, ptr); - } - template - APHRODITE_INLINE void cast_store(T* ptr) const { - cast_store_impl(ptr, *this); - } - - APHRODITE_INLINE static void memcpy(__nv_fp8_e4m3* dst, - const __nv_fp8_e4m3* src); -}; - -APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 4>::fill(__nv_fp8_e4m3 val) { - data.__x = (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | - __nv_fp8x4_storage_t(val.__x); -} - -APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 4>::load(const __nv_fp8_e4m3* ptr) { - data = *((__nv_fp8x4_e4m3*)ptr); -} - -APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 4>::store(__nv_fp8_e4m3* ptr) const { - *((__nv_fp8x4_e4m3*)ptr) = data; -} - -APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 4>::memcpy( - __nv_fp8_e4m3* dst, const __nv_fp8_e4m3* src) { - *((__nv_fp8x4_e4m3*)dst) = *((__nv_fp8x4_e4m3*)src); -} - -// __nv_fp8_e4m3 x 8 - -template <> -struct vec_t<__nv_fp8_e4m3, 8> { - uint2 data; - - APHRODITE_INLINE __nv_fp8_e4m3& operator[](size_t i) { - return ((__nv_fp8_e4m3*)(&data))[i]; - } - APHRODITE_INLINE const __nv_fp8_e4m3& operator[](size_t i) const { - return ((const __nv_fp8_e4m3*)(&data))[i]; - } - APHRODITE_INLINE __nv_fp8_e4m3* ptr() { - return reinterpret_cast<__nv_fp8_e4m3*>(&data); - } - APHRODITE_INLINE void fill(__nv_fp8_e4m3 val); - APHRODITE_INLINE void load(const __nv_fp8_e4m3* ptr); - APHRODITE_INLINE void store(__nv_fp8_e4m3* ptr) const; - template - APHRODITE_INLINE void cast_from(const vec_t& src) { - cast_from_impl(*this, src); - } - template - APHRODITE_INLINE void cast_load(const T* ptr) { - cast_load_impl(*this, ptr); - } - template - APHRODITE_INLINE void cast_store(T* ptr) const { - cast_store_impl(ptr, *this); - } - - APHRODITE_INLINE static void memcpy(__nv_fp8_e4m3* dst, - const __nv_fp8_e4m3* src); -}; - -APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 8>::fill(__nv_fp8_e4m3 val) { - ((__nv_fp8x4_e4m3*)(&data.x))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | - __nv_fp8x4_storage_t(val.__x); - ((__nv_fp8x4_e4m3*)(&data.y))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | - __nv_fp8x4_storage_t(val.__x); -} - -APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 8>::load(const __nv_fp8_e4m3* ptr) { - data = *((uint2*)ptr); -} - -APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 8>::store(__nv_fp8_e4m3* ptr) const { - *((uint2*)ptr) = data; -} - -APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 8>::memcpy( - __nv_fp8_e4m3* dst, const __nv_fp8_e4m3* src) { - *((uint2*)dst) = *((uint2*)src); -} - -// __nv_fp8_e4m3 x 16 or more -template -struct vec_t<__nv_fp8_e4m3, vec_size> { - uint4 data[vec_size / 16]; - - APHRODITE_INLINE __nv_fp8_e4m3& operator[](size_t i) { - return ((__nv_fp8_e4m3*)data)[i]; - } - APHRODITE_INLINE const __nv_fp8_e4m3& operator[](size_t i) const { - return ((const __nv_fp8_e4m3*)data)[i]; - } - APHRODITE_INLINE __nv_fp8_e4m3* ptr() { - return reinterpret_cast<__nv_fp8_e4m3*>(&data); - } - APHRODITE_INLINE void fill(__nv_fp8_e4m3 val) { -#pragma unroll - for (size_t i = 0; i < vec_size / 16; ++i) { - ((__nv_fp8x4_e4m3*)(&(data[i].x)))->__x = - (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); - ((__nv_fp8x4_e4m3*)(&(data[i].y)))->__x = - (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); - ((__nv_fp8x4_e4m3*)(&(data[i].z)))->__x = - (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); - ((__nv_fp8x4_e4m3*)(&(data[i].w)))->__x = - (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); - } - } - APHRODITE_INLINE void load(const __nv_fp8_e4m3* ptr) { -#pragma unroll - for (size_t i = 0; i < vec_size / 16; ++i) { - data[i] = ((uint4*)ptr)[i]; - } - } - APHRODITE_INLINE void store(__nv_fp8_e4m3* ptr) const { -#pragma unroll - for (size_t i = 0; i < vec_size / 16; ++i) { - ((uint4*)ptr)[i] = data[i]; - } - } - template - APHRODITE_INLINE void cast_from(const vec_t& src) { - cast_from_impl(*this, src); - } - template - APHRODITE_INLINE void cast_load(const T* ptr) { - cast_load_impl(*this, ptr); - } - template - APHRODITE_INLINE void cast_store(T* ptr) const { - cast_store_impl(ptr, *this); - } - - APHRODITE_INLINE static void memcpy(__nv_fp8_e4m3* dst, - const __nv_fp8_e4m3* src) { -#pragma unroll - for (size_t i = 0; i < vec_size / 16; ++i) { - ((uint4*)dst)[i] = ((uint4*)src)[i]; - } - } -}; - -/******************* vec_t<__nv_fp8_e5m2> *******************/ - -// __nv_fp8_e5m2 x 1 -template <> -struct vec_t<__nv_fp8_e5m2, 1> { - __nv_fp8_e5m2 data; - - APHRODITE_INLINE __nv_fp8_e5m2& operator[](size_t i) { - return ((__nv_fp8_e5m2*)(&data))[i]; - } - APHRODITE_INLINE const __nv_fp8_e5m2& operator[](size_t i) const { - return ((const __nv_fp8_e5m2*)(&data))[i]; - } - APHRODITE_INLINE __nv_fp8_e5m2* ptr() { - return reinterpret_cast<__nv_fp8_e5m2*>(&data); - } - APHRODITE_INLINE void fill(__nv_fp8_e5m2 val); - APHRODITE_INLINE void load(const __nv_fp8_e5m2* ptr); - APHRODITE_INLINE void store(__nv_fp8_e5m2* ptr) const; - template - APHRODITE_INLINE void cast_from(const vec_t& src) { - cast_from_impl(*this, src); - } - template - APHRODITE_INLINE void cast_load(const T* ptr) { - cast_load_impl(*this, ptr); - } - template - APHRODITE_INLINE void cast_store(T* ptr) const { - cast_store_impl(ptr, *this); - } - - APHRODITE_INLINE static void memcpy(__nv_fp8_e5m2* dst, - const __nv_fp8_e5m2* src); -}; - -APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 1>::fill(__nv_fp8_e5m2 val) { - data = val; -} - -APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 1>::load(const __nv_fp8_e5m2* ptr) { - data = *ptr; -} - -APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 1>::store(__nv_fp8_e5m2* ptr) const { - *ptr = data; -} - -APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 1>::memcpy( - __nv_fp8_e5m2* dst, const __nv_fp8_e5m2* src) { - *dst = *src; -} - -// __nv_fp8_e5m2 x 2 -template <> -struct vec_t<__nv_fp8_e5m2, 2> { - __nv_fp8x2_e5m2 data; - - APHRODITE_INLINE __nv_fp8_e5m2& operator[](size_t i) { - return ((__nv_fp8_e5m2*)(&data))[i]; - } - APHRODITE_INLINE const __nv_fp8_e5m2& operator[](size_t i) const { - return ((const __nv_fp8_e5m2*)(&data))[i]; - } - APHRODITE_INLINE __nv_fp8_e5m2* ptr() { - return reinterpret_cast<__nv_fp8_e5m2*>(&data); - } - APHRODITE_INLINE void fill(__nv_fp8_e5m2 val); - APHRODITE_INLINE void load(const __nv_fp8_e5m2* ptr); - APHRODITE_INLINE void store(__nv_fp8_e5m2* ptr) const; - template - APHRODITE_INLINE void cast_from(const vec_t& src) { - cast_from_impl(*this, src); - } - template - APHRODITE_INLINE void cast_load(const T* ptr) { - cast_load_impl(*this, ptr); - } - template - APHRODITE_INLINE void cast_store(T* ptr) const { - cast_store_impl(ptr, *this); - } - - APHRODITE_INLINE static void memcpy(__nv_fp8_e5m2* dst, - const __nv_fp8_e5m2* src); -}; - -APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 2>::fill(__nv_fp8_e5m2 val) { - data.__x = - (__nv_fp8x2_storage_t(val.__x) << 8) | __nv_fp8x2_storage_t(val.__x); -} - -APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 2>::load(const __nv_fp8_e5m2* ptr) { - data = *((__nv_fp8x2_e5m2*)ptr); -} - -APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 2>::store(__nv_fp8_e5m2* ptr) const { - *((__nv_fp8x2_e5m2*)ptr) = data; -} - -APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 2>::memcpy( - __nv_fp8_e5m2* dst, const __nv_fp8_e5m2* src) { - *((__nv_fp8x2_e5m2*)dst) = *((__nv_fp8x2_e5m2*)src); -} - -// __nv_fp8_e5m2 x 4 - -template <> -struct vec_t<__nv_fp8_e5m2, 4> { - __nv_fp8x4_e5m2 data; - - APHRODITE_INLINE __nv_fp8_e5m2& operator[](size_t i) { - return ((__nv_fp8_e5m2*)(&data))[i]; - } - APHRODITE_INLINE const __nv_fp8_e5m2& operator[](size_t i) const { - return ((const __nv_fp8_e5m2*)(&data))[i]; - } - APHRODITE_INLINE __nv_fp8_e5m2* ptr() { - return reinterpret_cast<__nv_fp8_e5m2*>(&data); - } - APHRODITE_INLINE void fill(__nv_fp8_e5m2 val); - APHRODITE_INLINE void load(const __nv_fp8_e5m2* ptr); - APHRODITE_INLINE void store(__nv_fp8_e5m2* ptr) const; - template - APHRODITE_INLINE void cast_from(const vec_t& src) { - cast_from_impl(*this, src); - } - template - APHRODITE_INLINE void cast_load(const T* ptr) { - cast_load_impl(*this, ptr); - } - template - APHRODITE_INLINE void cast_store(T* ptr) const { - cast_store_impl(ptr, *this); - } - - APHRODITE_INLINE static void memcpy(__nv_fp8_e5m2* dst, - const __nv_fp8_e5m2* src); -}; - -APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 4>::fill(__nv_fp8_e5m2 val) { - data.__x = (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | - __nv_fp8x4_storage_t(val.__x); -} - -APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 4>::load(const __nv_fp8_e5m2* ptr) { - data = *((__nv_fp8x4_e5m2*)ptr); -} - -APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 4>::store(__nv_fp8_e5m2* ptr) const { - *((__nv_fp8x4_e5m2*)ptr) = data; -} - -APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 4>::memcpy( - __nv_fp8_e5m2* dst, const __nv_fp8_e5m2* src) { - *((__nv_fp8x4_e5m2*)dst) = *((__nv_fp8x4_e5m2*)src); -} - -// __nv_fp8_e5m2 x 8 - -template <> -struct vec_t<__nv_fp8_e5m2, 8> { - uint2 data; - - APHRODITE_INLINE __nv_fp8_e5m2& operator[](size_t i) { - return ((__nv_fp8_e5m2*)(&data))[i]; - } - APHRODITE_INLINE const __nv_fp8_e5m2& operator[](size_t i) const { - return ((const __nv_fp8_e5m2*)(&data))[i]; - } - APHRODITE_INLINE __nv_fp8_e5m2* ptr() { - return reinterpret_cast<__nv_fp8_e5m2*>(&data); - } - APHRODITE_INLINE void fill(__nv_fp8_e5m2 val); - APHRODITE_INLINE void load(const __nv_fp8_e5m2* ptr); - APHRODITE_INLINE void store(__nv_fp8_e5m2* ptr) const; - template - APHRODITE_INLINE void cast_from(const vec_t& src) { - cast_from_impl(*this, src); - } - template - APHRODITE_INLINE void cast_load(const T* ptr) { - cast_load_impl(*this, ptr); - } - template - APHRODITE_INLINE void cast_store(T* ptr) const { - cast_store_impl(ptr, *this); - } - APHRODITE_INLINE static void memcpy(__nv_fp8_e5m2* dst, - const __nv_fp8_e5m2* src); -}; - -APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 8>::fill(__nv_fp8_e5m2 val) { - ((__nv_fp8x4_e5m2*)(&data.x))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | - __nv_fp8x4_storage_t(val.__x); - ((__nv_fp8x4_e5m2*)(&data.y))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | - __nv_fp8x4_storage_t(val.__x); -} - -APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 8>::load(const __nv_fp8_e5m2* ptr) { - data = *((uint2*)ptr); -} - -APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 8>::store(__nv_fp8_e5m2* ptr) const { - *((uint2*)ptr) = data; -} - -APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 8>::memcpy( - __nv_fp8_e5m2* dst, const __nv_fp8_e5m2* src) { - *((uint2*)dst) = *((uint2*)src); -} - -// __nv_fp8_e5m2 x 16 or more - -template -struct vec_t<__nv_fp8_e5m2, vec_size> { - uint4 data[vec_size / 16]; - - APHRODITE_INLINE __nv_fp8_e5m2& operator[](size_t i) { - return ((__nv_fp8_e5m2*)data)[i]; - } - APHRODITE_INLINE const __nv_fp8_e5m2& operator[](size_t i) const { - return ((const __nv_fp8_e5m2*)data)[i]; - } - APHRODITE_INLINE __nv_fp8_e5m2* ptr() { - return reinterpret_cast<__nv_fp8_e5m2*>(&data); - } - APHRODITE_INLINE void fill(__nv_fp8_e5m2 val) { -#pragma unroll - for (size_t i = 0; i < vec_size / 16; ++i) { - ((__nv_fp8x4_e5m2*)(&(data[i].x)))->__x = - (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); - ((__nv_fp8x4_e5m2*)(&(data[i].y)))->__x = - (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); - ((__nv_fp8x4_e5m2*)(&(data[i].z)))->__x = - (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); - ((__nv_fp8x4_e5m2*)(&(data[i].w)))->__x = - (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); - } - } - APHRODITE_INLINE void load(const __nv_fp8_e5m2* ptr) { -#pragma unroll - for (size_t i = 0; i < vec_size / 16; ++i) { - data[i] = ((uint4*)ptr)[i]; - } - } - APHRODITE_INLINE void store(__nv_fp8_e5m2* ptr) const { -#pragma unroll - for (size_t i = 0; i < vec_size / 16; ++i) { - ((uint4*)ptr)[i] = data[i]; - } - } - template - APHRODITE_INLINE void cast_from(const vec_t& src) { - cast_from_impl(*this, src); - } - template - APHRODITE_INLINE void cast_load(const T* ptr) { - cast_load_impl(*this, ptr); - } - template - APHRODITE_INLINE void cast_store(T* ptr) const { - cast_store_impl(ptr, *this); - } - APHRODITE_INLINE static void memcpy(__nv_fp8_e5m2* dst, - const __nv_fp8_e5m2* src) { -#pragma unroll - for (size_t i = 0; i < vec_size / 16; ++i) { - ((uint4*)dst)[i] = ((uint4*)src)[i]; - } - } -}; - -/******************* vec_t *******************/ - -// half x 1 -template <> -struct vec_t { - half data; - - APHRODITE_INLINE half& operator[](size_t i) { return ((half*)(&data))[i]; } - APHRODITE_INLINE const half& operator[](size_t i) const { - return ((const half*)(&data))[i]; - } - APHRODITE_INLINE half* ptr() { return reinterpret_cast(&data); } - APHRODITE_INLINE void fill(half val); - APHRODITE_INLINE void load(const half* ptr); - APHRODITE_INLINE void store(half* ptr) const; - template - APHRODITE_INLINE void cast_from(const vec_t& src) { - cast_from_impl(*this, src); - } - template - APHRODITE_INLINE void cast_load(const T* ptr) { - cast_load_impl(*this, ptr); - } - template - APHRODITE_INLINE void cast_store(T* ptr) const { - cast_store_impl(ptr, *this); - } - - APHRODITE_INLINE static void memcpy(half* dst, const half* src); -}; - -APHRODITE_INLINE void vec_t::fill(half val) { data = val; } - -APHRODITE_INLINE void vec_t::load(const half* ptr) { data = *ptr; } - -APHRODITE_INLINE void vec_t::store(half* ptr) const { *ptr = data; } - -APHRODITE_INLINE void vec_t::memcpy(half* dst, const half* src) { - *dst = *src; -} - -// half x 2 -template <> -struct vec_t { - half2 data; - - APHRODITE_INLINE half& operator[](size_t i) { return ((half*)(&data))[i]; } - APHRODITE_INLINE const half& operator[](size_t i) const { - return ((const half*)(&data))[i]; - } - APHRODITE_INLINE half* ptr() { return reinterpret_cast(&data); } - APHRODITE_INLINE void fill(half val); - APHRODITE_INLINE void load(const half* ptr); - APHRODITE_INLINE void store(half* ptr) const; - template - APHRODITE_INLINE void cast_from(const vec_t& src) { - cast_from_impl(*this, src); - } - template - APHRODITE_INLINE void cast_load(const T* ptr) { - cast_load_impl(*this, ptr); - } - template - APHRODITE_INLINE void cast_store(T* ptr) const { - cast_store_impl(ptr, *this); - } - - APHRODITE_INLINE static void memcpy(half* dst, const half* src); -}; - -APHRODITE_INLINE void vec_t::fill(half val) { - data = make_half2(val, val); -} - -APHRODITE_INLINE void vec_t::load(const half* ptr) { - data = *((half2*)ptr); -} - -APHRODITE_INLINE void vec_t::store(half* ptr) const { - *((half2*)ptr) = data; -} - -APHRODITE_INLINE void vec_t::memcpy(half* dst, const half* src) { - *((half2*)dst) = *((half2*)src); -} - -// half x 4 - -template <> -struct vec_t { - uint2 data; - - APHRODITE_INLINE half& operator[](size_t i) { return ((half*)(&data))[i]; } - APHRODITE_INLINE const half& operator[](size_t i) const { - return ((const half*)(&data))[i]; - } - APHRODITE_INLINE half* ptr() { return reinterpret_cast(&data); } - APHRODITE_INLINE void fill(half val); - APHRODITE_INLINE void load(const half* ptr); - APHRODITE_INLINE void store(half* ptr) const; - template - APHRODITE_INLINE void cast_from(const vec_t& src) { - cast_from_impl(*this, src); - } - template - APHRODITE_INLINE void cast_load(const T* ptr) { - cast_load_impl(*this, ptr); - } - template - APHRODITE_INLINE void cast_store(T* ptr) const { - cast_store_impl(ptr, *this); - } - APHRODITE_INLINE static void memcpy(half* dst, const half* src); -}; - -APHRODITE_INLINE void vec_t::fill(half val) { - *(half2*)(&data.x) = make_half2(val, val); - *(half2*)(&data.y) = make_half2(val, val); -} - -APHRODITE_INLINE void vec_t::load(const half* ptr) { - data = *((uint2*)ptr); -} - -APHRODITE_INLINE void vec_t::store(half* ptr) const { - *((uint2*)ptr) = data; -} - -APHRODITE_INLINE void vec_t::memcpy(half* dst, const half* src) { - *((uint2*)dst) = *((uint2*)src); -} - -// half x 8 or more - -template -struct vec_t { - uint4 data[vec_size / 8]; - APHRODITE_INLINE half& operator[](size_t i) { return ((half*)data)[i]; } - APHRODITE_INLINE const half& operator[](size_t i) const { - return ((const half*)data)[i]; - } - APHRODITE_INLINE half* ptr() { return reinterpret_cast(&data); } - APHRODITE_INLINE void fill(half val) { -#pragma unroll - for (size_t i = 0; i < vec_size / 8; ++i) { - *(half2*)(&(data[i].x)) = make_half2(val, val); - *(half2*)(&(data[i].y)) = make_half2(val, val); - *(half2*)(&(data[i].z)) = make_half2(val, val); - *(half2*)(&(data[i].w)) = make_half2(val, val); - } - } - APHRODITE_INLINE void load(const half* ptr) { -#pragma unroll - for (size_t i = 0; i < vec_size / 8; ++i) { - data[i] = ((uint4*)ptr)[i]; - } - } - APHRODITE_INLINE void store(half* ptr) const { -#pragma unroll - for (size_t i = 0; i < vec_size / 8; ++i) { - ((uint4*)ptr)[i] = data[i]; - } - } - template - APHRODITE_INLINE void cast_from(const vec_t& src) { - cast_from_impl(*this, src); - } - template - APHRODITE_INLINE void cast_load(const T* ptr) { - cast_load_impl(*this, ptr); - } - template - APHRODITE_INLINE void cast_store(T* ptr) const { - cast_store_impl(ptr, *this); - } - APHRODITE_INLINE static void memcpy(half* dst, const half* src) { -#pragma unroll - for (size_t i = 0; i < vec_size / 8; ++i) { - ((uint4*)dst)[i] = ((uint4*)src)[i]; - } - } -}; - -/******************* vec_t *******************/ - -// nv_bfloat16 x 1 -template <> -struct vec_t { - nv_bfloat16 data; - APHRODITE_INLINE nv_bfloat16& operator[](size_t i) { - return ((nv_bfloat16*)(&data))[i]; - } - APHRODITE_INLINE const nv_bfloat16& operator[](size_t i) const { - return ((const nv_bfloat16*)(&data))[i]; - } - APHRODITE_INLINE nv_bfloat16* ptr() { - return reinterpret_cast(&data); - } - APHRODITE_INLINE void fill(nv_bfloat16 val); - APHRODITE_INLINE void load(const nv_bfloat16* ptr); - APHRODITE_INLINE void store(nv_bfloat16* ptr) const; - template - APHRODITE_INLINE void cast_from(const vec_t& src) { - cast_from_impl(*this, src); - } - template - APHRODITE_INLINE void cast_load(const T* ptr) { - cast_load_impl(*this, ptr); - } - template - APHRODITE_INLINE void cast_store(T* ptr) const { - cast_store_impl(ptr, *this); - } - APHRODITE_INLINE static void memcpy(nv_bfloat16* dst, const nv_bfloat16* src); -}; - -APHRODITE_INLINE void vec_t::fill(nv_bfloat16 val) { - data = val; -} - -APHRODITE_INLINE void vec_t::load(const nv_bfloat16* ptr) { - data = *ptr; -} - -APHRODITE_INLINE void vec_t::store(nv_bfloat16* ptr) const { - *ptr = data; -} - -APHRODITE_INLINE void vec_t::memcpy(nv_bfloat16* dst, - const nv_bfloat16* src) { - *dst = *src; -} - -// nv_bfloat16 x 2 -template <> -struct vec_t { - nv_bfloat162 data; - - APHRODITE_INLINE nv_bfloat16& operator[](size_t i) { - return ((nv_bfloat16*)(&data))[i]; - } - APHRODITE_INLINE const nv_bfloat16& operator[](size_t i) const { - return ((const nv_bfloat16*)(&data))[i]; - } - APHRODITE_INLINE nv_bfloat16* ptr() { - return reinterpret_cast(&data); - } - APHRODITE_INLINE void fill(nv_bfloat16 val); - APHRODITE_INLINE void load(const nv_bfloat16* ptr); - APHRODITE_INLINE void store(nv_bfloat16* ptr) const; - template - APHRODITE_INLINE void cast_from(const vec_t& src) { - cast_from_impl(*this, src); - } - template - APHRODITE_INLINE void cast_load(const T* ptr) { - cast_load_impl(*this, ptr); - } - template - APHRODITE_INLINE void cast_store(T* ptr) const { - cast_store_impl(ptr, *this); - } - APHRODITE_INLINE static void memcpy(nv_bfloat16* dst, const nv_bfloat16* src); -}; - -APHRODITE_INLINE void vec_t::fill(nv_bfloat16 val) { - data = make_bfloat162(val, val); -} - -APHRODITE_INLINE void vec_t::load(const nv_bfloat16* ptr) { - data = *((nv_bfloat162*)ptr); -} - -APHRODITE_INLINE void vec_t::store(nv_bfloat16* ptr) const { - *((nv_bfloat162*)ptr) = data; -} - -APHRODITE_INLINE void vec_t::memcpy(nv_bfloat16* dst, - const nv_bfloat16* src) { - *((nv_bfloat162*)dst) = *((nv_bfloat162*)src); -} - -// nv_bfloat16 x 4 - -template <> -struct vec_t { - uint2 data; - - APHRODITE_INLINE nv_bfloat16& operator[](size_t i) { - return ((nv_bfloat16*)(&data))[i]; - } - APHRODITE_INLINE const nv_bfloat16& operator[](size_t i) const { - return ((const nv_bfloat16*)(&data))[i]; - } - APHRODITE_INLINE nv_bfloat16* ptr() { - return reinterpret_cast(&data); - } - APHRODITE_INLINE void fill(nv_bfloat16 val); - APHRODITE_INLINE void load(const nv_bfloat16* ptr); - APHRODITE_INLINE void store(nv_bfloat16* ptr) const; - template - APHRODITE_INLINE void cast_from(const vec_t& src) { - cast_from_impl(*this, src); - } - template - APHRODITE_INLINE void cast_load(const T* ptr) { - cast_load_impl(*this, ptr); - } - template - APHRODITE_INLINE void cast_store(T* ptr) const { - cast_store_impl(ptr, *this); - } - APHRODITE_INLINE static void memcpy(nv_bfloat16* dst, const nv_bfloat16* src); -}; - -APHRODITE_INLINE void vec_t::fill(nv_bfloat16 val) { - *(nv_bfloat162*)(&data.x) = make_bfloat162(val, val); - *(nv_bfloat162*)(&data.y) = make_bfloat162(val, val); -} - -APHRODITE_INLINE void vec_t::load(const nv_bfloat16* ptr) { - data = *((uint2*)ptr); -} - -APHRODITE_INLINE void vec_t::store(nv_bfloat16* ptr) const { - *((uint2*)ptr) = data; -} - -APHRODITE_INLINE void vec_t::memcpy(nv_bfloat16* dst, - const nv_bfloat16* src) { - *((uint2*)dst) = *((uint2*)src); -} - -// nv_bfloat16 x 8 or more - -template -struct vec_t { - uint4 data[vec_size / 8]; - - APHRODITE_INLINE nv_bfloat16& operator[](size_t i) { - return ((nv_bfloat16*)data)[i]; - } - APHRODITE_INLINE const nv_bfloat16& operator[](size_t i) const { - return ((const nv_bfloat16*)data)[i]; - } - APHRODITE_INLINE nv_bfloat16* ptr() { - return reinterpret_cast(&data); - } - APHRODITE_INLINE void fill(nv_bfloat16 val) { -#pragma unoll - for (size_t i = 0; i < vec_size / 8; ++i) { - *(nv_bfloat162*)(&(data[i].x)) = make_bfloat162(val, val); - *(nv_bfloat162*)(&(data[i].y)) = make_bfloat162(val, val); - *(nv_bfloat162*)(&(data[i].z)) = make_bfloat162(val, val); - *(nv_bfloat162*)(&(data[i].w)) = make_bfloat162(val, val); - } - } - APHRODITE_INLINE void load(const nv_bfloat16* ptr) { -#pragma unoll - for (size_t i = 0; i < vec_size / 8; ++i) { - data[i] = ((uint4*)ptr)[i]; - } - } - APHRODITE_INLINE void store(nv_bfloat16* ptr) const { -#pragma unoll - for (size_t i = 0; i < vec_size / 8; ++i) { - ((uint4*)ptr)[i] = data[i]; - } - } - template - APHRODITE_INLINE void cast_from(const vec_t& src) { - cast_from_impl(*this, src); - } - template - APHRODITE_INLINE void cast_load(const T* ptr) { - cast_load_impl(*this, ptr); - } - template - APHRODITE_INLINE void cast_store(T* ptr) const { - cast_store_impl(ptr, *this); - } - APHRODITE_INLINE static void memcpy(nv_bfloat16* dst, - const nv_bfloat16* src) { -#pragma unoll - for (size_t i = 0; i < vec_size / 8; ++i) { - ((uint4*)dst)[i] = ((uint4*)src)[i]; - } - } -}; - -/******************* vec_t *******************/ - -// float x 1 - -template <> -struct vec_t { - float data; - - APHRODITE_INLINE float& operator[](size_t i) { return ((float*)(&data))[i]; } - APHRODITE_INLINE const float& operator[](size_t i) const { - return ((const float*)(&data))[i]; - } - APHRODITE_INLINE float* ptr() { return reinterpret_cast(&data); } - APHRODITE_INLINE void fill(float val); - APHRODITE_INLINE void load(const float* ptr); - APHRODITE_INLINE void store(float* ptr) const; - template - APHRODITE_INLINE void cast_from(const vec_t& src) { - cast_from_impl(*this, src); - } - template - APHRODITE_INLINE void cast_load(const T* ptr) { - cast_load_impl(*this, ptr); - } - template - APHRODITE_INLINE void cast_store(T* ptr) const { - cast_store_impl(ptr, *this); - } - APHRODITE_INLINE static void memcpy(float* dst, const float* src); -}; - -APHRODITE_INLINE void vec_t::fill(float val) { data = val; } - -APHRODITE_INLINE void vec_t::load(const float* ptr) { data = *ptr; } - -APHRODITE_INLINE void vec_t::store(float* ptr) const { *ptr = data; } - -APHRODITE_INLINE void vec_t::memcpy(float* dst, const float* src) { - *dst = *src; -} - -// float x 2 - -template <> -struct vec_t { - float2 data; - - APHRODITE_INLINE float& operator[](size_t i) { return ((float*)(&data))[i]; } - APHRODITE_INLINE const float& operator[](size_t i) const { - return ((const float*)(&data))[i]; - } - APHRODITE_INLINE float* ptr() { return reinterpret_cast(&data); } - APHRODITE_INLINE void fill(float val); - APHRODITE_INLINE void load(const float* ptr); - APHRODITE_INLINE void store(float* ptr) const; - template - APHRODITE_INLINE void cast_from(const vec_t& src) { - cast_from_impl(*this, src); - } - template - APHRODITE_INLINE void cast_load(const T* ptr) { - cast_load_impl(*this, ptr); - } - template - APHRODITE_INLINE void cast_store(T* ptr) const { - cast_store_impl(ptr, *this); - } - APHRODITE_INLINE static void memcpy(float* dst, const float* src); -}; - -APHRODITE_INLINE void vec_t::fill(float val) { - data = make_float2(val, val); -} - -APHRODITE_INLINE void vec_t::load(const float* ptr) { - data = *((float2*)ptr); -} - -APHRODITE_INLINE void vec_t::store(float* ptr) const { - *((float2*)ptr) = data; -} - -APHRODITE_INLINE void vec_t::memcpy(float* dst, const float* src) { - *((float2*)dst) = *((float2*)src); -} - -// float x 4 or more -template -struct vec_t { - float4 data[vec_size / 4]; - - APHRODITE_INLINE float& operator[](size_t i) { return ((float*)(data))[i]; } - APHRODITE_INLINE const float& operator[](size_t i) const { - return ((const float*)(data))[i]; - } - APHRODITE_INLINE float* ptr() { return reinterpret_cast(&data); } - APHRODITE_INLINE void fill(float val) { -#pragma unroll - for (size_t i = 0; i < vec_size / 4; ++i) { - data[i] = make_float4(val, val, val, val); - } - } - APHRODITE_INLINE void load(const float* ptr) { -#pragma unroll - for (size_t i = 0; i < vec_size / 4; ++i) { - data[i] = ((float4*)ptr)[i]; - } - } - APHRODITE_INLINE void store(float* ptr) const { -#pragma unroll - for (size_t i = 0; i < vec_size / 4; ++i) { - ((float4*)ptr)[i] = data[i]; - } - } - template - APHRODITE_INLINE void cast_from(const vec_t& src) { - cast_from_impl(*this, src); - } - template - APHRODITE_INLINE void cast_load(const T* ptr) { - cast_load_impl(*this, ptr); - } - template - APHRODITE_INLINE void cast_store(T* ptr) const { - cast_store_impl(ptr, *this); - } - APHRODITE_INLINE static void memcpy(float* dst, const float* src) { -#pragma unroll - for (size_t i = 0; i < vec_size / 4; ++i) { - ((float4*)dst)[i] = ((float4*)src)[i]; - } - } -}; - -} // namespace aphrodite - -#endif // VEC_DTYPES_CUH_ \ No newline at end of file diff --git a/kernels/torch_bindings.cpp b/kernels/torch_bindings.cpp index 803a0d7f4a..2c66a880ef 100644 --- a/kernels/torch_bindings.cpp +++ b/kernels/torch_bindings.cpp @@ -186,6 +186,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.impl("apply_repetition_penalties_", torch::kCUDA, &apply_repetition_penalties_); + ops.def( + "topk_topp_sampling(Tensor! logits, Tensor! output_ids, Tensor top_k_values, " + "Tensor? top_p_values, Tensor? curand_states, Tensor!? output_logprobs, bool normalize_logprobs) -> ()"); + ops.impl("topk_topp_sampling", torch::kCUDA, &topk_topp_sampling); + // Layernorm-quant // Apply Root Mean Square (RMS) Normalization to the input tensor. ops.def( @@ -702,65 +707,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "CUBLAS_M_THRESHOLD, bool has_zp, bool n32k16_reorder) -> Tensor"); // conditionally compiled so impl in source file - // Sampling Kernels - ops.def( - "sampling_from_probs(Tensor probs, Tensor uniform_samples, bool " - "deterministic) -> Tensor", - &sampling_from_probs); - ops.impl("sampling_from_probs", torch::kCUDA, &sampling_from_probs); - - ops.def( - "top_k_sampling_from_probs(Tensor probs, Tensor uniform_samples," - " Tensor? maybe_top_k_arr, int top_k_val," - " bool deterministic) -> Tensor[]", - &top_k_sampling_from_probs); - ops.impl("top_k_sampling_from_probs", torch::kCUDA, - &top_k_sampling_from_probs); - - ops.def( - "min_p_sampling_from_probs(Tensor probs, Tensor uniform_samples," - " Tensor? maybe_min_p_arr, float min_p_val," - " bool deterministic) -> Tensor[]", - &min_p_sampling_from_probs); - ops.impl("min_p_sampling_from_probs", torch::kCUDA, - &min_p_sampling_from_probs); - - ops.def( - "top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples," - " Tensor? maybe_top_p_arr, float top_p_val," - " bool deterministic) -> Tensor[]", - &top_p_sampling_from_probs); - ops.impl("top_p_sampling_from_probs", torch::kCUDA, - &top_p_sampling_from_probs); - - ops.def( - "top_k_top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples," - " Tensor? maybe_top_k_arr, float top_k_val," - " Tensor? maybe_top_p_arr, float top_p_val," - " bool deterministic) -> Tensor[]", - &top_k_top_p_sampling_from_probs); - ops.impl("top_k_top_p_sampling_from_probs", torch::kCUDA, - &top_k_top_p_sampling_from_probs); - - ops.def( - "top_k_renorm_prob(Tensor probs, Tensor? maybe_top_k_arr, int top_k_val) " - "-> Tensor", - &top_k_renorm_prob); - ops.impl("top_k_renorm_prob", torch::kCUDA, &top_k_renorm_prob); - - ops.def( - "top_p_renorm_prob(Tensor probs, Tensor? maybe_top_p_arr, float " - "top_p_val) " - "-> Tensor", - &top_p_renorm_prob); - ops.impl("top_p_renorm_prob", torch::kCUDA, &top_p_renorm_prob); - - ops.def( - "top_k_mask_logits(Tensor logits, Tensor? maybe_top_k_arr, int " - "top_k_val) -> Tensor", - &top_k_mask_logits); - ops.impl("top_k_mask_logits", torch::kCUDA, &top_k_mask_logits); - #endif }