diff --git a/exir/emit/test/test_emit.py b/exir/emit/test/test_emit.py index 469f1c238c9..b2b473cc8c9 100644 --- a/exir/emit/test/test_emit.py +++ b/exir/emit/test/test_emit.py @@ -2452,6 +2452,44 @@ def forward(self, x): self.assertTrue(expected.shape == et_result.shape) self.assertTrue(torch.allclose(expected, et_result)) + def test_emit_sym_int(self) -> None: + class SymIntModel(nn.Module): + def forward(self, x): + n = x.shape[0] + f = torch.sym_float(n) + i = torch.sym_int(f) + return torch.zeros(i, dtype=x.dtype, device=x.device) + + model = SymIntModel() + model.eval() + test_inputs = [ + torch.randn(3, 4), + torch.randn(8, 4), + ] + reference_outputs = [] + with torch.no_grad(): + for inp in test_inputs: + reference_outputs.append(model(inp)) + + batch_dim = Dim("batch", min=1, max=20) + dynamic_shapes = {"x": {0: batch_dim}} + exported_program = torch.export.export( + model, (test_inputs[0],), dynamic_shapes=dynamic_shapes + ) + + edge_program = to_edge( + exported_program, + compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), + ) + et_program = edge_program.to_executorch() + program_buffer = et_program.buffer + et_module = _load_for_executorch_from_buffer(program_buffer) + for inp, expected in zip(test_inputs, reference_outputs): + et_output = et_module.forward([inp]) + et_result = et_output[0] + self.assertTrue(expected.shape == et_result.shape) + self.assertTrue(torch.allclose(expected, et_result)) + def test_emit_channels_last_constant(self) -> None: """Test that channels-last constant tensors are emitted correctly. diff --git a/exir/passes/executorch_prim_ops_registry.py b/exir/passes/executorch_prim_ops_registry.py index 37c36b6547a..c2235ae34ad 100644 --- a/exir/passes/executorch_prim_ops_registry.py +++ b/exir/passes/executorch_prim_ops_registry.py @@ -58,6 +58,11 @@ def sym_float(a: _SymScalar) -> _SymScalar: return float(a) # pyre-ignore +@bind_pattern_to_op(executorch_prims_lib, "sym_int.Scalar(Scalar a) -> Scalar") +def sym_int(a: _SymScalar) -> _SymScalar: + return int(a) # pyre-ignore + + # TODO: ideally we should return SymBool in the schema, but it seems # the schema parser does not recognize SymBool yet: P629748075 @bind_pattern_to_op(executorch_prims_lib, "gt.Scalar(Scalar a, Scalar b) -> bool") @@ -146,6 +151,7 @@ def sym_not(a: _SymScalar) -> bool: operator.mod: ops.backend.executorch_prim.mod.Scalar, operator.neg: ops.backend.executorch_prim.neg.Scalar, torch.sym_float: ops.backend.executorch_prim.sym_float.Scalar, + torch.sym_int: ops.backend.executorch_prim.sym_int.Scalar, torch.sym_max: ops.backend.executorch_prim.sym_max.Scalar, torch.sym_min: ops.backend.executorch_prim.sym_min.Scalar, torch.sym_not: ops.backend.executorch_prim.sym_not.Scalar, diff --git a/kernels/prim_ops/register_prim_ops.cpp b/kernels/prim_ops/register_prim_ops.cpp index 77ee50b4b04..468411a706b 100644 --- a/kernels/prim_ops/register_prim_ops.cpp +++ b/kernels/prim_ops/register_prim_ops.cpp @@ -400,6 +400,40 @@ static Kernel prim_ops[] = { }), #endif +#if !defined(EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD) || \ + defined(INCLUDE_EXECUTORCH_PRIM_SYM_INT_SCALAR) + // executorch_prim::sym_int.Scalar(Scalar) -> Scalar + Kernel( + "executorch_prim::sym_int.Scalar", + [](KernelRuntimeContext& context, Span stack) { + ET_KERNEL_CHECK_MSG( + context, + stack.size() == 2, + InvalidProgram, + /* void */, + "Expected %zu args, got %zu", + (size_t)2, + stack.size()); + EValue& a = *stack[0]; + EValue& out = *stack[1]; + if (a.isInt()) { + out = EValue(a.toInt()); + } else if (a.isDouble()) { + out = EValue(static_cast(a.toDouble())); + } else if (a.isBool()) { + out = EValue(static_cast(a.toBool())); + } else { + ET_KERNEL_CHECK_MSG( + context, + false, + InvalidType, + /* void */, + "sym_int only supports int, double, or bool inputs, got %zu", + (size_t)a.tag); + } + }), +#endif + #if !defined(EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD) || \ defined(INCLUDE_EXECUTORCH_PRIM_EQ_SCALAR) // executorch_prim::eq.Scalar(Scalar, Scalar) -> bool diff --git a/kernels/prim_ops/test/prim_ops_test.cpp b/kernels/prim_ops/test/prim_ops_test.cpp index 3d90943a303..c2ad20da7ec 100644 --- a/kernels/prim_ops/test/prim_ops_test.cpp +++ b/kernels/prim_ops/test/prim_ops_test.cpp @@ -41,6 +41,7 @@ TEST_F(RegisterPrimOpsTest, OpRegistered) { EXPECT_TRUE(hasOpsFn("executorch_prim::sym_max.Scalar")); EXPECT_TRUE(hasOpsFn("executorch_prim::sym_min.Scalar")); EXPECT_TRUE(hasOpsFn("executorch_prim::sym_not.Scalar")); + EXPECT_TRUE(hasOpsFn("executorch_prim::sym_int.Scalar")); } TEST_F(RegisterPrimOpsTest, SymSizeReturnsCorrectValue) { @@ -204,6 +205,42 @@ TEST_F(RegisterPrimOpsTest, SymFloatHandlesBoolInput) { EXPECT_FLOAT_EQ(stack[1]->toDouble(), 0.0); } +TEST_F(RegisterPrimOpsTest, SymIntReturnsCorrectValue) { + EValue values[2]; + EValue* stack[2]; + for (size_t i = 0; i < 2; i++) { + stack[i] = &values[i]; + } + + // Int passthrough + values[0] = EValue((int64_t)7); + values[1] = EValue((int64_t)0); + getOpsFn("executorch_prim::sym_int.Scalar")(context_, Span(stack)); + EXPECT_EQ(stack[1]->toInt(), 7); + + // Double truncates toward zero + values[0] = EValue(3.7); + values[1] = EValue((int64_t)0); + getOpsFn("executorch_prim::sym_int.Scalar")(context_, Span(stack)); + EXPECT_EQ(stack[1]->toInt(), 3); + + values[0] = EValue(-2.9); + values[1] = EValue((int64_t)0); + getOpsFn("executorch_prim::sym_int.Scalar")(context_, Span(stack)); + EXPECT_EQ(stack[1]->toInt(), -2); + + // Bool converts to 0/1 + values[0] = EValue(true); + values[1] = EValue((int64_t)0); + getOpsFn("executorch_prim::sym_int.Scalar")(context_, Span(stack)); + EXPECT_EQ(stack[1]->toInt(), 1); + + values[0] = EValue(false); + values[1] = EValue((int64_t)0); + getOpsFn("executorch_prim::sym_int.Scalar")(context_, Span(stack)); + EXPECT_EQ(stack[1]->toInt(), 0); +} + TEST_F(RegisterPrimOpsTest, TestAlgebraOps) { EValue values[3]; int64_t a = 3;