Skip to content

Commit 06ab0fc

Browse files
dlangermdlangerm-stackav
authored andcommitted
add examples
1 parent 82b074d commit 06ab0fc

3 files changed

Lines changed: 107 additions & 0 deletions

File tree

.github/workflows/pr.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,3 +74,7 @@ jobs:
7474
- name: Run benchmarks
7575
run: |
7676
uv run --frozen benchmark.py
77+
78+
- name: Run examples
79+
run: |
80+
uv run --frozen examples.py

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ pip3 install dltype
2828
## Usage
2929

3030
Type hints are evaluated in a context in source-code order, so any references to dimension symbols must exist before an expression is evaluated.
31+
Run [./examples.py](./examples.py) `uv run examples.py` to see some basic usage patterns.
3132

3233
## Supported syntax
3334

examples.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
"""Example usages."""
2+
3+
from __future__ import annotations
4+
5+
from collections.abc import Iterator # noqa: TC003
6+
from contextlib import contextmanager
7+
from typing import Annotated
8+
9+
import numpy as np
10+
11+
import dltype
12+
13+
14+
@contextmanager
15+
def _hide_internal_dltype_stacktrace(name: str) -> Iterator[None]:
16+
try:
17+
yield
18+
msg = "Expected block to raise"
19+
raise RuntimeError(msg)
20+
except dltype.DLTypeError as e:
21+
print(f"{name}: {e.__class__.__name__}: {e}") # noqa: T201
22+
23+
24+
"""
25+
Basic usage.
26+
"""
27+
28+
29+
@dltype.dltyped()
30+
def cat_1d(
31+
arr1: Annotated[np.ndarray, dltype.FloatTensor["len1"]],
32+
arr2: Annotated[np.ndarray, dltype.FloatTensor["len2"]],
33+
) -> Annotated[np.ndarray, dltype.FloatTensor["len1+len2"]]:
34+
"""Concatenate 2 arrays together on the first axis."""
35+
return np.concatenate((arr1, arr2), axis=0)
36+
37+
38+
@dltype.dltyped()
39+
def fixed_size_crop(
40+
arr1: Annotated[np.ndarray, dltype.FloatTensor["batch channels=3 height width"]],
41+
) -> Annotated[np.ndarray, dltype.FloatTensor["batch channels min(768,height) min(1024,width)"]]:
42+
"""Crop the top 1024x768 pixels."""
43+
return arr1[..., :768, :1024]
44+
45+
46+
@dltype.dltyped()
47+
def warning_for_missing_annotation(
48+
# >>> UserWarning: [no_annotation] is missing a DLType hint
49+
no_annotation: np.ndarray,
50+
) -> Annotated[np.ndarray, dltype.FloatTensor["batch channels w h"]]:
51+
"""Crop the top 1024x768 pixels."""
52+
return no_annotation
53+
54+
55+
B = dltype.VariableAxis("batch")
56+
C = dltype.ConstantAxis("channels", 3)
57+
W = dltype.VariableAxis("width")
58+
H = dltype.VariableAxis("height")
59+
N = dltype.AnonymousAxis("ndims")
60+
61+
# Saving an annotation as a type alias for later use
62+
ImgShape = dltype.Shape[B, C, W, H]
63+
Uint8Img = dltype.UInt8Tensor[ImgShape]
64+
NPImgArr = Annotated[np.ndarray, Uint8Img]
65+
66+
67+
@dltype.dltyped()
68+
def static_shape_stack(
69+
arr: Annotated[np.ndarray, dltype.IntTensor[dltype.Shape[B, C, N]]],
70+
# note the B*2, resolves to 2x the input batch dimension
71+
) -> Annotated[np.ndarray, dltype.IntTensor[dltype.Shape[B * 2, C, N]]]:
72+
"""
73+
Stack an array on top of itself.
74+
75+
Examples of using statically defined shapes.
76+
Static analyzers will catch invalid shape expressions.
77+
In addition to built in operators we also support ISQRT, min, and max (imported through dltype, not the python builtin).
78+
"""
79+
return np.concatenate((arr, arr), axis=0)
80+
81+
82+
if __name__ == "__main__":
83+
assert cat_1d(np.zeros((1)), np.ones((2))).shape == (3,)
84+
85+
with _hide_internal_dltype_stacktrace("bad dims"):
86+
# >>> DLTypeNDimsError: Invalid number of dimensions, tensor=arr2 expected ndims=1 actual=2
87+
cat_1d(np.zeros((1,)), np.zeros((1, 2)))
88+
89+
with _hide_internal_dltype_stacktrace("bad dtype"):
90+
# >>> DLTypeDtypeError: Invalid dtype, tensor=arr1 expected one of (...supported float types) got=int32
91+
cat_1d(np.zeros((1,), dtype=np.int32), np.zeros((1,)))
92+
93+
img = np.zeros((1, 3, 800, 2048))
94+
fixed_size_crop(img)
95+
96+
with _hide_internal_dltype_stacktrace("bad channels"):
97+
# >>> DLTypeShapeError: Invalid tensor shape, tensor=arr1 dim=1 expected=3 actual=1
98+
fixed_size_crop(img[:, :1, ...])
99+
100+
fixed_size_crop(img[..., :100])
101+
102+
static_shape_stack(np.zeros((1, 3, 5, 5, 9), dtype=np.int32))

0 commit comments

Comments
 (0)