Sym ite prim op#20681
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20681
Note: Links to docs will display an error until the docs builds have been completed. ❌ 3 New FailuresAs of commit bba0931 with merge base 71a80d7 ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
There was a problem hiding this comment.
Pull request overview
This PR adds ExecuTorch support for the torch.sym_ite symbolic conditional by registering a new executorch_prim::sym_ite.Scalar prim op, wiring it into the EXIR prim-op registry, and adding C++/Python coverage to validate both registration and end-to-end emission/runtime behavior.
Changes:
- Register a new prim op kernel:
executorch_prim::sym_ite.Scalar(bool b, Scalar t, Scalar f) -> Scalarin the prim ops registry. - Extend prim-ops C++ tests to validate
sym_iteregistration and runtime selection across int/double/bool. - Bind
torch.sym_iteto the ExecuTorch prim op and add an emit test exercising dynamic shape behavior.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
| kernels/prim_ops/test/prim_ops_test.cpp | Adds registration + correctness tests for executorch_prim::sym_ite.Scalar. |
| kernels/prim_ops/register_prim_ops.cpp | Registers the sym_ite prim op kernel implementation. |
| exir/passes/executorch_prim_ops_registry.py | Binds the sym_ite pattern and maps torch.sym_ite to the backend op overload. |
| exir/emit/test/test_emit.py | Adds an end-to-end emit/runtime test that uses torch.sym_ite with dynamic shapes. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| exported_program = torch.export.export( | ||
| model, (test_inputs[0],), dynamic_shapes=dynamic_shapes | ||
| ) | ||
|
|
||
| edge_program = to_edge( | ||
| exported_program, | ||
| compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), | ||
| ) |
Implements the ternary sym_ite(condition, true_val, false_val) op, needed by torch.export with Dim.AUTO when models contain conditional symbolic shape logic.
torch.sym_ite enforces type(t) == type(f). The original test passed a SymInt (n) and a plain int (6), which fails the type check. Use a second dynamic shape dimension (m) so both branches are SymInt.
| torch.randn(3, 4), | ||
| torch.randn(8, 4), | ||
| ] | ||
| ] |
| @bind_pattern_to_op( | ||
| executorch_prims_lib, "sym_ite.Scalar(Scalar b, Scalar t, Scalar f) -> Scalar" | ||
| ) | ||
| def sym_ite(b: _SymScalar, t: _SymScalar, f: _SymScalar) -> _SymScalar: | ||
| return t if b else f # pyre-ignore |
Adds Python op registration, mapping entry, and C++ kernel handling with tests