Skip to content

Commit 0de9b35

Browse files
dlangerm-stackdlangerm-stackav
authored andcommitted
fix non-class arguments breaking core mro check
1 parent 964cec5 commit 0de9b35

4 files changed

Lines changed: 37 additions & 9 deletions

File tree

dltype/_lib/_core.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import itertools
77
import warnings
88
from functools import lru_cache, wraps
9+
from types import EllipsisType
910
from typing import (
1011
TYPE_CHECKING,
1112
Annotated,
@@ -51,16 +52,22 @@ class DLTypeAnnotation(NamedTuple):
5152
dltype_annotation: _tensor_type_base.TensorTypeBase | None
5253

5354
@classmethod
54-
def from_hint(
55+
def from_hint( # noqa: PLR0911
5556
cls,
56-
hint: type | None,
57+
hint: type | EllipsisType | None,
5758
name: str,
5859
*,
5960
optional: bool = False,
61+
stack_offset: int = 0,
6062
) -> tuple[DLTypeAnnotation | None, ...]:
6163
"""Create a new _DLTypeAnnotation from a type hint."""
64+
if isinstance(hint, EllipsisType):
65+
return (None,)
66+
6267
if hint is None:
63-
warnings.warn(f"[{name}] is missing a DLType hint", category=UserWarning, stacklevel=4)
68+
warnings.warn(
69+
f"[{name}] is missing a DLType hint", category=UserWarning, stacklevel=4 + stack_offset
70+
)
6471
return (None,)
6572

6673
_logger.debug("Creating DLType from hint %r", hint)
@@ -83,20 +90,28 @@ def from_hint(
8390

8491
# tuple handling special case
8592
if origin is tuple:
86-
return tuple(itertools.chain(*[cls.from_hint(inner_hint, name) for inner_hint in args]))
93+
return tuple(
94+
itertools.chain(
95+
*[cls.from_hint(inner_hint, name, stack_offset=stack_offset + 1) for inner_hint in args]
96+
)
97+
)
8798

8899
# Only process Annotated types, warn if the annotated type is a tensor
89100
if origin is not Annotated:
90101
if any(T in hint.mro() for T in _dtypes.SUPPORTED_TENSOR_TYPES) if hint else False:
91-
warnings.warn(f"[{name}] is missing a DLType hint", category=UserWarning, stacklevel=4)
102+
warnings.warn(
103+
f"[{name}] is missing a DLType hint", category=UserWarning, stacklevel=4 + stack_offset
104+
)
92105
return (None,)
93106

94107
# Ensure the annotation is a TensorTypeBase
95108
if len(args) < n_expected_args or not isinstance(
96109
args[1],
97110
_tensor_type_base.TensorTypeBase,
98111
):
99-
warnings.warn(f"[{name}] has an invalid DLType hint", category=UserWarning, stacklevel=4)
112+
warnings.warn(
113+
f"[{name}] has an invalid DLType hint", category=UserWarning, stacklevel=4 + stack_offset
114+
)
100115
return (None,)
101116

102117
# Ensure the base type is a supported tensor type
@@ -130,13 +145,14 @@ def get_dltype_scope(self) -> _dltype_context.EvaluatedDimensionT:
130145
def _maybe_get_type_hints(
131146
existing_hints: dict[str, tuple[DLTypeAnnotation | None, ...]] | None,
132147
func: Callable[P, R],
148+
stack_offset: int = 0,
133149
) -> dict[str, tuple[DLTypeAnnotation | None, ...]] | None:
134150
"""Get the type hints for a function, or return an empty dict if not available."""
135151
if existing_hints is not None:
136152
return existing_hints
137153
try:
138154
return {
139-
name: DLTypeAnnotation.from_hint(hint, name)
155+
name: DLTypeAnnotation.from_hint(hint, name, stack_offset=stack_offset)
140156
for name, hint in get_type_hints(func, include_extras=True).items()
141157
}
142158
except NameError:

dltype/tests/dltype_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1694,3 +1694,15 @@ class CheckedNT(NamedTuple):
16941694
checked(bad_arr)
16951695
Checked(arg=bad_arr)
16961696
CheckedNT(arg=bad_arr)
1697+
1698+
1699+
def test_tuple_ellipsis() -> None:
1700+
1701+
with pytest.warns(UserWarning, match="is missing a DLType hint"):
1702+
1703+
@dltype.dltyped()
1704+
def tuple_function( # pyright: ignore[reportUnusedFunction]
1705+
tensor: tuple[torch.Tensor, ...],
1706+
tensor1: tuple[Annotated[torch.Tensor, dltype.FloatTensor["1 2 3"]]],
1707+
) -> None:
1708+
"""A function that takes a tensor and returns a tensor."""

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ license-files = ["LICENSE"]
3434
name = "dltype"
3535
readme = "README.md"
3636
requires-python = ">=3.10"
37-
version = "0.13.0"
37+
version = "0.13.1"
3838

3939
[project.optional-dependencies]
4040
jax = ["jax>=0.6.2"]

uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)