Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
__pycache__/
.pyc
.codex
build
dist
*.egg-info
Expand Down
30 changes: 30 additions & 0 deletions lightllm/common/basemodel/triton_kernel/mtp_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Optional
import triton
import triton.language as tl
import torch
Expand Down Expand Up @@ -93,10 +94,15 @@ def _fwd_kernel_mtp_scatter_next_token_ids(
req_to_next_token_ids_stride,
all_next_token_ids,
all_next_token_ids_stride,
req_to_next_token_probs,
req_to_next_token_probs_stride,
all_next_token_probs,
all_next_token_probs_stride,
mtp_accept_len,
b_req_mtp_start_loc,
b_req_idx,
mtp_step,
HAS_HAS_NEXT_TOKEN_PROBS: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):

Expand All @@ -106,6 +112,17 @@ def _fwd_kernel_mtp_scatter_next_token_ids(
cur_req_idx = tl.load(b_req_idx + req_start_loc)
offset = tl.arange(0, BLOCK_SIZE)

if HAS_HAS_NEXT_TOKEN_PROBS:
cur_next_token_probs = tl.load(
all_next_token_probs + (req_start_loc + accept_len - 1) * all_next_token_probs_stride + offset,
mask=offset < mtp_step,
other=0.0,
)
tl.store(
req_to_next_token_probs + cur_req_idx * req_to_next_token_probs_stride + offset,
cur_next_token_probs,
mask=offset < mtp_step,
)
scatter_next_token_ids = tl.load(
all_next_token_ids + (req_start_loc + accept_len - 1) * all_next_token_ids_stride + offset,
mask=offset < mtp_step,
Expand All @@ -125,23 +142,36 @@ def mtp_scatter_next_token_ids(
all_next_token_ids: torch.Tensor,
b_req_idx: torch.Tensor,
mtp_accept_len: torch.Tensor,
req_to_next_token_probs: Optional[torch.Tensor] = None,
all_next_token_probs: Optional[torch.Tensor] = None,
):
max_mtp_step = req_to_next_token_ids.shape[1]
BLOCK_SIZE = 16
assert max_mtp_step <= BLOCK_SIZE, f"max_mtp_step must be less than {BLOCK_SIZE}"
num_reqs = b_req_mtp_start_loc.shape[0]
mtp_step = all_next_token_ids.shape[1]
if req_to_next_token_probs is not None:
assert all_next_token_probs is not None
assert all_next_token_probs.shape == all_next_token_ids.shape

HAS_HAS_NEXT_TOKEN_PROBS = req_to_next_token_probs is not None

grid = (num_reqs,)
num_warps = 1
_fwd_kernel_mtp_scatter_next_token_ids[grid](
req_to_next_token_ids=req_to_next_token_ids,
req_to_next_token_ids_stride=req_to_next_token_ids.stride(0),
all_next_token_ids=all_next_token_ids,
all_next_token_ids_stride=all_next_token_ids.stride(0),
req_to_next_token_probs=req_to_next_token_probs,
req_to_next_token_probs_stride=req_to_next_token_probs.stride(0),
all_next_token_probs=all_next_token_probs,
all_next_token_probs_stride=all_next_token_probs.stride(0),
mtp_accept_len=mtp_accept_len,
b_req_mtp_start_loc=b_req_mtp_start_loc,
b_req_idx=b_req_idx,
mtp_step=mtp_step,
HAS_HAS_NEXT_TOKEN_PROBS=HAS_HAS_NEXT_TOKEN_PROBS,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
num_stages=1,
Expand Down
14 changes: 13 additions & 1 deletion lightllm/common/req_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from lightllm.common.basemodel.triton_kernel.gen_sampling_params import token_id_counter
from lightllm.common.basemodel.triton_kernel.gen_sampling_params import update_req_to_token_id_counter
from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args
from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args, enable_dynamic_mtp_verify
from lightllm.utils.config_utils import get_vocab_size
from lightllm.server.router.model_infer.pin_mem_manager import g_pin_mem_manager

Expand Down Expand Up @@ -116,6 +116,15 @@ def __init__(self, max_request_num):
dtype=torch.int64,
device="cuda",
)
if enable_dynamic_mtp_verify():
self.req_to_next_token_probs = torch.zeros(
(max_request_num + 1, 16),
dtype=torch.float32,
device="cuda",
)
else:
self.req_to_next_token_probs = None

self.req_to_exponential_decay_length_penalty = torch.zeros(
max_request_num + 1, dtype=torch.float32, device="cuda"
)
Expand All @@ -137,6 +146,9 @@ def init_req_sampling_params(self, req):

shm_param = req.sampling_param.shm_param
self.req_to_next_token_ids[req.req_idx][0:1].fill_(req.get_last_gen_token())
if enable_dynamic_mtp_verify():
self.req_to_next_token_probs[req.req_idx].fill_(0.0)
self.req_to_next_token_probs[req.req_idx][0:1].fill_(1.0)
self.req_to_presence_penalty[req.req_idx].fill_(shm_param.presence_penalty)
self.req_to_frequency_penalty[req.req_idx].fill_(shm_param.frequency_penalty)
self.req_to_repetition_penalty[req.req_idx].fill_(shm_param.repetition_penalty)
Expand Down
5 changes: 1 addition & 4 deletions lightllm/server/router/model_infer/infer_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,10 +435,7 @@ def __init__(
# mtp_step 用来记录一个请求 draft模型每步需要生成的token数量
# 正常模式下,这个值为0,在 mtp 模式下,这个值为 draft 模型每步需要生成的token数量
self.mtp_step: int = get_env_start_args().mtp_step
# current_mtp_step 用来记录当前的 MTP 验证长度(<= mtp_step)
# 在启用动态 MTP 验证时,每步会根据 prob 分布重新设置该值
# 静态模式下为 mtp_step,动态模式下为动态计算的 MTP 验证长度
self.current_mtp_step: int = self.mtp_step

if self.mtp_step > 0:
self.decode_need_token_num = self._mtp_decode_need_token_num
else:
Expand Down
19 changes: 12 additions & 7 deletions lightllm/server/router/model_infer/mode_backend/base_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import time
import threading
import torch.distributed as dist
import collections
from typing import List, Tuple, Callable, Optional
from transformers.configuration_utils import PretrainedConfig
from lightllm.utils.infer_utils import set_random_seed
Expand Down Expand Up @@ -775,14 +776,18 @@ def _update_mtp_accept_ratio(

return

def _update_mtp_verify_token_num(self, decode_reqs: List[InferReq]):
def _update_mtp_verify_token_num(
self, decode_reqs: List[InferReq], dynamic_mtp_run_reqs: Optional[List[InferReq]] = None
):
if self.is_master_in_dp:
for req in decode_reqs:
# 统计发送给主模型验证的 token 数量:1 个主 token + 当前 mtp_size 个 draft token
# 在静态 MTP 模式下,使用固定的 mtp_step;在动态 MTP 模式下,使用动态调整的 current_mtp_step
# current_mtp_step 在静态 MTP 模式下为 mtp_step,在动态 MTP 模式下会在推理过程中动态设置。
assert req.current_mtp_step >= 0
req.update_mtp_verify_token_num(verify_token_num=1 + req.current_mtp_step)
if dynamic_mtp_run_reqs is None:
for req in decode_reqs:
assert req.mtp_step > 0
req.update_mtp_verify_token_num(verify_token_num=1 + req.mtp_step)
else:
counter = collections.Counter([req.req_idx for req in dynamic_mtp_run_reqs])
for req in decode_reqs:
req.update_mtp_verify_token_num(verify_token_num=1 + counter[req.req_idx] - 1)
return

def _gen_argmax_token_ids(self, model_output: ModelOutput):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,17 +236,27 @@ def decode_mtp(
model_input, run_reqs = prepare_decode_inputs(decode_reqs)

with torch.cuda.stream(g_infer_context.get_overlap_stream()):
b_mtp_index_cpu = model_input.b_mtp_index

if self.enable_dynamic_mtp:
# 根据当前的 batch size 和 dynamic_batch_size 计算出需要裁剪的 batch size 的model_input
dynamic_batch_size = 10 # TODO: 需要根据实际情况计算出 dynamic_batch_size
trans_to_dynamic_model_input = None # TODO: 需要根据实际情况实现 trans_to_dynamic_model_input
model_input, selected_run_reqs = trans_to_dynamic_model_input(model_input, dynamic_batch_size)
# selected_run_reqs 是一个 gpu tensor, 类型为 int, 0, 表示没有选中, 1 表示选中。

selected_run_reqs_cpu = g_pin_mem_manager.async_copy_from_gpu_tensor(
key="selected_run_reqs",
gpu_tensor=selected_run_reqs,
)
trans_dynamic_model_input_event = torch.cuda.Event()
trans_dynamic_model_input_event.record()

model_output = self.model.forward(model_input)
next_token_ids, next_token_logprobs = sample(model_output.logits, run_reqs, self.eos_id)
# verify the next_token_ids
b_req_mtp_start_loc = [index for index, mtp_index in enumerate(b_mtp_index_cpu) if mtp_index == 0]
b_req_mtp_start_loc = g_pin_mem_manager.gen_from_list(
key="b_req_mtp_start_loc",
data=b_req_mtp_start_loc,
dtype=torch.int32,
).cuda(non_blocking=True)

get_b_req_mtp_start_loc = None # TODO: 需要根据实际情况实现 get_b_req_mtp_start_loc
b_req_mtp_start_loc = get_b_req_mtp_start_loc(model_input.b_mtp_index, req_num=len(decode_reqs))
# b_req_mtp_start_loc 是一个 gpu tensor, 类型为 int, 表示每个请求的 mtp_start_loc, shape 为 len(decode_reqs)
mtp_accept_len, accepted_index = self._verify_mtp_v2(
new_next_token_ids=next_token_ids,
b_req_idx=model_input.b_req_idx,
Expand Down Expand Up @@ -277,22 +287,29 @@ def decode_mtp(

# dynamic_sizes_gpu 用于第二阶段更新 req 的 mtp_size
if self.enable_dynamic_mtp:
draft_probs_tensor = torch.cat(draft_probs_list, dim=-1).view(self.mtp_step, b_mtp_index_cpu.shape[0])
draft_probs_tensor = torch.cat(draft_probs_list, dim=-1).view(
self.mtp_step, model_input.b_mtp_index.shape[0]
)
dynamic_sizes_gpu = self._compute_dynamic_mtp_size_gpu_part(draft_probs_tensor=draft_probs_tensor)
# 异步拷贝回 CPU Pin Memory
dynamic_sizes_cpu = g_pin_mem_manager.async_copy_from_gpu_tensor(
key="dynamic_mtp_sizes", gpu_tensor=dynamic_sizes_gpu
)

dynamic_mtp_event = torch.cuda.Event()
dynamic_mtp_event.record()
dynamic_sizes_cpu # TODO, use to update statcis.
draft_probs_list = [e.view(-1, 1) for e in draft_probs_list]
draft_probs_list = [torch.ones_like(draft_probs_list[-1])] + draft_probs_list
all_next_token_probs = torch.cat(draft_probs_list, dim=-1) # [batch_size, mtp_step + 1]
else:
all_next_token_probs = None

mtp_scatter_next_token_ids(
req_to_next_token_ids=self.model.req_manager.req_sampling_params_manager.req_to_next_token_ids,
b_req_mtp_start_loc=b_req_mtp_start_loc,
all_next_token_ids=all_next_token_ids,
b_req_idx=model_input.b_req_idx,
mtp_accept_len=mtp_accept_len,
req_to_next_token_probs=self.model.req_manager.req_sampling_params_manager.req_to_next_token_probs,
all_next_token_probs=all_next_token_probs,
)

next_token_ids_cpu, next_token_logprobs_cpu = self._async_copy_next_token_infos_to_pin_mem(
Expand All @@ -315,22 +332,34 @@ def decode_mtp(

# 第二阶段
event_pack.notify_post_handle_and_wait_pre_post_handle()
self._update_mtp_verify_token_num(decode_reqs=decode_reqs)

if self.enable_dynamic_mtp:
trans_dynamic_model_input_event.synchronize()
selected_run_reqs_cpu_numpy = selected_run_reqs_cpu.numpy()
run_reqs = [run_reqs[i] for i in range(len(run_reqs)) if selected_run_reqs_cpu_numpy[i] == 1]

if self.enable_dynamic_mtp:
self._update_mtp_verify_token_num(decode_reqs=decode_reqs, dynamic_mtp_run_reqs=run_reqs)
else:
self._update_mtp_verify_token_num(decode_reqs=decode_reqs)

verify_event.synchronize()
accepted_index_cpu_numpy = accepted_index_cpu.numpy()
verify_ok_reqs = [run_reqs[i] for i in range(len(run_reqs)) if accepted_index_cpu_numpy[i] == 1]
if self.enable_dynamic_mtp:
dynamic_mtp_event.synchronize()
self._update_dynamic_mtp_size_cpu_part(
run_reqs=run_reqs, dynamic_sizes_cpu=dynamic_sizes_cpu, accepted_index_cpu=accepted_index_cpu
)
update_packs = self._pre_post_handle(verify_ok_reqs, is_chuncked_mode=False)

# 第三阶段
event_pack.notify_forward_and_wait_post_handle()
sync_event.synchronize()

if self.enable_dynamic_mtp:
self._update_dynamic_mtp_size_cpu_part(
decode_reqs=decode_reqs,
run_reqs=run_reqs,
dynamic_sizes_cpu=dynamic_sizes_cpu,
accepted_index_cpu=accepted_index_cpu,
)

# 处理需要释放的内存索引
need_free_mem_indexes = model_input.mem_indexes_cpu[accepted_index_cpu == 0]
if additional_mem_indexes_cpu is not None:
Expand Down Expand Up @@ -367,15 +396,19 @@ def _compute_dynamic_mtp_size_gpu_part(

def _update_dynamic_mtp_size_cpu_part(
self,
decode_reqs: List[InferReq],
run_reqs: List[InferReq],
dynamic_sizes_cpu: torch.Tensor,
accepted_index_cpu: torch.Tensor,
):
id_to_current_mtp_step = {}
assert len(run_reqs) == dynamic_sizes_cpu.shape[0] == accepted_index_cpu.shape[0]
for req, new_size, accepted in zip(run_reqs, dynamic_sizes_cpu.numpy(), accepted_index_cpu.numpy()):
if int(accepted) == 1:
req.current_mtp_step = int(new_size)
assert req.current_mtp_step <= req.mtp_step
assert int(new_size) <= req.mtp_step
id_to_current_mtp_step[req.req_idx] = int(new_size)
# TODO 将 id_to_current_mtp_step 的信息更新到 planner 中去
pass

def _draft_prefill_forward(self, model_input: ModelInput, model_output: ModelOutput, next_token_ids: torch.Tensor):
# spec prefill: MTP, 这个地方只是为了填充draft model的 kv, 并不会使用生成的token_id。
Expand Down
Loading