Skip to content

Commit 559036a

Browse files
Andrew Pullinfacebook-github-bot
authored andcommitted
Skip redundant run_decompositions when no ops match decomp table (pytorch#18496)
Summary: Adds an early-exit check to _gen_edge_manager_for_partitioners: before calling program.run_decompositions(table), scan the graph for ops that appear in the decomposition table. If none are found, skip the call entirely. Each run_decompositions call performs a full re-export of the program via make_fx(), re-tracing every node through FakeTensor dispatch. On the EDGE_DO_NOT_DECOMP path this function is called up to 3 times; the early-exit eliminates at least one redundant call where the previous pass already decomposed all matching ops. The check recursively walks control flow submodules (cond/map/scan) to avoid incorrectly skipping when decomposable ops are nested. ## Benchmark Model: small CNN feature extractor (~50K params, 9 conv layers with LayerNorm, targeting Ethos-U55 via the ARM/TOSA lowering pipeline). Graph: ~1200 nodes. lower() before: 82 s lower() after: 71 s Delta: -11 s (-13 %) Differential Revision: D96489903
1 parent ae07d06 commit 559036a

1 file changed

Lines changed: 36 additions & 5 deletions

File tree

exir/program/_program.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1161,7 +1161,33 @@ def _can_skip_using_EDGE_DO_NOT_DECOMP(
11611161
return check_op_support is None
11621162

11631163

1164-
def _gen_edge_manager_for_partitioners(
1164+
def _has_decomposable_ops(
1165+
program: "ExportedProgram",
1166+
decomp_table: dict,
1167+
) -> bool:
1168+
"""Check if any ops in the graph match the decomposition table.
1169+
1170+
Returns True if the graph contains at least one op that appears in the
1171+
decomposition table, meaning run_decompositions would actually decompose
1172+
something. Returns True for empty tables (functionalization-only path)
1173+
since we can't cheaply determine if the graph needs functionalization.
1174+
"""
1175+
if not decomp_table:
1176+
return True # empty table = functionalize, can't skip cheaply
1177+
1178+
def _graph_has_match(gm: torch.fx.GraphModule) -> bool:
1179+
for node in gm.graph.nodes:
1180+
if node.op == "call_function" and node.target in decomp_table:
1181+
return True
1182+
for _name, child in gm.named_children():
1183+
if isinstance(child, torch.fx.GraphModule) and _graph_has_match(child):
1184+
return True
1185+
return False
1186+
1187+
return _graph_has_match(program.graph_module)
1188+
1189+
1190+
def _gen_edge_manager_for_partitioners( # noqa: C901
11651191
partitioner: Dict[str, List[Partitioner]],
11661192
aten_programs: Dict[str, ExportedProgram],
11671193
config: EdgeCompileConfig,
@@ -1196,7 +1222,8 @@ def _gen_edge_manager_for_partitioners(
11961222
table = _default_decomposition_table()
11971223
for op in config.preserve_ops:
11981224
table.pop(op, None)
1199-
program = program.run_decompositions(table)
1225+
if _has_decomposable_ops(program, table):
1226+
program = program.run_decompositions(table)
12001227

12011228
# Process each partitioner individually using their specific requirements
12021229
for curr_partitioner in partitioners_for_program:
@@ -1216,7 +1243,8 @@ def _gen_edge_manager_for_partitioners(
12161243
if table.pop(op, None) is not None:
12171244
ops_needing_preservation.append(op)
12181245

1219-
program = program.run_decompositions(table)
1246+
if _has_decomposable_ops(program, table):
1247+
program = program.run_decompositions(table)
12201248
final_ops_to_preserve.update(ops_needing_preservation)
12211249
else:
12221250
# EDGE_DO_NOT_DECOMP path for the partitioner
@@ -1230,7 +1258,8 @@ def _gen_edge_manager_for_partitioners(
12301258
table.pop(op, None)
12311259

12321260
# First pass of decompositions with this partitioner's preserved ops
1233-
program = program.run_decompositions(table)
1261+
if _has_decomposable_ops(program, table):
1262+
program = program.run_decompositions(table)
12341263

12351264
# Filter ops using EDGE_DO_NOT_DECOMP
12361265
temp_partitioner_dict = {name: [curr_partitioner]}
@@ -1243,7 +1272,9 @@ def _gen_edge_manager_for_partitioners(
12431272
final_ops_to_preserve.update(preserved_ops)
12441273

12451274
# Second pass of decompositions with this partitioner's preserved ops after filtering
1246-
program = program.run_decompositions(_default_decomposition_table())
1275+
full_table = _default_decomposition_table()
1276+
if _has_decomposable_ops(program, full_table):
1277+
program = program.run_decompositions(full_table)
12471278

12481279
# Restore ops from edge_no_decomp_namespace to aten ops
12491280
_restore_transformed_ops_to_aten_ops(program)

0 commit comments

Comments
 (0)