66import itertools
77import warnings
88from functools import lru_cache , wraps
9+ from types import EllipsisType
910from typing import (
1011 TYPE_CHECKING ,
1112 Annotated ,
@@ -51,16 +52,22 @@ class DLTypeAnnotation(NamedTuple):
5152 dltype_annotation : _tensor_type_base .TensorTypeBase | None
5253
5354 @classmethod
54- def from_hint (
55+ def from_hint ( # noqa: PLR0911
5556 cls ,
56- hint : type | None ,
57+ hint : type | EllipsisType | None ,
5758 name : str ,
5859 * ,
5960 optional : bool = False ,
61+ stack_offset : int = 0 ,
6062 ) -> tuple [DLTypeAnnotation | None , ...]:
6163 """Create a new _DLTypeAnnotation from a type hint."""
64+ if isinstance (hint , EllipsisType ):
65+ return (None ,)
66+
6267 if hint is None :
63- warnings .warn (f"[{ name } ] is missing a DLType hint" , category = UserWarning , stacklevel = 4 )
68+ warnings .warn (
69+ f"[{ name } ] is missing a DLType hint" , category = UserWarning , stacklevel = 4 + stack_offset
70+ )
6471 return (None ,)
6572
6673 _logger .debug ("Creating DLType from hint %r" , hint )
@@ -83,20 +90,28 @@ def from_hint(
8390
8491 # tuple handling special case
8592 if origin is tuple :
86- return tuple (itertools .chain (* [cls .from_hint (inner_hint , name ) for inner_hint in args ]))
93+ return tuple (
94+ itertools .chain (
95+ * [cls .from_hint (inner_hint , name , stack_offset = stack_offset + 1 ) for inner_hint in args ]
96+ )
97+ )
8798
8899 # Only process Annotated types, warn if the annotated type is a tensor
89100 if origin is not Annotated :
90101 if any (T in hint .mro () for T in _dtypes .SUPPORTED_TENSOR_TYPES ) if hint else False :
91- warnings .warn (f"[{ name } ] is missing a DLType hint" , category = UserWarning , stacklevel = 4 )
102+ warnings .warn (
103+ f"[{ name } ] is missing a DLType hint" , category = UserWarning , stacklevel = 4 + stack_offset
104+ )
92105 return (None ,)
93106
94107 # Ensure the annotation is a TensorTypeBase
95108 if len (args ) < n_expected_args or not isinstance (
96109 args [1 ],
97110 _tensor_type_base .TensorTypeBase ,
98111 ):
99- warnings .warn (f"[{ name } ] has an invalid DLType hint" , category = UserWarning , stacklevel = 4 )
112+ warnings .warn (
113+ f"[{ name } ] has an invalid DLType hint" , category = UserWarning , stacklevel = 4 + stack_offset
114+ )
100115 return (None ,)
101116
102117 # Ensure the base type is a supported tensor type
@@ -130,13 +145,14 @@ def get_dltype_scope(self) -> _dltype_context.EvaluatedDimensionT:
130145def _maybe_get_type_hints (
131146 existing_hints : dict [str , tuple [DLTypeAnnotation | None , ...]] | None ,
132147 func : Callable [P , R ],
148+ stack_offset : int = 0 ,
133149) -> dict [str , tuple [DLTypeAnnotation | None , ...]] | None :
134150 """Get the type hints for a function, or return an empty dict if not available."""
135151 if existing_hints is not None :
136152 return existing_hints
137153 try :
138154 return {
139- name : DLTypeAnnotation .from_hint (hint , name )
155+ name : DLTypeAnnotation .from_hint (hint , name , stack_offset = stack_offset )
140156 for name , hint in get_type_hints (func , include_extras = True ).items ()
141157 }
142158 except NameError :
0 commit comments