diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index 2a6839bd5c3..d9b213a3b59 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -318,14 +318,8 @@ def quantized_add_per_tensor( f"X and Y dtypes need to be in {supported_dtypes}. Got {dtype}" ) - if dtype == torch.uint8: - X = X.to(torch.int8) - Y = Y.to(torch.int8) - - # TODO(agrebenisan): This should be done in fixed point arithmetic, but to match the quantized_add_out.cpp - # reference implementation, we'll do it in floating point. - dequant_X = X_scale * (X - X_zero_point) - dequant_Y = Y_scale * (Y - Y_zero_point) + dequant_X = X_scale * (X.float() - X_zero_point) + dequant_Y = Y_scale * (Y.float() - Y_zero_point) # q_min/q_max are unused args return quantize_per_tensor( @@ -447,12 +441,8 @@ def quantized_mul_per_tensor( f"X and Y dtypes need to be in {supported_dtypes}. Got {dtype}" ) - if dtype == torch.uint8: - X = X.to(torch.int8) - Y = Y.to(torch.int8) - - dequant_X = X_scale * (X - X_zero_point) - dequant_Y = Y_scale * (Y - Y_zero_point) + dequant_X = X_scale * (X.float() - X_zero_point) + dequant_Y = Y_scale * (Y.float() - Y_zero_point) return quantize_per_tensor( dequant_X * dequant_Y,