Skip to content

[Model Runner V2] Spec decode rejection sampler greedy support#37238

Merged
WoosukKwon merged 1 commit intovllm-project:mainfrom
TheEpicDolphin:gdelfin/mrv2-spec-decode-rejection-sample-greedy
Mar 18, 2026
Merged

[Model Runner V2] Spec decode rejection sampler greedy support#37238
WoosukKwon merged 1 commit intovllm-project:mainfrom
TheEpicDolphin:gdelfin/mrv2-spec-decode-rejection-sample-greedy

Conversation

@TheEpicDolphin
Copy link
Copy Markdown
Collaborator

@TheEpicDolphin TheEpicDolphin commented Mar 16, 2026

Purpose

Following up on #35461, specifically with support for greedy sampling (temperature = 0).

In order to support this in an efficient way, I get local argmax/max from target logits for greedy requests in _gather_draft_logits_and_target_argmax_kernel. Then during _probabilistic_rejection_kernel, the target argmax token is sampled only for the greedy requests. This limits the performance impact of greedy requests on the rest of the batch.

@mergify mergify Bot added the v1 label Mar 16, 2026
@TheEpicDolphin TheEpicDolphin force-pushed the gdelfin/mrv2-spec-decode-rejection-sample-greedy branch from d0bbcc1 to f6618c3 Compare March 16, 2026 23:17
num_reqs, num_speculative_steps + 1, dtype=torch.int64
)
# [num_reqs]
rejected_steps = sampled.new_empty(num_reqs)
_probabilistic_rejection_sample_kernel[(num_reqs,)](
# [num_reqs]
rejected_pos = pos.new_empty(num_reqs)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I felt it made more sense to compute this in _probabilistic_rejection_kernel rather than _compute_residual_logits_kernel, so i moved it here. Also, renamed it from residual_pos to rejected_pos.

@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Mar 16, 2026

Hi @TheEpicDolphin, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for greedy sampling (temperature=0) in the speculative decoding rejection sampler. The changes are well-structured, introducing new Triton kernels to handle greedy and probabilistic paths efficiently. The logic for rejection sampling and resampling in the greedy case is sound. I've found one potential issue in a newly added but currently unused kernel that should be addressed.

Comment thread vllm/v1/worker/gpu/spec_decode/rejection_sampler.py Outdated
@TheEpicDolphin TheEpicDolphin force-pushed the gdelfin/mrv2-spec-decode-rejection-sample-greedy branch 3 times, most recently from 2c46096 to f188893 Compare March 17, 2026 00:04
Copy link
Copy Markdown
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

I think we can fuse more kernels to minimize the materialization of *_logits tensors, but we can probably follow up after this.

@WoosukKwon WoosukKwon added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 18, 2026
Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
@TheEpicDolphin TheEpicDolphin force-pushed the gdelfin/mrv2-spec-decode-rejection-sample-greedy branch from f188893 to 47f633e Compare March 18, 2026 20:32
@WoosukKwon WoosukKwon enabled auto-merge (squash) March 18, 2026 21:04
@WoosukKwon WoosukKwon merged commit 04244fd into vllm-project:main Mar 18, 2026
60 checks passed
@TheEpicDolphin TheEpicDolphin deleted the gdelfin/mrv2-spec-decode-rejection-sample-greedy branch March 18, 2026 23:02
ikaadil pushed a commit to ikaadil/vllm that referenced this pull request Mar 19, 2026
…project#37238)

Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
Signed-off-by: Ifta Khairul Alam Adil <ikaadil007@gmail.com>
SouthWest7 pushed a commit to SouthWest7/vllm that referenced this pull request Mar 27, 2026
khairulkabir1661 pushed a commit to khairulkabir1661/vllm that referenced this pull request Mar 27, 2026
JiantaoXu pushed a commit to JiantaoXu/vllm that referenced this pull request Mar 28, 2026
@geraldstanje1
Copy link
Copy Markdown

geraldstanje1 commented Apr 9, 2026

hi @TheEpicDolphin what gpu was used in this test and what vllm settings regarding max_batched_tokens etc? also have you done any tests with prefix caching enabled?

mtparet pushed a commit to blackfuel-ai/vllm that referenced this pull request Apr 9, 2026
@TheEpicDolphin
Copy link
Copy Markdown
Collaborator Author

TheEpicDolphin commented Apr 10, 2026

@geraldstanje1 i used H200, and I didn't have the chance to test with max_batched_tokens or prefix caching.

Also please ignore the benchmark results from this PR, there was a bug skewing the acceptance rates that i fixed later. This PR has the most recent benchmark results for rejection sample :)

@geraldstanje1
Copy link
Copy Markdown

@TheEpicDolphin can you show how you run those benchmarks in #38496 - i assume you also have speculative decoding enabled?

@TheEpicDolphin
Copy link
Copy Markdown
Collaborator Author

@geraldstanje1 yep, i used spec decoding and compared strict vs probabilistic rejection sampling methods. I added the server/benchmark commands to #38496

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants