[Model Runner V2] Spec decode rejection sampler greedy support#37238
Conversation
d0bbcc1 to
f6618c3
Compare
| num_reqs, num_speculative_steps + 1, dtype=torch.int64 | ||
| ) | ||
| # [num_reqs] | ||
| rejected_steps = sampled.new_empty(num_reqs) | ||
| _probabilistic_rejection_sample_kernel[(num_reqs,)]( | ||
| # [num_reqs] | ||
| rejected_pos = pos.new_empty(num_reqs) |
There was a problem hiding this comment.
I felt it made more sense to compute this in _probabilistic_rejection_kernel rather than _compute_residual_logits_kernel, so i moved it here. Also, renamed it from residual_pos to rejected_pos.
|
Hi @TheEpicDolphin, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
There was a problem hiding this comment.
Code Review
This pull request adds support for greedy sampling (temperature=0) in the speculative decoding rejection sampler. The changes are well-structured, introducing new Triton kernels to handle greedy and probabilistic paths efficiently. The logic for rejection sampling and resampling in the greedy case is sound. I've found one potential issue in a newly added but currently unused kernel that should be addressed.
2c46096 to
f188893
Compare
WoosukKwon
left a comment
There was a problem hiding this comment.
Thanks for the PR!
I think we can fuse more kernels to minimize the materialization of *_logits tensors, but we can probably follow up after this.
Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
f188893 to
47f633e
Compare
…project#37238) Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai> Signed-off-by: Ifta Khairul Alam Adil <ikaadil007@gmail.com>
…project#37238) Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
…project#37238) Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
…project#37238) Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
|
hi @TheEpicDolphin what gpu was used in this test and what vllm settings regarding max_batched_tokens etc? also have you done any tests with prefix caching enabled? |
…project#37238) Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
|
@geraldstanje1 i used H200, and I didn't have the chance to test with max_batched_tokens or prefix caching. Also please ignore the benchmark results from this PR, there was a bug skewing the acceptance rates that i fixed later. This PR has the most recent benchmark results for rejection sample :) |
|
@TheEpicDolphin can you show how you run those benchmarks in #38496 - i assume you also have speculative decoding enabled? |
|
@geraldstanje1 yep, i used spec decoding and compared strict vs probabilistic rejection sampling methods. I added the server/benchmark commands to #38496 |
Purpose
Following up on #35461, specifically with support for greedy sampling (temperature = 0).
In order to support this in an efficient way, I get local argmax/max from target logits for greedy requests in
_gather_draft_logits_and_target_argmax_kernel. Then during_probabilistic_rejection_kernel, the target argmax token is sampled only for the greedy requests. This limits the performance impact of greedy requests on the rest of the batch.