Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 19 additions & 17 deletions exir/tensor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
Expand Down Expand Up @@ -69,30 +69,32 @@
"""
from torch.fx.experimental.symbolic_shapes import (
guard_or_false,
guard_size_oblivious,
guard_or_true,
)

for _, s in enumerate(stride):
if guard_or_false(s == 0):
raise ValueError("0 in strides is not supported for ExecuTorch.")
for s in stride:
torch._check(s != 0, lambda: "0 in strides is not supported for ExecuTorch.")

class K(NamedTuple):
stride: int

def __lt__(self, other):
return guard_size_oblivious(self.stride < other.stride)

def __gt__(self, other):
return guard_size_oblivious(self.stride > other.stride)

def __le__(self, other):
return guard_size_oblivious(self.stride <= other.stride)

def __ge__(self, other):
return guard_size_oblivious(self.stride >= other.stride)

def __eq__(self, other):
return guard_size_oblivious(self.stride == other.stride)
# For backed/concrete strides this is practically a `<` operation.
# For unbacked, we return True if `<` is statically known, then
# try to answer symbolically with stride-ordering semantics:
# u0 < u0 -> False
# u0 < u1 (no info) -> False
# u0 < 2 * u0 -> True (divisibility)
# 1 < u0 -> True (1 divides anything; unprovable equality treated optimistically)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stride of 1 will be in every tensor do we expect this to matter?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • This preserves the existing behavior. Pre-PR, guard_size_oblivious(1 < u0) assumed u0 >= 2 and returned True, so dim_order_from_stride((1, u0)) already produced (1, 0). The new code reaches the same answer via explicit reasoning (1 < u0 is unprovable statically, but u0 % 1 == 0 is always true and 1 != u0 is the optimistic tie-breaker). No behavior change for the realistic case, and no new exposure surface.

  • Fundamentally, dim_order is a static permutation frozen at export time, so it can't represent layouts whose orientation depends on a runtime value. If u0 == 1 at runtime the chosen dim_order is "wrong" relative to the u0 != 1 interpretation — but since u0 is dynamic we have to commit to one general case. And if u0 were always 1 in practice, it wouldn't be dynamic in the first place — so optimism toward u0 >= 2 is the right default.

return (
guard_or_false(
self.stride < other.stride
) # statically known inequality
or (
guard_or_false(other.stride % self.stride == 0)
and guard_or_true(self.stride != other.stride)
) # symbolic inequality (e.g. u0 < 2048 * u0)
)

sorted_dims = [
i[0] for i in sorted(enumerate(stride), key=lambda x: K(x[1]), reverse=True)
Expand Down
45 changes: 44 additions & 1 deletion exir/tests/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,9 +246,52 @@ def test_dim_order_from_stride(self) -> None:
# dim[2] is broadcasting dim
# shape = (5, 1, 15, 10)
strides = (10, 10, 0, 1)
with self.assertRaises(ValueError):
# torch._check raises RuntimeError on concrete 0.
with self.assertRaises(RuntimeError):
dim_order = dim_order_from_stride(strides)

def test_dim_order_from_stride_unbacked(self) -> None:
"""
dim_order_from_stride should produce a sane permutation even when the
strides contain unbacked SymInts. The comparator falls back to
divisibility-based reasoning so common cases like (1, u0) and
(u0, 2 * u0) order correctly.
"""
from torch.fx.experimental.symbolic_shapes import ShapeEnv

shape_env = ShapeEnv()
u0 = shape_env.create_unbacked_symint()
u1 = shape_env.create_unbacked_symint()

# 1 < u0 should be True via divisibility (u0 % 1 == 0) + optimistic
# `1 != u0`. Descending sort puts u0 outer, stride 1 inner.
dim_order = dim_order_from_stride((1, u0))
self.assertEqual((1, 0), dim_order)

# u0 < 2 * u0 should be True via divisibility ((2*u0) % u0 == 0) and
# provable inequality (u0 != 0 after torch._check).
dim_order = dim_order_from_stride((u0, 2 * u0))
self.assertEqual((1, 0), dim_order)

# Mixed concrete + symbolic: (1, u0, 2 * u0). Descending stride order
# is (2*u0, u0, 1) -> indices (2, 1, 0).
dim_order = dim_order_from_stride((1, u0, 2 * u0))
self.assertEqual((2, 1, 0), dim_order)

# u0 < u1 (independent unbackeds) is genuinely ambiguous; stable sort
# preserves original order under reverse=True (no swap on ambiguous).
dim_order = dim_order_from_stride((u0, u1))
self.assertEqual((0, 1), dim_order)

# u0 < u0 is False both ways (symmetric); stable sort preserves order.
dim_order = dim_order_from_stride((u0, u0))
self.assertEqual((0, 1), dim_order)

# Unbacked stride of 0 (concrete 0 mixed with unbacked) -> RuntimeError
# via torch._check.
with self.assertRaises(RuntimeError):
dim_order_from_stride((u0, 0, 1))

def test_strides_from_dim_order(self) -> None:
sizes = []
dim_order = []
Expand Down
Loading