From 4127576d2e2603dc04b897450c485d5e6df07d4e Mon Sep 17 00:00:00 2001 From: Michael Vartanian Date: Thu, 26 Mar 2026 07:49:52 -0700 Subject: [PATCH] Sync Python-side quantized_softmax schema with C++ kernel (add mask_type and pos args) (#18495) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/18495 D88997196 added mask_type (int) and pos (Tensor) parameters to the C++ cadence::quantized_softmax kernels and custom_ops.yaml, but missed updating the Python-side op registrations, reference implementations, quantizer fusion pass, and tests. This caused an argument count mismatch at runtime (Expected 11 args received 9) when running quantized softmax on the Xtensa ISS. This diff completes the schema sync by updating: - ops_registrations.py — Updated all 4 lib.define() schemas and both register_fake meta functions to include int mask_type, Tensor pos after dim. - ref_implementations.py — Added mask_type and pos params to quantized_softmax_per_tensor_common, quantized_softmax_per_tensor, and quantized_softmax. Added assert mask_type == 0 guard consistent with existing assert mask is None. - quantizer/fusion_pass.py — Updated get_args_and_kwargs_softmax to emit mask_type=0 (no masking) and a dummy pos tensor (full([1], 0, dtype=int64)), matching the default behavior for standard softmax quantization. - tests/test_ref_implementations.py — Updated test_quantized_softmax_per_tensor and test_quantized_softmax call sites with the new args. Reviewed By: hsharma35 Differential Revision: D98145095 --- backends/cadence/aot/ops_registrations.py | 12 ++++++++---- backends/cadence/aot/quantizer/fusion_pass.py | 17 +++++++++++++++++ backends/cadence/aot/ref_implementations.py | 15 +++++++++++++++ .../aot/tests/test_ref_implementations.py | 4 ++++ 4 files changed, 44 insertions(+), 4 deletions(-) diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index 4d419830f67..92e82e6e7de 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -472,16 +472,16 @@ def register_fake( ) lib.define( - "quantized_softmax(Tensor input, Tensor mask, int dim, Tensor in_scale, Tensor in_zero_point, Tensor out_scale, Tensor out_zero_point) -> (Tensor out)" + "quantized_softmax(Tensor input, Tensor mask, int dim, int mask_type, Tensor pos, Tensor in_scale, Tensor in_zero_point, Tensor out_scale, Tensor out_zero_point) -> (Tensor out)" ) lib.define( - "quantized_softmax.per_tensor(Tensor input, Tensor mask, int dim, float in_scale, int in_zero_point, float out_scale, int out_zero_point) -> (Tensor out)" + "quantized_softmax.per_tensor(Tensor input, Tensor mask, int dim, int mask_type, Tensor pos, float in_scale, int in_zero_point, float out_scale, int out_zero_point) -> (Tensor out)" ) lib.define( - "quantized_softmax.out(Tensor input, Tensor mask, int dim, Tensor in_scale, Tensor in_zero_point, Tensor out_scale, Tensor out_zero_point, *, Tensor(a!) out) -> Tensor (a!)" + "quantized_softmax.out(Tensor input, Tensor mask, int dim, int mask_type, Tensor pos, Tensor in_scale, Tensor in_zero_point, Tensor out_scale, Tensor out_zero_point, *, Tensor(a!) out) -> Tensor (a!)" ) lib.define( - "quantized_softmax.per_tensor_out(Tensor input, Tensor mask, int dim, float in_scale, int in_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor (a!)" + "quantized_softmax.per_tensor_out(Tensor input, Tensor mask, int dim, int mask_type, Tensor pos, float in_scale, int in_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor (a!)" ) # pack float/bool mask tensor into a bitmask of type uint8 (each element holding 8 bool mask elements) @@ -2957,6 +2957,8 @@ def quantized_softmax_meta( input: torch.Tensor, mask: torch.Tensor, dim: int, + mask_type: int, + pos: torch.Tensor, in_scale: torch.Tensor, in_zero_point: torch.Tensor, out_scale: torch.Tensor, @@ -2970,6 +2972,8 @@ def quantized_softmax_per_tensor_meta( input: torch.Tensor, mask: torch.Tensor, dim: int, + mask_type: int, + pos: torch.Tensor, in_scale: float, in_zero_point: int, out_scale: float, diff --git a/backends/cadence/aot/quantizer/fusion_pass.py b/backends/cadence/aot/quantizer/fusion_pass.py index ae4d42f4898..f4cda91c8aa 100644 --- a/backends/cadence/aot/quantizer/fusion_pass.py +++ b/backends/cadence/aot/quantizer/fusion_pass.py @@ -378,6 +378,21 @@ def get_args_and_kwargs_softmax( with fake_mode: mask_tensor.meta["val"] = torch.full(mask_shape, 0.0, dtype=torch.int32) copy_node_metadata(mask_tensor, inputs_inputs[0]) + + # Default mask_type=0 (no masking) and dummy pos tensor + mask_type = 0 + pos_tensor = graph_module.graph.call_function( + torch.ops.aten.full.default, + ( + [1], + 0, + ), + {"dtype": torch.int64}, + ) + with fake_mode: + pos_tensor.meta["val"] = torch.full([1], 0, dtype=torch.int64) + copy_node_metadata(pos_tensor, inputs_inputs[0]) + # Make the scale and zero_point tensors in_scale = dequants_inputs[0].args[1] in_zero_point = dequants_inputs[0].args[2] @@ -389,6 +404,8 @@ def get_args_and_kwargs_softmax( inputs_inputs[0], mask_tensor, op_node.args[1], + mask_type, + pos_tensor, in_scale, in_zero_point, out_scale, diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index 91a54906c14..8404fe25268 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -2480,6 +2480,8 @@ def quantized_softmax_per_tensor_common( input_tensor: torch.Tensor, mask: torch.Tensor | None, dim: int, + mask_type: int, + pos: torch.Tensor, in_scale: float, in_zero_point: int, out_scale: float, @@ -2492,6 +2494,8 @@ def quantized_softmax_per_tensor_common( - input_tensor (Tensor): The quantized input tensor - mask (Tensor): Mask tensor - dim (int): The dimension along which softmax is computed + - mask_type (int): Masking strategy (0=none, 1=position-based causal) + - pos (Tensor): Position tensor for causal masking - in_scale (float): The scale of the input quantization - in_zero_point (int): The zero point of the input quantization - out_scale (float): The scale of the output quantization @@ -2499,6 +2503,9 @@ def quantized_softmax_per_tensor_common( """ # TODO: T228751479 - Add support for mask parameter in softmax assert mask is None + assert ( + mask_type == 0 + ), f"Only mask_type=0 (no masking) is supported, got {mask_type}" supported_dtypes = [torch.int8, torch.uint8, torch.int16] if input_tensor.dtype not in supported_dtypes: raise ValueError( @@ -2531,6 +2538,8 @@ def quantized_softmax_per_tensor( input_tensor: torch.Tensor, mask: torch.Tensor | None, dim: int, + mask_type: int, + pos: torch.Tensor, in_scale: float, in_zero_point: int, out_scale: float, @@ -2540,6 +2549,8 @@ def quantized_softmax_per_tensor( input_tensor, mask, dim, + mask_type, + pos, in_scale, in_zero_point, out_scale, @@ -2552,6 +2563,8 @@ def quantized_softmax( input_tensor: torch.Tensor, mask: torch.Tensor | None, dim: int, + mask_type: int, + pos: torch.Tensor, in_scale: torch.Tensor, in_zero_point: torch.Tensor, out_scale: float, @@ -2561,6 +2574,8 @@ def quantized_softmax( input_tensor, mask, dim, + mask_type, + pos, float(in_scale.item()), int(in_zero_point.item()), out_scale, diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py index 222fb27bfcd..63077d373a7 100644 --- a/backends/cadence/aot/tests/test_ref_implementations.py +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -3152,6 +3152,8 @@ def test_quantized_softmax_per_tensor( input_tensor, mask, dim, + 0, # mask_type (no masking) + torch.zeros(1, dtype=torch.int64), # pos in_scale, in_zero_point, out_scale, @@ -3189,6 +3191,8 @@ def test_quantized_softmax(self) -> None: input_tensor, None, # mask 1, # dim + 0, # mask_type (no masking) + torch.zeros(1, dtype=torch.int64), # pos in_scale, in_zero_point, 0.004, # out_scale