Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions backends/cadence/aot/ops_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,16 +472,16 @@ def register_fake(
)

lib.define(
"quantized_softmax(Tensor input, Tensor mask, int dim, Tensor in_scale, Tensor in_zero_point, Tensor out_scale, Tensor out_zero_point) -> (Tensor out)"
"quantized_softmax(Tensor input, Tensor mask, int dim, int mask_type, Tensor pos, Tensor in_scale, Tensor in_zero_point, Tensor out_scale, Tensor out_zero_point) -> (Tensor out)"
)
lib.define(
"quantized_softmax.per_tensor(Tensor input, Tensor mask, int dim, float in_scale, int in_zero_point, float out_scale, int out_zero_point) -> (Tensor out)"
"quantized_softmax.per_tensor(Tensor input, Tensor mask, int dim, int mask_type, Tensor pos, float in_scale, int in_zero_point, float out_scale, int out_zero_point) -> (Tensor out)"
)
lib.define(
"quantized_softmax.out(Tensor input, Tensor mask, int dim, Tensor in_scale, Tensor in_zero_point, Tensor out_scale, Tensor out_zero_point, *, Tensor(a!) out) -> Tensor (a!)"
"quantized_softmax.out(Tensor input, Tensor mask, int dim, int mask_type, Tensor pos, Tensor in_scale, Tensor in_zero_point, Tensor out_scale, Tensor out_zero_point, *, Tensor(a!) out) -> Tensor (a!)"
)
lib.define(
"quantized_softmax.per_tensor_out(Tensor input, Tensor mask, int dim, float in_scale, int in_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor (a!)"
"quantized_softmax.per_tensor_out(Tensor input, Tensor mask, int dim, int mask_type, Tensor pos, float in_scale, int in_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor (a!)"
)

# pack float/bool mask tensor into a bitmask of type uint8 (each element holding 8 bool mask elements)
Expand Down Expand Up @@ -2957,6 +2957,8 @@ def quantized_softmax_meta(
input: torch.Tensor,
mask: torch.Tensor,
dim: int,
mask_type: int,
pos: torch.Tensor,
in_scale: torch.Tensor,
in_zero_point: torch.Tensor,
out_scale: torch.Tensor,
Expand All @@ -2970,6 +2972,8 @@ def quantized_softmax_per_tensor_meta(
input: torch.Tensor,
mask: torch.Tensor,
dim: int,
mask_type: int,
pos: torch.Tensor,
in_scale: float,
in_zero_point: int,
out_scale: float,
Expand Down
17 changes: 17 additions & 0 deletions backends/cadence/aot/quantizer/fusion_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,21 @@ def get_args_and_kwargs_softmax(
with fake_mode:
mask_tensor.meta["val"] = torch.full(mask_shape, 0.0, dtype=torch.int32)
copy_node_metadata(mask_tensor, inputs_inputs[0])

# Default mask_type=0 (no masking) and dummy pos tensor
mask_type = 0
pos_tensor = graph_module.graph.call_function(
torch.ops.aten.full.default,
(
[1],
0,
),
{"dtype": torch.int64},
)
with fake_mode:
pos_tensor.meta["val"] = torch.full([1], 0, dtype=torch.int64)
copy_node_metadata(pos_tensor, inputs_inputs[0])

# Make the scale and zero_point tensors
in_scale = dequants_inputs[0].args[1]
in_zero_point = dequants_inputs[0].args[2]
Expand All @@ -389,6 +404,8 @@ def get_args_and_kwargs_softmax(
inputs_inputs[0],
mask_tensor,
op_node.args[1],
mask_type,
pos_tensor,
in_scale,
in_zero_point,
out_scale,
Expand Down
15 changes: 15 additions & 0 deletions backends/cadence/aot/ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2480,6 +2480,8 @@ def quantized_softmax_per_tensor_common(
input_tensor: torch.Tensor,
mask: torch.Tensor | None,
dim: int,
mask_type: int,
pos: torch.Tensor,
in_scale: float,
in_zero_point: int,
out_scale: float,
Expand All @@ -2492,13 +2494,18 @@ def quantized_softmax_per_tensor_common(
- input_tensor (Tensor): The quantized input tensor
- mask (Tensor): Mask tensor
- dim (int): The dimension along which softmax is computed
- mask_type (int): Masking strategy (0=none, 1=position-based causal)
- pos (Tensor): Position tensor for causal masking
- in_scale (float): The scale of the input quantization
- in_zero_point (int): The zero point of the input quantization
- out_scale (float): The scale of the output quantization
- out_zero_point (int): The zero point of the output quantization
"""
# TODO: T228751479 - Add support for mask parameter in softmax
assert mask is None
assert (
mask_type == 0
), f"Only mask_type=0 (no masking) is supported, got {mask_type}"
supported_dtypes = [torch.int8, torch.uint8, torch.int16]
if input_tensor.dtype not in supported_dtypes:
raise ValueError(
Expand Down Expand Up @@ -2531,6 +2538,8 @@ def quantized_softmax_per_tensor(
input_tensor: torch.Tensor,
mask: torch.Tensor | None,
dim: int,
mask_type: int,
pos: torch.Tensor,
in_scale: float,
in_zero_point: int,
out_scale: float,
Expand All @@ -2540,6 +2549,8 @@ def quantized_softmax_per_tensor(
input_tensor,
mask,
dim,
mask_type,
pos,
in_scale,
in_zero_point,
out_scale,
Expand All @@ -2552,6 +2563,8 @@ def quantized_softmax(
input_tensor: torch.Tensor,
mask: torch.Tensor | None,
dim: int,
mask_type: int,
pos: torch.Tensor,
in_scale: torch.Tensor,
in_zero_point: torch.Tensor,
out_scale: float,
Expand All @@ -2561,6 +2574,8 @@ def quantized_softmax(
input_tensor,
mask,
dim,
mask_type,
pos,
float(in_scale.item()),
int(in_zero_point.item()),
out_scale,
Expand Down
4 changes: 4 additions & 0 deletions backends/cadence/aot/tests/test_ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3152,6 +3152,8 @@ def test_quantized_softmax_per_tensor(
input_tensor,
mask,
dim,
0, # mask_type (no masking)
torch.zeros(1, dtype=torch.int64), # pos
in_scale,
in_zero_point,
out_scale,
Expand Down Expand Up @@ -3189,6 +3191,8 @@ def test_quantized_softmax(self) -> None:
input_tensor,
None, # mask
1, # dim
0, # mask_type (no masking)
torch.zeros(1, dtype=torch.int64), # pos
in_scale,
in_zero_point,
0.004, # out_scale
Expand Down
Loading