diff --git a/backends/arm/scripts/pre-push b/backends/arm/scripts/pre-push index fa8e0cd2ccd..e1a950237bd 100755 --- a/backends/arm/scripts/pre-push +++ b/backends/arm/scripts/pre-push @@ -302,18 +302,6 @@ for COMMIT in ${COMMITS}; do fi fi - # Test name checks - test_files=$(echo $commit_files | grep -oE 'backends/arm/test/\S+') - if [ "$test_files" ]; then - - # Check that the test name follows the specified convention - python ./backends/arm/scripts/testname_rules/validate_test_names.py $test_files - if [ $? -ne 0 ]; then - echo -e "${ERROR} Failed op test name check." >&2 - FAILED=1 - fi - fi - echo "" # Newline to visually separate commit processing done diff --git a/backends/arm/scripts/testname_rules/collect_testname_resources.py b/backends/arm/scripts/testname_rules/collect_testname_resources.py deleted file mode 100644 index cd37036f3cb..00000000000 --- a/backends/arm/scripts/testname_rules/collect_testname_resources.py +++ /dev/null @@ -1,194 +0,0 @@ -# Copyright 2025-2026 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -import ast -import pathlib -import re - -from executorch.exir.dialects.edge.spec.utils import SAMPLE_INPUT - -# Add all targets and TOSA profiles we support here. -TARGETS = [ - "tosa_FP", - "tosa_INT", - "tosa_INT+FP", - "u55_INT", - "u85_INT", - "vgf_INT", - "vgf_FP", - "vgf_quant", - "vgf_no_quant", - "no_target", -] - -# Add edge ops which we lower but which are not included in exir/dialects/edge/edge.yaml here. -_CUSTOM_EDGE_OPS = [ - "linspace.default", - "cond.default", - "eye.default", - "expm1.default", - "erfinv.default", - "gather.default", - "vector_norm.default", - "hardsigmoid.default", - "hardswish.default", - "linear.default", - "maximum.default", - "mean.default", - "log1p.default", - "multihead_attention.default", - "adaptive_avg_pool2d.default", - "bitwise_right_shift.Tensor", - "bitwise_right_shift.Scalar", - "bitwise_left_shift.Tensor", - "bitwise_left_shift.Scalar", - "native_group_norm.default", - "silu.default", - "sdpa.default", - "sum.default", - "unbind.int", - "unflatten.int", - "unfold_copy.default", - "_native_batch_norm_legit_no_training.default", - "_native_batch_norm_legit.no_stats", - "alias_copy.default", - "pixel_shuffle.default", - "pixel_unshuffle.default", - "while_loop.default", - "matmul.default", - "upsample_bilinear2d.vec", - "upsample_nearest2d.vec", - "index_put.default", - "conv_transpose2d.default", - "index_copy.default", -] -_ALL_EDGE_OPS = SAMPLE_INPUT.keys() | _CUSTOM_EDGE_OPS - -_NON_ARM_PASSES = ["quantize_io_pass"] - -_MODEL_ENTRY_PATTERN = re.compile(r"^\s*(?:[-*]|\d+\.)\s+(?P.+?)\s*$") -_NUMERIC_SERIES_PATTERN = re.compile(r"(\d+)(?=[a-z])") -_CAMEL_BOUNDARY = re.compile( - r"(? set[str]: - names: set[str] = set() - names.update(_extract_pass_names_from_init(init_path)) - names.update(_NON_ARM_PASSES) - return {_separate_numeric_series(_strip_pass_suffix(name)) for name in names} - - -def _extract_pass_names_from_init(init_path: pathlib.Path) -> set[str]: - source = init_path.read_text(encoding="utf-8") - module = ast.parse(source, filename=str(init_path)) - names: set[str] = set() - - for node in module.body: - if not isinstance(node, ast.ImportFrom): - continue - for alias in node.names: - candidate = alias.asname or alias.name - if not candidate or not candidate.endswith("Pass"): - continue - if candidate == "ArmPass": - continue - names.add(_camel_to_snake(candidate)) - return names - - -def _strip_pass_suffix(name: str) -> str: - return name[:-5] if name.endswith("_pass") else name - - -def _separate_numeric_series(name: str) -> str: - def repl(match: re.Match[str]) -> str: - next_index = match.end() - next_char = match.string[next_index] if next_index < len(match.string) else "" - if next_char == "d": # Avoid creating patterns like 3_d - return match.group(1) - return f"{match.group(1)}_" - - return _NUMERIC_SERIES_PATTERN.sub(repl, name) - - -def _collect_arm_models(models_md: pathlib.Path) -> set[str]: - models: set[str] = set() - for line in models_md.read_text(encoding="utf-8").splitlines(): - stripped = line.strip() - if not stripped or stripped.startswith("#"): - continue - match = _MODEL_ENTRY_PATTERN.match(line) - if not match: - continue - base, alias, is_parent = _split_model_entry(match.group("entry")) - if is_parent: - continue - if alias: - models.add(_normalize_model_entry(alias)) - else: - models.add(_normalize_model_entry(base)) - - if not models: - raise RuntimeError(f"No supported models found in {models_md}") - return models - - -def _normalize_op_name(edge_name: str) -> str: - op, overload = edge_name.split(".") - - # There are ops where we want to keep "copy" in the name - # Add them in this list as we encounter them - ignore_copy_list = {"index_copy"} - - op = op.lower() - op = op.removeprefix("_") - - if op not in ignore_copy_list: - op = op.removesuffix("_copy") - - op = op.removesuffix("_with_indices") - - overload = overload.lower() - if overload == "default": - return op - else: - return f"{op}_{overload}" - - -def _split_model_entry(entry: str) -> tuple[str, str | None, bool]: - entry = entry.strip() - if not entry: - return "", None, False - is_parent = entry.endswith(":") - if is_parent: - entry = entry[:-1].rstrip() - if "(" in entry and entry.endswith(")"): - base, _, rest = entry.partition("(") - alias = rest[:-1].strip() - return base.strip(), alias or None, is_parent - return entry, None, is_parent - - -def _normalize_model_entry(name: str) -> str: - cleaned = name.lower() - cleaned = re.sub(r"[^a-z0-9\s]", "", cleaned) - cleaned = re.sub(r"\s+", " ", cleaned).strip() - return cleaned.replace(" ", "_") - - -def _camel_to_snake(name: str) -> str: - if not name: - return "" - name = name.replace("-", "_").replace(" ", "_") - return _CAMEL_BOUNDARY.sub("_", name).lower() - - -OP_NAME_MAP = {_normalize_op_name(edge_name): edge_name for edge_name in _ALL_EDGE_OPS} -OP_LIST = sorted({_normalize_op_name(edge_name) for edge_name in _ALL_EDGE_OPS}) -PASS_LIST = sorted( - _collect_arm_passes(pathlib.Path("backends/arm/_passes/__init__.py")) -) -MODEL_LIST = sorted(_collect_arm_models(pathlib.Path("backends/arm/MODELS.md"))) diff --git a/backends/arm/scripts/testname_rules/collect_tests.py b/backends/arm/scripts/testname_rules/collect_tests.py deleted file mode 100644 index d71d2723bbf..00000000000 --- a/backends/arm/scripts/testname_rules/collect_tests.py +++ /dev/null @@ -1,74 +0,0 @@ -# Copyright 2025-2026 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -from __future__ import annotations - -import ast -import logging -import pathlib - - -LOGGER = logging.getLogger(__name__) - -ALLOWED_DIRNAMES = ("misc", "passes", "models", "quantizer", "ops") - - -class TestCollector(ast.NodeVisitor): - def __init__(self, path: pathlib.Path, collected: list[str]): - self.path = path - self._collected = collected - self._class_stack: list[str] = [] - - def visit_ClassDef(self, node: ast.ClassDef): - self._class_stack.append(node.name) - self.generic_visit(node) - self._class_stack.pop() - - def visit_FunctionDef(self, node: ast.FunctionDef): - self._record_test(node) - self.generic_visit(node) - - def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef): - self._record_test(node) - self.generic_visit(node) - - def _record_test(self, node: ast.AST): - name = getattr(node, "name", "") - if name.startswith("test_"): - self._collected.append(str(self.path) + "::" + name) - - -def collect_test_files(test_root: pathlib.Path): - search_dirs = [] - for dirname in ALLOWED_DIRNAMES: - dir_path = test_root / dirname - if dir_path.is_dir(): - search_dirs.append(dir_path) - else: - LOGGER.warning("skipped missing directory %s", dir_path) - - file_paths: list[pathlib.Path] = [] - for dir_path in search_dirs: - file_paths.extend(dir_path.rglob("test_*.py")) - return sorted(file_paths) - - -def collect_tests(file_paths: list[pathlib.Path]) -> list[str]: - tests: list[str] = [] - for file_path in file_paths: - try: - source = file_path.read_text(encoding="utf-8") - except OSError as error: - LOGGER.warning("failed to read %s: %s", file_path, error) - continue - - try: - tree = ast.parse(source, filename=str(file_path)) - except SyntaxError as error: - LOGGER.warning("failed to parse %s: %s", file_path, error) - continue - - TestCollector(file_path, tests).visit(tree) - - return tests diff --git a/backends/arm/scripts/testname_rules/parse_test_names.py b/backends/arm/scripts/testname_rules/parse_test_names.py deleted file mode 100644 index 4b9fa991768..00000000000 --- a/backends/arm/scripts/testname_rules/parse_test_names.py +++ /dev/null @@ -1,272 +0,0 @@ -# Copyright 2025-2026 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from __future__ import annotations - -import difflib -import logging - -from executorch.backends.arm.scripts.testname_rules.collect_testname_resources import ( - MODEL_LIST, - OP_LIST, - PASS_LIST, - TARGETS, -) - - -logger = logging.getLogger(__name__) - - -class TestNameViolation: - def __init__(self, test_name: str, message: str): - self.test_name = test_name - self.message = message - - def __str__(self) -> str: - msg_indented = "\n".join(" " + line for line in self.message.splitlines()) - return f"Invalid test name for {self.test_name}\n{msg_indented}\n" - - def __repr__(self) -> str: - return self.__str__() - - -def _match_allowed_op_prefix( - test_name: str, -) -> tuple[str | None, str | None, bool, bool]: - test_name = test_name.removeprefix("test_") - is_delegated = "not_delegated" not in test_name - is_16x8_quantized = False - - op = None - target = None - for potential_target in TARGETS: - index = test_name.find(potential_target) - if index != -1: - op = test_name[: index - 1] - target = potential_target - if ("16a8w" in test_name) or ("a16w8" in test_name): - is_16x8_quantized = True - break - - if op is not None: - # Special case for convolution - op = op.removesuffix("_1d") - op = op.removesuffix("_2d") - op = op.removesuffix("_3d") - - # Remove suffix for 16 bit activation and 8 bit weight test cases - op = op.removesuffix("_16a8w") - - return op, target, is_16x8_quantized, is_delegated - - -def _match_allowed_model_prefix(token: str, allowed_models: list[str]) -> str | None: - for allowed_model in allowed_models: - if token == allowed_model: - return allowed_model - return None - - -def _match_allowed_pass_prefix(token: str, allowed_passes: list[str]) -> str | None: - for allowed_pass in allowed_passes: - if token == allowed_pass: - return allowed_pass - return None - - -def _extract_target(name: str) -> str | None: - # The target is the last supported target token in the name, optionally - # followed by extra suffix data such as "_not_delegated". - for target in TARGETS: - marker = f"_{target}" - idx = name.rfind(marker) - if idx == -1: - continue - - suffix_idx = idx + len(marker) - if suffix_idx == len(name) or name[suffix_idx] == "_": - return target - return None - - -def _parse_test_name_tokens(name: str) -> tuple[str, str | None]: - rest = name[5:] - target = _extract_target(name) - token = rest - if target: - idx = rest.rfind(target) - token = rest[:idx].rstrip("_") - - return token, target - - -def _get_parsing_info(kind: str, name: str) -> str: - token, target = _parse_test_name_tokens(name) - - return f"{kind} token parsed as '{token}'\n" f"TARGET token parsed as '{target}'" - - -def parse_op_test(test_name: str) -> tuple[str, str, bool, bool] | TestNameViolation: - matched_op, target, quantized_16x8, delegated = _match_allowed_op_prefix(test_name) - - if "reject" in test_name: - return TestNameViolation( - test_name, - "Use 'not_delegated' instead of 'reject' in test names", - ) - - if not matched_op: - parsing_info = _get_parsing_info("OP", test_name) - return TestNameViolation( - test_name, - ( - f"Expected test_OP_TARGET_*\n" - f"OP token not found or invalid\n" - f"{parsing_info}" - ), - ) - - if target is None: - parsing_info = _get_parsing_info("OP", test_name) - return TestNameViolation( - test_name, - ( - f"Expected test_OP_TARGET_*\n" - f"TARGET is None (valid targets: {TARGETS}))\n" - f"{parsing_info}" - ), - ) - - if matched_op not in OP_LIST: - parsing_info = _get_parsing_info("OP", test_name) - closest_match = difflib.get_close_matches(matched_op, OP_LIST, n=1, cutoff=0.0)[ - 0 - ] - return TestNameViolation( - test_name, - ( - f"Expected test_OP_TARGET_*\n" - f"OP '{matched_op}' not recognized (closest match: {closest_match})\n" - f"{parsing_info}" - ), - ) - - result = (matched_op, target, quantized_16x8, delegated) - logger.debug('Parsed op test "%s": %s', test_name, result) - - return result - - -def parse_model_test(test_name: str) -> tuple[str, str] | TestNameViolation: - token, target = _parse_test_name_tokens(test_name) - - if not token: - parsing_info = _get_parsing_info("MODEL", test_name) - return TestNameViolation( - test_name, - ( - f"Expected test_MODEL_TARGET_*\n" - f"MODEL token not found or invalid\n" - f"{parsing_info}\n" - ), - ) - - if not target: - parsing_info = _get_parsing_info("MODEL", test_name) - return TestNameViolation( - test_name, - ( - f"Expected test_MODEL_TARGET_*\n" - f"TARGET token not found (valid targets: {TARGETS})\n" - f"{parsing_info}" - ), - ) - - matched_model = _match_allowed_model_prefix(token, MODEL_LIST) - if matched_model is None: - parsing_info = _get_parsing_info("MODEL", test_name) - closest_match = difflib.get_close_matches(token, MODEL_LIST, n=1, cutoff=0.0)[0] - return TestNameViolation( - test_name, - ( - f"Expected test_MODEL_TARGET_*\n" - f"MODEL {token} not recognized (closest match: {closest_match})\n" - f"{parsing_info}" - ), - ) - - result = (token, target) - logger.debug('Parsed model test "%s": %s', test_name, result) - - return result - - -def parse_pass_test(test_name: str) -> tuple[str, str] | TestNameViolation: - pass_, target = _parse_test_name_tokens(test_name) - - if not pass_: - parsing_info = _get_parsing_info("PASS", test_name) - return TestNameViolation( - test_name, - ( - f"Expected test_PASS_TARGET_*\n" - f"PASS token not found or invalid\n" - f"{parsing_info}\n" - ), - ) - - if not target: - parsing_info = _get_parsing_info("PASS", test_name) - return TestNameViolation( - test_name, - ( - f"Expected test_PASS_TARGET_*\n" - f"TARGET token not found (valid targets: {TARGETS})\n" - f"{parsing_info}" - ), - ) - - matched_pass = _match_allowed_pass_prefix(pass_, PASS_LIST) - if matched_pass is None: - parsing_info = _get_parsing_info("PASS", test_name) - closest_match = difflib.get_close_matches(pass_, PASS_LIST, n=1, cutoff=0.0)[0] - return TestNameViolation( - test_name, - ( - f"Expected test_PASS_TARGET_* with PASS in PASS_LIST\n" - f"PASS '{pass_} not recognized (closest match: {closest_match})'\n" - f"{parsing_info}" - ), - ) - - result = (pass_, target) - logger.debug('Parsed pass test "%s": %s', test_name, result) - return result - - -def parse_general_test(test_name: str) -> tuple[str, str] | TestNameViolation: - name, target = _parse_test_name_tokens(test_name) - - if not name: - parsing_info = _get_parsing_info("NAME", test_name) - return TestNameViolation( - test_name, - f"Expected test_*_TARGET_*\n" "Invalid NAME token\n" f"{parsing_info}", - ) - - if not target: - parsing_info = _get_parsing_info("NAME", test_name) - return TestNameViolation( - test_name, - ( - "Expected test_*_TARGET_*\n" - f"TARGET token not found (valid targets: {TARGETS})\n" - f"{parsing_info}" - ), - ) - - result = (name, target) - logger.debug('Parsed general test "%s": %s', test_name, result) - return result diff --git a/backends/arm/scripts/testname_rules/validate_test_names.py b/backends/arm/scripts/testname_rules/validate_test_names.py deleted file mode 100644 index 5a64b8b6cd7..00000000000 --- a/backends/arm/scripts/testname_rules/validate_test_names.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright 2025-2026 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from __future__ import annotations - -import logging -import sys -from pathlib import Path - -from executorch.backends.arm.scripts.testname_rules.collect_tests import ( - collect_test_files, - collect_tests, -) - -from executorch.backends.arm.scripts.testname_rules.parse_test_names import ( - parse_general_test, - parse_model_test, - parse_op_test, - parse_pass_test, - TestNameViolation, -) - -LOGGER = logging.getLogger(__name__) - - -def _is_in_path(child: str, parent: str) -> bool: - """Returns True if 'child' path is inside 'parent' path.""" - child = Path(child).resolve() - parent = Path(parent).resolve() - return parent == child or parent in child.parents - - -def check_test_name_validations( - tests: list[str], -) -> list[TestNameViolation]: - violations: list[TestNameViolation] = [] - for test in tests: - path, test_name = test.split("::") - result: tuple[str, str, bool, bool] | tuple[str, str] | TestNameViolation | None - - if _is_in_path(path, "backends/arm/test/ops"): - result = parse_op_test(test_name) - elif _is_in_path(path, "backends/arm/test/models"): - result = parse_model_test(test_name) - elif _is_in_path(path, "backends/arm/test/passes"): - result = parse_pass_test(test_name) - elif _is_in_path(path, "backends/arm/test/quantizer") or _is_in_path( - path, "backends/arm/test/misc" - ): - result = parse_general_test(test_name) - else: - result = None - - if isinstance(result, TestNameViolation): - violations.append(result) - - return violations - - -def main() -> int: - logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") - - tests: list[str] = [] - path_list: list[Path] = [Path(path) for path in sys.argv[1:]] - - if path_list == []: - tests = collect_tests(collect_test_files(Path("backends/arm/test"))) - else: - tests = collect_tests(path_list) - - violations = check_test_name_validations(tests) - - for entry in violations: - LOGGER.error("%s", entry) - - LOGGER.info("Total tests needing renaming: %d", len(violations)) - - if violations: - LOGGER.info( - "Please follow the test naming convention: https://confluence.arm.com/display/MLENG/Executorch+naming+conventions" - ) - return 1 - - return 0 - - -if __name__ == "__main__": - raise SystemExit(main())