diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index 145e953394cd..1b0f1b5e8779 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -19,9 +19,6 @@ # pylint: disable=no-value-for-parameter, unused-variable # pylint: disable=unexpected-keyword-arg, unused-import, too-many-function-args # ruff: noqa: RUF005 -# F821: _qnn and _expr references are in unreachable code paths (guarded by NotImplementedError) -# and will be resolved when quantization and vision op support are added. -# ruff: noqa: F821 """Tensorflow lite frontend.""" import functools @@ -547,9 +544,7 @@ def get_tensors(self, tensors_idx_list): qnn_params = dict() qnn_params["scale"] = relax.const(scale, "float32") qnn_params["zero_point"] = relax.const(zero_point, "int32") - raise NotImplementedError( - "Quantized TFLite models are not yet supported in the Relax frontend" - ) + qnn_params["axis"] = int(tflite_qnn_params.QuantizedDimension()) return_list.append(TensorWrapper(tensor_idx, tensor, buffer, qnn_params)) return return_list @@ -654,20 +649,22 @@ def quantize(self, expr, tensor_to_quantize): """Helper function to quantize a tensor with Relax""" tensor_type = tensor_to_quantize.tensor.Type() tensor_type_str = self.get_tensor_type_str(tensor_type) - quantized = _qnn.op.quantize( + quantized = relax.op.quantize( data=expr, - output_scale=tensor_to_quantize.qnn_params["scale"], - output_zero_point=tensor_to_quantize.qnn_params["zero_point"], + scale=tensor_to_quantize.qnn_params["scale"], + zero_point=tensor_to_quantize.qnn_params["zero_point"], + axis=tensor_to_quantize.qnn_params["axis"], out_dtype=tensor_type_str, ) return quantized def dequantize(self, expr, tensor): """Helper function to dequantize a tensor with Relax""" - dequantized = _qnn.op.dequantize( + dequantized = relax.op.dequantize( data=expr, - input_scale=tensor.qnn_params["scale"], - input_zero_point=tensor.qnn_params["zero_point"], + scale=tensor.qnn_params["scale"], + zero_point=tensor.qnn_params["zero_point"], + axis=tensor.qnn_params["axis"], ) return dequantized @@ -778,20 +775,15 @@ def convert_reshape(self, op): "TFLite reshape requires input and output scale and zero points to be equal" ) - out = relax.op.reshape(in_expr, shape=relax.ShapeExpr(target_shape)) if input_tensor.qnn_params and input_tensor_type_str == "uint8": output_tensor = output_tensors[0] if not self.has_same_qnn_params(input_tensor, output_tensor): - output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type()) - out = _qnn.op.requantize( - out, - input_scale=input_tensor.qnn_params["scale"], - input_zero_point=input_tensor.qnn_params["zero_point"], - output_scale=output_tensor.qnn_params["scale"], - output_zero_point=output_tensor.qnn_params["zero_point"], - out_dtype=output_tensor_type_str, - ) + in_f32 = self.dequantize(in_expr, input_tensor) + out = relax.op.reshape(in_f32, shape=relax.ShapeExpr(target_shape)) + out = self.quantize(out, output_tensor) + return out + out = relax.op.reshape(in_expr, shape=relax.ShapeExpr(target_shape)) return out def _convert_resize(self, method, op): @@ -1101,8 +1093,6 @@ def convert_shape(self, op): def convert_relu(self, op): """Convert TFLite ReLU""" - from tflite.ActivationFunctionType import ActivationFunctionType - input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 1, "input tensors length should be 1" input_tensor = input_tensors[0] @@ -1113,32 +1103,12 @@ def convert_relu(self, op): output_tensor = output_tensors[0] if input_tensor.qnn_params: - # Quantize a float value to an quantized integer value - scale_val = get_scalar_from_constant(input_tensor.qnn_params["scale"]) - zero_point_val = get_scalar_from_constant(input_tensor.qnn_params["zero_point"]) - - output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type()) - out = self.convert_qnn_fused_activation_function( - expr=in_expr, - fused_activation_fn=ActivationFunctionType.RELU, - scale=scale_val, - zero_point=zero_point_val, - dtype=output_tensor_type_str, - ) + in_f32 = self.dequantize(in_expr, input_tensor) + out = relax.op.nn.relu(in_f32) + out = self.quantize(out, output_tensor) else: out = relax.op.nn.relu(in_expr) - if output_tensor.qnn_params: - output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type()) - out = _qnn.op.requantize( - out, - input_scale=input_tensor.qnn_params["scale"], - input_zero_point=input_tensor.qnn_params["zero_point"], - output_scale=output_tensor.qnn_params["scale"], - output_zero_point=output_tensor.qnn_params["zero_point"], - out_dtype=output_tensor_type_str, - ) - return out def convert_hard_swish(self, op): @@ -1174,8 +1144,6 @@ def _hard_swish(data): def convert_relu6(self, op): """Convert TFLite ReLU6""" - from tflite.ActivationFunctionType import ActivationFunctionType - input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 1, "input tensors length should be 1" input_tensor = input_tensors[0] @@ -1186,32 +1154,12 @@ def convert_relu6(self, op): output_tensor = output_tensors[0] if input_tensor.qnn_params: - # Quantize a float value to an quantized integer value - scale_val = get_scalar_from_constant(input_tensor.qnn_params["scale"]) - zero_point_val = get_scalar_from_constant(input_tensor.qnn_params["zero_point"]) - - output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type()) - out = self.convert_qnn_fused_activation_function( - expr=in_expr, - fused_activation_fn=ActivationFunctionType.RELU6, - scale=scale_val, - zero_point=zero_point_val, - dtype=output_tensor_type_str, - ) + in_f32 = self.dequantize(in_expr, input_tensor) + out = relax.op.clip(in_f32, min=0, max=6) + out = self.quantize(out, output_tensor) else: out = relax.op.clip(in_expr, min=0, max=6) - if output_tensor.qnn_params: - output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type()) - out = _qnn.op.requantize( - out, - input_scale=input_tensor.qnn_params["scale"], - input_zero_point=input_tensor.qnn_params["zero_point"], - output_scale=output_tensor.qnn_params["scale"], - output_zero_point=output_tensor.qnn_params["zero_point"], - out_dtype=output_tensor_type_str, - ) - return out def convert_leaky_relu(self, op): @@ -1255,36 +1203,12 @@ def convert_relu_n1_to_1(self, op): output_tensor = output_tensors[0] if input_tensor.qnn_params: - # Quantize a float value to an quantized integer value - scale_val = get_scalar_from_constant(input_tensor.qnn_params["scale"]) - zero_point_val = get_scalar_from_constant(input_tensor.qnn_params["zero_point"]) - - def quantize(x): - return float(round(x / scale_val) + zero_point_val) - - # Get min/max of the input dtype. This will be used to ensure that - # clip a_min/a_max are not beyond the dtype range. - input_tensor_type_str = self.get_tensor_type_str(input_tensor.tensor.Type()) - qmin = float(tvm.tirx.min_value(input_tensor_type_str).value) - qmax = float(tvm.tirx.max_value(input_tensor_type_str).value) - - out = relax.op.clip( - in_expr, min=max(qmin, quantize(-1.0)), max=min(qmax, quantize(1.0)) - ) + in_f32 = self.dequantize(in_expr, input_tensor) + out = relax.op.clip(in_f32, min=-1, max=1) + out = self.quantize(out, output_tensor) else: out = relax.op.clip(in_expr, min=-1, max=1) - if output_tensor.qnn_params: - output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type()) - out = _qnn.op.requantize( - out, - input_scale=input_tensor.qnn_params["scale"], - input_zero_point=input_tensor.qnn_params["zero_point"], - output_scale=output_tensor.qnn_params["scale"], - output_zero_point=output_tensor.qnn_params["zero_point"], - out_dtype=output_tensor_type_str, - ) - return out def convert_log_softmax(self, op): @@ -1330,18 +1254,12 @@ def convert_concatenation(self, op): if not input_tensors[0].qnn_params: out = relax.op.concat(in_exprs, axis=concatenation_axis) else: - input_scales = [input_tensor.qnn_params["scale"] for input_tensor in input_tensors] - input_zero_points = [ - input_tensor.qnn_params["zero_point"] for input_tensor in input_tensors + in_f32s = [ + self.dequantize(expr, tensor) + for expr, tensor in zip(in_exprs, input_tensors) ] - out = _qnn.op.concat( - in_exprs, - input_scales=input_scales, - input_zero_points=input_zero_points, - output_scale=output_tensor.qnn_params["scale"], - output_zero_point=output_tensor.qnn_params["zero_point"], - axis=concatenation_axis, - ) + out = relax.op.concat(in_f32s, axis=concatenation_axis) + out = self.quantize(out, output_tensor) # Handle fused activations if output_tensor.qnn_params: @@ -2441,24 +2359,16 @@ def _convert_reduce(self, relax_op, op): keep_dims = False if input_tensor.qnn_params: - in_expr = relax.op.cast(in_expr, "int32") + in_expr = self.dequantize(in_expr, input_tensor) out = relax_op(in_expr, axis, keep_dims) - # Finally if the reduce is quantized. Add a requantize at the end. + # Finally if the reduce is quantized. Quantize the output. output_tensors = self.get_output_tensors(op) assert len(output_tensors) == 1, "output tensors length should be 1" output_tensor = output_tensors[0] - output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type()) if output_tensor.qnn_params: - out = _qnn.op.requantize( - out, - input_scale=input_tensor.qnn_params["scale"], - input_zero_point=input_tensor.qnn_params["zero_point"], - output_scale=output_tensor.qnn_params["scale"], - output_zero_point=output_tensor.qnn_params["zero_point"], - out_dtype=output_tensor_type_str, - ) + out = self.quantize(out, output_tensor) return out @@ -2575,20 +2485,24 @@ def convert_fully_connected(self, op): ) weight_expr = self.get_tensor_expr(weight_tensor) - weight_shape = weight_expr.struct_info.shape weight_expr = relax.op.permute_dims(weight_expr, [1, 0]) if input_tensor.qnn_params: - out = _qnn.op.dense( - in_expr, + # Dequantize input and weight (OC remapped from axis 0 to 1) + in_f32 = self.dequantize(in_expr, input_tensor) + weight_axis = weight_tensor.qnn_params["axis"] + if weight_axis != 0: + raise tvm.error.OpAttributeInvalid( + f"FC weight QuantizedDimension() must be 0 (output-channel " + f"axis in [OC,IC] layout), got {weight_axis}" + ) + w_f32 = relax.op.dequantize( weight_expr, - input_zero_point=input_tensor.qnn_params["zero_point"], - kernel_zero_point=weight_tensor.qnn_params["zero_point"], - input_scale=input_tensor.qnn_params["scale"], - kernel_scale=weight_tensor.qnn_params["scale"], - units=weight_shape[0], - out_dtype="int64" if output_tensor_type_str == "int16" else "int32", + scale=weight_tensor.qnn_params["scale"], + zero_point=weight_tensor.qnn_params["zero_point"], + axis=1, ) + out = relax.op.matmul(in_f32, w_f32) else: out = relax.op.matmul(in_expr, weight_expr) @@ -2612,27 +2526,27 @@ def convert_fully_connected(self, op): dtype=bias_tensor_type_str, source_name=bias_tensor.tensor.Name(), ) + if bias_tensor.qnn_params: + bias_expr = self.dequantize(bias_expr, bias_tensor) + elif input_tensor.qnn_params and bias_tensor_type in ( + TensorType.INT32, + TensorType.INT64, + ): + bias_scale = relax.op.multiply( + input_tensor.qnn_params["scale"], + weight_tensor.qnn_params["scale"], + ) + bias_expr = relax.op.dequantize( + bias_expr, + scale=bias_scale, + zero_point=relax.const(0, "int32"), + axis=0, + ) out = relax.op.add(out, bias_expr) - # Finally if the dense is quantized. Add a requantize at the end. + # Finally if the dense is quantized. Quantize the output. if output_tensor.qnn_params: - data_scale = input_tensor.qnn_params["scale"] - weight_scale = weight_tensor.qnn_params["scale"] - data_scale_val = get_scalar_from_constant(data_scale) - weight_scale_val = get_scalar_from_constant(weight_scale) - new_input_scale_val = data_scale_val * weight_scale_val - new_input_scale = relax.const(new_input_scale_val, "float32") - new_input_zero_point = relax.const(0, "int32") - - # Requantize - out = _qnn.op.requantize( - out, - input_scale=new_input_scale, - input_zero_point=new_input_zero_point, - output_scale=output_tensor.qnn_params["scale"], - output_zero_point=output_tensor.qnn_params["zero_point"], - out_dtype=output_tensor_type_str, - ) + out = self.quantize(out, output_tensor) # Call activation function output_scale_val = get_scalar_from_constant(output_tensor.qnn_params["scale"]) @@ -2844,15 +2758,35 @@ def convert_conv(self, op, conv_type): ) if input_tensor.qnn_params: - qnn_conv2d_params = dict(params) - qnn_conv2d_params["input_zero_point"] = input_tensor.qnn_params["zero_point"] - qnn_conv2d_params["kernel_zero_point"] = weight_tensor.qnn_params["zero_point"] - qnn_conv2d_params["out_dtype"] = ( - "int64" if output_tensor_type_str == "int16" else "int32" + # Dequantize input activation + in_f32 = self.dequantize(in_expr, input_tensor) + # Dequantize weight with per-channel axis remap. + # TFLite weight original layout: [OC, KH, KW, IC] + # After transpose to HWIO: [KH, KW, IC, OC] + # QuantizedDimension() == 0 (OC in original) → axis 3 in HWIO. + weight_axis = weight_tensor.qnn_params["axis"] + if is_depthwise_conv: + if weight_axis != 0: + raise tvm.error.OpNotImplemented( + "Per-channel quantized depthwise convolution is not supported " + "because the channel axis changes semantics after the " + "[1,KH,KW,C*M] → [KH,KW,C,M] reshape." + ) + else: + if weight_axis != 0: + raise tvm.error.OpAttributeInvalid( + f"Conv2D weight QuantizedDimension() must be 0 (output-channel " + f"axis in [OC,KH,KW,IC] layout), got {weight_axis}" + ) + weight_axis = 3 + w_f32 = relax.op.dequantize( + weight_expr, + scale=weight_tensor.qnn_params["scale"], + zero_point=weight_tensor.qnn_params["zero_point"], + axis=weight_axis, ) - qnn_conv2d_params["input_scale"] = input_tensor.qnn_params["scale"] - qnn_conv2d_params["kernel_scale"] = weight_tensor.qnn_params["scale"] - out = _qnn.op.conv2d(in_expr, weight_expr, **qnn_conv2d_params) + # Float convolution + out = relax.op.nn.conv2d(in_f32, w_f32, **params) else: out = relax.op.nn.conv2d(in_expr, weight_expr, **params) @@ -2875,37 +2809,31 @@ def convert_conv(self, op, conv_type): dtype=bias_tensor_type_str, source_name=bias_tensor.tensor.Name(), ) + # For quantized conv, INT32/INT64 bias must be dequantized + # to float32 before adding to the float conv output. + if bias_tensor.qnn_params: + bias_expr = self.dequantize(bias_expr, bias_tensor) + elif input_tensor.qnn_params and bias_tensor_type in ( + TensorType.INT32, + TensorType.INT64, + ): + bias_expr = relax.op.dequantize( + bias_expr, + scale=relax.op.multiply( + input_tensor.qnn_params["scale"], + weight_tensor.qnn_params["scale"], + ), + zero_point=relax.const(0, "int32"), + axis=0, + ) out = relax.op.add(out, bias_expr) # Handle fused activation. if output_tensor.qnn_params: - # Calculate the intermediate scale and zero point of the int32 output. - data_scale = input_tensor.qnn_params["scale"] - data_scale_val = get_scalar_from_constant(data_scale) - - weight_scale = weight_tensor.qnn_params["scale"] - # If weight scale is scalar, it is per-tensor quantization - if isinstance(weight_scale, float): - weight_scale_val = get_scalar_from_constant(weight_scale) - else: - weight_scale_val = get_tensor_from_constant(weight_scale) - - new_input_scale_val = data_scale_val * weight_scale_val - new_input_scale = relax.const(new_input_scale_val, "float32") - new_input_zero_point = relax.const(0, "int32") - - # Finally requantize - out = _qnn.op.requantize( - out, - input_scale=new_input_scale, - input_zero_point=new_input_zero_point, - output_scale=output_tensor.qnn_params["scale"], - output_zero_point=output_tensor.qnn_params["zero_point"], - out_dtype=output_tensor_type_str, - axis=3, - ) + # Quantize the float output using the output tensor's qnn params. + out = self.quantize(out, output_tensor) - # Call activation function + # Call quantized activation function output_scale_val = get_scalar_from_constant(output_tensor.qnn_params["scale"]) output_zero_point_val = get_scalar_from_constant(output_tensor.qnn_params["zero_point"]) out = self.convert_qnn_fused_activation_function( @@ -4385,25 +4313,27 @@ def convert_transpose_conv(self, op): padding = (0, 0, 0, 0) if input_tensor.qnn_params: - input_zero_point = input_tensor.qnn_params["zero_point"] - kernel_zero_point = weights_tensor.qnn_params["zero_point"] - input_scale = input_tensor.qnn_params["scale"] - kernel_scale = weights_tensor.qnn_params["scale"] - out_dtype = "int64" if output_tensor_type_str == "int16" else "int32" - out = _qnn.op.conv2d_transpose( - in_expr, + in_f32 = self.dequantize(in_expr, input_tensor) + weight_axis = weights_tensor.qnn_params["axis"] + if weight_axis != 0: + raise tvm.error.OpAttributeInvalid( + f"TransposeConv weight QuantizedDimension() must be 0 " + f"(output-channel axis in OHWI layout), got {weight_axis}" + ) + w_f32 = relax.op.dequantize( weight_expr_iohw, - input_zero_point, - kernel_zero_point, - input_scale, - kernel_scale, + scale=weights_tensor.qnn_params["scale"], + zero_point=weights_tensor.qnn_params["zero_point"], + axis=1, + ) + out = relax.op.nn.conv2d_transpose( + in_f32, + w_f32, strides=(stride_h, stride_w), padding=padding, - channels=int(out_channels), - kernel_size=(int(kernel_h), int(kernel_w)), data_layout="NHWC", kernel_layout="IOHW", - out_dtype=out_dtype, + out_dtype="float32", ) else: out = relax.op.nn.conv2d_transpose( @@ -4435,34 +4365,26 @@ def convert_transpose_conv(self, op): dtype=bias_tensor_type_str, source_name=bias_tensor.tensor.Name(), ) - channel_axis = 3 - out = relax.op.nn.bias_add(out, bias_expr, axis=channel_axis) + if bias_tensor.qnn_params: + bias_expr = self.dequantize(bias_expr, bias_tensor) + elif input_tensor.qnn_params and bias_tensor_type in ( + TensorType.INT32, + TensorType.INT64, + ): + bias_scale = relax.op.multiply( + input_tensor.qnn_params["scale"], + weights_tensor.qnn_params["scale"], + ) + bias_expr = relax.op.dequantize( + bias_expr, + scale=bias_scale, + zero_point=relax.const(0, "int32"), + axis=0, + ) + out = relax.op.add(out, bias_expr) if output_tensor.qnn_params: - # Calculate the intermediate scale and zero point of the int32 output. - data_scale = input_tensor.qnn_params["scale"] - data_scale_val = get_scalar_from_constant(data_scale) - - weight_scale = weights_tensor.qnn_params["scale"] - # If weight scale is scalar, it is per-tensor quantization - if isinstance(weight_scale, float): - weight_scale_val = get_scalar_from_constant(weight_scale) - else: - weight_scale_val = get_tensor_from_constant(weight_scale) - - new_input_scale_val = data_scale_val * weight_scale_val - new_input_scale = relax.const(new_input_scale_val, "float32") - new_input_zero_point = relax.const(0, "int32") - - out = _qnn.op.requantize( - out, - input_scale=new_input_scale, - input_zero_point=new_input_zero_point, - output_scale=output_tensor.qnn_params["scale"], - output_zero_point=output_tensor.qnn_params["zero_point"], - out_dtype=output_tensor_type_str, - axis=3, - ) + out = self.quantize(out, output_tensor) return out def convert_quantize(self, op): @@ -4477,7 +4399,6 @@ def convert_quantize(self, op): output_tensors = self.get_output_tensors(op) assert len(output_tensors) == 1, "output tensors length should be 1" output_tensor = output_tensors[0] - output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type()) # The output must be quantized assert output_tensor.qnn_params @@ -4486,14 +4407,8 @@ def convert_quantize(self, op): if input_tensor_type_str == "float32": out = self.quantize(in_expr, output_tensor) else: - out = _qnn.op.requantize( - in_expr, - input_scale=input_tensor.qnn_params["scale"], - input_zero_point=input_tensor.qnn_params["zero_point"], - output_scale=output_tensor.qnn_params["scale"], - output_zero_point=output_tensor.qnn_params["zero_point"], - out_dtype=output_tensor_type_str, - ) + in_f32 = self.dequantize(in_expr, input_tensor) + out = self.quantize(in_f32, output_tensor) return out def convert_dequantize(self, op): @@ -4642,23 +4557,11 @@ def convert_detection_postprocess(self, op): ) if inputs[0].qnn_params: - loc_prob = _qnn.op.dequantize( - data=loc_prob, - input_scale=inputs[0].qnn_params["scale"], - input_zero_point=inputs[0].qnn_params["zero_point"], - ) + loc_prob = self.dequantize(loc_prob, inputs[0]) if inputs[1].qnn_params: - cls_pred = _qnn.op.dequantize( - data=cls_pred, - input_scale=inputs[1].qnn_params["scale"], - input_zero_point=inputs[1].qnn_params["zero_point"], - ) + cls_pred = self.dequantize(cls_pred, inputs[1]) if inputs[2].qnn_params: - anchor_expr = _qnn.op.dequantize( - data=anchor_expr, - input_scale=inputs[2].qnn_params["scale"], - input_zero_point=inputs[2].qnn_params["zero_point"], - ) + anchor_expr = self.dequantize(anchor_expr, inputs[2]) # loc_prob coords are in yxhw format # need to convert to xywh diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index a53906d2f147..0ca614113c72 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -2400,8 +2400,8 @@ def _convert_detection_postprocess_with_options( converter.exp_tab = tflite_frontend.ExprTable() converter.get_input_tensors = lambda op: inputs converter.get_expr = lambda tensor_idx: {0: loc, 1: cls}[tensor_idx] - converter.get_tensor_value = ( - lambda tensor: _DETECTION_POSTPROCESS_ANCHORS if tensor.tensor_idx == 2 else None + converter.get_tensor_value = lambda tensor: ( + _DETECTION_POSTPROCESS_ANCHORS if tensor.tensor_idx == 2 else None ) converter.get_tensor_type_str = lambda tensor_type: "float32" op = _StubDetectionPostprocessOp(custom_options) @@ -3655,7 +3655,9 @@ def _get_tflite_schema_enum(enum_name): _tfl_add_options = _get_tflite_schema_module("AddOptions") _tfl_buffer = _get_tflite_schema_module("Buffer") _tfl_conv2d_options = _get_tflite_schema_module("Conv2DOptions") +_tfl_depthwise_conv2d_options = _get_tflite_schema_module("DepthwiseConv2DOptions") _tfl_dilate_options = _get_tflite_schema_module("DilateOptions") +_tfl_transpose_conv_options = _get_tflite_schema_module("TransposeConvOptions") # ── StableHLO BuiltinOptions2 schema modules ──────────────────────────── _tfl_stablehlo_concat_opts = _get_tflite_schema_module("StablehloConcatenateOptions") @@ -3673,6 +3675,7 @@ def _get_tflite_schema_enum(enum_name): _tfl_model = _get_tflite_schema_module("Model") _tfl_operator = _get_tflite_schema_module("Operator") _tfl_operator_code = _get_tflite_schema_module("OperatorCode") +_tfl_quantization_parameters = _get_tflite_schema_module("QuantizationParameters") _tfl_sparsity_parameters = _get_tflite_schema_module("SparsityParameters") _tfl_subgraph = _get_tflite_schema_module("SubGraph") _tfl_tensor = _get_tflite_schema_module("Tensor") @@ -3704,6 +3707,20 @@ def _tflite_int32_vector(builder, start_vector_fn, values): return builder.EndVector() +def _tflite_int64_vector(builder, start_vector_fn, values): + start_vector_fn(builder, len(values)) + for value in reversed(values): + builder.PrependInt64(value) + return builder.EndVector() + + +def _tflite_float32_vector(builder, start_vector_fn, values): + start_vector_fn(builder, len(values)) + for value in reversed(values): + builder.PrependFloat32(value) + return builder.EndVector() + + def _tflite_offset_vector(builder, start_vector_fn, offsets): start_vector_fn(builder, len(offsets)) for offset in reversed(offsets): @@ -3735,7 +3752,7 @@ def _tflite_shape(builder, shape): return _tflite_int32_vector(builder, _tfl_tensor.TensorStartShapeVector, shape) -def _build_tensor(builder, buffer_idx, shape, sparsity=None, tensor_type=None): +def _build_tensor(builder, buffer_idx, shape, sparsity=None, tensor_type=None, quantization=None): """Helper to build a TFLite tensor.""" if tensor_type is None: tensor_type = _tfl_tensor_type.FLOAT32 @@ -3747,6 +3764,8 @@ def _build_tensor(builder, buffer_idx, shape, sparsity=None, tensor_type=None): _tfl_tensor.TensorAddShape(builder, shape_vec) if sparsity is not None: _tfl_tensor.TensorAddSparsity(builder, sparsity) + if quantization is not None: + _tfl_tensor.TensorAddQuantization(builder, quantization) _tfl_tensor.TensorAddType(builder, tensor_type) return _tfl_tensor.TensorEnd(builder) @@ -3763,6 +3782,24 @@ def _build_buffer(builder, data=None): return _tfl_buffer.BufferEnd(builder) +def _build_quantization_parameters(builder, *, scale, zero_point, quantized_dimension): + scale_vec = _tflite_float32_vector( + builder, _tfl_quantization_parameters.QuantizationParametersStartScaleVector, scale + ) + zero_point_vec = _tflite_int64_vector( + builder, + _tfl_quantization_parameters.QuantizationParametersStartZeroPointVector, + zero_point, + ) + _tfl_quantization_parameters.QuantizationParametersStart(builder) + _tfl_quantization_parameters.QuantizationParametersAddScale(builder, scale_vec) + _tfl_quantization_parameters.QuantizationParametersAddZeroPoint(builder, zero_point_vec) + _tfl_quantization_parameters.QuantizationParametersAddQuantizedDimension( + builder, quantized_dimension + ) + return _tfl_quantization_parameters.QuantizationParametersEnd(builder) + + def _build_operator( builder, opcode_index, @@ -4970,6 +5007,1045 @@ def test_stablehlo_dynamic_slice_out_of_bounds_unsupported(): from_tflite(tflite_model) +def test_tensor_quantization_parameters_are_parsed(): + """Tensor quantization metadata is kept without requiring quantized op support.""" + builder = flatbuffers.Builder(1024) + + per_tensor_quantization = _build_quantization_parameters( + builder, scale=[0.5], zero_point=[3], quantized_dimension=0 + ) + per_axis_quantization = _build_quantization_parameters( + builder, scale=[0.25, 0.75], zero_point=[0, 0], quantized_dimension=3 + ) + per_tensor = _build_tensor( + builder, + 0, + [1, 4], + tensor_type=_tfl_tensor_type.UINT8, + quantization=per_tensor_quantization, + ) + per_axis = _build_tensor( + builder, + 1, + [1, 2, 3, 2], + tensor_type=_tfl_tensor_type.INT8, + quantization=per_axis_quantization, + ) + subgraph = _build_subgraph( + builder, tensors=[per_tensor, per_axis], operators=[], inputs=[0, 1], outputs=[0, 1] + ) + buffers = [_build_buffer(builder), _build_buffer(builder)] + buf = _finish_tflite_model(builder, subgraph=subgraph, operator_codes=[], buffers=buffers) + + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) + + converter = tflite_frontend.OperatorConverter( + tflite_model, tflite_model.Subgraphs(0), tflite_frontend.ExprTable(), None + ) + per_tensor_wrapper, per_axis_wrapper = converter.get_tensors([0, 1]) + + np.testing.assert_allclose(per_tensor_wrapper.qnn_params["scale"].data.numpy(), 0.5) + np.testing.assert_equal(per_tensor_wrapper.qnn_params["zero_point"].data.numpy(), 3) + assert per_tensor_wrapper.qnn_params["axis"] == 0 + + np.testing.assert_allclose( + per_axis_wrapper.qnn_params["scale"].data.numpy(), np.array([0.25, 0.75]) + ) + np.testing.assert_equal(per_axis_wrapper.qnn_params["zero_point"].data.numpy(), 0) + assert per_axis_wrapper.qnn_params["axis"] == 3 + + mod = from_tflite(tflite_model) + assert len(mod["main"].params) == 2 + + +def test_quantize_op_uses_relax_quantize(): + """TFLite QUANTIZE float32 -> int8 uses R.quantize.""" + builder = flatbuffers.Builder(1024) + + input_data = np.array([1.0, 2.0], dtype=np.float32) + output_qparams = _build_quantization_parameters( + builder, scale=[0.5], zero_point=[3], quantized_dimension=0 + ) + + input_tensor = _build_tensor(builder, 0, [2], tensor_type=_tfl_tensor_type.FLOAT32) + output_tensor = _build_tensor( + builder, + 1, + [2], + tensor_type=_tfl_tensor_type.INT8, + quantization=output_qparams, + ) + + quantize_op = _build_operator( + builder, + 0, + [0], + [1], + ) + subgraph = _build_subgraph( + builder, + tensors=[input_tensor, output_tensor], + operators=[quantize_op], + inputs=[0], + outputs=[1], + ) + operator_codes = [ + _build_operator_code(builder, _tfl_builtin_operator.QUANTIZE), # QUANTIZE + ] + input_buffer = _build_buffer(builder, input_data.tobytes()) + output_buffer = _build_buffer(builder) + buf = _finish_tflite_model( + builder, + subgraph=subgraph, + operator_codes=operator_codes, + buffers=[input_buffer, output_buffer], + ) + + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) + mod = from_tflite(tflite_model) + mod["main"] = mod["main"].without_attr("params") + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2,), dtype="float32")) -> R.Tensor((2,), dtype="int8"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + gv: R.Tensor((2,), dtype="int8") = R.quantize( + x, + R.const(0.5, "float32"), + R.const(3, "int32"), + axis=0, + out_dtype="int8", + ) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_quantize_op_requantize_uses_dq_q(): + """TFLite QUANTIZE with quantized input uses DQ→Q (requantize).""" + builder = flatbuffers.Builder(1024) + + input_data = np.array([10, 20], dtype=np.int8) + input_qparams = _build_quantization_parameters( + builder, scale=[0.25], zero_point=[1], quantized_dimension=0 + ) + output_qparams = _build_quantization_parameters( + builder, scale=[0.5], zero_point=[3], quantized_dimension=0 + ) + + input_tensor = _build_tensor( + builder, + 0, + [2], + tensor_type=_tfl_tensor_type.INT8, + quantization=input_qparams, + ) + output_tensor = _build_tensor( + builder, + 1, + [2], + tensor_type=_tfl_tensor_type.INT8, + quantization=output_qparams, + ) + + quantize_op = _build_operator( + builder, + 0, + [0], + [1], + ) + subgraph = _build_subgraph( + builder, + tensors=[input_tensor, output_tensor], + operators=[quantize_op], + inputs=[0], + outputs=[1], + ) + operator_codes = [ + _build_operator_code(builder, _tfl_builtin_operator.QUANTIZE), + ] + input_buffer = _build_buffer(builder, input_data.tobytes()) + output_buffer = _build_buffer(builder) + buf = _finish_tflite_model( + builder, + subgraph=subgraph, + operator_codes=operator_codes, + buffers=[input_buffer, output_buffer], + ) + + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) + mod = from_tflite(tflite_model) + mod["main"] = mod["main"].without_attr("params") + + @I.ir_module + class Expected: + @R.function + def main( + tvmgen_tensor_0: R.Tensor((2,), dtype="int8"), + ) -> R.Tensor((2,), dtype="int8"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + lv: R.Tensor((2,), dtype="float32") = R.dequantize( + tvmgen_tensor_0, + R.const(0.25, "float32"), + R.const(1, "int32"), + out_dtype="float32", + axis=0, + ) + gv: R.Tensor((2,), dtype="int8") = R.quantize( + lv, + R.const(0.5, "float32"), + R.const(3, "int32"), + out_dtype="int8", + axis=0, + ) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_dequantize_op_uses_relax_dequantize(): + """TFLite DEQUANTIZE int8 -> float32 uses R.dequantize.""" + builder = flatbuffers.Builder(1024) + + input_data = np.array([10, 20], dtype=np.int8) + input_qparams = _build_quantization_parameters( + builder, scale=[0.5], zero_point=[3], quantized_dimension=0 + ) + + input_tensor = _build_tensor( + builder, + 0, + [2], + tensor_type=_tfl_tensor_type.INT8, + quantization=input_qparams, + ) + output_tensor = _build_tensor(builder, 1, [2], tensor_type=_tfl_tensor_type.FLOAT32) + + dequantize_op = _build_operator( + builder, + 0, + [0], + [1], + ) + subgraph = _build_subgraph( + builder, + tensors=[input_tensor, output_tensor], + operators=[dequantize_op], + inputs=[0], + outputs=[1], + ) + operator_codes = [ + _build_operator_code(builder, _tfl_builtin_operator.DEQUANTIZE), # DEQUANTIZE + ] + input_buffer = _build_buffer(builder, input_data.tobytes()) + output_buffer = _build_buffer(builder) + buf = _finish_tflite_model( + builder, + subgraph=subgraph, + operator_codes=operator_codes, + buffers=[input_buffer, output_buffer], + ) + + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) + mod = from_tflite(tflite_model) + mod["main"] = mod["main"].without_attr("params") + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2,), dtype="int8")) -> R.Tensor((2,), dtype="float32"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + gv: R.Tensor((2,), dtype="float32") = R.dequantize( + x, + R.const(0.5, "float32"), + R.const(3, "int32"), + axis=0, + ) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_quantized_conv2d_per_tensor_uses_qdq(): + """Quantized Conv2D with per-tensor quantization uses DQ→conv2d→Q.""" + import flatbuffers + import tflite.Model + + builder = flatbuffers.Builder(2048) + + # Per-tensor quantization parameters + in_q = _build_quantization_parameters( + builder, scale=[0.5], zero_point=[3], quantized_dimension=0 + ) + wt_q = _build_quantization_parameters( + builder, scale=[0.25], zero_point=[0], quantized_dimension=0 + ) + out_q = _build_quantization_parameters( + builder, scale=[1.0], zero_point=[0], quantized_dimension=0 + ) + + # Tensors + input_tensor = _build_tensor( + builder, + 0, + [1, 4, 4, 1], + tensor_type=_tfl_tensor_type.INT8, + quantization=in_q, + ) + weight_tensor = _build_tensor( + builder, + 1, + [2, 3, 3, 1], + tensor_type=_tfl_tensor_type.INT8, + quantization=wt_q, + ) + output_tensor = _build_tensor( + builder, + 2, + [1, 2, 2, 2], + tensor_type=_tfl_tensor_type.INT8, + quantization=out_q, + ) + + # Conv2D options (strides=1, padding=VALID) + _tfl_conv2d_options.Conv2DOptionsStart(builder) + _tfl_conv2d_options.Conv2DOptionsAddStrideH(builder, 1) + _tfl_conv2d_options.Conv2DOptionsAddStrideW(builder, 1) + _tfl_conv2d_options.Conv2DOptionsAddPadding(builder, 1) # VALID + _tfl_conv2d_options.Conv2DOptionsAddFusedActivationFunction(builder, 0) # NONE + conv_opts = _tfl_conv2d_options.Conv2DOptionsEnd(builder) + + conv_op = _build_operator( + builder, + 0, + [0, 1], + [2], + builtin_options_type=_tfl_builtin_options.Conv2DOptions, + builtin_options=conv_opts, + ) + subgraph = _build_subgraph( + builder, + tensors=[input_tensor, weight_tensor, output_tensor], + operators=[conv_op], + inputs=[0, 1], + outputs=[2], + ) + operator_codes = [ + _build_operator_code(builder, _tfl_builtin_operator.CONV_2D), + ] + buf = _finish_tflite_model( + builder, + subgraph=subgraph, + operator_codes=operator_codes, + buffers=[_build_buffer(builder), _build_buffer(builder), _build_buffer(builder)], + ) + + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) + mod = from_tflite(tflite_model) + mod["main"] = mod["main"].without_attr("params") + + @I.ir_module + class Expected: + @R.function + def main( + tvmgen_tensor_0: R.Tensor((1, 4, 4, 1), dtype="int8"), + tvmgen_tensor_1: R.Tensor((2, 3, 3, 1), dtype="int8"), + ) -> R.Tensor((1, 2, 2, 2), dtype="int8"): + R.func_attr({"num_input": 2}) + with R.dataflow(): + lv: R.Tensor((1, 4, 4, 1), dtype="float32") = R.dequantize( + tvmgen_tensor_0, + R.const(0.5, "float32"), + R.const(3, "int32"), + out_dtype="float32", + axis=0, + ) + lv1: R.Tensor((3, 3, 1, 2), dtype="int8") = R.permute_dims( + tvmgen_tensor_1, + axes=[1, 2, 3, 0], + ) + lv2: R.Tensor((3, 3, 1, 2), dtype="float32") = R.dequantize( + lv1, + R.const(0.25, "float32"), + R.const(0, "int32"), + out_dtype="float32", + axis=3, + ) + lv3: R.Tensor((1, 2, 2, 2), dtype="float32") = R.nn.conv2d( + lv, + lv2, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="HWIO", + out_layout="NHWC", + out_dtype="void", + ) + gv: R.Tensor((1, 2, 2, 2), dtype="int8") = R.quantize( + lv3, + R.const(1.0, "float32"), + R.const(0, "int32"), + out_dtype="int8", + axis=0, + ) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_quantized_concat_uses_qdq(): + """Quantized CONCATENATION uses DQ each input → concat → Q.""" + import flatbuffers + import tflite.Model + + builder = flatbuffers.Builder(1024) + + in_q = _build_quantization_parameters( + builder, scale=[0.5], zero_point=[3], quantized_dimension=0 + ) + out_q = _build_quantization_parameters( + builder, scale=[0.5], zero_point=[3], quantized_dimension=0 + ) + + t0 = _build_tensor(builder, 0, [1, 2], tensor_type=_tfl_tensor_type.INT8, quantization=in_q) + t1 = _build_tensor(builder, 1, [1, 2], tensor_type=_tfl_tensor_type.INT8, quantization=in_q) + t2 = _build_tensor(builder, 2, [1, 4], tensor_type=_tfl_tensor_type.INT8, quantization=out_q) + + tflite.ConcatenationOptionsStart(builder) + tflite.ConcatenationOptionsAddAxis(builder, 1) + tflite.ConcatenationOptionsAddFusedActivationFunction(builder, 0) + concat_opts = tflite.ConcatenationOptionsEnd(builder) + + concat_op = _build_operator( + builder, + 0, + [0, 1], + [2], + builtin_options_type=_tfl_builtin_options.ConcatenationOptions, + builtin_options=concat_opts, + ) + subgraph = _build_subgraph( + builder, + tensors=[t0, t1, t2], + operators=[concat_op], + inputs=[0, 1], + outputs=[2], + ) + operator_codes = [_build_operator_code(builder, _tfl_builtin_operator.CONCATENATION)] + buf = _finish_tflite_model( + builder, + subgraph=subgraph, + operator_codes=operator_codes, + buffers=[_build_buffer(builder)] * 3, + ) + + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) + mod = from_tflite(tflite_model) + mod["main"] = mod["main"].without_attr("params") + + @I.ir_module + class Expected: + @R.function + def main( + tvmgen_tensor_0: R.Tensor((1, 2), dtype="int8"), + tvmgen_tensor_1: R.Tensor((1, 2), dtype="int8"), + ) -> R.Tensor((1, 4), dtype="int8"): + R.func_attr({"num_input": 2}) + with R.dataflow(): + lv: R.Tensor((1, 2), dtype="float32") = R.dequantize( + tvmgen_tensor_0, + R.const(0.5, "float32"), + R.const(3, "int32"), + out_dtype="float32", + axis=0, + ) + lv1: R.Tensor((1, 2), dtype="float32") = R.dequantize( + tvmgen_tensor_1, + R.const(0.5, "float32"), + R.const(3, "int32"), + out_dtype="float32", + axis=0, + ) + lv2: R.Tensor((1, 4), dtype="float32") = R.concat((lv, lv1), axis=1) + gv: R.Tensor((1, 4), dtype="int8") = R.quantize( + lv2, + R.const(0.5, "float32"), + R.const(3, "int32"), + out_dtype="int8", + axis=0, + ) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_quantized_conv2d_with_int32_bias_dequantizes_bias(): + """Conv2D with INT32 bias dequantizes bias with in_scale x wt_scale.""" + import flatbuffers + import tflite.Model + + builder = flatbuffers.Builder(2048) + + in_q = _build_quantization_parameters( + builder, scale=[0.5], zero_point=[3], quantized_dimension=0 + ) + wt_q = _build_quantization_parameters( + builder, scale=[0.25], zero_point=[0], quantized_dimension=0 + ) + out_q = _build_quantization_parameters( + builder, scale=[1.0], zero_point=[0], quantized_dimension=0 + ) + + t_in = _build_tensor( + builder, 0, [1, 4, 4, 1], tensor_type=_tfl_tensor_type.INT8, quantization=in_q + ) + t_wt = _build_tensor( + builder, 1, [2, 3, 3, 1], tensor_type=_tfl_tensor_type.INT8, quantization=wt_q + ) + t_bi = _build_tensor(builder, 2, [2], tensor_type=_tfl_tensor_type.INT32) + t_ou = _build_tensor( + builder, 3, [1, 2, 2, 2], tensor_type=_tfl_tensor_type.INT8, quantization=out_q + ) + + _tfl_conv2d_options.Conv2DOptionsStart(builder) + _tfl_conv2d_options.Conv2DOptionsAddStrideH(builder, 1) + _tfl_conv2d_options.Conv2DOptionsAddStrideW(builder, 1) + _tfl_conv2d_options.Conv2DOptionsAddPadding(builder, 1) + _tfl_conv2d_options.Conv2DOptionsAddFusedActivationFunction(builder, 0) + conv_opts = _tfl_conv2d_options.Conv2DOptionsEnd(builder) + + conv_op = _build_operator( + builder, + 0, + [0, 1, 2], + [3], + builtin_options_type=_tfl_builtin_options.Conv2DOptions, + builtin_options=conv_opts, + ) + subgraph = _build_subgraph( + builder, + tensors=[t_in, t_wt, t_bi, t_ou], + operators=[conv_op], + inputs=[0, 1, 2], + outputs=[3], + ) + operator_codes = [_build_operator_code(builder, _tfl_builtin_operator.CONV_2D)] + buf = _finish_tflite_model( + builder, + subgraph=subgraph, + operator_codes=operator_codes, + buffers=[_build_buffer(builder)] * 4, + ) + + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) + mod = from_tflite(tflite_model) + mod["main"] = mod["main"].without_attr("params") + + @I.ir_module + class Expected: + @R.function + def main( + tvmgen_tensor_0: R.Tensor((1, 4, 4, 1), dtype="int8"), + tvmgen_tensor_1: R.Tensor((2, 3, 3, 1), dtype="int8"), + tvmgen_tensor_2: R.Tensor((2,), dtype="int32"), + ) -> R.Tensor((1, 2, 2, 2), dtype="int8"): + R.func_attr({"num_input": 3}) + with R.dataflow(): + lv: R.Tensor((1, 4, 4, 1), dtype="float32") = R.dequantize( + tvmgen_tensor_0, + R.const(0.5, "float32"), + R.const(3, "int32"), + out_dtype="float32", + axis=0, + ) + lv1: R.Tensor((3, 3, 1, 2), dtype="int8") = R.permute_dims( + tvmgen_tensor_1, + axes=[1, 2, 3, 0], + ) + lv2: R.Tensor((3, 3, 1, 2), dtype="float32") = R.dequantize( + lv1, + R.const(0.25, "float32"), + R.const(0, "int32"), + out_dtype="float32", + axis=3, + ) + lv3: R.Tensor((1, 2, 2, 2), dtype="float32") = R.nn.conv2d( + lv, + lv2, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="HWIO", + out_layout="NHWC", + out_dtype="void", + ) + lv4: R.Tensor((), dtype="float32") = R.multiply( + R.const(0.5, "float32"), + R.const(0.25, "float32"), + ) + lv5: R.Tensor((2,), dtype="float32") = R.dequantize( + tvmgen_tensor_2, + lv4, + R.const(0, "int32"), + out_dtype="float32", + axis=0, + ) + lv6: R.Tensor((1, 2, 2, 2), dtype="float32") = R.add(lv3, lv5) + gv: R.Tensor((1, 2, 2, 2), dtype="int8") = R.quantize( + lv6, + R.const(1.0, "float32"), + R.const(0, "int32"), + out_dtype="int8", + axis=0, + ) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_per_channel_depthwise_conv_unsupported(): + """Per-channel quantized depthwise Conv2D raises OpNotImplemented.""" + import flatbuffers + import tflite.Model + + builder = flatbuffers.Builder(1024) + + in_q = _build_quantization_parameters( + builder, scale=[0.5], zero_point=[0], quantized_dimension=0 + ) + # Per-channel weight: 2 channels, scale vector length 2 + wt_q = _build_quantization_parameters( + builder, scale=[0.25, 0.75], zero_point=[0, 0], quantized_dimension=3 + ) + out_q = _build_quantization_parameters( + builder, scale=[1.0], zero_point=[0], quantized_dimension=0 + ) + + t_in = _build_tensor( + builder, 0, [1, 4, 4, 2], tensor_type=_tfl_tensor_type.INT8, quantization=in_q + ) + t_wt = _build_tensor( + builder, 1, [1, 3, 3, 2], tensor_type=_tfl_tensor_type.INT8, quantization=wt_q + ) + t_ou = _build_tensor( + builder, 2, [1, 2, 2, 2], tensor_type=_tfl_tensor_type.INT8, quantization=out_q + ) + + _tfl_depthwise_conv2d_options.DepthwiseConv2DOptionsStart(builder) + _tfl_depthwise_conv2d_options.DepthwiseConv2DOptionsAddStrideH(builder, 1) + _tfl_depthwise_conv2d_options.DepthwiseConv2DOptionsAddStrideW(builder, 1) + _tfl_depthwise_conv2d_options.DepthwiseConv2DOptionsAddDepthMultiplier(builder, 1) + _tfl_depthwise_conv2d_options.DepthwiseConv2DOptionsAddPadding(builder, 1) + _tfl_depthwise_conv2d_options.DepthwiseConv2DOptionsAddFusedActivationFunction(builder, 0) + dw_opts = _tfl_depthwise_conv2d_options.DepthwiseConv2DOptionsEnd(builder) + + dw_op = _build_operator( + builder, + 0, + [0, 1], + [2], + builtin_options_type=_tfl_builtin_options.DepthwiseConv2DOptions, + builtin_options=dw_opts, + ) + subgraph = _build_subgraph( + builder, + tensors=[t_in, t_wt, t_ou], + operators=[dw_op], + inputs=[0, 1], + outputs=[2], + ) + operator_codes = [_build_operator_code(builder, _tfl_builtin_operator.DEPTHWISE_CONV_2D)] + buf = _finish_tflite_model( + builder, + subgraph=subgraph, + operator_codes=operator_codes, + buffers=[_build_buffer(builder)] * 3, + ) + + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) + + with pytest.raises(tvm.error.OpNotImplemented, match="Per-channel"): + from_tflite(tflite_model) + + +def test_uint8_reshape_requantize_uses_dq_reshape_q(): + """uint8 RESHAPE with different qparams uses DQ→reshape→Q.""" + import flatbuffers + import numpy as np + import tflite.Model + + builder = flatbuffers.Builder(1024) + + in_q = _build_quantization_parameters( + builder, scale=[0.5], zero_point=[128], quantized_dimension=0 + ) + out_q = _build_quantization_parameters( + builder, scale=[1.0], zero_point=[100], quantized_dimension=0 + ) + + t_in = _build_tensor(builder, 0, [1, 4], tensor_type=_tfl_tensor_type.UINT8, quantization=in_q) + t_ou = _build_tensor(builder, 1, [2, 2], tensor_type=_tfl_tensor_type.UINT8, quantization=out_q) + + # Use ReshapeOptions with static new_shape [2, 2] + new_shape_np = np.array([2, 2], dtype=np.int32) + new_shape_vec = _tflite_int32_vector( + builder, tflite.ReshapeOptionsStartNewShapeVector, new_shape_np + ) + tflite.ReshapeOptionsStart(builder) + tflite.ReshapeOptionsAddNewShape(builder, new_shape_vec) + reshape_opts = tflite.ReshapeOptionsEnd(builder) + + reshape_op = _build_operator( + builder, + 0, + [0], + [1], + builtin_options_type=_tfl_builtin_options.ReshapeOptions, + builtin_options=reshape_opts, + ) + subgraph = _build_subgraph( + builder, + tensors=[t_in, t_ou], + operators=[reshape_op], + inputs=[0], + outputs=[1], + ) + operator_codes = [_build_operator_code(builder, _tfl_builtin_operator.RESHAPE)] + buf = _finish_tflite_model( + builder, + subgraph=subgraph, + operator_codes=operator_codes, + buffers=[_build_buffer(builder), _build_buffer(builder)], + ) + + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) + mod = from_tflite(tflite_model) + mod["main"] = mod["main"].without_attr("params") + + @I.ir_module + class Expected: + @R.function + def main( + tvmgen_tensor_0: R.Tensor((1, 4), dtype="uint8"), + ) -> R.Tensor((2, 2), dtype="uint8"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + lv: R.Tensor((1, 4), dtype="float32") = R.dequantize( + tvmgen_tensor_0, + R.const(0.5, "float32"), + R.const(128, "int32"), + out_dtype="float32", + axis=0, + ) + lv1: R.Tensor((2, 2), dtype="float32") = R.reshape( + lv, + R.shape([2, 2]), + ) + gv: R.Tensor((2, 2), dtype="uint8") = R.quantize( + lv1, + R.const(1.0, "float32"), + R.const(100, "int32"), + out_dtype="uint8", + axis=0, + ) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_transpose_conv_with_int32_bias_dequantizes_bias(): + """TRANSPOSE_CONV with INT32 bias dequantizes bias before adding.""" + import struct + + import flatbuffers + import tflite.Model + + builder = flatbuffers.Builder(2048) + + in_q = _build_quantization_parameters( + builder, scale=[0.5], zero_point=[3], quantized_dimension=0 + ) + wt_q = _build_quantization_parameters( + builder, scale=[0.25], zero_point=[0], quantized_dimension=0 + ) + out_q = _build_quantization_parameters( + builder, scale=[1.0], zero_point=[0], quantized_dimension=0 + ) + + t_in = _build_tensor( + builder, 0, [1, 1, 1, 1], tensor_type=_tfl_tensor_type.INT8, quantization=in_q + ) + t_wt = _build_tensor( + builder, 1, [1, 1, 1, 1], tensor_type=_tfl_tensor_type.INT8, quantization=wt_q + ) + t_bi = _build_tensor(builder, 2, [1], tensor_type=_tfl_tensor_type.INT32) + t_ou = _build_tensor( + builder, 3, [1, 1, 1, 1], tensor_type=_tfl_tensor_type.INT8, quantization=out_q + ) + oshape_data = struct.pack(" R.Tensor((1, 1, 1, 1), dtype="int8"): + R.func_attr({"num_input": 3}) + with R.dataflow(): + lv: R.Tensor((1, 1, 1, 1), dtype="float32") = R.dequantize( + tvmgen_tensor_0, + R.const(0.5, "float32"), + R.const(3, "int32"), + out_dtype="float32", + axis=0, + ) + lv1: R.Tensor((1, 1, 1, 1), dtype="int8") = R.permute_dims( + tvmgen_tensor_1, + axes=[3, 0, 1, 2], + ) + lv2: R.Tensor((1, 1, 1, 1), dtype="float32") = R.dequantize( + lv1, + R.const(0.25, "float32"), + R.const(0, "int32"), + out_dtype="float32", + axis=1, + ) + lv3: R.Tensor((1, 1, 1, 1), dtype="float32") = R.nn.conv2d_transpose( + lv, + lv2, + strides=[1, 1], + padding=[0, 0, 0, 0], + data_layout="NHWC", + kernel_layout="IOHW", + out_dtype="float32", + ) + lv4: R.Tensor((), dtype="float32") = R.multiply( + R.const(0.5, "float32"), + R.const(0.25, "float32"), + ) + lv5: R.Tensor((1,), dtype="float32") = R.dequantize( + tvmgen_tensor_2, + lv4, + R.const(0, "int32"), + out_dtype="float32", + axis=0, + ) + lv6: R.Tensor((1, 1, 1, 1), dtype="float32") = R.add(lv3, lv5) + gv: R.Tensor((1, 1, 1, 1), dtype="int8") = R.quantize( + lv6, + R.const(1.0, "float32"), + R.const(0, "int32"), + out_dtype="int8", + axis=0, + ) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_quantized_fully_connected_with_int32_bias_dequantizes_bias(): + """Quantized FullyConnected with INT32 bias dequantizes bias with in_scale x wt_scale.""" + import flatbuffers + import tflite.Model + + builder = flatbuffers.Builder(2048) + + in_q = _build_quantization_parameters( + builder, scale=[0.5], zero_point=[3], quantized_dimension=0 + ) + wt_q = _build_quantization_parameters( + builder, scale=[0.25], zero_point=[0], quantized_dimension=0 + ) + out_q = _build_quantization_parameters( + builder, scale=[1.0], zero_point=[0], quantized_dimension=0 + ) + + t_in = _build_tensor(builder, 0, [1, 4], tensor_type=_tfl_tensor_type.INT8, quantization=in_q) + t_wt = _build_tensor(builder, 1, [2, 4], tensor_type=_tfl_tensor_type.INT8, quantization=wt_q) + t_bi = _build_tensor(builder, 2, [2], tensor_type=_tfl_tensor_type.INT32) + t_ou = _build_tensor(builder, 3, [1, 2], tensor_type=_tfl_tensor_type.INT8, quantization=out_q) + + _tfl_fully_connected_options.FullyConnectedOptionsStart(builder) + _tfl_fully_connected_options.FullyConnectedOptionsAddFusedActivationFunction(builder, 0) + _tfl_fully_connected_options.FullyConnectedOptionsAddWeightsFormat( + builder, _tfl_fc_weights_format.DEFAULT + ) + _tfl_fully_connected_options.FullyConnectedOptionsAddKeepNumDims(builder, 0) + fc_opts = _tfl_fully_connected_options.FullyConnectedOptionsEnd(builder) + + fc_op = _build_operator( + builder, + 0, + [0, 1, 2], + [3], + builtin_options_type=_tfl_builtin_options.FullyConnectedOptions, + builtin_options=fc_opts, + ) + subgraph = _build_subgraph( + builder, + tensors=[t_in, t_wt, t_bi, t_ou], + operators=[fc_op], + inputs=[0, 1, 2], + outputs=[3], + ) + operator_codes = [_build_operator_code(builder, _tfl_builtin_operator.FULLY_CONNECTED)] + buf = _finish_tflite_model( + builder, + subgraph=subgraph, + operator_codes=operator_codes, + buffers=[_build_buffer(builder)] * 4, + ) + + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) + mod = from_tflite(tflite_model) + mod["main"] = mod["main"].without_attr("params") + + @I.ir_module + class Expected: + @R.function + def main( + tvmgen_tensor_0: R.Tensor((1, 4), dtype="int8"), + tvmgen_tensor_1: R.Tensor((2, 4), dtype="int8"), + tvmgen_tensor_2: R.Tensor((2,), dtype="int32"), + ) -> R.Tensor((1, 2), dtype="int8"): + R.func_attr({"num_input": 3}) + with R.dataflow(): + lv: R.Tensor((1, 4), dtype="float32") = R.dequantize( + tvmgen_tensor_0, + R.const(0.5, "float32"), + R.const(3, "int32"), + out_dtype="float32", + axis=0, + ) + lv1: R.Tensor((4, 2), dtype="int8") = R.permute_dims( + tvmgen_tensor_1, + axes=[1, 0], + ) + lv2: R.Tensor((4, 2), dtype="float32") = R.dequantize( + lv1, + R.const(0.25, "float32"), + R.const(0, "int32"), + out_dtype="float32", + axis=1, + ) + lv3: R.Tensor((1, 2), dtype="float32") = R.matmul(lv, lv2, out_dtype="void") + lv4: R.Tensor((), dtype="float32") = R.multiply( + R.const(0.5, "float32"), + R.const(0.25, "float32"), + ) + lv5: R.Tensor((2,), dtype="float32") = R.dequantize( + tvmgen_tensor_2, + lv4, + R.const(0, "int32"), + out_dtype="float32", + axis=0, + ) + lv6: R.Tensor((1, 2), dtype="float32") = R.add(lv3, lv5) + gv: R.Tensor((1, 2), dtype="int8") = R.quantize( + lv6, + R.const(1.0, "float32"), + R.const(0, "int32"), + out_dtype="int8", + axis=0, + ) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + def _build_csr_sparsity( builder, *,