Skip to content

Commit c7637bb

Browse files
Replace deprecated jax context manager and JIT-traced non-array arguments (#260)
* Replace deprecated jax.experimental.enable_x64 contextmanager * Avoid None attributes in JIT-wrapped functions * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Inline comment * Drop Python 3.10 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * CI test on Python 3.11 * Python 3.11 for readthedocs * Fix linting issues --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 10b635d commit c7637bb

11 files changed

Lines changed: 22 additions & 20 deletions

File tree

.github/workflows/tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ jobs:
1313
runs-on: ubuntu-latest
1414
strategy:
1515
matrix:
16-
python-version: ["3.10"]
16+
python-version: ["3.11"]
1717
x64: ["0"]
1818
include:
1919
- python-version: "3.13"

.readthedocs.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ build:
55
apt_packages:
66
- fonts-liberation
77
tools:
8-
python: "3.10"
8+
python: "3.11"
99

1010
python:
1111
install:

pyproject.toml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ name = "tinygp"
33
description = "The tiniest of Gaussian Process libraries"
44
authors = [{ name = "Dan Foreman-Mackey", email = "foreman.mackey@gmail.com" }]
55
readme = "README.md"
6-
requires-python = ">=3.10"
6+
requires-python = ">=3.11"
77
license = { text = "MIT" }
88
classifiers = [
99
"Development Status :: 4 - Beta",
@@ -45,11 +45,11 @@ source = "vcs"
4545
version-file = "src/tinygp/tinygp_version.py"
4646

4747
[tool.black]
48-
target-version = ["py39"]
48+
target-version = ["py312"]
4949
line-length = 88
5050

5151
[tool.ruff]
52-
target-version = "py39"
52+
target-version = "py312"
5353
line-length = 88
5454

5555
[tool.ruff.lint]
@@ -60,6 +60,7 @@ ignore = [
6060
"PLR0913", # Allow many arguments to functions
6161
"PLR0915", # Allow many statements
6262
"PLR2004", # Allow magic numbers in comparisons
63+
"B905", # Allow zip() without explicit `strict=` parameter
6364
]
6465
exclude = []
6566

src/tinygp/gp.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,11 @@
22

33
__all__ = ["GaussianProcess"]
44

5-
from collections.abc import Sequence
5+
from collections.abc import Callable, Sequence
66
from functools import partial
77
from typing import (
88
TYPE_CHECKING,
99
Any,
10-
Callable,
1110
NamedTuple,
1211
)
1312

src/tinygp/kernels/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
]
1313

1414
from abc import abstractmethod
15-
from collections.abc import Sequence
16-
from typing import TYPE_CHECKING, Any, Callable, Union
15+
from collections.abc import Callable, Sequence
16+
from typing import TYPE_CHECKING, Any
1717

1818
import equinox as eqx
1919
import jax
@@ -24,7 +24,7 @@
2424
if TYPE_CHECKING:
2525
from tinygp.solvers.solver import Solver
2626

27-
Axis = Union[int, Sequence[int]]
27+
Axis = int | Sequence[int]
2828

2929

3030
class Kernel(eqx.Module):

src/tinygp/means.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
__all__ = ["Mean", "Conditioned"]
1414

1515
from abc import abstractmethod
16-
from typing import Callable
16+
from collections.abc import Callable
1717

1818
import equinox as eqx
1919
import jax
@@ -39,18 +39,18 @@ class Mean(MeanBase):
3939
signature.
4040
"""
4141

42-
value: JAXArray | None = None
42+
value: JAXArray
4343
func: Callable[[JAXArray], JAXArray] | None = eqx.field(default=None, static=True)
4444

4545
def __init__(self, value: JAXArray | Callable[[JAXArray], JAXArray]):
4646
if callable(value):
4747
self.func = value
48+
self.value = jax.numpy.zeros(()) # avoids undefined traced values
4849
else:
4950
self.value = value
5051

5152
def __call__(self, X: JAXArray) -> JAXArray:
52-
if self.value is None:
53-
assert self.func is not None
53+
if self.func is not None:
5454
return self.func(X)
5555
return self.value
5656

src/tinygp/solvers/quasisep/core.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@
1818

1919
import dataclasses
2020
from abc import abstractmethod
21+
from collections.abc import Callable
2122
from functools import wraps
22-
from typing import Any, Callable
23+
from typing import Any
2324

2425
import equinox as eqx
2526
import jax

src/tinygp/solvers/quasisep/general.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@
1616

1717
__all__ = ["GeneralQSM"]
1818

19+
from collections.abc import Callable
1920
from functools import wraps
20-
from typing import Any, Callable
21+
from typing import Any
2122

2223
import equinox as eqx
2324
import jax

src/tinygp/transforms.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88

99
__all__ = ["Transform", "Linear", "Cholesky", "Subspace"]
1010

11-
from collections.abc import Sequence
11+
from collections.abc import Callable, Sequence
1212
from functools import partial
13-
from typing import Any, Callable
13+
from typing import Any
1414

1515
import equinox as eqx
1616
import jax.numpy as jnp

tests/test_gp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def data(random):
2424
def test_sample(data):
2525
X, _ = data
2626

27-
with jax.experimental.enable_x64(True):
27+
with jax.enable_x64(True):
2828
gp = GaussianProcess(
2929
kernels.Matern32(1.5), X, diag=0.01, mean=lambda x: jnp.sum(x)
3030
)

0 commit comments

Comments
 (0)