|
31 | 31 | partition, |
32 | 32 | setdiff1d, |
33 | 33 | sinc, |
| 34 | + union1d, |
34 | 35 | ) |
35 | 36 | from array_api_extra._lib._backends import NUMPY_VERSION, Backend |
36 | 37 | from array_api_extra._lib._testing import xfail, xp_assert_close, xp_assert_equal |
@@ -1637,3 +1638,36 @@ def test_kind(self, xp: ModuleType, library: Backend): |
1637 | 1638 | expected = xp.asarray([False, True, False, True]) |
1638 | 1639 | res = isin(a, b, kind="sort") |
1639 | 1640 | xp_assert_equal(res, expected) |
| 1641 | + |
| 1642 | + |
| 1643 | +@pytest.mark.skip_xp_backend( |
| 1644 | + Backend.ARRAY_API_STRICTEST, |
| 1645 | + reason="data_dependent_shapes flag for unique_values is disabled", |
| 1646 | +) |
| 1647 | +class TestUnion1d: |
| 1648 | + def test_simple(self, xp: ModuleType): |
| 1649 | + a = xp.asarray([-1, 1, 0]) |
| 1650 | + b = xp.asarray([2, -2, 0]) |
| 1651 | + expected = xp.asarray([-2, -1, 0, 1, 2]) |
| 1652 | + res = union1d(a, b) |
| 1653 | + xp_assert_equal(res, expected) |
| 1654 | + |
| 1655 | + def test_2d(self, xp: ModuleType): |
| 1656 | + a = xp.asarray([[-1, 1, 0], [1, 2, 0]]) |
| 1657 | + b = xp.asarray([[1, 0, 1], [-2, -1, 0]]) |
| 1658 | + expected = xp.asarray([-2, -1, 0, 1, 2]) |
| 1659 | + res = union1d(a, b) |
| 1660 | + xp_assert_equal(res, expected) |
| 1661 | + |
| 1662 | + def test_3d(self, xp: ModuleType): |
| 1663 | + a = xp.asarray([[[-1, 0], [1, 2]], [[-1, 0], [1, 2]]]) |
| 1664 | + b = xp.asarray([[[0, 1], [-1, 2]], [[1, -2], [0, 2]]]) |
| 1665 | + expected = xp.asarray([-2, -1, 0, 1, 2]) |
| 1666 | + res = union1d(a, b) |
| 1667 | + xp_assert_equal(res, expected) |
| 1668 | + |
| 1669 | + @pytest.mark.skip_xp_backend(Backend.TORCH, reason="materialize 'meta' device") |
| 1670 | + def test_device(self, xp: ModuleType, device: Device): |
| 1671 | + a = xp.asarray([-1, 1, 0], device=device) |
| 1672 | + b = xp.asarray([2, -2, 0], device=device) |
| 1673 | + assert get_device(union1d(a, b)) == device |
0 commit comments