Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,10 +161,6 @@
from executorch.backends.transforms.fuse_cascaded_transpose_or_permute_ops import (
FuseCascadedTransposeOrPermuteOps,
)
from executorch.backends.transforms.postpone_permute_below_squeeze_view import (
PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView,
)

from executorch.exir import ExportedProgram
from executorch.exir.pass_base import ExportPass
from executorch.exir.pass_manager import PassManager
Expand Down Expand Up @@ -538,7 +534,6 @@ def _tosa_pipeline(
RewritePadPass(),
FuseViewCopyTransformPass(),
RemovePermutesAroundElementwiseTosaOps(),
PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView(),
FuseCascadedTransposeOrPermuteOps(),
ConvertPermuteSingletonToViewPass(),
RewriteHighRankSingletonPermutePass(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@


class RemovePermutesAroundElementwiseTosaOps(RemovePermutesAroundElementwiseOps):
permutable_ops = {
*RemovePermutesAroundElementwiseOps.permutable_ops,
*TableOps.unary_table_ops.keys(),
*TableOps.special_table_ops,
exir_ops.backend.tosa.RESCALE.default,
exir_ops.backend.tosa.TABLE.default,
}
def __init__(self) -> None:
super().__init__(
extra_permutable_ops={
*TableOps.unary_table_ops.keys(),
*TableOps.special_table_ops,
exir_ops.backend.tosa.RESCALE.default,
exir_ops.backend.tosa.TABLE.default,
}
)

def permute_subgraph(self, subgraph):
# Original function will always permute constant nodes which is wrong for table ops
Expand Down
20 changes: 10 additions & 10 deletions backends/cadence/aot/remove_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,16 +603,16 @@ def maybe_remove_or_replace(self, node: Node) -> bool:

@register_cadence_pass(CadencePassAttribute(opt_level=2))
class RemovePermutesAroundElementwiseOps(_SharedRemovePermutesAroundElementwiseOps):
permutable_ops: set[EdgeOpOverload] = (
_SharedRemovePermutesAroundElementwiseOps.permutable_ops
| {
exir_ops.edge.cadence.quantize_per_tensor.default,
exir_ops.edge.cadence.dequantize_per_tensor.default,
exir_ops.edge.cadence.quantized_relu.per_tensor,
exir_ops.edge.cadence.requantize.per_tensor,
exir_ops.edge.cadence.quantized_add.per_tensor,
}
)
def __init__(self) -> None:
super().__init__(
extra_permutable_ops={
exir_ops.edge.cadence.quantize_per_tensor.default,
exir_ops.edge.cadence.dequantize_per_tensor.default,
exir_ops.edge.cadence.quantized_relu.per_tensor,
exir_ops.edge.cadence.requantize.per_tensor,
exir_ops.edge.cadence.quantized_add.per_tensor,
}
)


@register_cadence_pass(CadencePassAttribute(opt_level=2))
Expand Down
123 changes: 114 additions & 9 deletions backends/transforms/fuse_cascaded_transpose_or_permute_ops.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
Expand All @@ -20,28 +20,40 @@
class FuseCascadedTransposeOrPermuteOps(RemoveOrReplacePassInterface):
"""
Fuse a chain of transpose and permute ops into a single permute or a no-op.
Handles branches and chains permutes.
Handles branches and chains of permutes, including permute-view-permute
patterns where a squeeze/unsqueeze view sits between two permutes.
"""

transpose_or_permute_target = {
exir_ops.edge.aten.transpose_copy.int,
exir_ops.edge.aten.permute_copy.default,
}

_VIEW_OPS = {
exir_ops.edge.aten.view_copy.default,
exir_ops.edge.aten.view.default,
}

@property
def targets(self) -> list[EdgeOpOverload]:
return list(self.transpose_or_permute_target)

def maybe_remove_or_replace(self, node: Node) -> bool:
# Fuse with the parent node if it's also a permute or a transpose. Since the
# pass interface traverses all ops in order the pass will properly fuse a chain
# of permutes.
parent_node = get_arg(node, "input", Node)
if parent_node.target not in self.transpose_or_permute_target:
return False
input_of_parent = get_arg(parent_node, "input", Node)

# Compute combined effect of permutes.
# Case 1: Direct permute/transpose → permute/transpose
if parent_node.target in self.transpose_or_permute_target:
return self._fuse_direct(node, parent_node)

# Case 2: permute → view_copy(squeeze/unsqueeze) → permute
if parent_node.target in self._VIEW_OPS:
return self._fuse_across_view(node, parent_node)

return False

def _fuse_direct(self, node: Node, parent_node: Node) -> bool:
"""Fuse two adjacent permute/transpose ops."""
input_of_parent = get_arg(parent_node, "input", Node)
dims = list(range(node.meta["val"].ndim))

if parent_node.target == exir_ops.edge.aten.transpose_copy.int:
Expand All @@ -54,7 +66,6 @@
else:
dims = get_permuted_dims(node, dims)

# If combined effect is identity replace the node with input.
if dims == sorted(dims):
node.replace_all_uses_with(input_of_parent)
else:
Expand All @@ -67,3 +78,97 @@
node.replace_all_uses_with(new_permute)

return True

def _fuse_across_view(self, node: Node, view_node: Node) -> bool:

Check warning on line 82 in backends/transforms/fuse_cascaded_transpose_or_permute_ops.py

View workflow job for this annotation

GitHub Actions / lintrunner

FLAKE8 C901

'FuseCascadedTransposeOrPermuteOps._fuse_across_view' is too complex (15) See https://www.flake8rules.com/rules/C901.html.
"""Fuse permute -> view(squeeze/unsqueeze) -> permute into a view_copy."""
# view_node must have exactly one user (this permute node)
if len(list(view_node.users.keys())) != 1:
return False
# view_node's parent must be a permute/transpose
view_input = get_arg(view_node, "input", Node)
if view_input.target not in self.transpose_or_permute_target:
return False
# The view must be a squeeze or unsqueeze (rank differs by 1)
if view_node.meta.get("val") is None or view_input.meta.get("val") is None:
return False
view_in_shape = list(view_input.meta["val"].shape)
view_out_shape = list(view_node.meta["val"].shape)
if abs(len(view_in_shape) - len(view_out_shape)) != 1:
return False

# Get the input before the first permute
input_of_first_permute = get_arg(view_input, "input", Node)
if input_of_first_permute.meta.get("val") is None:
return False

# Compute the combined effect on the original input dimensions
# Start with identity dims for the original input
original_ndim = input_of_first_permute.meta["val"].ndim
dims = list(range(original_ndim))

# Apply first permute
if view_input.target == exir_ops.edge.aten.transpose_copy.int:
dims = get_transposed_dims(view_input, dims)
else:
dims = get_permuted_dims(view_input, dims)

# Apply the view (squeeze/unsqueeze)
if len(view_out_shape) == len(view_in_shape) + 1:
# unsqueeze: insert a new dim
index = self._find_extra_one(view_out_shape, view_in_shape)
if index == -1:
return False
dims = [x + 1 if x >= index else x for x in dims]
dims.insert(index, -1) # -1 marks the inserted dim
elif len(view_in_shape) == len(view_out_shape) + 1:
# squeeze: remove a dim
index = self._find_extra_one(view_in_shape, view_out_shape)
if index == -1:
return False
if dims[index] != -1:
# Safe: permutation preserves dimension sizes, so a size-1
# intermediate dim necessarily originated from a size-1 input dim.
pass
del dims[index]

# Apply second permute (node)
if node.target == exir_ops.edge.aten.transpose_copy.int:
node_dims = list(range(len(dims)))
node_dims = get_transposed_dims(node, node_dims)
dims = [dims[d] for d in node_dims]
else:
perm = list(node.args[1])
dims = [dims[d] for d in perm]

# Check if the combined effect (ignoring -1 inserted dims) is identity
real_dims = [d for d in dims if d != -1]
output_shape = list(node.meta["val"].shape)

if real_dims == sorted(real_dims):
# Combined permutations are identity — replace with view_copy
# (the only remaining effect is the squeeze/unsqueeze reshape)
if output_shape == list(input_of_first_permute.meta["val"].shape):
# Total no-op: replace with input
node.replace_all_uses_with(input_of_first_permute)
else:
with node.graph.inserting_before(node):
new_view = node.graph.call_function(
exir_ops.edge.aten.view_copy.default,
args=(input_of_first_permute, output_shape),
)
new_view.meta = node.meta
node.replace_all_uses_with(new_view)
return True

return False

@staticmethod
def _find_extra_one(longer, shorter):
if len(longer) != len(shorter) + 1:
return -1
for i in range(len(shorter)):
if longer[i] != shorter[i]:
if longer[i] == 1 and shorter[i:] == longer[i + 1:]:
return i
return -1
return len(shorter) if longer[-1] == 1 else -1
Loading
Loading