Skip to content

Commit 2cc1f22

Browse files
mvartani-metafacebook-github-bot
authored andcommitted
Sync Python-side quantized_softmax schema with C++ kernel (add mask_type and pos args) (pytorch#18495)
Summary: D88997196 added mask_type (int) and pos (Tensor) parameters to the C++ cadence::quantized_softmax kernels and custom_ops.yaml, but missed updating the Python-side op registrations, reference implementations, quantizer fusion pass, and tests. This caused an argument count mismatch at runtime (Expected 11 args received 9) when running quantized softmax on the Xtensa ISS. This diff completes the schema sync by updating: - ops_registrations.py — Updated all 4 lib.define() schemas and both register_fake meta functions to include int mask_type, Tensor pos after dim. - ref_implementations.py — Added mask_type and pos params to quantized_softmax_per_tensor_common, quantized_softmax_per_tensor, and quantized_softmax. Added assert mask_type == 0 guard consistent with existing assert mask is None. - quantizer/fusion_pass.py — Updated get_args_and_kwargs_softmax to emit mask_type=0 (no masking) and a dummy pos tensor (full([1], 0, dtype=int64)), matching the default behavior for standard softmax quantization. - tests/test_ref_implementations.py — Updated test_quantized_softmax_per_tensor and test_quantized_softmax call sites with the new args. Reviewed By: hsharma35 Differential Revision: D98145095
1 parent 0fc4d6d commit 2cc1f22

4 files changed

Lines changed: 44 additions & 4 deletions

File tree

backends/cadence/aot/ops_registrations.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -472,16 +472,16 @@ def register_fake(
472472
)
473473

474474
lib.define(
475-
"quantized_softmax(Tensor input, Tensor mask, int dim, Tensor in_scale, Tensor in_zero_point, Tensor out_scale, Tensor out_zero_point) -> (Tensor out)"
475+
"quantized_softmax(Tensor input, Tensor mask, int dim, int mask_type, Tensor pos, Tensor in_scale, Tensor in_zero_point, Tensor out_scale, Tensor out_zero_point) -> (Tensor out)"
476476
)
477477
lib.define(
478-
"quantized_softmax.per_tensor(Tensor input, Tensor mask, int dim, float in_scale, int in_zero_point, float out_scale, int out_zero_point) -> (Tensor out)"
478+
"quantized_softmax.per_tensor(Tensor input, Tensor mask, int dim, int mask_type, Tensor pos, float in_scale, int in_zero_point, float out_scale, int out_zero_point) -> (Tensor out)"
479479
)
480480
lib.define(
481-
"quantized_softmax.out(Tensor input, Tensor mask, int dim, Tensor in_scale, Tensor in_zero_point, Tensor out_scale, Tensor out_zero_point, *, Tensor(a!) out) -> Tensor (a!)"
481+
"quantized_softmax.out(Tensor input, Tensor mask, int dim, int mask_type, Tensor pos, Tensor in_scale, Tensor in_zero_point, Tensor out_scale, Tensor out_zero_point, *, Tensor(a!) out) -> Tensor (a!)"
482482
)
483483
lib.define(
484-
"quantized_softmax.per_tensor_out(Tensor input, Tensor mask, int dim, float in_scale, int in_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor (a!)"
484+
"quantized_softmax.per_tensor_out(Tensor input, Tensor mask, int dim, int mask_type, Tensor pos, float in_scale, int in_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor (a!)"
485485
)
486486

487487
# pack float/bool mask tensor into a bitmask of type uint8 (each element holding 8 bool mask elements)
@@ -2957,6 +2957,8 @@ def quantized_softmax_meta(
29572957
input: torch.Tensor,
29582958
mask: torch.Tensor,
29592959
dim: int,
2960+
mask_type: int,
2961+
pos: torch.Tensor,
29602962
in_scale: torch.Tensor,
29612963
in_zero_point: torch.Tensor,
29622964
out_scale: torch.Tensor,
@@ -2970,6 +2972,8 @@ def quantized_softmax_per_tensor_meta(
29702972
input: torch.Tensor,
29712973
mask: torch.Tensor,
29722974
dim: int,
2975+
mask_type: int,
2976+
pos: torch.Tensor,
29732977
in_scale: float,
29742978
in_zero_point: int,
29752979
out_scale: float,

backends/cadence/aot/quantizer/fusion_pass.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,21 @@ def get_args_and_kwargs_softmax(
378378
with fake_mode:
379379
mask_tensor.meta["val"] = torch.full(mask_shape, 0.0, dtype=torch.int32)
380380
copy_node_metadata(mask_tensor, inputs_inputs[0])
381+
382+
# Default mask_type=0 (no masking) and dummy pos tensor
383+
mask_type = 0
384+
pos_tensor = graph_module.graph.call_function(
385+
torch.ops.aten.full.default,
386+
(
387+
[1],
388+
0,
389+
),
390+
{"dtype": torch.int64},
391+
)
392+
with fake_mode:
393+
pos_tensor.meta["val"] = torch.full([1], 0, dtype=torch.int64)
394+
copy_node_metadata(pos_tensor, inputs_inputs[0])
395+
381396
# Make the scale and zero_point tensors
382397
in_scale = dequants_inputs[0].args[1]
383398
in_zero_point = dequants_inputs[0].args[2]
@@ -389,6 +404,8 @@ def get_args_and_kwargs_softmax(
389404
inputs_inputs[0],
390405
mask_tensor,
391406
op_node.args[1],
407+
mask_type,
408+
pos_tensor,
392409
in_scale,
393410
in_zero_point,
394411
out_scale,

backends/cadence/aot/ref_implementations.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2480,6 +2480,8 @@ def quantized_softmax_per_tensor_common(
24802480
input_tensor: torch.Tensor,
24812481
mask: torch.Tensor | None,
24822482
dim: int,
2483+
mask_type: int,
2484+
pos: torch.Tensor,
24832485
in_scale: float,
24842486
in_zero_point: int,
24852487
out_scale: float,
@@ -2492,13 +2494,18 @@ def quantized_softmax_per_tensor_common(
24922494
- input_tensor (Tensor): The quantized input tensor
24932495
- mask (Tensor): Mask tensor
24942496
- dim (int): The dimension along which softmax is computed
2497+
- mask_type (int): Masking strategy (0=none, 1=position-based causal)
2498+
- pos (Tensor): Position tensor for causal masking
24952499
- in_scale (float): The scale of the input quantization
24962500
- in_zero_point (int): The zero point of the input quantization
24972501
- out_scale (float): The scale of the output quantization
24982502
- out_zero_point (int): The zero point of the output quantization
24992503
"""
25002504
# TODO: T228751479 - Add support for mask parameter in softmax
25012505
assert mask is None
2506+
assert (
2507+
mask_type == 0
2508+
), f"Only mask_type=0 (no masking) is supported, got {mask_type}"
25022509
supported_dtypes = [torch.int8, torch.uint8, torch.int16]
25032510
if input_tensor.dtype not in supported_dtypes:
25042511
raise ValueError(
@@ -2531,6 +2538,8 @@ def quantized_softmax_per_tensor(
25312538
input_tensor: torch.Tensor,
25322539
mask: torch.Tensor | None,
25332540
dim: int,
2541+
mask_type: int,
2542+
pos: torch.Tensor,
25342543
in_scale: float,
25352544
in_zero_point: int,
25362545
out_scale: float,
@@ -2540,6 +2549,8 @@ def quantized_softmax_per_tensor(
25402549
input_tensor,
25412550
mask,
25422551
dim,
2552+
mask_type,
2553+
pos,
25432554
in_scale,
25442555
in_zero_point,
25452556
out_scale,
@@ -2552,6 +2563,8 @@ def quantized_softmax(
25522563
input_tensor: torch.Tensor,
25532564
mask: torch.Tensor | None,
25542565
dim: int,
2566+
mask_type: int,
2567+
pos: torch.Tensor,
25552568
in_scale: torch.Tensor,
25562569
in_zero_point: torch.Tensor,
25572570
out_scale: float,
@@ -2561,6 +2574,8 @@ def quantized_softmax(
25612574
input_tensor,
25622575
mask,
25632576
dim,
2577+
mask_type,
2578+
pos,
25642579
float(in_scale.item()),
25652580
int(in_zero_point.item()),
25662581
out_scale,

backends/cadence/aot/tests/test_ref_implementations.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3152,6 +3152,8 @@ def test_quantized_softmax_per_tensor(
31523152
input_tensor,
31533153
mask,
31543154
dim,
3155+
0, # mask_type (no masking)
3156+
torch.zeros(1, dtype=torch.int64), # pos
31553157
in_scale,
31563158
in_zero_point,
31573159
out_scale,
@@ -3189,6 +3191,8 @@ def test_quantized_softmax(self) -> None:
31893191
input_tensor,
31903192
None, # mask
31913193
1, # dim
3194+
0, # mask_type (no masking)
3195+
torch.zeros(1, dtype=torch.int64), # pos
31923196
in_scale,
31933197
in_zero_point,
31943198
0.004, # out_scale

0 commit comments

Comments
 (0)