Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions extension/llm/runner/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def define_common_targets():
visibility = ["PUBLIC"],
exported_deps = [
":text_decoder_runner" + aten_suffix,
"//executorch/extension/llm/sampler:sampler" + aten_suffix,
"//pytorch/tokenizers:headers",
"//executorch/extension/module:module" + aten_suffix,
"//executorch/extension/tensor:tensor" + aten_suffix,
Expand Down
81 changes: 81 additions & 0 deletions extension/llm/runner/text_token_generator.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
Expand All @@ -10,10 +10,14 @@
#pragma once

#include <atomic>
#include <memory>
#include <vector>

#include <executorch/extension/llm/runner/stats.h>
#include <executorch/extension/llm/runner/text_decoder_runner.h>
#include <executorch/extension/llm/sampler/logit_processor.h>
#include <executorch/extension/tensor/tensor.h>
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
#include <pytorch/tokenizers/tokenizer.h>

namespace executorch {
Expand All @@ -38,6 +42,20 @@
ignore_eos_ = ignore_eos;
}

void add_logit_processor(std::shared_ptr<LogitProcessor> processor) {
if (processor) {
logit_processors_.push_back(std::move(processor));
}
}

void clear_logit_processors() {
logit_processors_.clear();
}

size_t num_logit_processors() const {
return logit_processors_.size();
}

virtual ~TextTokenGenerator() = default;

/**
Expand Down Expand Up @@ -109,6 +127,10 @@

prev_token = cur_token;

if (!logit_processors_.empty()) {
ET_CHECK_OK_OR_RETURN_ERROR(apply_logit_processors_(logits_tensor));
}
Comment on lines +130 to +132

stats_->on_sampling_begin();
cur_token =
text_decoder_runner_->logits_to_token(logits_tensor, temperature);
Expand Down Expand Up @@ -177,6 +199,63 @@
}

private:
inline ::executorch::runtime::Error apply_logit_processors_(
::executorch::aten::Tensor& logits_tensor) {
ET_CHECK_OR_RETURN_ERROR(
logits_tensor.dim() >= 2,
InvalidArgument,
"LogitProcessor expects logits with dim >= 2, got %d",
static_cast<int>(logits_tensor.dim()));

const int32_t vocab_size = logits_tensor.size(logits_tensor.dim() - 1);
int32_t offset = 0;
if (logits_tensor.dim() == 3) {
const int32_t num_tokens = logits_tensor.size(1);
ET_CHECK_OR_RETURN_ERROR(
num_tokens > 0,
InvalidArgument,
"LogitProcessor expects non-empty sequence dimension");
offset = (num_tokens - 1) * vocab_size;
}

if (logits_tensor.scalar_type() ==::executorch::aten::ScalarType::Float) {
auto* logits = logits_tensor.mutable_data_ptr<float>() + offset;
for (auto& processor : logit_processors_) {
processor->process(logits);
}
return ::executorch::runtime::Error::Ok;
}

struct {
[[noreturn]] void fail(torch::executor::Error /* error */) {
ET_CHECK_MSG(false, "Unsupported dtype in apply_logit_processors_");
}
} ctx;

std::vector<float> temp(vocab_size);
ET_SWITCH_THREE_TYPES(
Half,
BFloat16,
UInt16,
logits_tensor.scalar_type(),
ctx,
"apply_logit_processors_",
CTYPE,
[&]() {
auto* logits = logits_tensor.mutable_data_ptr<CTYPE>() + offset;
for (int32_t i = 0; i < vocab_size; ++i) {
temp[i] = static_cast<float>(logits[i]);
}
for (auto& processor : logit_processors_) {
processor->process(temp.data());
}
for (int32_t i = 0; i < vocab_size; ++i) {
logits[i] = static_cast<CTYPE>(temp[i]);
}
});
return ::executorch::runtime::Error::Ok;
}

/**
* Note: TextTokenGenerator does not own the tokenizer_ and
* text_decoder_runner_. The lifecycle of these objects should be managed
Expand All @@ -189,6 +268,8 @@
bool use_kv_cache_;
bool ignore_eos_ = false;

std::vector<std::shared_ptr<LogitProcessor>> logit_processors_;

// state machine
std::atomic<bool> should_stop_{false};

Expand Down
57 changes: 57 additions & 0 deletions extension/llm/sampler/logit_processor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <cstdint>

#include <executorch/runtime/platform/compiler.h>

namespace executorch {
namespace extension {
namespace llm {

/**
* Interface for in-place logit transformations applied between the model's
* forward pass and the sampler. Examples include:
* - Grammar / constrained-decoding masks (set disallowed tokens to -inf)
* - Logit bias (additive per-token bias)
* - Custom debug instrumentation
*
* A `TextTokenGenerator` may be configured with a chain of processors. They
* are invoked in order on every decoding step, before the sampler sees the
* logits. Each processor mutates the buffer in place; later processors
* observe earlier processors' modifications.
*
* Implementations must be cheap to call repeatedly — `process()` runs on the
* critical path of every generated token.
*/
class ET_EXPERIMENTAL LogitProcessor {
public:
explicit LogitProcessor(int32_t vocab_size) : vocab_size_(vocab_size) {}
virtual ~LogitProcessor() = default;

/**
* Modify logits in place for the current decoding step.
*
* @param logits Mutable pointer to the logits buffer for the current
* step. Must contain at least `vocab_size` elements.
*/
virtual void process(float* logits) = 0;

int32_t vocab_size() const {
return vocab_size_;
}

private:
int32_t vocab_size_;
};

} // namespace llm
} // namespace extension
} // namespace executorch
1 change: 1 addition & 0 deletions extension/llm/sampler/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ def define_common_targets():
runtime.cxx_library(
name = "sampler" + aten_suffix,
exported_headers = [
"logit_processor.h",
"sampler.h",
"util.h",
],
Expand Down
10 changes: 10 additions & 0 deletions extension/llm/sampler/test/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,13 @@ def define_common_targets():
"//caffe2:torch-cpp",
],
)

runtime.cxx_test(
name = "test_logit_processor",
srcs = [
"test_logit_processor.cpp",
],
deps = [
"//executorch/extension/llm/sampler:sampler",
],
)
136 changes: 136 additions & 0 deletions extension/llm/sampler/test/test_logit_processor.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/extension/llm/sampler/logit_processor.h>

#include <limits>
#include <memory>
#include <vector>

#include <gtest/gtest.h>

using ::executorch::extension::llm::LogitProcessor;

namespace {

// Adds a fixed bias to every logit slot. Records how many times it was
// invoked so tests can verify chain ordering.
class AddBiasProcessor : public LogitProcessor {
public:
AddBiasProcessor(int32_t vocab_size, float bias)
: LogitProcessor(vocab_size), bias_(bias) {}

void process(float* logits) override {
++call_count_;
for (int32_t i = 0; i < vocab_size(); ++i) {
logits[i] += bias_;
}
}

int call_count() const {
return call_count_;
}

private:
float bias_;
int call_count_ = 0;
};

class MultiplyProcessor : public LogitProcessor {
public:
MultiplyProcessor(int32_t vocab_size, float factor)
: LogitProcessor(vocab_size), factor_(factor) {}

void process(float* logits) override {
for (int32_t i = 0; i < vocab_size(); ++i) {
logits[i] *= factor_;
}
}

private:
float factor_;
};

class MaskTokenProcessor : public LogitProcessor {
public:
MaskTokenProcessor(int32_t vocab_size, int32_t banned_token)
: LogitProcessor(vocab_size), banned_token_(banned_token) {}

void process(float* logits) override {
if (banned_token_ >= 0 && banned_token_ < vocab_size()) {
logits[banned_token_] = -std::numeric_limits<float>::infinity();
}
}

private:
int32_t banned_token_;
};

} // namespace

// A single processor sees the buffer and may mutate it in place.
TEST(LogitProcessorTest, SingleProcessorMutatesLogits) {
std::vector<float> logits = {1.0f, 2.0f, 3.0f, 4.0f};
AddBiasProcessor bias{static_cast<int32_t>(logits.size()), 10.0f};

bias.process(logits.data());

const std::vector<float> expected = {11.0f, 12.0f, 13.0f, 14.0f};
EXPECT_EQ(logits, expected);
EXPECT_EQ(bias.call_count(), 1);
}

// Multiply(×2) then Add(+1) gives (x*2)+1, which differs from
// Add(+1) then Multiply(×2) = (x+1)*2. Non-commutative operations
// verify that processors run in registration order.
TEST(LogitProcessorTest, ProcessorChainAppliesInOrder) {
std::vector<float> logits = {1.0f, 2.0f, 3.0f, 4.0f};

const int32_t vocab_size = static_cast<int32_t>(logits.size());
std::vector<std::shared_ptr<LogitProcessor>> chain;
chain.push_back(std::make_shared<MultiplyProcessor>(vocab_size, 2.0f));
chain.push_back(std::make_shared<AddBiasProcessor>(vocab_size, 1.0f));

for (auto& p : chain) {
// NOLINTNEXTLINE(facebook-hte-Deprecated)
p->process(logits.data());
}

// (x*2)+1, NOT (x+1)*2
const std::vector<float> expected = {3.0f, 5.0f, 7.0f, 9.0f};
EXPECT_EQ(logits, expected);
}

// A masking processor zeroes (well, -inf's) a specific token slot. This is
// the pattern grammar processors will follow.
TEST(LogitProcessorTest, MaskTokenDrivesArgmaxAway) {
std::vector<float> logits = {0.1f, 0.2f, 0.99f, 0.4f}; // argmax = 2

MaskTokenProcessor mask{
static_cast<int32_t>(logits.size()), /*banned_token=*/2};
mask.process(logits.data());

const std::vector<float> expected = {
0.1f, 0.2f, -std::numeric_limits<float>::infinity(), 0.4f};
EXPECT_EQ(logits, expected);
}

TEST(LogitProcessorTest, MaskTokenOutOfRangeIsNoOp) {
std::vector<float> logits = {1.0f, 2.0f, 3.0f};
const std::vector<float> snapshot = logits;

MaskTokenProcessor mask_over{
static_cast<int32_t>(logits.size()), /*banned_token=*/99};
mask_over.process(logits.data());
EXPECT_EQ(logits, snapshot);

MaskTokenProcessor mask_neg{
static_cast<int32_t>(logits.size()), /*banned_token=*/-1};
mask_neg.process(logits.data());
EXPECT_EQ(logits, snapshot);
}
Loading