diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index 4d419830f67..92e82e6e7de 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -472,16 +472,16 @@ def register_fake( ) lib.define( - "quantized_softmax(Tensor input, Tensor mask, int dim, Tensor in_scale, Tensor in_zero_point, Tensor out_scale, Tensor out_zero_point) -> (Tensor out)" + "quantized_softmax(Tensor input, Tensor mask, int dim, int mask_type, Tensor pos, Tensor in_scale, Tensor in_zero_point, Tensor out_scale, Tensor out_zero_point) -> (Tensor out)" ) lib.define( - "quantized_softmax.per_tensor(Tensor input, Tensor mask, int dim, float in_scale, int in_zero_point, float out_scale, int out_zero_point) -> (Tensor out)" + "quantized_softmax.per_tensor(Tensor input, Tensor mask, int dim, int mask_type, Tensor pos, float in_scale, int in_zero_point, float out_scale, int out_zero_point) -> (Tensor out)" ) lib.define( - "quantized_softmax.out(Tensor input, Tensor mask, int dim, Tensor in_scale, Tensor in_zero_point, Tensor out_scale, Tensor out_zero_point, *, Tensor(a!) out) -> Tensor (a!)" + "quantized_softmax.out(Tensor input, Tensor mask, int dim, int mask_type, Tensor pos, Tensor in_scale, Tensor in_zero_point, Tensor out_scale, Tensor out_zero_point, *, Tensor(a!) out) -> Tensor (a!)" ) lib.define( - "quantized_softmax.per_tensor_out(Tensor input, Tensor mask, int dim, float in_scale, int in_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor (a!)" + "quantized_softmax.per_tensor_out(Tensor input, Tensor mask, int dim, int mask_type, Tensor pos, float in_scale, int in_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor (a!)" ) # pack float/bool mask tensor into a bitmask of type uint8 (each element holding 8 bool mask elements) @@ -2957,6 +2957,8 @@ def quantized_softmax_meta( input: torch.Tensor, mask: torch.Tensor, dim: int, + mask_type: int, + pos: torch.Tensor, in_scale: torch.Tensor, in_zero_point: torch.Tensor, out_scale: torch.Tensor, @@ -2970,6 +2972,8 @@ def quantized_softmax_per_tensor_meta( input: torch.Tensor, mask: torch.Tensor, dim: int, + mask_type: int, + pos: torch.Tensor, in_scale: float, in_zero_point: int, out_scale: float, diff --git a/backends/cadence/aot/quantizer/fusion_pass.py b/backends/cadence/aot/quantizer/fusion_pass.py index ae4d42f4898..f4cda91c8aa 100644 --- a/backends/cadence/aot/quantizer/fusion_pass.py +++ b/backends/cadence/aot/quantizer/fusion_pass.py @@ -378,6 +378,21 @@ def get_args_and_kwargs_softmax( with fake_mode: mask_tensor.meta["val"] = torch.full(mask_shape, 0.0, dtype=torch.int32) copy_node_metadata(mask_tensor, inputs_inputs[0]) + + # Default mask_type=0 (no masking) and dummy pos tensor + mask_type = 0 + pos_tensor = graph_module.graph.call_function( + torch.ops.aten.full.default, + ( + [1], + 0, + ), + {"dtype": torch.int64}, + ) + with fake_mode: + pos_tensor.meta["val"] = torch.full([1], 0, dtype=torch.int64) + copy_node_metadata(pos_tensor, inputs_inputs[0]) + # Make the scale and zero_point tensors in_scale = dequants_inputs[0].args[1] in_zero_point = dequants_inputs[0].args[2] @@ -389,6 +404,8 @@ def get_args_and_kwargs_softmax( inputs_inputs[0], mask_tensor, op_node.args[1], + mask_type, + pos_tensor, in_scale, in_zero_point, out_scale, diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index 91a54906c14..8404fe25268 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -2480,6 +2480,8 @@ def quantized_softmax_per_tensor_common( input_tensor: torch.Tensor, mask: torch.Tensor | None, dim: int, + mask_type: int, + pos: torch.Tensor, in_scale: float, in_zero_point: int, out_scale: float, @@ -2492,6 +2494,8 @@ def quantized_softmax_per_tensor_common( - input_tensor (Tensor): The quantized input tensor - mask (Tensor): Mask tensor - dim (int): The dimension along which softmax is computed + - mask_type (int): Masking strategy (0=none, 1=position-based causal) + - pos (Tensor): Position tensor for causal masking - in_scale (float): The scale of the input quantization - in_zero_point (int): The zero point of the input quantization - out_scale (float): The scale of the output quantization @@ -2499,6 +2503,9 @@ def quantized_softmax_per_tensor_common( """ # TODO: T228751479 - Add support for mask parameter in softmax assert mask is None + assert ( + mask_type == 0 + ), f"Only mask_type=0 (no masking) is supported, got {mask_type}" supported_dtypes = [torch.int8, torch.uint8, torch.int16] if input_tensor.dtype not in supported_dtypes: raise ValueError( @@ -2531,6 +2538,8 @@ def quantized_softmax_per_tensor( input_tensor: torch.Tensor, mask: torch.Tensor | None, dim: int, + mask_type: int, + pos: torch.Tensor, in_scale: float, in_zero_point: int, out_scale: float, @@ -2540,6 +2549,8 @@ def quantized_softmax_per_tensor( input_tensor, mask, dim, + mask_type, + pos, in_scale, in_zero_point, out_scale, @@ -2552,6 +2563,8 @@ def quantized_softmax( input_tensor: torch.Tensor, mask: torch.Tensor | None, dim: int, + mask_type: int, + pos: torch.Tensor, in_scale: torch.Tensor, in_zero_point: torch.Tensor, out_scale: float, @@ -2561,6 +2574,8 @@ def quantized_softmax( input_tensor, mask, dim, + mask_type, + pos, float(in_scale.item()), int(in_zero_point.item()), out_scale, diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py index 222fb27bfcd..63077d373a7 100644 --- a/backends/cadence/aot/tests/test_ref_implementations.py +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -3152,6 +3152,8 @@ def test_quantized_softmax_per_tensor( input_tensor, mask, dim, + 0, # mask_type (no masking) + torch.zeros(1, dtype=torch.int64), # pos in_scale, in_zero_point, out_scale, @@ -3189,6 +3191,8 @@ def test_quantized_softmax(self) -> None: input_tensor, None, # mask 1, # dim + 0, # mask_type (no masking) + torch.zeros(1, dtype=torch.int64), # pos in_scale, in_zero_point, 0.004, # out_scale