Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions lefthook.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ pre-commit:
- name: mypy
glob: "*.{py,pyi}"
run: pixi {run} mypy
- name: pyrefly
glob: "*.{py, pyi}"
run: pixi {run} pyrefly
- name: typos
stage_fixed: true
run: pixi {run} typos
Expand Down
61 changes: 61 additions & 0 deletions pixi.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

43 changes: 41 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ actionlint = ">=1.7.12,<2"
blacken-docs = ">=1.20.0,<2"
pytest = ">=9.0.2,<10"
validate-pyproject = ">=0.25,<0.26"
pyrefly = ">=0.61.1,<0.62"
# NOTE: don't add cupy, jax, pytorch, or sparse here,
# as they slow down mypy and are not portable across target OSs

Expand All @@ -117,6 +118,7 @@ hooks = { cmd = "lefthook install", description = "Install pre-commit hooks" }
pre-commit = { cmd = "lefthook run pre-commit", description = "Run pre-commit checks" }
pylint = { cmd = "pylint array_api_extra", cwd = "src", description = "Lint with pylint" }
mypy = { cmd = "mypy", description = "Type check with mypy" }
pyrefly = { cmd = "pyrefly check", description = "Type check with pyrefly" }
pyright = { cmd = "basedpyright", description = "Type check with basedpyright" }
ruff-check = { cmd = "ruff check --fix", description = "Lint with ruff" }
ruff-format = { cmd = "ruff format", description = "Format with ruff" }
Expand Down Expand Up @@ -257,7 +259,7 @@ run.source = ["array_api_extra"]
# mypy

[tool.mypy]
files = ["src", "tests"]
files = ["src", "tests", "vendor_tests"]
python_version = "3.11"
warn_unused_configs = true
strict = true
Expand All @@ -273,10 +275,46 @@ ignore_missing_imports = true
module = ["tests/*"]
disable_error_code = ["no-untyped-def"] # test(...) without -> None

[[tool.mypy.overrides]]
module = ["vendor_tests/*"]
disable_error_code = ["no-untyped-def"] # test(...) without -> None

[[tool.mypy.overrides]]
module = ["vendor_tests/array_api_compat/*"]
ignore_errors = true

# pyrefly

[tool.pyrefly.errors]
# Redundant with mypy checks
missing-import = false
# extra checks from scipy/scipy-stubs
implicit-abstract-class = "error"
implicitly-defined-attribute = "error"
missing-override-decorator = "error"
missing-source = "ignore"
not-required-key-access = "error"
open-unpacking = "error"
unannotated-attribute = "error"
unannotated-parameter = "error"
unannotated-return = "error"
untyped-import = "error"
unused-ignore = "error"
variance-mismatch = "error"

[[tool.pyrefly.sub-config]]
matches = "tests/*.py"
errors = { unannotated-return = false }

[[tool.pyrefly.sub-config]]
matches = "vendor_tests/*.py"
errors = { unannotated-return = false }

# pyright

[tool.basedpyright]
include = ["src", "tests"]
include = ["src", "tests", "vendor_tests"]
exclude = ["vendor_tests/array_api_compat"]
pythonVersion = "3.11"
pythonPlatform = "All"
typeCheckingMode = "all"
Expand All @@ -302,6 +340,7 @@ reportUnknownLambdaType = false

executionEnvironments = [
{ root = "tests", reportPrivateUsage = false, reportUnknownArgumentType = false },
{ root = "vendor_tests", reportPrivateUsage = false, reportUnknownArgumentType = false },
{ root = "src" },
]

Expand Down
2 changes: 1 addition & 1 deletion src/array_api_extra/_lib/_at.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class _AtOp(Enum):
MAX = "max"

# @override from Python 3.12
def __str__(self) -> str: # pyright: ignore[reportImplicitOverride]
def __str__(self) -> str: # pyright: ignore[reportImplicitOverride] # pyrefly: ignore[missing-override-decorator]
"""
Return string representation (useful for pytest logs).
Expand Down
8 changes: 4 additions & 4 deletions src/array_api_extra/_lib/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
P = ParamSpec("P")


@overload
@overload # pyrefly: ignore[invalid-param-spec]
Comment thread
lucascolley marked this conversation as resolved.
def lazy_apply( # type: ignore[valid-type]
func: Callable[P, Array | ArrayLike],
*args: Array | complex | None,
Expand All @@ -42,7 +42,7 @@ def lazy_apply( # type: ignore[valid-type]
) -> Array: ... # numpydoc ignore=GL08


@overload
@overload # pyrefly: ignore[invalid-param-spec]
def lazy_apply( # type: ignore[valid-type]
func: Callable[P, Sequence[Array | ArrayLike]],
*args: Array | complex | None,
Expand All @@ -54,7 +54,7 @@ def lazy_apply( # type: ignore[valid-type]
) -> tuple[Array, ...]: ... # numpydoc ignore=GL08


def lazy_apply( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
def lazy_apply( # type: ignore[valid-type] # pyrefly: ignore[invalid-param-spec] # numpydoc ignore=GL07,SA04
func: Callable[P, Array | ArrayLike | Sequence[Array | ArrayLike]],
*args: Array | complex | None,
shape: tuple[int | None, ...] | Sequence[tuple[int | None, ...]] | None = None,
Expand Down Expand Up @@ -240,7 +240,7 @@ def lazy_apply( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
if is_dask_namespace(xp):
import dask

metas: list[Array] = [arg._meta for arg in array_args] # pylint: disable=protected-access # pyright: ignore[reportAttributeAccessIssue]
metas: list[Array] = [arg._meta for arg in array_args] # pylint: disable=protected-access # pyright: ignore[reportAttributeAccessIssue] # pyrefly: ignore[missing-attribute]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want to avoid long lines, then this should also do the trick I believe:

Suggested change
metas: list[Array] = [arg._meta for arg in array_args] # pylint: disable=protected-access # pyright: ignore[reportAttributeAccessIssue] # pyrefly: ignore[missing-attribute]
# pyrefly: ignore[missing-attribute]
metas: list[Array] = [arg._meta for arg in array_args] # pylint: disable=protected-access # pyright: ignore[reportAttributeAccessIssue]

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could do, but tbh I don't mind long lines as long as they are just comments to the machine

meta_xp = array_namespace(*metas)

wrapped = dask.delayed( # type: ignore[attr-defined] # pyright: ignore[reportPrivateImportUsage]
Expand Down
4 changes: 2 additions & 2 deletions src/array_api_extra/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def test_myfunc(xp):
f = func

try:
f._lazy_xp_function = tags # pylint: disable=protected-access # pyright: ignore[reportFunctionMemberAccess]
f._lazy_xp_function = tags # pylint: disable=protected-access # pyright: ignore[reportFunctionMemberAccess] # pyrefly: ignore[missing-attribute]
except AttributeError: # @cython.vectorize
_ufuncs_tags[f] = tags

Expand Down Expand Up @@ -461,7 +461,7 @@ class CountingDaskScheduler(SchedulerGetCallable):
max_count: int
msg: str

def __init__(self, max_count: int, msg: str): # numpydoc ignore=GL08
def __init__(self, max_count: int, msg: str) -> None: # numpydoc ignore=GL08
self.count = 0
self.max_count = max_count
self.msg = msg
Expand Down
5 changes: 4 additions & 1 deletion tests/test_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,10 @@ def test_lazy_apply_none_shape_in_args(xp: ModuleType, library: Backend):
mxp = np if library is Backend.DASK else xp
int_type = xp.asarray(0).dtype

ctx: contextlib.AbstractContextManager[object]
ctx: (
contextlib.AbstractContextManager[object]
| contextlib.AbstractContextManager[None]
)
Comment on lines -222 to +225
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pyrefly is not happy for contextlib.nullcontext() to fall under contextlib.AbstractContextManager[object]

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds like a bug in Pyrefly, because the first typar of AbstractContextManager is supposed to be covariant: https://github.com/python/typeshed/blob/eec9fe9aa9ce87df8987de5c1401f743e179378a/stdlib/contextlib.pyi#L48

Do you want to report it or should I?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Go for it if you're happy to!

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if library.like(Backend.JAX):
ctx = pytest.raises(ValueError, match="Output shape must be fully known")
elif library is Backend.ARRAY_API_STRICTEST:
Expand Down
7 changes: 5 additions & 2 deletions vendor_tests/_array_api_compat_vendor.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
"""This file is a hook imported by `src/array_api_extra/_lib/_compat.py`."""
# pyright: reportUnknownParameterType=false, reportMissingParameterType=false

from .array_api_compat import * # noqa: F403
from types import ModuleType
from typing import Any

from .array_api_compat import * # type: ignore[import-not-found] # noqa: F403
from .array_api_compat import array_namespace as array_namespace_compat


# Let unit tests check with `is` that we are picking up the function from this module
# and not from the original array_api_compat module.
def array_namespace(*xs, **kwargs): # numpydoc ignore=GL08
def array_namespace(*xs: Any | complex | None, **kwargs) -> ModuleType: # pyrefly: ignore[unannotated-parameter] # numpydoc ignore=GL08
return array_namespace_compat(*xs, **kwargs)
20 changes: 15 additions & 5 deletions vendor_tests/test_vendor.py
Comment thread
lucascolley marked this conversation as resolved.
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
# pyright: reportAttributeAccessIssue=false

from typing import Any
from typing import Any, cast

import array_api_strict as xp
from numpy.testing import assert_array_equal

from vendor_tests.array_api_compat.common._typing import ( # type: ignore[import-not-found]
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we should be cloning the vendored libraries into place for the static analysis — not sure

Array,
)


def test_vendor_compat():
from ._array_api_compat_vendor import ( # type: ignore[attr-defined]
Expand Down Expand Up @@ -35,6 +39,7 @@ def test_vendor_compat():
to_device(x, device(x))
assert is_array_api_obj(x)
assert is_array_api_strict_namespace(xp)
x = cast(Array, x)
Comment on lines 40 to +42
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

otherwise x stays narrowed by TypeGuard[_ArrayApiObj], where:

Argument type is partially unknown
  Argument corresponds to parameter "x" in function "is_cupy_array"
  Argument type is "ndarray[tuple[Any, ...], dtype[Any]] | Unknown | Array | SupportsArrayNamespace[Any]

assert not is_cupy_array(x)
assert not is_cupy_namespace(xp)
assert not is_dask_array(x)
Expand All @@ -53,15 +58,18 @@ def test_vendor_compat():


def test_vendor_extra():
from .array_api_extra import atleast_nd
from .array_api_extra import atleast_nd # type: ignore[import-not-found]

x = xp.asarray(1)
x = cast(Array, x)
y = atleast_nd(x, ndim=0)
assert_array_equal(y, x) # pyright: ignore[reportUnknownArgumentType]
assert_array_equal(y, x)


def test_vendor_extra_testing():
from .array_api_extra.testing import lazy_xp_function
from .array_api_extra.testing import ( # type: ignore[import-not-found]
lazy_xp_function,
)

def f(x: Any) -> Any:
return x
Expand All @@ -71,6 +79,8 @@ def f(x: Any) -> Any:

def test_vendor_extra_uses_vendor_compat():
from ._array_api_compat_vendor import array_namespace as n1
from .array_api_extra._lib._utils._compat import array_namespace as n2
from .array_api_extra._lib._utils._compat import ( # type: ignore[import-not-found]
array_namespace as n2,
)

assert n1 is n2