Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 46 additions & 16 deletions aphrodite/modeling/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ def fused_moe_kernel_gptq_awq(
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
SPLIT_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
MUL_ROUTED_WEIGHT: tl.constexpr,
top_k: tl.constexpr,
compute_type: tl.constexpr,
Expand Down Expand Up @@ -323,8 +323,8 @@ def fused_moe_kernel(
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
SPLIT_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
MUL_ROUTED_WEIGHT: tl.constexpr,
top_k: tl.constexpr,
compute_type: tl.constexpr,
Expand Down Expand Up @@ -363,7 +363,7 @@ def fused_moe_kernel(
# -----------------------------------------------------------
# Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse.
pid = tl.program_id(axis=0)
pid = tl.program_id(axis=1)
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
Expand All @@ -372,6 +372,7 @@ def fused_moe_kernel(
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
pid_k = tl.program_id(axis=0)

# ----------------------------------------------------------
# Create pointers for the first blocks of A and B.
Expand Down Expand Up @@ -406,7 +407,7 @@ def fused_moe_kernel(
return

offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
offs_k = tl.arange(0, BLOCK_SIZE_K) + pid_k * BLOCK_SIZE_K
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak)

b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
Expand Down Expand Up @@ -441,21 +442,22 @@ def fused_moe_kernel(
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop.
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)):
# Load the next block of A and B, generate a mask by checking the
# K dimension.
k_remaining = K - k * BLOCK_SIZE_K * SPLIT_K
a = tl.load(
a_ptrs,
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
mask=token_mask[:, None] & (offs_k[None, :] < k_remaining),
other=0.0,
)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0)
# We accumulate along the K dimension.
if use_int8_w8a16:
accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
elif use_fp8_w8a8 or use_int8_w8a8:
if group_k > 0 and group_n > 0:
k_start = k * BLOCK_SIZE_K
k_start = pid_k * BLOCK_SIZE_K + k * BLOCK_SIZE_K * SPLIT_K
offs_ks = k_start // group_k
a_scale = tl.load(a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0)
b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
Expand All @@ -470,9 +472,11 @@ def fused_moe_kernel(
else:
accumulator += tl.dot(a, b)
# Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
if HAS_BIAS:
a_ptrs += (BLOCK_SIZE_K * SPLIT_K) * stride_ak
b_ptrs += (BLOCK_SIZE_K * SPLIT_K) * stride_bk

# Only add bias in the first k partition
if pid_k == 0 and HAS_BIAS:
accumulator = accumulator + bias[None, :]
if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
Expand All @@ -492,7 +496,18 @@ def fused_moe_kernel(
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask)
if SPLIT_K == 1:
tl.store(c_ptrs, accumulator, mask=c_mask)
else:
tl.atomic_add(c_ptrs, accumulator, mask=c_mask, sem="relaxed")


def _zero_output(*args, **kwargs):
if kwargs["SPLIT_K"] != 1:
args[2].zero_()


fused_moe_kernel.add_pre_run_hook(_zero_output)


def invoke_fused_moe_kernel(
Expand All @@ -515,6 +530,7 @@ def invoke_fused_moe_kernel(
use_int8_w8a16: bool,
use_int4_w4a16: bool,
per_channel_quant: bool,
do_split_k: bool = False,
block_shape: list[int] | None = None,
B_bias: torch.Tensor | None = None,
) -> None:
Expand Down Expand Up @@ -544,12 +560,15 @@ def invoke_fused_moe_kernel(
# so num_valid_experts <= batch_size <= BLOCK_SIZE_M,
# and we can skip some invalid blocks.
EM = min(sorted_token_ids.size(0), A.size(0) * top_k * config["BLOCK_SIZE_M"])
grid = lambda META: (triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(B.size(1), META["BLOCK_SIZE_N"]),)

HAS_BIAS = B_bias is not None
if (use_int8_w8a16 or use_int4_w4a16) and block_shape is not None and block_shape[1] > 0:
assert B_scale is not None and B_scale.ndim == 3
assert B_zp is None or B_zp.ndim == 3

# TODO: add splitk to this kernel
grid = lambda META: (triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(B.size(1), META["BLOCK_SIZE_N"]),)

use_moe_wna16_cuda = should_moe_wna16_use_cuda(
num_valid_tokens=num_tokens,
group_size=block_shape[1],
Expand Down Expand Up @@ -628,11 +647,17 @@ def invoke_fused_moe_kernel(
**config,
)
else:
grid = lambda META: (
META["SPLIT_K"],
triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(B.size(1), META["BLOCK_SIZE_N"]),
)
config = config.copy()
config["SPLIT_K"] = 1
BLOCK_SIZE_K = config.pop("BLOCK_SIZE_K")
if block_shape is not None:
BLOCK_SIZE_K = min(BLOCK_SIZE_K, min(block_shape[0], block_shape[1]))
if not do_split_k:
config["SPLIT_K"] = 1
Comment on lines 655 to +660
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The line config["SPLIT_K"] = 1 unconditionally sets SPLIT_K to 1. This overrides any tuned value for SPLIT_K from the configuration and effectively disables the split-K optimization, as the kernel will always be launched with a grid dimension of 1 for the K-split axis. The subsequent check if not do_split_k: is then redundant when do_split_k is True.

To fix this, the unconditional assignment should be removed, and SPLIT_K should only be set to 1 if do_split_k is False.

Suggested change
config["SPLIT_K"] = 1
BLOCK_SIZE_K = config.pop("BLOCK_SIZE_K")
if block_shape is not None:
BLOCK_SIZE_K = min(BLOCK_SIZE_K, min(block_shape[0], block_shape[1]))
if not do_split_k:
config["SPLIT_K"] = 1
BLOCK_SIZE_K = config.pop("BLOCK_SIZE_K")
if block_shape is not None:
BLOCK_SIZE_K = min(BLOCK_SIZE_K, min(block_shape[0], block_shape[1]))
if not do_split_k:
config["SPLIT_K"] = 1

fused_moe_kernel[grid](
A,
B,
Expand Down Expand Up @@ -922,8 +947,8 @@ def get_default_config(
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": block_shape[0],
"BLOCK_SIZE_K": block_shape[1],
"GROUP_SIZE_M": 32,
"SPLIT_K": 1,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3 if not current_platform.is_rocm() else 2,
}
Expand All @@ -946,16 +971,16 @@ def get_default_config(
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"SPLIT_K": 1,
"GROUP_SIZE_M": 1,
}
else:
config = {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
"SPLIT_K": 1,
"GROUP_SIZE_M": 8,
}
return config

Expand Down Expand Up @@ -989,6 +1014,9 @@ def try_get_optimal_moe_config(
else:
# Else use the default config
config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, block_shape)
# Add SPLIT_K if not present
if "SPLIT_K" not in config:
config["SPLIT_K"] = 1
return config


Expand Down Expand Up @@ -1780,6 +1808,7 @@ def fused_experts_impl(
per_channel_quant=per_channel_quant,
block_shape=block_shape,
B_bias=w1_bias,
do_split_k=True,
)

# Activation function with multiplication
Expand Down Expand Up @@ -1973,6 +2002,7 @@ def apply(
per_channel_quant=self.per_act_token_quant,
block_shape=self.block_shape,
B_bias=self.w1_bias,
do_split_k=True,
)

self.activation(activation, intermediate_cache2, intermediate_cache1.view(-1, N))
Expand Down