From 5acef861b6d97512322c5d11cc7cf4ecfaa28103 Mon Sep 17 00:00:00 2001 From: goutamadwant Date: Tue, 30 Jun 2026 23:38:30 -0700 Subject: [PATCH] Add named manual kernel registration API Signed-off-by: goutamadwant --- codegen/gen.py | 55 +++++++++- codegen/templates/RegisterKernels.cpp | 6 +- codegen/templates/RegisterKernels.h | 4 +- codegen/test/test_executorch_gen.py | 103 +++++++++++++++++++ shim_et/xplat/executorch/codegen/codegen.bzl | 10 ++ tools/cmake/Codegen.cmake | 50 +++++++-- 6 files changed, 212 insertions(+), 16 deletions(-) diff --git a/codegen/gen.py b/codegen/gen.py index 9b102aa1bf6..1d905e6db9b 100644 --- a/codegen/gen.py +++ b/codegen/gen.py @@ -2,6 +2,7 @@ import argparse import os +import re from collections import defaultdict from dataclasses import dataclass from pathlib import Path @@ -79,6 +80,31 @@ from torchgen.selective_build.selector import SelectiveBuilder +DEFAULT_MANUAL_REGISTRATION_FUNCTION_NAME = "register_all_kernels" +MANUAL_REGISTRATION_LIB_NAME_PATTERN = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") + + +def get_manual_registration_function_name( + *, + manual_registration: bool, + manual_registration_lib_name: str | None, +) -> str: + if not manual_registration_lib_name: + return DEFAULT_MANUAL_REGISTRATION_FUNCTION_NAME + + if not manual_registration: + raise ValueError( + "--manual-registration-lib-name requires --manual-registration" + ) + + if not MANUAL_REGISTRATION_LIB_NAME_PATTERN.fullmatch(manual_registration_lib_name): + raise ValueError( + "--manual-registration-lib-name must be a valid C++ identifier" + ) + + return f"register_{manual_registration_lib_name}_kernels" + + def _sig_decl_wrapper(sig: CppSignature | ExecutorchCppSignature) -> str: """ A wrapper function to basically get `sig.decl(include_context=True)`. @@ -329,6 +355,9 @@ def gen_unboxing( use_aten_lib: bool, kernel_index: ETKernelIndex, manual_registration: bool, + manual_registration_function_name: str = ( + DEFAULT_MANUAL_REGISTRATION_FUNCTION_NAME + ), add_exception_boundary: bool = False, ) -> None: # Iterable type for write_sharded is a Tuple of (native_function, (kernel_key, metadata)) @@ -364,6 +393,9 @@ def key_func( ), # Only write header once }, num_shards=1, + base_env={ + "manual_registration_function_name": manual_registration_function_name, + }, sharded_keys={"unboxed_kernels", "fn_header"}, ) @@ -484,6 +516,9 @@ def gen_headers( kernel_index: ETKernelIndex, cpu_fm: FileManager, use_aten_lib: bool, + manual_registration_function_name: str = ( + DEFAULT_MANUAL_REGISTRATION_FUNCTION_NAME + ), ) -> None: """Generate headers. @@ -534,6 +569,7 @@ def gen_headers( "RegisterKernels.h", lambda: { "generated_comment": "@" + "generated by torchgen/gen_executorch.py", + "manual_registration_function_name": manual_registration_function_name, }, ) headers = { @@ -958,7 +994,15 @@ def main() -> None: "--manual-registration", action="store_true", help="a boolean flag to indicate whether we want to manually call" - "register_kernels() or rely on static init. ", + "register_all_kernels() or rely on static init. ", + ) + parser.add_argument( + "--manual-registration-lib-name", + "--manual_registration_lib_name", + "--lib-name", + "--lib_name", + help="library name to include in the manual registration API name. " + "Requires --manual-registration.", ) parser.add_argument( "--generate", @@ -977,6 +1021,13 @@ def main() -> None: ) options = parser.parse_args() assert options.tags_path, "tags.yaml is required by codegen yaml parsing." + try: + manual_registration_function_name = get_manual_registration_function_name( + manual_registration=options.manual_registration, + manual_registration_lib_name=options.manual_registration_lib_name, + ) + except ValueError as exc: + parser.error(str(exc)) selector = get_custom_build_selector( options.op_registration_whitelist, @@ -1011,6 +1062,7 @@ def main() -> None: kernel_index=kernel_index, cpu_fm=cpu_fm, use_aten_lib=options.use_aten_lib, + manual_registration_function_name=manual_registration_function_name, ) if "sources" in options.generate: @@ -1021,6 +1073,7 @@ def main() -> None: use_aten_lib=options.use_aten_lib, kernel_index=kernel_index, manual_registration=options.manual_registration, + manual_registration_function_name=manual_registration_function_name, add_exception_boundary=options.add_exception_boundary, ) if custom_ops_native_functions: diff --git a/codegen/templates/RegisterKernels.cpp b/codegen/templates/RegisterKernels.cpp index 0d173f52bb3..80a32deb3e4 100644 --- a/codegen/templates/RegisterKernels.cpp +++ b/codegen/templates/RegisterKernels.cpp @@ -7,7 +7,7 @@ */ // ${generated_comment} -// This implements register_all_kernels() API that is declared in +// This implements ${manual_registration_function_name}() API that is declared in // RegisterKernels.h #include "RegisterKernels.h" #include @@ -16,14 +16,14 @@ namespace torch { namespace executor { -Error register_all_kernels() { +Error ${manual_registration_function_name}() { Kernel kernels_to_register[] = { ${unboxed_kernels} // Generated kernels }; Error success_with_kernel_reg = ::executorch::runtime::register_kernels({kernels_to_register}); if (success_with_kernel_reg != Error::Ok) { - ET_LOG(Error, "Failed register all kernels"); + ET_LOG(Error, "Failed to register kernels"); return success_with_kernel_reg; } return Error::Ok; diff --git a/codegen/templates/RegisterKernels.h b/codegen/templates/RegisterKernels.h index 3c7ecff50b5..4d4fd75dd52 100644 --- a/codegen/templates/RegisterKernels.h +++ b/codegen/templates/RegisterKernels.h @@ -7,7 +7,7 @@ */ // ${generated_comment} -// Exposing an API for registering all kernels at once. +// Exposing an API for registering generated kernels at once. #include #include #include @@ -16,7 +16,7 @@ namespace torch { namespace executor { -Error register_all_kernels(); +Error ${manual_registration_function_name}(); } // namespace executor } // namespace torch diff --git a/codegen/test/test_executorch_gen.py b/codegen/test/test_executorch_gen.py index cd445acee62..5a4d8ed6298 100644 --- a/codegen/test/test_executorch_gen.py +++ b/codegen/test/test_executorch_gen.py @@ -14,6 +14,9 @@ from executorch.codegen.gen import ( ComputeCodegenUnboxedKernels, gen_functions_declarations, + gen_headers, + gen_unboxing, + get_manual_registration_function_name, parse_yaml_files, translate_native_yaml, ) @@ -29,6 +32,7 @@ OperatorName, ) from torchgen.selective_build.selector import SelectiveBuilder +from torchgen.utils import FileManager TEST_YAML = """ @@ -318,6 +322,105 @@ def tearDown(self) -> None: pass +class TestManualRegistrationFunctionName(unittest.TestCase): + def test_default_function_name(self) -> None: + self.assertEqual( + get_manual_registration_function_name( + manual_registration=True, + manual_registration_lib_name=None, + ), + "register_all_kernels", + ) + + def test_named_function_name(self) -> None: + self.assertEqual( + get_manual_registration_function_name( + manual_registration=True, + manual_registration_lib_name="portable_ops_lib", + ), + "register_portable_ops_lib_kernels", + ) + + def test_named_function_requires_manual_registration(self) -> None: + with self.assertRaisesRegex( + ValueError, "--manual-registration-lib-name requires" + ): + get_manual_registration_function_name( + manual_registration=False, + manual_registration_lib_name="portable_ops_lib", + ) + + def test_named_function_requires_cpp_identifier(self) -> None: + with self.assertRaisesRegex(ValueError, "valid C\\+\\+ identifier"): + get_manual_registration_function_name( + manual_registration=True, + manual_registration_lib_name="portable-ops-lib", + ) + + +class TestManualRegistrationTemplates(unittest.TestCase): + def setUp(self) -> None: + self.template_dir = os.path.join( + os.path.dirname(os.path.dirname(__file__)), "templates" + ) + self.function_name = "register_portable_ops_lib_kernels" + + def test_register_kernels_header_uses_named_function(self) -> None: + with tempfile.TemporaryDirectory() as tempdir: + gen_headers( + native_functions=[], + gen_custom_ops_header=False, + custom_ops_native_functions=[], + selector=SelectiveBuilder.get_nop_selector(), + kernel_index=ETKernelIndex(index={}), # type: ignore[arg-type] + cpu_fm=FileManager(tempdir, self.template_dir, False), + use_aten_lib=False, + manual_registration_function_name=self.function_name, + ) + + with open(os.path.join(tempdir, "RegisterKernels.h")) as f: + header = f.read() + + self.assertIn(f"Error {self.function_name}();", header) + self.assertNotIn("Error register_all_kernels();", header) + + def test_register_kernels_cpp_uses_named_function(self) -> None: + with tempfile.TemporaryDirectory() as tempdir: + native_function, backend_index = NativeFunction.from_yaml( + { + "func": "custom_1::op_1() -> bool", + "dispatch": {"CPU": "kernel_1"}, + }, + loc=Location(__file__, 1), + valid_tags=set(), + ) + backend_indices: dict[ + DispatchKey, dict[OperatorName, BackendMetadata] + ] = { + DispatchKey.CPU: {}, + DispatchKey.QuantizedCPU: {}, + } + BackendIndex.grow_index(backend_indices, backend_index) + + gen_unboxing( + native_functions=[native_function], + cpu_fm=FileManager(tempdir, self.template_dir, False), + selector=SelectiveBuilder.from_yaml_dict( + {"include_all_operators": True} + ), + use_aten_lib=False, + kernel_index=ETKernelIndex.from_backend_indices(backend_indices), + manual_registration=True, + manual_registration_function_name=self.function_name, + ) + + with open(os.path.join(tempdir, "RegisterKernelsEverything.cpp")) as f: + source = f.read() + + self.assertIn(f"Error {self.function_name}() {{", source) + self.assertNotIn("Error register_all_kernels() {", source) + + class TestGenFunctionsDeclarations(unittest.TestCase): def setUp(self) -> None: ( diff --git a/shim_et/xplat/executorch/codegen/codegen.bzl b/shim_et/xplat/executorch/codegen/codegen.bzl index 318996784a1..f26d3fc704f 100644 --- a/shim_et/xplat/executorch/codegen/codegen.bzl +++ b/shim_et/xplat/executorch/codegen/codegen.bzl @@ -273,6 +273,7 @@ def _prepare_genrule_and_lib( custom_ops_yaml_path = None, custom_ops_requires_runtime_registration = True, manual_registration = False, + manual_registration_lib_name = None, aten_mode = False, support_exceptions = True): """ @@ -350,6 +351,10 @@ def _prepare_genrule_and_lib( genrule_cmd = genrule_cmd + [ "--manual_registration", ] + if manual_registration_lib_name: + genrule_cmd = genrule_cmd + [ + "--manual-registration-lib-name={}".format(manual_registration_lib_name), + ] if custom_ops_yaml_path: genrule_cmd = genrule_cmd + [ "--custom_ops_yaml_path=" + custom_ops_yaml_path, @@ -828,6 +833,7 @@ def executorch_generated_lib( visibility = [], aten_mode = False, manual_registration = False, + manual_registration_lib_name = None, use_default_aten_ops_lib = True, deps = [], xplat_deps = [], @@ -888,6 +894,9 @@ def executorch_generated_lib( xplat_deps: Additional xplat deps, can be used to provide custom operator library. fbcode_deps: Additional fbcode deps, can be used to provide custom operator library. compiler_flags: compiler_flags args to runtime.cxx_library + manual_registration_lib_name: Optional C++ identifier to use when + generating a named manual registration API. If omitted, manual + registration keeps using `register_all_kernels`. dtype_selective_build: In additional to operator selection, dtype selective build further selects the dtypes for each operator. Can be used with model or dict selective build APIs, where dtypes can be specified. @@ -999,6 +1008,7 @@ def executorch_generated_lib( custom_ops_requires_runtime_registration = custom_ops_requires_runtime_registration, aten_mode = aten_mode, manual_registration = manual_registration, + manual_registration_lib_name = manual_registration_lib_name, support_exceptions = support_exceptions, ) diff --git a/tools/cmake/Codegen.cmake b/tools/cmake/Codegen.cmake index 4253fa44dc5..9436540eddc 100644 --- a/tools/cmake/Codegen.cmake +++ b/tools/cmake/Codegen.cmake @@ -154,9 +154,9 @@ endfunction() # custom_ops_yaml. # # Invoked as generate_bindings_for_kernels( LIB_NAME lib_name FUNCTIONS_YAML -# functions_yaml CUSTOM_OPS_YAML custom_ops_yaml ) +# functions_yaml CUSTOM_OPS_YAML custom_ops_yaml [MANUAL_REGISTRATION] ) function(generate_bindings_for_kernels) - set(options ADD_EXCEPTION_BOUNDARY) + set(options ADD_EXCEPTION_BOUNDARY MANUAL_REGISTRATION) set(arg_names LIB_NAME FUNCTIONS_YAML CUSTOM_OPS_YAML DTYPE_SELECTIVE_BUILD) cmake_parse_arguments(GEN "${options}" "${arg_names}" "" ${ARGN}) @@ -165,6 +165,7 @@ function(generate_bindings_for_kernels) message(STATUS " FUNCTIONS_YAML: ${GEN_FUNCTIONS_YAML}") message(STATUS " CUSTOM_OPS_YAML: ${GEN_CUSTOM_OPS_YAML}") message(STATUS " ADD_EXCEPTION_BOUNDARY: ${GEN_ADD_EXCEPTION_BOUNDARY}") + message(STATUS " MANUAL_REGISTRATION: ${GEN_MANUAL_REGISTRATION}") message(STATUS " DTYPE_SELECTIVE_BUILD: ${GEN_DTYPE_SELECTIVE_BUILD}") # Command to generate selected_operators.yaml from custom_ops.yaml. @@ -205,11 +206,27 @@ function(generate_bindings_for_kernels) if(GEN_ADD_EXCEPTION_BOUNDARY) set(_gen_command "${_gen_command}" --add-exception-boundary) endif() + if(GEN_MANUAL_REGISTRATION) + list( + APPEND + _gen_command + --manual-registration + --manual-registration-lib-name=${GEN_LIB_NAME} + ) + endif() - set(_gen_command_sources - ${_out_dir}/RegisterCodegenUnboxedKernelsEverything.cpp - ${_out_dir}/Functions.h ${_out_dir}/NativeFunctions.h - ) + if(GEN_MANUAL_REGISTRATION) + set(_gen_command_sources + ${_out_dir}/RegisterKernelsEverything.cpp + ${_out_dir}/RegisterKernels.h ${_out_dir}/Functions.h + ${_out_dir}/NativeFunctions.h + ) + else() + set(_gen_command_sources + ${_out_dir}/RegisterCodegenUnboxedKernelsEverything.cpp + ${_out_dir}/Functions.h ${_out_dir}/NativeFunctions.h + ) + endif() if(GEN_FUNCTIONS_YAML) list(APPEND _gen_command --functions-yaml-path=${GEN_FUNCTIONS_YAML}) @@ -268,13 +285,15 @@ endfunction() # Generate a runtime lib for registering operators in Executorch function(gen_operators_lib) + set(options MANUAL_REGISTRATION) set(multi_arg_names LIB_NAME KERNEL_LIBS DEPS DTYPE_SELECTIVE_BUILD) - cmake_parse_arguments(GEN "" "" "${multi_arg_names}" ${ARGN}) + cmake_parse_arguments(GEN "${options}" "" "${multi_arg_names}" ${ARGN}) message(STATUS "Generating operator lib:") message(STATUS " LIB_NAME: ${GEN_LIB_NAME}") message(STATUS " KERNEL_LIBS: ${GEN_KERNEL_LIBS}") message(STATUS " DEPS: ${GEN_DEPS}") + message(STATUS " MANUAL_REGISTRATION: ${GEN_MANUAL_REGISTRATION}") message(STATUS " DTYPE_SELECTIVE_BUILD: ${GEN_DTYPE_SELECTIVE_BUILD}") set(_out_dir ${CMAKE_CURRENT_BINARY_DIR}/${GEN_LIB_NAME}) @@ -284,9 +303,17 @@ function(gen_operators_lib) add_library(${GEN_LIB_NAME}) - set(_srcs_list ${_out_dir}/RegisterCodegenUnboxedKernelsEverything.cpp - ${_out_dir}/Functions.h ${_out_dir}/NativeFunctions.h - ) + if(GEN_MANUAL_REGISTRATION) + set(_srcs_list + ${_out_dir}/RegisterKernelsEverything.cpp + ${_out_dir}/RegisterKernels.h ${_out_dir}/Functions.h + ${_out_dir}/NativeFunctions.h + ) + else() + set(_srcs_list ${_out_dir}/RegisterCodegenUnboxedKernelsEverything.cpp + ${_out_dir}/Functions.h ${_out_dir}/NativeFunctions.h + ) + endif() if(GEN_DTYPE_SELECTIVE_BUILD) list(APPEND _srcs_list ${_opvariant_h}) endif() @@ -373,6 +400,9 @@ function(gen_operators_lib) executorch_target_link_options_shared_lib(${GEN_LIB_NAME}) set(_generated_headers ${_out_dir}/Functions.h ${_out_dir}/NativeFunctions.h) + if(GEN_MANUAL_REGISTRATION) + list(APPEND _generated_headers ${_out_dir}/RegisterKernels.h) + endif() if(GEN_DTYPE_SELECTIVE_BUILD) list(APPEND _generated_headers ${_opvariant_h}) endif()