diff --git a/.gitignore b/.gitignore index d572eac42..b1717ce67 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ __pycache__/ .pyc +.codex build dist *.egg-info diff --git a/lightllm/common/basemodel/triton_kernel/mtp_utils.py b/lightllm/common/basemodel/triton_kernel/mtp_utils.py index 2d70a68c0..843a77d96 100644 --- a/lightllm/common/basemodel/triton_kernel/mtp_utils.py +++ b/lightllm/common/basemodel/triton_kernel/mtp_utils.py @@ -1,3 +1,4 @@ +from typing import Optional import triton import triton.language as tl import torch @@ -93,10 +94,15 @@ def _fwd_kernel_mtp_scatter_next_token_ids( req_to_next_token_ids_stride, all_next_token_ids, all_next_token_ids_stride, + req_to_next_token_probs, + req_to_next_token_probs_stride, + all_next_token_probs, + all_next_token_probs_stride, mtp_accept_len, b_req_mtp_start_loc, b_req_idx, mtp_step, + HAS_HAS_NEXT_TOKEN_PROBS: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): @@ -106,6 +112,17 @@ def _fwd_kernel_mtp_scatter_next_token_ids( cur_req_idx = tl.load(b_req_idx + req_start_loc) offset = tl.arange(0, BLOCK_SIZE) + if HAS_HAS_NEXT_TOKEN_PROBS: + cur_next_token_probs = tl.load( + all_next_token_probs + (req_start_loc + accept_len - 1) * all_next_token_probs_stride + offset, + mask=offset < mtp_step, + other=0.0, + ) + tl.store( + req_to_next_token_probs + cur_req_idx * req_to_next_token_probs_stride + offset, + cur_next_token_probs, + mask=offset < mtp_step, + ) scatter_next_token_ids = tl.load( all_next_token_ids + (req_start_loc + accept_len - 1) * all_next_token_ids_stride + offset, mask=offset < mtp_step, @@ -125,12 +142,20 @@ def mtp_scatter_next_token_ids( all_next_token_ids: torch.Tensor, b_req_idx: torch.Tensor, mtp_accept_len: torch.Tensor, + req_to_next_token_probs: Optional[torch.Tensor] = None, + all_next_token_probs: Optional[torch.Tensor] = None, ): max_mtp_step = req_to_next_token_ids.shape[1] BLOCK_SIZE = 16 assert max_mtp_step <= BLOCK_SIZE, f"max_mtp_step must be less than {BLOCK_SIZE}" num_reqs = b_req_mtp_start_loc.shape[0] mtp_step = all_next_token_ids.shape[1] + if req_to_next_token_probs is not None: + assert all_next_token_probs is not None + assert all_next_token_probs.shape == all_next_token_ids.shape + + HAS_HAS_NEXT_TOKEN_PROBS = req_to_next_token_probs is not None + grid = (num_reqs,) num_warps = 1 _fwd_kernel_mtp_scatter_next_token_ids[grid]( @@ -138,10 +163,15 @@ def mtp_scatter_next_token_ids( req_to_next_token_ids_stride=req_to_next_token_ids.stride(0), all_next_token_ids=all_next_token_ids, all_next_token_ids_stride=all_next_token_ids.stride(0), + req_to_next_token_probs=req_to_next_token_probs, + req_to_next_token_probs_stride=req_to_next_token_probs.stride(0), + all_next_token_probs=all_next_token_probs, + all_next_token_probs_stride=all_next_token_probs.stride(0), mtp_accept_len=mtp_accept_len, b_req_mtp_start_loc=b_req_mtp_start_loc, b_req_idx=b_req_idx, mtp_step=mtp_step, + HAS_HAS_NEXT_TOKEN_PROBS=HAS_HAS_NEXT_TOKEN_PROBS, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps, num_stages=1, diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index 3a4e2b631..278b3509c 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -6,7 +6,7 @@ from lightllm.common.basemodel.triton_kernel.gen_sampling_params import token_id_counter from lightllm.common.basemodel.triton_kernel.gen_sampling_params import update_req_to_token_id_counter -from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args +from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args, enable_dynamic_mtp_verify from lightllm.utils.config_utils import get_vocab_size from lightllm.server.router.model_infer.pin_mem_manager import g_pin_mem_manager @@ -116,6 +116,15 @@ def __init__(self, max_request_num): dtype=torch.int64, device="cuda", ) + if enable_dynamic_mtp_verify(): + self.req_to_next_token_probs = torch.zeros( + (max_request_num + 1, 16), + dtype=torch.float32, + device="cuda", + ) + else: + self.req_to_next_token_probs = None + self.req_to_exponential_decay_length_penalty = torch.zeros( max_request_num + 1, dtype=torch.float32, device="cuda" ) @@ -137,6 +146,9 @@ def init_req_sampling_params(self, req): shm_param = req.sampling_param.shm_param self.req_to_next_token_ids[req.req_idx][0:1].fill_(req.get_last_gen_token()) + if enable_dynamic_mtp_verify(): + self.req_to_next_token_probs[req.req_idx].fill_(0.0) + self.req_to_next_token_probs[req.req_idx][0:1].fill_(1.0) self.req_to_presence_penalty[req.req_idx].fill_(shm_param.presence_penalty) self.req_to_frequency_penalty[req.req_idx].fill_(shm_param.frequency_penalty) self.req_to_repetition_penalty[req.req_idx].fill_(shm_param.repetition_penalty) diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 46608da13..7b931f153 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -435,10 +435,7 @@ def __init__( # mtp_step 用来记录一个请求 draft模型每步需要生成的token数量 # 正常模式下,这个值为0,在 mtp 模式下,这个值为 draft 模型每步需要生成的token数量 self.mtp_step: int = get_env_start_args().mtp_step - # current_mtp_step 用来记录当前的 MTP 验证长度(<= mtp_step) - # 在启用动态 MTP 验证时,每步会根据 prob 分布重新设置该值 - # 静态模式下为 mtp_step,动态模式下为动态计算的 MTP 验证长度 - self.current_mtp_step: int = self.mtp_step + if self.mtp_step > 0: self.decode_need_token_num = self._mtp_decode_need_token_num else: diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 321055b4b..4203b093f 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -4,6 +4,7 @@ import time import threading import torch.distributed as dist +import collections from typing import List, Tuple, Callable, Optional from transformers.configuration_utils import PretrainedConfig from lightllm.utils.infer_utils import set_random_seed @@ -775,14 +776,18 @@ def _update_mtp_accept_ratio( return - def _update_mtp_verify_token_num(self, decode_reqs: List[InferReq]): + def _update_mtp_verify_token_num( + self, decode_reqs: List[InferReq], dynamic_mtp_run_reqs: Optional[List[InferReq]] = None + ): if self.is_master_in_dp: - for req in decode_reqs: - # 统计发送给主模型验证的 token 数量:1 个主 token + 当前 mtp_size 个 draft token - # 在静态 MTP 模式下,使用固定的 mtp_step;在动态 MTP 模式下,使用动态调整的 current_mtp_step - # current_mtp_step 在静态 MTP 模式下为 mtp_step,在动态 MTP 模式下会在推理过程中动态设置。 - assert req.current_mtp_step >= 0 - req.update_mtp_verify_token_num(verify_token_num=1 + req.current_mtp_step) + if dynamic_mtp_run_reqs is None: + for req in decode_reqs: + assert req.mtp_step > 0 + req.update_mtp_verify_token_num(verify_token_num=1 + req.mtp_step) + else: + counter = collections.Counter([req.req_idx for req in dynamic_mtp_run_reqs]) + for req in decode_reqs: + req.update_mtp_verify_token_num(verify_token_num=1 + counter[req.req_idx] - 1) return def _gen_argmax_token_ids(self, model_output: ModelOutput): diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index c41dbb6d9..392c826b9 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -236,17 +236,27 @@ def decode_mtp( model_input, run_reqs = prepare_decode_inputs(decode_reqs) with torch.cuda.stream(g_infer_context.get_overlap_stream()): - b_mtp_index_cpu = model_input.b_mtp_index + + if self.enable_dynamic_mtp: + # 根据当前的 batch size 和 dynamic_batch_size 计算出需要裁剪的 batch size 的model_input + dynamic_batch_size = 10 # TODO: 需要根据实际情况计算出 dynamic_batch_size + trans_to_dynamic_model_input = None # TODO: 需要根据实际情况实现 trans_to_dynamic_model_input + model_input, selected_run_reqs = trans_to_dynamic_model_input(model_input, dynamic_batch_size) + # selected_run_reqs 是一个 gpu tensor, 类型为 int, 0, 表示没有选中, 1 表示选中。 + + selected_run_reqs_cpu = g_pin_mem_manager.async_copy_from_gpu_tensor( + key="selected_run_reqs", + gpu_tensor=selected_run_reqs, + ) + trans_dynamic_model_input_event = torch.cuda.Event() + trans_dynamic_model_input_event.record() + model_output = self.model.forward(model_input) next_token_ids, next_token_logprobs = sample(model_output.logits, run_reqs, self.eos_id) # verify the next_token_ids - b_req_mtp_start_loc = [index for index, mtp_index in enumerate(b_mtp_index_cpu) if mtp_index == 0] - b_req_mtp_start_loc = g_pin_mem_manager.gen_from_list( - key="b_req_mtp_start_loc", - data=b_req_mtp_start_loc, - dtype=torch.int32, - ).cuda(non_blocking=True) - + get_b_req_mtp_start_loc = None # TODO: 需要根据实际情况实现 get_b_req_mtp_start_loc + b_req_mtp_start_loc = get_b_req_mtp_start_loc(model_input.b_mtp_index, req_num=len(decode_reqs)) + # b_req_mtp_start_loc 是一个 gpu tensor, 类型为 int, 表示每个请求的 mtp_start_loc, shape 为 len(decode_reqs) mtp_accept_len, accepted_index = self._verify_mtp_v2( new_next_token_ids=next_token_ids, b_req_idx=model_input.b_req_idx, @@ -277,15 +287,20 @@ def decode_mtp( # dynamic_sizes_gpu 用于第二阶段更新 req 的 mtp_size if self.enable_dynamic_mtp: - draft_probs_tensor = torch.cat(draft_probs_list, dim=-1).view(self.mtp_step, b_mtp_index_cpu.shape[0]) + draft_probs_tensor = torch.cat(draft_probs_list, dim=-1).view( + self.mtp_step, model_input.b_mtp_index.shape[0] + ) dynamic_sizes_gpu = self._compute_dynamic_mtp_size_gpu_part(draft_probs_tensor=draft_probs_tensor) # 异步拷贝回 CPU Pin Memory dynamic_sizes_cpu = g_pin_mem_manager.async_copy_from_gpu_tensor( key="dynamic_mtp_sizes", gpu_tensor=dynamic_sizes_gpu ) - - dynamic_mtp_event = torch.cuda.Event() - dynamic_mtp_event.record() + dynamic_sizes_cpu # TODO, use to update statcis. + draft_probs_list = [e.view(-1, 1) for e in draft_probs_list] + draft_probs_list = [torch.ones_like(draft_probs_list[-1])] + draft_probs_list + all_next_token_probs = torch.cat(draft_probs_list, dim=-1) # [batch_size, mtp_step + 1] + else: + all_next_token_probs = None mtp_scatter_next_token_ids( req_to_next_token_ids=self.model.req_manager.req_sampling_params_manager.req_to_next_token_ids, @@ -293,6 +308,8 @@ def decode_mtp( all_next_token_ids=all_next_token_ids, b_req_idx=model_input.b_req_idx, mtp_accept_len=mtp_accept_len, + req_to_next_token_probs=self.model.req_manager.req_sampling_params_manager.req_to_next_token_probs, + all_next_token_probs=all_next_token_probs, ) next_token_ids_cpu, next_token_logprobs_cpu = self._async_copy_next_token_infos_to_pin_mem( @@ -315,22 +332,34 @@ def decode_mtp( # 第二阶段 event_pack.notify_post_handle_and_wait_pre_post_handle() - self._update_mtp_verify_token_num(decode_reqs=decode_reqs) + + if self.enable_dynamic_mtp: + trans_dynamic_model_input_event.synchronize() + selected_run_reqs_cpu_numpy = selected_run_reqs_cpu.numpy() + run_reqs = [run_reqs[i] for i in range(len(run_reqs)) if selected_run_reqs_cpu_numpy[i] == 1] + + if self.enable_dynamic_mtp: + self._update_mtp_verify_token_num(decode_reqs=decode_reqs, dynamic_mtp_run_reqs=run_reqs) + else: + self._update_mtp_verify_token_num(decode_reqs=decode_reqs) verify_event.synchronize() accepted_index_cpu_numpy = accepted_index_cpu.numpy() verify_ok_reqs = [run_reqs[i] for i in range(len(run_reqs)) if accepted_index_cpu_numpy[i] == 1] - if self.enable_dynamic_mtp: - dynamic_mtp_event.synchronize() - self._update_dynamic_mtp_size_cpu_part( - run_reqs=run_reqs, dynamic_sizes_cpu=dynamic_sizes_cpu, accepted_index_cpu=accepted_index_cpu - ) update_packs = self._pre_post_handle(verify_ok_reqs, is_chuncked_mode=False) # 第三阶段 event_pack.notify_forward_and_wait_post_handle() sync_event.synchronize() + if self.enable_dynamic_mtp: + self._update_dynamic_mtp_size_cpu_part( + decode_reqs=decode_reqs, + run_reqs=run_reqs, + dynamic_sizes_cpu=dynamic_sizes_cpu, + accepted_index_cpu=accepted_index_cpu, + ) + # 处理需要释放的内存索引 need_free_mem_indexes = model_input.mem_indexes_cpu[accepted_index_cpu == 0] if additional_mem_indexes_cpu is not None: @@ -367,15 +396,19 @@ def _compute_dynamic_mtp_size_gpu_part( def _update_dynamic_mtp_size_cpu_part( self, + decode_reqs: List[InferReq], run_reqs: List[InferReq], dynamic_sizes_cpu: torch.Tensor, accepted_index_cpu: torch.Tensor, ): + id_to_current_mtp_step = {} assert len(run_reqs) == dynamic_sizes_cpu.shape[0] == accepted_index_cpu.shape[0] for req, new_size, accepted in zip(run_reqs, dynamic_sizes_cpu.numpy(), accepted_index_cpu.numpy()): if int(accepted) == 1: - req.current_mtp_step = int(new_size) - assert req.current_mtp_step <= req.mtp_step + assert int(new_size) <= req.mtp_step + id_to_current_mtp_step[req.req_idx] = int(new_size) + # TODO 将 id_to_current_mtp_step 的信息更新到 planner 中去 + pass def _draft_prefill_forward(self, model_input: ModelInput, model_output: ModelOutput, next_token_ids: torch.Tensor): # spec prefill: MTP, 这个地方只是为了填充draft model的 kv, 并不会使用生成的token_id。 diff --git a/lightllm/server/router/model_infer/mode_backend/dynamic_mtp_planner.py b/lightllm/server/router/model_infer/mode_backend/dynamic_mtp_planner.py new file mode 100644 index 000000000..b09a711bd --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/dynamic_mtp_planner.py @@ -0,0 +1,275 @@ +import copy +import math +import random +from dataclasses import dataclass +from typing import Dict, List, Optional + +import torch + +from lightllm.common.basemodel.batch_objs import ModelInput +from lightllm.common.basemodel.infer_lock import g_infer_state_lock +from lightllm.server.router.model_infer.infer_batch import InferReq, g_infer_context +from lightllm.utils.envs_utils import get_diverse_max_batch_shared_group_size + + +@dataclass +class DynamicMTPPlan: + req_num: int + original_batch_size: int + dynamic_batch_size: int + keep_indices: torch.Tensor + per_req_rows: List[int] + estimated_accept_mean: float + estimated_accept_std: float + + +class _EMAValue: + def __init__(self, decay: float, init_value: Optional[float] = None) -> None: + self.decay = decay + self.value = init_value + self.initialized = init_value is not None + + def update(self, new_value: float) -> float: + if not self.initialized: + self.value = new_value + self.initialized = True + else: + self.value = self.decay * self.value + (1.0 - self.decay) * new_value + return self.value + + def get(self, fallback: float) -> float: + return self.value if self.initialized else fallback + + +class DynamicMTPPlanner: + def __init__( + self, + mtp_step: int, + ema_decay: float = 0.9, + confidence_k: float = 1.0, + ) -> None: + self.mtp_step = mtp_step + self.max_rows_per_req = mtp_step + 1 + self.confidence_k = confidence_k + self.accept_mean = _EMAValue(ema_decay, init_value=float(self.max_rows_per_req)) + self.accept_second_moment = _EMAValue(ema_decay, init_value=float(self.max_rows_per_req**2)) + self.req_accept_mean: Dict[int, _EMAValue] = {} + self.req_prob: Dict[int, _EMAValue] = {} + self.latency_ms_by_batch_size: Dict[int, _EMAValue] = {} + self.accepted_token_speed = _EMAValue(ema_decay) + self.verify_row_speed = _EMAValue(ema_decay) + self.actual_speedup = _EMAValue(ema_decay) + self.single_token_speed_by_req_num: Dict[int, _EMAValue] = {} + self.last_plan: Optional[DynamicMTPPlan] = None + + def trim_before_forward( + self, + model_input: ModelInput, + run_reqs: List[InferReq], + decode_reqs: List[InferReq], + ): + plan = self._build_plan(model_input=model_input, decode_reqs=decode_reqs) + self.last_plan = plan + if plan.dynamic_batch_size == plan.original_batch_size: + return model_input, run_reqs, plan + + pruned_indices = self._invert_indices(plan.keep_indices, plan.original_batch_size) + if pruned_indices.numel() > 0: + pruned_mem_indexes = model_input.mem_indexes_cpu[pruned_indices] + g_infer_state_lock.acquire() + g_infer_context.req_manager.mem_manager.free(pruned_mem_indexes) + g_infer_state_lock.release() + + trimmed_input = copy.copy(model_input) + keep_indices = plan.keep_indices + keep_list = keep_indices.tolist() + + trimmed_input.batch_size = plan.dynamic_batch_size + trimmed_input.b_req_idx = model_input.b_req_idx[keep_indices].contiguous() + trimmed_input.b_mtp_index = model_input.b_mtp_index[keep_indices].contiguous() + trimmed_input.b_seq_len = model_input.b_seq_len[keep_indices].contiguous() + trimmed_input.mem_indexes_cpu = model_input.mem_indexes_cpu[keep_indices].contiguous() + trimmed_input.mem_indexes = None + trimmed_input.total_token_num = int(trimmed_input.b_seq_len.sum().item()) + trimmed_input.max_kv_seq_len = int(trimmed_input.b_seq_len.max().item()) + trimmed_input.multimodal_params = [model_input.multimodal_params[index] for index in keep_list] + trimmed_run_reqs = [run_reqs[index] for index in keep_list] + trimmed_input.b_mark_shared_group = self._build_mtp_shared_group_infos(trimmed_run_reqs) + trimmed_input.check_input() + return trimmed_input, trimmed_run_reqs, plan + + def update_after_verify( + self, + plan: DynamicMTPPlan, + decode_reqs: List[InferReq], + mtp_accept_len_cpu: torch.Tensor, + elapsed_ms: float, + per_req_probs_cpu: Optional[torch.Tensor] = None, + ) -> None: + if plan is None: + return + + accept_lens = [int(value) for value in mtp_accept_len_cpu.numpy()] + if not accept_lens: + return + + batch_mean = sum(accept_lens) / len(accept_lens) + batch_second = sum(value * value for value in accept_lens) / len(accept_lens) + self.accept_mean.update(batch_mean) + self.accept_second_moment.update(batch_second) + + for req, accept_len in zip(decode_reqs, accept_lens): + req_ema = self.req_accept_mean.get(req.req_id) + if req_ema is None: + req_ema = _EMAValue(self.accept_mean.decay, init_value=batch_mean) + self.req_accept_mean[req.req_id] = req_ema + req_ema.update(float(accept_len)) + + if per_req_probs_cpu is not None: + for req, req_prob in zip(decode_reqs, per_req_probs_cpu.numpy()): + req_prob_ema = self.req_prob.get(req.req_id) + if req_prob_ema is None: + req_prob_ema = _EMAValue(self.accept_mean.decay) + self.req_prob[req.req_id] = req_prob_ema + req_prob_ema.update(self._clip_prob(float(req_prob))) + + if elapsed_ms > 0: + current_output_speed = sum(accept_lens) / elapsed_ms + latency_ema = self.latency_ms_by_batch_size.get(plan.dynamic_batch_size) + if latency_ema is None: + latency_ema = _EMAValue(self.accept_mean.decay) + self.latency_ms_by_batch_size[plan.dynamic_batch_size] = latency_ema + latency_ema.update(elapsed_ms) + self.accepted_token_speed.update(current_output_speed) + self.verify_row_speed.update(plan.dynamic_batch_size / elapsed_ms) + if plan.dynamic_batch_size == plan.req_num: + single_token_speed = self.single_token_speed_by_req_num.get(plan.req_num) + if single_token_speed is None: + single_token_speed = _EMAValue(self.accept_mean.decay) + self.single_token_speed_by_req_num[plan.req_num] = single_token_speed + single_token_speed.update(plan.req_num / elapsed_ms) + baseline_speed = self.single_token_speed_by_req_num.get(plan.req_num) + if baseline_speed is not None and baseline_speed.get(0.0) > 0: + self.actual_speedup.update(current_output_speed / baseline_speed.get(0.0)) + + def get_stats_snapshot(self) -> Dict[str, float]: + return { + "accept_mean": self.accept_mean.get(float(self.max_rows_per_req)), + "accept_second_moment": self.accept_second_moment.get(float(self.max_rows_per_req**2)), + "accepted_token_speed": self.accepted_token_speed.get(0.0), + "verify_row_speed": self.verify_row_speed.get(0.0), + "actual_speedup": self.actual_speedup.get(0.0), + } + + def _build_plan(self, model_input: ModelInput, decode_reqs: List[InferReq]) -> DynamicMTPPlan: + req_num = len(decode_reqs) + original_batch_size = model_input.batch_size + if req_num == 0 or self.mtp_step == 0: + keep_indices = torch.arange(original_batch_size, dtype=torch.long, device="cpu") + return DynamicMTPPlan(req_num, original_batch_size, original_batch_size, keep_indices, [], 1.0, 0.0) + + mean = self.accept_mean.get(float(self.max_rows_per_req)) + second = self.accept_second_moment.get(float(self.max_rows_per_req**2)) + variance = max(0.0, second - mean * mean) + std = math.sqrt(variance) + budget = math.ceil(req_num * mean + self.confidence_k * math.sqrt(req_num) * std) + dynamic_batch_size = max(req_num, min(original_batch_size, budget)) + + per_req_rows = self._allocate_rows(decode_reqs=decode_reqs, dynamic_batch_size=dynamic_batch_size) + keep_indices, per_req_rows = self._build_keep_indices(model_input=model_input, per_req_rows=per_req_rows) + dynamic_batch_size = int(keep_indices.numel()) + + return DynamicMTPPlan( + req_num=req_num, + original_batch_size=original_batch_size, + dynamic_batch_size=dynamic_batch_size, + keep_indices=keep_indices, + per_req_rows=per_req_rows, + estimated_accept_mean=mean, + estimated_accept_std=std, + ) + + def _allocate_rows(self, decode_reqs: List[InferReq], dynamic_batch_size: int) -> List[int]: + req_num = len(decode_reqs) + per_req_rows = [1 for _ in range(req_num)] + remaining = dynamic_batch_size - req_num + if remaining <= 0: + return per_req_rows + + req_order = sorted( + range(req_num), + key=lambda index: self._req_prob(decode_reqs[index]), + reverse=True, + ) + + for req_index in req_order: + req_prob = self._req_prob(decode_reqs[req_index]) + for _ in range(self.mtp_step): + if remaining <= 0: + break + if random.random() >= req_prob: + break + per_req_rows[req_index] += 1 + remaining -= 1 + if remaining <= 0: + break + return per_req_rows + + def _req_prob(self, req: InferReq) -> float: + req_prob = self.req_prob.get(req.req_id) + if req_prob is not None: + return self._clip_prob(req_prob.get(1.0)) + fallback = self.accept_mean.get(float(self.max_rows_per_req)) / float(self.max_rows_per_req) + return self._clip_prob(fallback) + + def _clip_prob(self, value: float) -> float: + return min(1.0, max(0.0, value)) + + def _build_keep_indices(self, model_input: ModelInput, per_req_rows: List[int]): + keep_indices = [] + effective_per_req_rows = [0 for _ in per_req_rows] + req_index = -1 + cur_req_kept = 0 + cur_req_target = 0 + for index, mtp_index in enumerate(model_input.b_mtp_index.tolist()): + if mtp_index == 0: + req_index += 1 + cur_req_kept = 0 + cur_req_target = per_req_rows[req_index] + if cur_req_kept < cur_req_target: + keep_indices.append(index) + cur_req_kept += 1 + effective_per_req_rows[req_index] += 1 + return torch.tensor(keep_indices, dtype=torch.long, device="cpu"), effective_per_req_rows + + def _invert_indices(self, keep_indices: torch.Tensor, total_size: int) -> torch.Tensor: + keep_mask = torch.zeros((total_size,), dtype=torch.bool, device="cpu") + keep_mask[keep_indices] = True + return torch.nonzero(~keep_mask, as_tuple=False).view(-1) + + def _build_mtp_shared_group_infos(self, run_reqs: List[InferReq]) -> torch.Tensor: + max_batch_shared_group_size = get_diverse_max_batch_shared_group_size() + req_ids = [req.req_id for req in run_reqs] + b_mark_shared_group = [] + current_group = [] + for req_id in req_ids: + if not current_group: + current_group.append(req_id) + elif req_id == current_group[-1]: + current_group.append(req_id) + else: + b_mark_shared_group.extend([0 for _ in range(len(current_group))]) + b_mark_shared_group[-1] = len(current_group) + current_group.clear() + current_group.append(req_id) + + if len(current_group) == max_batch_shared_group_size: + b_mark_shared_group.extend([0 for _ in range(len(current_group))]) + b_mark_shared_group[-1] = len(current_group) + current_group.clear() + + if current_group: + b_mark_shared_group.extend([0 for _ in range(len(current_group))]) + b_mark_shared_group[-1] = len(current_group) + + return torch.tensor(b_mark_shared_group, dtype=torch.int32, device="cpu") diff --git a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py index 3d9d8815e..c50e0b7c3 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py @@ -113,9 +113,7 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In b_mtp_index.append(0) multimodal_params.append(req.multimodal_params) # process the draft tokens. - # 动态 MTP 模式:使用动态 current_mtp_step 构建 batch - # 非动态 MTP 模式:current_mtp_step 为固定的 mtp_step - for step in range(req.current_mtp_step): + for step in range(req.mtp_step): run_reqs.append(req) b_req_idx.append(req.req_idx) seq_len += 1 diff --git a/test/benchmark/service/benchmark_sharegpt.py b/test/benchmark/service/benchmark_sharegpt.py index b056e69bd..9337e6527 100644 --- a/test/benchmark/service/benchmark_sharegpt.py +++ b/test/benchmark/service/benchmark_sharegpt.py @@ -215,7 +215,7 @@ async def send_request( "top_k": 1, "top_p": 1.0, "temperature": 0, - "stream": True, + # "stream": True, "ignore_eos": True, "max_tokens": output_len, } @@ -224,20 +224,41 @@ async def send_request( async with aiohttp.ClientSession(timeout=timeout) as session: async with session.post(url, headers=headers, json=data) as response: + response.raise_for_status() chunks = [] text = "" start_time = time.time() is_first = True + sse_buffer = "" async for chunk, _ in response.content.iter_chunks(): now_time = time.time() delta_time = now_time - start_time if is_first: is_first = False ttft = delta_time - text += json.loads(chunk.decode("utf-8")[6:])["choices"][0]["delta"].get("content", "") - if delta_time < 0.005: - receive_n += 1 chunks.append(delta_time) + # OpenAI-compatible stream is SSE; one TCP chunk may contain + # partial/multiple events. Parse by complete lines safely. + sse_buffer += chunk.decode("utf-8", errors="ignore") + while "\n" in sse_buffer: + line, sse_buffer = sse_buffer.split("\n", 1) + line = line.strip() + if not line or not line.startswith("data:"): + continue + payload = line[5:].strip() + if payload == "[DONE]": + break + if not payload: + continue + try: + event = json.loads(payload) + except json.JSONDecodeError: + # In rare cases malformed/partial payload slips in; + # skip and continue to keep benchmark running. + continue + text += event.get("choices", [{}])[0].get("delta", {}).get("content", "") + if delta_time < 0.005: + receive_n += 1 start_time = now_time # print("messages", messages) # print("text", text) diff --git a/test/speculative/bench_throughput.sh b/test/speculative/bench_throughput.sh index 8e14f8189..4cfa90bcb 100644 --- a/test/speculative/bench_throughput.sh +++ b/test/speculative/bench_throughput.sh @@ -2,7 +2,7 @@ # 默认值 PORT=8088 NUM_PROMPTS=1000 -TOKENIZER="/mtc/models/qwen3-8b" +TOKENIZER="/mtc/models/qwen3-32b" DATASET="/data/nvme0/chenjunyi/project/lightllm/datasets/gsm8k.json" HISTORY_TURNS=1 CONCURRENCY=128 diff --git a/test/speculative/qwen3-32b/dynamic_triton.sh b/test/speculative/qwen3-32b/dynamic_triton.sh index 39145e5f5..ca70bf15e 100644 --- a/test/speculative/qwen3-32b/dynamic_triton.sh +++ b/test/speculative/qwen3-32b/dynamic_triton.sh @@ -16,8 +16,10 @@ done MODEL_DIR=/mtc/models/qwen3-32b DRAFT_MODEL_DIR=/mtc/models/qwen3-32b-eagle3 -LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ ---tp 4 --max_total_token_num 200000 \ +PATH=/data/nvme0/chenjunyi/miniconda3/envs/lightllm/bin:$PATH + +LOADWORKER=18 /data/nvme0/chenjunyi/miniconda3/envs/lightllm/bin/python -m lightllm.server.api_server --port 8088 \ +--tp 2 \ --model_dir ${MODEL_DIR} \ --mtp_mode eagle3 \ --disable_dynamic_prompt_cache \ diff --git a/test/speculative/qwen3-32b/no_mtp_fa3.sh b/test/speculative/qwen3-32b/no_mtp_fa3.sh new file mode 100644 index 000000000..c17562721 --- /dev/null +++ b/test/speculative/qwen3-32b/no_mtp_fa3.sh @@ -0,0 +1,11 @@ +MODEL_DIR=/mtc/models/qwen3-32b +DRAFT_MODEL_DIR=/mtc/models/qwen3-32b-eagle3 + +PATH=/data/nvme0/chenjunyi/miniconda3/envs/lightllm/bin:$PATH + +LOADWORKER=18 /data/nvme0/chenjunyi/miniconda3/envs/lightllm/bin/python -m lightllm.server.api_server --port 8088 \ +--tp 2 \ +--model_dir ${MODEL_DIR} \ +--disable_dynamic_prompt_cache \ +--graph_grow_step_size 1 \ +--llm_decode_att_backend triton \ No newline at end of file diff --git a/test/speculative/qwen3-32b/static_fa3.sh b/test/speculative/qwen3-32b/static_fa3.sh index c9712116e..44c67e03b 100644 --- a/test/speculative/qwen3-32b/static_fa3.sh +++ b/test/speculative/qwen3-32b/static_fa3.sh @@ -16,8 +16,10 @@ done MODEL_DIR=/mtc/models/qwen3-32b DRAFT_MODEL_DIR=/mtc/models/qwen3-32b-eagle3 -LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ ---tp 4 --max_total_token_num 200000 \ +PATH=/data/nvme0/chenjunyi/miniconda3/envs/lightllm/bin:$PATH + +LOADWORKER=18 /data/nvme0/chenjunyi/miniconda3/envs/lightllm/bin/python -m lightllm.server.api_server --port 8088 \ +--tp 2 \ --model_dir ${MODEL_DIR} \ --mtp_mode eagle3 \ --mtp_draft_model_dir ${DRAFT_MODEL_DIR} \ diff --git a/test/speculative/qwen3-32b/static_triton.sh b/test/speculative/qwen3-32b/static_triton.sh index 453c5678e..71964c9af 100644 --- a/test/speculative/qwen3-32b/static_triton.sh +++ b/test/speculative/qwen3-32b/static_triton.sh @@ -16,8 +16,10 @@ done MODEL_DIR=/mtc/models/qwen3-32b DRAFT_MODEL_DIR=/mtc/models/qwen3-32b-eagle3 -LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ ---tp 4 --max_total_token_num 200000 \ +PATH=/data/nvme0/chenjunyi/miniconda3/envs/lightllm/bin:$PATH + +LOADWORKER=18 /data/nvme0/chenjunyi/miniconda3/envs/lightllm/bin/python -m lightllm.server.api_server --port 8088 \ +--tp 2 \ --model_dir ${MODEL_DIR} \ --mtp_mode eagle3 \ --disable_dynamic_prompt_cache \ diff --git a/test/speculative/run_vllm_speculative_baseline.sh b/test/speculative/run_vllm_speculative_baseline.sh new file mode 100755 index 000000000..f2027f20c --- /dev/null +++ b/test/speculative/run_vllm_speculative_baseline.sh @@ -0,0 +1,298 @@ +#!/bin/bash + +# ============================================================================= +# vLLM Speculative Decoding Baseline Experiment Script +# Function: Run vLLM default draft-model speculative decoding baseline for +# different mtp steps (mapped to num_speculative_tokens), and collect +# throughput/latency metrics with the same benchmark script. +# ============================================================================= + +set -euo pipefail + +# Keep default GPU visibility aligned with existing LightLLM experiment scripts. +export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-1,2,3,4,6}" +# Reduce allocator fragmentation risk during model warmup. +export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}" + +# ============================================================================= +# Configurable Parameters +# ============================================================================= +PROJECT_DIR="/data/nvme0/chenjunyi/project/lightllm" +BENCH_PY_SCRIPT="${PROJECT_DIR}/test/benchmark/service/benchmark_sharegpt.py" +DATASET="${PROJECT_DIR}/datasets/gsm8k.json" + +# Keep defaults close to existing LightLLM qwen3-32b setup. +MODEL_DIR="/mtc/models/qwen3-32b" +DRAFT_MODEL_DIR="/mtc/models/qwen3-32b-eagle3" +TOKENIZER="/mtc/models/qwen3-32b" + +SAMPLES=1000 +CONCURRENCY=256 +PORT=8088 +TP=4 +MAX_MODEL_LEN=16384 +MAX_NUM_BATCHED_TOKENS=200000 +MAX_NUM_SEQS=256 +GPU_MEMORY_UTILIZATION=0.6 +MAX_CUDAGRAPH_CAPTURE_SIZE=256 +ATTENTION_BACKEND="FLASH_ATTN" +DISABLE_CUSTOM_ALL_REDUCE=1 +MTP_STEPS=(5) + +RESULTS_DIR="${PROJECT_DIR}/experiment_results" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +DATASET_NAME=$(basename "${DATASET}" .json) +EXPERIMENT_SUBDIR="${RESULTS_DIR}/${DATASET_NAME}_${TIMESTAMP}_vllm_spec_default" +RESULTS_FILE="${EXPERIMENT_SUBDIR}/results.csv" + +usage() { + echo "Usage: $0 [options]" + echo "" + echo "Options:" + echo " --model-dir PATH Main model path (default: ${MODEL_DIR})" + echo " --draft-model-dir PATH Draft model path (default: ${DRAFT_MODEL_DIR})" + echo " --dataset PATH Dataset path (default: ${DATASET})" + echo " --tokenizer PATH Tokenizer path (default: ${TOKENIZER})" + echo " --samples NUM Number of prompts (default: ${SAMPLES})" + echo " --concurrency NUM Concurrency (default: ${CONCURRENCY})" + echo " --port PORT Service port (default: ${PORT})" + echo " --tp NUM Tensor parallel size (default: ${TP})" + echo " --mtp-steps LIST Comma-separated mtp steps (default: 5)" + echo " --num-speculative-tokens NUM Backward-compatible alias, equals one mtp step" + echo " --max-model-len NUM vLLM max model len (default: ${MAX_MODEL_LEN})" + echo " --max-num-batched-tokens NUM vLLM max batched tokens (default: ${MAX_NUM_BATCHED_TOKENS})" + echo " --max-num-seqs NUM vLLM max number of concurrent seqs (default: ${MAX_NUM_SEQS})" + echo " --max-cudagraph-capture-size NUM vLLM max cudagraph capture size (default: ${MAX_CUDAGRAPH_CAPTURE_SIZE})" + echo " --gpu-memory-utilization F GPU memory utilization (default: ${GPU_MEMORY_UTILIZATION})" + echo " --attention-backend NAME vLLM attention backend (default: ${ATTENTION_BACKEND})" + echo " --enable-custom-all-reduce Enable custom all-reduce (default: disabled)" + echo " --results-dir DIR Results base dir (default: ${RESULTS_DIR})" + echo " --help Show this help" + exit 1 +} + +while [[ $# -gt 0 ]]; do + case "$1" in + --model-dir) + MODEL_DIR="$2" + shift 2 + ;; + --draft-model-dir) + DRAFT_MODEL_DIR="$2" + shift 2 + ;; + --dataset) + DATASET="$2" + shift 2 + ;; + --tokenizer) + TOKENIZER="$2" + shift 2 + ;; + --samples) + SAMPLES="$2" + shift 2 + ;; + --concurrency) + CONCURRENCY="$2" + shift 2 + ;; + --port) + PORT="$2" + shift 2 + ;; + --tp) + TP="$2" + shift 2 + ;; + --mtp-steps) + IFS=',' read -ra MTP_STEPS <<< "$2" + shift 2 + ;; + --num-speculative-tokens) + MTP_STEPS=("$2") + shift 2 + ;; + --max-model-len) + MAX_MODEL_LEN="$2" + shift 2 + ;; + --max-num-batched-tokens) + MAX_NUM_BATCHED_TOKENS="$2" + shift 2 + ;; + --max-num-seqs) + MAX_NUM_SEQS="$2" + shift 2 + ;; + --max-cudagraph-capture-size) + MAX_CUDAGRAPH_CAPTURE_SIZE="$2" + shift 2 + ;; + --gpu-memory-utilization) + GPU_MEMORY_UTILIZATION="$2" + shift 2 + ;; + --attention-backend) + ATTENTION_BACKEND="$2" + shift 2 + ;; + --enable-custom-all-reduce) + DISABLE_CUSTOM_ALL_REDUCE=0 + shift 1 + ;; + --results-dir) + RESULTS_DIR="$2" + shift 2 + ;; + --help) + usage + ;; + *) + echo "Unknown argument: $1" + usage + ;; + esac +done + +# Recompute result paths in case dataset/results-dir was overridden. +DATASET_NAME=$(basename "${DATASET}" .json) +EXPERIMENT_SUBDIR="${RESULTS_DIR}/${DATASET_NAME}_${TIMESTAMP}_vllm_spec_default" +RESULTS_FILE="${EXPERIMENT_SUBDIR}/results.csv" + +mkdir -p "${EXPERIMENT_SUBDIR}" + +echo "timestamp,engine,mode,mtp_step,dataset,samples,concurrency,throughput,avg_latency,avg_ttft,avg_inter_token_latency" > "${RESULTS_FILE}" + +wait_for_server() { + local max_attempts=600 + local attempt=0 + echo "Waiting for vLLM server to start..." + while [[ ${attempt} -lt ${max_attempts} ]]; do + if curl -s "http://localhost:${PORT}/health" > /dev/null 2>&1; then + echo "vLLM server started" + return 0 + fi + sleep 2 + attempt=$((attempt + 1)) + done + echo "vLLM server startup timeout" + return 1 +} + +extract_benchmark_metrics() { + local log_file="$1" + local throughput="" + local avg_latency="" + local avg_ttft="" + local avg_inter_token_latency="" + + throughput=$(grep -oP 'Throughput: \K[\d.]+' "$log_file" | tail -1) + avg_latency=$(grep -oP 'Average latency: \K[\d.]+' "$log_file" | tail -1) + avg_ttft=$(grep -oP 'Average time to first token: \K[\d.]+' "$log_file" | tail -1) + avg_inter_token_latency=$(grep -oP 'Average inter-token latency: \K[\d.]+' "$log_file" | tail -1) + + echo "${throughput:-NA},${avg_latency:-NA},${avg_ttft:-NA},${avg_inter_token_latency:-NA}" +} + +kill_vllm() { + echo "Stopping vLLM server..." + pkill -9 -f "vllm serve" 2>/dev/null || true + pkill -9 -f "vllm.entrypoints.openai.api_server" 2>/dev/null || true + sleep 1 + echo "vLLM server stopped" +} + +trap 'kill_vllm' EXIT + +echo "==============================================" +echo "vLLM Speculative Baseline Started" +echo "==============================================" +echo "Model: ${MODEL_DIR}" +echo "Draft model: ${DRAFT_MODEL_DIR}" +echo "Tokenizer: ${TOKENIZER}" +echo "Dataset: ${DATASET}" +echo "Samples: ${SAMPLES}" +echo "Concurrency: ${CONCURRENCY}" +echo "TP: ${TP}" +echo "Port: ${PORT}" +echo "Max model len: ${MAX_MODEL_LEN}" +echo "Max batched tokens: ${MAX_NUM_BATCHED_TOKENS}" +echo "Max num seqs: ${MAX_NUM_SEQS}" +echo "Max cudagraph capture size: ${MAX_CUDAGRAPH_CAPTURE_SIZE}" +echo "GPU memory utilization: ${GPU_MEMORY_UTILIZATION}" +echo "Attention backend: ${ATTENTION_BACKEND}" +echo "Disable custom all reduce: ${DISABLE_CUSTOM_ALL_REDUCE}" +echo "MTP steps: ${MTP_STEPS[*]}" +echo "Results directory: ${EXPERIMENT_SUBDIR}" +echo "==============================================" + +for MTP_STEP in "${MTP_STEPS[@]}"; do + echo "" + echo "--- Running mtp step: ${MTP_STEP} ---" + + LOG_FILE="${EXPERIMENT_SUBDIR}/log_vllm_spec_default_step${MTP_STEP}_${TIMESTAMP}.txt" + BENCH_LOG="${EXPERIMENT_SUBDIR}/bench_vllm_spec_default_step${MTP_STEP}_${TIMESTAMP}.txt" + + SPECULATIVE_CONFIG=$(printf '{"model": "%s", "num_speculative_tokens": %s, "method": "draft_model"}' \ + "${DRAFT_MODEL_DIR}" "${MTP_STEP}") + CUSTOM_ALL_REDUCE_FLAG="" + if [[ "${DISABLE_CUSTOM_ALL_REDUCE}" == "1" ]]; then + CUSTOM_ALL_REDUCE_FLAG="--disable-custom-all-reduce" + fi + + kill_vllm + + echo "Starting vLLM server with speculative_config=${SPECULATIVE_CONFIG}" + ( + vllm serve "${MODEL_DIR}" \ + --host 0.0.0.0 \ + --port "${PORT}" \ + --served-model-name DeepSeek-R1 \ + -tp "${TP}" \ + --max_model_len "${MAX_MODEL_LEN}" \ + --max_num_batched_tokens "${MAX_NUM_BATCHED_TOKENS}" \ + --max_num_seqs "${MAX_NUM_SEQS}" \ + --max-cudagraph-capture-size "${MAX_CUDAGRAPH_CAPTURE_SIZE}" \ + --attention-backend "${ATTENTION_BACKEND}" \ + ${CUSTOM_ALL_REDUCE_FLAG} \ + --speculative_config "${SPECULATIVE_CONFIG}" + ) > "${LOG_FILE}" 2>&1 & + + SERVER_PID=$! + echo "vLLM PID: ${SERVER_PID}" + + if ! wait_for_server; then + echo "vLLM server failed to start for mtp step ${MTP_STEP}. Check log: ${LOG_FILE}" + RESULT_LINE="${TIMESTAMP},vllm,speculative_draft_model_default,${MTP_STEP},${DATASET},${SAMPLES},${CONCURRENCY},NA,NA,NA,NA" + echo "${RESULT_LINE}" >> "${RESULTS_FILE}" + continue + fi + + sleep 5 + + echo "Running benchmark with benchmark_sharegpt.py (OpenAI API mode)..." + python "${BENCH_PY_SCRIPT}" \ + --use_openai_api \ + --port "${PORT}" \ + --num-prompts "${SAMPLES}" \ + --tokenizer "${TOKENIZER}" \ + --dataset "${DATASET}" \ + --history-turns 1 \ + --concurrency "${CONCURRENCY}" 2>&1 | tee "${BENCH_LOG}" + + cat "${BENCH_LOG}" >> "${LOG_FILE}" + + BENCH_METRICS=$(extract_benchmark_metrics "${LOG_FILE}") + RESULT_LINE="${TIMESTAMP},vllm,speculative_draft_model_default,${MTP_STEP},${DATASET},${SAMPLES},${CONCURRENCY},${BENCH_METRICS}" + echo "${RESULT_LINE}" >> "${RESULTS_FILE}" + + echo "Completed mtp step ${MTP_STEP}: ${RESULT_LINE}" +done + +echo "" +echo "==============================================" +echo "All Experiments Completed" +echo "==============================================" +echo "Results file: ${RESULTS_FILE}" +cat "${RESULTS_FILE}"