From 4659e934c0403417c438c09c4d53f4aa31edbe23 Mon Sep 17 00:00:00 2001 From: AlpinDale Date: Sat, 2 May 2026 03:27:30 +0430 Subject: [PATCH] feat: add silu clamp limit to shared expert for DeepSeek-V4 Signed-off-by: AlpinDale --- aphrodite/model_executor/layers/activation.py | 40 +++++++ .../layers/fused_moe/cpu_fused_moe.py | 2 +- .../model_executor/models/deepseek_v4.py | 56 ++++++++- csrc/activation_kernels.cu | 110 ++++++++++++++---- csrc/ops.h | 2 + csrc/torch_bindings.cpp | 6 + 6 files changed, 187 insertions(+), 29 deletions(-) diff --git a/aphrodite/model_executor/layers/activation.py b/aphrodite/model_executor/layers/activation.py index f8e440c2ce..c8e61731ee 100644 --- a/aphrodite/model_executor/layers/activation.py +++ b/aphrodite/model_executor/layers/activation.py @@ -157,6 +157,46 @@ def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: return self.forward_cuda(x) +@CustomOp.register("silu_and_mul_with_clamp") +class SiluAndMulWithClamp(CustomOp): + """SwiGLU activation with input clamping (used by some MoE shared experts). + + Computes: + gate = clamp(x[..., :d], max=swiglu_limit) + up = clamp(x[..., d:], min=-swiglu_limit, max=swiglu_limit) + out = silu(gate) * up + where d = x.shape[-1] // 2. + + Shapes: + x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d) + return: (num_tokens, d) or (batch_size, seq_len, d) + """ + + def __init__(self, swiglu_limit: float, *, compile_native: bool = True): + super().__init__(compile_native=compile_native) + self.swiglu_limit = float(swiglu_limit) + if current_platform.is_cuda_alike() or current_platform.is_xpu(): + self.op = torch.ops._C.silu_and_mul_with_clamp + elif current_platform.is_cpu(): + self._forward_method = self.forward_native + + def forward_native(self, x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] // 2 + gate = torch.clamp(x[..., :d], max=self.swiglu_limit) + up = torch.clamp(x[..., d:], min=-self.swiglu_limit, max=self.swiglu_limit) + return F.silu(gate) * up + + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] // 2 + output_shape = x.shape[:-1] + (d,) + out = torch.empty(output_shape, dtype=x.dtype, device=x.device) + self.op(out, x, self.swiglu_limit) + return out + + def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: + return self.forward_cuda(x) + + # --8<-- [start:mul_and_silu] @CustomOp.register("mul_and_silu") class MulAndSilu(CustomOp): diff --git a/aphrodite/model_executor/layers/fused_moe/cpu_fused_moe.py b/aphrodite/model_executor/layers/fused_moe/cpu_fused_moe.py index 3809b0c111..56801ee9bd 100644 --- a/aphrodite/model_executor/layers/fused_moe/cpu_fused_moe.py +++ b/aphrodite/model_executor/layers/fused_moe/cpu_fused_moe.py @@ -45,7 +45,7 @@ def _gelu_and_mul( # Uses static methods or standalone functions to avoid instantiating CustomOp # classes, which would call get_current_aphrodite_config() before config is set. _CPU_MOE_ACT_FN: dict[MoEActivation, Callable[[torch.Tensor], torch.Tensor]] = { - MoEActivation.SILU: SiluAndMul.forward_native, + MoEActivation.SILU: lambda x: SiluAndMul(compile_native=False).forward_native(x), MoEActivation.SWIGLUOAI: _swigluoai_forward_native, MoEActivation.GELU: _gelu_and_mul, } diff --git a/aphrodite/model_executor/models/deepseek_v4.py b/aphrodite/model_executor/models/deepseek_v4.py index b945c62382..ab94b37f74 100644 --- a/aphrodite/model_executor/models/deepseek_v4.py +++ b/aphrodite/model_executor/models/deepseek_v4.py @@ -17,6 +17,7 @@ get_tensor_model_parallel_world_size, ) from aphrodite.forward_context import get_forward_context +from aphrodite.model_executor.layers.activation import SiluAndMul, SiluAndMulWithClamp from aphrodite.model_executor.layers.deepseek_v4_attention import ( DeepseekV4Indexer, DeepseekV4MLAModules, @@ -34,7 +35,7 @@ RowParallelLinear, ) from aphrodite.model_executor.layers.logits_processor import LogitsProcessor -from aphrodite.model_executor.layers.quantization import QuantizationMethods +from aphrodite.model_executor.layers.quantization import QuantizationConfig, QuantizationMethods from aphrodite.model_executor.layers.quantization.fp8 import Fp8Config from aphrodite.model_executor.layers.quantization.mxfp4 import Mxfp4MoEMethod from aphrodite.model_executor.layers.quantization.utils.quant_utils import ( @@ -46,7 +47,6 @@ VocabParallelEmbedding, ) from aphrodite.model_executor.model_loader.weight_utils import default_weight_loader -from aphrodite.model_executor.models.deepseek_v2 import DeepseekV2MLP from aphrodite.model_executor.utils import set_weight_attrs from aphrodite.platforms import current_platform from aphrodite.sequence import IntermediateTensors @@ -63,6 +63,55 @@ ) +class DeepseekV4MLP(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + swiglu_limit: float | None = None, + quant_config: QuantizationConfig | None = None, + reduce_results: bool = True, + is_sequence_parallel: bool = False, + prefix: str = "", + ) -> None: + super().__init__() + + # If is_sequence_parallel, the input and output tensors are sharded + # across the ranks within the tp_group. In this case the weights are + # replicated and no collective ops are needed. + # Otherwise we use standard TP with an allreduce at the end. + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + disable_tp=is_sequence_parallel, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + disable_tp=is_sequence_parallel, + prefix=f"{prefix}.down_proj", + ) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. Only silu is supported for now.") + if swiglu_limit is not None: + self.act_fn = SiluAndMulWithClamp(swiglu_limit) + else: + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + class DeepseekV4FP8Config(Fp8Config): """FP8 config that routes MoE layers to MXFP4 quantization. @@ -642,10 +691,11 @@ def __init__( else: intermediate_size = config.moe_intermediate_size * config.n_shared_experts - self.shared_experts = DeepseekV2MLP( + self.shared_experts = DeepseekV4MLP( hidden_size=config.hidden_size, intermediate_size=intermediate_size, hidden_act=config.hidden_act, + swiglu_limit=self.swiglu_limit, quant_config=quant_config, reduce_results=self.use_mega_moe, prefix=f"{prefix}.shared_experts", diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index 9016f7aaa5..f5d18c957f 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -11,29 +11,74 @@ namespace aphrodite { template + bool act_first, bool HAS_CLAMP> __device__ __forceinline__ scalar_t compute(const scalar_t& x, - const scalar_t& y) { - return act_first ? ACT_FN(x) * y : x * ACT_FN(y); + const scalar_t& y, + const float limit) { + if constexpr (act_first) { + scalar_t gate = x; + scalar_t up = y; + if constexpr (HAS_CLAMP) { + gate = (scalar_t)fminf((float)gate, limit); + up = (scalar_t)fmaxf(fminf((float)up, limit), -limit); + } + return ACT_FN(gate) * up; + } else { + scalar_t gate = x; + scalar_t up = y; + if constexpr (HAS_CLAMP) { + gate = (scalar_t)fmaxf(fminf((float)gate, limit), -limit); + up = (scalar_t)fminf((float)up, limit); + } + return gate * ACT_FN(up); + } } template + bool act_first, bool HAS_CLAMP> __device__ __forceinline__ packed_t packed_compute(const packed_t& x, - const packed_t& y) { - return act_first ? packed_mul(PACKED_ACT_FN(x), y) - : packed_mul(x, PACKED_ACT_FN(y)); + const packed_t& y, + const float limit) { + if constexpr (act_first) { + packed_t gate = x; + packed_t up = y; + if constexpr (HAS_CLAMP) { + float2 g = cast_to_float2(gate); + float2 u = cast_to_float2(up); + g.x = fminf(g.x, limit); + g.y = fminf(g.y, limit); + u.x = fmaxf(fminf(u.x, limit), -limit); + u.y = fmaxf(fminf(u.y, limit), -limit); + gate = cast_to_packed(g); + up = cast_to_packed(u); + } + return packed_mul(PACKED_ACT_FN(gate), up); + } else { + packed_t gate = x; + packed_t up = y; + if constexpr (HAS_CLAMP) { + float2 g = cast_to_float2(gate); + float2 u = cast_to_float2(up); + g.x = fmaxf(fminf(g.x, limit), -limit); + g.y = fmaxf(fminf(g.y, limit), -limit); + u.x = fminf(u.x, limit); + u.y = fminf(u.y, limit); + gate = cast_to_packed(g); + up = cast_to_packed(u); + } + return packed_mul(gate, PACKED_ACT_FN(up)); + } } // Activation and gating kernel template. template + bool use_vec, bool HAS_CLAMP, bool use_256b = false> __global__ void act_and_mul_kernel( scalar_t* __restrict__ out, // [..., d] const scalar_t* __restrict__ input, // [..., 2, d] - const int d) { + const int d, const float limit) { const scalar_t* x_ptr = input + blockIdx.x * 2 * d; const scalar_t* y_ptr = x_ptr + d; scalar_t* out_ptr = out + blockIdx.x * d; @@ -58,8 +103,9 @@ __global__ void act_and_mul_kernel( } #pragma unroll for (int j = 0; j < pvec_t::NUM_ELTS; j++) { - x.elts[j] = packed_compute( - x.elts[j], y.elts[j]); + x.elts[j] = + packed_compute( + x.elts[j], y.elts[j], limit); } if constexpr (use_256b) { st256(x, &out_vec[i]); @@ -72,7 +118,8 @@ __global__ void act_and_mul_kernel( for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { const scalar_t x = APHRODITE_LDG(&x_ptr[idx]); const scalar_t y = APHRODITE_LDG(&y_ptr[idx]); - out_ptr[idx] = compute(x, y); + out_ptr[idx] = + compute(x, y, limit); } } } @@ -176,8 +223,11 @@ packed_gelu_tanh_kernel(const packed_t& val) { // Launch activation and gating kernel. // Use ACT_FIRST (bool) indicating whether to apply the activation function -// first. -#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL, PACKED_KERNEL, ACT_FIRST) \ +// first. HAS_CLAMP (bool) enables pre-activation clamping: gate input is +// clamped (max only) and up input is clamped (both sides) before the +// activation function is applied. +#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL, PACKED_KERNEL, ACT_FIRST, \ + HAS_CLAMP, LIMIT) \ auto dtype = input.scalar_type(); \ int d = input.size(-1) / 2; \ int64_t num_tokens = input.numel() / input.size(-1); \ @@ -203,8 +253,8 @@ packed_gelu_tanh_kernel(const packed_t& val) { KERNEL, \ PACKED_KERNEL< \ typename aphrodite::PackedTypeConverter::Type>, \ - ACT_FIRST, true, true><<>>( \ - out.data_ptr(), input.data_ptr(), d); \ + ACT_FIRST, true, HAS_CLAMP, true><<>>( \ + out.data_ptr(), input.data_ptr(), d, LIMIT); \ }); \ } else { \ APHRODITE_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \ @@ -213,8 +263,8 @@ packed_gelu_tanh_kernel(const packed_t& val) { KERNEL, \ PACKED_KERNEL< \ typename aphrodite::PackedTypeConverter::Type>, \ - ACT_FIRST, true, false><<>>( \ - out.data_ptr(), input.data_ptr(), d); \ + ACT_FIRST, true, HAS_CLAMP, false><<>>( \ + out.data_ptr(), input.data_ptr(), d, LIMIT); \ }); \ } \ } else { \ @@ -225,16 +275,24 @@ packed_gelu_tanh_kernel(const packed_t& val) { KERNEL, \ PACKED_KERNEL< \ typename aphrodite::PackedTypeConverter::Type>, \ - ACT_FIRST, false><<>>( \ - out.data_ptr(), input.data_ptr(), d); \ + ACT_FIRST, false, HAS_CLAMP><<>>( \ + out.data_ptr(), input.data_ptr(), d, LIMIT); \ }); \ } void silu_and_mul(torch::Tensor& out, // [..., d] torch::Tensor& input) // [..., 2 * d] { + LAUNCH_ACTIVATION_GATE_KERNEL( + aphrodite::silu_kernel, aphrodite::packed_silu_kernel, true, false, 0.0f); +} + +void silu_and_mul_clamp(torch::Tensor& out, // [..., d] + torch::Tensor& input, // [..., 2 * d] + double limit) { LAUNCH_ACTIVATION_GATE_KERNEL(aphrodite::silu_kernel, - aphrodite::packed_silu_kernel, true); + aphrodite::packed_silu_kernel, true, true, + (float)limit); } void mul_and_silu(torch::Tensor& out, // [..., d] @@ -243,7 +301,8 @@ void mul_and_silu(torch::Tensor& out, // [..., d] // The difference between mul_and_silu and silu_and_mul is that mul_and_silu // applies the silu to the latter half of the input. LAUNCH_ACTIVATION_GATE_KERNEL(aphrodite::silu_kernel, - aphrodite::packed_silu_kernel, false); + aphrodite::packed_silu_kernel, false, false, + 0.0f); } void silu_mul(torch::Tensor& out, torch::Tensor const& gate, @@ -304,15 +363,16 @@ void make_gate_up_indices(torch::Tensor& out, torch::Tensor const& indices, void gelu_and_mul(torch::Tensor& out, // [..., d] torch::Tensor& input) // [..., 2 * d] { - LAUNCH_ACTIVATION_GATE_KERNEL(aphrodite::gelu_kernel, - aphrodite::packed_gelu_kernel, true); + LAUNCH_ACTIVATION_GATE_KERNEL( + aphrodite::gelu_kernel, aphrodite::packed_gelu_kernel, true, false, 0.0f); } void gelu_tanh_and_mul(torch::Tensor& out, // [..., d] torch::Tensor& input) // [..., 2 * d] { LAUNCH_ACTIVATION_GATE_KERNEL(aphrodite::gelu_tanh_kernel, - aphrodite::packed_gelu_tanh_kernel, true); + aphrodite::packed_gelu_tanh_kernel, true, false, + 0.0f); } namespace aphrodite { diff --git a/csrc/ops.h b/csrc/ops.h index dac9137ede..9c5384e418 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -259,6 +259,8 @@ void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, void silu_and_mul(torch::Tensor& out, torch::Tensor& input); +void silu_and_mul_clamp(torch::Tensor& out, torch::Tensor& input, double limit); + void silu_mul(torch::Tensor& out, torch::Tensor const& gate, torch::Tensor const& up); diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 9f4788f15e..6240b37ec8 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -168,6 +168,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("silu_and_mul(Tensor! result, Tensor input) -> ()"); ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul); + // SwiGLU activation with input clamping. + ops.def( + "silu_and_mul_with_clamp(Tensor! result, Tensor input, float limit) " + "-> ()"); + ops.impl("silu_and_mul_with_clamp", torch::kCUDA, &silu_and_mul_clamp); + ops.def("silu_mul(Tensor! result, Tensor gate, Tensor up) -> ()"); ops.impl("silu_mul", torch::kCUDA, &silu_mul);