From 3da61114acf86410cff987368f27ac961c60723b Mon Sep 17 00:00:00 2001 From: Hardik Sharma Date: Wed, 25 Mar 2026 22:02:08 -0700 Subject: [PATCH] Move graph_builder and program_builder to executorch.backends.test (#18483) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/18483 Move `GraphBuilder` and `ProgramBuilder` from `executorch.backends.cadence.aot` to `executorch.backends.test` since they are general-purpose test utilities not specific to the Cadence backend. - Created new canonical modules at `executorch/backends/test/` - Old modules now re-export from the new location for backward compatibility - Updated all 12 downstream consumers to import from the new path - Updated BUCK targets: new targets in `backends/test/targets.bzl`, old targets now depend on new ones Reviewed By: DrJessop Differential Revision: D97995878 --- backends/cadence/aot/BUCK | 38 ++--- backends/cadence/aot/graph_builder.py | 137 +---------------- backends/cadence/aot/program_builder.py | 137 +---------------- .../aot/tests/test_decompose_ops_passes.py | 2 +- .../aot/tests/test_fusion_ops_passes.py | 2 +- .../cadence/aot/tests/test_graph_builder.py | 5 +- backends/cadence/aot/tests/test_idma_ops.py | 2 +- .../cadence/aot/tests/test_memory_passes.py | 4 +- .../cadence/aot/tests/test_program_builder.py | 2 +- .../cadence/aot/tests/test_quantizer_ops.py | 5 +- .../aot/tests/test_remove_ops_passes.py | 2 +- .../aot/tests/test_reorder_ops_passes.py | 2 +- .../aot/tests/test_replace_ops_passes.py | 5 +- .../aot/tests/test_simplify_ops_passes.py | 2 +- .../cadence/aot/tests/test_to_out_var_pass.py | 2 +- .../aot/tests/test_type_dispatch_passes.py | 2 +- backends/test/graph_builder.py | 142 ++++++++++++++++++ backends/test/program_builder.py | 138 +++++++++++++++++ backends/test/targets.bzl | 28 ++++ 19 files changed, 346 insertions(+), 311 deletions(-) create mode 100644 backends/test/graph_builder.py create mode 100644 backends/test/program_builder.py diff --git a/backends/cadence/aot/BUCK b/backends/cadence/aot/BUCK index 28ffe9788eb..3eb77d4470f 100644 --- a/backends/cadence/aot/BUCK +++ b/backends/cadence/aot/BUCK @@ -227,8 +227,7 @@ fbcode_target(_kind = runtime.python_library, ], typing = True, deps = [ - "fbcode//caffe2:torch", - "fbcode//executorch/exir:pass_base", + "//executorch/backends/test:graph_builder", ], ) @@ -239,11 +238,7 @@ fbcode_target(_kind = runtime.python_library, ], typing = True, deps = [ - ":graph_builder", - "fbcode//caffe2:torch", - "fbcode//executorch/exir:lib", - "fbcode//executorch/exir:pass_base", - "fbcode//executorch/exir/verification:verifier", + "//executorch/backends/test:program_builder", ], ) @@ -254,7 +249,7 @@ fbcode_target(_kind = python_unittest, ], typing = True, deps = [ - ":program_builder", + "//executorch/backends/test:program_builder", "//caffe2:torch", "//later:lib", ], @@ -398,7 +393,7 @@ fbcode_target(_kind = python_unittest, ":typing_stubs", ":type_dispatch", "//caffe2:torch", - "//executorch/backends/cadence/aot:graph_builder", + "//executorch/backends/test:graph_builder", "//executorch/backends/cadence/aot:pass_utils", "//executorch/exir:pass_base", "//executorch/exir/dialects:lib", @@ -438,7 +433,7 @@ fbcode_target(_kind = python_unittest, deps = [ ":ops_registrations", "//caffe2:torch", - "//executorch/backends/cadence/aot:graph_builder", + "//executorch/backends/test:graph_builder", "//executorch/backends/cadence/aot:pass_utils", "//executorch/exir:pass_base", "//executorch/exir/dialects:lib", @@ -459,7 +454,7 @@ fbcode_target(_kind = python_unittest, ":replace_ops", "//caffe2:torch", "//executorch/backends/cadence/aot:compiler", - "//executorch/backends/cadence/aot:graph_builder", + "//executorch/backends/test:graph_builder", "//executorch/backends/cadence/aot:pass_utils", "//executorch/exir:pass_base", "//executorch/exir/dialects:lib", @@ -480,7 +475,7 @@ fbcode_target(_kind = python_unittest, "//caffe2:torch", ":typing_stubs", "//executorch/backends/cadence/aot:compiler", - "//executorch/backends/cadence/aot:graph_builder", + "//executorch/backends/test:graph_builder", "//executorch/backends/cadence/aot:pass_utils", "//executorch/exir:pass_base", "//executorch/exir/dialects:lib", @@ -501,7 +496,7 @@ fbcode_target(_kind = python_unittest, "//caffe2:torch", "//executorch/backends/cadence/aot:compiler", "//executorch/backends/cadence/aot:fuse_ops", - "//executorch/backends/cadence/aot:graph_builder", + "//executorch/backends/test:graph_builder", "//executorch/backends/cadence/aot:ops_registrations", "//executorch/backends/cadence/aot:pass_utils", "//executorch/exir/dialects:lib", @@ -522,7 +517,7 @@ fbcode_target(_kind = python_unittest, ":compiler", "//caffe2:torch", "//executorch/backends/cadence/aot:compiler", - "//executorch/backends/cadence/aot:graph_builder", + "//executorch/backends/test:graph_builder", "//executorch/backends/cadence/aot:ops_registrations", "//executorch/backends/cadence/aot:pass_utils", "//executorch/backends/cadence/aot:remove_ops", @@ -542,7 +537,7 @@ fbcode_target(_kind = python_unittest, ":typing_stubs", "//caffe2:torch", "//executorch/backends/cadence/aot:compiler", - "//executorch/backends/cadence/aot:graph_builder", + "//executorch/backends/test:graph_builder", "//executorch/backends/cadence/aot:ops_registrations", "//executorch/backends/cadence/aot:pass_utils", "//executorch/backends/cadence/aot:simplify_ops", @@ -562,7 +557,7 @@ fbcode_target(_kind = python_unittest, "//caffe2:torch", "//executorch/backends/cadence/aot:compiler", "//executorch/backends/cadence/aot:fuse_ops", - "//executorch/backends/cadence/aot:graph_builder", + "//executorch/backends/test:graph_builder", "//executorch/backends/cadence/aot:ops_registrations", "//executorch/backends/cadence/aot:pass_utils", "//executorch/backends/cadence/aot:reorder_ops", @@ -632,11 +627,11 @@ fbcode_target(_kind = python_unittest, ":typing_stubs", ":ops_registrations", ":pass_utils", - ":program_builder", + "//executorch/backends/test:program_builder", "//caffe2:torch", "//executorch/exir:memory", "//executorch/exir/dialects:lib", - "//executorch/backends/cadence/aot:graph_builder", + "//executorch/backends/test:graph_builder", "//executorch/exir/tests:models", ], ) @@ -648,8 +643,7 @@ fbcode_target(_kind = python_unittest, ], typing = True, deps = [ - ":program_builder", - "//executorch/backends/cadence/aot:graph_builder", + "//executorch/backends/test:graph_builder", "//executorch/backends/cadence/aot:ops_registrations", "//executorch/runtime:runtime", "//later:lib", @@ -679,7 +673,7 @@ fbcode_target(_kind = python_unittest, deps = [ "fbsource//third-party/pypi/parameterized:parameterized", "//caffe2:torch", - "//executorch/backends/cadence/aot:graph_builder", + "//executorch/backends/test:graph_builder", "//executorch/backends/cadence/aot/quantizer:quantizer", "//executorch/exir:pass_base", "//pytorch/ao:torchao", @@ -694,7 +688,7 @@ fbcode_target(_kind = python_unittest, typing = True, deps = [ ":ops_registrations", - ":program_builder", + "//executorch/backends/test:program_builder", ":to_out_var_pass", "//caffe2:torch", "//executorch/exir:lib", diff --git a/backends/cadence/aot/graph_builder.py b/backends/cadence/aot/graph_builder.py index f609ba55472..e8aac331ecd 100644 --- a/backends/cadence/aot/graph_builder.py +++ b/backends/cadence/aot/graph_builder.py @@ -6,137 +6,8 @@ # pyre-strict -import logging -from typing import Optional, Sequence, Union +# This module has moved to executorch.backends.test.graph_builder. +# This re-export exists for backward compatibility. +from executorch.backends.test.graph_builder import GraphBuilder, single_op_builder -import torch -from executorch.exir.pass_base import ( - Argument, - ExportPass, - NodeMetadata, - PassResult, - ProxyValue, -) -from torch._dispatch.python import enable_python_dispatcher -from torch._subclasses import FakeTensor, FakeTensorMode -from torch.fx.node import Target -from torch.utils import _pytree as pytree - - -class GraphBuilder(ExportPass): - """Utility class for creating a graph module with user-specified ops. - - This class allows us to create test graph modules with any ops we want - directly, rather than relying on decomposition or passes. - - Usage: - builder = GraphBuilder() - # To insert placeholders, use builder.placeholder. - x = builder.placeholder("x", torch.randn(1, 3, 224, 224)) - # To insert an op, use builder.call_operator. - op = builder.call_operator( - some_op - (x, other_args, ...), - ) - # Insert outputs as a list of ProxyValues using builder.output. - builder.output([op]) - # Get GraphModule from builder. - gm = builder.get_graph_module() - """ - - def __init__(self, fake_tensor_mode: Optional[FakeTensorMode] = None) -> None: - self.exporter = ExportPass() - self.tracer: ExportPass.ExportTracer = self.ExportTracer( - self, torch.fx.graph.CodeGen() - ) - self.fake_tensor_mode: FakeTensorMode = fake_tensor_mode or FakeTensorMode( - allow_fallback_kernels=False, - allow_non_fake_inputs=True, - ) - self.tracer.fake_tensor_mode = self.fake_tensor_mode - - # This will be called to create nodes in tracer. - self.interpreter = torch.fx.Interpreter( - torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph()) - ) - - # pyre-ignore[14]: Inconsistent override. - def placeholder( - self, target: str, fake_tensor: Union[FakeTensor, torch.Tensor] - ) -> ProxyValue: - if not isinstance(fake_tensor, FakeTensor): - fake_tensor = self.fake_tensor_mode.from_tensor(fake_tensor) - logging.debug(f"Creating placeholder {target} => {fake_tensor.shape}") - placeholder = super().placeholder(target, fake_tensor, NodeMetadata({})) - return placeholder - - # pyre-ignore[14]: Inconsistent override. - def output(self, results: list[ProxyValue]) -> ProxyValue: - logging.debug(f"Creating outputs {results}") - return super().output(results, NodeMetadata({})) - - def get_graph_module(self) -> torch.fx.GraphModule: - return torch.fx.GraphModule(self.tracer.root, self.tracer.graph) - - def call_operator( - self, - op, # pyre-ignore - args: tuple[Argument, ...], - kwargs: Optional[dict[str, Argument]] = None, - meta: Optional[NodeMetadata] = None, - ) -> ProxyValue: - if meta is None: - meta = NodeMetadata({}) - if kwargs is None: - kwargs = {} - return super().call_operator(op, args, kwargs, meta) - - def call_submodule( - self, graph_module: torch.fx.GraphModule, inputs: tuple[Argument, ...] - ) -> PassResult: - return ExportPass().call(graph_module) - - def call_getitem( - self, value: ProxyValue, key: int, meta: Optional[NodeMetadata] = None - ) -> ProxyValue: - return super().call_getitem(value, key, meta or NodeMetadata({})) - - def _fx( - self, - kind: str, - target: torch.fx.node.Target, - args: tuple[Argument, ...], - kwargs: dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - with self.fake_tensor_mode, enable_python_dispatcher(): - return super()._fx(kind, target, args, kwargs, meta) - - -def single_op_builder( - placeholders: Sequence[Union[torch.Tensor, FakeTensor]], - op: Target, - args: Sequence[Argument], - kwargs: Optional[dict[str, Argument]] = None, -) -> torch.fx.GraphModule: - """Create a graph module with a single op. - - Args: - placeholders: Placeholders to be used as inputs to the GraphModule. - op: The op to be inserted. - args: The args to be passed to the op. - kwargs: The kwargs to be passed to the op. - - Returns: - A graph module with a single op - """ - builder = GraphBuilder() - op_to_placeholder_dict = { - p: builder.placeholder(f"p_{i}", p) for i, p in enumerate(placeholders) - } - proxy_args, proxy_kwargs = pytree.tree_map_only( - (torch.Tensor, FakeTensor), lambda x: op_to_placeholder_dict[x], (args, kwargs) - ) - node = builder.call_operator(op, proxy_args, proxy_kwargs) - builder.output([node]) - return builder.get_graph_module() +__all__ = ["GraphBuilder", "single_op_builder"] diff --git a/backends/cadence/aot/program_builder.py b/backends/cadence/aot/program_builder.py index 0f4e2bc7850..8c50d27154c 100644 --- a/backends/cadence/aot/program_builder.py +++ b/backends/cadence/aot/program_builder.py @@ -2,137 +2,8 @@ # pyre-strict -from enum import auto, Enum -from typing import Optional +# This module has moved to executorch.backends.test.program_builder. +# This re-export exists for backward compatibility. +from executorch.backends.test.program_builder import IrMode, ProgramBuilder -from executorch.backends.cadence.aot.graph_builder import GraphBuilder -from executorch.exir import EdgeCompileConfig, EdgeProgramManager -from executorch.exir.pass_base import ProxyValue -from executorch.exir.verification.verifier import EXIREdgeDialectVerifier -from torch import Tensor -from torch._export.verifier import Verifier -from torch._ops import OpOverload -from torch._subclasses.fake_tensor import FakeTensorMode -from torch.export import ExportedProgram -from torch.export.exported_program import ModuleCallEntry, ModuleCallSignature -from torch.export.graph_signature import ( - ExportGraphSignature, - InputKind, - InputSpec, - OutputKind, - OutputSpec, - TensorArgument, -) -from torch.utils import _pytree as pytree - - -class IrMode(Enum): - EXIR = auto() - ATEN = auto() - - -class ProgramBuilder(GraphBuilder): - """Utility class to build a program from a graph module.""" - - def __init__( - self, - mode: Optional[IrMode] = None, - _core_aten_ops_exception_list: Optional[list[OpOverload]] = None, - fake_tensor_mode: Optional[FakeTensorMode] = None, - ) -> None: - self.input_specs: list[InputSpec] = [] - self.output_specs: list[OutputSpec] = [] - self.constants: dict[str, Tensor] = {} - self.state_dict: dict[str, Tensor] = {} - self.mode: IrMode = mode or IrMode.EXIR - self._core_aten_ops_exception_list: list[OpOverload] = ( - _core_aten_ops_exception_list or [] - ) - super().__init__(fake_tensor_mode=fake_tensor_mode) - - def insert_input_spec( - self, target: str, input_kind: InputKind, value: Tensor - ) -> None: - persistent: Optional[bool] = None - if input_kind == InputKind.BUFFER: - persistent = True - self.input_specs.append( - InputSpec( - input_kind, TensorArgument(target), target=target, persistent=persistent - ) - ) - if input_kind == InputKind.PARAMETER or input_kind == InputKind.BUFFER: - self.state_dict[target] = value - elif input_kind == InputKind.CONSTANT_TENSOR: - self.constants[target] = value - - def placeholder( - self, - target: str, - fake_tensor: Tensor, - input_kind: InputKind = InputKind.USER_INPUT, - ) -> ProxyValue: - placeholder = super().placeholder(target, fake_tensor) - self.insert_input_spec(target, input_kind, fake_tensor) - return placeholder - - def output( - self, - results: list[ProxyValue], - output_kinds: Optional[list[OutputKind]] = None, - output_targets: Optional[list[str | None]] = None, - ) -> ProxyValue: - if output_kinds is None: - output_kinds = [OutputKind.USER_OUTPUT] * len(results) - if output_targets is None: - output_targets = [None] * len(results) - for result, out_kind, target in zip(results, output_kinds, output_targets): - self.output_specs.append( - OutputSpec(out_kind, TensorArgument(result.node.name), target=target) - ) - return super().output(results) - - def get_verifiers(self) -> Optional[list[Verifier]]: - if self.mode == IrMode.ATEN: - return None - return [ - EXIREdgeDialectVerifier( - edge_compile_config=EdgeCompileConfig( - _check_ir_validity=False, - _core_aten_ops_exception_list=self._core_aten_ops_exception_list, - ), - core_aten_ops_exception_list=self._core_aten_ops_exception_list, - class_only=True, - ) - ] - - def get_program(self) -> ExportedProgram: - gm = self.get_graph_module() - graph_signature = ExportGraphSignature(self.input_specs, self.output_specs) - in_spec = pytree.tree_flatten((tuple(graph_signature.user_inputs), {}))[1] - out_spec = pytree.tree_flatten(graph_signature.user_outputs)[1] - return ExportedProgram( - root=gm, - graph=gm.graph, - graph_signature=graph_signature, - # pyre-ignore[6]: Incompatible parameter type. - constants=self.constants, - state_dict=self.state_dict, - range_constraints={}, - module_call_graph=[ - ModuleCallEntry( - "", - ModuleCallSignature( - inputs=[], outputs=[], in_spec=in_spec, out_spec=out_spec - ), - ) - ], - # pyre-ignore[6]: Incompatible parameter type. - verifiers=self.get_verifiers(), - ) - - def get_edge_program(self) -> EdgeProgramManager: - return EdgeProgramManager( - self.get_program(), - core_aten_ops_exception_list=self._core_aten_ops_exception_list, - ) +__all__ = ["IrMode", "ProgramBuilder"] diff --git a/backends/cadence/aot/tests/test_decompose_ops_passes.py b/backends/cadence/aot/tests/test_decompose_ops_passes.py index e4bdf42ff62..9472ad3b565 100644 --- a/backends/cadence/aot/tests/test_decompose_ops_passes.py +++ b/backends/cadence/aot/tests/test_decompose_ops_passes.py @@ -11,8 +11,8 @@ import torch from executorch.backends.cadence.aot.decompose_ops import DecomposeAtenApproxGeluPass -from executorch.backends.cadence.aot.graph_builder import single_op_builder from executorch.backends.cadence.aot.pass_utils import count_node +from executorch.backends.test.graph_builder import single_op_builder from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.dialects.edge._ops import EdgeOpOverload from executorch.exir.pass_base import ExportPass diff --git a/backends/cadence/aot/tests/test_fusion_ops_passes.py b/backends/cadence/aot/tests/test_fusion_ops_passes.py index b1e73dce94c..f5afbe243f8 100644 --- a/backends/cadence/aot/tests/test_fusion_ops_passes.py +++ b/backends/cadence/aot/tests/test_fusion_ops_passes.py @@ -28,9 +28,9 @@ FuseTransposeOrPermuteOpPairsPass, HierarchicalCSEPass, ) -from executorch.backends.cadence.aot.graph_builder import GraphBuilder from executorch.backends.cadence.aot.pass_utils import count_node, op_counts_match from executorch.backends.cadence.aot.typing_stubs import expand +from executorch.backends.test.graph_builder import GraphBuilder from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.dialects.edge._ops import EdgeOpOverload from executorch.exir.pass_base import PassResult, ProxyValue diff --git a/backends/cadence/aot/tests/test_graph_builder.py b/backends/cadence/aot/tests/test_graph_builder.py index c3506dc4c07..6e48f4ea668 100644 --- a/backends/cadence/aot/tests/test_graph_builder.py +++ b/backends/cadence/aot/tests/test_graph_builder.py @@ -11,11 +11,8 @@ import executorch.backends.cadence.aot.ops_registrations # noqa import torch -from executorch.backends.cadence.aot.graph_builder import ( - GraphBuilder, - single_op_builder, -) from executorch.backends.cadence.aot.pass_utils import count_node +from executorch.backends.test.graph_builder import GraphBuilder, single_op_builder from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, NodeMetadata from later.unittest import TestCase diff --git a/backends/cadence/aot/tests/test_idma_ops.py b/backends/cadence/aot/tests/test_idma_ops.py index 6320bfc482b..ca264f2458a 100644 --- a/backends/cadence/aot/tests/test_idma_ops.py +++ b/backends/cadence/aot/tests/test_idma_ops.py @@ -5,7 +5,7 @@ import executorch.backends.cadence.aot.ops_registrations # noqa import torch -from executorch.backends.cadence.aot.graph_builder import GraphBuilder +from executorch.backends.test.graph_builder import GraphBuilder from executorch.exir.dialects._ops import ops as exir_ops from later.unittest import TestCase diff --git a/backends/cadence/aot/tests/test_memory_passes.py b/backends/cadence/aot/tests/test_memory_passes.py index 6c8da2202d4..21c4212ac8c 100644 --- a/backends/cadence/aot/tests/test_memory_passes.py +++ b/backends/cadence/aot/tests/test_memory_passes.py @@ -14,7 +14,6 @@ import executorch.backends.cadence.aot.ops_registrations # noqa import torch from executorch.backends.cadence.aot import compiler -from executorch.backends.cadence.aot.graph_builder import GraphBuilder from executorch.backends.cadence.aot.memory_constraints import ( ConstraintsGenPass, MemConstraints, @@ -33,12 +32,13 @@ count_node, register_cadence_pass, ) -from executorch.backends.cadence.aot.program_builder import ProgramBuilder from executorch.backends.cadence.aot.typing_stubs import expand from executorch.backends.cadence.aot.utils import ( get_default_memory_config, MemoryConfig, ) +from executorch.backends.test.graph_builder import GraphBuilder +from executorch.backends.test.program_builder import ProgramBuilder from executorch.exir import EdgeProgramManager, ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.memory_planning import ( diff --git a/backends/cadence/aot/tests/test_program_builder.py b/backends/cadence/aot/tests/test_program_builder.py index a16d42e2378..8d9d421c00c 100644 --- a/backends/cadence/aot/tests/test_program_builder.py +++ b/backends/cadence/aot/tests/test_program_builder.py @@ -2,7 +2,7 @@ # pyre-strict import torch -from executorch.backends.cadence.aot.program_builder import IrMode, ProgramBuilder +from executorch.backends.test.program_builder import IrMode, ProgramBuilder from executorch.exir.dialects._ops import ops as exir_ops from later.unittest import TestCase from torch._export.verifier import SpecViolationError diff --git a/backends/cadence/aot/tests/test_quantizer_ops.py b/backends/cadence/aot/tests/test_quantizer_ops.py index e8061bb266c..06e2c08f4f4 100644 --- a/backends/cadence/aot/tests/test_quantizer_ops.py +++ b/backends/cadence/aot/tests/test_quantizer_ops.py @@ -11,10 +11,6 @@ from typing import Callable import torch -from executorch.backends.cadence.aot.graph_builder import ( - GraphBuilder, - single_op_builder, -) from executorch.backends.cadence.aot.quantizer import quantizer as quantizer_module from executorch.backends.cadence.aot.quantizer.patterns import AddmmPattern from executorch.backends.cadence.aot.quantizer.quantizer import ( @@ -35,6 +31,7 @@ qconfig_A8W8, qconfig_A8W8sym, ) +from executorch.backends.test.graph_builder import GraphBuilder, single_op_builder from executorch.exir.pass_base import NodeMetadata from parameterized import parameterized from torch._ops import OpOverload diff --git a/backends/cadence/aot/tests/test_remove_ops_passes.py b/backends/cadence/aot/tests/test_remove_ops_passes.py index 3772c2dc19a..11bceff0a05 100644 --- a/backends/cadence/aot/tests/test_remove_ops_passes.py +++ b/backends/cadence/aot/tests/test_remove_ops_passes.py @@ -14,7 +14,6 @@ import executorch.backends.cadence.aot.ops_registrations # noqa import torch from executorch.backends.cadence.aot.fuse_ops import FuseQuantDequantToRequantizePass -from executorch.backends.cadence.aot.graph_builder import GraphBuilder from executorch.backends.cadence.aot.pass_utils import count_node from executorch.backends.cadence.aot.remove_ops import ( @@ -36,6 +35,7 @@ RemoveZeroSizedConstantPadNd, ) from executorch.backends.cadence.aot.typing_stubs import expand +from executorch.backends.test.graph_builder import GraphBuilder from executorch.exir.dialects._ops import ops as exir_ops from pyre_extensions import none_throws diff --git a/backends/cadence/aot/tests/test_reorder_ops_passes.py b/backends/cadence/aot/tests/test_reorder_ops_passes.py index 998bfd7a676..4aa7f46c8a1 100644 --- a/backends/cadence/aot/tests/test_reorder_ops_passes.py +++ b/backends/cadence/aot/tests/test_reorder_ops_passes.py @@ -13,7 +13,6 @@ import executorch.backends.cadence.aot.ops_registrations # noqa import torch from executorch.backends.cadence.aot.fuse_ops import FuseQuantDequantToRequantizePass -from executorch.backends.cadence.aot.graph_builder import GraphBuilder from executorch.backends.cadence.aot.pass_utils import ( count_node, get_compute_nodes_in_gm, @@ -28,6 +27,7 @@ PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView, SinkOpsCloserToUsePass, ) +from executorch.backends.test.graph_builder import GraphBuilder from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import PassBase, PassResult from torch.utils import _pytree as pytree diff --git a/backends/cadence/aot/tests/test_replace_ops_passes.py b/backends/cadence/aot/tests/test_replace_ops_passes.py index bca07cbde5e..452a91b8af5 100644 --- a/backends/cadence/aot/tests/test_replace_ops_passes.py +++ b/backends/cadence/aot/tests/test_replace_ops_passes.py @@ -14,10 +14,6 @@ import executorch.backends.cadence.aot.ref_implementations # noqa import torch -from executorch.backends.cadence.aot.graph_builder import ( - GraphBuilder, - single_op_builder, -) from executorch.backends.cadence.aot.pass_utils import count_node, op_counts_match from executorch.backends.cadence.aot.replace_ops import ( MakeSliceAndCatDimOutermostPass, @@ -56,6 +52,7 @@ ) from executorch.backends.cadence.aot.typing_stubs import expand +from executorch.backends.test.graph_builder import GraphBuilder, single_op_builder from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, ProxyValue from torch.fx.passes.infra.pass_base import PassResult diff --git a/backends/cadence/aot/tests/test_simplify_ops_passes.py b/backends/cadence/aot/tests/test_simplify_ops_passes.py index a20b7fd535e..79d89c24713 100644 --- a/backends/cadence/aot/tests/test_simplify_ops_passes.py +++ b/backends/cadence/aot/tests/test_simplify_ops_passes.py @@ -12,13 +12,13 @@ import executorch.backends.cadence.aot.ops_registrations # noqa import torch -from executorch.backends.cadence.aot.graph_builder import single_op_builder from executorch.backends.cadence.aot.pass_utils import count_node from executorch.backends.cadence.aot.simplify_ops import ( BindOptionalArgsPass, SimplifySliceOpPass, ) from executorch.backends.cadence.aot.typing_stubs import expand +from executorch.backends.test.graph_builder import single_op_builder from executorch.exir.dialects._ops import ops as exir_ops from torch.fx.passes.infra.pass_base import PassBase, PassResult from torch.utils import _pytree as pytree diff --git a/backends/cadence/aot/tests/test_to_out_var_pass.py b/backends/cadence/aot/tests/test_to_out_var_pass.py index f9181f2c0bb..01c64136fa7 100644 --- a/backends/cadence/aot/tests/test_to_out_var_pass.py +++ b/backends/cadence/aot/tests/test_to_out_var_pass.py @@ -8,8 +8,8 @@ import executorch.backends.cadence.aot.ops_registrations # noqa import torch -from executorch.backends.cadence.aot.program_builder import ProgramBuilder from executorch.backends.cadence.aot.to_out_var_pass import CadenceToOutVarPass +from executorch.backends.test.program_builder import ProgramBuilder from executorch.exir import ExecutorchBackendConfig from executorch.exir.dialects._ops import ops as exir_ops from later.unittest import TestCase diff --git a/backends/cadence/aot/tests/test_type_dispatch_passes.py b/backends/cadence/aot/tests/test_type_dispatch_passes.py index f0847e8ca77..595b79dd8c6 100644 --- a/backends/cadence/aot/tests/test_type_dispatch_passes.py +++ b/backends/cadence/aot/tests/test_type_dispatch_passes.py @@ -10,10 +10,10 @@ import executorch.backends.cadence.aot.ops_registrations # noqa import torch -from executorch.backends.cadence.aot.graph_builder import single_op_builder from executorch.backends.cadence.aot.pass_utils import count_node from executorch.backends.cadence.aot.type_dispatch import CompileTimeTypeDispatchPass from executorch.backends.cadence.aot.typing_stubs import expand +from executorch.backends.test.graph_builder import single_op_builder from executorch.exir.dialects._ops import ops as exir_ops from torch.fx.passes.infra.pass_base import PassResult diff --git a/backends/test/graph_builder.py b/backends/test/graph_builder.py new file mode 100644 index 00000000000..f609ba55472 --- /dev/null +++ b/backends/test/graph_builder.py @@ -0,0 +1,142 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import logging +from typing import Optional, Sequence, Union + +import torch +from executorch.exir.pass_base import ( + Argument, + ExportPass, + NodeMetadata, + PassResult, + ProxyValue, +) +from torch._dispatch.python import enable_python_dispatcher +from torch._subclasses import FakeTensor, FakeTensorMode +from torch.fx.node import Target +from torch.utils import _pytree as pytree + + +class GraphBuilder(ExportPass): + """Utility class for creating a graph module with user-specified ops. + + This class allows us to create test graph modules with any ops we want + directly, rather than relying on decomposition or passes. + + Usage: + builder = GraphBuilder() + # To insert placeholders, use builder.placeholder. + x = builder.placeholder("x", torch.randn(1, 3, 224, 224)) + # To insert an op, use builder.call_operator. + op = builder.call_operator( + some_op + (x, other_args, ...), + ) + # Insert outputs as a list of ProxyValues using builder.output. + builder.output([op]) + # Get GraphModule from builder. + gm = builder.get_graph_module() + """ + + def __init__(self, fake_tensor_mode: Optional[FakeTensorMode] = None) -> None: + self.exporter = ExportPass() + self.tracer: ExportPass.ExportTracer = self.ExportTracer( + self, torch.fx.graph.CodeGen() + ) + self.fake_tensor_mode: FakeTensorMode = fake_tensor_mode or FakeTensorMode( + allow_fallback_kernels=False, + allow_non_fake_inputs=True, + ) + self.tracer.fake_tensor_mode = self.fake_tensor_mode + + # This will be called to create nodes in tracer. + self.interpreter = torch.fx.Interpreter( + torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph()) + ) + + # pyre-ignore[14]: Inconsistent override. + def placeholder( + self, target: str, fake_tensor: Union[FakeTensor, torch.Tensor] + ) -> ProxyValue: + if not isinstance(fake_tensor, FakeTensor): + fake_tensor = self.fake_tensor_mode.from_tensor(fake_tensor) + logging.debug(f"Creating placeholder {target} => {fake_tensor.shape}") + placeholder = super().placeholder(target, fake_tensor, NodeMetadata({})) + return placeholder + + # pyre-ignore[14]: Inconsistent override. + def output(self, results: list[ProxyValue]) -> ProxyValue: + logging.debug(f"Creating outputs {results}") + return super().output(results, NodeMetadata({})) + + def get_graph_module(self) -> torch.fx.GraphModule: + return torch.fx.GraphModule(self.tracer.root, self.tracer.graph) + + def call_operator( + self, + op, # pyre-ignore + args: tuple[Argument, ...], + kwargs: Optional[dict[str, Argument]] = None, + meta: Optional[NodeMetadata] = None, + ) -> ProxyValue: + if meta is None: + meta = NodeMetadata({}) + if kwargs is None: + kwargs = {} + return super().call_operator(op, args, kwargs, meta) + + def call_submodule( + self, graph_module: torch.fx.GraphModule, inputs: tuple[Argument, ...] + ) -> PassResult: + return ExportPass().call(graph_module) + + def call_getitem( + self, value: ProxyValue, key: int, meta: Optional[NodeMetadata] = None + ) -> ProxyValue: + return super().call_getitem(value, key, meta or NodeMetadata({})) + + def _fx( + self, + kind: str, + target: torch.fx.node.Target, + args: tuple[Argument, ...], + kwargs: dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: + with self.fake_tensor_mode, enable_python_dispatcher(): + return super()._fx(kind, target, args, kwargs, meta) + + +def single_op_builder( + placeholders: Sequence[Union[torch.Tensor, FakeTensor]], + op: Target, + args: Sequence[Argument], + kwargs: Optional[dict[str, Argument]] = None, +) -> torch.fx.GraphModule: + """Create a graph module with a single op. + + Args: + placeholders: Placeholders to be used as inputs to the GraphModule. + op: The op to be inserted. + args: The args to be passed to the op. + kwargs: The kwargs to be passed to the op. + + Returns: + A graph module with a single op + """ + builder = GraphBuilder() + op_to_placeholder_dict = { + p: builder.placeholder(f"p_{i}", p) for i, p in enumerate(placeholders) + } + proxy_args, proxy_kwargs = pytree.tree_map_only( + (torch.Tensor, FakeTensor), lambda x: op_to_placeholder_dict[x], (args, kwargs) + ) + node = builder.call_operator(op, proxy_args, proxy_kwargs) + builder.output([node]) + return builder.get_graph_module() diff --git a/backends/test/program_builder.py b/backends/test/program_builder.py new file mode 100644 index 00000000000..a69c02286e5 --- /dev/null +++ b/backends/test/program_builder.py @@ -0,0 +1,138 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +# pyre-strict + +from enum import auto, Enum +from typing import Optional + +from executorch.backends.test.graph_builder import GraphBuilder +from executorch.exir import EdgeCompileConfig, EdgeProgramManager +from executorch.exir.pass_base import ProxyValue +from executorch.exir.verification.verifier import EXIREdgeDialectVerifier +from torch import Tensor +from torch._export.verifier import Verifier +from torch._ops import OpOverload +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.export import ExportedProgram +from torch.export.exported_program import ModuleCallEntry, ModuleCallSignature +from torch.export.graph_signature import ( + ExportGraphSignature, + InputKind, + InputSpec, + OutputKind, + OutputSpec, + TensorArgument, +) +from torch.utils import _pytree as pytree + + +class IrMode(Enum): + EXIR = auto() + ATEN = auto() + + +class ProgramBuilder(GraphBuilder): + """Utility class to build a program from a graph module.""" + + def __init__( + self, + mode: Optional[IrMode] = None, + _core_aten_ops_exception_list: Optional[list[OpOverload]] = None, + fake_tensor_mode: Optional[FakeTensorMode] = None, + ) -> None: + self.input_specs: list[InputSpec] = [] + self.output_specs: list[OutputSpec] = [] + self.constants: dict[str, Tensor] = {} + self.state_dict: dict[str, Tensor] = {} + self.mode: IrMode = mode or IrMode.EXIR + self._core_aten_ops_exception_list: list[OpOverload] = ( + _core_aten_ops_exception_list or [] + ) + super().__init__(fake_tensor_mode=fake_tensor_mode) + + def insert_input_spec( + self, target: str, input_kind: InputKind, value: Tensor + ) -> None: + persistent: Optional[bool] = None + if input_kind == InputKind.BUFFER: + persistent = True + self.input_specs.append( + InputSpec( + input_kind, TensorArgument(target), target=target, persistent=persistent + ) + ) + if input_kind == InputKind.PARAMETER or input_kind == InputKind.BUFFER: + self.state_dict[target] = value + elif input_kind == InputKind.CONSTANT_TENSOR: + self.constants[target] = value + + def placeholder( + self, + target: str, + fake_tensor: Tensor, + input_kind: InputKind = InputKind.USER_INPUT, + ) -> ProxyValue: + placeholder = super().placeholder(target, fake_tensor) + self.insert_input_spec(target, input_kind, fake_tensor) + return placeholder + + def output( + self, + results: list[ProxyValue], + output_kinds: Optional[list[OutputKind]] = None, + output_targets: Optional[list[str | None]] = None, + ) -> ProxyValue: + if output_kinds is None: + output_kinds = [OutputKind.USER_OUTPUT] * len(results) + if output_targets is None: + output_targets = [None] * len(results) + for result, out_kind, target in zip(results, output_kinds, output_targets): + self.output_specs.append( + OutputSpec(out_kind, TensorArgument(result.node.name), target=target) + ) + return super().output(results) + + def get_verifiers(self) -> Optional[list[Verifier]]: + if self.mode == IrMode.ATEN: + return None + return [ + EXIREdgeDialectVerifier( + edge_compile_config=EdgeCompileConfig( + _check_ir_validity=False, + _core_aten_ops_exception_list=self._core_aten_ops_exception_list, + ), + core_aten_ops_exception_list=self._core_aten_ops_exception_list, + class_only=True, + ) + ] + + def get_program(self) -> ExportedProgram: + gm = self.get_graph_module() + graph_signature = ExportGraphSignature(self.input_specs, self.output_specs) + in_spec = pytree.tree_flatten((tuple(graph_signature.user_inputs), {}))[1] + out_spec = pytree.tree_flatten(graph_signature.user_outputs)[1] + return ExportedProgram( + root=gm, + graph=gm.graph, + graph_signature=graph_signature, + # pyre-ignore[6]: Incompatible parameter type. + constants=self.constants, + state_dict=self.state_dict, + range_constraints={}, + module_call_graph=[ + ModuleCallEntry( + "", + ModuleCallSignature( + inputs=[], outputs=[], in_spec=in_spec, out_spec=out_spec + ), + ) + ], + # pyre-ignore[6]: Incompatible parameter type. + verifiers=self.get_verifiers(), + ) + + def get_edge_program(self) -> EdgeProgramManager: + return EdgeProgramManager( + self.get_program(), + core_aten_ops_exception_list=self._core_aten_ops_exception_list, + ) diff --git a/backends/test/targets.bzl b/backends/test/targets.bzl index 6588c57fcc7..307782a3bc2 100644 --- a/backends/test/targets.bzl +++ b/backends/test/targets.bzl @@ -6,6 +6,34 @@ def define_common_targets(is_fbcode = False): The directory containing this targets.bzl file should also contain both TARGETS and BUCK files that call this function. """ + if is_fbcode: + runtime.python_library( + name = "graph_builder", + srcs = [ + "graph_builder.py", + ], + typing = True, + deps = [ + "//caffe2:torch", + "//executorch/exir:pass_base", + ], + ) + + runtime.python_library( + name = "program_builder", + srcs = [ + "program_builder.py", + ], + typing = True, + deps = [ + ":graph_builder", + "//caffe2:torch", + "//executorch/exir:lib", + "//executorch/exir:pass_base", + "//executorch/exir/verification:verifier", + ], + ) + if not runtime.is_oss and is_fbcode: modules_env = { "ET_XNNPACK_GENERATED_ADD_LARGE_PTE_PATH": "$(location fbcode//executorch/test/models:exported_xnnp_delegated_programs[ModuleAddLarge.pte])",