From bb8197e45efa2306222a0f86b2a95510f639c8d5 Mon Sep 17 00:00:00 2001 From: Matthias Cremon Date: Tue, 12 May 2026 23:55:29 -0700 Subject: [PATCH 1/2] Handle rank-changing views in RemovePermutesAroundElementwiseOps MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Extend RemovePermutesAroundElementwiseOps to cancel permute pairs across rank-changing squeeze/unsqueeze view boundaries. When a permute's sole user is a view_copy that adds or removes a single size-1 dimension, the pass adapts the expected permutation to the new rank and continues traversal. This enables removing permutes that sit on opposite sides of an unsqueeze→elementwise→squeeze chain (e.g. the NHWC↔NTC layout conversion around convolutions in the cascade detector model). Key changes: - Accept extra_permutable_ops constructor parameter for backend-specific ops - Track per-node expected permutations across view boundaries - Run dimension updates before edges_in bypass to preserve original metadata - Handle view_copy, unsqueeze_copy, squeeze_copy rank changes - Treat aten.full.default as a compile-time constant Note: The PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView pass is removed from the Arm pass manager, since it doesn't actually help anymore. Differential Revision: D104775244 --- backends/arm/_passes/arm_pass_manager.py | 5 - ...ve_permutes_around_elementwise_tosa_ops.py | 16 +- backends/cadence/aot/remove_ops.py | 20 +- .../remove_permutes_around_elementwise_ops.py | 421 ++++++++++++++---- backends/transforms/targets.bzl | 1 + .../test/test_permute_optimization_passes.py | 79 ++++ 6 files changed, 442 insertions(+), 100 deletions(-) diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index bad535efe6f..1cbd100ae6b 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -161,10 +161,6 @@ from executorch.backends.transforms.fuse_cascaded_transpose_or_permute_ops import ( FuseCascadedTransposeOrPermuteOps, ) -from executorch.backends.transforms.postpone_permute_below_squeeze_view import ( - PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView, -) - from executorch.exir import ExportedProgram from executorch.exir.pass_base import ExportPass from executorch.exir.pass_manager import PassManager @@ -538,7 +534,6 @@ def _tosa_pipeline( RewritePadPass(), FuseViewCopyTransformPass(), RemovePermutesAroundElementwiseTosaOps(), - PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView(), FuseCascadedTransposeOrPermuteOps(), ConvertPermuteSingletonToViewPass(), RewriteHighRankSingletonPermutePass(), diff --git a/backends/arm/_passes/remove_permutes_around_elementwise_tosa_ops.py b/backends/arm/_passes/remove_permutes_around_elementwise_tosa_ops.py index e000b3d6fe8..fa6f6f7988c 100644 --- a/backends/arm/_passes/remove_permutes_around_elementwise_tosa_ops.py +++ b/backends/arm/_passes/remove_permutes_around_elementwise_tosa_ops.py @@ -11,13 +11,15 @@ class RemovePermutesAroundElementwiseTosaOps(RemovePermutesAroundElementwiseOps): - permutable_ops = { - *RemovePermutesAroundElementwiseOps.permutable_ops, - *TableOps.unary_table_ops.keys(), - *TableOps.special_table_ops, - exir_ops.backend.tosa.RESCALE.default, - exir_ops.backend.tosa.TABLE.default, - } + def __init__(self) -> None: + super().__init__( + extra_permutable_ops={ + *TableOps.unary_table_ops.keys(), + *TableOps.special_table_ops, + exir_ops.backend.tosa.RESCALE.default, + exir_ops.backend.tosa.TABLE.default, + } + ) def permute_subgraph(self, subgraph): # Original function will always permute constant nodes which is wrong for table ops diff --git a/backends/cadence/aot/remove_ops.py b/backends/cadence/aot/remove_ops.py index e532d088e5c..c221c3a5a18 100644 --- a/backends/cadence/aot/remove_ops.py +++ b/backends/cadence/aot/remove_ops.py @@ -603,16 +603,16 @@ def maybe_remove_or_replace(self, node: Node) -> bool: @register_cadence_pass(CadencePassAttribute(opt_level=2)) class RemovePermutesAroundElementwiseOps(_SharedRemovePermutesAroundElementwiseOps): - permutable_ops: set[EdgeOpOverload] = ( - _SharedRemovePermutesAroundElementwiseOps.permutable_ops - | { - exir_ops.edge.cadence.quantize_per_tensor.default, - exir_ops.edge.cadence.dequantize_per_tensor.default, - exir_ops.edge.cadence.quantized_relu.per_tensor, - exir_ops.edge.cadence.requantize.per_tensor, - exir_ops.edge.cadence.quantized_add.per_tensor, - } - ) + def __init__(self) -> None: + super().__init__( + extra_permutable_ops={ + exir_ops.edge.cadence.quantize_per_tensor.default, + exir_ops.edge.cadence.dequantize_per_tensor.default, + exir_ops.edge.cadence.quantized_relu.per_tensor, + exir_ops.edge.cadence.requantize.per_tensor, + exir_ops.edge.cadence.quantized_add.per_tensor, + } + ) @register_cadence_pass(CadencePassAttribute(opt_level=2)) diff --git a/backends/transforms/remove_permutes_around_elementwise_ops.py b/backends/transforms/remove_permutes_around_elementwise_ops.py index 28b739c8d91..4ed8fa4e21d 100644 --- a/backends/transforms/remove_permutes_around_elementwise_ops.py +++ b/backends/transforms/remove_permutes_around_elementwise_ops.py @@ -13,7 +13,6 @@ import torch import torch.fx from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.dialects.edge._ops import EdgeOpOverload from executorch.exir.pass_base import ExportPass, PassResult @@ -39,20 +38,151 @@ class Subgraph: constant_edges_in: set[tuple[torch.fx.Node, torch.fx.Node]] = field( default_factory=set ) + # Per-node expected end permutation (may differ from end_permute + # when the subgraph contains rank-changing views). + node_end_permute: dict[torch.fx.Node, list[int]] = field(default_factory=dict) + # Per-node expected start permutation for upstream traversal. + node_start_permute: dict[torch.fx.Node, list[int]] = field( + default_factory=dict + ) + + # Ops explicitly listed as permutable. Includes non-pointwise ops + # that need special dimension-argument handling (cat, mean, sum, slice) + # and quantize/dequantize ops not tagged as pointwise in ATen. + # In addition, view_copy ops that insert/remove a single size-1 + # dimension are accepted as rank-changing boundaries + # (see _is_squeeze_unsqueeze_view). + _base_permutable_ops: set = set() + _base_ops_initialized = False + + def __init__(self, extra_permutable_ops: set | None = None) -> None: + super().__init__() + self._extra_permutable_ops = extra_permutable_ops or set() + + @classmethod + def _get_base_permutable_ops(cls) -> set: + if not cls._base_ops_initialized: + cls._base_permutable_ops = { + exir_ops.edge.aten.add.Tensor, + exir_ops.edge.aten.mul.Tensor, + exir_ops.edge.aten.sub.Tensor, + exir_ops.edge.aten.hardtanh.default, + exir_ops.edge.aten.clamp.default, + exir_ops.edge.aten.cat.default, + exir_ops.edge.aten.mean.dim, + exir_ops.edge.aten.sum.dim_IntList, + exir_ops.edge.aten.slice_copy.Tensor, + } + try: + cls._base_permutable_ops.add( + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default + ) + cls._base_permutable_ops.add( + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default + ) + except (AttributeError, RuntimeError): + pass + cls._base_ops_initialized = True + return cls._base_permutable_ops + + def _is_in_permutable_ops(self, target) -> bool: + return target in self._get_base_permutable_ops() or target in self._extra_permutable_ops + + _VIEW_OPS = ( + exir_ops.edge.aten.view_copy.default, + exir_ops.edge.aten.view.default, + ) + + _UNSQUEEZE_OPS = ( + exir_ops.edge.aten.unsqueeze_copy.default, + ) + + _SQUEEZE_OPS = ( + exir_ops.edge.aten.squeeze_copy.dim, + ) - permutable_ops: set[EdgeOpOverload] = { - exir_ops.edge.aten.add.Tensor, - exir_ops.edge.aten.mul.Tensor, - exir_ops.edge.aten.hardtanh.default, - exir_ops.edge.aten.clamp.default, - exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, - exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, - # Ops that require special handling. - exir_ops.edge.aten.cat.default, - exir_ops.edge.aten.mean.dim, - exir_ops.edge.aten.sum.dim_IntList, - exir_ops.edge.aten.slice_copy.Tensor, - } + @staticmethod + def _find_extra_one(longer: list[int], shorter: list[int]) -> int: + """If longer has exactly one more element of value 1, return its index. Else -1.""" + if len(longer) != len(shorter) + 1: + return -1 + for i in range(len(shorter)): + if longer[i] != shorter[i]: + if longer[i] == 1 and shorter[i:] == longer[i + 1 :]: + return i + return -1 + return len(shorter) if longer[-1] == 1 else -1 + + def _is_squeeze_unsqueeze_view(self, node: torch.fx.Node) -> bool: + """Check if a node is a squeeze, unsqueeze, or view_copy that only + adds or removes a single dim of size 1.""" + if node.target in self._UNSQUEEZE_OPS or node.target in self._SQUEEZE_OPS: + return True + if node.target not in self._VIEW_OPS: + return False + if node.meta.get("val") is None: + return False + inp = node.args[0] + if not isinstance(inp, torch.fx.Node) or inp.meta.get("val") is None: + return False + in_shape = list(inp.meta["val"].shape) + out_shape = list(node.meta["val"].shape) + if len(out_shape) == len(in_shape) + 1: + return self._find_extra_one(out_shape, in_shape) != -1 + if len(in_shape) == len(out_shape) + 1: + return self._find_extra_one(in_shape, out_shape) != -1 + return False + + def _adapt_permute_across_view( + self, permute: list[int], node: torch.fx.Node + ) -> list[int] | None: + """Adjust a permutation across a squeeze/unsqueeze boundary. + + Adapts from input-rank to output-rank space (downstream direction). + Returns the adjusted permutation, or None if not possible. + """ + # Handle explicit unsqueeze_copy(dim) + if node.target in self._UNSQUEEZE_OPS: + dim = cast(int, node.args[1]) + rank = len(permute) + index = dim if dim >= 0 else dim + rank + 1 + new_perm = [x + 1 if x >= index else x for x in permute] + new_perm.insert(index, index) + return new_perm + + # Handle explicit squeeze_copy(dim) + if node.target in self._SQUEEZE_OPS: + dim = cast(int, node.args[1]) + rank = len(permute) + index = dim if dim >= 0 else dim + rank + # Find where the squeezed input dim appears in the permutation + if index not in permute: + return None + new_perm = [x - 1 if x > index else x for x in permute if x != index] + return new_perm + + # Handle view_copy (squeeze/unsqueeze-like reshape) + inp = node.args[0] + assert isinstance(inp, torch.fx.Node) + in_shape = list(inp.meta["val"].shape) + out_shape = list(node.meta["val"].shape) + + if len(out_shape) == len(in_shape) + 1: + # unsqueeze: insert identity mapping at the new dim + index = self._find_extra_one(out_shape, in_shape) + if index == -1: + return None + new_perm = [x + 1 if x >= index else x for x in permute] + new_perm.insert(index, index) + return new_perm + elif len(in_shape) == len(out_shape) + 1: + # squeeze via view_copy: find the squeezed dim and remove it + index = self._find_extra_one(in_shape, out_shape) + if index == -1 or index not in permute: + return None + new_perm = [x - 1 if x > index else x for x in permute if x != index] + return new_perm + return None def call(self, graph_module: torch.fx.GraphModule) -> PassResult: subgraphs_found: list[RemovePermutesAroundElementwiseOps.Subgraph] = [] @@ -66,16 +196,55 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # Expected end permutation for the subgraph. end_permute = [start_permute.index(i) for i in range(len(start_permute))] + # Try direct users first (same-rank matching) for user in node.users: - if user.target not in self.permutable_ops: + if not self.is_node_permutable(user): continue - # Create a separate subgraph for each user since there may be cases - # where only a portion of the users are permutable. subgraph = self.Subgraph(start_permute, end_permute) if self.visit(user, subgraph, processed_nodes): subgraphs_found.append(subgraph) - for node in subgraph.nodes: - processed_nodes.add(node) + for n in subgraph.nodes: + processed_nodes.add(n) + + # Also try: permute → view(squeeze/unsqueeze) → chain → ... + # If the permute's sole user is a squeeze/unsqueeze view, + # adapt the permutation across the view and search for a + # matching end permute at the new rank. + users = list(node.users.keys()) + if ( + len(users) == 1 + and self._is_squeeze_unsqueeze_view(users[0]) + and node not in processed_nodes + ): + view_node = users[0] + adapted_start = self._adapt_permute_across_view( + start_permute, view_node + ) + if adapted_start is not None: + adapted_end = [ + adapted_start.index(i) for i in range(len(adapted_start)) + ] + for view_user in view_node.users: + if not self.is_node_permutable(view_user): + continue + subgraph = self.Subgraph( + adapted_start, adapted_end + ) + # Include the view in the subgraph + subgraph.nodes.add(view_node) + subgraph.node_end_permute[view_node] = adapted_end + # Use the ORIGINAL start_permute for the view node + # so update_view_copy can remap its shape correctly + subgraph.node_start_permute[view_node] = start_permute + # The start permute feeds into the view + subgraph.edges_in.add((node, view_node)) + if self.visit( + view_user, subgraph, processed_nodes, + adapted_end, adapted_start + ): + subgraphs_found.append(subgraph) + for n in subgraph.nodes: + processed_nodes.add(n) modified = False for subgraph in subgraphs_found: @@ -94,40 +263,85 @@ def visit( # noqa: C901 node: torch.fx.Node, subgraph: Subgraph, processed_nodes: set[torch.fx.Node], + current_end_permute: list[int] | None = None, + current_start_permute: list[int] | None = None, ) -> bool: + if current_end_permute is None: + current_end_permute = subgraph.end_permute + if current_start_permute is None: + current_start_permute = subgraph.start_permute + if node in subgraph.nodes: return True if node in processed_nodes or not self.is_node_permutable(node): return False subgraph.nodes.add(node) + subgraph.node_end_permute[node] = current_end_permute + subgraph.node_start_permute[node] = current_start_permute + + # If this is a squeeze/unsqueeze view, adapt permutations for + # traversal across the rank change boundary. + downstream_end = current_end_permute + downstream_start = current_start_permute + if self._is_squeeze_unsqueeze_view(node): + # Adapt end permute for downstream (input-rank → output-rank) + adapted_end = self._adapt_permute_across_view( + current_end_permute, node + ) + if adapted_end is None: + return False + downstream_end = adapted_end + + # Adapt start permute for downstream too (so upstream checks + # on the other side of the view use the correct rank) + adapted_start = self._adapt_permute_across_view( + current_start_permute, node + ) + if adapted_start is None: + return False + downstream_start = adapted_start # Traverse downstream: for user in node.users: - # Output should either go to a matching permute or another permutable op. if user.target == exir_ops.edge.aten.permute_copy.default: - if self.get_permutation(user) != subgraph.end_permute: + user_perm = self.get_permutation(user) + if user_perm == downstream_end: + subgraph.edges_out.add((node, user)) + else: + # Check if permute → view(squeeze/unsqueeze) forms an + # end boundary at a different rank. + user_users = list(user.users.keys()) + if ( + len(user_users) == 1 + and self._is_squeeze_unsqueeze_view(user_users[0]) + ): + view_after = user_users[0] + # Adapt the expected end permute across the view + adapted = self._adapt_permute_across_view( + downstream_end, view_after + ) + if adapted is not None and user_perm == adapted: + # Include both the permute and the view as end edges + subgraph.edges_out.add((node, user)) + # Mark the view for inclusion so it gets preserved + continue return False - subgraph.edges_out.add((node, user)) elif user.op == "output": - # Graph output requires the data in its original layout. - # Removing permutes here would silently change the output - # format, so treat this as an invalid subgraph boundary. return False - elif not self.visit(user, subgraph, processed_nodes): + elif not self.visit( + user, subgraph, processed_nodes, downstream_end, downstream_start + ): return False # Traverse upstream: for inp in node.all_input_nodes: - # Input should either come from a matching permute or another permutable op. if inp.target == exir_ops.edge.aten.permute_copy.default: - if self.get_permutation(inp) != subgraph.start_permute: + if self.get_permutation(inp) != current_start_permute: return False subgraph.edges_in.add((inp, node)) elif self._is_constant(inp): - # Only accept the constant if we can insert a compensating - # permute or view. Otherwise reject the subgraph. const_rank = self._get_node_rank(inp) - permute_rank = len(subgraph.end_permute) + permute_rank = len(current_end_permute) if const_rank is None: return False if const_rank > permute_rank: @@ -135,7 +349,10 @@ def visit( # noqa: C901 if const_rank < permute_rank and inp.meta.get("val") is None: return False subgraph.constant_edges_in.add((inp, node)) - elif not self.visit(inp, subgraph, processed_nodes): + elif not self.visit( + inp, subgraph, processed_nodes, + current_end_permute, current_start_permute + ): return False return True @@ -143,13 +360,19 @@ def visit( # noqa: C901 def _is_constant(self, node: torch.fx.Node) -> bool: """Check if a node's value is available at compile time. Only considers direct constants (get_attr, parameter/buffer/constant - placeholders) — does not recurse into call_function chains to avoid - stack overflow on deep graphs.""" + placeholders, full ops producing scalar constants) — does not recurse + into call_function chains to avoid stack overflow on deep graphs.""" if node.op == "get_attr": return True if node.op == "placeholder": target = str(node.target) return target.startswith(("b_", "p_", "c_")) + # full.default creates scalar constants (e.g. epsilon in LayerNorm) + if ( + node.op == "call_function" + and node.target == exir_ops.edge.aten.full.default + ): + return True return False def _get_node_rank(self, node: torch.fx.Node) -> int | None: @@ -160,25 +383,48 @@ def _get_node_rank(self, node: torch.fx.Node) -> int | None: return None def is_node_permutable(self, node: torch.fx.Node) -> bool: - if node.target not in self.permutable_ops: - return False - if node.target in ( - exir_ops.edge.aten.mean.dim, - exir_ops.edge.aten.sum.dim_IntList, - ): - # keepdim should be True. - if len(node.args) >= 3: - if not node.args[2]: - return False - elif "keepdim" in node.kwargs: - if not node.kwargs["keepdim"]: + if self._is_in_permutable_ops(node.target): + if node.target in ( + exir_ops.edge.aten.mean.dim, + exir_ops.edge.aten.sum.dim_IntList, + ): + if len(node.args) >= 3: + if not node.args[2]: + return False + elif "keepdim" in node.kwargs: + if not node.kwargs["keepdim"]: + return False + else: return False - else: - # Default keepdim is False. - return False - return True + return True + if self._is_squeeze_unsqueeze_view(node): + return True + return self._is_in_permutable_ops(node.target) def permute_subgraph(self, subgraph: Subgraph) -> None: + # Handle dimension related node arguments FIRST, before + # bypassing permutes (which changes node inputs/metadata). + for node in subgraph.nodes: + node_start_perm = subgraph.node_start_permute.get( + node, subgraph.start_permute + ) + if node.target == exir_ops.edge.aten.cat.default: + self.update_cat(node, node_start_perm) + elif node.target in ( + exir_ops.edge.aten.mean.dim, + exir_ops.edge.aten.sum.dim_IntList, + ): + self.update_mean_dim(node, node_start_perm) + elif node.target == exir_ops.edge.aten.slice_copy.Tensor: + self.update_slice_copy(node, node_start_perm) + elif node.target in self._VIEW_OPS: + self.update_view_copy(node, node_start_perm) + elif node.target in self._UNSQUEEZE_OPS or node.target in self._SQUEEZE_OPS: + dim = cast(int, node.args[1]) + rank = len(node_start_perm) + index = dim if dim >= 0 else dim + rank + node.update_arg(1, node_start_perm[index]) + # Skip incoming permutes. for inp, out in subgraph.edges_in: assert inp.target == exir_ops.edge.aten.permute_copy.default @@ -188,38 +434,30 @@ def permute_subgraph(self, subgraph: Subgraph) -> None: out.replace_input_with(inp, cast(torch.fx.Node, inp.kwargs["input"])) # Insert compensating permute on constant inputs. - # Since the subgraph's start permutes are being removed, the subgraph - # will operate in the un-permuted (original) layout. Constants that - # were in the permuted layout need end_permute (the inverse of - # start_permute) to convert back to the original layout. for const_node, user_node in subgraph.constant_edges_in: graph = const_node.graph const_rank = self._get_node_rank(const_node) - permute_rank = len(subgraph.end_permute) + # Use the node-specific end_permute for the correct rank + node_end_perm = subgraph.node_end_permute.get( + user_node, subgraph.end_permute + ) + permute_rank = len(node_end_perm) with graph.inserting_after(const_node): if const_rank is not None and const_rank == permute_rank: new_node = graph.create_node( "call_function", exir_ops.edge.aten.permute_copy.default, - args=(const_node, subgraph.end_permute), + args=(const_node, node_end_perm), ) elif ( const_rank is not None and const_rank < permute_rank and const_node.meta.get("val") is not None ): - # Rank mismatch (e.g. rank-1 bias with rank-4 permute). - # The constant is broadcastable and its shape is smaller - # than the permute rank, so we can't apply the permute - # directly. Instead, use view_copy to rearrange the - # shape according to the end_permute restricted to - # the trailing dimensions. original_shape = list(const_node.meta["val"].shape) - # Pad shape to match permute rank for reordering padded = [1] * (permute_rank - const_rank) + original_shape - target_shape = [padded[d] for d in subgraph.end_permute] - # Strip leading 1s back to original rank + target_shape = [padded[d] for d in node_end_perm] target_shape = target_shape[permute_rank - const_rank :] new_node = graph.create_node( "call_function", @@ -227,7 +465,6 @@ def permute_subgraph(self, subgraph: Subgraph) -> None: args=(const_node, target_shape), ) else: - # Cannot determine rank or handle this case; skip. continue user_node.replace_input_with(const_node, new_node) @@ -236,18 +473,6 @@ def permute_subgraph(self, subgraph: Subgraph) -> None: assert out.target == exir_ops.edge.aten.permute_copy.default out.replace_all_uses_with(inp) - # Handle dimension related node arguments. - for node in subgraph.nodes: - if node.target == exir_ops.edge.aten.cat.default: - self.update_cat(node, subgraph.start_permute) - elif node.target in ( - exir_ops.edge.aten.mean.dim, - exir_ops.edge.aten.sum.dim_IntList, - ): - self.update_mean_dim(node, subgraph.start_permute) - elif node.target == exir_ops.edge.aten.slice_copy.Tensor: - self.update_slice_copy(node, subgraph.start_permute) - def update_cat(self, node: torch.fx.Node, start_permute: list[int]) -> None: if len(node.args) >= 2: node.update_arg(1, start_permute[cast(int, node.args[1])]) @@ -274,6 +499,46 @@ def update_slice_copy(self, node: torch.fx.Node, start_permute: list[int]) -> No else: node.update_kwarg("dim", start_permute[cast(int, node.kwargs["dim"])]) + def update_view_copy( + self, node: torch.fx.Node, start_permute: list[int] + ) -> None: + """Adjust view_copy shape arg after permute removal. + + After removing the start permute, the view's input is in the original + (un-permuted) layout. Recompute the view's target shape accordingly. + """ + if node.meta.get("val") is None: + return + inp = node.args[0] + if not isinstance(inp, torch.fx.Node) or inp.meta.get("val") is None: + return + + in_shape = list(inp.meta["val"].shape) + out_shape = list(node.meta["val"].shape) + + # Compute un-permuted input shape + inverse_permute = [start_permute.index(i) for i in range(len(start_permute))] + unpermuted_in = [in_shape[inverse_permute[i]] for i in range(len(in_shape))] + + if len(out_shape) == len(in_shape) + 1: + # unsqueeze: find the inserted dim in the permuted output, + # then determine where it goes in the un-permuted layout + index = self._find_extra_one(out_shape, in_shape) + if index != -1: + new_shape = list(unpermuted_in) + new_shape.insert(index, 1) + node.update_arg(1, new_shape) + elif len(in_shape) == len(out_shape) + 1: + # squeeze: find the removed dim in the permuted input, + # map it to the un-permuted position, and remove it + index = self._find_extra_one(in_shape, out_shape) + if index != -1: + # Map the squeezed dim from permuted to un-permuted space + unpermuted_index = start_permute[index] + new_shape = list(unpermuted_in) + del new_shape[unpermuted_index] + node.update_arg(1, new_shape) + def get_permutation(self, permute_node: torch.fx.Node) -> list[int] | None: assert permute_node.target == exir_ops.edge.aten.permute_copy.default raw_permute: list[int] diff --git a/backends/transforms/targets.bzl b/backends/transforms/targets.bzl index 5c3343469ce..4122ef41d55 100644 --- a/backends/transforms/targets.bzl +++ b/backends/transforms/targets.bzl @@ -376,6 +376,7 @@ def define_common_targets(): ":fuse_cascaded_transpose_or_permute_ops", ":fuse_cascaded_view_ops", ":postpone_permute_below_squeeze_view", + ":remove_permutes_around_elementwise_ops", ":replace_nop_transpose_or_permute_with_view", ], ) diff --git a/backends/transforms/test/test_permute_optimization_passes.py b/backends/transforms/test/test_permute_optimization_passes.py index bb326f125bc..0db61186766 100644 --- a/backends/transforms/test/test_permute_optimization_passes.py +++ b/backends/transforms/test/test_permute_optimization_passes.py @@ -22,6 +22,9 @@ from executorch.backends.transforms.replace_nop_transpose_or_permute_with_view import ( ReplaceNopTransposeOrPermuteWithViewPass, ) +from executorch.backends.transforms.remove_permutes_around_elementwise_ops import ( + RemovePermutesAroundElementwiseOps, +) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import PassResult from torch.utils import _pytree as pytree @@ -380,6 +383,82 @@ def test_replace_nop_transpose_with_view_float(self) -> None: gm_before, gm_after, [x], "ReplaceNopTransposeOrPermuteWithViewPass" ) + +# ────────────────────────────────────────────────────────────────────── +# Tests for RemovePermutesAroundElementwiseOps cross-view handling +# ────────────────────────────────────────────────────────────────────── + + +class RemovePermutesAcrossViewTest(unittest.TestCase): + def test_permute_view_squeeze_elementwise_view_unsqueeze_permute(self) -> None: + """permute(3D) → view(unsqueeze) → mul(4D) → view(squeeze) → permute(3D) + should have both permutes removed.""" + builder = GraphBuilder() + x_data = torch.randn(1, 128, 16) + x = builder.placeholder("x", x_data) + p1 = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 2, 1]) + ) + v1 = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, args=(p1, [1, 16, 1, 128]) + ) + mul = builder.call_operator( + op=exir_ops.edge.aten.mul.Tensor, args=(v1, v1) + ) + v2 = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, args=(mul, [1, 16, 128]) + ) + p2 = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(v2, [0, 2, 1]) + ) + builder.output([p2]) + original = builder.get_graph_module() + gm_before = copy.deepcopy(original) + + p = RemovePermutesAroundElementwiseOps() + result = cast(PassResult, p(original)) + self.assertTrue(result.modified) + self.assertEqual( + count_node(result.graph_module, exir_ops.edge.aten.permute_copy.default), 0 + ) + validate_numerics( + gm_before, result.graph_module, [x_data], "RemovePermutesAcrossView", + ) + + def test_4d_permute_squeeze_clamp_3d_permute(self) -> None: + """Cascade detector conv→LN boundary: permute_4D([0,3,1,2]) → + view(squeeze) → hardtanh → permute_3D([0,2,1]). + The two permutes should cancel across the squeeze+clamp.""" + builder = GraphBuilder() + x_data = torch.randn(1, 1, 16, 128) + x = builder.placeholder("x", x_data) + p1 = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 3, 1, 2]) + ) + v1 = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, args=(p1, [1, 128, 16]) + ) + clamp = builder.call_operator( + op=exir_ops.edge.aten.hardtanh.default, args=(v1,) + ) + p2 = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(clamp, [0, 2, 1]) + ) + builder.output([p2]) + original = builder.get_graph_module() + gm_before = copy.deepcopy(original) + + p = RemovePermutesAroundElementwiseOps() + result = cast(PassResult, p(original)) + self.assertTrue(result.modified) + self.assertEqual( + count_node(result.graph_module, exir_ops.edge.aten.permute_copy.default), 0 + ) + validate_numerics( + gm_before, result.graph_module, [x_data], + "4D_permute_squeeze_clamp_3D_permute", + ) + def test_replace_nop_transpose_with_view_int(self) -> None: x = torch.randint(low=0, high=100, size=(2, 1, 5), dtype=torch.int64) gm = single_op_builder( From 7699a000de21368125218e53d6552a605ed68817 Mon Sep 17 00:00:00 2001 From: Matthias Cremon Date: Tue, 12 May 2026 23:55:29 -0700 Subject: [PATCH 2/2] Handle rank-changing views in FuseCascadedTransposeOrPermuteOps (#19539) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Extend FuseCascadedTransposeOrPermuteOps to fuse permute→view_copy→permute patterns where a squeeze/unsqueeze view sits between two permutes. The pass computes the combined effect of permute+view+permute: if the permutation components cancel out (identity), the entire chain is replaced with a single view_copy. This handles patterns like permute_3D([0,2,1]) → view(unsqueeze) → permute_4D([0,2,3,1]) which composes to a simple view_copy (the permutations cancel, leaving only the reshape). Differential Revision: D104775245 --- .../fuse_cascaded_transpose_or_permute_ops.py | 123 ++++++++++++- .../test/test_permute_optimization_passes.py | 165 ++++++++++-------- 2 files changed, 210 insertions(+), 78 deletions(-) diff --git a/backends/transforms/fuse_cascaded_transpose_or_permute_ops.py b/backends/transforms/fuse_cascaded_transpose_or_permute_ops.py index b8d6c75a174..9d0798f4cc1 100644 --- a/backends/transforms/fuse_cascaded_transpose_or_permute_ops.py +++ b/backends/transforms/fuse_cascaded_transpose_or_permute_ops.py @@ -20,7 +20,8 @@ class FuseCascadedTransposeOrPermuteOps(RemoveOrReplacePassInterface): """ Fuse a chain of transpose and permute ops into a single permute or a no-op. - Handles branches and chains permutes. + Handles branches and chains of permutes, including permute-view-permute + patterns where a squeeze/unsqueeze view sits between two permutes. """ transpose_or_permute_target = { @@ -28,20 +29,31 @@ class FuseCascadedTransposeOrPermuteOps(RemoveOrReplacePassInterface): exir_ops.edge.aten.permute_copy.default, } + _VIEW_OPS = { + exir_ops.edge.aten.view_copy.default, + exir_ops.edge.aten.view.default, + } + @property def targets(self) -> list[EdgeOpOverload]: return list(self.transpose_or_permute_target) def maybe_remove_or_replace(self, node: Node) -> bool: - # Fuse with the parent node if it's also a permute or a transpose. Since the - # pass interface traverses all ops in order the pass will properly fuse a chain - # of permutes. parent_node = get_arg(node, "input", Node) - if parent_node.target not in self.transpose_or_permute_target: - return False - input_of_parent = get_arg(parent_node, "input", Node) - # Compute combined effect of permutes. + # Case 1: Direct permute/transpose → permute/transpose + if parent_node.target in self.transpose_or_permute_target: + return self._fuse_direct(node, parent_node) + + # Case 2: permute → view_copy(squeeze/unsqueeze) → permute + if parent_node.target in self._VIEW_OPS: + return self._fuse_across_view(node, parent_node) + + return False + + def _fuse_direct(self, node: Node, parent_node: Node) -> bool: + """Fuse two adjacent permute/transpose ops.""" + input_of_parent = get_arg(parent_node, "input", Node) dims = list(range(node.meta["val"].ndim)) if parent_node.target == exir_ops.edge.aten.transpose_copy.int: @@ -54,7 +66,6 @@ def maybe_remove_or_replace(self, node: Node) -> bool: else: dims = get_permuted_dims(node, dims) - # If combined effect is identity replace the node with input. if dims == sorted(dims): node.replace_all_uses_with(input_of_parent) else: @@ -67,3 +78,97 @@ def maybe_remove_or_replace(self, node: Node) -> bool: node.replace_all_uses_with(new_permute) return True + + def _fuse_across_view(self, node: Node, view_node: Node) -> bool: + """Fuse permute -> view(squeeze/unsqueeze) -> permute into a view_copy.""" + # view_node must have exactly one user (this permute node) + if len(list(view_node.users.keys())) != 1: + return False + # view_node's parent must be a permute/transpose + view_input = get_arg(view_node, "input", Node) + if view_input.target not in self.transpose_or_permute_target: + return False + # The view must be a squeeze or unsqueeze (rank differs by 1) + if view_node.meta.get("val") is None or view_input.meta.get("val") is None: + return False + view_in_shape = list(view_input.meta["val"].shape) + view_out_shape = list(view_node.meta["val"].shape) + if abs(len(view_in_shape) - len(view_out_shape)) != 1: + return False + + # Get the input before the first permute + input_of_first_permute = get_arg(view_input, "input", Node) + if input_of_first_permute.meta.get("val") is None: + return False + + # Compute the combined effect on the original input dimensions + # Start with identity dims for the original input + original_ndim = input_of_first_permute.meta["val"].ndim + dims = list(range(original_ndim)) + + # Apply first permute + if view_input.target == exir_ops.edge.aten.transpose_copy.int: + dims = get_transposed_dims(view_input, dims) + else: + dims = get_permuted_dims(view_input, dims) + + # Apply the view (squeeze/unsqueeze) + if len(view_out_shape) == len(view_in_shape) + 1: + # unsqueeze: insert a new dim + index = self._find_extra_one(view_out_shape, view_in_shape) + if index == -1: + return False + dims = [x + 1 if x >= index else x for x in dims] + dims.insert(index, -1) # -1 marks the inserted dim + elif len(view_in_shape) == len(view_out_shape) + 1: + # squeeze: remove a dim + index = self._find_extra_one(view_in_shape, view_out_shape) + if index == -1: + return False + if dims[index] != -1: + # Safe: permutation preserves dimension sizes, so a size-1 + # intermediate dim necessarily originated from a size-1 input dim. + pass + del dims[index] + + # Apply second permute (node) + if node.target == exir_ops.edge.aten.transpose_copy.int: + node_dims = list(range(len(dims))) + node_dims = get_transposed_dims(node, node_dims) + dims = [dims[d] for d in node_dims] + else: + perm = list(node.args[1]) + dims = [dims[d] for d in perm] + + # Check if the combined effect (ignoring -1 inserted dims) is identity + real_dims = [d for d in dims if d != -1] + output_shape = list(node.meta["val"].shape) + + if real_dims == sorted(real_dims): + # Combined permutations are identity — replace with view_copy + # (the only remaining effect is the squeeze/unsqueeze reshape) + if output_shape == list(input_of_first_permute.meta["val"].shape): + # Total no-op: replace with input + node.replace_all_uses_with(input_of_first_permute) + else: + with node.graph.inserting_before(node): + new_view = node.graph.call_function( + exir_ops.edge.aten.view_copy.default, + args=(input_of_first_permute, output_shape), + ) + new_view.meta = node.meta + node.replace_all_uses_with(new_view) + return True + + return False + + @staticmethod + def _find_extra_one(longer, shorter): + if len(longer) != len(shorter) + 1: + return -1 + for i in range(len(shorter)): + if longer[i] != shorter[i]: + if longer[i] == 1 and shorter[i:] == longer[i + 1:]: + return i + return -1 + return len(shorter) if longer[-1] == 1 else -1 diff --git a/backends/transforms/test/test_permute_optimization_passes.py b/backends/transforms/test/test_permute_optimization_passes.py index 0db61186766..d1b74dc3078 100644 --- a/backends/transforms/test/test_permute_optimization_passes.py +++ b/backends/transforms/test/test_permute_optimization_passes.py @@ -19,12 +19,12 @@ from executorch.backends.transforms.postpone_permute_below_squeeze_view import ( PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView, ) -from executorch.backends.transforms.replace_nop_transpose_or_permute_with_view import ( - ReplaceNopTransposeOrPermuteWithViewPass, -) from executorch.backends.transforms.remove_permutes_around_elementwise_ops import ( RemovePermutesAroundElementwiseOps, ) +from executorch.backends.transforms.replace_nop_transpose_or_permute_with_view import ( + ReplaceNopTransposeOrPermuteWithViewPass, +) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import PassResult from torch.utils import _pytree as pytree @@ -122,15 +122,12 @@ def test_cascaded_permutes_multiple_users(self) -> None: permute1 = builder.call_operator( op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 2, 3, 1]) ) - # permute2 reverses permute1 => identity permute2 = builder.call_operator( op=exir_ops.edge.aten.permute_copy.default, args=(permute1, [0, 3, 1, 2]) ) - # permute3: different permutation permute3 = builder.call_operator( op=exir_ops.edge.aten.permute_copy.default, args=(permute1, [0, 2, 1, 3]) ) - # permute4 -> permute5: chained permute4 = builder.call_operator( op=exir_ops.edge.aten.permute_copy.default, args=(permute1, [3, 2, 0, 1]) ) @@ -151,6 +148,38 @@ def test_cascaded_permutes_multiple_users(self) -> None: "FuseCascadedTransposeOrPermuteOps", ) + def test_permute_view_permute_fuse(self) -> None: + """permute_3D([0,2,1]) → view(unsqueeze) → permute_4D([0,2,3,1]) should + be replaced with a single view_copy (permutations cancel out).""" + builder = GraphBuilder() + x_data = torch.randn(1, 40, 18) + x = builder.placeholder("x", x_data) + p1 = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 2, 1]) + ) + v = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, args=(p1, [1, 18, 1, 40]) + ) + p2 = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(v, [0, 2, 3, 1]) + ) + builder.output([p2]) + original = builder.get_graph_module() + gm_before = copy.deepcopy(original) + + p = FuseCascadedTransposeOrPermuteOps() + result = cast(PassResult, p(original)) + self.assertTrue(result.modified) + gm = result.graph_module + + self.assertEqual(count_node(gm, exir_ops.edge.aten.permute_copy.default), 0) + self.assertGreaterEqual( + count_node(gm, exir_ops.edge.aten.view_copy.default), 1 + ) + validate_numerics( + gm_before, gm, [x_data], "FuseCascadedAcrossView", + ) + # ────────────────────────────────────────────────────────────────────── # Tests for FuseCascadedViewOps @@ -250,7 +279,6 @@ def test_permute3_view4_chains(self) -> None: self.assertEqual(count_node(gm, exir_ops.edge.aten.view_copy.default), 2) self.assertEqual(count_node(gm, exir_ops.edge.aten.permute_copy.default), 2) - # Verify order: views before permutes targets = get_compute_nodes(gm) view_indices = [ i @@ -350,7 +378,6 @@ def test_negative_not_squeeze_like(self) -> None: count_node(result.graph_module, exir_ops.edge.aten.permute_copy.default), 2, ) - # Order unchanged: view, permute, view, permute targets = get_compute_nodes(result.graph_module) self.assertEqual(targets[0], exir_ops.edge.aten.view_copy.default) self.assertEqual(targets[1], exir_ops.edge.aten.permute_copy.default) @@ -383,6 +410,67 @@ def test_replace_nop_transpose_with_view_float(self) -> None: gm_before, gm_after, [x], "ReplaceNopTransposeOrPermuteWithViewPass" ) + def test_replace_nop_transpose_with_view_int(self) -> None: + x = torch.randint(low=0, high=100, size=(2, 1, 5), dtype=torch.int64) + gm = single_op_builder( + placeholders=(x,), + op=exir_ops.edge.aten.transpose_copy.int, + args=(x, 1, 0), + ) + gm_before = copy.deepcopy(gm) + + p = ReplaceNopTransposeOrPermuteWithViewPass() + result = cast(PassResult, p(gm)) + self.assertTrue(result.modified) + gm_after = result.graph_module + self.assertEqual(count_node(gm_after, exir_ops.edge.aten.transpose_copy.int), 0) + self.assertEqual(count_node(gm_after, exir_ops.edge.aten.view_copy.default), 1) + validate_numerics( + gm_before, gm_after, [x], "ReplaceNopTransposeOrPermuteWithViewPass" + ) + + def test_replace_nop_permute_5d(self) -> None: + x = torch.randn(3, 1, 3, 1, 4) + gm = single_op_builder( + placeholders=(x,), + op=exir_ops.edge.aten.permute_copy.default, + args=(x, [0, 2, 4, 1, 3]), + ) + gm_before = copy.deepcopy(gm) + + p = ReplaceNopTransposeOrPermuteWithViewPass() + result = cast(PassResult, p(gm)) + self.assertTrue(result.modified) + gm_after = result.graph_module + self.assertEqual( + count_node(gm_after, exir_ops.edge.aten.permute_copy.default), 0 + ) + self.assertEqual(count_node(gm_after, exir_ops.edge.aten.view_copy.default), 1) + validate_numerics( + gm_before, gm_after, [x], "ReplaceNopTransposeOrPermuteWithViewPass" + ) + + def test_replace_nop_permute_3d(self) -> None: + x = torch.randn(1, 3, 4) + gm = single_op_builder( + placeholders=(x,), + op=exir_ops.edge.aten.permute_copy.default, + args=(x, [1, 2, 0]), + ) + gm_before = copy.deepcopy(gm) + + p = ReplaceNopTransposeOrPermuteWithViewPass() + result = cast(PassResult, p(gm)) + self.assertTrue(result.modified) + gm_after = result.graph_module + self.assertEqual( + count_node(gm_after, exir_ops.edge.aten.permute_copy.default), 0 + ) + self.assertEqual(count_node(gm_after, exir_ops.edge.aten.view_copy.default), 1) + validate_numerics( + gm_before, gm_after, [x], "ReplaceNopTransposeOrPermuteWithViewPass" + ) + # ────────────────────────────────────────────────────────────────────── # Tests for RemovePermutesAroundElementwiseOps cross-view handling @@ -458,64 +546,3 @@ def test_4d_permute_squeeze_clamp_3d_permute(self) -> None: gm_before, result.graph_module, [x_data], "4D_permute_squeeze_clamp_3D_permute", ) - - def test_replace_nop_transpose_with_view_int(self) -> None: - x = torch.randint(low=0, high=100, size=(2, 1, 5), dtype=torch.int64) - gm = single_op_builder( - placeholders=(x,), - op=exir_ops.edge.aten.transpose_copy.int, - args=(x, 1, 0), - ) - gm_before = copy.deepcopy(gm) - - p = ReplaceNopTransposeOrPermuteWithViewPass() - result = cast(PassResult, p(gm)) - self.assertTrue(result.modified) - gm_after = result.graph_module - self.assertEqual(count_node(gm_after, exir_ops.edge.aten.transpose_copy.int), 0) - self.assertEqual(count_node(gm_after, exir_ops.edge.aten.view_copy.default), 1) - validate_numerics( - gm_before, gm_after, [x], "ReplaceNopTransposeOrPermuteWithViewPass" - ) - - def test_replace_nop_permute_5d(self) -> None: - x = torch.randn(3, 1, 3, 1, 4) - gm = single_op_builder( - placeholders=(x,), - op=exir_ops.edge.aten.permute_copy.default, - args=(x, [0, 2, 4, 1, 3]), - ) - gm_before = copy.deepcopy(gm) - - p = ReplaceNopTransposeOrPermuteWithViewPass() - result = cast(PassResult, p(gm)) - self.assertTrue(result.modified) - gm_after = result.graph_module - self.assertEqual( - count_node(gm_after, exir_ops.edge.aten.permute_copy.default), 0 - ) - self.assertEqual(count_node(gm_after, exir_ops.edge.aten.view_copy.default), 1) - validate_numerics( - gm_before, gm_after, [x], "ReplaceNopTransposeOrPermuteWithViewPass" - ) - - def test_replace_nop_permute_3d(self) -> None: - x = torch.randn(1, 3, 4) - gm = single_op_builder( - placeholders=(x,), - op=exir_ops.edge.aten.permute_copy.default, - args=(x, [1, 2, 0]), - ) - gm_before = copy.deepcopy(gm) - - p = ReplaceNopTransposeOrPermuteWithViewPass() - result = cast(PassResult, p(gm)) - self.assertTrue(result.modified) - gm_after = result.graph_module - self.assertEqual( - count_node(gm_after, exir_ops.edge.aten.permute_copy.default), 0 - ) - self.assertEqual(count_node(gm_after, exir_ops.edge.aten.view_copy.default), 1) - validate_numerics( - gm_before, gm_after, [x], "ReplaceNopTransposeOrPermuteWithViewPass" - )