From d72f45bad10324f30a6f7351059ba06b60ceac4a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20Lindstr=C3=B6m?= Date: Tue, 24 Mar 2026 14:38:18 +0100 Subject: [PATCH] Arm backend: Remove test name convention and validation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit It has been decided to remove test name convention because the benefit of it is not worth the maintenance cost. This patch removes test the validation script that was run by the pre-push hook. Signed-off-by: Martin Lindström Change-Id: I129db66db0d0c129313e41db4a719a77d74a8d96 --- backends/arm/scripts/pre-push | 12 - .../collect_testname_resources.py | 194 ------------- .../scripts/testname_rules/collect_tests.py | 74 ----- .../testname_rules/parse_test_names.py | 272 ------------------ .../testname_rules/validate_test_names.py | 90 ------ 5 files changed, 642 deletions(-) delete mode 100644 backends/arm/scripts/testname_rules/collect_testname_resources.py delete mode 100644 backends/arm/scripts/testname_rules/collect_tests.py delete mode 100644 backends/arm/scripts/testname_rules/parse_test_names.py delete mode 100644 backends/arm/scripts/testname_rules/validate_test_names.py diff --git a/backends/arm/scripts/pre-push b/backends/arm/scripts/pre-push index fa8e0cd2ccd..e1a950237bd 100755 --- a/backends/arm/scripts/pre-push +++ b/backends/arm/scripts/pre-push @@ -302,18 +302,6 @@ for COMMIT in ${COMMITS}; do fi fi - # Test name checks - test_files=$(echo $commit_files | grep -oE 'backends/arm/test/\S+') - if [ "$test_files" ]; then - - # Check that the test name follows the specified convention - python ./backends/arm/scripts/testname_rules/validate_test_names.py $test_files - if [ $? -ne 0 ]; then - echo -e "${ERROR} Failed op test name check." >&2 - FAILED=1 - fi - fi - echo "" # Newline to visually separate commit processing done diff --git a/backends/arm/scripts/testname_rules/collect_testname_resources.py b/backends/arm/scripts/testname_rules/collect_testname_resources.py deleted file mode 100644 index cd37036f3cb..00000000000 --- a/backends/arm/scripts/testname_rules/collect_testname_resources.py +++ /dev/null @@ -1,194 +0,0 @@ -# Copyright 2025-2026 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -import ast -import pathlib -import re - -from executorch.exir.dialects.edge.spec.utils import SAMPLE_INPUT - -# Add all targets and TOSA profiles we support here. -TARGETS = [ - "tosa_FP", - "tosa_INT", - "tosa_INT+FP", - "u55_INT", - "u85_INT", - "vgf_INT", - "vgf_FP", - "vgf_quant", - "vgf_no_quant", - "no_target", -] - -# Add edge ops which we lower but which are not included in exir/dialects/edge/edge.yaml here. -_CUSTOM_EDGE_OPS = [ - "linspace.default", - "cond.default", - "eye.default", - "expm1.default", - "erfinv.default", - "gather.default", - "vector_norm.default", - "hardsigmoid.default", - "hardswish.default", - "linear.default", - "maximum.default", - "mean.default", - "log1p.default", - "multihead_attention.default", - "adaptive_avg_pool2d.default", - "bitwise_right_shift.Tensor", - "bitwise_right_shift.Scalar", - "bitwise_left_shift.Tensor", - "bitwise_left_shift.Scalar", - "native_group_norm.default", - "silu.default", - "sdpa.default", - "sum.default", - "unbind.int", - "unflatten.int", - "unfold_copy.default", - "_native_batch_norm_legit_no_training.default", - "_native_batch_norm_legit.no_stats", - "alias_copy.default", - "pixel_shuffle.default", - "pixel_unshuffle.default", - "while_loop.default", - "matmul.default", - "upsample_bilinear2d.vec", - "upsample_nearest2d.vec", - "index_put.default", - "conv_transpose2d.default", - "index_copy.default", -] -_ALL_EDGE_OPS = SAMPLE_INPUT.keys() | _CUSTOM_EDGE_OPS - -_NON_ARM_PASSES = ["quantize_io_pass"] - -_MODEL_ENTRY_PATTERN = re.compile(r"^\s*(?:[-*]|\d+\.)\s+(?P.+?)\s*$") -_NUMERIC_SERIES_PATTERN = re.compile(r"(\d+)(?=[a-z])") -_CAMEL_BOUNDARY = re.compile( - r"(? set[str]: - names: set[str] = set() - names.update(_extract_pass_names_from_init(init_path)) - names.update(_NON_ARM_PASSES) - return {_separate_numeric_series(_strip_pass_suffix(name)) for name in names} - - -def _extract_pass_names_from_init(init_path: pathlib.Path) -> set[str]: - source = init_path.read_text(encoding="utf-8") - module = ast.parse(source, filename=str(init_path)) - names: set[str] = set() - - for node in module.body: - if not isinstance(node, ast.ImportFrom): - continue - for alias in node.names: - candidate = alias.asname or alias.name - if not candidate or not candidate.endswith("Pass"): - continue - if candidate == "ArmPass": - continue - names.add(_camel_to_snake(candidate)) - return names - - -def _strip_pass_suffix(name: str) -> str: - return name[:-5] if name.endswith("_pass") else name - - -def _separate_numeric_series(name: str) -> str: - def repl(match: re.Match[str]) -> str: - next_index = match.end() - next_char = match.string[next_index] if next_index < len(match.string) else "" - if next_char == "d": # Avoid creating patterns like 3_d - return match.group(1) - return f"{match.group(1)}_" - - return _NUMERIC_SERIES_PATTERN.sub(repl, name) - - -def _collect_arm_models(models_md: pathlib.Path) -> set[str]: - models: set[str] = set() - for line in models_md.read_text(encoding="utf-8").splitlines(): - stripped = line.strip() - if not stripped or stripped.startswith("#"): - continue - match = _MODEL_ENTRY_PATTERN.match(line) - if not match: - continue - base, alias, is_parent = _split_model_entry(match.group("entry")) - if is_parent: - continue - if alias: - models.add(_normalize_model_entry(alias)) - else: - models.add(_normalize_model_entry(base)) - - if not models: - raise RuntimeError(f"No supported models found in {models_md}") - return models - - -def _normalize_op_name(edge_name: str) -> str: - op, overload = edge_name.split(".") - - # There are ops where we want to keep "copy" in the name - # Add them in this list as we encounter them - ignore_copy_list = {"index_copy"} - - op = op.lower() - op = op.removeprefix("_") - - if op not in ignore_copy_list: - op = op.removesuffix("_copy") - - op = op.removesuffix("_with_indices") - - overload = overload.lower() - if overload == "default": - return op - else: - return f"{op}_{overload}" - - -def _split_model_entry(entry: str) -> tuple[str, str | None, bool]: - entry = entry.strip() - if not entry: - return "", None, False - is_parent = entry.endswith(":") - if is_parent: - entry = entry[:-1].rstrip() - if "(" in entry and entry.endswith(")"): - base, _, rest = entry.partition("(") - alias = rest[:-1].strip() - return base.strip(), alias or None, is_parent - return entry, None, is_parent - - -def _normalize_model_entry(name: str) -> str: - cleaned = name.lower() - cleaned = re.sub(r"[^a-z0-9\s]", "", cleaned) - cleaned = re.sub(r"\s+", " ", cleaned).strip() - return cleaned.replace(" ", "_") - - -def _camel_to_snake(name: str) -> str: - if not name: - return "" - name = name.replace("-", "_").replace(" ", "_") - return _CAMEL_BOUNDARY.sub("_", name).lower() - - -OP_NAME_MAP = {_normalize_op_name(edge_name): edge_name for edge_name in _ALL_EDGE_OPS} -OP_LIST = sorted({_normalize_op_name(edge_name) for edge_name in _ALL_EDGE_OPS}) -PASS_LIST = sorted( - _collect_arm_passes(pathlib.Path("backends/arm/_passes/__init__.py")) -) -MODEL_LIST = sorted(_collect_arm_models(pathlib.Path("backends/arm/MODELS.md"))) diff --git a/backends/arm/scripts/testname_rules/collect_tests.py b/backends/arm/scripts/testname_rules/collect_tests.py deleted file mode 100644 index d71d2723bbf..00000000000 --- a/backends/arm/scripts/testname_rules/collect_tests.py +++ /dev/null @@ -1,74 +0,0 @@ -# Copyright 2025-2026 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -from __future__ import annotations - -import ast -import logging -import pathlib - - -LOGGER = logging.getLogger(__name__) - -ALLOWED_DIRNAMES = ("misc", "passes", "models", "quantizer", "ops") - - -class TestCollector(ast.NodeVisitor): - def __init__(self, path: pathlib.Path, collected: list[str]): - self.path = path - self._collected = collected - self._class_stack: list[str] = [] - - def visit_ClassDef(self, node: ast.ClassDef): - self._class_stack.append(node.name) - self.generic_visit(node) - self._class_stack.pop() - - def visit_FunctionDef(self, node: ast.FunctionDef): - self._record_test(node) - self.generic_visit(node) - - def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef): - self._record_test(node) - self.generic_visit(node) - - def _record_test(self, node: ast.AST): - name = getattr(node, "name", "") - if name.startswith("test_"): - self._collected.append(str(self.path) + "::" + name) - - -def collect_test_files(test_root: pathlib.Path): - search_dirs = [] - for dirname in ALLOWED_DIRNAMES: - dir_path = test_root / dirname - if dir_path.is_dir(): - search_dirs.append(dir_path) - else: - LOGGER.warning("skipped missing directory %s", dir_path) - - file_paths: list[pathlib.Path] = [] - for dir_path in search_dirs: - file_paths.extend(dir_path.rglob("test_*.py")) - return sorted(file_paths) - - -def collect_tests(file_paths: list[pathlib.Path]) -> list[str]: - tests: list[str] = [] - for file_path in file_paths: - try: - source = file_path.read_text(encoding="utf-8") - except OSError as error: - LOGGER.warning("failed to read %s: %s", file_path, error) - continue - - try: - tree = ast.parse(source, filename=str(file_path)) - except SyntaxError as error: - LOGGER.warning("failed to parse %s: %s", file_path, error) - continue - - TestCollector(file_path, tests).visit(tree) - - return tests diff --git a/backends/arm/scripts/testname_rules/parse_test_names.py b/backends/arm/scripts/testname_rules/parse_test_names.py deleted file mode 100644 index 4b9fa991768..00000000000 --- a/backends/arm/scripts/testname_rules/parse_test_names.py +++ /dev/null @@ -1,272 +0,0 @@ -# Copyright 2025-2026 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from __future__ import annotations - -import difflib -import logging - -from executorch.backends.arm.scripts.testname_rules.collect_testname_resources import ( - MODEL_LIST, - OP_LIST, - PASS_LIST, - TARGETS, -) - - -logger = logging.getLogger(__name__) - - -class TestNameViolation: - def __init__(self, test_name: str, message: str): - self.test_name = test_name - self.message = message - - def __str__(self) -> str: - msg_indented = "\n".join(" " + line for line in self.message.splitlines()) - return f"Invalid test name for {self.test_name}\n{msg_indented}\n" - - def __repr__(self) -> str: - return self.__str__() - - -def _match_allowed_op_prefix( - test_name: str, -) -> tuple[str | None, str | None, bool, bool]: - test_name = test_name.removeprefix("test_") - is_delegated = "not_delegated" not in test_name - is_16x8_quantized = False - - op = None - target = None - for potential_target in TARGETS: - index = test_name.find(potential_target) - if index != -1: - op = test_name[: index - 1] - target = potential_target - if ("16a8w" in test_name) or ("a16w8" in test_name): - is_16x8_quantized = True - break - - if op is not None: - # Special case for convolution - op = op.removesuffix("_1d") - op = op.removesuffix("_2d") - op = op.removesuffix("_3d") - - # Remove suffix for 16 bit activation and 8 bit weight test cases - op = op.removesuffix("_16a8w") - - return op, target, is_16x8_quantized, is_delegated - - -def _match_allowed_model_prefix(token: str, allowed_models: list[str]) -> str | None: - for allowed_model in allowed_models: - if token == allowed_model: - return allowed_model - return None - - -def _match_allowed_pass_prefix(token: str, allowed_passes: list[str]) -> str | None: - for allowed_pass in allowed_passes: - if token == allowed_pass: - return allowed_pass - return None - - -def _extract_target(name: str) -> str | None: - # The target is the last supported target token in the name, optionally - # followed by extra suffix data such as "_not_delegated". - for target in TARGETS: - marker = f"_{target}" - idx = name.rfind(marker) - if idx == -1: - continue - - suffix_idx = idx + len(marker) - if suffix_idx == len(name) or name[suffix_idx] == "_": - return target - return None - - -def _parse_test_name_tokens(name: str) -> tuple[str, str | None]: - rest = name[5:] - target = _extract_target(name) - token = rest - if target: - idx = rest.rfind(target) - token = rest[:idx].rstrip("_") - - return token, target - - -def _get_parsing_info(kind: str, name: str) -> str: - token, target = _parse_test_name_tokens(name) - - return f"{kind} token parsed as '{token}'\n" f"TARGET token parsed as '{target}'" - - -def parse_op_test(test_name: str) -> tuple[str, str, bool, bool] | TestNameViolation: - matched_op, target, quantized_16x8, delegated = _match_allowed_op_prefix(test_name) - - if "reject" in test_name: - return TestNameViolation( - test_name, - "Use 'not_delegated' instead of 'reject' in test names", - ) - - if not matched_op: - parsing_info = _get_parsing_info("OP", test_name) - return TestNameViolation( - test_name, - ( - f"Expected test_OP_TARGET_*\n" - f"OP token not found or invalid\n" - f"{parsing_info}" - ), - ) - - if target is None: - parsing_info = _get_parsing_info("OP", test_name) - return TestNameViolation( - test_name, - ( - f"Expected test_OP_TARGET_*\n" - f"TARGET is None (valid targets: {TARGETS}))\n" - f"{parsing_info}" - ), - ) - - if matched_op not in OP_LIST: - parsing_info = _get_parsing_info("OP", test_name) - closest_match = difflib.get_close_matches(matched_op, OP_LIST, n=1, cutoff=0.0)[ - 0 - ] - return TestNameViolation( - test_name, - ( - f"Expected test_OP_TARGET_*\n" - f"OP '{matched_op}' not recognized (closest match: {closest_match})\n" - f"{parsing_info}" - ), - ) - - result = (matched_op, target, quantized_16x8, delegated) - logger.debug('Parsed op test "%s": %s', test_name, result) - - return result - - -def parse_model_test(test_name: str) -> tuple[str, str] | TestNameViolation: - token, target = _parse_test_name_tokens(test_name) - - if not token: - parsing_info = _get_parsing_info("MODEL", test_name) - return TestNameViolation( - test_name, - ( - f"Expected test_MODEL_TARGET_*\n" - f"MODEL token not found or invalid\n" - f"{parsing_info}\n" - ), - ) - - if not target: - parsing_info = _get_parsing_info("MODEL", test_name) - return TestNameViolation( - test_name, - ( - f"Expected test_MODEL_TARGET_*\n" - f"TARGET token not found (valid targets: {TARGETS})\n" - f"{parsing_info}" - ), - ) - - matched_model = _match_allowed_model_prefix(token, MODEL_LIST) - if matched_model is None: - parsing_info = _get_parsing_info("MODEL", test_name) - closest_match = difflib.get_close_matches(token, MODEL_LIST, n=1, cutoff=0.0)[0] - return TestNameViolation( - test_name, - ( - f"Expected test_MODEL_TARGET_*\n" - f"MODEL {token} not recognized (closest match: {closest_match})\n" - f"{parsing_info}" - ), - ) - - result = (token, target) - logger.debug('Parsed model test "%s": %s', test_name, result) - - return result - - -def parse_pass_test(test_name: str) -> tuple[str, str] | TestNameViolation: - pass_, target = _parse_test_name_tokens(test_name) - - if not pass_: - parsing_info = _get_parsing_info("PASS", test_name) - return TestNameViolation( - test_name, - ( - f"Expected test_PASS_TARGET_*\n" - f"PASS token not found or invalid\n" - f"{parsing_info}\n" - ), - ) - - if not target: - parsing_info = _get_parsing_info("PASS", test_name) - return TestNameViolation( - test_name, - ( - f"Expected test_PASS_TARGET_*\n" - f"TARGET token not found (valid targets: {TARGETS})\n" - f"{parsing_info}" - ), - ) - - matched_pass = _match_allowed_pass_prefix(pass_, PASS_LIST) - if matched_pass is None: - parsing_info = _get_parsing_info("PASS", test_name) - closest_match = difflib.get_close_matches(pass_, PASS_LIST, n=1, cutoff=0.0)[0] - return TestNameViolation( - test_name, - ( - f"Expected test_PASS_TARGET_* with PASS in PASS_LIST\n" - f"PASS '{pass_} not recognized (closest match: {closest_match})'\n" - f"{parsing_info}" - ), - ) - - result = (pass_, target) - logger.debug('Parsed pass test "%s": %s', test_name, result) - return result - - -def parse_general_test(test_name: str) -> tuple[str, str] | TestNameViolation: - name, target = _parse_test_name_tokens(test_name) - - if not name: - parsing_info = _get_parsing_info("NAME", test_name) - return TestNameViolation( - test_name, - f"Expected test_*_TARGET_*\n" "Invalid NAME token\n" f"{parsing_info}", - ) - - if not target: - parsing_info = _get_parsing_info("NAME", test_name) - return TestNameViolation( - test_name, - ( - "Expected test_*_TARGET_*\n" - f"TARGET token not found (valid targets: {TARGETS})\n" - f"{parsing_info}" - ), - ) - - result = (name, target) - logger.debug('Parsed general test "%s": %s', test_name, result) - return result diff --git a/backends/arm/scripts/testname_rules/validate_test_names.py b/backends/arm/scripts/testname_rules/validate_test_names.py deleted file mode 100644 index 5a64b8b6cd7..00000000000 --- a/backends/arm/scripts/testname_rules/validate_test_names.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright 2025-2026 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from __future__ import annotations - -import logging -import sys -from pathlib import Path - -from executorch.backends.arm.scripts.testname_rules.collect_tests import ( - collect_test_files, - collect_tests, -) - -from executorch.backends.arm.scripts.testname_rules.parse_test_names import ( - parse_general_test, - parse_model_test, - parse_op_test, - parse_pass_test, - TestNameViolation, -) - -LOGGER = logging.getLogger(__name__) - - -def _is_in_path(child: str, parent: str) -> bool: - """Returns True if 'child' path is inside 'parent' path.""" - child = Path(child).resolve() - parent = Path(parent).resolve() - return parent == child or parent in child.parents - - -def check_test_name_validations( - tests: list[str], -) -> list[TestNameViolation]: - violations: list[TestNameViolation] = [] - for test in tests: - path, test_name = test.split("::") - result: tuple[str, str, bool, bool] | tuple[str, str] | TestNameViolation | None - - if _is_in_path(path, "backends/arm/test/ops"): - result = parse_op_test(test_name) - elif _is_in_path(path, "backends/arm/test/models"): - result = parse_model_test(test_name) - elif _is_in_path(path, "backends/arm/test/passes"): - result = parse_pass_test(test_name) - elif _is_in_path(path, "backends/arm/test/quantizer") or _is_in_path( - path, "backends/arm/test/misc" - ): - result = parse_general_test(test_name) - else: - result = None - - if isinstance(result, TestNameViolation): - violations.append(result) - - return violations - - -def main() -> int: - logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") - - tests: list[str] = [] - path_list: list[Path] = [Path(path) for path in sys.argv[1:]] - - if path_list == []: - tests = collect_tests(collect_test_files(Path("backends/arm/test"))) - else: - tests = collect_tests(path_list) - - violations = check_test_name_validations(tests) - - for entry in violations: - LOGGER.error("%s", entry) - - LOGGER.info("Total tests needing renaming: %d", len(violations)) - - if violations: - LOGGER.info( - "Please follow the test naming convention: https://confluence.arm.com/display/MLENG/Executorch+naming+conventions" - ) - return 1 - - return 0 - - -if __name__ == "__main__": - raise SystemExit(main())