Skip to content

Commit 90641b3

Browse files
committed
add examples
1 parent 82b074d commit 90641b3

2 files changed

Lines changed: 103 additions & 0 deletions

File tree

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ pip3 install dltype
2828
## Usage
2929

3030
Type hints are evaluated in a context in source-code order, so any references to dimension symbols must exist before an expression is evaluated.
31+
Run [./examples.py](./examples.py) `uv run examples.py` to see some basic usage patterns.
3132

3233
## Supported syntax
3334

examples.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
"""Example usages."""
2+
3+
from contextlib import contextmanager
4+
from typing import TYPE_CHECKING, Annotated
5+
6+
import numpy as np
7+
8+
import dltype
9+
10+
if TYPE_CHECKING:
11+
from collections.abc import Iterator
12+
13+
14+
@contextmanager
15+
def _hide_internal_dltype_stacktrace(name: str) -> Iterator[None]:
16+
try:
17+
yield
18+
msg = "Expected block to raise"
19+
raise RuntimeError(msg)
20+
except dltype.DLTypeError as e:
21+
print(f"{name}: {e.__class__.__name__}: {e}") # noqa: T201
22+
23+
24+
"""
25+
Basic usage.
26+
"""
27+
28+
29+
@dltype.dltyped()
30+
def cat_1d(
31+
arr1: Annotated[np.ndarray, dltype.FloatTensor["len1"]],
32+
arr2: Annotated[np.ndarray, dltype.FloatTensor["len2"]],
33+
) -> Annotated[np.ndarray, dltype.FloatTensor["len1+len2"]]:
34+
"""Concatenate 2 arrays together on the first axis."""
35+
return np.concatenate((arr1, arr2), axis=0)
36+
37+
38+
@dltype.dltyped()
39+
def fixed_size_crop(
40+
arr1: Annotated[np.ndarray, dltype.FloatTensor["batch channels=3 height width"]],
41+
) -> Annotated[np.ndarray, dltype.FloatTensor["batch channels min(768,height) min(1024,width)"]]:
42+
"""Crop the top 1024x768 pixels."""
43+
return arr1[..., :768, :1024]
44+
45+
46+
@dltype.dltyped()
47+
def warning_for_missing_annotation(
48+
# >>> UserWarning: [no_annotation] is missing a DLType hint
49+
no_annotation: np.ndarray,
50+
) -> Annotated[np.ndarray, dltype.FloatTensor["batch channels w h"]]:
51+
"""Crop the top 1024x768 pixels."""
52+
return no_annotation
53+
54+
55+
B = dltype.VariableAxis("batch")
56+
C = dltype.ConstantAxis("channels", 3)
57+
W = dltype.VariableAxis("width")
58+
H = dltype.VariableAxis("height")
59+
N = dltype.AnonymousAxis("ndims")
60+
61+
# Saving an annotation as a type alias for later use
62+
ImgShape = dltype.Shape[B, C, W, H]
63+
Uint8Img = dltype.UInt8Tensor[ImgShape]
64+
NPImgArr = Annotated[np.ndarray, Uint8Img]
65+
66+
67+
@dltype.dltyped()
68+
def static_shape_stack(
69+
arr: Annotated[np.ndarray, dltype.IntTensor[dltype.Shape[B, C, N]]],
70+
# note the B*2, resolves to 2x the input batch dimension
71+
) -> Annotated[np.ndarray, dltype.IntTensor[dltype.Shape[B * 2, C, N]]]:
72+
"""
73+
Stack an array on top of itself.
74+
75+
Examples of using statically defined shapes.
76+
Static analyzers will catch invalid shape expressions.
77+
In addition to built in operators we also support ISQRT, min, and max (imported through dltype, not the python builtin).
78+
"""
79+
return np.concatenate((arr, arr), axis=0)
80+
81+
82+
if __name__ == "__main__":
83+
assert cat_1d(np.zeros((1)), np.ones((2))).shape == (3,)
84+
85+
with _hide_internal_dltype_stacktrace("bad dims"):
86+
# >>> DLTypeNDimsError: Invalid number of dimensions, tensor=arr2 expected ndims=1 actual=2
87+
cat_1d(np.zeros((1,)), np.zeros((1, 2)))
88+
89+
with _hide_internal_dltype_stacktrace("bad dtype"):
90+
# >>> DLTypeDtypeError: Invalid dtype, tensor=arr1 expected one of (...supported float types) got=int32
91+
cat_1d(np.zeros((1,), dtype=np.int32), np.zeros((1,)))
92+
93+
img = np.zeros((1, 3, 800, 2048))
94+
fixed_size_crop(img)
95+
96+
with _hide_internal_dltype_stacktrace("bad channels"):
97+
# >>> DLTypeShapeError: Invalid tensor shape, tensor=arr1 dim=1 expected=3 actual=1
98+
fixed_size_crop(img[:, :1, ...])
99+
100+
fixed_size_crop(img[..., :100])
101+
102+
static_shape_stack(np.zeros((1, 3, 5, 5, 9), dtype=np.int32))

0 commit comments

Comments
 (0)