diff --git a/backends/arm/_passes/arm_pass.py b/backends/arm/_passes/arm_pass.py index 1a1a179f456..21cea5f88d1 100644 --- a/backends/arm/_passes/arm_pass.py +++ b/backends/arm/_passes/arm_pass.py @@ -20,6 +20,25 @@ class ArmPass(ExportPass): """Base class for Arm passes.""" + def __init_subclass__(cls, **kwargs) -> None: + super().__init_subclass__(**kwargs) + if getattr(cls, "targeted_ops", None) is not None: + return + # Only auto-discover targeted_ops for passes that use the standard + # call_operator() pattern. Passes that override call() use _TARGET_OPS + # for their own graph manipulation logic, not as a fast-copy declaration. + if "call" in cls.__dict__: + return + for attr in ("_TARGET_OPS", "_supported_ops"): + ops = getattr(cls, attr, None) + if ops: + cls.targeted_ops = set(ops) if not isinstance(ops, set) else ops # type: ignore[attr-defined] + return + edge = getattr(cls, "_EDGE_OPS", None) + aten = getattr(cls, "_ATEN_OPS", None) + if edge or aten: + cls.targeted_ops = {*(edge or ()), *(aten or ())} # type: ignore[attr-defined] + def __init__(self, tfa_pass: bool = False, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.submodule_depth = 0 @@ -78,6 +97,34 @@ def get_name(pass_) -> str: f"Cannot get name for pass: {pass_}. It must be an instance of ExportPass or have a __name__ attribute." ) + def should_run(self, graph_module: GraphModule) -> bool: + """Skip this pass if the graph contains none of its targeted ops. + + Subclasses that define a ``targeted_ops`` class attribute (a set of + op overloads) get this check for free via inheritance. Passes + without ``targeted_ops`` always run (the default). + + Recursively checks control flow submodules (cond/while_loop) so + passes are not incorrectly skipped when targeted ops are nested. + + """ + targeted = getattr(self, "targeted_ops", None) + if targeted is None: + return True + + from executorch.exir.graph_module import get_control_flow_submodules + + def _has_targeted_op(gm: GraphModule) -> bool: + for node in gm.graph.nodes: + if node.op == "call_function" and node.target in targeted: + return True + for _, submod, _ in get_control_flow_submodules(gm): + if _has_targeted_op(submod): + return True + return False + + return _has_targeted_op(graph_module) + def call_operator(self, op, args, kwargs, meta, updated: Optional[bool] = False): if not updated: return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/cast_to_int32_pass.py b/backends/arm/_passes/cast_to_int32_pass.py index 609526b9ecc..6b117da3fb1 100644 --- a/backends/arm/_passes/cast_to_int32_pass.py +++ b/backends/arm/_passes/cast_to_int32_pass.py @@ -6,9 +6,7 @@ from typing import Set, Type import torch - from executorch.backends.arm._passes.arm_pass import ArmPass - from executorch.backends.arm.tosa.specification import get_context_spec from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult diff --git a/backends/arm/_passes/conv1d_unsqueeze_pass.py b/backends/arm/_passes/conv1d_unsqueeze_pass.py index 58c3c0c35a2..1c7b2618976 100644 --- a/backends/arm/_passes/conv1d_unsqueeze_pass.py +++ b/backends/arm/_passes/conv1d_unsqueeze_pass.py @@ -9,10 +9,8 @@ from typing import Set, Type from executorch.backends.arm._passes import ArmPass - from executorch.backends.arm._passes.rewrite_conv_pass import RewriteConvPass from executorch.backends.arm._passes.size_adjust_input_pass import SizeAdjustInputPass - from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -35,6 +33,8 @@ class Conv1dUnsqueezePass(ArmPass): SizeAdjustInputPass, } + targeted_ops = {exir_ops.edge.aten.convolution.default} + def call_operator(self, op, args, kwargs, meta): if op != exir_ops.edge.aten.convolution.default: return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/convert_expand_copy_to_repeat.py b/backends/arm/_passes/convert_expand_copy_to_repeat.py index 69056cb47f4..429b82f85ba 100644 --- a/backends/arm/_passes/convert_expand_copy_to_repeat.py +++ b/backends/arm/_passes/convert_expand_copy_to_repeat.py @@ -8,7 +8,6 @@ from typing import cast, Set, Type import torch - from executorch.backends.arm._passes.arm_pass import ArmPass from executorch.backends.arm._passes.unsqueeze_before_repeat_pass import ( UnsqueezeBeforeRepeatPass, @@ -58,6 +57,8 @@ class ConvertExpandCopyToRepeatPass(ArmPass): _passes_required_after: Set[Type[ExportPass]] = {UnsqueezeBeforeRepeatPass} + targeted_ops = {exir_ops.edge.aten.expand_copy.default} + expand_copy = exir_ops.edge.aten.expand_copy.default repeat = exir_ops.edge.aten.repeat.default diff --git a/backends/arm/_passes/convert_full_like_to_full_pass.py b/backends/arm/_passes/convert_full_like_to_full_pass.py index 1e26f24250a..710c84569ff 100644 --- a/backends/arm/_passes/convert_full_like_to_full_pass.py +++ b/backends/arm/_passes/convert_full_like_to_full_pass.py @@ -9,7 +9,6 @@ from executorch.backends.arm._passes.fuse_constant_ops_pass import ( ComputeConstantOpsAOTPass, ) - from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -36,6 +35,8 @@ class ConvertFullLikeToFullPass(ArmPass): _passes_required_after: Set[Type[ExportPass]] = {ComputeConstantOpsAOTPass} + targeted_ops = {exir_ops.edge.aten.full_like.default} + def call_operator(self, op, args, kwargs, meta): if op not in [ exir_ops.edge.aten.full_like.default, diff --git a/backends/arm/_passes/convert_permute_singleton_to_view_pass.py b/backends/arm/_passes/convert_permute_singleton_to_view_pass.py index 7447cf037bc..763b736c5f3 100644 --- a/backends/arm/_passes/convert_permute_singleton_to_view_pass.py +++ b/backends/arm/_passes/convert_permute_singleton_to_view_pass.py @@ -7,10 +7,8 @@ from typing import Sequence, Set, Tuple, Type from executorch.backends.arm._passes.arm_pass import ArmPass - from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass - from torch._ops import OpOverload @@ -35,6 +33,8 @@ class ConvertPermuteSingletonToViewPass(ArmPass): _passes_required_after: Set[Type[ExportPass]] = set() + targeted_ops = set(_PERMUTE_TARGETS) + def call_operator(self, op, args, kwargs, meta): if op not in _PERMUTE_TARGETS: return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/convert_split_to_slice.py b/backends/arm/_passes/convert_split_to_slice.py index 03c5c794d6a..cb88a19ca50 100644 --- a/backends/arm/_passes/convert_split_to_slice.py +++ b/backends/arm/_passes/convert_split_to_slice.py @@ -21,6 +21,11 @@ class ConvertSplitToSlicePass(ArmPass): _passes_required_after: Set[Type[ExportPass]] = set() + targeted_ops = { + exir_ops.edge.aten.split_with_sizes_copy.default, + exir_ops.edge.aten.split_copy.Tensor, + } + split_ops = ( exir_ops.edge.aten.split_with_sizes_copy.default, exir_ops.edge.aten.split_copy.Tensor, diff --git a/backends/arm/_passes/convert_squeezes_to_view.py b/backends/arm/_passes/convert_squeezes_to_view.py index 2058c3407e3..0eaca3fa3b3 100644 --- a/backends/arm/_passes/convert_squeezes_to_view.py +++ b/backends/arm/_passes/convert_squeezes_to_view.py @@ -24,6 +24,11 @@ class ConvertSqueezesToViewPass(ArmPass): _passes_required_after: Set[Type[ExportPass]] = {FuseViewCopyTransformPass} + targeted_ops = { + exir_ops.edge.aten.squeeze_copy.dims, + exir_ops.edge.aten.unsqueeze_copy.default, + } + def call_operator(self, op, args, kwargs, meta): if op not in [ exir_ops.edge.aten.squeeze_copy.dims, diff --git a/backends/arm/_passes/convert_to_clamp_pass.py b/backends/arm/_passes/convert_to_clamp_pass.py index effb46f25c4..e47ee5f3b15 100644 --- a/backends/arm/_passes/convert_to_clamp_pass.py +++ b/backends/arm/_passes/convert_to_clamp_pass.py @@ -6,11 +6,9 @@ from typing import Set, Tuple, Type from executorch.backends.arm._passes import ArmPass - from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( QuantizeClampArgumentsPass, ) - from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -32,6 +30,8 @@ def get_clamp_params(op, args) -> Tuple[float | None, float | None]: class ConvertToClampPass(ArmPass): _passes_required_after: Set[Type[ExportPass]] = {QuantizeClampArgumentsPass} + targeted_ops = edge_operators + def call_operator(self, op, args, kwargs, meta): if op not in edge_operators or not self.allowed_to_transform(meta): return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_acosh_pass.py b/backends/arm/_passes/decompose_acosh_pass.py index 3ce6d73abc3..50ad46c6dee 100644 --- a/backends/arm/_passes/decompose_acosh_pass.py +++ b/backends/arm/_passes/decompose_acosh_pass.py @@ -37,6 +37,8 @@ class DecomposeAcoshPass(ArmPass): MatchArgDtypePass, } + targeted_ops = {edge_acosh_op} + def call_operator(self, op, args, kwargs, meta, updated=False): if op is not edge_acosh_op: diff --git a/backends/arm/_passes/decompose_adaptive_avg_pool2d_pass.py b/backends/arm/_passes/decompose_adaptive_avg_pool2d_pass.py index eda9dd28bf9..2c68edd727b 100644 --- a/backends/arm/_passes/decompose_adaptive_avg_pool2d_pass.py +++ b/backends/arm/_passes/decompose_adaptive_avg_pool2d_pass.py @@ -7,12 +7,10 @@ from typing import Set, Type import torch - from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.decompose_avg_pool2d_pass import ( DecomposeAvgPool2dPass, ) - from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, NodeMetadata @@ -48,6 +46,8 @@ class DecomposeAdaptiveAvgPool2dPass(ArmPass): _passes_required_after: Set[Type[ExportPass]] = {DecomposeAvgPool2dPass} + targeted_ops = {*edge_ops, *aten_ops} + def call_operator(self, op, args, kwargs, meta, updated=False): if op not in (edge_ops + aten_ops) or not self.allowed_to_transform(meta): return super().call_operator(op, args, kwargs, meta, updated) diff --git a/backends/arm/_passes/decompose_add_sub_alpha_pass.py b/backends/arm/_passes/decompose_add_sub_alpha_pass.py index d7db9c5bcf9..62f3d8a0d6e 100644 --- a/backends/arm/_passes/decompose_add_sub_alpha_pass.py +++ b/backends/arm/_passes/decompose_add_sub_alpha_pass.py @@ -60,6 +60,8 @@ class DecomposeAddSubAlphaPass(ArmPass): _passes_required_after: Set[Type[ExportPass]] = set() + targeted_ops = {*_ADD_OPS, *_SUB_OPS} + def call_operator(self, op, args, kwargs, meta, updated: bool | None = False): if op not in _ADD_OPS + _SUB_OPS: return super().call_operator(op, args, kwargs, meta, updated) diff --git a/backends/arm/_passes/decompose_addmm_pass.py b/backends/arm/_passes/decompose_addmm_pass.py index d1368602d5d..179e6fac166 100644 --- a/backends/arm/_passes/decompose_addmm_pass.py +++ b/backends/arm/_passes/decompose_addmm_pass.py @@ -6,7 +6,6 @@ from typing import Set, Type import torch - from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass @@ -50,6 +49,8 @@ class DecomposeAddmmPass(ArmPass): MatchArgDtypePass, } + targeted_ops = {edge_addmm, aten_addmm} + def call_operator(self, op, args, kwargs, meta): if op not in [edge_addmm, aten_addmm] or not self.allowed_to_transform(meta): return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_as_strided_copy_pass.py b/backends/arm/_passes/decompose_as_strided_copy_pass.py index a60d1b19fd9..56851bfdbac 100644 --- a/backends/arm/_passes/decompose_as_strided_copy_pass.py +++ b/backends/arm/_passes/decompose_as_strided_copy_pass.py @@ -6,7 +6,6 @@ from typing import Dict, Optional, Set, Tuple, Type import torch - from executorch.backends.arm._passes import ArmPass from executorch.backends.arm.common.as_strided_utils import ( contiguous_strides, diff --git a/backends/arm/_passes/decompose_asin_and_acos_pass.py b/backends/arm/_passes/decompose_asin_and_acos_pass.py index 707e6ec070d..7ddab0daa6a 100644 --- a/backends/arm/_passes/decompose_asin_and_acos_pass.py +++ b/backends/arm/_passes/decompose_asin_and_acos_pass.py @@ -9,7 +9,6 @@ from typing import Set, Type import torch - from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.convert_full_like_to_full_pass import ( ConvertFullLikeToFullPass, @@ -72,6 +71,8 @@ class DecomposeAsinAndAcosPass(ArmPass): ReplaceScalarWithTensorByProfilePass, } + targeted_ops = {*edge_asin_op, *edge_acos_op} + def _build_polynomial( self, coefficients: list[float], variable: torch.Tensor, meta: dict[str, str] ) -> torch.Tensor: diff --git a/backends/arm/_passes/decompose_asinh_pass.py b/backends/arm/_passes/decompose_asinh_pass.py index 822b793d203..1efad8bebcf 100644 --- a/backends/arm/_passes/decompose_asinh_pass.py +++ b/backends/arm/_passes/decompose_asinh_pass.py @@ -37,6 +37,8 @@ class DecomposeAsinhPass(ArmPass): MatchArgDtypePass, } + targeted_ops = {*edge_asinh_op} + def call_operator(self, op, args, kwargs, meta): if op not in edge_asinh_op: return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_atan_pass.py b/backends/arm/_passes/decompose_atan_pass.py index a7ca90e7b43..7d37a2fdbe3 100644 --- a/backends/arm/_passes/decompose_atan_pass.py +++ b/backends/arm/_passes/decompose_atan_pass.py @@ -50,6 +50,8 @@ class DecomposeAtanPass(ArmPass): ReplaceScalarWithTensorByProfilePass, } + targeted_ops = {edge_atan} + def _rational_approximation(self, z, ops, meta): """Creates a (2,1) Padé approximation for atan(x) on [-1, 1].""" diff --git a/backends/arm/_passes/decompose_atanh_pass.py b/backends/arm/_passes/decompose_atanh_pass.py index 014da39d7bd..c705b68a66f 100644 --- a/backends/arm/_passes/decompose_atanh_pass.py +++ b/backends/arm/_passes/decompose_atanh_pass.py @@ -47,6 +47,8 @@ class DecomposeAtanhPass(ArmPass): ReplaceScalarWithTensorByProfilePass, } + targeted_ops = {edge_atanh} + def call_operator(self, op, args, kwargs, meta): if op is not edge_atanh: return super().call_operator(op, args, kwargs, meta, updated=False) diff --git a/backends/arm/_passes/decompose_avg_pool2d_pass.py b/backends/arm/_passes/decompose_avg_pool2d_pass.py index a3fe049b8bb..206beea171f 100644 --- a/backends/arm/_passes/decompose_avg_pool2d_pass.py +++ b/backends/arm/_passes/decompose_avg_pool2d_pass.py @@ -42,6 +42,8 @@ def get_decomposition(op) -> tuple: class DecomposeAvgPool2dPass(ArmPass): _passes_required_after: Set[Type[ExportPass]] = {ComputeConstantOpsAOTPass} + targeted_ops = {*edge_div_ops, *aten_div_ops} + def call_operator(self, op, args, kwargs, meta): if op not in (edge_div_ops + aten_div_ops) or not self.allowed_to_transform( meta diff --git a/backends/arm/_passes/decompose_cosh_pass.py b/backends/arm/_passes/decompose_cosh_pass.py index 70d4247d9e0..0cbcd36ec82 100644 --- a/backends/arm/_passes/decompose_cosh_pass.py +++ b/backends/arm/_passes/decompose_cosh_pass.py @@ -35,6 +35,8 @@ class DecomposeCoshPass(ArmPass): MatchArgDtypePass, } + targeted_ops = {edge_cosh} + def call_operator(self, op, args, kwargs, meta, updated=False): if op is not edge_cosh: return super().call_operator(op, args, kwargs, meta, updated) diff --git a/backends/arm/_passes/decompose_cosine_similarity_pass.py b/backends/arm/_passes/decompose_cosine_similarity_pass.py index 6ceb50fdf55..bebfe126a70 100644 --- a/backends/arm/_passes/decompose_cosine_similarity_pass.py +++ b/backends/arm/_passes/decompose_cosine_similarity_pass.py @@ -10,7 +10,6 @@ from executorch.backends.arm._passes.convert_full_like_to_full_pass import ( ConvertFullLikeToFullPass, ) - from executorch.backends.arm._passes.decompose_div_pass import DecomposeDivPass from executorch.backends.arm._passes.decompose_sum_pass import DecomposeSumPass from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass @@ -43,6 +42,8 @@ class DecomposeCosineSimilarityPass(ArmPass): InsertTableOpsPass, } + targeted_ops = {*torch_cosine_similarity} + def call_operator(self, op, args, kwargs, meta): if op not in torch_cosine_similarity or not self.allowed_to_transform(meta): return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_div_pass.py b/backends/arm/_passes/decompose_div_pass.py index 651e58a563c..f3c16b0b9b3 100644 --- a/backends/arm/_passes/decompose_div_pass.py +++ b/backends/arm/_passes/decompose_div_pass.py @@ -41,6 +41,8 @@ class DecomposeDivPass(ArmPass): _passes_required_after: Set[Type[ExportPass]] = {InsertTableOpsPass} + targeted_ops = {*edge_div_ops, *aten_div_ops} + def call_operator(self, op, args, kwargs, meta): if op not in (edge_div_ops + aten_div_ops) or not self.allowed_to_transform( meta diff --git a/backends/arm/_passes/decompose_div_tensor_mode.py b/backends/arm/_passes/decompose_div_tensor_mode.py index 774557b816f..9a046748ec3 100644 --- a/backends/arm/_passes/decompose_div_tensor_mode.py +++ b/backends/arm/_passes/decompose_div_tensor_mode.py @@ -58,6 +58,8 @@ class DecomposeDivTensorModePass(ArmPass): _passes_required_after: Set[Type[ExportPass]] = {DecomposeDivPass} + targeted_ops = {*edge_div_mode_ops, *aten_div_mode_ops} + def call_operator(self, op, args, kwargs, meta): if op not in ( edge_div_mode_ops + aten_div_mode_ops diff --git a/backends/arm/_passes/decompose_elu_pass.py b/backends/arm/_passes/decompose_elu_pass.py index d212c3ec9cb..a23362671ba 100644 --- a/backends/arm/_passes/decompose_elu_pass.py +++ b/backends/arm/_passes/decompose_elu_pass.py @@ -60,6 +60,8 @@ class DecomposeEluPass(ArmPass): _passes_required_after: Set[Type[ExportPass]] = set() + targeted_ops = {*edge_elu_ops} + def call_operator(self, op, args, kwargs, meta): if op not in edge_elu_ops: return super().call_operator(op, args, kwargs, meta, updated=False) diff --git a/backends/arm/_passes/decompose_expm1_pass.py b/backends/arm/_passes/decompose_expm1_pass.py index c1cb0b83166..36c7d333125 100644 --- a/backends/arm/_passes/decompose_expm1_pass.py +++ b/backends/arm/_passes/decompose_expm1_pass.py @@ -88,6 +88,8 @@ class DecomposeExpm1Pass(ArmPass): MatchArgRanksPass, } + targeted_ops = {*edge_expm1_ops} + def call_operator(self, op, args, kwargs, meta): if op not in edge_expm1_ops: return super().call_operator(op, args, kwargs, meta, updated=False) diff --git a/backends/arm/_passes/decompose_floor_divide_pass.py b/backends/arm/_passes/decompose_floor_divide_pass.py index 20e63f48023..bf29cbe4b92 100644 --- a/backends/arm/_passes/decompose_floor_divide_pass.py +++ b/backends/arm/_passes/decompose_floor_divide_pass.py @@ -54,6 +54,8 @@ class DecomposeFloorDividePass(ArmPass): _passes_required_after: Set[Type[ExportPass]] = {DecomposeDivTensorModePass} + targeted_ops = {*edge_floor_divide_ops, *aten_floor_divide_ops} + def call_operator(self, op, args, kwargs, meta): if op not in (edge_floor_divide_ops + aten_floor_divide_ops): return super().call_operator(op, args, kwargs, meta, updated=False) diff --git a/backends/arm/_passes/decompose_gelu_pass.py b/backends/arm/_passes/decompose_gelu_pass.py index 7815b5fa44f..23bdf7f81a4 100644 --- a/backends/arm/_passes/decompose_gelu_pass.py +++ b/backends/arm/_passes/decompose_gelu_pass.py @@ -89,6 +89,8 @@ class DecomposeGeluPass(ArmPass): MatchArgRanksPass, } + targeted_ops = {*torch_gelu, *edge_gelu} + def call_operator(self, op, args, kwargs, meta): if op not in torch_gelu + edge_gelu: return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_glu_pass.py b/backends/arm/_passes/decompose_glu_pass.py index 68efaedd784..757b7255d2a 100644 --- a/backends/arm/_passes/decompose_glu_pass.py +++ b/backends/arm/_passes/decompose_glu_pass.py @@ -44,6 +44,8 @@ class DecomposeGluPass(ArmPass): _passes_required_after: Set[Type[ExportPass]] = {InsertTableOpsPass} + targeted_ops = {edge_glu, aten_glu} + def call_operator(self, op, args, kwargs, meta): if op not in [edge_glu, aten_glu] or not self.allowed_to_transform(meta): return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_grouped_conv_pass.py b/backends/arm/_passes/decompose_grouped_conv_pass.py index ed0adbe83d7..fc990b43324 100644 --- a/backends/arm/_passes/decompose_grouped_conv_pass.py +++ b/backends/arm/_passes/decompose_grouped_conv_pass.py @@ -47,6 +47,11 @@ class DecomposeGroupedConvPass(ArmPass): """ _passes_required_after: Set[Type[ExportPass]] = {Conv1dUnsqueezePass} + targeted_ops = { + exir_ops.edge.aten.convolution.default, + torch.ops.aten.conv_transpose2d.input, + torch.ops.aten.conv2d.default, + } @staticmethod def _get_decomposition(op): @@ -207,7 +212,6 @@ def _get_meta_copy( # Get quantization params of the weights and slice them. w_qarg = new_qparams[1] if DecomposeGroupedConvPass._is_per_channel_qparams(w_qarg): - # For transpose conv, axis=1 corresponds to output channels and # does not align with grouped slicing. # Per-channel quantization on axis=0 on the other hand could align here but @@ -288,7 +292,6 @@ def call_operator(self, op, args, kwargs, meta): for i, (input_slice, filter_slice, bias_slice) in enumerate( zip(input_slices, weight_slices, bias_slices) ): - meta_copy = DecomposeGroupedConvPass._get_meta_copy( meta, i, diff --git a/backends/arm/_passes/decompose_index_select_to_gather_pass.py b/backends/arm/_passes/decompose_index_select_to_gather_pass.py index 5947e8c5499..90b98b9c29f 100644 --- a/backends/arm/_passes/decompose_index_select_to_gather_pass.py +++ b/backends/arm/_passes/decompose_index_select_to_gather_pass.py @@ -7,7 +7,6 @@ from typing import Set, Type import torch - from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.convert_expand_copy_to_repeat import ( ConvertExpandCopyToRepeatPass, diff --git a/backends/arm/_passes/decompose_int16_activation_conv_pass.py b/backends/arm/_passes/decompose_int16_activation_conv_pass.py index 14ccc709fe8..0aee7487945 100644 --- a/backends/arm/_passes/decompose_int16_activation_conv_pass.py +++ b/backends/arm/_passes/decompose_int16_activation_conv_pass.py @@ -9,7 +9,6 @@ import torch from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.quant_args import QuantArgs - from executorch.backends.arm.tosa.specification import get_context_spec from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -30,6 +29,7 @@ def __init__(self) -> None: super().__init__() _passes_required_after: Set[Type[ExportPass]] = set() + targeted_ops = {exir_ops.edge.aten.convolution.default} def bias_view_shape( self, bias: torch.Tensor, activation_rank: int diff --git a/backends/arm/_passes/decompose_int_pow_pass.py b/backends/arm/_passes/decompose_int_pow_pass.py index 2df8d3b2522..44d7e3e8a0a 100644 --- a/backends/arm/_passes/decompose_int_pow_pass.py +++ b/backends/arm/_passes/decompose_int_pow_pass.py @@ -20,6 +20,7 @@ class DecomposeIntPowPass(ArmPass): """ _passes_required_after: Set[Type[ExportPass]] = set() + targeted_ops = {exir_ops.edge.aten.pow.Tensor_Scalar} def call_operator(self, op, args, kwargs, meta): if op != exir_ops.edge.aten.pow.Tensor_Scalar: diff --git a/backends/arm/_passes/decompose_leaky_relu_pass.py b/backends/arm/_passes/decompose_leaky_relu_pass.py index eb8b5bda61a..8b07ac7caa4 100644 --- a/backends/arm/_passes/decompose_leaky_relu_pass.py +++ b/backends/arm/_passes/decompose_leaky_relu_pass.py @@ -48,6 +48,8 @@ class DecomposeLeakyReLUPass(ArmPass): _passes_required_after: Set[Type[ExportPass]] = set() + targeted_ops = {*edge_ops, *torch_ops} + def call_operator(self, op, args, kwargs, meta): if op not in (edge_ops + torch_ops) or not self.allowed_to_transform(meta): return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_linalg_vector_norm_pass.py b/backends/arm/_passes/decompose_linalg_vector_norm_pass.py index 8b165658c37..7f6698f7aa1 100644 --- a/backends/arm/_passes/decompose_linalg_vector_norm_pass.py +++ b/backends/arm/_passes/decompose_linalg_vector_norm_pass.py @@ -40,6 +40,7 @@ class DecomposeLinalgVectorNormPass(ArmPass): } torch_linalg_vector_norm = (torch.ops.aten.linalg_vector_norm.default,) + targeted_ops = torch_linalg_vector_norm def call_operator(self, op, args, kwargs, meta): if op not in self.torch_linalg_vector_norm or not self.allowed_to_transform( diff --git a/backends/arm/_passes/decompose_linear_pass.py b/backends/arm/_passes/decompose_linear_pass.py index 146fb4e648f..9bcd1c74605 100644 --- a/backends/arm/_passes/decompose_linear_pass.py +++ b/backends/arm/_passes/decompose_linear_pass.py @@ -32,6 +32,8 @@ class DecomposeLinearPass(ArmPass): _passes_required_after: Set[Type[ExportPass]] = {InsertRescaleInt32Pass} + targeted_ops = {exir_ops.edge.aten.linear.default} + def call(self, graph_module): for node in graph_module.graph.nodes: if node.op != "call_function": diff --git a/backends/arm/_passes/decompose_logit_pass.py b/backends/arm/_passes/decompose_logit_pass.py index fa82ff4f579..d03163eae11 100644 --- a/backends/arm/_passes/decompose_logit_pass.py +++ b/backends/arm/_passes/decompose_logit_pass.py @@ -6,7 +6,6 @@ from typing import Set, Type import torch - from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass diff --git a/backends/arm/_passes/decompose_masked_fill_pass.py b/backends/arm/_passes/decompose_masked_fill_pass.py index 748aee3fc49..a759f0e176f 100644 --- a/backends/arm/_passes/decompose_masked_fill_pass.py +++ b/backends/arm/_passes/decompose_masked_fill_pass.py @@ -7,7 +7,6 @@ from typing import Set, Type import torch - from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.convert_full_like_to_full_pass import ( ConvertFullLikeToFullPass, @@ -44,6 +43,8 @@ class DecomposeMaskedFillPass(ArmPass): _passes_required_after: Set[Type[ExportPass]] = {ConvertFullLikeToFullPass} + targeted_ops = {*edge_ops, *aten_ops} + def call_operator(self, op, args, kwargs, meta, updated=False): if op not in (*aten_ops, *edge_ops): return super().call_operator(op, args, kwargs, meta, updated) diff --git a/backends/arm/_passes/decompose_maxpool2d_with_dilation_pass.py b/backends/arm/_passes/decompose_maxpool2d_with_dilation_pass.py index 72fe53d57b9..012d065c1be 100644 --- a/backends/arm/_passes/decompose_maxpool2d_with_dilation_pass.py +++ b/backends/arm/_passes/decompose_maxpool2d_with_dilation_pass.py @@ -8,7 +8,6 @@ from typing import Set, Type import torch - from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.size_adjust_input_pass import SizeAdjustInputPass from executorch.exir.dialects._ops import ops as exir_ops @@ -55,6 +54,7 @@ class DecomposeMaxPool2dPass(ArmPass): _passes_required_after: Set[Type[ExportPass]] = { SizeAdjustInputPass, } + targeted_ops = set(EDGE_MAXPOOL2D) def call_operator(self, op, args, kwargs, meta): # Only intercept EXIR edge max_pool2d ops diff --git a/backends/arm/_passes/decompose_meandim_pass.py b/backends/arm/_passes/decompose_meandim_pass.py index dec890c5561..d1bd7951d83 100644 --- a/backends/arm/_passes/decompose_meandim_pass.py +++ b/backends/arm/_passes/decompose_meandim_pass.py @@ -89,6 +89,12 @@ class DecomposeMeanDimPass(ArmPass): DecomposeSumPass, SizeAdjustInputPass, } + targeted_ops = { + exir_ops.edge.aten.mean.dim, + torch.ops.aten.mean.dim, + exir_ops.edge.aten.mean.default, + torch.ops.aten.mean.default, + } def __init__(self, graph_module, tosa_spec, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/backends/arm/_passes/decompose_ne_pass.py b/backends/arm/_passes/decompose_ne_pass.py index 95dfc0e1179..c2f2db87b5a 100644 --- a/backends/arm/_passes/decompose_ne_pass.py +++ b/backends/arm/_passes/decompose_ne_pass.py @@ -58,6 +58,8 @@ class DecomposeNotEqualPass(ArmPass): _passes_required_after: Set[Type[ExportPass]] = set() + targeted_ops = {*edge_ne_ops, *aten_ne_ops} + def call_operator(self, op, args, kwargs, meta): if op not in (edge_ne_ops + aten_ne_ops) or not self.allowed_to_transform(meta): return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_remainder_pass.py b/backends/arm/_passes/decompose_remainder_pass.py index 38185b85149..1bb8006b0cb 100644 --- a/backends/arm/_passes/decompose_remainder_pass.py +++ b/backends/arm/_passes/decompose_remainder_pass.py @@ -49,6 +49,7 @@ class DecomposeRemainderPass(ArmPass): """ _passes_required_after: Set[Type[ExportPass]] = {DecomposeDivTensorModePass} + targeted_ops = set(_decomposition_ops.keys()) def call_operator(self, op, args, kwargs, meta, updated=False): supported_ops = ( diff --git a/backends/arm/_passes/decompose_select_scatter_pass.py b/backends/arm/_passes/decompose_select_scatter_pass.py index 4b4db8d208c..a4d7cd86bef 100644 --- a/backends/arm/_passes/decompose_select_scatter_pass.py +++ b/backends/arm/_passes/decompose_select_scatter_pass.py @@ -6,7 +6,6 @@ from typing import Set, Type import torch - from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.convert_int64_const_ops_to_int32 import ( ConvertInt64ConstOpsToInt32Pass, @@ -66,6 +65,8 @@ class DecomposeSelectScatterPass(ArmPass): ConvertInt64ConstOpsToInt32Pass, } + targeted_ops = {*edge_scatter_ops, *aten_scatter_ops} + def call_operator(self, op, args, kwargs, meta): if op not in (edge_scatter_ops + aten_scatter_ops): return super().call_operator(op, args, kwargs, meta, updated=False) diff --git a/backends/arm/_passes/decompose_sign_pass.py b/backends/arm/_passes/decompose_sign_pass.py index 111d1ca5ee3..589e307ce8f 100644 --- a/backends/arm/_passes/decompose_sign_pass.py +++ b/backends/arm/_passes/decompose_sign_pass.py @@ -6,7 +6,6 @@ from typing import Set, Type import torch - from executorch.backends.arm._passes import ArmPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -51,6 +50,8 @@ class DecomposeSignPass(ArmPass): _passes_required_after: Set[Type[ExportPass]] = set() + targeted_ops = {edge_sign, aten_sign} + def call_operator(self, op, args, kwargs, meta): if op not in (edge_sign, aten_sign) or not self.allowed_to_transform(meta): return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_sinh_pass.py b/backends/arm/_passes/decompose_sinh_pass.py index 71ac0a34f08..361f34b508c 100644 --- a/backends/arm/_passes/decompose_sinh_pass.py +++ b/backends/arm/_passes/decompose_sinh_pass.py @@ -40,6 +40,8 @@ class DecomposeSinhPass(ArmPass): MatchArgDtypePass, } + targeted_ops = {edge_sinh} + def call_operator(self, op, args, kwargs, meta): if op is not edge_sinh: return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_slice_scatter_pass.py b/backends/arm/_passes/decompose_slice_scatter_pass.py index 24cdfeb96a5..c086709c377 100644 --- a/backends/arm/_passes/decompose_slice_scatter_pass.py +++ b/backends/arm/_passes/decompose_slice_scatter_pass.py @@ -6,7 +6,6 @@ from typing import Set, Type import torch - from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.accumulate_index_put_pass import ( AccumulateIndexPutPass, @@ -72,6 +71,8 @@ class DecomposeSliceScatterPass(ArmPass): RewriteIndexPutPass, } + targeted_ops = {*edge_slice_scatter_ops, *aten_slice_scatter_ops} + def call_operator(self, op, args, kwargs, meta): if op not in (edge_slice_scatter_ops + aten_slice_scatter_ops): return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_softmax_pass.py b/backends/arm/_passes/decompose_softmax_pass.py index 85343b229dd..6680a1aac5e 100644 --- a/backends/arm/_passes/decompose_softmax_pass.py +++ b/backends/arm/_passes/decompose_softmax_pass.py @@ -74,6 +74,7 @@ class DecomposeSoftmaxPass(ArmPass): DecomposeSumPass, InsertTableOpsPass, } + targeted_ops = {*torch_softmax, *edge_softmax} def __init__(self, skip_safe_softmax: bool = False, **kwargs): super().__init__(**kwargs) diff --git a/backends/arm/_passes/decompose_softmax_unstable_pass.py b/backends/arm/_passes/decompose_softmax_unstable_pass.py index 4e3eb4f003b..f60ea93f009 100644 --- a/backends/arm/_passes/decompose_softmax_unstable_pass.py +++ b/backends/arm/_passes/decompose_softmax_unstable_pass.py @@ -64,6 +64,7 @@ class DecomposeSoftmaxUnstablePass(ArmPass): DecomposeSumPass, InsertTableOpsPass, } + targeted_ops = {*torch_softmax, *edge_softmax} def call_operator(self, op, args, kwargs, meta): if op not in (torch_softmax + edge_softmax) or not self.allowed_to_transform( diff --git a/backends/arm/_passes/decompose_sqrt_pass.py b/backends/arm/_passes/decompose_sqrt_pass.py index 86e5d6681bd..be682f01b1a 100644 --- a/backends/arm/_passes/decompose_sqrt_pass.py +++ b/backends/arm/_passes/decompose_sqrt_pass.py @@ -30,6 +30,8 @@ def get_sqrt_decomposition(op) -> Union[Tuple, torch._ops.OpOverload]: class DecomposeSqrtPass(ArmPass): _passes_required_after: Set[Type[ExportPass]] = {InsertTableOpsPass} + targeted_ops = {*edge_sqrt_ops, *aten_sqrt_ops} + def call_operator(self, op, args, kwargs, meta): """Decomposes `sqrt(x)` into `pow(x, 0.5)` for backend support.""" diff --git a/backends/arm/_passes/decompose_sum_pass.py b/backends/arm/_passes/decompose_sum_pass.py index cff1cca3ba3..2be2c545646 100644 --- a/backends/arm/_passes/decompose_sum_pass.py +++ b/backends/arm/_passes/decompose_sum_pass.py @@ -44,6 +44,10 @@ class DecomposeSumPass(ArmPass): """ _passes_required_after: Set[Type[ExportPass]] = set() + targeted_ops = { + exir_ops.edge.aten.sum.dim_IntList, + torch.ops.aten.sum.dim_IntList, + } def call_operator(self, op, args, kwargs, meta): if op not in [ diff --git a/backends/arm/_passes/decompose_tan_pass.py b/backends/arm/_passes/decompose_tan_pass.py index 87b347dbbad..3020d7c4462 100644 --- a/backends/arm/_passes/decompose_tan_pass.py +++ b/backends/arm/_passes/decompose_tan_pass.py @@ -18,6 +18,8 @@ class DecomposeTanPass(ArmPass): _passes_required_after: Set[Type[ExportPass]] = {DecomposeDivPass} + targeted_ops = {edge_tan_op} + def call_operator(self, op, args, kwargs, meta, updated=False): if op != edge_tan_op: return super().call_operator(op, args, kwargs, meta, updated) diff --git a/backends/arm/_passes/decompose_tril_pass.py b/backends/arm/_passes/decompose_tril_pass.py index 3101b24e95b..85c52ede57f 100644 --- a/backends/arm/_passes/decompose_tril_pass.py +++ b/backends/arm/_passes/decompose_tril_pass.py @@ -54,6 +54,7 @@ class DecomposeTrilPass(ArmPass): """ _passes_required_after: Set[Type[ExportPass]] = {ComputeConstantOpsAOTPass} + targeted_ops = {torch.ops.aten.tril.default} def call_operator(self, op, args, kwargs, meta): handled_ops = [torch.ops.aten.tril.default] diff --git a/backends/arm/_passes/decompose_unfold_to_gather_pass.py b/backends/arm/_passes/decompose_unfold_to_gather_pass.py index d0e3897080a..0e86940dd56 100644 --- a/backends/arm/_passes/decompose_unfold_to_gather_pass.py +++ b/backends/arm/_passes/decompose_unfold_to_gather_pass.py @@ -8,7 +8,6 @@ from typing import Set, Type import torch - from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import ( ReplaceScalarWithTensorByProfilePass, diff --git a/backends/arm/_passes/decompose_var_pass.py b/backends/arm/_passes/decompose_var_pass.py index fcf61cf5129..083d290b32b 100644 --- a/backends/arm/_passes/decompose_var_pass.py +++ b/backends/arm/_passes/decompose_var_pass.py @@ -56,6 +56,11 @@ class DecomposeVarPass(ArmPass): DecomposeMeanDimPass, DecomposeSumPass, } + targeted_ops = { + exir_ops.edge.aten.var.correction, + torch.ops.aten.var.correction, + torch.ops.aten.var.dim, + } def call_operator(self, op, args, kwargs, meta): if op not in ( diff --git a/backends/arm/_passes/insert_rescales_pass.py b/backends/arm/_passes/insert_rescales_pass.py index 8cab19dc551..16227c67fc8 100644 --- a/backends/arm/_passes/insert_rescales_pass.py +++ b/backends/arm/_passes/insert_rescales_pass.py @@ -14,7 +14,6 @@ from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( get_output_qparams, ) - from executorch.backends.arm._passes.quant_args import QuantArgs from executorch.backends.arm.constants import DQ_OPS, Q_OPS from executorch.exir.dialects._ops import ops as exir_ops @@ -34,6 +33,8 @@ class InsertRescalePass(ArmPass): _passes_required_after: Set[Type[ExportPass]] = set() + targeted_ops = {*DQ_OPS} + def fold_dq_q_to_rescale(self, node: Node, user: Node, graph_module: GraphModule): dq_args = QuantArgs.from_operator(node.target, node.args) q_args = QuantArgs.from_operator(user.target, user.args) @@ -92,7 +93,7 @@ class InsertRescaleInt32Pass(ArmPass): # decomposition. _passes_required_after: Set[Type[ExportPass]] = {DecomposeSumPass} - included_targets = [ + targeted_ops = { exir_ops.edge.aten.abs.default, exir_ops.edge.aten.add.Tensor, exir_ops.edge.aten.eq.Tensor, @@ -105,7 +106,9 @@ class InsertRescaleInt32Pass(ArmPass): exir_ops.edge.aten.mul.Tensor, exir_ops.edge.aten.sub.Tensor, exir_ops.edge.aten.sum.dim_IntList, - ] + } + + included_targets = list(targeted_ops) def _int32_qargs(self, s): """Helper creator function for INT32-based QuantArgs.""" @@ -554,8 +557,12 @@ def _get_output_qparams_map(self, node: Node): def _rescale_cond_submodules(self, node: Node, graph_module: GraphModule) -> bool: modified = False - if_graph: GraphModule = cast(GraphModule, graph_module.get_submodule(node.args[1].target)) # type: ignore - else_graph: GraphModule = cast(GraphModule, graph_module.get_submodule(node.args[2].target)) # type: ignore + if_graph: GraphModule = cast( + GraphModule, graph_module.get_submodule(node.args[1].target) # type: ignore[union-attr, arg-type] + ) + else_graph: GraphModule = cast( + GraphModule, graph_module.get_submodule(node.args[2].target) # type: ignore[union-attr, arg-type] + ) input_qparams_map = self._get_input_qparams_map(node, 3) if input_qparams_map: modified |= self._rescale_submodule_inputs(if_graph, input_qparams_map) @@ -569,8 +576,12 @@ def _rescale_cond_submodules(self, node: Node, graph_module: GraphModule) -> boo def _rescale_while_submodules(self, node: Node, graph_module: GraphModule): modified = False - cond_graph: GraphModule = cast(GraphModule, graph_module.get_submodule(node.args[0].target)) # type: ignore - body_graph: GraphModule = cast(GraphModule, graph_module.get_submodule(node.args[1].target)) # type: ignore + cond_graph: GraphModule = cast( + GraphModule, graph_module.get_submodule(node.args[0].target) # type: ignore[union-attr, arg-type] + ) + body_graph: GraphModule = cast( + GraphModule, graph_module.get_submodule(node.args[1].target) # type: ignore[union-attr, arg-type] + ) input_qparams_map = self._get_input_qparams_map(node, 2) if input_qparams_map: diff --git a/backends/arm/_passes/insert_table_ops.py b/backends/arm/_passes/insert_table_ops.py index 78702bf9035..efa0f0ddad6 100644 --- a/backends/arm/_passes/insert_table_ops.py +++ b/backends/arm/_passes/insert_table_ops.py @@ -12,12 +12,9 @@ from executorch.backends.arm._passes.arm_pass_utils import create_node from executorch.backends.arm._passes.quant_args import QuantArgs from executorch.backends.transforms.utils import create_constant_placeholder - from executorch.exir import ExportedProgram - from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.dialects.edge._ops import EdgeOpOverload - from executorch.exir.pass_base import ExportPass, PassResult from torch.export.graph_signature import InputKind from torch.fx import GraphModule diff --git a/backends/arm/_passes/remove_getitem_pass.py b/backends/arm/_passes/remove_getitem_pass.py index 3ce157d3fd8..122a8330203 100644 --- a/backends/arm/_passes/remove_getitem_pass.py +++ b/backends/arm/_passes/remove_getitem_pass.py @@ -7,8 +7,14 @@ from executorch.backends.arm._passes.arm_pass import ArmPass from executorch.backends.transforms import remove_getitem_op +from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass class RemoveGetItemPass(ArmPass, remove_getitem_op.RemoveGetItemPass): _passes_required_after: Set[Type[ExportPass]] = set() + + targeted_ops = { + exir_ops.edge.aten.max_pool2d_with_indices.default, + exir_ops.edge.aten.max.dim, + } diff --git a/backends/arm/_passes/remove_noop_pass.py b/backends/arm/_passes/remove_noop_pass.py index c7fe469c8b8..fba8d37ddc8 100644 --- a/backends/arm/_passes/remove_noop_pass.py +++ b/backends/arm/_passes/remove_noop_pass.py @@ -9,7 +9,6 @@ from typing import Set, Type from executorch.backends.arm._passes import ArmPass - from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -21,14 +20,16 @@ class RemoveNoopPass(ArmPass): _passes_required_after: Set[Type[ExportPass]] = set() + targeted_ops = { + exir_ops.edge.dim_order_ops._clone_dim_order.default, + exir_ops.edge.dim_order_ops._to_dim_order_copy.default, + exir_ops.edge.aten.alias_copy.default, + exir_ops.edge.aten.copy.default, + exir_ops.edge.aten.detach_copy.default, + } + def call_operator(self, op, args, kwargs, meta): - if op not in ( - exir_ops.edge.dim_order_ops._clone_dim_order.default, - exir_ops.edge.dim_order_ops._to_dim_order_copy.default, - exir_ops.edge.aten.alias_copy.default, - exir_ops.edge.aten.copy.default, - exir_ops.edge.aten.detach_copy.default, - ): + if op not in self.targeted_ops: return super().call_operator(op, args, kwargs, meta) input_dtype = args[0].data.dtype diff --git a/backends/arm/_passes/rewrite_bool_bitwise_to_logical_pass.py b/backends/arm/_passes/rewrite_bool_bitwise_to_logical_pass.py index 8c6bf6f39ec..0ec72fabcac 100644 --- a/backends/arm/_passes/rewrite_bool_bitwise_to_logical_pass.py +++ b/backends/arm/_passes/rewrite_bool_bitwise_to_logical_pass.py @@ -33,6 +33,8 @@ class RewriteBoolBitwiseToLogicalPass(ArmPass): exir_ops.edge.aten.bitwise_xor.Scalar: exir_ops.edge.aten.logical_xor.default, } + targeted_ops = set(_TARGET_TO_LOGICAL.keys()) + def call_operator(self, op, args, kwargs, meta): if op not in self._TARGET_TO_LOGICAL: return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/rewrite_le_lt_to_ge_gt_pass.py b/backends/arm/_passes/rewrite_le_lt_to_ge_gt_pass.py index 9119567b7aa..bf29bb1bc1c 100644 --- a/backends/arm/_passes/rewrite_le_lt_to_ge_gt_pass.py +++ b/backends/arm/_passes/rewrite_le_lt_to_ge_gt_pass.py @@ -6,7 +6,6 @@ from typing import Set, Type import torch - from executorch.backends.arm._passes import ArmPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -24,6 +23,8 @@ class RewriteLeLtToGeGtPass(ArmPass): _passes_required_after: Set[Type[ExportPass]] = set() + targeted_ops = {*OP_MAP} + def call_operator(self, op, args, kwargs, meta): if not self.allowed_to_transform(meta): return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/size_adjust_input_pass.py b/backends/arm/_passes/size_adjust_input_pass.py index a7898b0d0d5..8fd0cb95bfa 100644 --- a/backends/arm/_passes/size_adjust_input_pass.py +++ b/backends/arm/_passes/size_adjust_input_pass.py @@ -203,6 +203,8 @@ class SizeAdjustInputPass(ArmPass): RewriteConvPass, } + targeted_ops = set(valid_operators) + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: graph = graph_module.graph modified_graph = False diff --git a/exir/pass_base.py b/exir/pass_base.py index 5350c917230..78d86b282d2 100644 --- a/exir/pass_base.py +++ b/exir/pass_base.py @@ -9,7 +9,7 @@ import operator import traceback -from contextlib import nullcontext +from contextlib import contextmanager, nullcontext from typing import ( Any, Callable, @@ -27,9 +27,7 @@ import torch from executorch.exir import memory - from executorch.exir.delegate import executorch_call_delegate, is_lowered_module - from executorch.exir.dialects.edge._ops import EdgeOpOverload from executorch.exir.error import ExportError, ExportErrorType from torch import fx @@ -157,6 +155,113 @@ class ExportPassBaseError(RuntimeError): pass +# Namespaces of ops that are safe to cache in the FakeTensor dispatch cache. +# By default, FakeTensorMode only caches ops in {"aten", "prim", "prims"}. +# ExecuTorch passes commonly use quantization and TOSA ops that are +# deterministic and shape-preserving, so we extend caching to cover them +# during pass execution to avoid redundant FakeTensor dispatches. +_EXTRA_CACHEABLE_NAMESPACES: frozenset[str] = frozenset( + { + "quantized_decomposed", + "tosa", + "dim_order_ops", + "cortex_m", + } +) + + +@contextmanager +# pyre-ignore[3] +def _extend_faketensor_cache_builtins(): # noqa: C901 + """Temporarily extend FakeTensor dispatch cache to cover ExecuTorch ops. + + The FakeTensor dispatch cache (``FakeTensorMode``) only caches "builtin" + ops whose namespace is in ``{"aten", "prim", "prims"}``. ExecuTorch + passes operate on graphs that contain ``quantized_decomposed``, ``tosa``, + and other non-builtin ops that are nonetheless safe to cache -- they are + deterministic and their output metadata depends only on input metadata. + + Without caching these ops, every pass re-dispatches them through the full + PyTorch stack (~0.5 ms each), leading to tens of seconds of overhead + across 50+ passes on a ~1200-node graph. + + This context manager monkey-patches ``torch._library.utils.is_builtin`` + so that the cache also covers the extra namespaces, then restores the + original function on exit. + """ + import torch._library.utils as _library_utils + + _original_is_builtin = _library_utils.is_builtin + + def _extended_is_builtin(op: torch._ops.OpOverload) -> bool: + if not isinstance(op, torch._ops.OpOverload): + raise AssertionError(f"op must be OpOverload, got {type(op)}") + return op.namespace in {"aten", "prim", "prims"} or ( + op.namespace in _EXTRA_CACHEABLE_NAMESPACES + ) + + _library_utils.is_builtin = _extended_is_builtin # pyre-ignore[8] + + # Evict negative cache entries ("non-builtin" bypass entries) that were + # stored before the extension was active. FakeTensorMode stores + # _DispatchCacheBypassEntry objects as negative cache hits — once stored, + # _validate_cache_key is never re-evaluated for that key. We must evict + # these so the first dispatch under the extension re-evaluates is_builtin + # and creates a proper positive cache entry instead. + # + # There are TWO caches that can hold negative entries: + # 1. FakeTensorMode.cache -- the global (class-level) cache, used when + # the dispatch has no SymInt inputs. + # 2. shape_env.fake_tensor_cache -- per-ShapeEnv cache, used when the + # dispatch involves SymInt/SymFloat inputs (cache_on_shape_env=True). + # We must evict from both. + try: + from torch._subclasses.fake_tensor import ( + _DispatchCacheBypassEntry, + FakeTensorMode, + ) + + def _is_nonbuiltin_bypass(v: object) -> bool: + return ( + isinstance(v, _DispatchCacheBypassEntry) and v.reason == "non-builtin" + ) + + # 1. Evict from the global class-level cache. + FakeTensorMode.cache = { + k: v + for k, v in FakeTensorMode.cache.items() + if not _is_nonbuiltin_bypass(v) + } + + # 2. Evict from the per-ShapeEnv cache of the currently active + # FakeTensorMode (if any). When ExportPass enters _fx(), the + # FakeTensorMode is already on the dispatch stack before this CM + # is entered, so we can reach its shape_env cache. + try: + from torch.utils._python_dispatch import _get_current_dispatch_mode_stack + + for mode in _get_current_dispatch_mode_stack(): + if isinstance(mode, FakeTensorMode): + se = getattr(mode, "shape_env", None) + if se is not None: + se_cache = getattr(se, "fake_tensor_cache", None) + if se_cache: + se.fake_tensor_cache = { + k: v + for k, v in se_cache.items() + if not _is_nonbuiltin_bypass(v) + } + except (ImportError, AttributeError): + pass + except (ImportError, AttributeError): + pass # Graceful degradation if internals change + + try: + yield + finally: + _library_utils.is_builtin = _original_is_builtin # pyre-ignore[8] + + class _ExportPassBase(PassBase): """ Interpreter-based pass class to help users maintain the IR spec while writing @@ -290,12 +395,45 @@ def make_tensor_meta(x: Argument) -> Optional[TensorMetadata]: node.meta["tensor_meta"] = pytree.tree_map(make_tensor_meta, value) + # Types whose nodes are eligible for the fast-copy optimisation in + # ``run_node``. Subclass interpreters (e.g. ``ExportPass``) extend + # this tuple to include dialect-specific overload types such as + # ``EdgeOpOverload``. + _OPERATOR_TARGET_TYPES: Tuple[type, ...] = ( + torch._ops.OpOverload, + torch._ops.OpOverloadPacket, + ) + class ExportInterpreter(fx.Interpreter): def __init__(self, callback: "_ExportPassBase", gm: fx.GraphModule) -> None: super().__init__(gm) self.callback = callback self.node: torch.fx.Node = next(iter(gm.graph.nodes)) + # --- fast-copy bookkeeping --------------------------------- + # When the owning pass declares ``targeted_ops``, cold nodes + # (those whose target is *not* in the set) can be copied into + # the new graph without an expensive FakeTensor dispatch. + targeted: Optional[Set[Any]] = getattr(callback, "targeted_ops", None) + self._targeted_ops: Optional[Set[Any]] = targeted if targeted else None + + # Fast-copy relies on the existing ``n.meta["val"]`` being + # correct for cold nodes. If the pass overrides ``call()`` + # it may modify the graph (e.g. insert nodes with metadata + # copied from unrelated ops) before calling ``super().call()``, + # which would make cold-node metadata unreliable. Disable the + # optimisation in that case. + call_overridden = type(callback).call is not _ExportPassBase.call + self._fast_copy_enabled: bool = ( + self._targeted_ops is not None and not call_overridden + ) + + # Maps old-graph nodes to their new-graph equivalents so that + # ``_fast_copy_node`` can remap arguments (including get_attr + # nodes that are stored in ``self.env`` as raw tensors rather + # than ProxyValues). + self._node_remap: Dict[torch.fx.Node, torch.fx.Node] = {} + def placeholder( # pyre-fixme[14] self, target: str, @@ -389,10 +527,113 @@ def call_method( # pyre-fixme[14] ) -> None: raise ExportPassBaseError("call_method is not supported.") + # -- fast-copy helpers ------------------------------------------ + + def _fast_copy_node(self, n: torch.fx.Node) -> "ProxyValue": + """Copy *n* into the new graph without FakeTensor dispatch. + + This is the fast path for "cold" nodes — nodes whose target is + not in the pass's ``targeted_ops``. Instead of running the + full ``_fx`` pipeline (unwrap → dispatch → create_proxy → + set_metadata), we use ``graph.node_copy`` to clone the node + directly and reuse the original ``val`` metadata. + + Typical savings: ~0.4 ms → ~0.02 ms per node. + """ + + tracer = self.callback.tracer + + def _arg_transform(old_node: torch.fx.Node) -> torch.fx.Node: + # 1. Check the remap dict (populated for processed nodes + # whose result is a ProxyValue). + new_node = self._node_remap.get(old_node) + if new_node is not None: + return new_node + # 2. Fallback: extract from ProxyValue in env. + pv = self.env.get(old_node) + if pv is not None and hasattr(pv, "proxy"): + mapped = pv.proxy.node + self._node_remap[old_node] = mapped + return mapped + # 3. For get_attr / placeholder nodes that were processed + # via the normal path but returned raw tensors (not + # ProxyValue), they won't be in _node_remap. Copy + # them into the new graph on demand. + if old_node.op in ("get_attr", "placeholder"): + copied = tracer.graph.node_copy( + old_node, lambda x: self._node_remap.get(x, x) + ) + self._node_remap[old_node] = copied + # For get_attr, also register the attribute on the + # new module so GraphModule.__init__ can find it. + if old_node.op == "get_attr": + val = self.fetch_attr(old_node.target) + target_atoms = old_node.target.split(".") + root = tracer.root + for atom in target_atoms[:-1]: + if not hasattr(root, atom): + setattr(root, atom, torch.nn.Module()) + root = getattr(root, atom) + setattr(root, target_atoms[-1], val) + return copied + return old_node + + new_node = tracer.graph.node_copy(n, _arg_transform) + # node_copy already does copy.copy(node.meta) + + val = n.meta.get("val") + proxy = torch.fx.Proxy(new_node, tracer) + result = ProxyValue(val, proxy) + self._node_remap[n] = new_node + return result + def run_node(self, n: torch.fx.Node) -> Argument: self.node = n self.callback.node_debug_str = n.format_node() - return super().run_node(n) + + # Fast-copy path: skip the full interpreter dispatch for cold + # call_function nodes whose operator is not targeted by this + # pass. This avoids the expensive FakeTensor re-dispatch and + # proxy reconstruction for nodes the pass will not modify. + if ( + self._fast_copy_enabled + and n.op == "call_function" + and isinstance(n.target, self.callback._OPERATOR_TARGET_TYPES) + and n.target not in self._targeted_ops # type: ignore[operator] + and n.meta.get("val") is not None + ): + return self._fast_copy_node(n) + + result = super().run_node(n) + + # Record old→new node mapping for fast-copy arg remapping. + if self._fast_copy_enabled and isinstance(result, ProxyValue): + self._node_remap[n] = result.proxy.node + + # After a hot node runs through full dispatch, verify that + # it did not change output shapes. If it did, downstream + # cold nodes' original ``val`` metadata would be stale, so + # we disable the fast-copy optimisation for the remainder + # of this interpreter walk. + if ( + self._fast_copy_enabled + and n.op == "call_function" + and self._targeted_ops is not None + and n.target in self._targeted_ops + and isinstance(result, ProxyValue) + ): + original_val = n.meta.get("val") + new_val = result.data + if isinstance(original_val, torch.Tensor) and isinstance( + new_val, torch.Tensor + ): + if ( + original_val.shape != new_val.shape + or original_val.dtype != new_val.dtype + ): + self._fast_copy_enabled = False + + return result def __init__(self) -> None: self.interpreter = torch.fx.Interpreter( @@ -601,13 +842,17 @@ def output(self, results: List[Argument], meta: NodeMetadata) -> ProxyValue: def call_submodule( self, graph_module: fx.GraphModule, inputs: Tuple[Argument, ...] ) -> PassResult: - prev_tracer, self.tracer = self.tracer, self.ExportTracer( - self, graph_module.graph._codegen + prev_tracer, self.tracer = ( + self.tracer, + self.ExportTracer(self, graph_module.graph._codegen), ) self.tracer.fake_tensor_mode = prev_tracer.fake_tensor_mode interpreter = self.ExportInterpreter(self, graph_module) - prev_interpreter, self.interpreter = self.interpreter, torch.fx.Interpreter( - torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph()) + prev_interpreter, self.interpreter = ( + self.interpreter, + torch.fx.Interpreter( + torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph()) + ), ) inputs_data = pytree.tree_map_only(ProxyValue, lambda x: x.data, inputs) with fx_traceback.preserve_node_meta(): @@ -622,12 +867,33 @@ def call_submodule( True, ) + def should_run(self, graph_module: fx.GraphModule) -> bool: + """Override to declare when this pass can be skipped entirely. + + When this method returns False, the expensive FakeTensor graph + re-interpretation is bypassed and the original graph module is returned + unchanged. Subclasses should override this to inspect the graph cheaply + (e.g. checking whether any node targets an op this pass cares about). + + The default implementation returns True so existing passes are + unaffected. + """ + return True + def call(self, graph_module: fx.GraphModule) -> PassResult: if not getattr(self, "_initialized", False): raise ExportPassBaseError( "ExportPass is not initialized with __init__().", ) + if not getattr(self, "_skip_should_run", False) and not self.should_run( + graph_module + ): + return PassResult(graph_module, False) + + prev_skip = getattr(self, "_skip_should_run", False) + self._skip_should_run = True + inputs = self.inputs(graph_module) fake_tensor_mode = None @@ -647,12 +913,22 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: self.fake_tensor_mode = fake_tensor_mode with fake_tensor_mode, dispatcher_mode: # type: ignore[assignment, union-attr] - result = self.call_submodule(graph_module, tuple(inputs)) + with _extend_faketensor_cache_builtins(): + result = self.call_submodule(graph_module, tuple(inputs)) + self._skip_should_run = prev_skip return result class ExportPass(_ExportPassBase): + # Extend operator target types to include the Edge dialect overloads so + # that the fast-copy optimisation in ``run_node`` also covers Edge ops. + _OPERATOR_TARGET_TYPES: Tuple[type, ...] = ( + torch._ops.OpOverload, + torch._ops.OpOverloadPacket, + EdgeOpOverload, + ) + class ExportTracer(_ExportPassBase.ExportTracer): def create_arg(self, a: Argument) -> torch.fx.Node: if isinstance(a, torch.nn.Module): diff --git a/exir/tests/test_passes.py b/exir/tests/test_passes.py index f683384f8f9..26091f76893 100644 --- a/exir/tests/test_passes.py +++ b/exir/tests/test_passes.py @@ -542,6 +542,79 @@ class NullPass(ExportPass): self.assertEqual(new_node.op, old_node.op) self.assertEqual(new_node.target, old_node.target) + def test_export_pass_should_run_skip(self) -> None: + """Test that should_run=False skips FakeTensor re-interpretation.""" + + class Foo(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + x + + class AlwaysSkipPass(ExportPass): + def should_run(self, graph_module) -> bool: + return False + + def call_operator(self, op, args, kwargs, meta): + raise AssertionError("call_operator should never be reached") + + prog = to_edge(export(Foo(), (torch.ones(3, 2),), strict=True)) + original_gm = prog.exported_program().graph_module + + result = AlwaysSkipPass()(original_gm) + self.assertIsNotNone(result) + self.assertFalse(result.modified) + self.assertIs(result.graph_module, original_gm) + + def test_export_pass_should_run_op_predicate(self) -> None: + """Test should_run with op-based predicate: skip when irrelevant ops.""" + + class Foo(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + x + + class MulOnlyPass(ExportPass): + """A pass that only cares about mul ops.""" + + def should_run(self, graph_module) -> bool: + return any( + node.target == torch.ops.aten.mul.Tensor + for node in graph_module.graph.nodes + if node.op == "call_function" + ) + + def call_operator(self, op, args, kwargs, meta): + raise AssertionError("call_operator should never be reached") + + # Foo only has add ops, so MulOnlyPass should be skipped + prog = to_edge(export(Foo(), (torch.ones(3, 2),), strict=True)) + gm = prog.exported_program().graph_module + + result = MulOnlyPass()(gm) + self.assertIsNotNone(result) + self.assertFalse(result.modified) + self.assertIs(result.graph_module, gm) + + def test_export_pass_should_run_true_still_runs(self) -> None: + """Test that should_run=True (default) still runs the pass normally.""" + + call_count = 0 + + class Foo(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + x + + class CountingPass(ExportPass): + def call_operator(self, op, args, kwargs, meta): + nonlocal call_count + call_count += 1 + return super().call_operator(op, args, kwargs, meta) + + prog = to_edge(export(Foo(), (torch.ones(3, 2),), strict=True)) + gm = prog.exported_program().graph_module + + result = CountingPass()(gm) + self.assertIsNotNone(result) + self.assertGreater(call_count, 0) + def test_export_scalar_to_tensor_pass(self) -> None: # Build a graph with a scalar argument where schema expects tensor graph = torch.fx.Graph()