From 74189c27672cd850d4b6370efcc2e8611e18da11 Mon Sep 17 00:00:00 2001 From: Junyi Chen Date: Mon, 20 Apr 2026 20:23:52 +0800 Subject: [PATCH 01/12] add vllm test script --- .gitignore | 1 + test/benchmark/service/benchmark_sharegpt.py | 29 +- test/speculative/bench_throughput.sh | 2 +- .../run_vllm_speculative_baseline.sh | 298 ++++++++++++++++++ 4 files changed, 325 insertions(+), 5 deletions(-) create mode 100755 test/speculative/run_vllm_speculative_baseline.sh diff --git a/.gitignore b/.gitignore index d572eac42..b1717ce67 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ __pycache__/ .pyc +.codex build dist *.egg-info diff --git a/test/benchmark/service/benchmark_sharegpt.py b/test/benchmark/service/benchmark_sharegpt.py index b056e69bd..9337e6527 100644 --- a/test/benchmark/service/benchmark_sharegpt.py +++ b/test/benchmark/service/benchmark_sharegpt.py @@ -215,7 +215,7 @@ async def send_request( "top_k": 1, "top_p": 1.0, "temperature": 0, - "stream": True, + # "stream": True, "ignore_eos": True, "max_tokens": output_len, } @@ -224,20 +224,41 @@ async def send_request( async with aiohttp.ClientSession(timeout=timeout) as session: async with session.post(url, headers=headers, json=data) as response: + response.raise_for_status() chunks = [] text = "" start_time = time.time() is_first = True + sse_buffer = "" async for chunk, _ in response.content.iter_chunks(): now_time = time.time() delta_time = now_time - start_time if is_first: is_first = False ttft = delta_time - text += json.loads(chunk.decode("utf-8")[6:])["choices"][0]["delta"].get("content", "") - if delta_time < 0.005: - receive_n += 1 chunks.append(delta_time) + # OpenAI-compatible stream is SSE; one TCP chunk may contain + # partial/multiple events. Parse by complete lines safely. + sse_buffer += chunk.decode("utf-8", errors="ignore") + while "\n" in sse_buffer: + line, sse_buffer = sse_buffer.split("\n", 1) + line = line.strip() + if not line or not line.startswith("data:"): + continue + payload = line[5:].strip() + if payload == "[DONE]": + break + if not payload: + continue + try: + event = json.loads(payload) + except json.JSONDecodeError: + # In rare cases malformed/partial payload slips in; + # skip and continue to keep benchmark running. + continue + text += event.get("choices", [{}])[0].get("delta", {}).get("content", "") + if delta_time < 0.005: + receive_n += 1 start_time = now_time # print("messages", messages) # print("text", text) diff --git a/test/speculative/bench_throughput.sh b/test/speculative/bench_throughput.sh index 8e14f8189..4cfa90bcb 100644 --- a/test/speculative/bench_throughput.sh +++ b/test/speculative/bench_throughput.sh @@ -2,7 +2,7 @@ # 默认值 PORT=8088 NUM_PROMPTS=1000 -TOKENIZER="/mtc/models/qwen3-8b" +TOKENIZER="/mtc/models/qwen3-32b" DATASET="/data/nvme0/chenjunyi/project/lightllm/datasets/gsm8k.json" HISTORY_TURNS=1 CONCURRENCY=128 diff --git a/test/speculative/run_vllm_speculative_baseline.sh b/test/speculative/run_vllm_speculative_baseline.sh new file mode 100755 index 000000000..f2027f20c --- /dev/null +++ b/test/speculative/run_vllm_speculative_baseline.sh @@ -0,0 +1,298 @@ +#!/bin/bash + +# ============================================================================= +# vLLM Speculative Decoding Baseline Experiment Script +# Function: Run vLLM default draft-model speculative decoding baseline for +# different mtp steps (mapped to num_speculative_tokens), and collect +# throughput/latency metrics with the same benchmark script. +# ============================================================================= + +set -euo pipefail + +# Keep default GPU visibility aligned with existing LightLLM experiment scripts. +export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-1,2,3,4,6}" +# Reduce allocator fragmentation risk during model warmup. +export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}" + +# ============================================================================= +# Configurable Parameters +# ============================================================================= +PROJECT_DIR="/data/nvme0/chenjunyi/project/lightllm" +BENCH_PY_SCRIPT="${PROJECT_DIR}/test/benchmark/service/benchmark_sharegpt.py" +DATASET="${PROJECT_DIR}/datasets/gsm8k.json" + +# Keep defaults close to existing LightLLM qwen3-32b setup. +MODEL_DIR="/mtc/models/qwen3-32b" +DRAFT_MODEL_DIR="/mtc/models/qwen3-32b-eagle3" +TOKENIZER="/mtc/models/qwen3-32b" + +SAMPLES=1000 +CONCURRENCY=256 +PORT=8088 +TP=4 +MAX_MODEL_LEN=16384 +MAX_NUM_BATCHED_TOKENS=200000 +MAX_NUM_SEQS=256 +GPU_MEMORY_UTILIZATION=0.6 +MAX_CUDAGRAPH_CAPTURE_SIZE=256 +ATTENTION_BACKEND="FLASH_ATTN" +DISABLE_CUSTOM_ALL_REDUCE=1 +MTP_STEPS=(5) + +RESULTS_DIR="${PROJECT_DIR}/experiment_results" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +DATASET_NAME=$(basename "${DATASET}" .json) +EXPERIMENT_SUBDIR="${RESULTS_DIR}/${DATASET_NAME}_${TIMESTAMP}_vllm_spec_default" +RESULTS_FILE="${EXPERIMENT_SUBDIR}/results.csv" + +usage() { + echo "Usage: $0 [options]" + echo "" + echo "Options:" + echo " --model-dir PATH Main model path (default: ${MODEL_DIR})" + echo " --draft-model-dir PATH Draft model path (default: ${DRAFT_MODEL_DIR})" + echo " --dataset PATH Dataset path (default: ${DATASET})" + echo " --tokenizer PATH Tokenizer path (default: ${TOKENIZER})" + echo " --samples NUM Number of prompts (default: ${SAMPLES})" + echo " --concurrency NUM Concurrency (default: ${CONCURRENCY})" + echo " --port PORT Service port (default: ${PORT})" + echo " --tp NUM Tensor parallel size (default: ${TP})" + echo " --mtp-steps LIST Comma-separated mtp steps (default: 5)" + echo " --num-speculative-tokens NUM Backward-compatible alias, equals one mtp step" + echo " --max-model-len NUM vLLM max model len (default: ${MAX_MODEL_LEN})" + echo " --max-num-batched-tokens NUM vLLM max batched tokens (default: ${MAX_NUM_BATCHED_TOKENS})" + echo " --max-num-seqs NUM vLLM max number of concurrent seqs (default: ${MAX_NUM_SEQS})" + echo " --max-cudagraph-capture-size NUM vLLM max cudagraph capture size (default: ${MAX_CUDAGRAPH_CAPTURE_SIZE})" + echo " --gpu-memory-utilization F GPU memory utilization (default: ${GPU_MEMORY_UTILIZATION})" + echo " --attention-backend NAME vLLM attention backend (default: ${ATTENTION_BACKEND})" + echo " --enable-custom-all-reduce Enable custom all-reduce (default: disabled)" + echo " --results-dir DIR Results base dir (default: ${RESULTS_DIR})" + echo " --help Show this help" + exit 1 +} + +while [[ $# -gt 0 ]]; do + case "$1" in + --model-dir) + MODEL_DIR="$2" + shift 2 + ;; + --draft-model-dir) + DRAFT_MODEL_DIR="$2" + shift 2 + ;; + --dataset) + DATASET="$2" + shift 2 + ;; + --tokenizer) + TOKENIZER="$2" + shift 2 + ;; + --samples) + SAMPLES="$2" + shift 2 + ;; + --concurrency) + CONCURRENCY="$2" + shift 2 + ;; + --port) + PORT="$2" + shift 2 + ;; + --tp) + TP="$2" + shift 2 + ;; + --mtp-steps) + IFS=',' read -ra MTP_STEPS <<< "$2" + shift 2 + ;; + --num-speculative-tokens) + MTP_STEPS=("$2") + shift 2 + ;; + --max-model-len) + MAX_MODEL_LEN="$2" + shift 2 + ;; + --max-num-batched-tokens) + MAX_NUM_BATCHED_TOKENS="$2" + shift 2 + ;; + --max-num-seqs) + MAX_NUM_SEQS="$2" + shift 2 + ;; + --max-cudagraph-capture-size) + MAX_CUDAGRAPH_CAPTURE_SIZE="$2" + shift 2 + ;; + --gpu-memory-utilization) + GPU_MEMORY_UTILIZATION="$2" + shift 2 + ;; + --attention-backend) + ATTENTION_BACKEND="$2" + shift 2 + ;; + --enable-custom-all-reduce) + DISABLE_CUSTOM_ALL_REDUCE=0 + shift 1 + ;; + --results-dir) + RESULTS_DIR="$2" + shift 2 + ;; + --help) + usage + ;; + *) + echo "Unknown argument: $1" + usage + ;; + esac +done + +# Recompute result paths in case dataset/results-dir was overridden. +DATASET_NAME=$(basename "${DATASET}" .json) +EXPERIMENT_SUBDIR="${RESULTS_DIR}/${DATASET_NAME}_${TIMESTAMP}_vllm_spec_default" +RESULTS_FILE="${EXPERIMENT_SUBDIR}/results.csv" + +mkdir -p "${EXPERIMENT_SUBDIR}" + +echo "timestamp,engine,mode,mtp_step,dataset,samples,concurrency,throughput,avg_latency,avg_ttft,avg_inter_token_latency" > "${RESULTS_FILE}" + +wait_for_server() { + local max_attempts=600 + local attempt=0 + echo "Waiting for vLLM server to start..." + while [[ ${attempt} -lt ${max_attempts} ]]; do + if curl -s "http://localhost:${PORT}/health" > /dev/null 2>&1; then + echo "vLLM server started" + return 0 + fi + sleep 2 + attempt=$((attempt + 1)) + done + echo "vLLM server startup timeout" + return 1 +} + +extract_benchmark_metrics() { + local log_file="$1" + local throughput="" + local avg_latency="" + local avg_ttft="" + local avg_inter_token_latency="" + + throughput=$(grep -oP 'Throughput: \K[\d.]+' "$log_file" | tail -1) + avg_latency=$(grep -oP 'Average latency: \K[\d.]+' "$log_file" | tail -1) + avg_ttft=$(grep -oP 'Average time to first token: \K[\d.]+' "$log_file" | tail -1) + avg_inter_token_latency=$(grep -oP 'Average inter-token latency: \K[\d.]+' "$log_file" | tail -1) + + echo "${throughput:-NA},${avg_latency:-NA},${avg_ttft:-NA},${avg_inter_token_latency:-NA}" +} + +kill_vllm() { + echo "Stopping vLLM server..." + pkill -9 -f "vllm serve" 2>/dev/null || true + pkill -9 -f "vllm.entrypoints.openai.api_server" 2>/dev/null || true + sleep 1 + echo "vLLM server stopped" +} + +trap 'kill_vllm' EXIT + +echo "==============================================" +echo "vLLM Speculative Baseline Started" +echo "==============================================" +echo "Model: ${MODEL_DIR}" +echo "Draft model: ${DRAFT_MODEL_DIR}" +echo "Tokenizer: ${TOKENIZER}" +echo "Dataset: ${DATASET}" +echo "Samples: ${SAMPLES}" +echo "Concurrency: ${CONCURRENCY}" +echo "TP: ${TP}" +echo "Port: ${PORT}" +echo "Max model len: ${MAX_MODEL_LEN}" +echo "Max batched tokens: ${MAX_NUM_BATCHED_TOKENS}" +echo "Max num seqs: ${MAX_NUM_SEQS}" +echo "Max cudagraph capture size: ${MAX_CUDAGRAPH_CAPTURE_SIZE}" +echo "GPU memory utilization: ${GPU_MEMORY_UTILIZATION}" +echo "Attention backend: ${ATTENTION_BACKEND}" +echo "Disable custom all reduce: ${DISABLE_CUSTOM_ALL_REDUCE}" +echo "MTP steps: ${MTP_STEPS[*]}" +echo "Results directory: ${EXPERIMENT_SUBDIR}" +echo "==============================================" + +for MTP_STEP in "${MTP_STEPS[@]}"; do + echo "" + echo "--- Running mtp step: ${MTP_STEP} ---" + + LOG_FILE="${EXPERIMENT_SUBDIR}/log_vllm_spec_default_step${MTP_STEP}_${TIMESTAMP}.txt" + BENCH_LOG="${EXPERIMENT_SUBDIR}/bench_vllm_spec_default_step${MTP_STEP}_${TIMESTAMP}.txt" + + SPECULATIVE_CONFIG=$(printf '{"model": "%s", "num_speculative_tokens": %s, "method": "draft_model"}' \ + "${DRAFT_MODEL_DIR}" "${MTP_STEP}") + CUSTOM_ALL_REDUCE_FLAG="" + if [[ "${DISABLE_CUSTOM_ALL_REDUCE}" == "1" ]]; then + CUSTOM_ALL_REDUCE_FLAG="--disable-custom-all-reduce" + fi + + kill_vllm + + echo "Starting vLLM server with speculative_config=${SPECULATIVE_CONFIG}" + ( + vllm serve "${MODEL_DIR}" \ + --host 0.0.0.0 \ + --port "${PORT}" \ + --served-model-name DeepSeek-R1 \ + -tp "${TP}" \ + --max_model_len "${MAX_MODEL_LEN}" \ + --max_num_batched_tokens "${MAX_NUM_BATCHED_TOKENS}" \ + --max_num_seqs "${MAX_NUM_SEQS}" \ + --max-cudagraph-capture-size "${MAX_CUDAGRAPH_CAPTURE_SIZE}" \ + --attention-backend "${ATTENTION_BACKEND}" \ + ${CUSTOM_ALL_REDUCE_FLAG} \ + --speculative_config "${SPECULATIVE_CONFIG}" + ) > "${LOG_FILE}" 2>&1 & + + SERVER_PID=$! + echo "vLLM PID: ${SERVER_PID}" + + if ! wait_for_server; then + echo "vLLM server failed to start for mtp step ${MTP_STEP}. Check log: ${LOG_FILE}" + RESULT_LINE="${TIMESTAMP},vllm,speculative_draft_model_default,${MTP_STEP},${DATASET},${SAMPLES},${CONCURRENCY},NA,NA,NA,NA" + echo "${RESULT_LINE}" >> "${RESULTS_FILE}" + continue + fi + + sleep 5 + + echo "Running benchmark with benchmark_sharegpt.py (OpenAI API mode)..." + python "${BENCH_PY_SCRIPT}" \ + --use_openai_api \ + --port "${PORT}" \ + --num-prompts "${SAMPLES}" \ + --tokenizer "${TOKENIZER}" \ + --dataset "${DATASET}" \ + --history-turns 1 \ + --concurrency "${CONCURRENCY}" 2>&1 | tee "${BENCH_LOG}" + + cat "${BENCH_LOG}" >> "${LOG_FILE}" + + BENCH_METRICS=$(extract_benchmark_metrics "${LOG_FILE}") + RESULT_LINE="${TIMESTAMP},vllm,speculative_draft_model_default,${MTP_STEP},${DATASET},${SAMPLES},${CONCURRENCY},${BENCH_METRICS}" + echo "${RESULT_LINE}" >> "${RESULTS_FILE}" + + echo "Completed mtp step ${MTP_STEP}: ${RESULT_LINE}" +done + +echo "" +echo "==============================================" +echo "All Experiments Completed" +echo "==============================================" +echo "Results file: ${RESULTS_FILE}" +cat "${RESULTS_FILE}" From df549356461870e5b93c310240818aeb5b1d166b Mon Sep 17 00:00:00 2001 From: Junyi Chen Date: Thu, 7 May 2026 21:10:43 +0800 Subject: [PATCH 02/12] remove 200000 token limit in test script --- test/speculative/qwen3-32b/dynamic_triton.sh | 6 ++++-- test/speculative/qwen3-32b/no_mtp_fa3.sh | 11 +++++++++++ test/speculative/qwen3-32b/static_fa3.sh | 6 ++++-- test/speculative/qwen3-32b/static_triton.sh | 6 ++++-- 4 files changed, 23 insertions(+), 6 deletions(-) create mode 100644 test/speculative/qwen3-32b/no_mtp_fa3.sh diff --git a/test/speculative/qwen3-32b/dynamic_triton.sh b/test/speculative/qwen3-32b/dynamic_triton.sh index 39145e5f5..ca70bf15e 100644 --- a/test/speculative/qwen3-32b/dynamic_triton.sh +++ b/test/speculative/qwen3-32b/dynamic_triton.sh @@ -16,8 +16,10 @@ done MODEL_DIR=/mtc/models/qwen3-32b DRAFT_MODEL_DIR=/mtc/models/qwen3-32b-eagle3 -LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ ---tp 4 --max_total_token_num 200000 \ +PATH=/data/nvme0/chenjunyi/miniconda3/envs/lightllm/bin:$PATH + +LOADWORKER=18 /data/nvme0/chenjunyi/miniconda3/envs/lightllm/bin/python -m lightllm.server.api_server --port 8088 \ +--tp 2 \ --model_dir ${MODEL_DIR} \ --mtp_mode eagle3 \ --disable_dynamic_prompt_cache \ diff --git a/test/speculative/qwen3-32b/no_mtp_fa3.sh b/test/speculative/qwen3-32b/no_mtp_fa3.sh new file mode 100644 index 000000000..c17562721 --- /dev/null +++ b/test/speculative/qwen3-32b/no_mtp_fa3.sh @@ -0,0 +1,11 @@ +MODEL_DIR=/mtc/models/qwen3-32b +DRAFT_MODEL_DIR=/mtc/models/qwen3-32b-eagle3 + +PATH=/data/nvme0/chenjunyi/miniconda3/envs/lightllm/bin:$PATH + +LOADWORKER=18 /data/nvme0/chenjunyi/miniconda3/envs/lightllm/bin/python -m lightllm.server.api_server --port 8088 \ +--tp 2 \ +--model_dir ${MODEL_DIR} \ +--disable_dynamic_prompt_cache \ +--graph_grow_step_size 1 \ +--llm_decode_att_backend triton \ No newline at end of file diff --git a/test/speculative/qwen3-32b/static_fa3.sh b/test/speculative/qwen3-32b/static_fa3.sh index c9712116e..44c67e03b 100644 --- a/test/speculative/qwen3-32b/static_fa3.sh +++ b/test/speculative/qwen3-32b/static_fa3.sh @@ -16,8 +16,10 @@ done MODEL_DIR=/mtc/models/qwen3-32b DRAFT_MODEL_DIR=/mtc/models/qwen3-32b-eagle3 -LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ ---tp 4 --max_total_token_num 200000 \ +PATH=/data/nvme0/chenjunyi/miniconda3/envs/lightllm/bin:$PATH + +LOADWORKER=18 /data/nvme0/chenjunyi/miniconda3/envs/lightllm/bin/python -m lightllm.server.api_server --port 8088 \ +--tp 2 \ --model_dir ${MODEL_DIR} \ --mtp_mode eagle3 \ --mtp_draft_model_dir ${DRAFT_MODEL_DIR} \ diff --git a/test/speculative/qwen3-32b/static_triton.sh b/test/speculative/qwen3-32b/static_triton.sh index 453c5678e..71964c9af 100644 --- a/test/speculative/qwen3-32b/static_triton.sh +++ b/test/speculative/qwen3-32b/static_triton.sh @@ -16,8 +16,10 @@ done MODEL_DIR=/mtc/models/qwen3-32b DRAFT_MODEL_DIR=/mtc/models/qwen3-32b-eagle3 -LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ ---tp 4 --max_total_token_num 200000 \ +PATH=/data/nvme0/chenjunyi/miniconda3/envs/lightllm/bin:$PATH + +LOADWORKER=18 /data/nvme0/chenjunyi/miniconda3/envs/lightllm/bin/python -m lightllm.server.api_server --port 8088 \ +--tp 2 \ --model_dir ${MODEL_DIR} \ --mtp_mode eagle3 \ --disable_dynamic_prompt_cache \ From db0a2ca692c2525b8c5ca8b8a02ed635b2e1470e Mon Sep 17 00:00:00 2001 From: Junyi Chen Date: Thu, 14 May 2026 20:18:05 +0800 Subject: [PATCH 03/12] first commit for ema --- .../mode_backend/chunked_prefill/impl.py | 91 ++++---------- .../mode_backend/dynamic_mtp_planner.py | 112 ++++++++++++++++++ .../mode_backend/generic_pre_process.py | 40 +++++-- 3 files changed, 166 insertions(+), 77 deletions(-) create mode 100644 lightllm/server/router/model_infer/mode_backend/dynamic_mtp_planner.py diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index c41dbb6d9..6b1578ae7 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -25,6 +25,7 @@ from lightllm.utils.log_utils import init_logger from lightllm.utils.dist_utils import get_current_device_id from lightllm.utils.envs_utils import get_env_start_args, enable_dynamic_mtp_verify +from lightllm.server.router.model_infer.mode_backend.dynamic_mtp_planner import DynamicMTPPlanner from .control_state import ControlState logger = init_logger(__name__) @@ -45,6 +46,9 @@ def __init__(self) -> None: self.num_mtp_models = 1 if self.is_mtp_eagle else get_env_start_args().mtp_step self._draft_decode_func = self._draft_decode_eagle if self.is_mtp_eagle else self._draft_decode_vanilla self.enable_dynamic_mtp = enable_dynamic_mtp_verify() + self.dynamic_mtp_planner = ( + DynamicMTPPlanner(max_mtp_step=get_env_start_args().mtp_step) if self.enable_dynamic_mtp else None + ) else: self.prefill = self.prefill_normal self.decode = self.decode_normal @@ -233,7 +237,15 @@ def decode_mtp( """ MTP解码的通用流程,整合eagle和vanilla的共同逻辑 """ - model_input, run_reqs = prepare_decode_inputs(decode_reqs) + mtp_plan = None + if self.enable_dynamic_mtp: + mtp_plan = self.dynamic_mtp_planner.build_plan(decode_reqs) + model_input, run_reqs = prepare_decode_inputs( + decode_reqs, + mtp_decode_indexes=mtp_plan.selected_mtp_indexes, + ) + else: + model_input, run_reqs = prepare_decode_inputs(decode_reqs) with torch.cuda.stream(g_infer_context.get_overlap_stream()): b_mtp_index_cpu = model_input.b_mtp_index @@ -260,32 +272,12 @@ def decode_mtp( verify_event = torch.cuda.Event() verify_event.record() - if self.enable_dynamic_mtp: - all_next_token_ids, additional_mem_indexes_cpu, draft_probs_list = self._draft_decode_func( - main_model_input=model_input, - main_model_output=model_output, - next_token_ids=next_token_ids, - b_req_mtp_start_loc=b_req_mtp_start_loc, - ) - else: - all_next_token_ids, additional_mem_indexes_cpu = self._draft_decode_func( - main_model_input=model_input, - main_model_output=model_output, - next_token_ids=next_token_ids, - b_req_mtp_start_loc=b_req_mtp_start_loc, - ) - - # dynamic_sizes_gpu 用于第二阶段更新 req 的 mtp_size - if self.enable_dynamic_mtp: - draft_probs_tensor = torch.cat(draft_probs_list, dim=-1).view(self.mtp_step, b_mtp_index_cpu.shape[0]) - dynamic_sizes_gpu = self._compute_dynamic_mtp_size_gpu_part(draft_probs_tensor=draft_probs_tensor) - # 异步拷贝回 CPU Pin Memory - dynamic_sizes_cpu = g_pin_mem_manager.async_copy_from_gpu_tensor( - key="dynamic_mtp_sizes", gpu_tensor=dynamic_sizes_gpu - ) - - dynamic_mtp_event = torch.cuda.Event() - dynamic_mtp_event.record() + all_next_token_ids, additional_mem_indexes_cpu = self._draft_decode_func( + main_model_input=model_input, + main_model_output=model_output, + next_token_ids=next_token_ids, + b_req_mtp_start_loc=b_req_mtp_start_loc, + ) mtp_scatter_next_token_ids( req_to_next_token_ids=self.model.req_manager.req_sampling_params_manager.req_to_next_token_ids, @@ -320,11 +312,6 @@ def decode_mtp( verify_event.synchronize() accepted_index_cpu_numpy = accepted_index_cpu.numpy() verify_ok_reqs = [run_reqs[i] for i in range(len(run_reqs)) if accepted_index_cpu_numpy[i] == 1] - if self.enable_dynamic_mtp: - dynamic_mtp_event.synchronize() - self._update_dynamic_mtp_size_cpu_part( - run_reqs=run_reqs, dynamic_sizes_cpu=dynamic_sizes_cpu, accepted_index_cpu=accepted_index_cpu - ) update_packs = self._pre_post_handle(verify_ok_reqs, is_chuncked_mode=False) # 第三阶段 @@ -337,6 +324,9 @@ def decode_mtp( need_free_mem_indexes = torch.cat([need_free_mem_indexes, additional_mem_indexes_cpu], dim=0) self._update_mtp_accept_ratio(decode_reqs=decode_reqs, mtp_accept_len_cpu=mtp_accept_len_cpu) + if self.enable_dynamic_mtp: + self.dynamic_mtp_planner.update(decode_reqs, mtp_accept_len_cpu) + select_mask = torch.tensor(accepted_index_cpu, dtype=torch.bool, device="cpu") self._post_handle( run_reqs=verify_ok_reqs, @@ -355,28 +345,6 @@ def decode_mtp( event_pack.notify_pre_post_handle() return - def _compute_dynamic_mtp_size_gpu_part( - self, - draft_probs_tensor: torch.Tensor, - ) -> torch.Tensor: - rand_vals = torch.rand_like(draft_probs_tensor) - accepted_mask = draft_probs_tensor > rand_vals - valid_steps = torch.cumprod(accepted_mask.to(torch.int32), dim=0) - dynamic_mtp_sizes = valid_steps.sum(dim=0) - return dynamic_mtp_sizes - - def _update_dynamic_mtp_size_cpu_part( - self, - run_reqs: List[InferReq], - dynamic_sizes_cpu: torch.Tensor, - accepted_index_cpu: torch.Tensor, - ): - assert len(run_reqs) == dynamic_sizes_cpu.shape[0] == accepted_index_cpu.shape[0] - for req, new_size, accepted in zip(run_reqs, dynamic_sizes_cpu.numpy(), accepted_index_cpu.numpy()): - if int(accepted) == 1: - req.current_mtp_step = int(new_size) - assert req.current_mtp_step <= req.mtp_step - def _draft_prefill_forward(self, model_input: ModelInput, model_output: ModelOutput, next_token_ids: torch.Tensor): # spec prefill: MTP, 这个地方只是为了填充draft model的 kv, 并不会使用生成的token_id。 draft_model_input = model_input @@ -442,9 +410,6 @@ def _draft_decode_eagle( all_next_token_ids = [] all_next_token_ids.append(next_token_ids) - # 用于收集每个 step 的 probs - draft_probs_list = [] if self.enable_dynamic_mtp else None - # process the draft model output for _step in range(self.mtp_step): draft_model_input.input_ids = draft_next_token_ids @@ -453,12 +418,7 @@ def _draft_decode_eagle( draft_model_idx = _step % self.num_mtp_models draft_model_output: ModelOutput = self.draft_models[draft_model_idx].forward(draft_model_input) - # 收集 probs(如果需要) - if self.enable_dynamic_mtp: - draft_next_token_ids, draft_probs = self._gen_argmax_token_ids_and_prob(draft_model_output) - draft_probs_list.append(draft_probs) - else: - draft_next_token_ids = self._gen_argmax_token_ids(draft_model_output) + draft_next_token_ids = self._gen_argmax_token_ids(draft_model_output) draft_model_input.b_seq_len += 1 draft_model_input.max_kv_seq_len += 1 eagle_mem_indexes_i = eagle_mem_indexes[_step * num_reqs : (_step + 1) * num_reqs] @@ -478,7 +438,4 @@ def _draft_decode_eagle( all_next_token_ids = torch.stack(all_next_token_ids, dim=1) # [batch_size, mtp_step + 1] - if self.enable_dynamic_mtp: - return all_next_token_ids, eagle_mem_indexes_cpu, draft_probs_list - else: - return all_next_token_ids, eagle_mem_indexes_cpu + return all_next_token_ids, eagle_mem_indexes_cpu diff --git a/lightllm/server/router/model_infer/mode_backend/dynamic_mtp_planner.py b/lightllm/server/router/model_infer/mode_backend/dynamic_mtp_planner.py new file mode 100644 index 000000000..eb2a51162 --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/dynamic_mtp_planner.py @@ -0,0 +1,112 @@ +from __future__ import annotations + +import dataclasses +import math +import threading +from typing import List, Sequence, TYPE_CHECKING + +if TYPE_CHECKING: + from lightllm.server.router.model_infer.infer_batch import InferReq + + +# Development-time knobs. Keep these local while the dynamic MTP planner is being +# tuned; move the stable subset to StartArgs once the policy settles. +EMA_ALPHA = 0.2 +BUDGET_SCALE = 1.0 +MIN_STEP = 1 +MAX_STEP = None + + +@dataclasses.dataclass +class MTPPlan: + planned_steps: List[int] + selected_mtp_indexes: List[List[int]] + budget: int + estimated_step: int + b_req_mtp_start_loc: List[int] + + +class DynamicMTPPlanner: + """ + Plans a uniform dynamic MTP verification length from historical acceptance. + + The plan is intentionally based on already available history so decode + preprocessing does not have to wait for the current draft pass to finish. + """ + + def __init__( + self, + max_mtp_step: int, + ema_alpha: float = EMA_ALPHA, + budget_scale: float = BUDGET_SCALE, + min_step: int = MIN_STEP, + max_step: int = None, + ) -> None: + assert max_mtp_step >= 0 + assert 0.0 < ema_alpha <= 1.0 + assert budget_scale > 0.0 + self.max_mtp_step = max_mtp_step + self.ema_alpha = ema_alpha + self.budget_scale = budget_scale + self.min_step = max(0, min(min_step, max_mtp_step)) + if max_step is None: + max_step = max_mtp_step if MAX_STEP is None else MAX_STEP + self.max_step = max(self.min_step, min(max_step, max_mtp_step)) + self._lock = threading.Lock() + self._ema_max_accept_step = float(self.max_step) + + def build_plan(self, reqs: Sequence[InferReq]) -> MTPPlan: + req_num = len(reqs) + if req_num == 0: + return MTPPlan( + planned_steps=[], + selected_mtp_indexes=[], + budget=0, + estimated_step=0, + b_req_mtp_start_loc=[], + ) + + with self._lock: + slot_limit = int(math.ceil(self._ema_max_accept_step * self.budget_scale)) + + slot_limit = min(max(slot_limit, self.min_step), self.max_step) + planned_steps = [slot_limit for _ in reqs] + + selected_mtp_indexes = [list(range(1, step + 1)) for step in planned_steps] + + start_locs = [] + cur_loc = 0 + for selected_indexes in selected_mtp_indexes: + start_locs.append(cur_loc) + cur_loc += 1 + len(selected_indexes) + + for req, step in zip(reqs, planned_steps): + req.current_mtp_step = step + + return MTPPlan( + planned_steps=planned_steps, + selected_mtp_indexes=selected_mtp_indexes, + budget=sum(planned_steps), + estimated_step=slot_limit, + b_req_mtp_start_loc=start_locs, + ) + + def update( + self, + reqs: Sequence[InferReq], + mtp_accept_len_cpu, + ) -> None: + if not reqs: + return + + accept_len_np = mtp_accept_len_cpu.numpy() + max_accept_step = 0 + for req_index in range(len(reqs)): + accept_len = int(accept_len_np[req_index]) + accept_step = max(0, accept_len - 1) + max_accept_step = max(max_accept_step, min(accept_step, self.max_step)) + + with self._lock: + self._ema_max_accept_step = ( + self.ema_alpha * max_accept_step + (1.0 - self.ema_alpha) * self._ema_max_accept_step + ) diff --git a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py index 3d9d8815e..c4ae24faf 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py @@ -1,6 +1,6 @@ import torch import numpy as np -from typing import List, Tuple +from typing import List, Optional, Sequence, Tuple from lightllm.server.router.model_infer.infer_batch import InferReq, g_infer_context from lightllm.common.basemodel.infer_lock import g_infer_state_lock from lightllm.common.basemodel.batch_objs import ModelInput @@ -94,7 +94,17 @@ def prepare_prefill_inputs(req_objs: List[InferReq], is_chuncked_mode: bool) -> return model_input, run_reqs -def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[InferReq]]: +def prepare_decode_inputs( + req_objs: List[InferReq], + mtp_decode_steps: Optional[Sequence[int]] = None, + mtp_decode_indexes: Optional[Sequence[Sequence[int]]] = None, +) -> Tuple[ModelInput, List[InferReq]]: + if mtp_decode_steps is not None: + assert len(mtp_decode_steps) == len(req_objs) + if mtp_decode_indexes is not None: + assert mtp_decode_steps is None + assert len(mtp_decode_indexes) == len(req_objs) + run_reqs: List[InferReq] = [] total_token_num = 0 b_req_idx = [] @@ -102,7 +112,7 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In b_seq_len = [] b_q_seq_len = [] multimodal_params = [] - for req in req_objs: + for req_index, req in enumerate(req_objs): run_reqs.append(req) b_req_idx.append(req.req_idx) seq_len = req.get_cur_total_len() @@ -113,15 +123,25 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In b_mtp_index.append(0) multimodal_params.append(req.multimodal_params) # process the draft tokens. - # 动态 MTP 模式:使用动态 current_mtp_step 构建 batch - # 非动态 MTP 模式:current_mtp_step 为固定的 mtp_step - for step in range(req.current_mtp_step): + # 动态 MTP planner 可以提前给出本轮要填充进验证槽位的 draft index。 + # 当前 planner 使用连续 prefix index;后续非连续选择可在该接口后接 compact kernel。 + if mtp_decode_indexes is not None: + decode_indexes = [int(index) for index in mtp_decode_indexes[req_index]] + assert decode_indexes == list(range(1, len(decode_indexes) + 1)), ( + "Current MTP verify path requires contiguous prefix draft indexes. " + "Non-prefix indexes need a compact/remap kernel before decode." + ) + else: + decode_step = req.current_mtp_step if mtp_decode_steps is None else int(mtp_decode_steps[req_index]) + decode_indexes = range(1, decode_step + 1) + + for mtp_index in decode_indexes: run_reqs.append(req) b_req_idx.append(req.req_idx) - seq_len += 1 - b_seq_len.append(seq_len) - total_token_num += seq_len - b_mtp_index.append(step + 1) + mtp_seq_len = seq_len + int(mtp_index) + b_seq_len.append(mtp_seq_len) + total_token_num += mtp_seq_len + b_mtp_index.append(int(mtp_index)) multimodal_params.append(req.multimodal_params) b_q_seq_len.append(1) From c6e0a7270663585aa5a9ce22eac5ae0ab4fb54fe Mon Sep 17 00:00:00 2001 From: Junyi Chen Date: Sat, 16 May 2026 13:29:44 +0800 Subject: [PATCH 04/12] Revert "first commit for ema" This reverts commit db0a2ca692c2525b8c5ca8b8a02ed635b2e1470e. --- .../mode_backend/chunked_prefill/impl.py | 91 ++++++++++---- .../mode_backend/dynamic_mtp_planner.py | 112 ------------------ .../mode_backend/generic_pre_process.py | 40 ++----- 3 files changed, 77 insertions(+), 166 deletions(-) delete mode 100644 lightllm/server/router/model_infer/mode_backend/dynamic_mtp_planner.py diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index 6b1578ae7..c41dbb6d9 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -25,7 +25,6 @@ from lightllm.utils.log_utils import init_logger from lightllm.utils.dist_utils import get_current_device_id from lightllm.utils.envs_utils import get_env_start_args, enable_dynamic_mtp_verify -from lightllm.server.router.model_infer.mode_backend.dynamic_mtp_planner import DynamicMTPPlanner from .control_state import ControlState logger = init_logger(__name__) @@ -46,9 +45,6 @@ def __init__(self) -> None: self.num_mtp_models = 1 if self.is_mtp_eagle else get_env_start_args().mtp_step self._draft_decode_func = self._draft_decode_eagle if self.is_mtp_eagle else self._draft_decode_vanilla self.enable_dynamic_mtp = enable_dynamic_mtp_verify() - self.dynamic_mtp_planner = ( - DynamicMTPPlanner(max_mtp_step=get_env_start_args().mtp_step) if self.enable_dynamic_mtp else None - ) else: self.prefill = self.prefill_normal self.decode = self.decode_normal @@ -237,15 +233,7 @@ def decode_mtp( """ MTP解码的通用流程,整合eagle和vanilla的共同逻辑 """ - mtp_plan = None - if self.enable_dynamic_mtp: - mtp_plan = self.dynamic_mtp_planner.build_plan(decode_reqs) - model_input, run_reqs = prepare_decode_inputs( - decode_reqs, - mtp_decode_indexes=mtp_plan.selected_mtp_indexes, - ) - else: - model_input, run_reqs = prepare_decode_inputs(decode_reqs) + model_input, run_reqs = prepare_decode_inputs(decode_reqs) with torch.cuda.stream(g_infer_context.get_overlap_stream()): b_mtp_index_cpu = model_input.b_mtp_index @@ -272,12 +260,32 @@ def decode_mtp( verify_event = torch.cuda.Event() verify_event.record() - all_next_token_ids, additional_mem_indexes_cpu = self._draft_decode_func( - main_model_input=model_input, - main_model_output=model_output, - next_token_ids=next_token_ids, - b_req_mtp_start_loc=b_req_mtp_start_loc, - ) + if self.enable_dynamic_mtp: + all_next_token_ids, additional_mem_indexes_cpu, draft_probs_list = self._draft_decode_func( + main_model_input=model_input, + main_model_output=model_output, + next_token_ids=next_token_ids, + b_req_mtp_start_loc=b_req_mtp_start_loc, + ) + else: + all_next_token_ids, additional_mem_indexes_cpu = self._draft_decode_func( + main_model_input=model_input, + main_model_output=model_output, + next_token_ids=next_token_ids, + b_req_mtp_start_loc=b_req_mtp_start_loc, + ) + + # dynamic_sizes_gpu 用于第二阶段更新 req 的 mtp_size + if self.enable_dynamic_mtp: + draft_probs_tensor = torch.cat(draft_probs_list, dim=-1).view(self.mtp_step, b_mtp_index_cpu.shape[0]) + dynamic_sizes_gpu = self._compute_dynamic_mtp_size_gpu_part(draft_probs_tensor=draft_probs_tensor) + # 异步拷贝回 CPU Pin Memory + dynamic_sizes_cpu = g_pin_mem_manager.async_copy_from_gpu_tensor( + key="dynamic_mtp_sizes", gpu_tensor=dynamic_sizes_gpu + ) + + dynamic_mtp_event = torch.cuda.Event() + dynamic_mtp_event.record() mtp_scatter_next_token_ids( req_to_next_token_ids=self.model.req_manager.req_sampling_params_manager.req_to_next_token_ids, @@ -312,6 +320,11 @@ def decode_mtp( verify_event.synchronize() accepted_index_cpu_numpy = accepted_index_cpu.numpy() verify_ok_reqs = [run_reqs[i] for i in range(len(run_reqs)) if accepted_index_cpu_numpy[i] == 1] + if self.enable_dynamic_mtp: + dynamic_mtp_event.synchronize() + self._update_dynamic_mtp_size_cpu_part( + run_reqs=run_reqs, dynamic_sizes_cpu=dynamic_sizes_cpu, accepted_index_cpu=accepted_index_cpu + ) update_packs = self._pre_post_handle(verify_ok_reqs, is_chuncked_mode=False) # 第三阶段 @@ -324,9 +337,6 @@ def decode_mtp( need_free_mem_indexes = torch.cat([need_free_mem_indexes, additional_mem_indexes_cpu], dim=0) self._update_mtp_accept_ratio(decode_reqs=decode_reqs, mtp_accept_len_cpu=mtp_accept_len_cpu) - if self.enable_dynamic_mtp: - self.dynamic_mtp_planner.update(decode_reqs, mtp_accept_len_cpu) - select_mask = torch.tensor(accepted_index_cpu, dtype=torch.bool, device="cpu") self._post_handle( run_reqs=verify_ok_reqs, @@ -345,6 +355,28 @@ def decode_mtp( event_pack.notify_pre_post_handle() return + def _compute_dynamic_mtp_size_gpu_part( + self, + draft_probs_tensor: torch.Tensor, + ) -> torch.Tensor: + rand_vals = torch.rand_like(draft_probs_tensor) + accepted_mask = draft_probs_tensor > rand_vals + valid_steps = torch.cumprod(accepted_mask.to(torch.int32), dim=0) + dynamic_mtp_sizes = valid_steps.sum(dim=0) + return dynamic_mtp_sizes + + def _update_dynamic_mtp_size_cpu_part( + self, + run_reqs: List[InferReq], + dynamic_sizes_cpu: torch.Tensor, + accepted_index_cpu: torch.Tensor, + ): + assert len(run_reqs) == dynamic_sizes_cpu.shape[0] == accepted_index_cpu.shape[0] + for req, new_size, accepted in zip(run_reqs, dynamic_sizes_cpu.numpy(), accepted_index_cpu.numpy()): + if int(accepted) == 1: + req.current_mtp_step = int(new_size) + assert req.current_mtp_step <= req.mtp_step + def _draft_prefill_forward(self, model_input: ModelInput, model_output: ModelOutput, next_token_ids: torch.Tensor): # spec prefill: MTP, 这个地方只是为了填充draft model的 kv, 并不会使用生成的token_id。 draft_model_input = model_input @@ -410,6 +442,9 @@ def _draft_decode_eagle( all_next_token_ids = [] all_next_token_ids.append(next_token_ids) + # 用于收集每个 step 的 probs + draft_probs_list = [] if self.enable_dynamic_mtp else None + # process the draft model output for _step in range(self.mtp_step): draft_model_input.input_ids = draft_next_token_ids @@ -418,7 +453,12 @@ def _draft_decode_eagle( draft_model_idx = _step % self.num_mtp_models draft_model_output: ModelOutput = self.draft_models[draft_model_idx].forward(draft_model_input) - draft_next_token_ids = self._gen_argmax_token_ids(draft_model_output) + # 收集 probs(如果需要) + if self.enable_dynamic_mtp: + draft_next_token_ids, draft_probs = self._gen_argmax_token_ids_and_prob(draft_model_output) + draft_probs_list.append(draft_probs) + else: + draft_next_token_ids = self._gen_argmax_token_ids(draft_model_output) draft_model_input.b_seq_len += 1 draft_model_input.max_kv_seq_len += 1 eagle_mem_indexes_i = eagle_mem_indexes[_step * num_reqs : (_step + 1) * num_reqs] @@ -438,4 +478,7 @@ def _draft_decode_eagle( all_next_token_ids = torch.stack(all_next_token_ids, dim=1) # [batch_size, mtp_step + 1] - return all_next_token_ids, eagle_mem_indexes_cpu + if self.enable_dynamic_mtp: + return all_next_token_ids, eagle_mem_indexes_cpu, draft_probs_list + else: + return all_next_token_ids, eagle_mem_indexes_cpu diff --git a/lightllm/server/router/model_infer/mode_backend/dynamic_mtp_planner.py b/lightllm/server/router/model_infer/mode_backend/dynamic_mtp_planner.py deleted file mode 100644 index eb2a51162..000000000 --- a/lightllm/server/router/model_infer/mode_backend/dynamic_mtp_planner.py +++ /dev/null @@ -1,112 +0,0 @@ -from __future__ import annotations - -import dataclasses -import math -import threading -from typing import List, Sequence, TYPE_CHECKING - -if TYPE_CHECKING: - from lightllm.server.router.model_infer.infer_batch import InferReq - - -# Development-time knobs. Keep these local while the dynamic MTP planner is being -# tuned; move the stable subset to StartArgs once the policy settles. -EMA_ALPHA = 0.2 -BUDGET_SCALE = 1.0 -MIN_STEP = 1 -MAX_STEP = None - - -@dataclasses.dataclass -class MTPPlan: - planned_steps: List[int] - selected_mtp_indexes: List[List[int]] - budget: int - estimated_step: int - b_req_mtp_start_loc: List[int] - - -class DynamicMTPPlanner: - """ - Plans a uniform dynamic MTP verification length from historical acceptance. - - The plan is intentionally based on already available history so decode - preprocessing does not have to wait for the current draft pass to finish. - """ - - def __init__( - self, - max_mtp_step: int, - ema_alpha: float = EMA_ALPHA, - budget_scale: float = BUDGET_SCALE, - min_step: int = MIN_STEP, - max_step: int = None, - ) -> None: - assert max_mtp_step >= 0 - assert 0.0 < ema_alpha <= 1.0 - assert budget_scale > 0.0 - self.max_mtp_step = max_mtp_step - self.ema_alpha = ema_alpha - self.budget_scale = budget_scale - self.min_step = max(0, min(min_step, max_mtp_step)) - if max_step is None: - max_step = max_mtp_step if MAX_STEP is None else MAX_STEP - self.max_step = max(self.min_step, min(max_step, max_mtp_step)) - self._lock = threading.Lock() - self._ema_max_accept_step = float(self.max_step) - - def build_plan(self, reqs: Sequence[InferReq]) -> MTPPlan: - req_num = len(reqs) - if req_num == 0: - return MTPPlan( - planned_steps=[], - selected_mtp_indexes=[], - budget=0, - estimated_step=0, - b_req_mtp_start_loc=[], - ) - - with self._lock: - slot_limit = int(math.ceil(self._ema_max_accept_step * self.budget_scale)) - - slot_limit = min(max(slot_limit, self.min_step), self.max_step) - planned_steps = [slot_limit for _ in reqs] - - selected_mtp_indexes = [list(range(1, step + 1)) for step in planned_steps] - - start_locs = [] - cur_loc = 0 - for selected_indexes in selected_mtp_indexes: - start_locs.append(cur_loc) - cur_loc += 1 + len(selected_indexes) - - for req, step in zip(reqs, planned_steps): - req.current_mtp_step = step - - return MTPPlan( - planned_steps=planned_steps, - selected_mtp_indexes=selected_mtp_indexes, - budget=sum(planned_steps), - estimated_step=slot_limit, - b_req_mtp_start_loc=start_locs, - ) - - def update( - self, - reqs: Sequence[InferReq], - mtp_accept_len_cpu, - ) -> None: - if not reqs: - return - - accept_len_np = mtp_accept_len_cpu.numpy() - max_accept_step = 0 - for req_index in range(len(reqs)): - accept_len = int(accept_len_np[req_index]) - accept_step = max(0, accept_len - 1) - max_accept_step = max(max_accept_step, min(accept_step, self.max_step)) - - with self._lock: - self._ema_max_accept_step = ( - self.ema_alpha * max_accept_step + (1.0 - self.ema_alpha) * self._ema_max_accept_step - ) diff --git a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py index c4ae24faf..3d9d8815e 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py @@ -1,6 +1,6 @@ import torch import numpy as np -from typing import List, Optional, Sequence, Tuple +from typing import List, Tuple from lightllm.server.router.model_infer.infer_batch import InferReq, g_infer_context from lightllm.common.basemodel.infer_lock import g_infer_state_lock from lightllm.common.basemodel.batch_objs import ModelInput @@ -94,17 +94,7 @@ def prepare_prefill_inputs(req_objs: List[InferReq], is_chuncked_mode: bool) -> return model_input, run_reqs -def prepare_decode_inputs( - req_objs: List[InferReq], - mtp_decode_steps: Optional[Sequence[int]] = None, - mtp_decode_indexes: Optional[Sequence[Sequence[int]]] = None, -) -> Tuple[ModelInput, List[InferReq]]: - if mtp_decode_steps is not None: - assert len(mtp_decode_steps) == len(req_objs) - if mtp_decode_indexes is not None: - assert mtp_decode_steps is None - assert len(mtp_decode_indexes) == len(req_objs) - +def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[InferReq]]: run_reqs: List[InferReq] = [] total_token_num = 0 b_req_idx = [] @@ -112,7 +102,7 @@ def prepare_decode_inputs( b_seq_len = [] b_q_seq_len = [] multimodal_params = [] - for req_index, req in enumerate(req_objs): + for req in req_objs: run_reqs.append(req) b_req_idx.append(req.req_idx) seq_len = req.get_cur_total_len() @@ -123,25 +113,15 @@ def prepare_decode_inputs( b_mtp_index.append(0) multimodal_params.append(req.multimodal_params) # process the draft tokens. - # 动态 MTP planner 可以提前给出本轮要填充进验证槽位的 draft index。 - # 当前 planner 使用连续 prefix index;后续非连续选择可在该接口后接 compact kernel。 - if mtp_decode_indexes is not None: - decode_indexes = [int(index) for index in mtp_decode_indexes[req_index]] - assert decode_indexes == list(range(1, len(decode_indexes) + 1)), ( - "Current MTP verify path requires contiguous prefix draft indexes. " - "Non-prefix indexes need a compact/remap kernel before decode." - ) - else: - decode_step = req.current_mtp_step if mtp_decode_steps is None else int(mtp_decode_steps[req_index]) - decode_indexes = range(1, decode_step + 1) - - for mtp_index in decode_indexes: + # 动态 MTP 模式:使用动态 current_mtp_step 构建 batch + # 非动态 MTP 模式:current_mtp_step 为固定的 mtp_step + for step in range(req.current_mtp_step): run_reqs.append(req) b_req_idx.append(req.req_idx) - mtp_seq_len = seq_len + int(mtp_index) - b_seq_len.append(mtp_seq_len) - total_token_num += mtp_seq_len - b_mtp_index.append(int(mtp_index)) + seq_len += 1 + b_seq_len.append(seq_len) + total_token_num += seq_len + b_mtp_index.append(step + 1) multimodal_params.append(req.multimodal_params) b_q_seq_len.append(1) From 5317c5ae740023ebab69480e1069f657aeb45fa8 Mon Sep 17 00:00:00 2001 From: Junyi Chen Date: Sun, 17 May 2026 21:04:09 +0800 Subject: [PATCH 05/12] fix dynamic_mtp_planner --- .../model_infer/mode_backend/base_backend.py | 19 +- .../mode_backend/chunked_prefill/impl.py | 101 ++++--- .../mode_backend/dynamic_mtp_planner.py | 275 ++++++++++++++++++ 3 files changed, 340 insertions(+), 55 deletions(-) create mode 100644 lightllm/server/router/model_infer/mode_backend/dynamic_mtp_planner.py diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 321055b4b..bde7fad52 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -775,14 +775,19 @@ def _update_mtp_accept_ratio( return - def _update_mtp_verify_token_num(self, decode_reqs: List[InferReq]): + def _update_mtp_verify_token_num( + self, + decode_reqs: List[InferReq], + verify_token_nums: Optional[List[int]] = None, + ): if self.is_master_in_dp: - for req in decode_reqs: - # 统计发送给主模型验证的 token 数量:1 个主 token + 当前 mtp_size 个 draft token - # 在静态 MTP 模式下,使用固定的 mtp_step;在动态 MTP 模式下,使用动态调整的 current_mtp_step - # current_mtp_step 在静态 MTP 模式下为 mtp_step,在动态 MTP 模式下会在推理过程中动态设置。 - assert req.current_mtp_step >= 0 - req.update_mtp_verify_token_num(verify_token_num=1 + req.current_mtp_step) + if verify_token_nums is None: + verify_token_nums = [1 + req.current_mtp_step for req in decode_reqs] + assert len(decode_reqs) == len(verify_token_nums) + for req, verify_token_num in zip(decode_reqs, verify_token_nums): + # 统计发送给主模型验证的 token 数量,动态 MTP 模式由 planner 传入实际裁剪后的行数。 + assert verify_token_num >= 1 + req.update_mtp_verify_token_num(verify_token_num=verify_token_num) return def _gen_argmax_token_ids(self, model_output: ModelOutput): diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index c41dbb6d9..71b22c08e 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -25,6 +25,7 @@ from lightllm.utils.log_utils import init_logger from lightllm.utils.dist_utils import get_current_device_id from lightllm.utils.envs_utils import get_env_start_args, enable_dynamic_mtp_verify +from lightllm.server.router.model_infer.mode_backend.dynamic_mtp_planner import DynamicMTPPlanner from .control_state import ControlState logger = init_logger(__name__) @@ -45,6 +46,9 @@ def __init__(self) -> None: self.num_mtp_models = 1 if self.is_mtp_eagle else get_env_start_args().mtp_step self._draft_decode_func = self._draft_decode_eagle if self.is_mtp_eagle else self._draft_decode_vanilla self.enable_dynamic_mtp = enable_dynamic_mtp_verify() + self.dynamic_mtp_planner = ( + DynamicMTPPlanner(mtp_step=get_env_start_args().mtp_step) if self.enable_dynamic_mtp else None + ) else: self.prefill = self.prefill_normal self.decode = self.decode_normal @@ -233,9 +237,25 @@ def decode_mtp( """ MTP解码的通用流程,整合eagle和vanilla的共同逻辑 """ + if self.enable_dynamic_mtp: + # 让通用 pre-process 始终构建最大候选池,动态策略只在 forward 前裁剪。 + for req in decode_reqs: + req.current_mtp_step = req.mtp_step model_input, run_reqs = prepare_decode_inputs(decode_reqs) + dynamic_mtp_plan = None with torch.cuda.stream(g_infer_context.get_overlap_stream()): + if self.enable_dynamic_mtp: + model_input, run_reqs, dynamic_mtp_plan = self.dynamic_mtp_planner.trim_before_forward( + model_input=model_input, + run_reqs=run_reqs, + decode_reqs=decode_reqs, + ) + dynamic_mtp_start_event = None + if self.enable_dynamic_mtp: + dynamic_mtp_start_event = torch.cuda.Event(enable_timing=True) + dynamic_mtp_start_event.record() + b_mtp_index_cpu = model_input.b_mtp_index model_output = self.model.forward(model_input) next_token_ids, next_token_logprobs = sample(model_output.logits, run_reqs, self.eos_id) @@ -257,9 +277,10 @@ def decode_mtp( gpu_tensor=accepted_index, ) - verify_event = torch.cuda.Event() + verify_event = torch.cuda.Event(enable_timing=self.enable_dynamic_mtp) verify_event.record() + per_req_probs_cpu = None if self.enable_dynamic_mtp: all_next_token_ids, additional_mem_indexes_cpu, draft_probs_list = self._draft_decode_func( main_model_input=model_input, @@ -267,6 +288,13 @@ def decode_mtp( next_token_ids=next_token_ids, b_req_mtp_start_loc=b_req_mtp_start_loc, ) + draft_probs_tensor = torch.stack(draft_probs_list, dim=1) + request_start_rows = b_req_mtp_start_loc.to(torch.long) + per_req_probs = draft_probs_tensor[request_start_rows].mean(dim=1) + per_req_probs_cpu = g_pin_mem_manager.async_copy_from_gpu_tensor( + key="dynamic_mtp_req_probs", + gpu_tensor=per_req_probs, + ) else: all_next_token_ids, additional_mem_indexes_cpu = self._draft_decode_func( main_model_input=model_input, @@ -275,18 +303,6 @@ def decode_mtp( b_req_mtp_start_loc=b_req_mtp_start_loc, ) - # dynamic_sizes_gpu 用于第二阶段更新 req 的 mtp_size - if self.enable_dynamic_mtp: - draft_probs_tensor = torch.cat(draft_probs_list, dim=-1).view(self.mtp_step, b_mtp_index_cpu.shape[0]) - dynamic_sizes_gpu = self._compute_dynamic_mtp_size_gpu_part(draft_probs_tensor=draft_probs_tensor) - # 异步拷贝回 CPU Pin Memory - dynamic_sizes_cpu = g_pin_mem_manager.async_copy_from_gpu_tensor( - key="dynamic_mtp_sizes", gpu_tensor=dynamic_sizes_gpu - ) - - dynamic_mtp_event = torch.cuda.Event() - dynamic_mtp_event.record() - mtp_scatter_next_token_ids( req_to_next_token_ids=self.model.req_manager.req_sampling_params_manager.req_to_next_token_ids, b_req_mtp_start_loc=b_req_mtp_start_loc, @@ -310,26 +326,34 @@ def decode_mtp( gpu_tensor=mtp_accept_len, ) - sync_event = torch.cuda.Event() + sync_event = torch.cuda.Event(enable_timing=self.enable_dynamic_mtp) sync_event.record() # 第二阶段 event_pack.notify_post_handle_and_wait_pre_post_handle() - self._update_mtp_verify_token_num(decode_reqs=decode_reqs) + self._update_mtp_verify_token_num( + decode_reqs=decode_reqs, + verify_token_nums=dynamic_mtp_plan.per_req_rows if dynamic_mtp_plan is not None else None, + ) verify_event.synchronize() + dynamic_mtp_elapsed_ms = None accepted_index_cpu_numpy = accepted_index_cpu.numpy() verify_ok_reqs = [run_reqs[i] for i in range(len(run_reqs)) if accepted_index_cpu_numpy[i] == 1] - if self.enable_dynamic_mtp: - dynamic_mtp_event.synchronize() - self._update_dynamic_mtp_size_cpu_part( - run_reqs=run_reqs, dynamic_sizes_cpu=dynamic_sizes_cpu, accepted_index_cpu=accepted_index_cpu - ) update_packs = self._pre_post_handle(verify_ok_reqs, is_chuncked_mode=False) # 第三阶段 event_pack.notify_forward_and_wait_post_handle() sync_event.synchronize() + if self.enable_dynamic_mtp: + dynamic_mtp_elapsed_ms = dynamic_mtp_start_event.elapsed_time(sync_event) + self.dynamic_mtp_planner.update_after_verify( + plan=dynamic_mtp_plan, + decode_reqs=decode_reqs, + mtp_accept_len_cpu=mtp_accept_len_cpu, + elapsed_ms=dynamic_mtp_elapsed_ms, + per_req_probs_cpu=per_req_probs_cpu, + ) # 处理需要释放的内存索引 need_free_mem_indexes = model_input.mem_indexes_cpu[accepted_index_cpu == 0] @@ -355,28 +379,6 @@ def decode_mtp( event_pack.notify_pre_post_handle() return - def _compute_dynamic_mtp_size_gpu_part( - self, - draft_probs_tensor: torch.Tensor, - ) -> torch.Tensor: - rand_vals = torch.rand_like(draft_probs_tensor) - accepted_mask = draft_probs_tensor > rand_vals - valid_steps = torch.cumprod(accepted_mask.to(torch.int32), dim=0) - dynamic_mtp_sizes = valid_steps.sum(dim=0) - return dynamic_mtp_sizes - - def _update_dynamic_mtp_size_cpu_part( - self, - run_reqs: List[InferReq], - dynamic_sizes_cpu: torch.Tensor, - accepted_index_cpu: torch.Tensor, - ): - assert len(run_reqs) == dynamic_sizes_cpu.shape[0] == accepted_index_cpu.shape[0] - for req, new_size, accepted in zip(run_reqs, dynamic_sizes_cpu.numpy(), accepted_index_cpu.numpy()): - if int(accepted) == 1: - req.current_mtp_step = int(new_size) - assert req.current_mtp_step <= req.mtp_step - def _draft_prefill_forward(self, model_input: ModelInput, model_output: ModelOutput, next_token_ids: torch.Tensor): # spec prefill: MTP, 这个地方只是为了填充draft model的 kv, 并不会使用生成的token_id。 draft_model_input = model_input @@ -405,17 +407,24 @@ def _draft_decode_vanilla( draft_next_token_ids = next_token_ids all_next_token_ids = [] all_next_token_ids.append(next_token_ids) + draft_probs_list = [] if self.enable_dynamic_mtp else None # process the draft model output for draft_model_idx in range(self.mtp_step): draft_model_input.input_ids = draft_next_token_ids draft_model_input.mtp_draft_input_hiddens = draft_model_output.mtp_main_output_hiddens # spec decode: MTP draft_model_output: ModelOutput = self.draft_models[draft_model_idx].forward(draft_model_input) - draft_next_token_ids = self._gen_argmax_token_ids(draft_model_output) + if self.enable_dynamic_mtp: + draft_next_token_ids, draft_probs = self._gen_argmax_token_ids_and_prob(draft_model_output) + draft_probs_list.append(draft_probs) + else: + draft_next_token_ids = self._gen_argmax_token_ids(draft_model_output) all_next_token_ids.append(draft_next_token_ids) all_next_token_ids = torch.stack(all_next_token_ids, dim=1) # [batch_size, mtp_step + 1] + if self.enable_dynamic_mtp: + return all_next_token_ids, None, draft_probs_list return all_next_token_ids, None def _draft_decode_eagle( @@ -441,8 +450,6 @@ def _draft_decode_eagle( draft_next_token_ids = next_token_ids all_next_token_ids = [] all_next_token_ids.append(next_token_ids) - - # 用于收集每个 step 的 probs draft_probs_list = [] if self.enable_dynamic_mtp else None # process the draft model output @@ -453,7 +460,6 @@ def _draft_decode_eagle( draft_model_idx = _step % self.num_mtp_models draft_model_output: ModelOutput = self.draft_models[draft_model_idx].forward(draft_model_input) - # 收集 probs(如果需要) if self.enable_dynamic_mtp: draft_next_token_ids, draft_probs = self._gen_argmax_token_ids_and_prob(draft_model_output) draft_probs_list.append(draft_probs) @@ -480,5 +486,4 @@ def _draft_decode_eagle( if self.enable_dynamic_mtp: return all_next_token_ids, eagle_mem_indexes_cpu, draft_probs_list - else: - return all_next_token_ids, eagle_mem_indexes_cpu + return all_next_token_ids, eagle_mem_indexes_cpu diff --git a/lightllm/server/router/model_infer/mode_backend/dynamic_mtp_planner.py b/lightllm/server/router/model_infer/mode_backend/dynamic_mtp_planner.py new file mode 100644 index 000000000..b09a711bd --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/dynamic_mtp_planner.py @@ -0,0 +1,275 @@ +import copy +import math +import random +from dataclasses import dataclass +from typing import Dict, List, Optional + +import torch + +from lightllm.common.basemodel.batch_objs import ModelInput +from lightllm.common.basemodel.infer_lock import g_infer_state_lock +from lightllm.server.router.model_infer.infer_batch import InferReq, g_infer_context +from lightllm.utils.envs_utils import get_diverse_max_batch_shared_group_size + + +@dataclass +class DynamicMTPPlan: + req_num: int + original_batch_size: int + dynamic_batch_size: int + keep_indices: torch.Tensor + per_req_rows: List[int] + estimated_accept_mean: float + estimated_accept_std: float + + +class _EMAValue: + def __init__(self, decay: float, init_value: Optional[float] = None) -> None: + self.decay = decay + self.value = init_value + self.initialized = init_value is not None + + def update(self, new_value: float) -> float: + if not self.initialized: + self.value = new_value + self.initialized = True + else: + self.value = self.decay * self.value + (1.0 - self.decay) * new_value + return self.value + + def get(self, fallback: float) -> float: + return self.value if self.initialized else fallback + + +class DynamicMTPPlanner: + def __init__( + self, + mtp_step: int, + ema_decay: float = 0.9, + confidence_k: float = 1.0, + ) -> None: + self.mtp_step = mtp_step + self.max_rows_per_req = mtp_step + 1 + self.confidence_k = confidence_k + self.accept_mean = _EMAValue(ema_decay, init_value=float(self.max_rows_per_req)) + self.accept_second_moment = _EMAValue(ema_decay, init_value=float(self.max_rows_per_req**2)) + self.req_accept_mean: Dict[int, _EMAValue] = {} + self.req_prob: Dict[int, _EMAValue] = {} + self.latency_ms_by_batch_size: Dict[int, _EMAValue] = {} + self.accepted_token_speed = _EMAValue(ema_decay) + self.verify_row_speed = _EMAValue(ema_decay) + self.actual_speedup = _EMAValue(ema_decay) + self.single_token_speed_by_req_num: Dict[int, _EMAValue] = {} + self.last_plan: Optional[DynamicMTPPlan] = None + + def trim_before_forward( + self, + model_input: ModelInput, + run_reqs: List[InferReq], + decode_reqs: List[InferReq], + ): + plan = self._build_plan(model_input=model_input, decode_reqs=decode_reqs) + self.last_plan = plan + if plan.dynamic_batch_size == plan.original_batch_size: + return model_input, run_reqs, plan + + pruned_indices = self._invert_indices(plan.keep_indices, plan.original_batch_size) + if pruned_indices.numel() > 0: + pruned_mem_indexes = model_input.mem_indexes_cpu[pruned_indices] + g_infer_state_lock.acquire() + g_infer_context.req_manager.mem_manager.free(pruned_mem_indexes) + g_infer_state_lock.release() + + trimmed_input = copy.copy(model_input) + keep_indices = plan.keep_indices + keep_list = keep_indices.tolist() + + trimmed_input.batch_size = plan.dynamic_batch_size + trimmed_input.b_req_idx = model_input.b_req_idx[keep_indices].contiguous() + trimmed_input.b_mtp_index = model_input.b_mtp_index[keep_indices].contiguous() + trimmed_input.b_seq_len = model_input.b_seq_len[keep_indices].contiguous() + trimmed_input.mem_indexes_cpu = model_input.mem_indexes_cpu[keep_indices].contiguous() + trimmed_input.mem_indexes = None + trimmed_input.total_token_num = int(trimmed_input.b_seq_len.sum().item()) + trimmed_input.max_kv_seq_len = int(trimmed_input.b_seq_len.max().item()) + trimmed_input.multimodal_params = [model_input.multimodal_params[index] for index in keep_list] + trimmed_run_reqs = [run_reqs[index] for index in keep_list] + trimmed_input.b_mark_shared_group = self._build_mtp_shared_group_infos(trimmed_run_reqs) + trimmed_input.check_input() + return trimmed_input, trimmed_run_reqs, plan + + def update_after_verify( + self, + plan: DynamicMTPPlan, + decode_reqs: List[InferReq], + mtp_accept_len_cpu: torch.Tensor, + elapsed_ms: float, + per_req_probs_cpu: Optional[torch.Tensor] = None, + ) -> None: + if plan is None: + return + + accept_lens = [int(value) for value in mtp_accept_len_cpu.numpy()] + if not accept_lens: + return + + batch_mean = sum(accept_lens) / len(accept_lens) + batch_second = sum(value * value for value in accept_lens) / len(accept_lens) + self.accept_mean.update(batch_mean) + self.accept_second_moment.update(batch_second) + + for req, accept_len in zip(decode_reqs, accept_lens): + req_ema = self.req_accept_mean.get(req.req_id) + if req_ema is None: + req_ema = _EMAValue(self.accept_mean.decay, init_value=batch_mean) + self.req_accept_mean[req.req_id] = req_ema + req_ema.update(float(accept_len)) + + if per_req_probs_cpu is not None: + for req, req_prob in zip(decode_reqs, per_req_probs_cpu.numpy()): + req_prob_ema = self.req_prob.get(req.req_id) + if req_prob_ema is None: + req_prob_ema = _EMAValue(self.accept_mean.decay) + self.req_prob[req.req_id] = req_prob_ema + req_prob_ema.update(self._clip_prob(float(req_prob))) + + if elapsed_ms > 0: + current_output_speed = sum(accept_lens) / elapsed_ms + latency_ema = self.latency_ms_by_batch_size.get(plan.dynamic_batch_size) + if latency_ema is None: + latency_ema = _EMAValue(self.accept_mean.decay) + self.latency_ms_by_batch_size[plan.dynamic_batch_size] = latency_ema + latency_ema.update(elapsed_ms) + self.accepted_token_speed.update(current_output_speed) + self.verify_row_speed.update(plan.dynamic_batch_size / elapsed_ms) + if plan.dynamic_batch_size == plan.req_num: + single_token_speed = self.single_token_speed_by_req_num.get(plan.req_num) + if single_token_speed is None: + single_token_speed = _EMAValue(self.accept_mean.decay) + self.single_token_speed_by_req_num[plan.req_num] = single_token_speed + single_token_speed.update(plan.req_num / elapsed_ms) + baseline_speed = self.single_token_speed_by_req_num.get(plan.req_num) + if baseline_speed is not None and baseline_speed.get(0.0) > 0: + self.actual_speedup.update(current_output_speed / baseline_speed.get(0.0)) + + def get_stats_snapshot(self) -> Dict[str, float]: + return { + "accept_mean": self.accept_mean.get(float(self.max_rows_per_req)), + "accept_second_moment": self.accept_second_moment.get(float(self.max_rows_per_req**2)), + "accepted_token_speed": self.accepted_token_speed.get(0.0), + "verify_row_speed": self.verify_row_speed.get(0.0), + "actual_speedup": self.actual_speedup.get(0.0), + } + + def _build_plan(self, model_input: ModelInput, decode_reqs: List[InferReq]) -> DynamicMTPPlan: + req_num = len(decode_reqs) + original_batch_size = model_input.batch_size + if req_num == 0 or self.mtp_step == 0: + keep_indices = torch.arange(original_batch_size, dtype=torch.long, device="cpu") + return DynamicMTPPlan(req_num, original_batch_size, original_batch_size, keep_indices, [], 1.0, 0.0) + + mean = self.accept_mean.get(float(self.max_rows_per_req)) + second = self.accept_second_moment.get(float(self.max_rows_per_req**2)) + variance = max(0.0, second - mean * mean) + std = math.sqrt(variance) + budget = math.ceil(req_num * mean + self.confidence_k * math.sqrt(req_num) * std) + dynamic_batch_size = max(req_num, min(original_batch_size, budget)) + + per_req_rows = self._allocate_rows(decode_reqs=decode_reqs, dynamic_batch_size=dynamic_batch_size) + keep_indices, per_req_rows = self._build_keep_indices(model_input=model_input, per_req_rows=per_req_rows) + dynamic_batch_size = int(keep_indices.numel()) + + return DynamicMTPPlan( + req_num=req_num, + original_batch_size=original_batch_size, + dynamic_batch_size=dynamic_batch_size, + keep_indices=keep_indices, + per_req_rows=per_req_rows, + estimated_accept_mean=mean, + estimated_accept_std=std, + ) + + def _allocate_rows(self, decode_reqs: List[InferReq], dynamic_batch_size: int) -> List[int]: + req_num = len(decode_reqs) + per_req_rows = [1 for _ in range(req_num)] + remaining = dynamic_batch_size - req_num + if remaining <= 0: + return per_req_rows + + req_order = sorted( + range(req_num), + key=lambda index: self._req_prob(decode_reqs[index]), + reverse=True, + ) + + for req_index in req_order: + req_prob = self._req_prob(decode_reqs[req_index]) + for _ in range(self.mtp_step): + if remaining <= 0: + break + if random.random() >= req_prob: + break + per_req_rows[req_index] += 1 + remaining -= 1 + if remaining <= 0: + break + return per_req_rows + + def _req_prob(self, req: InferReq) -> float: + req_prob = self.req_prob.get(req.req_id) + if req_prob is not None: + return self._clip_prob(req_prob.get(1.0)) + fallback = self.accept_mean.get(float(self.max_rows_per_req)) / float(self.max_rows_per_req) + return self._clip_prob(fallback) + + def _clip_prob(self, value: float) -> float: + return min(1.0, max(0.0, value)) + + def _build_keep_indices(self, model_input: ModelInput, per_req_rows: List[int]): + keep_indices = [] + effective_per_req_rows = [0 for _ in per_req_rows] + req_index = -1 + cur_req_kept = 0 + cur_req_target = 0 + for index, mtp_index in enumerate(model_input.b_mtp_index.tolist()): + if mtp_index == 0: + req_index += 1 + cur_req_kept = 0 + cur_req_target = per_req_rows[req_index] + if cur_req_kept < cur_req_target: + keep_indices.append(index) + cur_req_kept += 1 + effective_per_req_rows[req_index] += 1 + return torch.tensor(keep_indices, dtype=torch.long, device="cpu"), effective_per_req_rows + + def _invert_indices(self, keep_indices: torch.Tensor, total_size: int) -> torch.Tensor: + keep_mask = torch.zeros((total_size,), dtype=torch.bool, device="cpu") + keep_mask[keep_indices] = True + return torch.nonzero(~keep_mask, as_tuple=False).view(-1) + + def _build_mtp_shared_group_infos(self, run_reqs: List[InferReq]) -> torch.Tensor: + max_batch_shared_group_size = get_diverse_max_batch_shared_group_size() + req_ids = [req.req_id for req in run_reqs] + b_mark_shared_group = [] + current_group = [] + for req_id in req_ids: + if not current_group: + current_group.append(req_id) + elif req_id == current_group[-1]: + current_group.append(req_id) + else: + b_mark_shared_group.extend([0 for _ in range(len(current_group))]) + b_mark_shared_group[-1] = len(current_group) + current_group.clear() + current_group.append(req_id) + + if len(current_group) == max_batch_shared_group_size: + b_mark_shared_group.extend([0 for _ in range(len(current_group))]) + b_mark_shared_group[-1] = len(current_group) + current_group.clear() + + if current_group: + b_mark_shared_group.extend([0 for _ in range(len(current_group))]) + b_mark_shared_group[-1] = len(current_group) + + return torch.tensor(b_mark_shared_group, dtype=torch.int32, device="cpu") From 7e67469c4fb497f504863f2a073dd83ad8da2dc6 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Mon, 18 May 2026 07:08:02 +0000 Subject: [PATCH 06/12] reback --- .../mode_backend/chunked_prefill/impl.py | 101 +++++++++--------- 1 file changed, 48 insertions(+), 53 deletions(-) diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index 71b22c08e..c41dbb6d9 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -25,7 +25,6 @@ from lightllm.utils.log_utils import init_logger from lightllm.utils.dist_utils import get_current_device_id from lightllm.utils.envs_utils import get_env_start_args, enable_dynamic_mtp_verify -from lightllm.server.router.model_infer.mode_backend.dynamic_mtp_planner import DynamicMTPPlanner from .control_state import ControlState logger = init_logger(__name__) @@ -46,9 +45,6 @@ def __init__(self) -> None: self.num_mtp_models = 1 if self.is_mtp_eagle else get_env_start_args().mtp_step self._draft_decode_func = self._draft_decode_eagle if self.is_mtp_eagle else self._draft_decode_vanilla self.enable_dynamic_mtp = enable_dynamic_mtp_verify() - self.dynamic_mtp_planner = ( - DynamicMTPPlanner(mtp_step=get_env_start_args().mtp_step) if self.enable_dynamic_mtp else None - ) else: self.prefill = self.prefill_normal self.decode = self.decode_normal @@ -237,25 +233,9 @@ def decode_mtp( """ MTP解码的通用流程,整合eagle和vanilla的共同逻辑 """ - if self.enable_dynamic_mtp: - # 让通用 pre-process 始终构建最大候选池,动态策略只在 forward 前裁剪。 - for req in decode_reqs: - req.current_mtp_step = req.mtp_step model_input, run_reqs = prepare_decode_inputs(decode_reqs) - dynamic_mtp_plan = None with torch.cuda.stream(g_infer_context.get_overlap_stream()): - if self.enable_dynamic_mtp: - model_input, run_reqs, dynamic_mtp_plan = self.dynamic_mtp_planner.trim_before_forward( - model_input=model_input, - run_reqs=run_reqs, - decode_reqs=decode_reqs, - ) - dynamic_mtp_start_event = None - if self.enable_dynamic_mtp: - dynamic_mtp_start_event = torch.cuda.Event(enable_timing=True) - dynamic_mtp_start_event.record() - b_mtp_index_cpu = model_input.b_mtp_index model_output = self.model.forward(model_input) next_token_ids, next_token_logprobs = sample(model_output.logits, run_reqs, self.eos_id) @@ -277,10 +257,9 @@ def decode_mtp( gpu_tensor=accepted_index, ) - verify_event = torch.cuda.Event(enable_timing=self.enable_dynamic_mtp) + verify_event = torch.cuda.Event() verify_event.record() - per_req_probs_cpu = None if self.enable_dynamic_mtp: all_next_token_ids, additional_mem_indexes_cpu, draft_probs_list = self._draft_decode_func( main_model_input=model_input, @@ -288,13 +267,6 @@ def decode_mtp( next_token_ids=next_token_ids, b_req_mtp_start_loc=b_req_mtp_start_loc, ) - draft_probs_tensor = torch.stack(draft_probs_list, dim=1) - request_start_rows = b_req_mtp_start_loc.to(torch.long) - per_req_probs = draft_probs_tensor[request_start_rows].mean(dim=1) - per_req_probs_cpu = g_pin_mem_manager.async_copy_from_gpu_tensor( - key="dynamic_mtp_req_probs", - gpu_tensor=per_req_probs, - ) else: all_next_token_ids, additional_mem_indexes_cpu = self._draft_decode_func( main_model_input=model_input, @@ -303,6 +275,18 @@ def decode_mtp( b_req_mtp_start_loc=b_req_mtp_start_loc, ) + # dynamic_sizes_gpu 用于第二阶段更新 req 的 mtp_size + if self.enable_dynamic_mtp: + draft_probs_tensor = torch.cat(draft_probs_list, dim=-1).view(self.mtp_step, b_mtp_index_cpu.shape[0]) + dynamic_sizes_gpu = self._compute_dynamic_mtp_size_gpu_part(draft_probs_tensor=draft_probs_tensor) + # 异步拷贝回 CPU Pin Memory + dynamic_sizes_cpu = g_pin_mem_manager.async_copy_from_gpu_tensor( + key="dynamic_mtp_sizes", gpu_tensor=dynamic_sizes_gpu + ) + + dynamic_mtp_event = torch.cuda.Event() + dynamic_mtp_event.record() + mtp_scatter_next_token_ids( req_to_next_token_ids=self.model.req_manager.req_sampling_params_manager.req_to_next_token_ids, b_req_mtp_start_loc=b_req_mtp_start_loc, @@ -326,34 +310,26 @@ def decode_mtp( gpu_tensor=mtp_accept_len, ) - sync_event = torch.cuda.Event(enable_timing=self.enable_dynamic_mtp) + sync_event = torch.cuda.Event() sync_event.record() # 第二阶段 event_pack.notify_post_handle_and_wait_pre_post_handle() - self._update_mtp_verify_token_num( - decode_reqs=decode_reqs, - verify_token_nums=dynamic_mtp_plan.per_req_rows if dynamic_mtp_plan is not None else None, - ) + self._update_mtp_verify_token_num(decode_reqs=decode_reqs) verify_event.synchronize() - dynamic_mtp_elapsed_ms = None accepted_index_cpu_numpy = accepted_index_cpu.numpy() verify_ok_reqs = [run_reqs[i] for i in range(len(run_reqs)) if accepted_index_cpu_numpy[i] == 1] + if self.enable_dynamic_mtp: + dynamic_mtp_event.synchronize() + self._update_dynamic_mtp_size_cpu_part( + run_reqs=run_reqs, dynamic_sizes_cpu=dynamic_sizes_cpu, accepted_index_cpu=accepted_index_cpu + ) update_packs = self._pre_post_handle(verify_ok_reqs, is_chuncked_mode=False) # 第三阶段 event_pack.notify_forward_and_wait_post_handle() sync_event.synchronize() - if self.enable_dynamic_mtp: - dynamic_mtp_elapsed_ms = dynamic_mtp_start_event.elapsed_time(sync_event) - self.dynamic_mtp_planner.update_after_verify( - plan=dynamic_mtp_plan, - decode_reqs=decode_reqs, - mtp_accept_len_cpu=mtp_accept_len_cpu, - elapsed_ms=dynamic_mtp_elapsed_ms, - per_req_probs_cpu=per_req_probs_cpu, - ) # 处理需要释放的内存索引 need_free_mem_indexes = model_input.mem_indexes_cpu[accepted_index_cpu == 0] @@ -379,6 +355,28 @@ def decode_mtp( event_pack.notify_pre_post_handle() return + def _compute_dynamic_mtp_size_gpu_part( + self, + draft_probs_tensor: torch.Tensor, + ) -> torch.Tensor: + rand_vals = torch.rand_like(draft_probs_tensor) + accepted_mask = draft_probs_tensor > rand_vals + valid_steps = torch.cumprod(accepted_mask.to(torch.int32), dim=0) + dynamic_mtp_sizes = valid_steps.sum(dim=0) + return dynamic_mtp_sizes + + def _update_dynamic_mtp_size_cpu_part( + self, + run_reqs: List[InferReq], + dynamic_sizes_cpu: torch.Tensor, + accepted_index_cpu: torch.Tensor, + ): + assert len(run_reqs) == dynamic_sizes_cpu.shape[0] == accepted_index_cpu.shape[0] + for req, new_size, accepted in zip(run_reqs, dynamic_sizes_cpu.numpy(), accepted_index_cpu.numpy()): + if int(accepted) == 1: + req.current_mtp_step = int(new_size) + assert req.current_mtp_step <= req.mtp_step + def _draft_prefill_forward(self, model_input: ModelInput, model_output: ModelOutput, next_token_ids: torch.Tensor): # spec prefill: MTP, 这个地方只是为了填充draft model的 kv, 并不会使用生成的token_id。 draft_model_input = model_input @@ -407,24 +405,17 @@ def _draft_decode_vanilla( draft_next_token_ids = next_token_ids all_next_token_ids = [] all_next_token_ids.append(next_token_ids) - draft_probs_list = [] if self.enable_dynamic_mtp else None # process the draft model output for draft_model_idx in range(self.mtp_step): draft_model_input.input_ids = draft_next_token_ids draft_model_input.mtp_draft_input_hiddens = draft_model_output.mtp_main_output_hiddens # spec decode: MTP draft_model_output: ModelOutput = self.draft_models[draft_model_idx].forward(draft_model_input) - if self.enable_dynamic_mtp: - draft_next_token_ids, draft_probs = self._gen_argmax_token_ids_and_prob(draft_model_output) - draft_probs_list.append(draft_probs) - else: - draft_next_token_ids = self._gen_argmax_token_ids(draft_model_output) + draft_next_token_ids = self._gen_argmax_token_ids(draft_model_output) all_next_token_ids.append(draft_next_token_ids) all_next_token_ids = torch.stack(all_next_token_ids, dim=1) # [batch_size, mtp_step + 1] - if self.enable_dynamic_mtp: - return all_next_token_ids, None, draft_probs_list return all_next_token_ids, None def _draft_decode_eagle( @@ -450,6 +441,8 @@ def _draft_decode_eagle( draft_next_token_ids = next_token_ids all_next_token_ids = [] all_next_token_ids.append(next_token_ids) + + # 用于收集每个 step 的 probs draft_probs_list = [] if self.enable_dynamic_mtp else None # process the draft model output @@ -460,6 +453,7 @@ def _draft_decode_eagle( draft_model_idx = _step % self.num_mtp_models draft_model_output: ModelOutput = self.draft_models[draft_model_idx].forward(draft_model_input) + # 收集 probs(如果需要) if self.enable_dynamic_mtp: draft_next_token_ids, draft_probs = self._gen_argmax_token_ids_and_prob(draft_model_output) draft_probs_list.append(draft_probs) @@ -486,4 +480,5 @@ def _draft_decode_eagle( if self.enable_dynamic_mtp: return all_next_token_ids, eagle_mem_indexes_cpu, draft_probs_list - return all_next_token_ids, eagle_mem_indexes_cpu + else: + return all_next_token_ids, eagle_mem_indexes_cpu From b39600c3a07cebc76e0e545870306e89dc28c998 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Mon, 18 May 2026 07:11:25 +0000 Subject: [PATCH 07/12] add static mtp_step --- .../router/model_infer/mode_backend/generic_pre_process.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py index 3d9d8815e..c50e0b7c3 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py @@ -113,9 +113,7 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In b_mtp_index.append(0) multimodal_params.append(req.multimodal_params) # process the draft tokens. - # 动态 MTP 模式:使用动态 current_mtp_step 构建 batch - # 非动态 MTP 模式:current_mtp_step 为固定的 mtp_step - for step in range(req.current_mtp_step): + for step in range(req.mtp_step): run_reqs.append(req) b_req_idx.append(req.req_idx) seq_len += 1 From 783be5a7f7b73b40b4c1c57425bb24b7f9ec844a Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Mon, 18 May 2026 07:41:22 +0000 Subject: [PATCH 08/12] fix --- .../model_infer/mode_backend/base_backend.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index bde7fad52..321055b4b 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -775,19 +775,14 @@ def _update_mtp_accept_ratio( return - def _update_mtp_verify_token_num( - self, - decode_reqs: List[InferReq], - verify_token_nums: Optional[List[int]] = None, - ): + def _update_mtp_verify_token_num(self, decode_reqs: List[InferReq]): if self.is_master_in_dp: - if verify_token_nums is None: - verify_token_nums = [1 + req.current_mtp_step for req in decode_reqs] - assert len(decode_reqs) == len(verify_token_nums) - for req, verify_token_num in zip(decode_reqs, verify_token_nums): - # 统计发送给主模型验证的 token 数量,动态 MTP 模式由 planner 传入实际裁剪后的行数。 - assert verify_token_num >= 1 - req.update_mtp_verify_token_num(verify_token_num=verify_token_num) + for req in decode_reqs: + # 统计发送给主模型验证的 token 数量:1 个主 token + 当前 mtp_size 个 draft token + # 在静态 MTP 模式下,使用固定的 mtp_step;在动态 MTP 模式下,使用动态调整的 current_mtp_step + # current_mtp_step 在静态 MTP 模式下为 mtp_step,在动态 MTP 模式下会在推理过程中动态设置。 + assert req.current_mtp_step >= 0 + req.update_mtp_verify_token_num(verify_token_num=1 + req.current_mtp_step) return def _gen_argmax_token_ids(self, model_output: ModelOutput): From 0054d08fb9904c24d6467951014ecae6b86ac52f Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Mon, 18 May 2026 08:18:06 +0000 Subject: [PATCH 09/12] fix --- .../basemodel/triton_kernel/mtp_utils.py | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/lightllm/common/basemodel/triton_kernel/mtp_utils.py b/lightllm/common/basemodel/triton_kernel/mtp_utils.py index 2d70a68c0..843a77d96 100644 --- a/lightllm/common/basemodel/triton_kernel/mtp_utils.py +++ b/lightllm/common/basemodel/triton_kernel/mtp_utils.py @@ -1,3 +1,4 @@ +from typing import Optional import triton import triton.language as tl import torch @@ -93,10 +94,15 @@ def _fwd_kernel_mtp_scatter_next_token_ids( req_to_next_token_ids_stride, all_next_token_ids, all_next_token_ids_stride, + req_to_next_token_probs, + req_to_next_token_probs_stride, + all_next_token_probs, + all_next_token_probs_stride, mtp_accept_len, b_req_mtp_start_loc, b_req_idx, mtp_step, + HAS_HAS_NEXT_TOKEN_PROBS: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): @@ -106,6 +112,17 @@ def _fwd_kernel_mtp_scatter_next_token_ids( cur_req_idx = tl.load(b_req_idx + req_start_loc) offset = tl.arange(0, BLOCK_SIZE) + if HAS_HAS_NEXT_TOKEN_PROBS: + cur_next_token_probs = tl.load( + all_next_token_probs + (req_start_loc + accept_len - 1) * all_next_token_probs_stride + offset, + mask=offset < mtp_step, + other=0.0, + ) + tl.store( + req_to_next_token_probs + cur_req_idx * req_to_next_token_probs_stride + offset, + cur_next_token_probs, + mask=offset < mtp_step, + ) scatter_next_token_ids = tl.load( all_next_token_ids + (req_start_loc + accept_len - 1) * all_next_token_ids_stride + offset, mask=offset < mtp_step, @@ -125,12 +142,20 @@ def mtp_scatter_next_token_ids( all_next_token_ids: torch.Tensor, b_req_idx: torch.Tensor, mtp_accept_len: torch.Tensor, + req_to_next_token_probs: Optional[torch.Tensor] = None, + all_next_token_probs: Optional[torch.Tensor] = None, ): max_mtp_step = req_to_next_token_ids.shape[1] BLOCK_SIZE = 16 assert max_mtp_step <= BLOCK_SIZE, f"max_mtp_step must be less than {BLOCK_SIZE}" num_reqs = b_req_mtp_start_loc.shape[0] mtp_step = all_next_token_ids.shape[1] + if req_to_next_token_probs is not None: + assert all_next_token_probs is not None + assert all_next_token_probs.shape == all_next_token_ids.shape + + HAS_HAS_NEXT_TOKEN_PROBS = req_to_next_token_probs is not None + grid = (num_reqs,) num_warps = 1 _fwd_kernel_mtp_scatter_next_token_ids[grid]( @@ -138,10 +163,15 @@ def mtp_scatter_next_token_ids( req_to_next_token_ids_stride=req_to_next_token_ids.stride(0), all_next_token_ids=all_next_token_ids, all_next_token_ids_stride=all_next_token_ids.stride(0), + req_to_next_token_probs=req_to_next_token_probs, + req_to_next_token_probs_stride=req_to_next_token_probs.stride(0), + all_next_token_probs=all_next_token_probs, + all_next_token_probs_stride=all_next_token_probs.stride(0), mtp_accept_len=mtp_accept_len, b_req_mtp_start_loc=b_req_mtp_start_loc, b_req_idx=b_req_idx, mtp_step=mtp_step, + HAS_HAS_NEXT_TOKEN_PROBS=HAS_HAS_NEXT_TOKEN_PROBS, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps, num_stages=1, From ccc0f78779e424221122fa46f812db7b0ae00824 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Mon, 18 May 2026 08:23:12 +0000 Subject: [PATCH 10/12] fix req_manager --- lightllm/common/req_manager.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index 3a4e2b631..278b3509c 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -6,7 +6,7 @@ from lightllm.common.basemodel.triton_kernel.gen_sampling_params import token_id_counter from lightllm.common.basemodel.triton_kernel.gen_sampling_params import update_req_to_token_id_counter -from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args +from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args, enable_dynamic_mtp_verify from lightllm.utils.config_utils import get_vocab_size from lightllm.server.router.model_infer.pin_mem_manager import g_pin_mem_manager @@ -116,6 +116,15 @@ def __init__(self, max_request_num): dtype=torch.int64, device="cuda", ) + if enable_dynamic_mtp_verify(): + self.req_to_next_token_probs = torch.zeros( + (max_request_num + 1, 16), + dtype=torch.float32, + device="cuda", + ) + else: + self.req_to_next_token_probs = None + self.req_to_exponential_decay_length_penalty = torch.zeros( max_request_num + 1, dtype=torch.float32, device="cuda" ) @@ -137,6 +146,9 @@ def init_req_sampling_params(self, req): shm_param = req.sampling_param.shm_param self.req_to_next_token_ids[req.req_idx][0:1].fill_(req.get_last_gen_token()) + if enable_dynamic_mtp_verify(): + self.req_to_next_token_probs[req.req_idx].fill_(0.0) + self.req_to_next_token_probs[req.req_idx][0:1].fill_(1.0) self.req_to_presence_penalty[req.req_idx].fill_(shm_param.presence_penalty) self.req_to_frequency_penalty[req.req_idx].fill_(shm_param.frequency_penalty) self.req_to_repetition_penalty[req.req_idx].fill_(shm_param.repetition_penalty) From 326515ce99139ea812ca7156f258b23ee666f057 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Mon, 18 May 2026 08:55:42 +0000 Subject: [PATCH 11/12] fix --- .../model_infer/mode_backend/base_backend.py | 19 ++++-- .../mode_backend/chunked_prefill/impl.py | 64 +++++++++++++------ 2 files changed, 58 insertions(+), 25 deletions(-) diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 321055b4b..4203b093f 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -4,6 +4,7 @@ import time import threading import torch.distributed as dist +import collections from typing import List, Tuple, Callable, Optional from transformers.configuration_utils import PretrainedConfig from lightllm.utils.infer_utils import set_random_seed @@ -775,14 +776,18 @@ def _update_mtp_accept_ratio( return - def _update_mtp_verify_token_num(self, decode_reqs: List[InferReq]): + def _update_mtp_verify_token_num( + self, decode_reqs: List[InferReq], dynamic_mtp_run_reqs: Optional[List[InferReq]] = None + ): if self.is_master_in_dp: - for req in decode_reqs: - # 统计发送给主模型验证的 token 数量:1 个主 token + 当前 mtp_size 个 draft token - # 在静态 MTP 模式下,使用固定的 mtp_step;在动态 MTP 模式下,使用动态调整的 current_mtp_step - # current_mtp_step 在静态 MTP 模式下为 mtp_step,在动态 MTP 模式下会在推理过程中动态设置。 - assert req.current_mtp_step >= 0 - req.update_mtp_verify_token_num(verify_token_num=1 + req.current_mtp_step) + if dynamic_mtp_run_reqs is None: + for req in decode_reqs: + assert req.mtp_step > 0 + req.update_mtp_verify_token_num(verify_token_num=1 + req.mtp_step) + else: + counter = collections.Counter([req.req_idx for req in dynamic_mtp_run_reqs]) + for req in decode_reqs: + req.update_mtp_verify_token_num(verify_token_num=1 + counter[req.req_idx] - 1) return def _gen_argmax_token_ids(self, model_output: ModelOutput): diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index c41dbb6d9..d68d34918 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -236,17 +236,27 @@ def decode_mtp( model_input, run_reqs = prepare_decode_inputs(decode_reqs) with torch.cuda.stream(g_infer_context.get_overlap_stream()): - b_mtp_index_cpu = model_input.b_mtp_index + + if self.enable_dynamic_mtp: + # 根据当前的 batch size 和 dynamic_batch_size 计算出需要裁剪的 batch size 的model_input + dynamic_batch_size = 10 # TODO: 需要根据实际情况计算出 dynamic_batch_size + trans_to_dynamic_model_input = None # TODO: 需要根据实际情况实现 trans_to_dynamic_model_input + model_input, selected_run_reqs = trans_to_dynamic_model_input(model_input, dynamic_batch_size) + # selected_run_reqs 是一个 gpu tensor, 类型为 int, 0, 表示没有选中, 1 表示选中。 + + selected_run_reqs_cpu = g_pin_mem_manager.async_copy_from_gpu_tensor( + key="selected_run_reqs", + gpu_tensor=selected_run_reqs, + ) + trans_dynamic_model_input_event = torch.cuda.Event() + trans_dynamic_model_input_event.record() + model_output = self.model.forward(model_input) next_token_ids, next_token_logprobs = sample(model_output.logits, run_reqs, self.eos_id) # verify the next_token_ids - b_req_mtp_start_loc = [index for index, mtp_index in enumerate(b_mtp_index_cpu) if mtp_index == 0] - b_req_mtp_start_loc = g_pin_mem_manager.gen_from_list( - key="b_req_mtp_start_loc", - data=b_req_mtp_start_loc, - dtype=torch.int32, - ).cuda(non_blocking=True) - + get_b_req_mtp_start_loc = None # TODO: 需要根据实际情况实现 get_b_req_mtp_start_loc + b_req_mtp_start_loc = get_b_req_mtp_start_loc(model_input.b_mtp_index, req_num=len(decode_reqs)) + # b_req_mtp_start_loc 是一个 gpu tensor, 类型为 int, 表示每个请求的 mtp_start_loc, shape 为 len(decode_reqs) mtp_accept_len, accepted_index = self._verify_mtp_v2( new_next_token_ids=next_token_ids, b_req_idx=model_input.b_req_idx, @@ -277,15 +287,20 @@ def decode_mtp( # dynamic_sizes_gpu 用于第二阶段更新 req 的 mtp_size if self.enable_dynamic_mtp: - draft_probs_tensor = torch.cat(draft_probs_list, dim=-1).view(self.mtp_step, b_mtp_index_cpu.shape[0]) + draft_probs_tensor = torch.cat(draft_probs_list, dim=-1).view( + self.mtp_step, model_input.b_mtp_index.shape[0] + ) dynamic_sizes_gpu = self._compute_dynamic_mtp_size_gpu_part(draft_probs_tensor=draft_probs_tensor) # 异步拷贝回 CPU Pin Memory dynamic_sizes_cpu = g_pin_mem_manager.async_copy_from_gpu_tensor( key="dynamic_mtp_sizes", gpu_tensor=dynamic_sizes_gpu ) - - dynamic_mtp_event = torch.cuda.Event() - dynamic_mtp_event.record() + dynamic_sizes_cpu # TODO, use to update statcis. + draft_probs_list = [e.view(-1, 1) for e in draft_probs_list] + draft_probs_list = [torch.ones_like(draft_probs_list[-1])] + draft_probs_list + all_next_token_probs = torch.cat(draft_probs_list, dim=-1) # [batch_size, mtp_step + 1] + else: + all_next_token_probs = None mtp_scatter_next_token_ids( req_to_next_token_ids=self.model.req_manager.req_sampling_params_manager.req_to_next_token_ids, @@ -293,6 +308,8 @@ def decode_mtp( all_next_token_ids=all_next_token_ids, b_req_idx=model_input.b_req_idx, mtp_accept_len=mtp_accept_len, + req_to_next_token_probs=self.model.req_manager.req_sampling_params_manager.req_to_next_token_probs, + all_next_token_probs=all_next_token_probs, ) next_token_ids_cpu, next_token_logprobs_cpu = self._async_copy_next_token_infos_to_pin_mem( @@ -315,22 +332,33 @@ def decode_mtp( # 第二阶段 event_pack.notify_post_handle_and_wait_pre_post_handle() - self._update_mtp_verify_token_num(decode_reqs=decode_reqs) + + if self.enable_dynamic_mtp: + trans_dynamic_model_input_event.synchronize() + selected_run_reqs_cpu_numpy = selected_run_reqs_cpu.numpy() + run_reqs = [run_reqs[i] for i in range(len(run_reqs)) if selected_run_reqs_cpu_numpy[i] == 1] + + if self.enable_dynamic_mtp: + self._update_mtp_verify_token_num(decode_reqs=decode_reqs, dynamic_mtp_run_reqs=run_reqs) + else: + self._update_mtp_verify_token_num(decode_reqs=decode_reqs) verify_event.synchronize() accepted_index_cpu_numpy = accepted_index_cpu.numpy() verify_ok_reqs = [run_reqs[i] for i in range(len(run_reqs)) if accepted_index_cpu_numpy[i] == 1] - if self.enable_dynamic_mtp: - dynamic_mtp_event.synchronize() - self._update_dynamic_mtp_size_cpu_part( - run_reqs=run_reqs, dynamic_sizes_cpu=dynamic_sizes_cpu, accepted_index_cpu=accepted_index_cpu - ) update_packs = self._pre_post_handle(verify_ok_reqs, is_chuncked_mode=False) # 第三阶段 event_pack.notify_forward_and_wait_post_handle() sync_event.synchronize() + if self.enable_dynamic_mtp: + # TODO: 更新动态 mtp step 步的相关信息到 planer中,进行相关的信息统计。便于分析。 + # self._update_dynamic_mtp_size_cpu_part( + # run_reqs=run_reqs, dynamic_sizes_cpu=dynamic_sizes_cpu, accepted_index_cpu=accepted_index_cpu + # ) + pass + # 处理需要释放的内存索引 need_free_mem_indexes = model_input.mem_indexes_cpu[accepted_index_cpu == 0] if additional_mem_indexes_cpu is not None: From 19fbb69fbc1940af71b6e2c89895137b7bd4ddfb Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Mon, 18 May 2026 09:22:51 +0000 Subject: [PATCH 12/12] fix --- .../server/router/model_infer/infer_batch.py | 5 +---- .../mode_backend/chunked_prefill/impl.py | 19 ++++++++++++------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 46608da13..7b931f153 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -435,10 +435,7 @@ def __init__( # mtp_step 用来记录一个请求 draft模型每步需要生成的token数量 # 正常模式下,这个值为0,在 mtp 模式下,这个值为 draft 模型每步需要生成的token数量 self.mtp_step: int = get_env_start_args().mtp_step - # current_mtp_step 用来记录当前的 MTP 验证长度(<= mtp_step) - # 在启用动态 MTP 验证时,每步会根据 prob 分布重新设置该值 - # 静态模式下为 mtp_step,动态模式下为动态计算的 MTP 验证长度 - self.current_mtp_step: int = self.mtp_step + if self.mtp_step > 0: self.decode_need_token_num = self._mtp_decode_need_token_num else: diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index d68d34918..392c826b9 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -353,11 +353,12 @@ def decode_mtp( sync_event.synchronize() if self.enable_dynamic_mtp: - # TODO: 更新动态 mtp step 步的相关信息到 planer中,进行相关的信息统计。便于分析。 - # self._update_dynamic_mtp_size_cpu_part( - # run_reqs=run_reqs, dynamic_sizes_cpu=dynamic_sizes_cpu, accepted_index_cpu=accepted_index_cpu - # ) - pass + self._update_dynamic_mtp_size_cpu_part( + decode_reqs=decode_reqs, + run_reqs=run_reqs, + dynamic_sizes_cpu=dynamic_sizes_cpu, + accepted_index_cpu=accepted_index_cpu, + ) # 处理需要释放的内存索引 need_free_mem_indexes = model_input.mem_indexes_cpu[accepted_index_cpu == 0] @@ -395,15 +396,19 @@ def _compute_dynamic_mtp_size_gpu_part( def _update_dynamic_mtp_size_cpu_part( self, + decode_reqs: List[InferReq], run_reqs: List[InferReq], dynamic_sizes_cpu: torch.Tensor, accepted_index_cpu: torch.Tensor, ): + id_to_current_mtp_step = {} assert len(run_reqs) == dynamic_sizes_cpu.shape[0] == accepted_index_cpu.shape[0] for req, new_size, accepted in zip(run_reqs, dynamic_sizes_cpu.numpy(), accepted_index_cpu.numpy()): if int(accepted) == 1: - req.current_mtp_step = int(new_size) - assert req.current_mtp_step <= req.mtp_step + assert int(new_size) <= req.mtp_step + id_to_current_mtp_step[req.req_idx] = int(new_size) + # TODO 将 id_to_current_mtp_step 的信息更新到 planner 中去 + pass def _draft_prefill_forward(self, model_input: ModelInput, model_output: ModelOutput, next_token_ids: torch.Tensor): # spec prefill: MTP, 这个地方只是为了填充draft model的 kv, 并不会使用生成的token_id。