Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions exir/emit/test/test_emit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2452,6 +2452,44 @@ def forward(self, x):
self.assertTrue(expected.shape == et_result.shape)
self.assertTrue(torch.allclose(expected, et_result))

def test_emit_sym_int(self) -> None:
class SymIntModel(nn.Module):
def forward(self, x):
n = x.shape[0]
f = torch.sym_float(n)
i = torch.sym_int(f)
return torch.zeros(i, dtype=x.dtype, device=x.device)

model = SymIntModel()
model.eval()
test_inputs = [
torch.randn(3, 4),
torch.randn(8, 4),
]
reference_outputs = []
with torch.no_grad():
for inp in test_inputs:
reference_outputs.append(model(inp))

batch_dim = Dim("batch", min=1, max=20)
dynamic_shapes = {"x": {0: batch_dim}}
exported_program = torch.export.export(
model, (test_inputs[0],), dynamic_shapes=dynamic_shapes
)

edge_program = to_edge(
exported_program,
compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
)
Comment on lines +2476 to +2483
et_program = edge_program.to_executorch()
program_buffer = et_program.buffer
et_module = _load_for_executorch_from_buffer(program_buffer)
for inp, expected in zip(test_inputs, reference_outputs):
et_output = et_module.forward([inp])
et_result = et_output[0]
self.assertTrue(expected.shape == et_result.shape)
self.assertTrue(torch.allclose(expected, et_result))

def test_emit_channels_last_constant(self) -> None:
"""Test that channels-last constant tensors are emitted correctly.

Expand Down
6 changes: 6 additions & 0 deletions exir/passes/executorch_prim_ops_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ def sym_float(a: _SymScalar) -> _SymScalar:
return float(a) # pyre-ignore


@bind_pattern_to_op(executorch_prims_lib, "sym_int.Scalar(Scalar a) -> Scalar")
def sym_int(a: _SymScalar) -> _SymScalar:
return int(a) # pyre-ignore


# TODO: ideally we should return SymBool in the schema, but it seems
# the schema parser does not recognize SymBool yet: P629748075
@bind_pattern_to_op(executorch_prims_lib, "gt.Scalar(Scalar a, Scalar b) -> bool")
Expand Down Expand Up @@ -146,6 +151,7 @@ def sym_not(a: _SymScalar) -> bool:
operator.mod: ops.backend.executorch_prim.mod.Scalar,
operator.neg: ops.backend.executorch_prim.neg.Scalar,
torch.sym_float: ops.backend.executorch_prim.sym_float.Scalar,
torch.sym_int: ops.backend.executorch_prim.sym_int.Scalar,
torch.sym_max: ops.backend.executorch_prim.sym_max.Scalar,
torch.sym_min: ops.backend.executorch_prim.sym_min.Scalar,
torch.sym_not: ops.backend.executorch_prim.sym_not.Scalar,
Expand Down
34 changes: 34 additions & 0 deletions kernels/prim_ops/register_prim_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,40 @@ static Kernel prim_ops[] = {
}),
#endif

#if !defined(EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD) || \
defined(INCLUDE_EXECUTORCH_PRIM_SYM_INT_SCALAR)
// executorch_prim::sym_int.Scalar(Scalar) -> Scalar
Kernel(
"executorch_prim::sym_int.Scalar",
[](KernelRuntimeContext& context, Span<EValue*> stack) {
ET_KERNEL_CHECK_MSG(
context,
stack.size() == 2,
InvalidProgram,
/* void */,
"Expected %zu args, got %zu",
(size_t)2,
stack.size());
EValue& a = *stack[0];
EValue& out = *stack[1];
if (a.isInt()) {
out = EValue(a.toInt());
} else if (a.isDouble()) {
out = EValue(static_cast<int64_t>(a.toDouble()));
} else if (a.isBool()) {
Comment on lines +421 to +423
out = EValue(static_cast<int64_t>(a.toBool()));
} else {
ET_KERNEL_CHECK_MSG(
context,
false,
InvalidType,
/* void */,
"sym_int only supports int, double, or bool inputs, got %zu",
(size_t)a.tag);
}
}),
#endif

#if !defined(EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD) || \
defined(INCLUDE_EXECUTORCH_PRIM_EQ_SCALAR)
// executorch_prim::eq.Scalar(Scalar, Scalar) -> bool
Expand Down
37 changes: 37 additions & 0 deletions kernels/prim_ops/test/prim_ops_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ TEST_F(RegisterPrimOpsTest, OpRegistered) {
EXPECT_TRUE(hasOpsFn("executorch_prim::sym_max.Scalar"));
EXPECT_TRUE(hasOpsFn("executorch_prim::sym_min.Scalar"));
EXPECT_TRUE(hasOpsFn("executorch_prim::sym_not.Scalar"));
EXPECT_TRUE(hasOpsFn("executorch_prim::sym_int.Scalar"));
}

TEST_F(RegisterPrimOpsTest, SymSizeReturnsCorrectValue) {
Expand Down Expand Up @@ -204,6 +205,42 @@ TEST_F(RegisterPrimOpsTest, SymFloatHandlesBoolInput) {
EXPECT_FLOAT_EQ(stack[1]->toDouble(), 0.0);
}

TEST_F(RegisterPrimOpsTest, SymIntReturnsCorrectValue) {
EValue values[2];
EValue* stack[2];
for (size_t i = 0; i < 2; i++) {
stack[i] = &values[i];
}

// Int passthrough
values[0] = EValue((int64_t)7);
values[1] = EValue((int64_t)0);
getOpsFn("executorch_prim::sym_int.Scalar")(context_, Span<EValue*>(stack));
EXPECT_EQ(stack[1]->toInt(), 7);

// Double truncates toward zero
values[0] = EValue(3.7);
values[1] = EValue((int64_t)0);
getOpsFn("executorch_prim::sym_int.Scalar")(context_, Span<EValue*>(stack));
EXPECT_EQ(stack[1]->toInt(), 3);

values[0] = EValue(-2.9);
values[1] = EValue((int64_t)0);
getOpsFn("executorch_prim::sym_int.Scalar")(context_, Span<EValue*>(stack));
EXPECT_EQ(stack[1]->toInt(), -2);

// Bool converts to 0/1
values[0] = EValue(true);
values[1] = EValue((int64_t)0);
getOpsFn("executorch_prim::sym_int.Scalar")(context_, Span<EValue*>(stack));
EXPECT_EQ(stack[1]->toInt(), 1);

values[0] = EValue(false);
values[1] = EValue((int64_t)0);
getOpsFn("executorch_prim::sym_int.Scalar")(context_, Span<EValue*>(stack));
EXPECT_EQ(stack[1]->toInt(), 0);
}

TEST_F(RegisterPrimOpsTest, TestAlgebraOps) {
EValue values[3];
int64_t a = 3;
Expand Down
Loading