diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index f5b88b0c6ad5..41e3abdef927 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -128,6 +128,7 @@ def __init__(self, model, subgraph, exp_tab, ctx): "BITCAST": self.convert_bitcast, "BROADCAST_TO": self.convert_broadcast_to, "BROADCAST_ARGS": self.convert_broadcast_args, + "BUCKETIZE": self.convert_bucketize, "CAST": self.convert_cast, "CEIL": functools.partial(self._convert_unary_elemwise, relax_op=_op.ceil), "CONCATENATION": self.convert_concatenation, @@ -213,6 +214,7 @@ def __init__(self, model, subgraph, exp_tab, ctx): "RELU": self.convert_relu, "RELU6": self.convert_relu6, "RELU_N1_TO_1": self.convert_relu_n1_to_1, + "RELU_0_TO_1": self.convert_relu_0_to_1, "RESHAPE": self.convert_reshape, "RESIZE_BILINEAR": self.convert_resize_bilinear, "RESIZE_NEAREST_NEIGHBOR": self.convert_resize_nearest_neighbor, @@ -227,6 +229,7 @@ def __init__(self, model, subgraph, exp_tab, ctx): self._convert_segment_op, op_name="SEGMENT_SUM", reduction="add" ), "SHAPE": self.convert_shape, + "SIGN": functools.partial(self._convert_unary_elemwise, relax_op=_op.sign), "SIN": functools.partial(self._convert_unary_elemwise, relax_op=_op.sin), "SLICE": self.convert_slice, "SOFTMAX": self.convert_softmax, @@ -249,6 +252,7 @@ def __init__(self, model, subgraph, exp_tab, ctx): "TRANSPOSE_CONV": self.convert_transpose_conv, "TRANSPOSE": self.convert_transpose, "UNPACK": self.convert_unpack, + "UNIQUE": self.convert_unique, "UNSORTED_SEGMENT_MIN": functools.partial( self._convert_segment_op, op_name="UNSORTED_SEGMENT_MIN", reduction="min" ), @@ -977,7 +981,12 @@ def convert_tanh(self, op): return out def convert_range(self, op): - """Convert TFLite Range""" + """Convert TFLite Range + + Handles both constant and dynamic scalar inputs. When all three operands + are compile-time constants the output shape is fully static; when any + operand is a dynamic Relax expr the shape is symbolic. + """ from tflite.TensorType import TensorType @@ -986,28 +995,25 @@ def convert_range(self, op): start, limit, delta = input_tensors[0], input_tensors[1], input_tensors[2] - def get_scalar_value(tensor): + def get_scalar_or_expr(tensor): + """Return a Python scalar for constants, a Relax expr for dynamic inputs.""" if self.has_expr(tensor.tensor_idx): expr = self.get_expr(tensor.tensor_idx) if isinstance(expr, relax.Constant): value = expr.data.numpy() - else: - # relax.op.arange currently expects scalar-like values here. - # Keep dynamic scalar RANGE explicit until frontend support is added. - raise tvm.error.OpNotImplemented( - "TFLite RANGE with dynamic scalar inputs is not supported in Relax frontend yet." - ) - else: - value = self.get_tensor_value(tensor) - + assert value.size == 1, "RANGE scalar input must have exactly one element" + return value.item() + # Dynamic: pass the 0-d tensor expr directly to relax.op.arange. + return expr + value = self.get_tensor_value(tensor) # TFLite RANGE operands are scalar tensors in the flatbuffer. assert value.size == 1, "RANGE scalar input must have exactly one element" return value.item() - start_value = get_scalar_value(start) - limit_value = get_scalar_value(limit) - delta_value = get_scalar_value(delta) - + start_value = get_scalar_or_expr(start) + limit_value = get_scalar_or_expr(limit) + delta_value = get_scalar_or_expr(delta) + # out type inference if delta.tensor.Type() == TensorType.FLOAT32: out_type = self.get_tensor_type_str(delta.tensor.Type()) @@ -1229,6 +1235,46 @@ def quantize(x): return out + def convert_relu_0_to_1(self, op): + """Convert TFLite RELU_0_TO_1 — clips input to [0, 1].""" + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 1, "input tensors length should be 1" + input_tensor = input_tensors[0] + in_expr = self.get_expr(input_tensor.tensor_idx) + + output_tensors = self.get_output_tensors(op) + assert len(output_tensors) == 1, "output tensors length should be 1" + output_tensor = output_tensors[0] + + if input_tensor.qnn_params: + scale_val = get_scalar_from_constant(input_tensor.qnn_params["scale"]) + zero_point_val = get_scalar_from_constant(input_tensor.qnn_params["zero_point"]) + + def quantize(x): + return float(math.floor(x / scale_val + 0.5) + zero_point_val) + + input_tensor_type_str = self.get_tensor_type_str(input_tensor.tensor.Type()) + qmin = float(tvm.tirx.min_value(input_tensor_type_str).value) + qmax = float(tvm.tirx.max_value(input_tensor_type_str).value) + out = relax.op.clip( + in_expr, min=max(qmin, quantize(0.0)), max=min(qmax, quantize(1.0)) + ) + else: + out = relax.op.clip(in_expr, min=0, max=1) + + if output_tensor.qnn_params: + output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type()) + out = _qnn.op.requantize( + out, + input_scale=input_tensor.qnn_params["scale"], + input_zero_point=input_tensor.qnn_params["zero_point"], + output_scale=output_tensor.qnn_params["scale"], + output_zero_point=output_tensor.qnn_params["zero_point"], + out_dtype=output_tensor_type_str, + ) + + return out + def convert_log_softmax(self, op): """Convert TFLite LOG_SOFTMAX""" input_tensors = self.get_input_tensors(op) @@ -2802,6 +2848,32 @@ def convert_broadcast_args(self, op): relax.op.where(s1_is_one, s0, relax.op.maximum(s0, s1)), ) + def convert_bucketize(self, op): + """Convert TFLite BUCKETIZE → relax.op.bucketize. + + Boundaries are stored as a repeated float in BucketizeOptions, not as a + tensor input, so we materialise them as a compile-time constant. + """ + from tflite.BuiltinOptions import BuiltinOptions + from tflite.BucketizeOptions import BucketizeOptions + + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 1, "input tensors length should be 1" + in_expr = self.get_tensor_expr(input_tensors[0]) + + assert op.BuiltinOptionsType() == BuiltinOptions.BucketizeOptions + op_options = op.BuiltinOptions() + bucket_options = BucketizeOptions() + bucket_options.Init(op_options.Bytes, op_options.Pos) + + boundaries = [ + bucket_options.Boundaries(i) for i in range(bucket_options.BoundariesLength()) + ] + boundaries_const = relax.const(np.array(boundaries, dtype="float32")) + + out = relax.op.bucketize(in_expr, boundaries_const, right=True) + return out + def convert_cast(self, op): """Convert TFLite CAST""" @@ -3125,6 +3197,47 @@ def convert_unpack(self, op): return squeezed + def convert_unique(self, op): + """Convert TFLite UNIQUE → relax.op.unique. + + TFLite always emits two outputs: unique values and the per-element index + back into the unique values. The index dtype (int32 or int64) is encoded + in UniqueOptions. + """ + from tflite.BuiltinOptions import BuiltinOptions + from tflite.UniqueOptions import UniqueOptions + + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 1, "input tensors length should be 1" + in_expr = self.get_tensor_expr(input_tensors[0]) + + output_tensors = self.get_output_tensors(op) + assert len(output_tensors) == 2, "output tensors length should be 2" + + assert op.BuiltinOptionsType() == BuiltinOptions.UniqueOptions + op_options = op.BuiltinOptions() + unique_options = UniqueOptions() + unique_options.Init(op_options.Bytes, op_options.Pos) + + idx_dtype = self.get_tensor_type_str(output_tensors[1].tensor.Type()) + + # relax.op.unique returns (values, indices, inverse_indices, counts). + # TFLite expects (values, indices) where indices map each input element + # to its position in the unique output. That corresponds to inverse_indices. + out = relax.op.unique( + in_expr, + sorted=False, + return_index=False, + return_inverse=True, + return_counts=False, + dim=None, + ) + values = relax.TupleGetItem(out, 0) + inverse_indices = relax.TupleGetItem(out, 1) + if idx_dtype != "int32": + inverse_indices = relax.op.astype(inverse_indices, idx_dtype) + return relax.Tuple([values, inverse_indices]) + """ def convert_unidirectional_sequence_lstm(self, op): ### Long Short Term Memory for TFLite implementation. ### @@ -4415,7 +4528,18 @@ def convert_densify(self, op): self.set_prefetched_node(output_tensor.tensor_idx, dense_weight) def convert_fake_quant(self, op): - """Convert TFLite FAKE_QUANT""" + """Convert TFLite FAKE_QUANT. + + Implements the same nudging logic as the TFLite reference kernel: + https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/fake_quant.cc + + Fixes vs the previous implementation: + * Degenerate range (opt_min == opt_max, scale == 0): early-return a + passthrough clip rather than dividing by zero. + * Use ``quant_max - quant_min`` (= num_levels) consistently as the + scale denominator, which is correct for both narrow_range and + standard configs. + """ input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 1, "input tensors length should be 1" @@ -4440,7 +4564,13 @@ def convert_fake_quant(self, op): quant_min = 1 if narrow_range else 0 quant_max = (1 << num_bits) - 1 - scale = (opt_max - opt_min) / (quant_max - quant_min) + num_levels = quant_max - quant_min # 254 for narrow int8, 255 for standard int8 + + # Guard degenerate range: scale == 0 would cause division by zero. + if opt_max == opt_min: + return relax.op.clip(in_expr, opt_min, opt_max) + + scale = (opt_max - opt_min) / num_levels zero_point_from_min = quant_min - opt_min / scale if zero_point_from_min <= quant_min: @@ -4937,4 +5067,4 @@ def func(self, data): func_attrs["params"] = [tvm.runtime.tensor(arr) for arr in param_value_list] relax_mod["main"] = relax_mod["main"].with_attrs(func_attrs) - return relax_mod + return relax_mod \ No newline at end of file diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index e4c237887e6e..7fea5b7fe89c 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -551,8 +551,8 @@ def func(self): verify(Range) -def test_range_dynamic_scalar_inputs_not_supported(): - """RANGE conversion currently rejects dynamic scalar inputs.""" +def test_range_dynamic(): + """RANGE with dynamic scalar inputs lowers to relax.op.arange.""" class RangeDynamic(tf.Module): @tf.function( @@ -565,8 +565,7 @@ class RangeDynamic(tf.Module): def func(self, start, limit, delta): return tf.range(start, limit, delta, dtype=tf.int32) - with pytest.raises(tvm.error.OpNotImplemented, match="dynamic scalar inputs"): - verify(RangeDynamic) + verify(RangeDynamic) def test_tile_ir(): """TILE conversion with explicit Relax IR structural check.""" @@ -4250,5 +4249,141 @@ def main( tvm.ir.assert_structural_equal(mod, Expected) +def test_sign(): + """SIGN → relax.op.sign (unary elemwise, float and int).""" + + class Sign(tf.Module): + @tf.function(input_signature=[tf.TensorSpec(shape=(3, 4), dtype=tf.float32)]) + def func(self, x): + return tf.math.sign(x) + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor((3, 4), dtype="float32")) -> R.Tensor((3, 4), dtype="float32"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + gv: R.Tensor((3, 4), dtype="float32") = R.sign(x) + R.output(gv) + return gv + + verify(Sign, Expected) + + class SignInt(tf.Module): + @tf.function(input_signature=[tf.TensorSpec(shape=(5,), dtype=tf.int32)]) + def func(self, x): + return tf.math.sign(x) + + verify(SignInt) + + +def test_unique(): + """UNIQUE → relax.op.unique, two-output (values, inverse_indices).""" + + class Unique(tf.Module): + @tf.function(input_signature=[tf.TensorSpec(shape=(6,), dtype=tf.float32)]) + def func(self, x): + y, idx = tf.unique(x) + return y, idx + + verify(Unique) + + class UniqueInt(tf.Module): + @tf.function(input_signature=[tf.TensorSpec(shape=(8,), dtype=tf.int32)]) + def func(self, x): + y, idx = tf.unique(x) + return y, idx + + verify(UniqueInt) + + +def test_bucketize(): + """BUCKETIZE → relax.op.bucketize with constant boundaries from BucketizeOptions.""" + + class Bucketize(tf.Module): + @tf.function(input_signature=[tf.TensorSpec(shape=(5,), dtype=tf.float32)]) + def func(self, x): + return tf.raw_ops.Bucketize(input=x, boundaries=[0.0, 1.0, 2.0]) + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor((5,), dtype="float32")) -> R.Tensor((5,), dtype="int32"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + lv: R.Tensor((3,), dtype="float32") = R.const( + np.array([0.0, 1.0, 2.0], dtype="float32"), "float32" + ) + gv: R.Tensor((5,), dtype="int32") = R.bucketize(x, lv, right=False) + R.output(gv) + return gv + + verify(Bucketize, Expected) + + class BucketizeEmpty(tf.Module): + @tf.function(input_signature=[tf.TensorSpec(shape=(4,), dtype=tf.float32)]) + def func(self, x): + return tf.raw_ops.Bucketize(input=x, boundaries=[]) + + verify(BucketizeEmpty) + + +def test_fake_quant(): + """FAKE_QUANT — standard range, narrow range, and degenerate (min == max).""" + + class FakeQuantStandard(tf.Module): + @tf.function(input_signature=[tf.TensorSpec(shape=(2, 4), dtype=tf.float32)]) + def func(self, x): + return tf.quantization.fake_quant_with_min_max_args( + x, min=-1.0, max=1.0, num_bits=8, narrow_range=False + ) + + verify(FakeQuantStandard) + + class FakeQuantNarrowRange(tf.Module): + @tf.function(input_signature=[tf.TensorSpec(shape=(2, 4), dtype=tf.float32)]) + def func(self, x): + return tf.quantization.fake_quant_with_min_max_args( + x, min=-1.0, max=1.0, num_bits=8, narrow_range=True + ) + + verify(FakeQuantNarrowRange) + + class FakeQuant4Bit(tf.Module): + @tf.function(input_signature=[tf.TensorSpec(shape=(3, 3), dtype=tf.float32)]) + def func(self, x): + return tf.quantization.fake_quant_with_min_max_args( + x, min=0.0, max=15.0, num_bits=4, narrow_range=False + ) + + verify(FakeQuant4Bit) + + # Degenerate range (min == max → scale == 0). The fix must emit a plain + # clip rather than dividing by zero. We check the IR directly to confirm + # that the output is exactly R.clip(x, min=v, max=v) and that no division + # node is present. + class FakeQuantDegenerate(tf.Module): + @tf.function(input_signature=[tf.TensorSpec(shape=(2, 3), dtype=tf.float32)]) + def func(self, x): + return tf.quantization.fake_quant_with_min_max_args( + x, min=0.5, max=0.5, num_bits=8, narrow_range=False + ) + + @I.ir_module + class ExpectedDegenerate: + @R.function + def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + gv: R.Tensor((2, 3), dtype="float32") = R.clip(x, min=0.5, max=0.5) + R.output(gv) + return gv + + mod = verify(FakeQuantDegenerate, ExpectedDegenerate) + # Double-check: no division node must appear in the compiled IR. + ir_text = mod.script() + assert "R.divide(" not in ir_text, "Degenerate FAKE_QUANT must not emit a division node" + + if __name__ == "__main__": - pytest.main(["-s", __file__]) + pytest.main(["-s", __file__]) \ No newline at end of file