diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 7d85906cffdd..3cb1ab6756ad 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -2327,6 +2327,72 @@ def _normalize_constant_axes(axes: list[int], rank: int, op_name: str) -> list[i return normalized_axes +def _try_get_static_int(value: int | tirx.PrimExpr) -> int | None: + """Return a Python int for statically known integer values.""" + + if isinstance(value, int): + return value + if isinstance(value, tirx.IntImm): + return int(value.value) + return None + + +def _get_static_tensor_dim(expr: relax.Expr, axis: int) -> int | None: + """Return a statically known tensor dimension when available.""" + + if isinstance(expr, relax.Constant): + return int(expr.data.numpy().shape[axis]) + + struct_info = expr.struct_info + if not isinstance(struct_info, relax.TensorStructInfo): + return None + if not isinstance(struct_info.shape, relax.ShapeExpr): + return None + return _try_get_static_int(struct_info.shape.values[axis]) + + +def _canonicalize_onnx_slice_index(index: int, dim: int, step: int, is_start: bool) -> int: + """Canonicalize a Slice bound using ONNX's static clamp rules.""" + + if index < 0: + index += dim + + lower_bound = 0 + upper_bound = dim + if step < 0: + lower_bound = 0 if is_start else -1 + upper_bound = dim - 1 + + return min(max(index, lower_bound), upper_bound) + + +def _normalize_empty_constant_slice_ranges( + data: relax.Expr, + axes: list[int], + starts: list[int | tirx.PrimExpr], + ends: list[int | tirx.PrimExpr], + steps: list[int | tirx.PrimExpr], +) -> list[int | tirx.PrimExpr]: + """Normalize statically empty ONNX Slice ranges for Relax strided_slice.""" + + normalized_ends = list(ends) + for i, axis in enumerate(axes): + start = _try_get_static_int(starts[i]) + end = _try_get_static_int(ends[i]) + step = _try_get_static_int(steps[i]) + dim = _get_static_tensor_dim(data, axis) + if start is None or end is None or step is None or dim is None: + continue + + start = _canonicalize_onnx_slice_index(start, dim, step, is_start=True) + end = _canonicalize_onnx_slice_index(end, dim, step, is_start=False) + is_empty = start >= end if step > 0 else start <= end + if is_empty: + normalized_ends[i] = starts[i] + + return normalized_ends + + def _as_int64_tensor(bb: relax.BlockBuilder, expr: relax.Expr) -> relax.Expr: """Convert a tensor-like expression to an int64 tensor expression.""" @@ -2490,6 +2556,7 @@ def _impl_v13(cls, bb, inputs, attr, params): assume_inbound = not all( [isinstance(param, tirx.IntImm | int) for param in [*starts, *ends, *steps]] ) + ends = _normalize_empty_constant_slice_ranges(data, axes, starts, ends, steps) starts = get_prim_value_list(starts) ends = get_prim_value_list(ends) steps = get_prim_value_list(steps) diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 52a4064cc8f5..5b73ec35183f 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -2916,6 +2916,7 @@ def verify_slice(data_shape, output_shape, starts, ends, axes=None, steps=None): steps=[-1, -3, -2], axes=[0, 1, 2], ) + verify_slice([10], [0], starts=[0], ends=[5], axes=[0], steps=[-1]) verify_slice([20, 10, 5], [10, 5], starts=[0, 0], ends=[3, 10], axes=[1, 2]) verify_slice([20, 10, 5], [10, 5], starts=[0, 0], ends=[3, 10], axes=[1, 2])