diff --git a/backends/arm/test/runner_utils.py b/backends/arm/test/runner_utils.py index 9a63452e325..dbaa7435e46 100644 --- a/backends/arm/test/runner_utils.py +++ b/backends/arm/test/runner_utils.py @@ -45,6 +45,40 @@ logger = logging.getLogger(__name__) +# For Torch 2.12 and later overloads. +_QDQ_TORCH_OVERLOADS = ( + ("quantize_per_tensor", ("tensor", "tensor2", "default")), + ("dequantize_per_tensor", ("tensor", "tensor2", "default")), + ("quantize_per_channel", ("default",)), + ("dequantize_per_channel", ("default",)), +) + +# For backward compatibility with Torch versions older than 2.12. +_QDQ_BACKWARD_COMPAT_OVERLOADS = ( + ("quantize_per_tensor", ("out",)), + ("dequantize_per_tensor", ("out",)), + ("quantize_per_channel", ("out",)), + ("dequantize_per_channel", ("out",)), +) + + +def _get_qdq_memory_format_ops() -> tuple[object, ...]: + qdq_ops = [] + backward_compat = dict(_QDQ_BACKWARD_COMPAT_OVERLOADS) + ns = torch.ops.quantized_decomposed + for op_name, overload_names in _QDQ_TORCH_OVERLOADS: + op_packet = getattr(ns, op_name, None) + if op_packet is None: + continue + for overload_name in overload_names + backward_compat[op_name]: + if hasattr(op_packet, overload_name): + qdq_ops.append(getattr(op_packet, overload_name)) + + return tuple(qdq_ops) + + +_QDQ_MEMORY_FORMAT_OPS = _get_qdq_memory_format_ops() + # Copied from PyTorch. # From torch/testing/_internal/common_utils.py:torch_to_numpy_dtype_dict # To avoid a dependency on _internal stuff. @@ -358,12 +392,7 @@ def __torch_function__(self, func, types, args=..., kwargs=None): # This is a hack since Q/DQ ops does not handle channels last input correctly: the simplest and most robust # workaround is to simply run them in channels first format and then convert back to channels last. - if func in ( - torch.ops.quantized_decomposed.quantize_per_tensor.out, - torch.ops.quantized_decomposed.dequantize_per_tensor.out, - torch.ops.quantized_decomposed.quantize_per_channel.out, - torch.ops.quantized_decomposed.dequantize_per_channel.out, - ): + if func in _QDQ_MEMORY_FORMAT_OPS: input_dim_order = args[0].dim_order() if input_dim_order in (NHWC_ORDER, NNHWC_ORDER):