Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,8 @@ set(APHRODITE_EXT_SRC
"kernels/activation_kernels.cu"
"kernels/layernorm_kernels.cu"
"kernels/layernorm_quant_kernels.cu"
"kernels/sampler.cu"
"kernels/sampling/repetition_penalty.cu"
"kernels/sampling/topk_topp.cu"
"kernels/cuda_view.cu"
"kernels/quantization/squeezellm/quant_cuda_kernel.cu"
"kernels/quantization/gptq/q_gemm.cu"
Expand Down Expand Up @@ -300,7 +301,6 @@ set(APHRODITE_EXT_SRC
"kernels/quantization/awq/gemm_kernels.cu"
"kernels/quantization/quip/origin_order.cu"
"kernels/permute_cols.cu"
"kernels/sampling/sampling.cu"
"kernels/quantization/cutlass_w8a8/scaled_mm_entry.cu"
"kernels/quantization/fp4/nvfp4_quant_entry.cu"
"kernels/quantization/fp4/nvfp4_scaled_mm_entry.cu"
Expand Down
150 changes: 24 additions & 126 deletions aphrodite/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,30 @@ def apply_repetition_penalties(logits: torch.Tensor, prompt_mask: torch.Tensor,
repetition_penalties)


def apply_top_k_top_p_cuda(
logits: torch.Tensor,
output_ids: torch.Tensor,
top_k_values: torch.Tensor,
top_p_values: Optional[torch.Tensor] = None,
curand_states: Optional[torch.Tensor] = None,
output_logprobs: Optional[torch.Tensor] = None,
normalize_logprobs: bool = False,
) -> None:
"""Apply top-k and top-p sampling using CUDA kernel.
Args:
logits: The logits tensor of shape [num_seqs, vocab_size].
output_ids: Output tensor for sampled token ids [num_seqs].
top_k_values: Top-k values per sequence [num_seqs].
top_p_values: Optional top-p values per sequence [num_seqs].
curand_states: Optional CUDA random states for sampling [num_seqs].
output_logprobs: Optional output for log probabilities [num_seqs].
normalize_logprobs: Whether to normalize log probabilities.
"""
torch.ops._C.topk_topp_sampling(
logits, output_ids, top_k_values, top_p_values, curand_states,
output_logprobs, normalize_logprobs)


def advance_step_flashattn(num_seqs: int, num_queries: int, block_size: int,
input_tokens: torch.Tensor,
sampled_token_ids: torch.Tensor,
Expand Down Expand Up @@ -2020,129 +2044,3 @@ def int8_scaled_mm_with_quant_fake(
M = mat1.size(0)
N = mat2.size(0)
return torch.empty((M, N), dtype=out_dtype)


# Sampling Kernels
def sampling_from_probs(probs: torch.Tensor,
uniform_samplers: torch.Tensor,
deterministic: bool = True,
check_nan: bool = False) -> torch.Tensor:
if check_nan and torch.any(torch.isnan(probs)):
raise ValueError("NaN detected in probs")
return torch.ops._C.sampling_from_probs(probs, uniform_samplers,
deterministic)


def _to_tensor_scalar_tuple(x):
if isinstance(x, torch.Tensor):
return (x, 0)
else:
return (None, x)


def top_p_sampling_from_probs(
probs: torch.Tensor,
uniform_samples: torch.Tensor,
top_p: Union[torch.Tensor, float],
deterministic: bool = True,
check_nan: bool = False) -> tuple[torch.Tensor, torch.Tensor]:
if check_nan and torch.any(torch.isnan(probs)):
raise ValueError("NaN detected in probs")
return torch.ops._C.top_p_sampling_from_probs(
probs, uniform_samples, *_to_tensor_scalar_tuple(top_p), deterministic)


def top_k_sampling_from_probs(
probs: torch.Tensor,
uniform_samples: torch.Tensor,
top_k: Union[torch.Tensor, int],
deterministic: bool = True,
check_nan: bool = False) -> tuple[torch.Tensor, torch.Tensor]:
if check_nan and torch.any(torch.isnan(probs)):
raise ValueError("NaN detected in probs")
return torch.ops._C.top_k_sampling_from_probs(
probs, uniform_samples, *_to_tensor_scalar_tuple(top_k), deterministic)


def min_p_sampling_from_probs(
probs: torch.Tensor,
uniform_samples: torch.Tensor,
min_p: Union[torch.Tensor, float],
deterministic: bool = True,
check_nan: bool = False) -> tuple[torch.Tensor, torch.Tensor]:
if check_nan and torch.any(torch.isnan(probs)):
raise ValueError("NaN detected in probs")
return torch.ops._C.min_p_sampling_from_probs(
probs, uniform_samples, *_to_tensor_scalar_tuple(min_p), deterministic)


def top_k_mask_logits(
logits: torch.Tensor,
top_k: Union[torch.Tensor, int],
) -> torch.Tensor:
return torch.ops._C.top_k_mask_logits(logits,
*_to_tensor_scalar_tuple(top_k))


def top_p_renorm_prob(
probs: torch.Tensor,
top_p: Union[torch.Tensor, float],
) -> torch.Tensor:
return torch.ops._C.top_p_renorm_prob(probs,
*_to_tensor_scalar_tuple(top_p))


def top_k_renorm_prob(
probs: torch.Tensor,
top_k: Union[torch.Tensor, int],
) -> torch.Tensor:
return torch.ops._C.top_k_renorm_prob(probs,
*_to_tensor_scalar_tuple(top_k))


def top_k_top_p_sampling_from_logits(
probs: torch.Tensor,
uniform_samples: torch.Tensor,
top_k: Union[torch.Tensor, int],
top_p: Union[torch.Tensor, float],
filter_apply_order: str = "top_k_first",
deterministic: bool = True,
check_nan: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
if filter_apply_order == "top_k_first":
masked_logits = top_k_mask_logits(probs, top_k)
probs = torch.softmax(masked_logits, dim=-1)
return top_p_sampling_from_probs(probs, uniform_samples, top_p,
deterministic, check_nan)
elif filter_apply_order == "joint":
probs = torch.softmax(probs, dim=-1)
if check_nan and torch.any(torch.isnan(probs)):
raise ValueError("NaN detected in probs")
return torch.ops._C.top_k_top_p_sampling_from_logits(
probs, uniform_samples, *_to_tensor_scalar_tuple(top_k),
*_to_tensor_scalar_tuple(top_p), deterministic)
else:
raise ValueError(f"Invalid filter_apply_order: {filter_apply_order}")


def top_k_top_p_sampling_from_probs(
probs: torch.Tensor,
uniform_samples: torch.Tensor,
top_k: Union[torch.Tensor, int],
top_p: Union[torch.Tensor, float],
filter_apply_order: str = "top_k_first",
deterministic: bool = True,
check_nan: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
if filter_apply_order == "top_k_first":
renorm_probs = top_k_renorm_prob(probs, top_k)
return top_p_sampling_from_probs(renorm_probs, uniform_samples, top_p,
deterministic, check_nan)
elif filter_apply_order == "joint":
if check_nan and torch.any(torch.isnan(probs)):
raise ValueError("NaN detected in probs")
return torch.ops._C.top_k_top_p_sampling_from_probs(
probs, uniform_samples, *_to_tensor_scalar_tuple(top_k),
*_to_tensor_scalar_tuple(top_p), deterministic)
else:
raise ValueError(f"Invalid filter_apply_order: {filter_apply_order}")
2 changes: 1 addition & 1 deletion aphrodite/common/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@
APHRODITE_REQUEST_LEVEL_METRICS: bool = False
APHRODITE_USE_SAMPLING_KERNELS: bool = False
APHRODITE_NO_DEPRECATION_WARNING: bool = False
APHRODITE_DISABLE_FLASH_ATTN: bool = False
APHRODITE_DISABLE_FLASH_ATTN_COMPILE: bool = False
APHRODITE_DYNAMIC_ROPE_SCALING: bool = False
APHRODITE_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False
APHRODITE_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False
Expand Down
1 change: 1 addition & 0 deletions aphrodite/v1/sample/ops/temperatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

_SAMPLING_EPS = 1e-5


def _tensor_or_zeros(tens, like_tensor):
return tens if tens is not None else torch.zeros_like(like_tensor)

Expand Down
85 changes: 77 additions & 8 deletions aphrodite/v1/sample/ops/topk_topp_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from aphrodite.common import envs
from aphrodite.common.logger import log_once
from aphrodite.platforms import current_platform
from aphrodite._custom_ops import apply_top_k_top_p_cuda

try:
import flashinfer.sampling
Expand All @@ -18,8 +19,9 @@
try:
from aphrodite.distributed.parallel_state import (
get_tensor_model_parallel_rank)
rank = get_tensor_model_parallel_rank()
except Exception:
get_tensor_model_parallel_rank = lambda: 0
rank = 0


class TopKTopPSampler(nn.Module):
Expand All @@ -45,28 +47,34 @@ def __init__(self):
# earlier design.
# https://github.com/flashinfer-ai/flashinfer/releases/
# tag/v0.2.3
if get_tensor_model_parallel_rank() == 0:
if rank == 0:
logger.info(
"FlashInfer version >= 0.2.3 required. "
"Falling back to default sampling implementation.")
self.forward = self.forward_native
elif envs.APHRODITE_USE_SAMPLING_KERNELS is not False:
elif envs.APHRODITE_USE_SAMPLING_KERNELS is True:
# Use custom CUDA kernel for top-k/top-p sampling
if rank == 0:
logger.info("Using custom CUDA kernel for top-p & "
"top-k sampling.")
self.forward = self.forward_cuda_kernel
elif envs.APHRODITE_USE_FLASHINFER_SAMPLER is not None:
# NOTE: The V0 sampler doesn't use FlashInfer for
# sampling unless APHRODITE_USE_SAMPLING_KERNELS=1 (i.e., by
# sampling unless APHRODITE_USE_FLASHINFER_SAMPLER=1 (i.e., by
# default it is unused). For backward compatibility, we set
# `APHRODITE_USE_SAMPLING_KERNELS` as None by default and
# `APHRODITE_USE_FLASHINFER_SAMPLER` as None by default and
# interpret it differently in V0 and V1 samplers: In V0,
# None means False, while in V1, None means True. This is
# why we use the condition
# `envs.APHRODITE_USE_SAMPLING_KERNELS is not False` here.
# `envs.APHRODITE_USE_FLASHINFER_SAMPLER is not None` here.
logger.info("Using FlashInfer for top-p & top-k sampling.")
self.forward = self.forward_cuda
else:
if get_tensor_model_parallel_rank() == 0:
logger.warning(
"FlashInfer is available, but it is not enabled. "
"Falling back to the PyTorch-native implementation "
"of top-p & top-k sampling. For the best "
"Falling back to the PyTorch-native implementation"
" of top-p & top-k sampling. For the best "
"performance, please set "
"APHRODITE_USE_SAMPLING_KERNELS=1.")
self.forward = self.forward_native
Expand Down Expand Up @@ -136,6 +144,67 @@ def forward_tpu(
probs = logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators)

def forward_cuda_kernel(
self,
logits: torch.Tensor,
generators: dict[int, torch.Generator],
k: Optional[torch.Tensor],
p: Optional[torch.Tensor],
) -> torch.Tensor:
"""Use custom CUDA kernel for top-k and top-p sampling."""
if k is None and p is None:
# No filtering needed, use regular sampling
probs = logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators)

if generators:
log_once(
"WARNING",
"Custom CUDA kernel does not support per-request generators. "
"Falling back to PyTorch-native implementation.")
return self.forward_native(logits, generators, k, p)

num_seqs = logits.size(0)
vocab_size = logits.size(1)

# Prepare output tensor for the CUDA kernel
output_ids = torch.empty(num_seqs, dtype=torch.int64,
device=logits.device)

# Prepare top-k and top-p values
# Convert to the format expected by CUDA kernel
if k is not None:
top_k_values = k.to(dtype=torch.int32, device=logits.device)
else:
top_k_values = torch.full((num_seqs,), vocab_size,
dtype=torch.int32, device=logits.device)

if p is not None:
top_p_values = p.to(dtype=torch.float32, device=logits.device)
else:
top_p_values = None

# Call the CUDA kernel
# Note: We don't use curand_states for now, relying on the
# kernel's internal randomness
try:
apply_top_k_top_p_cuda(
logits=logits,
output_ids=output_ids,
top_k_values=top_k_values,
top_p_values=top_p_values,
curand_states=None, # Not using CUDA random states for now
output_logprobs=None, # Not requesting log probabilities
normalize_logprobs=False
)
return output_ids
except Exception as e:
log_once(
"WARNING",
f"Custom CUDA kernel failed: {e}. Falling back to "
"PyTorch-native implementation.")
return self.forward_native(logits, generators, k, p)


def apply_top_k_top_p_tpu(
logits: torch.Tensor,
Expand Down
40 changes: 7 additions & 33 deletions kernels/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,13 @@ void apply_repetition_penalties_(torch::Tensor& logits,
const torch::Tensor& output_mask,
const torch::Tensor& repetition_penalties);

void topk_topp_sampling(torch::Tensor& logits, torch::Tensor& output_ids,
const torch::Tensor& top_k_values,
const std::optional<torch::Tensor>& top_p_values,
const std::optional<torch::Tensor>& curand_states,
std::optional<torch::Tensor>& output_logprobs,
bool normalize_logprobs = false);

void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
torch::Tensor& weight, torch::Tensor& scale,
double epsilon);
Expand Down Expand Up @@ -205,39 +212,6 @@ void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta,

torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm);

// Sampling kernels
#ifndef USE_ROCM
torch::Tensor sampling_from_probs(torch::Tensor probs,
torch::Tensor uniform_samples,
bool deterministic);
std::vector<torch::Tensor> top_p_sampling_from_probs(
torch::Tensor probs, torch::Tensor uniform_samples,
std::optional<torch::Tensor> maybe_top_p_arr, double top_p_val,
bool deterministic);
std::vector<torch::Tensor> top_k_sampling_from_probs(
torch::Tensor probs, torch::Tensor uniform_samples,
std::optional<torch::Tensor> maybe_top_k_arr, int64_t top_k_val,
bool deterministic);
std::vector<torch::Tensor> min_p_sampling_from_probs(
torch::Tensor probs, torch::Tensor uniform_samples,
std::optional<torch::Tensor> maybe_min_p_arr, double min_p_val,
bool deterministic);
std::vector<torch::Tensor> top_k_top_p_sampling_from_probs(
torch::Tensor probs, torch::Tensor uniform_samples,
std::optional<torch::Tensor> maybe_top_k_arr, double top_k_val,
std::optional<torch::Tensor> maybe_top_p_arr, double top_p_val,
bool deterministic);
torch::Tensor top_p_renorm_prob(torch::Tensor probs,
std::optional<torch::Tensor> maybe_top_p_arr,
double top_p_val);
torch::Tensor top_k_renorm_prob(torch::Tensor probs,
std::optional<torch::Tensor> maybe_top_k_arr,
int64_t top_k_val);
torch::Tensor top_k_mask_logits(torch::Tensor logits,
std::optional<torch::Tensor> maybe_top_k_arr,
int64_t top_k_val);

#endif

// Quantization kernels
#ifndef USE_ROCM
Expand Down
Loading
Loading