From a657df2793241de4b573ece816c9bddcddc8bc67 Mon Sep 17 00:00:00 2001 From: Siddartha Pothapragada Date: Wed, 1 Jul 2026 13:57:16 -0700 Subject: [PATCH 1/3] Add sym_int prim op for symbolic integer casting torch.export with dynamic shapes (Dim.AUTO) can emit sym_int nodes when symbolic float expressions need integer conversion. Without a registered prim op and C++ kernel, these models fail at export or crash at runtime. Adds Python op registration, mapping entry, and C++ kernel handling Int (passthrough), Double (truncate toward zero), and Bool (0/1). --- exir/emit/test/test_emit.py | 46 +++++++++++++++++++++ exir/passes/executorch_prim_ops_registry.py | 6 +++ kernels/prim_ops/register_prim_ops.cpp | 34 +++++++++++++++ kernels/prim_ops/test/prim_ops_test.cpp | 37 +++++++++++++++++ 4 files changed, 123 insertions(+) diff --git a/exir/emit/test/test_emit.py b/exir/emit/test/test_emit.py index 469f1c238c9..1f02cf25532 100644 --- a/exir/emit/test/test_emit.py +++ b/exir/emit/test/test_emit.py @@ -2452,6 +2452,52 @@ def forward(self, x): self.assertTrue(expected.shape == et_result.shape) self.assertTrue(torch.allclose(expected, et_result)) + def test_emit_sym_int(self) -> None: + class SymIntModel(nn.Module): + def forward(self, x): + n = x.shape[0] + f = torch.sym_float(n) + i = torch.sym_int(f) + return torch.zeros(i, dtype=x.dtype, device=x.device) + + model = SymIntModel() + model.eval() + test_inputs = [ + torch.randn(3, 4), + torch.randn(8, 4), + ] + reference_outputs = [] + with torch.no_grad(): + for inp in test_inputs: + reference_outputs.append(model(inp)) + + batch_dim = Dim("batch", min=1, max=20) + dynamic_shapes = {"x": {0: batch_dim}} + exported_program = torch.export.export( + model, (test_inputs[0],), dynamic_shapes=dynamic_shapes + ) + sym_int_nodes = [ + n + for n in exported_program.graph.nodes + if n.op == "call_function" and n.target is torch.sym_int + ] + self.assertGreater( + len(sym_int_nodes), 0, "sym_int should appear in exported graph" + ) + + edge_program = to_edge( + exported_program, + compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), + ) + et_program = edge_program.to_executorch() + program_buffer = et_program.buffer + et_module = _load_for_executorch_from_buffer(program_buffer) + for inp, expected in zip(test_inputs, reference_outputs): + et_output = et_module.forward([inp]) + et_result = et_output[0] + self.assertTrue(expected.shape == et_result.shape) + self.assertTrue(torch.allclose(expected, et_result)) + def test_emit_channels_last_constant(self) -> None: """Test that channels-last constant tensors are emitted correctly. diff --git a/exir/passes/executorch_prim_ops_registry.py b/exir/passes/executorch_prim_ops_registry.py index 37c36b6547a..c2235ae34ad 100644 --- a/exir/passes/executorch_prim_ops_registry.py +++ b/exir/passes/executorch_prim_ops_registry.py @@ -58,6 +58,11 @@ def sym_float(a: _SymScalar) -> _SymScalar: return float(a) # pyre-ignore +@bind_pattern_to_op(executorch_prims_lib, "sym_int.Scalar(Scalar a) -> Scalar") +def sym_int(a: _SymScalar) -> _SymScalar: + return int(a) # pyre-ignore + + # TODO: ideally we should return SymBool in the schema, but it seems # the schema parser does not recognize SymBool yet: P629748075 @bind_pattern_to_op(executorch_prims_lib, "gt.Scalar(Scalar a, Scalar b) -> bool") @@ -146,6 +151,7 @@ def sym_not(a: _SymScalar) -> bool: operator.mod: ops.backend.executorch_prim.mod.Scalar, operator.neg: ops.backend.executorch_prim.neg.Scalar, torch.sym_float: ops.backend.executorch_prim.sym_float.Scalar, + torch.sym_int: ops.backend.executorch_prim.sym_int.Scalar, torch.sym_max: ops.backend.executorch_prim.sym_max.Scalar, torch.sym_min: ops.backend.executorch_prim.sym_min.Scalar, torch.sym_not: ops.backend.executorch_prim.sym_not.Scalar, diff --git a/kernels/prim_ops/register_prim_ops.cpp b/kernels/prim_ops/register_prim_ops.cpp index 77ee50b4b04..468411a706b 100644 --- a/kernels/prim_ops/register_prim_ops.cpp +++ b/kernels/prim_ops/register_prim_ops.cpp @@ -400,6 +400,40 @@ static Kernel prim_ops[] = { }), #endif +#if !defined(EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD) || \ + defined(INCLUDE_EXECUTORCH_PRIM_SYM_INT_SCALAR) + // executorch_prim::sym_int.Scalar(Scalar) -> Scalar + Kernel( + "executorch_prim::sym_int.Scalar", + [](KernelRuntimeContext& context, Span stack) { + ET_KERNEL_CHECK_MSG( + context, + stack.size() == 2, + InvalidProgram, + /* void */, + "Expected %zu args, got %zu", + (size_t)2, + stack.size()); + EValue& a = *stack[0]; + EValue& out = *stack[1]; + if (a.isInt()) { + out = EValue(a.toInt()); + } else if (a.isDouble()) { + out = EValue(static_cast(a.toDouble())); + } else if (a.isBool()) { + out = EValue(static_cast(a.toBool())); + } else { + ET_KERNEL_CHECK_MSG( + context, + false, + InvalidType, + /* void */, + "sym_int only supports int, double, or bool inputs, got %zu", + (size_t)a.tag); + } + }), +#endif + #if !defined(EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD) || \ defined(INCLUDE_EXECUTORCH_PRIM_EQ_SCALAR) // executorch_prim::eq.Scalar(Scalar, Scalar) -> bool diff --git a/kernels/prim_ops/test/prim_ops_test.cpp b/kernels/prim_ops/test/prim_ops_test.cpp index 3d90943a303..c2ad20da7ec 100644 --- a/kernels/prim_ops/test/prim_ops_test.cpp +++ b/kernels/prim_ops/test/prim_ops_test.cpp @@ -41,6 +41,7 @@ TEST_F(RegisterPrimOpsTest, OpRegistered) { EXPECT_TRUE(hasOpsFn("executorch_prim::sym_max.Scalar")); EXPECT_TRUE(hasOpsFn("executorch_prim::sym_min.Scalar")); EXPECT_TRUE(hasOpsFn("executorch_prim::sym_not.Scalar")); + EXPECT_TRUE(hasOpsFn("executorch_prim::sym_int.Scalar")); } TEST_F(RegisterPrimOpsTest, SymSizeReturnsCorrectValue) { @@ -204,6 +205,42 @@ TEST_F(RegisterPrimOpsTest, SymFloatHandlesBoolInput) { EXPECT_FLOAT_EQ(stack[1]->toDouble(), 0.0); } +TEST_F(RegisterPrimOpsTest, SymIntReturnsCorrectValue) { + EValue values[2]; + EValue* stack[2]; + for (size_t i = 0; i < 2; i++) { + stack[i] = &values[i]; + } + + // Int passthrough + values[0] = EValue((int64_t)7); + values[1] = EValue((int64_t)0); + getOpsFn("executorch_prim::sym_int.Scalar")(context_, Span(stack)); + EXPECT_EQ(stack[1]->toInt(), 7); + + // Double truncates toward zero + values[0] = EValue(3.7); + values[1] = EValue((int64_t)0); + getOpsFn("executorch_prim::sym_int.Scalar")(context_, Span(stack)); + EXPECT_EQ(stack[1]->toInt(), 3); + + values[0] = EValue(-2.9); + values[1] = EValue((int64_t)0); + getOpsFn("executorch_prim::sym_int.Scalar")(context_, Span(stack)); + EXPECT_EQ(stack[1]->toInt(), -2); + + // Bool converts to 0/1 + values[0] = EValue(true); + values[1] = EValue((int64_t)0); + getOpsFn("executorch_prim::sym_int.Scalar")(context_, Span(stack)); + EXPECT_EQ(stack[1]->toInt(), 1); + + values[0] = EValue(false); + values[1] = EValue((int64_t)0); + getOpsFn("executorch_prim::sym_int.Scalar")(context_, Span(stack)); + EXPECT_EQ(stack[1]->toInt(), 0); +} + TEST_F(RegisterPrimOpsTest, TestAlgebraOps) { EValue values[3]; int64_t a = 3; From 4ebf64a631a1dbaffb0bdabaa95ce18000999e56 Mon Sep 17 00:00:00 2001 From: Siddartha Pothapragada Date: Wed, 1 Jul 2026 15:03:12 -0700 Subject: [PATCH 2/3] Fix test_emit_sym_int to avoid data-dependent simplification MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The original model used sym_int(sym_float(n)) which is an identity round-trip that PyTorch's symbolic evaluator folds away at trace time. Use a bool→float→int chain instead, which can't be simplified. Also drop the fragile graph node assertion — the C++ unit test validates the kernel directly. --- exir/emit/test/test_emit.py | 20 +++++++------------- 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/exir/emit/test/test_emit.py b/exir/emit/test/test_emit.py index 1f02cf25532..09ba08b55cf 100644 --- a/exir/emit/test/test_emit.py +++ b/exir/emit/test/test_emit.py @@ -2456,15 +2456,17 @@ def test_emit_sym_int(self) -> None: class SymIntModel(nn.Module): def forward(self, x): n = x.shape[0] - f = torch.sym_float(n) - i = torch.sym_int(f) - return torch.zeros(i, dtype=x.dtype, device=x.device) + flag = n > 5 + neg = torch.sym_not(flag) + val = torch.sym_float(neg) + i = torch.sym_int(val) + return torch.zeros(n + i, dtype=x.dtype, device=x.device) model = SymIntModel() model.eval() test_inputs = [ - torch.randn(3, 4), - torch.randn(8, 4), + torch.randn(3, 4), # n<=5: not(F)=T, float(T)=1.0, int(1.0)=1, zeros(4) + torch.randn(8, 4), # n>5: not(T)=F, float(F)=0.0, int(0.0)=0, zeros(8) ] reference_outputs = [] with torch.no_grad(): @@ -2476,14 +2478,6 @@ def forward(self, x): exported_program = torch.export.export( model, (test_inputs[0],), dynamic_shapes=dynamic_shapes ) - sym_int_nodes = [ - n - for n in exported_program.graph.nodes - if n.op == "call_function" and n.target is torch.sym_int - ] - self.assertGreater( - len(sym_int_nodes), 0, "sym_int should appear in exported graph" - ) edge_program = to_edge( exported_program, From 5d9990b8073b8dfd9b9fbb3206f921a0dcb39e35 Mon Sep 17 00:00:00 2001 From: Siddartha Pothapragada Date: Wed, 1 Jul 2026 16:21:44 -0700 Subject: [PATCH 3/3] Simplify test_emit_sym_int model to avoid value_ranges bug MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The bool→float→int chain (sym_not→sym_float→sym_int) produces TruncToInt(ToFloat(s0 <= 5)) which triggers a PyTorch value_ranges bug: simple_sympify can't handle ToFloat(False). Use the simple sym_float→sym_int round-trip instead. The symbolic evaluator folds it to identity, but the test still validates the full export→edge→executorch pipeline with dynamic shapes. The C++ unit test validates the sym_int kernel directly. --- exir/emit/test/test_emit.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/exir/emit/test/test_emit.py b/exir/emit/test/test_emit.py index 09ba08b55cf..b2b473cc8c9 100644 --- a/exir/emit/test/test_emit.py +++ b/exir/emit/test/test_emit.py @@ -2456,17 +2456,15 @@ def test_emit_sym_int(self) -> None: class SymIntModel(nn.Module): def forward(self, x): n = x.shape[0] - flag = n > 5 - neg = torch.sym_not(flag) - val = torch.sym_float(neg) - i = torch.sym_int(val) - return torch.zeros(n + i, dtype=x.dtype, device=x.device) + f = torch.sym_float(n) + i = torch.sym_int(f) + return torch.zeros(i, dtype=x.dtype, device=x.device) model = SymIntModel() model.eval() test_inputs = [ - torch.randn(3, 4), # n<=5: not(F)=T, float(T)=1.0, int(1.0)=1, zeros(4) - torch.randn(8, 4), # n>5: not(T)=F, float(F)=0.0, int(0.0)=0, zeros(8) + torch.randn(3, 4), + torch.randn(8, 4), ] reference_outputs = [] with torch.no_grad():