Skip to content

Commit c4e71f6

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
Thread method-scoped kernel registry through Program and Method
Summary: Certain kernels might make optimizations that are broadly optimal but sub optimal for a specific model. In those scenarios it is useful to expose a backdoor for the exception method to defer to a different implementation without forcing the root imlementation to have to handle all possible dispatches. This is just a proposal impl because things still get a little weird because ET today tends to have kernel impls get auto registered. Might need follow ups to allow generating boxed kernels separately from registering them into ETs generic kernel registry. Differential Revision: D98080033
1 parent 20e2582 commit c4e71f6

7 files changed

Lines changed: 250 additions & 9 deletions

File tree

runtime/executor/method.cpp

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -802,8 +802,19 @@ Error Method::resolve_operator(
802802
}
803803

804804
// Find a kernel with the matching name and tensor meta.
805-
Result<OpFunction> op_function =
806-
get_op_function_from_registry(operator_name, {meta, count});
805+
// Try method-scoped registry first (if provided), then fall back to global.
806+
auto resolve_op_function = [&]() -> Result<OpFunction> {
807+
if (!kernel_registry_.empty()) {
808+
Result<OpFunction> method_scoped_op_function =
809+
get_op_function_from_registry(
810+
operator_name, {meta, count}, kernel_registry_);
811+
if (method_scoped_op_function.ok()) {
812+
return method_scoped_op_function;
813+
}
814+
}
815+
return get_op_function_from_registry(operator_name, {meta, count});
816+
};
817+
Result<OpFunction> op_function = resolve_op_function();
807818
if (!op_function.ok()) {
808819
ET_LOG(
809820
Error,
@@ -831,7 +842,8 @@ Result<Method> Method::load(
831842
MemoryManager* memory_manager,
832843
EventTracer* event_tracer,
833844
const NamedDataMap* external_data_map,
834-
const LoadBackendOptionsMap* backend_options) {
845+
const LoadBackendOptionsMap* backend_options,
846+
Span<const Kernel> kernel_registry) {
835847
MemoryAllocator* temp_allocator = memory_manager->temp_allocator();
836848
if (temp_allocator == nullptr) {
837849
PlatformMemoryAllocator* platform_allocator =
@@ -844,7 +856,8 @@ Result<Method> Method::load(
844856
new (platform_allocator) PlatformMemoryAllocator();
845857
temp_allocator = platform_allocator;
846858
}
847-
Method method(program, memory_manager, event_tracer, temp_allocator);
859+
Method method(
860+
program, memory_manager, event_tracer, temp_allocator, kernel_registry);
848861
ET_LOG(Debug, "Loading method: %s.", s_plan->name()->c_str());
849862
Error err = method.init(s_plan, external_data_map, backend_options);
850863
if (err != Error::Ok) {

runtime/executor/method.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include <executorch/runtime/core/named_data_map.h>
2222
#include <executorch/runtime/core/span.h>
2323
#include <executorch/runtime/executor/memory_manager.h>
24+
#include <executorch/runtime/kernel/operator_registry.h>
2425
#include <executorch/runtime/executor/merged_data_map.h>
2526
#include <executorch/runtime/executor/method_meta.h>
2627
#include <executorch/runtime/platform/compiler.h>
@@ -82,6 +83,7 @@ class Method final {
8283
merged_data_map_(std::move(rhs.merged_data_map_)),
8384
external_constants_(rhs.external_constants_),
8485
n_external_constants_(rhs.n_external_constants_),
86+
kernel_registry_(rhs.kernel_registry_),
8587
init_state_(rhs.init_state_) {
8688
// Required: clear out fields that the dtor looks at, so that we don't free
8789
// anything twice.
@@ -331,7 +333,8 @@ class Method final {
331333
const Program* program,
332334
MemoryManager* memory_manager,
333335
EventTracer* event_tracer,
334-
MemoryAllocator* temp_allocator)
336+
MemoryAllocator* temp_allocator,
337+
Span<const Kernel> kernel_registry = {})
335338
: step_state_(),
336339
program_(program),
337340
memory_manager_(memory_manager),
@@ -348,6 +351,7 @@ class Method final {
348351
merged_data_map_(nullptr),
349352
external_constants_(nullptr),
350353
n_external_constants_(0),
354+
kernel_registry_(kernel_registry),
351355
init_state_(InitializationState::Uninitialized) {}
352356

353357
/// Static factory used by Program.
@@ -357,7 +361,8 @@ class Method final {
357361
MemoryManager* memory_manager,
358362
EventTracer* event_tracer,
359363
const NamedDataMap* named_data_map,
360-
const LoadBackendOptionsMap* backend_options = nullptr);
364+
const LoadBackendOptionsMap* backend_options = nullptr,
365+
Span<const Kernel> kernel_registry = {});
361366

362367
/**
363368
* Initialize the method from its serialized representation.
@@ -403,6 +408,8 @@ class Method final {
403408
NamedData* external_constants_;
404409
size_t n_external_constants_ = 0;
405410

411+
Span<const Kernel> kernel_registry_;
412+
406413
InitializationState init_state_;
407414

408415
/**

runtime/executor/program.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,8 @@ Result<Method> Program::load_method(
355355
MemoryManager* memory_manager,
356356
EventTracer* event_tracer,
357357
const NamedDataMap* named_data_map,
358-
const LoadBackendOptionsMap* backend_options) const {
358+
const LoadBackendOptionsMap* backend_options,
359+
Span<const Kernel> kernel_registry) const {
359360
EXECUTORCH_SCOPE_PROF("Program::load_method");
360361
internal::event_tracer_create_event_block(event_tracer, "Default");
361362
internal::EventTracerProfileMethodScope event_tracer_scope =
@@ -378,7 +379,8 @@ Result<Method> Program::load_method(
378379
memory_manager,
379380
event_tracer,
380381
named_data_map,
381-
backend_options);
382+
backend_options,
383+
kernel_registry);
382384
}
383385

384386
Result<MethodMeta> Program::method_meta(const char* method_name) const {

runtime/executor/program.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include <executorch/runtime/executor/method.h>
2222
#include <executorch/runtime/executor/method_meta.h>
2323
#include <executorch/runtime/executor/pte_data_map.h>
24+
#include <executorch/runtime/kernel/operator_registry.h>
2425
#include <executorch/runtime/platform/compiler.h>
2526

2627
// Forward declare flatbuffer types. This is a public header and must not
@@ -151,7 +152,8 @@ class Program final {
151152
MemoryManager* memory_manager,
152153
EventTracer* event_tracer = nullptr,
153154
const NamedDataMap* named_data_map = nullptr,
154-
const LoadBackendOptionsMap* backend_options = nullptr) const;
155+
const LoadBackendOptionsMap* backend_options = nullptr,
156+
Span<const Kernel> kernel_registry = {}) const;
155157

156158
/**
157159
* Gathers metadata for the named method.
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <cstring>
10+
11+
#include <executorch/extension/data_loader/file_data_loader.h>
12+
#include <executorch/extension/runner_util/inputs.h>
13+
#include <executorch/runtime/executor/program.h>
14+
#include <executorch/runtime/executor/test/managed_memory_manager.h>
15+
#include <executorch/runtime/kernel/kernel_runtime_context.h>
16+
#include <executorch/runtime/kernel/operator_registry.h>
17+
#include <executorch/runtime/platform/runtime.h>
18+
19+
#include <gtest/gtest.h>
20+
21+
using executorch::runtime::Error;
22+
using executorch::runtime::EValue;
23+
using executorch::runtime::Kernel;
24+
using executorch::runtime::KernelRuntimeContext;
25+
using executorch::runtime::Method;
26+
using executorch::runtime::Program;
27+
using executorch::runtime::Result;
28+
using executorch::runtime::Span;
29+
using executorch::runtime::testing::ManagedMemoryManager;
30+
using torch::executor::util::FileDataLoader;
31+
32+
constexpr size_t kDefaultNonConstMemBytes = 32 * 1024U;
33+
constexpr size_t kDefaultRuntimeMemBytes = 32 * 1024U;
34+
35+
namespace {
36+
37+
// aten::add.out args: [self, other, out, out]
38+
void multiply_by_two(
39+
KernelRuntimeContext& /*context*/,
40+
Span<EValue*> args) {
41+
auto& in = args[0]->toTensor();
42+
auto& out = args[args.size() - 1]->toTensor();
43+
for (ssize_t i = 0; i < in.numel(); ++i) {
44+
out.mutable_data_ptr<float>()[i] = in.const_data_ptr<float>()[i] * 2.0f;
45+
}
46+
}
47+
48+
void multiply_by_three(
49+
KernelRuntimeContext& /*context*/,
50+
Span<EValue*> args) {
51+
auto& in = args[0]->toTensor();
52+
auto& out = args[args.size() - 1]->toTensor();
53+
for (ssize_t i = 0; i < in.numel(); ++i) {
54+
out.mutable_data_ptr<float>()[i] = in.const_data_ptr<float>()[i] * 3.0f;
55+
}
56+
}
57+
58+
} // namespace
59+
60+
class KernelRegistryTest : public ::testing::Test {
61+
protected:
62+
void SetUp() override {
63+
executorch::runtime::runtime_init();
64+
65+
const char* path = std::getenv("ET_MODULE_ADD_PATH");
66+
ASSERT_NE(path, nullptr)
67+
<< "ET_MODULE_ADD_PATH environment variable must be set";
68+
69+
Result<FileDataLoader> loader = FileDataLoader::from(path);
70+
ASSERT_EQ(loader.error(), Error::Ok);
71+
loader_ = std::make_unique<FileDataLoader>(std::move(loader.get()));
72+
73+
Result<Program> program = Program::load(loader_.get());
74+
ASSERT_EQ(program.error(), Error::Ok);
75+
program_ = std::make_unique<Program>(std::move(program.get()));
76+
}
77+
78+
std::unique_ptr<FileDataLoader> loader_;
79+
std::unique_ptr<Program> program_;
80+
};
81+
82+
TEST_F(KernelRegistryTest, MethodScopedKernelOverridesGlobal) {
83+
// Create two fallback kernels for aten::add.out with different behavior.
84+
Kernel kernel_x2(
85+
"aten::add.out", multiply_by_two);
86+
Kernel kernel_x3(
87+
"aten::add.out", multiply_by_three);
88+
89+
Span<const Kernel> registry_x2(&kernel_x2, 1);
90+
Span<const Kernel> registry_x3(&kernel_x3, 1);
91+
92+
// Load two methods with different kernel registries.
93+
ManagedMemoryManager mmm_a(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);
94+
Result<Method> method_a = program_->load_method(
95+
"forward",
96+
&mmm_a.get(),
97+
/*event_tracer=*/nullptr,
98+
/*named_data_map=*/nullptr,
99+
/*backend_options=*/nullptr,
100+
registry_x2);
101+
ASSERT_EQ(method_a.error(), Error::Ok);
102+
103+
ManagedMemoryManager mmm_b(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);
104+
Result<Method> method_b = program_->load_method(
105+
"forward",
106+
&mmm_b.get(),
107+
/*event_tracer=*/nullptr,
108+
/*named_data_map=*/nullptr,
109+
/*backend_options=*/nullptr,
110+
registry_x3);
111+
ASSERT_EQ(method_b.error(), Error::Ok);
112+
113+
// Prepare inputs: tensor inputs + alpha scalar (input index 2).
114+
auto inputs_a = torch::executor::util::prepare_input_tensors(method_a.get());
115+
ASSERT_EQ(inputs_a.error(), Error::Ok);
116+
ASSERT_EQ(method_a->set_input(EValue(1.0), 2), Error::Ok);
117+
118+
auto inputs_b = torch::executor::util::prepare_input_tensors(method_b.get());
119+
ASSERT_EQ(inputs_b.error(), Error::Ok);
120+
ASSERT_EQ(method_b->set_input(EValue(1.0), 2), Error::Ok);
121+
122+
// Execute both methods.
123+
ASSERT_EQ(method_a->execute(), Error::Ok);
124+
ASSERT_EQ(method_b->execute(), Error::Ok);
125+
126+
// Check outputs: method_a should have 2.0, method_b should have 3.0.
127+
const auto& out_a = method_a->get_output(0).toTensor();
128+
const auto& out_b = method_b->get_output(0).toTensor();
129+
130+
ASSERT_GT(out_a.numel(), 0);
131+
for (ssize_t i = 0; i < out_a.numel(); ++i) {
132+
EXPECT_FLOAT_EQ(out_a.const_data_ptr<float>()[i], 2.0f)
133+
<< "method_a output[" << i << "] should be 2.0";
134+
}
135+
for (ssize_t i = 0; i < out_b.numel(); ++i) {
136+
EXPECT_FLOAT_EQ(out_b.const_data_ptr<float>()[i], 3.0f)
137+
<< "method_b output[" << i << "] should be 3.0";
138+
}
139+
}

runtime/executor/test/targets.bzl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,23 @@ def define_common_targets(is_fbcode = False):
223223
env = modules_env,
224224
)
225225

226+
runtime.cxx_test(
227+
name = "kernel_registry_test",
228+
srcs = [
229+
"kernel_registry_test.cpp",
230+
],
231+
deps = [
232+
":managed_memory_manager",
233+
"//executorch/extension/data_loader:file_data_loader",
234+
"//executorch/extension/runner_util:inputs",
235+
"//executorch/runtime/executor:program",
236+
"//executorch/runtime/kernel:kernel_runtime_context",
237+
"//executorch/runtime/kernel:operator_registry",
238+
"//executorch/runtime/platform:platform",
239+
],
240+
env = modules_env,
241+
)
242+
226243
runtime.cxx_test(
227244
name = "kernel_integration_test",
228245
srcs = [

runtime/kernel/test/operator_registry_test.cpp

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,67 @@ TEST_F(OperatorRegistryTest, GetOpFunctionUsesProvidedKernelList) {
438438
EXPECT_EQ(run_kernel(*fallback_func), 50);
439439
}
440440

441+
TEST_F(OperatorRegistryTest, ProvidedKernelListMissCanFallBackToGlobal) {
442+
std::array<char, kKernelKeyBufSize> buf;
443+
Error err = make_kernel_key(
444+
{{ScalarType::Long, {0, 1, 2, 3}}}, buf.data(), buf.size());
445+
ASSERT_EQ(err, Error::Ok);
446+
KernelKey long_key = KernelKey(buf.data());
447+
448+
Kernel global_kernel = Kernel(
449+
"test::provided_kernel_list_global_fallback",
450+
KernelKey{},
451+
[](KernelRuntimeContext& context, Span<EValue*> stack) {
452+
(void)context;
453+
*(stack[0]) = Scalar(50);
454+
});
455+
err = register_kernels({&global_kernel, 1});
456+
ASSERT_EQ(err, Error::Ok);
457+
458+
Kernel scoped_kernel = Kernel(
459+
"test::provided_kernel_list_global_fallback",
460+
long_key,
461+
[](KernelRuntimeContext& context, Span<EValue*> stack) {
462+
(void)context;
463+
*(stack[0]) = Scalar(100);
464+
});
465+
Span<const Kernel> scoped_registry(&scoped_kernel, 1);
466+
467+
Tensor::DimOrderType dims[] = {0, 1, 2, 3};
468+
auto dim_order_type = Span<Tensor::DimOrderType>(dims, 4);
469+
TensorMeta long_meta[] = {TensorMeta(ScalarType::Long, dim_order_type)};
470+
Span<const TensorMeta> long_kernel_key(long_meta);
471+
472+
TensorMeta float_meta[] = {TensorMeta(ScalarType::Float, dim_order_type)};
473+
Span<const TensorMeta> float_kernel_key(float_meta);
474+
475+
auto run_kernel = [](OpFunction func) {
476+
EValue value = Scalar(0);
477+
EValue* stack[] = {&value};
478+
KernelRuntimeContext context{};
479+
func(context, Span<EValue*>(stack));
480+
return value.toScalar().to<int64_t>();
481+
};
482+
483+
Result<OpFunction> scoped_func = get_op_function_from_registry(
484+
"test::provided_kernel_list_global_fallback",
485+
long_kernel_key,
486+
scoped_registry);
487+
ASSERT_EQ(scoped_func.error(), Error::Ok);
488+
EXPECT_EQ(run_kernel(*scoped_func), 100);
489+
490+
Result<OpFunction> scoped_miss = get_op_function_from_registry(
491+
"test::provided_kernel_list_global_fallback",
492+
float_kernel_key,
493+
scoped_registry);
494+
ASSERT_EQ(scoped_miss.error(), Error::OperatorMissing);
495+
496+
Result<OpFunction> global_func = get_op_function_from_registry(
497+
"test::provided_kernel_list_global_fallback", float_kernel_key);
498+
ASSERT_EQ(global_func.error(), Error::Ok);
499+
EXPECT_EQ(run_kernel(*global_func), 50);
500+
}
501+
441502
TEST_F(OperatorRegistryTest, DoubleRegisterKernelsDies) {
442503
std::array<char, kKernelKeyBufSize> buf_long_contiguous;
443504
Error err = make_kernel_key(

0 commit comments

Comments
 (0)