Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,6 @@
from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec
from executorch.backends.arm.common.pipeline_config import (
ArmPassPipelineConfig,
FuseDuplicateUsersConfig,
SoftmaxDecompositionConfig,
)
from executorch.backends.arm.tosa.specification import (
Expand Down Expand Up @@ -238,9 +237,6 @@ def configure_skip_passes(
case SoftmaxDecompositionConfig.STABLE:
skip_set.add(DecomposeMaskedFillPass)

if config.fuse_duplicate_users is FuseDuplicateUsersConfig.DISABLED:
skip_set.add(FuseDuplicateUsersPass)

self._skip_pass_types = tuple(skip_set)
skip_names = [skipped_pass.__name__ for skipped_pass in self._skip_pass_types]
logger.debug(f"Passes in skip list: {skip_names}")
Expand Down
14 changes: 1 addition & 13 deletions backends/arm/common/pipeline_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,12 @@ class SoftmaxDecompositionConfig(Enum):
STABLE = auto() # Stable softmax, no masked fill decomposition


class FuseDuplicateUsersConfig(Enum):
ENABLED = auto()
DISABLED = auto()


@dataclass
class ArmPassPipelineConfig:
softmax: SoftmaxDecompositionConfig = SoftmaxDecompositionConfig.MASKED
fuse_duplicate_users: FuseDuplicateUsersConfig = FuseDuplicateUsersConfig.ENABLED

def disable_fuse_duplicate_users(self) -> None:
self.fuse_duplicate_users = FuseDuplicateUsersConfig.DISABLED

def is_default(self) -> bool:
return (
self.softmax is SoftmaxDecompositionConfig.MASKED
and self.fuse_duplicate_users is FuseDuplicateUsersConfig.ENABLED
)
return self.softmax is SoftmaxDecompositionConfig.MASKED

def to_dict(self) -> dict[str, str]:
return {f.name: getattr(self, f.name).name for f in fields(self)}
Expand Down
9 changes: 3 additions & 6 deletions backends/arm/test/misc/test_compile_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,7 @@

import warnings

from executorch.backends.arm.common.pipeline_config import (
FuseDuplicateUsersConfig,
SoftmaxDecompositionConfig,
)
from executorch.backends.arm.common.pipeline_config import SoftmaxDecompositionConfig
from executorch.backends.arm.ethosu import EthosUCompileSpec
from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec
from executorch.backends.arm.vgf import VgfCompileSpec
Expand Down Expand Up @@ -66,11 +63,11 @@ def test_compile_spec_vgf_no_quant():
EthosUCompileSpec._from_list(spec_list)


def test_compile_spec_vgf_defaults_to_enabled_fuse_duplicate_users():
def test_compile_spec_vgf_uses_default_pipeline_config():
compile_spec = VgfCompileSpec()
pipeline_config = compile_spec._get_pass_pipeline_config()

assert pipeline_config.fuse_duplicate_users == FuseDuplicateUsersConfig.ENABLED
assert pipeline_config.is_default()


def test_compile_spec_tosa_INT():
Expand Down
7 changes: 3 additions & 4 deletions backends/arm/test/misc/test_pass_pipeline_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,13 @@ def test_pipeline_config_override_outside_compile_spec():
override_compile_spec = TosaCompileSpec(
TosaSpecification.create_from_string("TOSA-1.00+INT")
)
override_config = ArmPassPipelineConfig()
override_config.disable_fuse_duplicate_users()
override_config = ArmPassPipelineConfig(softmax=SoftmaxDecompositionConfig.STABLE)
override_compile_spec.set_pass_pipeline_config(override_config)
override_manager = ArmPassManager(override_compile_spec)
skip_passes = override_manager._skip_pass_types

assert FuseDuplicateUsersPass in skip_passes
assert DecomposeSoftmaxPass not in skip_passes
assert FuseDuplicateUsersPass not in skip_passes
assert DecomposeMaskedFillPass in skip_passes


def test_softmax_config_masked_no_target():
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/test/misc/test_tosa_dialect_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# LICENSE file in the root directory of this source tree.


import executorch.backends.arm.tosa.dialect # noqa: unused
import executorch.backends.arm.tosa.dialect # noqa: F401
import pytest
import torch
from executorch.backends.arm.tosa.dialect.lib import TosaValueError
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/test/misc/test_tosa_dialect_dw_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# LICENSE file in the root directory of this source tree.


import executorch.backends.arm.tosa.dialect # noqa: unused
import executorch.backends.arm.tosa.dialect # noqa: F401
import pytest
import torch
from executorch.backends.arm.tosa.dialect.lib import TosaValueError
Expand Down
3 changes: 0 additions & 3 deletions backends/arm/tosa/compile_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@
# LICENSE file in the root directory of this source tree.

from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec
from executorch.backends.arm.common.pipeline_config import ( # noqa: unused
ArmPassPipelineConfig,
)
from executorch.backends.arm.tosa import TosaSpecification


Expand Down
3 changes: 0 additions & 3 deletions backends/arm/vgf/compile_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@
import logging

from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec
from executorch.backends.arm.common.pipeline_config import ( # noqa: unused
ArmPassPipelineConfig,
)
from executorch.backends.arm.tosa import ( # type: ignore[import-not-found]
TosaSpecification,
)
Expand Down
Loading