From a9c97460dd83ce3dbe2c4bccefb35700c96d74d1 Mon Sep 17 00:00:00 2001 From: Andrew Pullin Date: Fri, 17 Apr 2026 09:03:00 -0700 Subject: [PATCH] Minor speedup for model lowering: Skip redundant run_decompositions when no ops match decomp table (#18496) Summary: Adds an early-exit check to _gen_edge_manager_for_partitioners: before calling program.run_decompositions(table), scan the graph for ops that appear in the decomposition table. If none are found, skip the call entirely. Each run_decompositions call performs a full re-export of the program via make_fx(), re-tracing every node through FakeTensor dispatch. On the EDGE_DO_NOT_DECOMP path this function is called up to 3 times; the early-exit eliminates at least one redundant call where the previous pass already decomposed all matching ops. The check recursively walks control flow submodules (cond/map/scan) to avoid incorrectly skipping when decomposable ops are nested. ## Benchmark Model: small CNN feature extractor (~50K params, 9 conv layers with LayerNorm, targeting Ethos-U55 via the ARM/TOSA lowering pipeline). Graph: ~1200 nodes. lower() before: 82 s lower() after: 71 s Delta: -11 s (-13 %) Differential Revision: D96489903 --- exir/program/_program.py | 41 +++++++++++++++++++++++++++++++++++----- 1 file changed, 36 insertions(+), 5 deletions(-) diff --git a/exir/program/_program.py b/exir/program/_program.py index c68d0eed945..12690a8b86f 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -1163,7 +1163,33 @@ def _can_skip_using_EDGE_DO_NOT_DECOMP( return check_op_support is None -def _gen_edge_manager_for_partitioners( +def _has_decomposable_ops( + program: "ExportedProgram", + decomp_table: dict, +) -> bool: + """Check if any ops in the graph match the decomposition table. + + Returns True if the graph contains at least one op that appears in the + decomposition table, meaning run_decompositions would actually decompose + something. Returns True for empty tables (functionalization-only path) + since we can't cheaply determine if the graph needs functionalization. + """ + if not decomp_table: + return True # empty table = functionalize, can't skip cheaply + + def _graph_has_match(gm: torch.fx.GraphModule) -> bool: + for node in gm.graph.nodes: + if node.op == "call_function" and node.target in decomp_table: + return True + for _, submod, _ in get_control_flow_submodules(gm): + if _graph_has_match(submod): + return True + return False + + return _graph_has_match(program.graph_module) + + +def _gen_edge_manager_for_partitioners( # noqa: C901 partitioner: Dict[str, List[Partitioner]], aten_programs: Dict[str, ExportedProgram], config: EdgeCompileConfig, @@ -1198,7 +1224,8 @@ def _gen_edge_manager_for_partitioners( table = _default_decomposition_table() for op in config.preserve_ops: table.pop(op, None) - program = program.run_decompositions(table) + if _has_decomposable_ops(program, table): + program = program.run_decompositions(table) # Process each partitioner individually using their specific requirements for curr_partitioner in partitioners_for_program: @@ -1218,7 +1245,8 @@ def _gen_edge_manager_for_partitioners( if table.pop(op, None) is not None: ops_needing_preservation.append(op) - program = program.run_decompositions(table) + if _has_decomposable_ops(program, table): + program = program.run_decompositions(table) final_ops_to_preserve.update(ops_needing_preservation) else: # EDGE_DO_NOT_DECOMP path for the partitioner @@ -1232,7 +1260,8 @@ def _gen_edge_manager_for_partitioners( table.pop(op, None) # First pass of decompositions with this partitioner's preserved ops - program = program.run_decompositions(table) + if _has_decomposable_ops(program, table): + program = program.run_decompositions(table) # Filter ops using EDGE_DO_NOT_DECOMP temp_partitioner_dict = {name: [curr_partitioner]} @@ -1245,7 +1274,9 @@ def _gen_edge_manager_for_partitioners( final_ops_to_preserve.update(preserved_ops) # Second pass of decompositions with this partitioner's preserved ops after filtering - program = program.run_decompositions(_default_decomposition_table()) + full_table = _default_decomposition_table() + if _has_decomposable_ops(program, full_table): + program = program.run_decompositions(full_table) # Restore ops from edge_no_decomp_namespace to aten ops _restore_transformed_ops_to_aten_ops(program)