Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3448,13 +3448,23 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.speculative.p_split = std::stof(value);
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}).set_env("LLAMA_ARG_DRAFT_P_SPLIT"));

add_opt(common_arg(
{"--draft-p-min"}, "P",
string_format("minimum speculative decoding probability (greedy) (default: %.2f)", (double)params.speculative.p_min),
[](common_params & params, const std::string & value) {
params.speculative.p_min = std::stof(value);
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_DRAFT_P_MIN"));

add_opt(common_arg(
{"--draft-p-accept"}, "P",
string_format("MTP draft acceptance probability threshold - accept non-argmax draft token if main model assigns it at least this probability (default: %.2f, 0.0 = greedy match only)", (double)params.speculative.p_accept),
[](common_params & params, const std::string & value) {
params.speculative.p_accept = std::stof(value);
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_DRAFT_P_ACCEPT"));

add_opt(common_arg(
{"-cd", "--ctx-size-draft"}, "N",
string_format("size of the prompt context for the draft model (default: %d, 0 = loaded from model)", params.speculative.n_ctx),
Expand Down
2 changes: 2 additions & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,8 @@ struct common_params_speculative {
int32_t draft_block_size = 3;
float p_split = 0.1f; // speculative decoding split probability
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
float p_accept = 0.0f; // min probability for main model to accept a non-argmax MTP draft token


// ngram-based speculative decoding

Expand Down
30 changes: 21 additions & 9 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
return id;
}

std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first) {
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first, float p_accept) {
GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");

std::vector<llama_token> result;
Expand All @@ -614,14 +614,27 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
size_t i = 0;
for (; i < draft.size(); i++) {
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);

common_sampler_accept(gsmpl, id, true);

result.push_back(id);

if (draft[i] != id) {
if (p_accept > 0.0f) {
// use already-computed GPU-side candidate probabilities
// avoids 262K float softmax on CPU per rejection
llama_token_data_array * cur_p = common_sampler_get_candidates(gsmpl, false);
for (size_t j = 0; j < cur_p->size; j++) {
if (cur_p->data[j].id == draft[i] && cur_p->data[j].p >= p_accept) {
common_sampler_accept(gsmpl, draft[i], true);
result.push_back(draft[i]);
goto next_draft_token;
break;
}
}
}
common_sampler_accept(gsmpl, id, true);
result.push_back(id);
break;
}
common_sampler_accept(gsmpl, id, true);
result.push_back(id);
next_draft_token:;
}

if (i == draft.size()) {
Expand All @@ -635,13 +648,12 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
return result;
}

std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) {
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first, float p_accept) {
std::vector<int> idxs(draft.size() + 1);
for (size_t i = 0; i < idxs.size(); ++i) {
idxs[i] = i;
}

return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first);
return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first, p_accept);
}

uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
Expand Down
4 changes: 2 additions & 2 deletions common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,10 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
//
// returns at least 1 token, up to idxs.size()
//
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first = false);
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first = false, float p_accept = 0.0f);

// assume idxs == [ 0, 1, 2, ..., draft.size() ]
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false);
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false, float p_accept = 0.0f);

uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);

Expand Down
2 changes: 1 addition & 1 deletion tools/server/server-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2938,7 +2938,7 @@ struct server_context_impl {
const size_t n_draft = slot.drafted.size();

// the accepted tokens from the speculation
const auto ids = common_sampler_sample_and_accept_n(slot.smpl.get(), ctx, slot.i_batch_dft, slot.drafted);
const auto ids = common_sampler_sample_and_accept_n(slot.smpl.get(), ctx, slot.i_batch_dft, slot.drafted, false, params_base.speculative.p_accept);

// For MTP speculation, h_prev for the next draft must come from the LAST ACCEPTED
// batch output - not embeddings_ith(-1), which would point at a rejected draft's
Expand Down
Loading