From 4b1a350ab40c89d96ee241dd4596622ef450e640 Mon Sep 17 00:00:00 2001 From: AlpinDale Date: Wed, 5 Nov 2025 20:18:45 +0000 Subject: [PATCH] [kernel][moe] better splitK for fused moe Signed-off-by: AlpinDale --- .../modeling/layers/fused_moe/fused_moe.py | 62 ++++++++++++++----- 1 file changed, 46 insertions(+), 16 deletions(-) diff --git a/aphrodite/modeling/layers/fused_moe/fused_moe.py b/aphrodite/modeling/layers/fused_moe/fused_moe.py index df51f00cba..38a5d5fadd 100644 --- a/aphrodite/modeling/layers/fused_moe/fused_moe.py +++ b/aphrodite/modeling/layers/fused_moe/fused_moe.py @@ -107,8 +107,8 @@ def fused_moe_kernel_gptq_awq( BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, SPLIT_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, MUL_ROUTED_WEIGHT: tl.constexpr, top_k: tl.constexpr, compute_type: tl.constexpr, @@ -323,8 +323,8 @@ def fused_moe_kernel( BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, SPLIT_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, MUL_ROUTED_WEIGHT: tl.constexpr, top_k: tl.constexpr, compute_type: tl.constexpr, @@ -363,7 +363,7 @@ def fused_moe_kernel( # ----------------------------------------------------------- # Map program ids `pid` to the block of C it should compute. # This is done in a grouped ordering to promote L2 data reuse. - pid = tl.program_id(axis=0) + pid = tl.program_id(axis=1) num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) num_pid_in_group = GROUP_SIZE_M * num_pid_n @@ -372,6 +372,7 @@ def fused_moe_kernel( group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) pid_n = (pid % num_pid_in_group) // group_size_m + pid_k = tl.program_id(axis=0) # ---------------------------------------------------------- # Create pointers for the first blocks of A and B. @@ -406,7 +407,7 @@ def fused_moe_kernel( return offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N - offs_k = tl.arange(0, BLOCK_SIZE_K) + offs_k = tl.arange(0, BLOCK_SIZE_K) + pid_k * BLOCK_SIZE_K a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak) b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) @@ -441,21 +442,22 @@ def fused_moe_kernel( # of fp32 values for higher accuracy. # `accumulator` will be converted back to fp16 after the loop. accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)): # Load the next block of A and B, generate a mask by checking the # K dimension. + k_remaining = K - k * BLOCK_SIZE_K * SPLIT_K a = tl.load( a_ptrs, - mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), + mask=token_mask[:, None] & (offs_k[None, :] < k_remaining), other=0.0, ) - b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0) # We accumulate along the K dimension. if use_int8_w8a16: accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) elif use_fp8_w8a8 or use_int8_w8a8: if group_k > 0 and group_n > 0: - k_start = k * BLOCK_SIZE_K + k_start = pid_k * BLOCK_SIZE_K + k * BLOCK_SIZE_K * SPLIT_K offs_ks = k_start // group_k a_scale = tl.load(a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0) b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk) @@ -470,9 +472,11 @@ def fused_moe_kernel( else: accumulator += tl.dot(a, b) # Advance the ptrs to the next K block. - a_ptrs += BLOCK_SIZE_K * stride_ak - b_ptrs += BLOCK_SIZE_K * stride_bk - if HAS_BIAS: + a_ptrs += (BLOCK_SIZE_K * SPLIT_K) * stride_ak + b_ptrs += (BLOCK_SIZE_K * SPLIT_K) * stride_bk + + # Only add bias in the first k partition + if pid_k == 0 and HAS_BIAS: accumulator = accumulator + bias[None, :] if MUL_ROUTED_WEIGHT: moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) @@ -492,7 +496,18 @@ def fused_moe_kernel( offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] c_mask = token_mask[:, None] & (offs_cn[None, :] < N) - tl.store(c_ptrs, accumulator, mask=c_mask) + if SPLIT_K == 1: + tl.store(c_ptrs, accumulator, mask=c_mask) + else: + tl.atomic_add(c_ptrs, accumulator, mask=c_mask, sem="relaxed") + + +def _zero_output(*args, **kwargs): + if kwargs["SPLIT_K"] != 1: + args[2].zero_() + + +fused_moe_kernel.add_pre_run_hook(_zero_output) def invoke_fused_moe_kernel( @@ -515,6 +530,7 @@ def invoke_fused_moe_kernel( use_int8_w8a16: bool, use_int4_w4a16: bool, per_channel_quant: bool, + do_split_k: bool = False, block_shape: list[int] | None = None, B_bias: torch.Tensor | None = None, ) -> None: @@ -544,12 +560,15 @@ def invoke_fused_moe_kernel( # so num_valid_experts <= batch_size <= BLOCK_SIZE_M, # and we can skip some invalid blocks. EM = min(sorted_token_ids.size(0), A.size(0) * top_k * config["BLOCK_SIZE_M"]) - grid = lambda META: (triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(B.size(1), META["BLOCK_SIZE_N"]),) + HAS_BIAS = B_bias is not None if (use_int8_w8a16 or use_int4_w4a16) and block_shape is not None and block_shape[1] > 0: assert B_scale is not None and B_scale.ndim == 3 assert B_zp is None or B_zp.ndim == 3 + # TODO: add splitk to this kernel + grid = lambda META: (triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(B.size(1), META["BLOCK_SIZE_N"]),) + use_moe_wna16_cuda = should_moe_wna16_use_cuda( num_valid_tokens=num_tokens, group_size=block_shape[1], @@ -628,11 +647,17 @@ def invoke_fused_moe_kernel( **config, ) else: + grid = lambda META: ( + META["SPLIT_K"], + triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(B.size(1), META["BLOCK_SIZE_N"]), + ) config = config.copy() config["SPLIT_K"] = 1 BLOCK_SIZE_K = config.pop("BLOCK_SIZE_K") if block_shape is not None: BLOCK_SIZE_K = min(BLOCK_SIZE_K, min(block_shape[0], block_shape[1])) + if not do_split_k: + config["SPLIT_K"] = 1 fused_moe_kernel[grid]( A, B, @@ -922,8 +947,8 @@ def get_default_config( "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": block_shape[0], "BLOCK_SIZE_K": block_shape[1], - "GROUP_SIZE_M": 32, "SPLIT_K": 1, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3 if not current_platform.is_rocm() else 2, } @@ -946,16 +971,16 @@ def get_default_config( "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, "SPLIT_K": 1, + "GROUP_SIZE_M": 1, } else: config = { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, "SPLIT_K": 1, + "GROUP_SIZE_M": 8, } return config @@ -989,6 +1014,9 @@ def try_get_optimal_moe_config( else: # Else use the default config config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, block_shape) + # Add SPLIT_K if not present + if "SPLIT_K" not in config: + config["SPLIT_K"] = 1 return config @@ -1780,6 +1808,7 @@ def fused_experts_impl( per_channel_quant=per_channel_quant, block_shape=block_shape, B_bias=w1_bias, + do_split_k=True, ) # Activation function with multiplication @@ -1973,6 +2002,7 @@ def apply( per_channel_quant=self.per_act_token_quant, block_shape=self.block_shape, B_bias=self.w1_bias, + do_split_k=True, ) self.activation(activation, intermediate_cache2, intermediate_cache1.view(-1, N))