Skip to content

Commit ad7b73c

Browse files
Andrew Pullinmeta-codesync[bot]
authored andcommitted
Add should_run() + fast-copy infrastructure with targeted_ops annotations (pytorch#18497)
Summary: Pull Request resolved: pytorch#18497 Adds infrastructure for skipping and fast-copying unchanged nodes during ExportPass execution, then annotates ~60 ARM backend passes to use it. ## Changes ### 1. should_run() hook on ExportPass / ArmPass Subclasses that declare a `targeted_ops` class attribute (a set of op overloads) can be skipped entirely when the graph contains none of their target ops. ArmPass provides a default implementation via inheritance. ### 2. Fast-copy for cold nodes When a pass declares `targeted_ops`, nodes whose ops are NOT in the set are copied into the new graph via `graph.node_copy()` instead of full FakeTensor dispatch. Per-node cost drops from ~0.4 ms to ~0.02 ms (~20x). Includes a safety guard: nodes without `val` metadata (e.g. nodes inserted by `call()` overrides before `super().call()`) fall back to full dispatch instead of propagating None. ### 3. FakeTensor cache extension Context manager `_extend_faketensor_cache_builtins()` temporarily extends the FakeTensor dispatch cache to cover ExecuTorch op namespaces (quantized_decomposed, tosa, dim_order_ops, cortex_m). Avoids redundant re-dispatches for non-builtin ops across 50+ passes. ### 4. __init_subclass__ auto-discovery on ArmPass Subclasses with existing `_TARGET_OPS`, `_supported_ops`, or `_EDGE_OPS`/`_ATEN_OPS` attributes get `targeted_ops` populated automatically at class definition time — no manual annotation needed. ### 5. targeted_ops annotations on ~60 ARM passes Each annotation is a one-liner declaring the ops the pass checks in `call_operator()`. Combined with should_run() and fast-copy, this achieves the measured speedup below. ## Benchmark Model: small CNN feature extractor (~50K params, 9 conv layers with LayerNorm, targeting Ethos-U55 via the ARM/TOSA lowering pipeline). Graph: ~1200 nodes, 146 ExportPass invocations. lower() before: 186 s lower() after: 100 s Passes skipped: 53 of 146 Delta: -86 s (-46 %) Adds should_run() hook to ExportPass that subclasses can override to skip execution when a pass has no work to do. ArmPass implements a default that checks a targeted_ops class attribute against the graph's call_function nodes. Also adds: - _fast_copy_node path in ExportInterpreter.run_node that uses graph.node_copy instead of full FakeTensor dispatch for cold nodes in passes that declare targeted_ops. Per-node cost drops from ~0.4ms to ~0.02ms. - _extend_faketensor_cache_builtins context manager that extends FakeTensor dispatch cache to cover ExecuTorch ops (quantized_decomposed, tosa, etc.) - __init_subclass__ on ArmPass for auto-discovery of targeted_ops from existing _TARGET_OPS, _supported_ops, _EDGE_OPS/_ATEN_OPS attributes - targeted_ops annotations on ~60 ARM pass subclasses Measured on SleepNet featurizer (U55 lowering): lower(): 185s -> 96s = -89s (-48%) Differential Revision: D97528110
1 parent 980c012 commit ad7b73c

62 files changed

Lines changed: 548 additions & 55 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

backends/arm/_passes/arm_pass.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,25 @@
2020
class ArmPass(ExportPass):
2121
"""Base class for Arm passes."""
2222

23+
def __init_subclass__(cls, **kwargs) -> None:
24+
super().__init_subclass__(**kwargs)
25+
if getattr(cls, "targeted_ops", None) is not None:
26+
return
27+
# Only auto-discover targeted_ops for passes that use the standard
28+
# call_operator() pattern. Passes that override call() use _TARGET_OPS
29+
# for their own graph manipulation logic, not as a fast-copy declaration.
30+
if "call" in cls.__dict__:
31+
return
32+
for attr in ("_TARGET_OPS", "_supported_ops"):
33+
ops = getattr(cls, attr, None)
34+
if ops:
35+
cls.targeted_ops = set(ops) if not isinstance(ops, set) else ops # type: ignore[attr-defined]
36+
return
37+
edge = getattr(cls, "_EDGE_OPS", None)
38+
aten = getattr(cls, "_ATEN_OPS", None)
39+
if edge or aten:
40+
cls.targeted_ops = {*(edge or ()), *(aten or ())} # type: ignore[attr-defined]
41+
2342
def __init__(self, tfa_pass: bool = False, *args, **kwargs) -> None:
2443
super().__init__(*args, **kwargs)
2544
self.submodule_depth = 0
@@ -78,6 +97,34 @@ def get_name(pass_) -> str:
7897
f"Cannot get name for pass: {pass_}. It must be an instance of ExportPass or have a __name__ attribute."
7998
)
8099

100+
def should_run(self, graph_module: GraphModule) -> bool:
101+
"""Skip this pass if the graph contains none of its targeted ops.
102+
103+
Subclasses that define a ``targeted_ops`` class attribute (a set of
104+
op overloads) get this check for free via inheritance. Passes
105+
without ``targeted_ops`` always run (the default).
106+
107+
Recursively checks control flow submodules (cond/while_loop) so
108+
passes are not incorrectly skipped when targeted ops are nested.
109+
110+
"""
111+
targeted = getattr(self, "targeted_ops", None)
112+
if targeted is None:
113+
return True
114+
115+
from executorch.exir.graph_module import get_control_flow_submodules
116+
117+
def _has_targeted_op(gm: GraphModule) -> bool:
118+
for node in gm.graph.nodes:
119+
if node.op == "call_function" and node.target in targeted:
120+
return True
121+
for _, submod, _ in get_control_flow_submodules(gm):
122+
if _has_targeted_op(submod):
123+
return True
124+
return False
125+
126+
return _has_targeted_op(graph_module)
127+
81128
def call_operator(self, op, args, kwargs, meta, updated: Optional[bool] = False):
82129
if not updated:
83130
return super().call_operator(op, args, kwargs, meta)

backends/arm/_passes/cast_to_int32_pass.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,7 @@
66
from typing import Set, Type
77

88
import torch
9-
109
from executorch.backends.arm._passes.arm_pass import ArmPass
11-
1210
from executorch.backends.arm.tosa.specification import get_context_spec
1311
from executorch.exir.dialects._ops import ops as exir_ops
1412
from executorch.exir.pass_base import ExportPass, PassResult

backends/arm/_passes/conv1d_unsqueeze_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,8 @@
99
from typing import Set, Type
1010

1111
from executorch.backends.arm._passes import ArmPass
12-
1312
from executorch.backends.arm._passes.rewrite_conv_pass import RewriteConvPass
1413
from executorch.backends.arm._passes.size_adjust_input_pass import SizeAdjustInputPass
15-
1614
from executorch.exir.dialects._ops import ops as exir_ops
1715
from executorch.exir.pass_base import ExportPass
1816

@@ -35,6 +33,8 @@ class Conv1dUnsqueezePass(ArmPass):
3533
SizeAdjustInputPass,
3634
}
3735

36+
targeted_ops = {exir_ops.edge.aten.convolution.default}
37+
3838
def call_operator(self, op, args, kwargs, meta):
3939
if op != exir_ops.edge.aten.convolution.default:
4040
return super().call_operator(op, args, kwargs, meta)

backends/arm/_passes/convert_expand_copy_to_repeat.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from typing import cast, Set, Type
99

1010
import torch
11-
1211
from executorch.backends.arm._passes.arm_pass import ArmPass
1312
from executorch.backends.arm._passes.unsqueeze_before_repeat_pass import (
1413
UnsqueezeBeforeRepeatPass,
@@ -58,6 +57,8 @@ class ConvertExpandCopyToRepeatPass(ArmPass):
5857

5958
_passes_required_after: Set[Type[ExportPass]] = {UnsqueezeBeforeRepeatPass}
6059

60+
targeted_ops = {exir_ops.edge.aten.expand_copy.default}
61+
6162
expand_copy = exir_ops.edge.aten.expand_copy.default
6263
repeat = exir_ops.edge.aten.repeat.default
6364

backends/arm/_passes/convert_full_like_to_full_pass.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from executorch.backends.arm._passes.fuse_constant_ops_pass import (
1010
ComputeConstantOpsAOTPass,
1111
)
12-
1312
from executorch.exir.dialects._ops import ops as exir_ops
1413
from executorch.exir.pass_base import ExportPass
1514

@@ -36,6 +35,8 @@ class ConvertFullLikeToFullPass(ArmPass):
3635

3736
_passes_required_after: Set[Type[ExportPass]] = {ComputeConstantOpsAOTPass}
3837

38+
targeted_ops = {exir_ops.edge.aten.full_like.default}
39+
3940
def call_operator(self, op, args, kwargs, meta):
4041
if op not in [
4142
exir_ops.edge.aten.full_like.default,

backends/arm/_passes/convert_permute_singleton_to_view_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,8 @@
77
from typing import Sequence, Set, Tuple, Type
88

99
from executorch.backends.arm._passes.arm_pass import ArmPass
10-
1110
from executorch.exir.dialects._ops import ops as exir_ops
1211
from executorch.exir.pass_base import ExportPass
13-
1412
from torch._ops import OpOverload
1513

1614

@@ -35,6 +33,8 @@ class ConvertPermuteSingletonToViewPass(ArmPass):
3533

3634
_passes_required_after: Set[Type[ExportPass]] = set()
3735

36+
targeted_ops = set(_PERMUTE_TARGETS)
37+
3838
def call_operator(self, op, args, kwargs, meta):
3939
if op not in _PERMUTE_TARGETS:
4040
return super().call_operator(op, args, kwargs, meta)

backends/arm/_passes/convert_split_to_slice.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,11 @@ class ConvertSplitToSlicePass(ArmPass):
2121

2222
_passes_required_after: Set[Type[ExportPass]] = set()
2323

24+
targeted_ops = {
25+
exir_ops.edge.aten.split_with_sizes_copy.default,
26+
exir_ops.edge.aten.split_copy.Tensor,
27+
}
28+
2429
split_ops = (
2530
exir_ops.edge.aten.split_with_sizes_copy.default,
2631
exir_ops.edge.aten.split_copy.Tensor,

backends/arm/_passes/convert_squeezes_to_view.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@ class ConvertSqueezesToViewPass(ArmPass):
2424

2525
_passes_required_after: Set[Type[ExportPass]] = {FuseViewCopyTransformPass}
2626

27+
targeted_ops = {
28+
exir_ops.edge.aten.squeeze_copy.dims,
29+
exir_ops.edge.aten.unsqueeze_copy.default,
30+
}
31+
2732
def call_operator(self, op, args, kwargs, meta):
2833
if op not in [
2934
exir_ops.edge.aten.squeeze_copy.dims,

backends/arm/_passes/convert_to_clamp_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,9 @@
66
from typing import Set, Tuple, Type
77

88
from executorch.backends.arm._passes import ArmPass
9-
109
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
1110
QuantizeClampArgumentsPass,
1211
)
13-
1412
from executorch.exir.dialects._ops import ops as exir_ops
1513
from executorch.exir.pass_base import ExportPass
1614

@@ -32,6 +30,8 @@ def get_clamp_params(op, args) -> Tuple[float | None, float | None]:
3230
class ConvertToClampPass(ArmPass):
3331
_passes_required_after: Set[Type[ExportPass]] = {QuantizeClampArgumentsPass}
3432

33+
targeted_ops = edge_operators
34+
3535
def call_operator(self, op, args, kwargs, meta):
3636
if op not in edge_operators or not self.allowed_to_transform(meta):
3737
return super().call_operator(op, args, kwargs, meta)

backends/arm/_passes/decompose_acosh_pass.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ class DecomposeAcoshPass(ArmPass):
3737
MatchArgDtypePass,
3838
}
3939

40+
targeted_ops = {edge_acosh_op}
41+
4042
def call_operator(self, op, args, kwargs, meta, updated=False):
4143

4244
if op is not edge_acosh_op:

0 commit comments

Comments
 (0)