From bae5446ad0d37fee88ccb1cdde2fb25b52e0ba3a Mon Sep 17 00:00:00 2001 From: Pavlo Pelikh Date: Wed, 20 May 2026 01:04:14 -0300 Subject: [PATCH 1/7] Add lazy LinOp algebra foundation --- spacecore/__init__.py | 27 +- spacecore/linop/__init__.py | 20 + spacecore/linop/_algebra.py | 568 +++++++++++++++++++++++++++ spacecore/linop/_base.py | 93 +++++ tests/integration/test_public_api.py | 3 + tests/linops/test_algebra_linop.py | 244 ++++++++++++ 6 files changed, 954 insertions(+), 1 deletion(-) create mode 100644 spacecore/linop/_algebra.py create mode 100644 tests/linops/test_algebra_linop.py diff --git a/spacecore/__init__.py b/spacecore/__init__.py index d130500..4371730 100644 --- a/spacecore/__init__.py +++ b/spacecore/__init__.py @@ -6,7 +6,23 @@ from .backend import TorchOps as TorchOps except ImportError: pass -from .linop import DenseLinOp, SparseLinOp, BlockDiagonalLinOp, SumToSingleLinOp, StackedLinOp, LinOp +from .linop import ( + BlockDiagonalLinOp, + ComposedLinOp, + DenseLinOp, + IdentityLinOp, + LinOp, + MatrixFreeLinOp, + ScaledLinOp, + SparseLinOp, + StackedLinOp, + SumLinOp, + SumToSingleLinOp, + ZeroLinOp, + make_composed, + make_scaled, + make_sum, +) from .space import ( BackendCheck, DTypeCheck, @@ -42,8 +58,17 @@ "NumpyOps", "LinOp", + "ComposedLinOp", "DenseLinOp", + "IdentityLinOp", + "MatrixFreeLinOp", + "ScaledLinOp", "SparseLinOp", + "SumLinOp", + "ZeroLinOp", + "make_composed", + "make_scaled", + "make_sum", "BlockDiagonalLinOp", "SumToSingleLinOp", "StackedLinOp", diff --git a/spacecore/linop/__init__.py b/spacecore/linop/__init__.py index 7a6f537..f810bb1 100644 --- a/spacecore/linop/__init__.py +++ b/spacecore/linop/__init__.py @@ -1,12 +1,32 @@ from ._base import LinOp +from ._algebra import ( + ComposedLinOp, + IdentityLinOp, + MatrixFreeLinOp, + ScaledLinOp, + SumLinOp, + ZeroLinOp, + make_composed, + make_scaled, + make_sum, +) from ._dense import DenseLinOp from ._sparse import SparseLinOp from .product import ProductLinOp, StackedLinOp, SumToSingleLinOp, BlockDiagonalLinOp __all__ = [ "LinOp", + "ComposedLinOp", "DenseLinOp", + "IdentityLinOp", + "MatrixFreeLinOp", + "ScaledLinOp", "SparseLinOp", + "SumLinOp", + "ZeroLinOp", + "make_composed", + "make_scaled", + "make_sum", "ProductLinOp", "SumToSingleLinOp", "BlockDiagonalLinOp", diff --git a/spacecore/linop/_algebra.py b/spacecore/linop/_algebra.py new file mode 100644 index 0000000..f104c2f --- /dev/null +++ b/spacecore/linop/_algebra.py @@ -0,0 +1,568 @@ +from __future__ import annotations + +from numbers import Number +from typing import Any, Sequence + +from ._base import LinOp, Domain, Codomain +from ..backend import Context, jax_pytree_class + + +def is_scalar_like(value: Any) -> bool: + """Return whether ``value`` can be used as a scalar multiplier for a ``LinOp``.""" + if isinstance(value, Number): + return True + shape = getattr(value, "shape", None) + if shape is not None: + return tuple(shape) == () + ndim = getattr(value, "ndim", None) + return ndim == 0 + + +def _conjugate_scalar(value: Any) -> Any: + if hasattr(value, "conjugate"): + return value.conjugate() + if hasattr(value, "conj"): + return value.conj() + return value + + +def _same_context(left: LinOp, right: LinOp) -> bool: + return ( + left.ctx == right.ctx + and left.ctx.dtype == right.ctx.dtype + and left.ctx.enable_checks == right.ctx.enable_checks + ) + + +def _require_same_context(ops: Sequence[LinOp]) -> Context: + ctx = ops[0].ctx + for i, op in enumerate(ops[1:], start=1): + if not _same_context(ops[0], op): + raise ValueError( + "All LinOp operands in an algebraic expression must have the same ctx; " + f"operand 0 has ctx {ctx!r}, operand {i} has ctx {op.ctx!r}." + ) + return ctx + + +def _require_linop(op: Any, name: str) -> LinOp: + if not isinstance(op, LinOp): + raise TypeError(f"{name} must be a LinOp, got {type(op).__name__}.") + return op + + +def _scalar_equal(value: Any, target: Any) -> bool: + try: + return bool(value == target) + except Exception: + return False + + +def _is_zero_scalar(value: Any) -> bool: + return _scalar_equal(value, 0) + + +def _is_one_scalar(value: Any) -> bool: + return _scalar_equal(value, 1) + + +def _flatten_sum_terms(ops: Sequence[LinOp]) -> tuple[LinOp, ...]: + terms: list[LinOp] = [] + for i, op in enumerate(ops): + op = _require_linop(op, f"ops[{i}]") + if isinstance(op, SumLinOp): + terms.extend(_flatten_sum_terms(op.parts)) + else: + terms.append(op) + return tuple(terms) + + +def make_sum(ops: Sequence[LinOp]) -> LinOp: + """ + Return a locally simplified lazy sum of linear operators. + + This factory performs only local algebraic canonicalization: nested + ``SumLinOp`` nodes are flattened and ``ZeroLinOp`` terms are removed. It + does not collect like terms, reorder operands, or attempt full symbolic + optimization. All operands must have the same context, domain, and codomain + before a simplified operator is returned. + """ + if not ops: + raise ValueError("make_sum requires a nonempty sequence of LinOp operands.") + + terms = _flatten_sum_terms(ops) + ctx = _require_same_context(terms) + domain = terms[0].domain + codomain = terms[0].codomain + for i, op in enumerate(terms[1:], start=1): + if op.domain != domain or op.codomain != codomain: + raise ValueError( + "All SumLinOp operands must have the same domain and codomain; " + f"operand 0 maps {domain!r} -> {codomain!r}, " + f"operand {i} maps {op.domain!r} -> {op.codomain!r}." + ) + + nonzero_terms = tuple(op for op in terms if not isinstance(op, ZeroLinOp)) + if not nonzero_terms: + return ZeroLinOp(domain, codomain, ctx) + if len(nonzero_terms) == 1: + return nonzero_terms[0] + return SumLinOp(nonzero_terms) + + +def make_scaled(scalar: Any, op: LinOp) -> LinOp: + """ + Return a locally simplified scalar multiple of a linear operator. + + This factory performs only local algebraic canonicalization: zero and unit + scalars are simplified, and nested ``ScaledLinOp`` nodes are folded into one + scalar. It does not distribute scaling over sums or perform full symbolic + optimization. Complex scalars retain the usual conjugated coefficient in + ``rapply`` through ``ScaledLinOp``. + """ + op = _require_linop(op, "op") + if not is_scalar_like(scalar): + raise TypeError(f"scalar must be scalar-like, got {type(scalar).__name__}.") + + if _is_zero_scalar(scalar): + return ZeroLinOp(op.domain, op.codomain, op.ctx) + if _is_one_scalar(scalar): + return op + if isinstance(op, ScaledLinOp): + return make_scaled(scalar * op.scalar, op.op) + return ScaledLinOp(scalar, op) + + +def make_composed(left: LinOp, right: LinOp) -> LinOp: + """ + Return a locally simplified composition of two linear operators. + + This factory performs only local algebraic canonicalization: identity + factors are removed and compositions with zero maps become zero maps. It + preserves the binary ``ComposedLinOp`` representation and does not flatten + multi-factor chains or attempt full symbolic optimization. Operands must + have the same context and compatible middle spaces before a simplified + operator is returned. + """ + left = _require_linop(left, "left") + right = _require_linop(right, "right") + _require_same_context((left, right)) + if right.codomain != left.domain: + raise ValueError( + "ComposedLinOp requires right.codomain == left.domain; " + f"got {right.codomain!r} and {left.domain!r}." + ) + + if isinstance(right, IdentityLinOp): + return left + if isinstance(left, IdentityLinOp): + return right + if isinstance(left, ZeroLinOp): + return ZeroLinOp(right.domain, left.codomain, left.ctx) + if isinstance(right, ZeroLinOp): + return ZeroLinOp(right.domain, left.codomain, left.ctx) + return ComposedLinOp(left, right) + + +@jax_pytree_class +class ScaledLinOp(LinOp[Domain, Codomain]): + """ + Lazy scalar multiple of a linear operator. + + ``ScaledLinOp(alpha, A)`` represents the mathematical operator + ``alpha * A``. Its context is exactly ``A.ctx``; its domain is ``A.domain`` + and its codomain is ``A.codomain``. No dense matrix representation is + formed. + + The forward action is ``apply(x) = alpha * A.apply(x)`` for + ``x in A.domain``. The reverse action is + ``rapply(y) = conj(alpha) * A.rapply(y)`` for ``y in A.codomain``, so + complex scalars use the conjugated coefficient. + """ + + def __init__(self, scalar: Any, op: LinOp[Domain, Codomain]) -> None: + op = _require_linop(op, "op") + if not is_scalar_like(scalar): + raise TypeError(f"scalar must be scalar-like, got {type(scalar).__name__}.") + super().__init__(op.domain, op.codomain, op.ctx) + self.scalar = scalar + self.op = op + + def apply(self, x: Any) -> Any: + """Return ``scalar * op.apply(x)``.""" + return self.scalar * self.op.apply(x) + + def rapply(self, y: Any) -> Any: + """Return ``conj(scalar) * op.rapply(y)``.""" + return _conjugate_scalar(self.scalar) * self.op.rapply(y) + + def __eq__(self, other: Any) -> bool: + if type(other) is type(self): + return self.scalar == other.scalar and self.op == other.op + return False + + def tree_flatten(self): + children = (self.scalar, self.op) + aux = () + return children, aux + + @classmethod + def tree_unflatten(cls, aux, children): + scalar, op = children + return cls(scalar, op) + + def _convert(self, new_ctx: Context) -> ScaledLinOp: + return ScaledLinOp(self.scalar, self.op.convert(new_ctx)) + + +@jax_pytree_class +class SumLinOp(LinOp[Domain, Codomain]): + """ + Lazy finite sum of linear operators with common spaces. + + ``SumLinOp((A1, ..., Ak))`` represents ``A1 + ... + Ak`` for a nonempty + sequence of ``LinOp`` instances. All operands must have the same ``ctx``, + the same domain, and the same codomain before construction. The resulting + operator has that shared context, domain, and codomain. + + The forward action is ``apply(x) = sum_i Ai.apply(x)`` for the shared + domain element ``x``. The reverse action is + ``rapply(y) = sum_i Ai.rapply(y)`` for the shared codomain element ``y``. + """ + + def __init__(self, ops: Sequence[LinOp[Domain, Codomain]]) -> None: + if not ops: + raise ValueError("SumLinOp requires a nonempty sequence of LinOp operands.") + parts = tuple(_require_linop(op, f"ops[{i}]") for i, op in enumerate(ops)) + ctx = _require_same_context(parts) + domain = parts[0].domain + codomain = parts[0].codomain + for i, op in enumerate(parts[1:], start=1): + if op.domain != domain or op.codomain != codomain: + raise ValueError( + "All SumLinOp operands must have the same domain and codomain; " + f"operand 0 maps {domain!r} -> {codomain!r}, " + f"operand {i} maps {op.domain!r} -> {op.codomain!r}." + ) + super().__init__(domain, codomain, ctx) + self.ops_tuple = parts + + @property + def parts(self) -> tuple[LinOp[Domain, Codomain], ...]: + """Operators in this lazy sum.""" + return self.ops_tuple + + def apply(self, x: Any) -> Any: + """Return ``sum_i ops[i].apply(x)``.""" + acc = self.ops_tuple[0].apply(x) + for op in self.ops_tuple[1:]: + acc = self.codomain.add(acc, op.apply(x)) + return acc + + def rapply(self, y: Any) -> Any: + """Return ``sum_i ops[i].rapply(y)``.""" + acc = self.ops_tuple[0].rapply(y) + for op in self.ops_tuple[1:]: + acc = self.domain.add(acc, op.rapply(y)) + return acc + + def __eq__(self, other: Any) -> bool: + if type(other) is type(self): + return self.ops_tuple == other.ops_tuple + return False + + def tree_flatten(self): + children = self.ops_tuple + aux = () + return children, aux + + @classmethod + def tree_unflatten(cls, aux, children): + return cls(tuple(children)) + + def _convert(self, new_ctx: Context) -> SumLinOp: + return SumLinOp(tuple(op.convert(new_ctx) for op in self.ops_tuple)) + + +@jax_pytree_class +class ComposedLinOp(LinOp[Domain, Codomain]): + """ + Lazy composition of two linear operators. + + ``ComposedLinOp(A, B)`` represents ``A @ B = A circ B``. The operands must + have the same ``ctx`` before construction, and ``B.codomain`` must equal + ``A.domain``. The resulting operator has domain ``B.domain`` and codomain + ``A.codomain``. + + The forward action is ``apply(x) = A.apply(B.apply(x))`` for + ``x in B.domain``. The reverse action is ``rapply(z) = B.rapply(A.rapply(z))`` + for ``z in A.codomain``. + """ + + def __init__(self, left: LinOp, right: LinOp) -> None: + left = _require_linop(left, "left") + right = _require_linop(right, "right") + _require_same_context((left, right)) + if right.codomain != left.domain: + raise ValueError( + "ComposedLinOp requires right.codomain == left.domain; " + f"got {right.codomain!r} and {left.domain!r}." + ) + super().__init__(right.domain, left.codomain, left.ctx) + self.left = left + self.right = right + + def apply(self, x: Any) -> Any: + """Return ``left.apply(right.apply(x))``.""" + return self.left.apply(self.right.apply(x)) + + def rapply(self, z: Any) -> Any: + """Return ``right.rapply(left.rapply(z))``.""" + return self.right.rapply(self.left.rapply(z)) + + def __eq__(self, other: Any) -> bool: + if type(other) is type(self): + return self.left == other.left and self.right == other.right + return False + + def tree_flatten(self): + children = (self.left, self.right) + aux = () + return children, aux + + @classmethod + def tree_unflatten(cls, aux, children): + left, right = children + return cls(left, right) + + def _convert(self, new_ctx: Context) -> ComposedLinOp: + return ComposedLinOp(self.left.convert(new_ctx), self.right.convert(new_ctx)) + + +@jax_pytree_class +class ZeroLinOp(LinOp[Domain, Codomain]): + """ + Lazy zero map between two spaces. + + ``ZeroLinOp(X, Y)`` represents the linear map ``0 : X -> Y``. The context is + resolved from the optional ``ctx`` argument and the two spaces, then both + spaces are converted to that context. Its domain is ``X`` and its codomain + is ``Y`` in the resolved context. + + The forward action is ``apply(x) = 0_Y`` for ``x in X``. The reverse action + is ``rapply(y) = 0_X`` for ``y in Y``. + """ + + def __init__( + self, + dom: Domain, + cod: Codomain, + ctx: Context | str | None = None, + ) -> None: + super().__init__(dom, cod, ctx) + + def apply(self, x: Any) -> Any: + """Return the zero element of the codomain.""" + if self._enable_checks: + self.domain._check_member(x) + return self.codomain.zeros() + + def rapply(self, y: Any) -> Any: + """Return the zero element of the domain.""" + if self._enable_checks: + self.codomain._check_member(y) + return self.domain.zeros() + + def __eq__(self, other: Any) -> bool: + if type(other) is type(self): + return self.domain == other.domain and self.codomain == other.codomain + return False + + def tree_flatten(self): + children = () + aux = (self.domain, self.codomain, self.ctx) + return children, aux + + @classmethod + def tree_unflatten(cls, aux, children): + domain, codomain, ctx = aux + return cls(domain, codomain, ctx) + + def _convert(self, new_ctx: Context) -> ZeroLinOp: + return ZeroLinOp(self.domain.convert(new_ctx), self.codomain.convert(new_ctx), new_ctx) + + +@jax_pytree_class +class IdentityLinOp(LinOp[Domain, Domain]): + """ + Lazy identity map on a space. + + ``IdentityLinOp(X)`` represents the identity operator ``I_X : X -> X``. The + context is resolved from the optional ``ctx`` argument and the space, and the + resulting operator has domain and codomain equal to ``X`` in that context. + + The forward action is ``apply(x) = x`` for ``x in X``. The reverse action is + ``rapply(x) = x`` for ``x in X``. + """ + + def __init__(self, space: Domain, ctx: Context | str | None = None) -> None: + super().__init__(space, space, ctx) + + def apply(self, x: Any) -> Any: + """Return ``x`` after domain validation.""" + if self._enable_checks: + self.domain._check_member(x) + return x + + def rapply(self, x: Any) -> Any: + """Return ``x`` after codomain validation.""" + if self._enable_checks: + self.codomain._check_member(x) + return x + + def __eq__(self, other: Any) -> bool: + if type(other) is type(self): + return self.domain == other.domain + return False + + def tree_flatten(self): + children = () + aux = (self.domain, self.ctx) + return children, aux + + @classmethod + def tree_unflatten(cls, aux, children): + domain, ctx = aux + return cls(domain, ctx) + + def _convert(self, new_ctx: Context) -> IdentityLinOp: + return IdentityLinOp(self.domain.convert(new_ctx), new_ctx) + + +@jax_pytree_class +class MatrixFreeLinOp(LinOp[Domain, Codomain]): + """ + Linear operator defined by user-supplied forward and reverse callables. + + ``MatrixFreeLinOp(apply, rapply, X, Y)`` represents a matrix-free map + ``A : X -> Y`` without storing or materializing a matrix. The context is + resolved from the optional ``ctx`` argument and the spaces, then the spaces + are converted to that context. + + The forward action is ``apply(x) = apply_fn(x)`` for ``x in X``. The reverse + action is ``rapply(y) = rapply_fn(y)`` for ``y in Y``. When checks are + enabled, inputs and callable outputs are validated against the corresponding + domain and codomain. + """ + + def __init__( + self, + apply: Any, + rapply: Any, + dom: Domain, + cod: Codomain, + ctx: Context | str | None = None, + ) -> None: + if not callable(apply): + raise TypeError(f"apply must be callable, got {type(apply).__name__}.") + if not callable(rapply): + raise TypeError(f"rapply must be callable, got {type(rapply).__name__}.") + super().__init__(dom, cod, ctx) + self.apply_fn = apply + self.rapply_fn = rapply + + def apply(self, x: Any) -> Any: + """Return ``apply_fn(x)``.""" + if self._enable_checks: + self.domain._check_member(x) + y = self.apply_fn(x) + if self._enable_checks: + self.codomain._check_member(y) + return y + + def rapply(self, y: Any) -> Any: + """Return ``rapply_fn(y)``.""" + if self._enable_checks: + self.codomain._check_member(y) + x = self.rapply_fn(y) + if self._enable_checks: + self.domain._check_member(x) + return x + + def __eq__(self, other: Any) -> bool: + if type(other) is type(self): + return ( + self.domain == other.domain + and self.codomain == other.codomain + and self.apply_fn is other.apply_fn + and self.rapply_fn is other.rapply_fn + ) + return False + + def tree_flatten(self): + children = () + aux = (self.apply_fn, self.rapply_fn, self.domain, self.codomain, self.ctx) + return children, aux + + @classmethod + def tree_unflatten(cls, aux, children): + apply_fn, rapply_fn, domain, codomain, ctx = aux + return cls(apply_fn, rapply_fn, domain, codomain, ctx) + + def _convert(self, new_ctx: Context) -> MatrixFreeLinOp: + return MatrixFreeLinOp( + self.apply_fn, + self.rapply_fn, + self.domain.convert(new_ctx), + self.codomain.convert(new_ctx), + new_ctx, + ) + + +@jax_pytree_class +class _AdjointViewLinOp(LinOp[Codomain, Domain]): + """ + Hermitian-adjoint view of a linear operator. + + ``A.H`` represents the adjoint view ``A*``. Its context is exactly + ``A.ctx``; its domain is ``A.codomain`` and its codomain is ``A.domain``. + ``A.H.H`` returns ``A`` rather than constructing another wrapper. + + The forward action is ``apply(y) = A.rapply(y)`` for ``y in A.codomain``. + The reverse action is ``rapply(x) = A.apply(x)`` for ``x in A.domain``. + """ + + def __init__(self, op: LinOp[Domain, Codomain]) -> None: + op = _require_linop(op, "op") + super().__init__(op.codomain, op.domain, op.ctx) + self.op = op + + def apply(self, y: Any) -> Any: + """Return ``op.rapply(y)``.""" + return self.op.rapply(y) + + def rapply(self, x: Any) -> Any: + """Return ``op.apply(x)``.""" + return self.op.apply(x) + + @property + def H(self) -> LinOp[Domain, Codomain]: + """Original operator viewed as the adjoint of this adjoint view.""" + return self.op + + def __eq__(self, other: Any) -> bool: + if type(other) is type(self): + return self.op == other.op + return False + + def tree_flatten(self): + children = (self.op,) + aux = () + return children, aux + + @classmethod + def tree_unflatten(cls, aux, children): + return cls(children[0]) + + def _convert(self, new_ctx: Context) -> _AdjointViewLinOp: + return _AdjointViewLinOp(self.op.convert(new_ctx)) diff --git a/spacecore/linop/_base.py b/spacecore/linop/_base.py index dfc06b9..05707d5 100644 --- a/spacecore/linop/_base.py +++ b/spacecore/linop/_base.py @@ -1,6 +1,7 @@ from __future__ import annotations from abc import abstractmethod +from numbers import Number from typing import Any, Generic, TypeVar from ..space import Space @@ -31,6 +32,16 @@ def __init__(self, dom: Domain, cod: Codomain, ctx: Context | str | None = None) self.cod = cod.convert(self.ctx) self._enable_checks = self.ctx.enable_checks + @property + def domain(self) -> Domain: + """Domain space of this linear operator.""" + return self.dom + + @property + def codomain(self) -> Codomain: + """Codomain space of this linear operator.""" + return self.cod + @abstractmethod def apply(self, x: Any) -> Any: """ @@ -52,8 +63,90 @@ def rapply(self, y: Any) -> Any: """ def __call__(self, x: Any) -> Any: + """Apply this linear operator to ``x``.""" return self.apply(x) + def adjoint_apply(self, y: Any) -> Any: + """Apply the adjoint of this linear operator to ``y``.""" + return self.rapply(y) + + @property + def H(self) -> LinOp: + """Hermitian-adjoint view of this linear operator.""" + from ._algebra import _AdjointViewLinOp + + return _AdjointViewLinOp(self) + + def __add__(self, other: Any) -> LinOp: + """Return the lazy sum ``self + other`` of two compatible operators.""" + from ._algebra import make_sum + + if not isinstance(other, LinOp): + return NotImplemented + return make_sum((self, other)) + + def __radd__(self, other: Any) -> LinOp: + """Return the lazy sum ``other + self`` of two compatible operators.""" + from ._algebra import make_sum + + if isinstance(other, Number) and other == 0: + return self + if not isinstance(other, LinOp): + return NotImplemented + return make_sum((other, self)) + + def __neg__(self) -> LinOp: + """Return the lazy negation ``-self``.""" + from ._algebra import make_scaled + + return make_scaled(-1, self) + + def __sub__(self, other: Any) -> LinOp: + """Return the lazy difference ``self - other`` of two compatible operators.""" + from ._algebra import make_scaled, make_sum + + if not isinstance(other, LinOp): + return NotImplemented + return make_sum((self, make_scaled(-1, other))) + + def __rsub__(self, other: Any) -> LinOp: + """Return the lazy difference ``other - self`` of two compatible operators.""" + from ._algebra import make_scaled, make_sum + + if isinstance(other, Number) and other == 0: + return make_scaled(-1, self) + if not isinstance(other, LinOp): + return NotImplemented + return make_sum((other, make_scaled(-1, self))) + + def __mul__(self, scalar: Any) -> LinOp: + """Return the lazy right scalar multiple ``self * scalar``.""" + from ._algebra import is_scalar_like, make_scaled + + if not is_scalar_like(scalar): + return NotImplemented + return make_scaled(scalar, self) + + def __rmul__(self, scalar: Any) -> LinOp: + """Return the lazy left scalar multiple ``scalar * self``.""" + from ._algebra import is_scalar_like, make_scaled + + if not is_scalar_like(scalar): + return NotImplemented + return make_scaled(scalar, self) + + def __matmul__(self, other: Any) -> LinOp: + """Return the lazy composition ``self @ other`` of two compatible operators.""" + from ._algebra import make_composed + + if not isinstance(other, LinOp): + return NotImplemented + return make_composed(self, other) + + def adjoint(self) -> LinOp: + """Return the Hermitian-adjoint view of this linear operator.""" + return self.H + def assert_domain(self, x: Any) -> None: self.dom.check_member(x) diff --git a/tests/integration/test_public_api.py b/tests/integration/test_public_api.py index 26ceb65..45b6f4a 100644 --- a/tests/integration/test_public_api.py +++ b/tests/integration/test_public_api.py @@ -19,6 +19,9 @@ def test_expected_names_are_exported(): sc = importlib.import_module("spacecore") expected = { "Context", "BackendOps", "NumpyOps", "DenseLinOp", "SparseLinOp", + "ScaledLinOp", "SumLinOp", "ComposedLinOp", "ZeroLinOp", + "IdentityLinOp", "MatrixFreeLinOp", "make_sum", "make_scaled", + "make_composed", "BlockDiagonalLinOp", "StackedLinOp", "SumToSingleLinOp", "VectorSpace", "HermitianSpace", "ProductSpace", "Space", "DenseArray", "SparseArray", "ArrayLike", diff --git a/tests/linops/test_algebra_linop.py b/tests/linops/test_algebra_linop.py new file mode 100644 index 0000000..fd2a8f8 --- /dev/null +++ b/tests/linops/test_algebra_linop.py @@ -0,0 +1,244 @@ +import importlib + +import numpy as np +import pytest + + +def _ctx(dtype=np.float64, enable_checks=True): + sc = importlib.import_module("spacecore") + return sc.Context(sc.NumpyOps(), dtype=dtype, enable_checks=enable_checks) + + +def _op(matrix, dom_shape, cod_shape, ctx=None): + sc = importlib.import_module("spacecore") + ctx = ctx or _ctx() + dom = sc.VectorSpace(dom_shape, ctx) + cod = sc.VectorSpace(cod_shape, ctx) + return sc.DenseLinOp(ctx.asarray(matrix), dom, cod, ctx) + + +def test_algebra_linops_inherit_from_linop(): + sc = importlib.import_module("spacecore") + A = _op([[1.0, 2.0], [3.0, 4.0]], (2,), (2,)) + + assert isinstance(2.0 * A, sc.LinOp) + assert isinstance(A + A, sc.LinOp) + assert isinstance(A @ A, sc.LinOp) + assert isinstance(A.H, sc.LinOp) + assert isinstance(sc.ZeroLinOp(A.domain, A.codomain, A.ctx), sc.LinOp) + assert isinstance(sc.IdentityLinOp(A.domain, A.ctx), sc.LinOp) + assert isinstance(sc.MatrixFreeLinOp(A.apply, A.rapply, A.domain, A.codomain, A.ctx), sc.LinOp) + assert issubclass(sc.ScaledLinOp, sc.LinOp) + assert issubclass(sc.SumLinOp, sc.LinOp) + assert issubclass(sc.ComposedLinOp, sc.LinOp) + assert issubclass(sc.ZeroLinOp, sc.LinOp) + assert issubclass(sc.IdentityLinOp, sc.LinOp) + assert issubclass(sc.MatrixFreeLinOp, sc.LinOp) + assert not hasattr(sc, "AdjointLinOp") + + +def test_context_mismatch_raises_clear_error(): + A = _op([[1.0, 2.0], [3.0, 4.0]], (2,), (2,), _ctx(enable_checks=True)) + B = _op([[5.0, 6.0], [7.0, 8.0]], (2,), (2,), _ctx(enable_checks=False)) + + with pytest.raises(ValueError, match="same ctx"): + _ = A + B + with pytest.raises(ValueError, match="same ctx"): + _ = A @ B + + +def test_sum_requires_matching_domain_and_codomain(): + A = _op([[1.0, 2.0], [3.0, 4.0]], (2,), (2,)) + bad_cod = _op([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], (2,), (3,), A.ctx) + bad_dom = _op([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], (3,), (2,), A.ctx) + + with pytest.raises(ValueError, match="same domain and codomain"): + _ = A + bad_cod + with pytest.raises(ValueError, match="same domain and codomain"): + _ = A + bad_dom + + +def test_composition_requires_matching_middle_space(): + A = _op([[1.0, 2.0], [3.0, 4.0]], (2,), (2,)) + B = _op([[1.0, 2.0], [3.0, 4.0]], (2,), (2,), A.ctx) + C = _op([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], (2,), (3,), A.ctx) + + assert (A @ B).domain == B.domain + assert (A @ B).codomain == A.codomain + with pytest.raises(ValueError, match="right.codomain == left.domain"): + _ = A @ C + + +def test_scaled_sum_subtraction_and_negation_are_numerically_correct(): + ctx = _ctx() + A = _op([[1.0, 2.0], [3.0, 4.0]], (2,), (2,), ctx) + B = _op([[5.0, 1.0], [-2.0, 3.0]], (2,), (2,), ctx) + x = ctx.asarray([2.0, -1.0]) + y = ctx.asarray([1.0, 3.0]) + dense_a = np.array([[1.0, 2.0], [3.0, 4.0]]) + dense_b = np.array([[5.0, 1.0], [-2.0, 3.0]]) + + expr = 2.0 * A + B - (-A) + + assert expr.domain == A.domain + assert expr.codomain == A.codomain + assert np.allclose(expr.apply(x), (3.0 * dense_a + dense_b) @ np.asarray(x)) + assert np.allclose(expr.rapply(y), (3.0 * dense_a + dense_b).T @ np.asarray(y)) + assert np.allclose((-A).apply(x), -dense_a @ np.asarray(x)) + assert np.allclose((A * 3.0).apply(x), 3.0 * dense_a @ np.asarray(x)) + + +def test_complex_scaled_adjoint_conjugates_scalar(): + ctx = _ctx(np.complex128) + A = _op([[1.0 + 1.0j, 2.0], [3.0j, 4.0 - 2.0j]], (2,), (2,), ctx) + y = ctx.asarray([1.0 - 1.0j, 2.0 + 3.0j]) + dense = np.array([[1.0 + 1.0j, 2.0], [3.0j, 4.0 - 2.0j]]) + alpha = 2.0 + 3.0j + + op = alpha * A + + assert np.allclose(op.rapply(y), np.conj(alpha) * dense.conj().T @ np.asarray(y)) + + +def test_composition_apply_and_adjoint_are_numerically_correct(): + ctx = _ctx() + A = _op([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], (2,), (3,), ctx) + B = _op([[2.0, -1.0], [0.5, 3.0]], (2,), (2,), ctx) + x = ctx.asarray([4.0, -2.0]) + z = ctx.asarray([1.0, -1.0, 2.0]) + dense_a = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + dense_b = np.array([[2.0, -1.0], [0.5, 3.0]]) + + op = A @ B + + assert op.domain == B.domain + assert op.codomain == A.codomain + assert np.allclose(op.apply(x), dense_a @ dense_b @ np.asarray(x)) + assert np.allclose(op.rapply(z), dense_b.T @ dense_a.T @ np.asarray(z)) + + +def test_H_swaps_spaces_and_double_H_returns_original(): + ctx = _ctx() + A = _op([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], (2,), (3,), ctx) + x = ctx.asarray([7.0, 8.0]) + y = ctx.asarray([1.0, -1.0, 2.0]) + + AH = A.H + AHH = AH.H + + assert AH.ctx == A.ctx + assert AH.domain == A.codomain + assert AH.codomain == A.domain + assert np.allclose(AH.apply(y), A.rapply(y)) + assert np.allclose(AH.rapply(x), A.apply(x)) + assert AHH is A + assert np.allclose(AHH.apply(x), A.apply(x)) + assert np.allclose(AHH.rapply(y), A.rapply(y)) + + +def test_zero_identity_and_matrix_free_rapply_are_numerically_correct(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + dom = sc.VectorSpace((2,), ctx) + cod = sc.VectorSpace((3,), ctx) + dense = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + x = ctx.asarray([7.0, 8.0]) + y = ctx.asarray([1.0, -1.0, 2.0]) + + zero = sc.ZeroLinOp(dom, cod, ctx) + identity = sc.IdentityLinOp(dom, ctx) + matrix_free = sc.MatrixFreeLinOp( + lambda v: ctx.asarray(dense @ np.asarray(v)), + lambda w: ctx.asarray(dense.T @ np.asarray(w)), + dom, + cod, + ctx, + ) + + assert np.allclose(zero.apply(x), np.zeros(3)) + assert np.allclose(zero.rapply(y), np.zeros(2)) + assert np.allclose(identity.apply(x), np.asarray(x)) + assert np.allclose(identity.rapply(x), np.asarray(x)) + assert np.allclose(matrix_free.apply(x), dense @ np.asarray(x)) + assert np.allclose(matrix_free.rapply(y), dense.T @ np.asarray(y)) + + +def test_sum_factory_flattens_nested_sums_and_removes_zero_terms(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + A = _op([[1.0, 2.0], [3.0, 4.0]], (2,), (2,), ctx) + B = _op([[5.0, 1.0], [-2.0, 3.0]], (2,), (2,), ctx) + Z = sc.ZeroLinOp(A.domain, A.codomain, ctx) + x = ctx.asarray([2.0, -1.0]) + y = ctx.asarray([1.0, 3.0]) + + nested = sc.SumLinOp((A, B)) + simplified = nested + Z + zero_sum = Z + Z + + assert isinstance(simplified, sc.SumLinOp) + assert simplified.parts == (A, B) + assert A + Z is A + assert Z + A is A + assert isinstance(zero_sum, sc.ZeroLinOp) + assert zero_sum.domain == A.domain + assert zero_sum.codomain == A.codomain + + unsimplified = sc.SumLinOp((nested, Z)) + assert np.allclose(simplified.apply(x), unsimplified.apply(x)) + assert np.allclose(simplified.rapply(y), unsimplified.rapply(y)) + + +def test_scaling_factory_simplifies_zero_one_and_nested_scaling(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + A = _op([[1.0, 2.0], [3.0, 4.0]], (2,), (2,), ctx) + x = ctx.asarray([2.0, -1.0]) + y = ctx.asarray([1.0, 3.0]) + dense = np.array([[1.0, 2.0], [3.0, 4.0]]) + + zero = 0 * A + unit = 1 * A + nested = 2 * (3 * A) + + assert isinstance(zero, sc.ZeroLinOp) + assert unit is A + assert isinstance(nested, sc.ScaledLinOp) + assert nested.scalar == 6 + assert nested.op is A + assert np.allclose(zero.apply(x), np.zeros(2)) + assert np.allclose(zero.rapply(y), np.zeros(2)) + assert np.allclose(nested.apply(x), 6 * dense @ np.asarray(x)) + assert np.allclose(nested.rapply(y), 6 * dense.T @ np.asarray(y)) + + +def test_composition_factory_simplifies_identity_and_zero_factors(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + A = _op([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], (2,), (3,), ctx) + id_domain = sc.IdentityLinOp(A.domain, ctx) + id_codomain = sc.IdentityLinOp(A.codomain, ctx) + left_zero = sc.ZeroLinOp(A.codomain, sc.VectorSpace((4,), ctx), ctx) + right_zero = sc.ZeroLinOp(sc.VectorSpace((5,), ctx), A.domain, ctx) + x = ctx.asarray([7.0, 8.0]) + y = ctx.asarray([1.0, -1.0, 2.0]) + dense = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + + assert A @ id_domain is A + assert id_codomain @ A is A + + left_simplified = left_zero @ A + right_simplified = A @ right_zero + + assert isinstance(left_simplified, sc.ZeroLinOp) + assert left_simplified.domain == A.domain + assert left_simplified.codomain == left_zero.codomain + assert isinstance(right_simplified, sc.ZeroLinOp) + assert right_simplified.domain == right_zero.domain + assert right_simplified.codomain == A.codomain + + unsimplified_left = sc.ComposedLinOp(left_zero, A) + assert np.allclose((A @ id_domain).apply(x), dense @ np.asarray(x)) + assert np.allclose((id_codomain @ A).rapply(y), dense.T @ np.asarray(y)) + assert np.allclose(left_simplified.apply(x), unsimplified_left.apply(x)) + assert np.allclose(left_simplified.rapply(ctx.asarray([1.0, 2.0, 3.0, 4.0])), np.zeros(2)) From d5cf13982eaa88e8c3643b4445af082af29cbdb2 Mon Sep 17 00:00:00 2001 From: Pavlo Pelikh Date: Wed, 20 May 2026 01:04:50 -0300 Subject: [PATCH 2/7] Fix LinOp base equality protocol --- spacecore/linop/_base.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/spacecore/linop/_base.py b/spacecore/linop/_base.py index 05707d5..6cd45aa 100644 --- a/spacecore/linop/_base.py +++ b/spacecore/linop/_base.py @@ -153,12 +153,14 @@ def assert_domain(self, x: Any) -> None: def assert_codomain(self, y: Any) -> None: self.cod.check_member(y) - def __eq__(self, x: Any) -> bool: - raise NotImplementedError() + def __eq__(self, other: Any) -> bool: + return NotImplemented + @abstractmethod def tree_flatten(self): - raise NotImplementedError() + ... @classmethod + @abstractmethod def tree_unflatten(cls, aux, children): - raise NotImplementedError() + ... From c454332b4f02d19cae37909c41c41ff596833201 Mon Sep 17 00:00:00 2001 From: Pavlo Pelikh Date: Wed, 20 May 2026 01:05:37 -0300 Subject: [PATCH 3/7] Document Space membership check convention --- spacecore/space/_base.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/spacecore/space/_base.py b/spacecore/space/_base.py index 26eaf6f..d206bb4 100644 --- a/spacecore/space/_base.py +++ b/spacecore/space/_base.py @@ -16,6 +16,11 @@ class Space(ContextBound): A Space owns the *geometry* (inner product, norm) and the basic linear structure (add/scale/axpy) for its elements. + Membership validation is exposed through ``check_member``, which respects + the space's ``enable_checks`` policy. Internal code paths that have already + checked that policy may call ``_check_member`` to run the concrete checks + exactly once. + Solvers should use only this API. """ From 2d54d6458f066e16a3156b4510fb974df432a26f Mon Sep 17 00:00:00 2001 From: Pavlo Pelikh Date: Wed, 20 May 2026 01:06:04 -0300 Subject: [PATCH 4/7] Simplify scaled zero LinOps --- spacecore/linop/_algebra.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/spacecore/linop/_algebra.py b/spacecore/linop/_algebra.py index f104c2f..a1e7cc8 100644 --- a/spacecore/linop/_algebra.py +++ b/spacecore/linop/_algebra.py @@ -128,6 +128,8 @@ def make_scaled(scalar: Any, op: LinOp) -> LinOp: return ZeroLinOp(op.domain, op.codomain, op.ctx) if _is_one_scalar(scalar): return op + if isinstance(op, ZeroLinOp): + return op if isinstance(op, ScaledLinOp): return make_scaled(scalar * op.scalar, op.op) return ScaledLinOp(scalar, op) From 57283f806fad0a05644884cf6b0209b663c17da7 Mon Sep 17 00:00:00 2001 From: Pavlo Pelikh Date: Wed, 20 May 2026 01:07:59 -0300 Subject: [PATCH 5/7] Add dense materialization for LinOps --- spacecore/linop/_algebra.py | 20 ++++++++ spacecore/linop/_base.py | 23 +++++++++ spacecore/linop/_dense.py | 8 +++ tests/linops/test_to_dense.py | 93 +++++++++++++++++++++++++++++++++++ 4 files changed, 144 insertions(+) create mode 100644 tests/linops/test_to_dense.py diff --git a/spacecore/linop/_algebra.py b/spacecore/linop/_algebra.py index a1e7cc8..cebd8b7 100644 --- a/spacecore/linop/_algebra.py +++ b/spacecore/linop/_algebra.py @@ -375,6 +375,14 @@ def rapply(self, y: Any) -> Any: self.codomain._check_member(y) return self.domain.zeros() + def to_dense(self) -> Any: + """ + Return the dense tensor representation of the zero map. + + The returned array has shape ``self.codomain.shape + self.domain.shape``. + """ + return self.ops.zeros(tuple(self.codomain.shape) + tuple(self.domain.shape), dtype=self.dtype) + def __eq__(self, other: Any) -> bool: if type(other) is type(self): return self.domain == other.domain and self.codomain == other.codomain @@ -422,6 +430,18 @@ def rapply(self, x: Any) -> Any: self.codomain._check_member(x) return x + def to_dense(self) -> Any: + """ + Return the dense tensor representation of this identity map. + + The returned array has shape ``self.codomain.shape + self.domain.shape``. + """ + size = 1 + for dim in self.domain.shape: + size *= dim + eye = self.ops.eye(size, dtype=self.dtype) + return self.ops.reshape(eye, tuple(self.codomain.shape) + tuple(self.domain.shape)) + def __eq__(self, other: Any) -> bool: if type(other) is type(self): return self.domain == other.domain diff --git a/spacecore/linop/_base.py b/spacecore/linop/_base.py index 6cd45aa..f59d02e 100644 --- a/spacecore/linop/_base.py +++ b/spacecore/linop/_base.py @@ -1,6 +1,7 @@ from __future__ import annotations from abc import abstractmethod +from math import prod from numbers import Number from typing import Any, Generic, TypeVar @@ -147,6 +148,28 @@ def adjoint(self) -> LinOp: """Return the Hermitian-adjoint view of this linear operator.""" return self.H + def to_dense(self) -> Any: + """ + Materialize this operator as a dense backend array. + + The returned array has shape ``self.codomain.shape + self.domain.shape``. + The default implementation applies the operator to each standard basis + vector of the domain, stacks the flattened outputs as matrix columns, + and reshapes the result back to tensor-operator form. Subclasses that + already store the matrix should override this method for efficiency. + """ + domain_size = prod(self.domain.shape) + codomain_size = prod(self.codomain.shape) + zero = self.ops.zeros((domain_size,), dtype=self.dtype) + columns = [] + for i in range(domain_size): + basis_vector = self.ops.index_set(zero, i, 1, copy=True) + x = self.domain.unflatten(basis_vector) + y = self.apply(x) + columns.append(self.codomain.flatten(y)) + matrix = self.ops.stack(tuple(columns), axis=1) + return self.ops.reshape(matrix, tuple(self.codomain.shape) + tuple(self.domain.shape)) + def assert_domain(self, x: Any) -> None: self.dom.check_member(x) diff --git a/spacecore/linop/_dense.py b/spacecore/linop/_dense.py index a751799..9b36a09 100644 --- a/spacecore/linop/_dense.py +++ b/spacecore/linop/_dense.py @@ -88,6 +88,14 @@ def _rapply_unchecked(self, y: DenseArray) -> DenseArray: return x1 if self._dom_is_flat else x1.reshape(self.dom.shape) return self.dom.unflatten(x1) + def to_dense(self) -> DenseArray: + """ + Return the stored dense tensor representation of this operator. + + The returned array has shape ``self.codomain.shape + self.domain.shape``. + """ + return self.A + def __eq__(self, x: Any) -> bool: if type(x) is type(self): return (self.dom == x.dom diff --git a/tests/linops/test_to_dense.py b/tests/linops/test_to_dense.py new file mode 100644 index 0000000..5ca33f2 --- /dev/null +++ b/tests/linops/test_to_dense.py @@ -0,0 +1,93 @@ +import importlib + +import numpy as np +import scipy.sparse as sps + + +def _ctx(): + sc = importlib.import_module("spacecore") + return sc.Context(sc.NumpyOps(), dtype=np.float64) + + +def _assert_to_dense_matches_apply(op, x): + dense = op.to_dense() + matrix = dense.reshape((np.prod(op.codomain.shape), np.prod(op.domain.shape))) + y_from_dense = matrix @ op.domain.flatten(x) + y_from_apply = op.codomain.flatten(op.apply(x)) + assert np.allclose(y_from_dense, y_from_apply) + + +def test_dense_linop_to_dense_returns_stored_matrix_and_matches_apply(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + dom = sc.VectorSpace((2,), ctx) + cod = sc.VectorSpace((3,), ctx) + A = ctx.asarray([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + op = sc.DenseLinOp(A, dom, cod, ctx) + + assert op.to_dense() is A + _assert_to_dense_matches_apply(op, ctx.asarray([7.0, 8.0])) + + +def test_sparse_linop_to_dense_matches_apply(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + dom = sc.VectorSpace((2,), ctx) + cod = sc.VectorSpace((3,), ctx) + dense = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + op = sc.SparseLinOp(sps.csr_matrix(dense), dom, cod, ctx) + + assert np.allclose(op.to_dense(), dense) + _assert_to_dense_matches_apply(op, ctx.asarray([7.0, 8.0])) + + +def test_identity_linop_to_dense_matches_apply(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + space = sc.VectorSpace((2, 2), ctx) + op = sc.IdentityLinOp(space, ctx) + + assert np.allclose(op.to_dense().reshape((4, 4)), np.eye(4)) + _assert_to_dense_matches_apply(op, ctx.asarray([[1.0, 2.0], [3.0, 4.0]])) + + +def test_zero_linop_to_dense_matches_apply(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + dom = sc.VectorSpace((2,), ctx) + cod = sc.VectorSpace((3,), ctx) + op = sc.ZeroLinOp(dom, cod, ctx) + + assert np.allclose(op.to_dense(), np.zeros((3, 2))) + _assert_to_dense_matches_apply(op, ctx.asarray([7.0, 8.0])) + + +def test_matrix_free_linop_to_dense_matches_apply(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + dom = sc.VectorSpace((2,), ctx) + cod = sc.VectorSpace((3,), ctx) + dense = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + op = sc.MatrixFreeLinOp( + lambda x: ctx.asarray(dense @ np.asarray(x)), + lambda y: ctx.asarray(dense.T @ np.asarray(y)), + dom, + cod, + ctx, + ) + + assert np.allclose(op.to_dense(), dense) + _assert_to_dense_matches_apply(op, ctx.asarray([7.0, 8.0])) + + +def test_sum_linop_to_dense_matches_apply(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + dom = sc.VectorSpace((2,), ctx) + cod = sc.VectorSpace((3,), ctx) + A = sc.DenseLinOp(ctx.asarray([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]), dom, cod, ctx) + B = sc.DenseLinOp(ctx.asarray([[0.5, 1.0], [-1.0, 2.0], [3.0, -0.5]]), dom, cod, ctx) + op = A + B + + assert np.allclose(op.to_dense(), A.to_dense() + B.to_dense()) + _assert_to_dense_matches_apply(op, ctx.asarray([7.0, 8.0])) From 89f007bc3dc38ca59912d132ba96a09800cb547e Mon Sep 17 00:00:00 2001 From: Pavlo Pelikh Date: Wed, 20 May 2026 01:11:35 -0300 Subject: [PATCH 6/7] Add algebra layer regression tests --- spacecore/linop/_base.py | 7 +- tests/linops/test_algebra.py | 285 +++++++++++++++++++++++++++++++++++ 2 files changed, 290 insertions(+), 2 deletions(-) create mode 100644 tests/linops/test_algebra.py diff --git a/spacecore/linop/_base.py b/spacecore/linop/_base.py index f59d02e..d1a05be 100644 --- a/spacecore/linop/_base.py +++ b/spacecore/linop/_base.py @@ -76,7 +76,11 @@ def H(self) -> LinOp: """Hermitian-adjoint view of this linear operator.""" from ._algebra import _AdjointViewLinOp - return _AdjointViewLinOp(self) + view = getattr(self, "_adjoint_view", None) + if view is None: + view = _AdjointViewLinOp(self) + self._adjoint_view = view + return view def __add__(self, other: Any) -> LinOp: """Return the lazy sum ``self + other`` of two compatible operators.""" @@ -159,7 +163,6 @@ def to_dense(self) -> Any: already store the matrix should override this method for efficiency. """ domain_size = prod(self.domain.shape) - codomain_size = prod(self.codomain.shape) zero = self.ops.zeros((domain_size,), dtype=self.dtype) columns = [] for i in range(domain_size): diff --git a/tests/linops/test_algebra.py b/tests/linops/test_algebra.py new file mode 100644 index 0000000..770b20c --- /dev/null +++ b/tests/linops/test_algebra.py @@ -0,0 +1,285 @@ +import importlib + +import numpy as np +import pytest + +from tests._helpers import has_jax, has_torch, jax_complex_dtype, jax_real_dtype +from tests._helpers import to_numpy, torch_complex_dtype + + +def _backend_params(): + params = [pytest.param("numpy", np.complex128, id="numpy")] + params.append( + pytest.param( + "jax", + jax_complex_dtype(), + marks=pytest.mark.skipif(not has_jax(), reason="jax is not installed"), + id="jax", + ) + ) + params.append( + pytest.param( + "torch", + torch_complex_dtype(), + marks=pytest.mark.skipif(not has_torch(), reason="torch is not installed"), + id="torch", + ) + ) + return params + + +def _ops_for_backend(name): + sc = importlib.import_module("spacecore") + if name == "numpy": + return sc.NumpyOps() + if name == "jax": + return sc.JaxOps() + if name == "torch": + return sc.TorchOps() + raise ValueError(f"Unknown backend {name!r}.") + + +def _ctx(dtype=np.complex128, enable_checks=True): + sc = importlib.import_module("spacecore") + return sc.Context(sc.NumpyOps(), dtype=dtype, enable_checks=enable_checks) + + +def _spaces(ctx): + sc = importlib.import_module("spacecore") + return sc.VectorSpace((2,), ctx), sc.VectorSpace((3,), ctx) + + +def _matrix(): + return np.array( + [ + [1.0 + 2.0j, 3.0 - 1.0j], + [-2.0 + 0.5j, 0.25 + 4.0j], + [1.5 - 3.0j, -0.75 + 2.0j], + ] + ) + + +def _square_matrix(): + return np.array([[2.0 - 1.0j, -0.5 + 0.25j], [1.25 + 2.0j, -3.0 + 0.5j]]) + + +def _dense_linop(ctx): + sc = importlib.import_module("spacecore") + dom, cod = _spaces(ctx) + return sc.DenseLinOp(ctx.asarray(_matrix()), dom, cod, ctx) + + +def _dense_same_shape(ctx, scale=1.0): + sc = importlib.import_module("spacecore") + dom, cod = _spaces(ctx) + return sc.DenseLinOp(ctx.asarray(scale * _matrix()), dom, cod, ctx) + + +def _dense_square(ctx): + sc = importlib.import_module("spacecore") + dom = sc.VectorSpace((2,), ctx) + return sc.DenseLinOp(ctx.asarray(_square_matrix()), dom, dom, ctx) + + +def _xy(ctx): + x = ctx.asarray([2.0 - 1.0j, -0.5 + 0.25j]) + y = ctx.asarray([1.0 + 0.5j, -2.0j, 0.75 - 1.25j]) + return x, y + + +def _assert_adjoint_identity(op, x, y, ctx): + lhs = ctx.ops.vdot(op.apply(x), y) + rhs = ctx.ops.vdot(x, op.rapply(y)) + np.testing.assert_allclose(to_numpy(lhs), to_numpy(rhs), rtol=1e-6, atol=1e-6) + + +def _adjoint_cases(ctx): + sc = importlib.import_module("spacecore") + A = _dense_linop(ctx) + B = _dense_same_shape(ctx, scale=0.5 - 0.25j) + C = _dense_square(ctx) + dom, cod = _spaces(ctx) + x, y = _xy(ctx) + z = ctx.asarray([-1.0 + 0.5j, 2.0 - 0.25j]) + + matrix = ctx.asarray(_matrix()) + matrix_free = sc.MatrixFreeLinOp( + lambda v: matrix @ v, + lambda w: ctx.ops.conj(ctx.ops.transpose(matrix)) @ w, + dom, + cod, + ctx, + ) + + return [ + ((2.0 + 3.0j) * A, x, y), + (A + B, x, y), + (A @ C, z, y), + (sc.ZeroLinOp(dom, cod, ctx), x, y), + (sc.IdentityLinOp(dom, ctx), x, x), + (matrix_free, x, y), + (A.H, y, x), + ] + + +@pytest.mark.parametrize("backend_name,dtype", _backend_params()) +@pytest.mark.parametrize("case_index", range(7)) +def test_complex_adjoint_identity_for_algebra_classes(backend_name, dtype, case_index): + sc = importlib.import_module("spacecore") + ctx = sc.Context(_ops_for_backend(backend_name), dtype=dtype) + op, x, y = _adjoint_cases(ctx)[case_index] + + _assert_adjoint_identity(op, x, y, ctx) + + +def test_simplification_canonicalizations(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + A = _dense_linop(ctx) + B = _dense_same_shape(ctx, scale=2.0) + C = _dense_same_shape(ctx, scale=-1.0) + Z = sc.ZeroLinOp(A.domain, A.codomain, ctx) + + assert sc.make_sum((A, Z)) is A + assert isinstance(sc.make_sum((Z, Z)), sc.ZeroLinOp) + assert sc.make_sum((A,)) is A + flattened = sc.make_sum((sc.make_sum((A, B)), C)) + assert isinstance(flattened, sc.SumLinOp) + assert flattened.parts == (A, B, C) + + scaled_zero = sc.make_scaled(0, A) + assert isinstance(scaled_zero, sc.ZeroLinOp) + assert scaled_zero.domain == A.domain + assert scaled_zero.codomain == A.codomain + assert sc.make_scaled(1, A) is A + assert sc.make_scaled(7.0, Z) is Z + folded = sc.make_scaled(2, sc.make_scaled(3, A)) + assert isinstance(folded, sc.ScaledLinOp) + assert folded.scalar == 6 + assert folded.op is A + + I_dom = sc.IdentityLinOp(A.domain, ctx) + I_cod = sc.IdentityLinOp(A.codomain, ctx) + assert sc.make_composed(I_cod, A) is A + assert sc.make_composed(A, I_dom) is A + + out = sc.VectorSpace((4,), ctx) + left_zero = sc.ZeroLinOp(A.codomain, out, ctx) + composed_zero = sc.make_composed(left_zero, A) + assert isinstance(composed_zero, sc.ZeroLinOp) + assert composed_zero.domain == A.domain + assert composed_zero.codomain == out + + +@pytest.mark.parametrize("case_index", range(7)) +def test_double_adjoint_view_returns_literal_original(case_index): + ctx = _ctx() + op, _, _ = _adjoint_cases(ctx)[case_index] + + assert op.H.H is op + + +def test_identity_linop_apply_is_literal_input_when_checks_disabled(): + sc = importlib.import_module("spacecore") + ctx = _ctx(enable_checks=False) + space = sc.VectorSpace((2,), ctx) + op = sc.IdentityLinOp(space, ctx) + x = ctx.asarray([1.0 + 2.0j, 3.0 - 4.0j]) + + assert op.apply(x) is x + assert op.rapply(x) is x + + +def test_identity_linop_apply_equals_input_when_checks_enabled(): + sc = importlib.import_module("spacecore") + ctx = _ctx(enable_checks=True) + space = sc.VectorSpace((2,), ctx) + op = sc.IdentityLinOp(space, ctx) + x = ctx.asarray([1.0 + 2.0j, 3.0 - 4.0j]) + + np.testing.assert_allclose(op.apply(x), x) + np.testing.assert_allclose(op.rapply(x), x) + + +def test_python_sum_starts_from_zero_and_accumulates_linops(): + ctx = _ctx() + A = _dense_same_shape(ctx, scale=1.0) + B = _dense_same_shape(ctx, scale=0.5) + C = _dense_same_shape(ctx, scale=-2.0) + x, _ = _xy(ctx) + + op = sum([A, B, C]) + expected = A.apply(x) + B.apply(x) + C.apply(x) + + np.testing.assert_allclose(op.apply(x), expected) + + +@pytest.mark.skipif(not has_jax(), reason="jax is not installed") +@pytest.mark.parametrize("case_index", range(7)) +def test_jax_pytree_roundtrip_for_algebra_classes(case_index): + import jax + + ctx = _ctx() + op, _, _ = _adjoint_cases(ctx)[case_index] + leaves, treedef = jax.tree.flatten(op) + rebuilt = jax.tree.unflatten(treedef, leaves) + + assert rebuilt == op + + +@pytest.mark.skipif(not has_jax(), reason="jax is not installed") +def test_jax_jit_algebra_expression_matches_eager(): + import jax + + sc = importlib.import_module("spacecore") + ctx = sc.Context(sc.JaxOps(), dtype=jax_real_dtype(), enable_checks=False) + X = sc.VectorSpace((2,), ctx) + Y = sc.VectorSpace((3,), ctx) + A = sc.DenseLinOp(ctx.asarray([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]), X, Y, ctx) + B = sc.DenseLinOp(ctx.asarray([[0.5, -1.0], [2.0, 1.0], [-0.5, 3.0]]), X, Y, ctx) + C = sc.DenseLinOp(ctx.asarray([[2.0, -1.0], [0.25, 1.5]]), X, X, ctx) + expr = (2 * A + B) @ C + x = ctx.asarray([1.0, -2.0]) + + apply_jit = jax.jit(lambda op, z: op.apply(z)) + + np.testing.assert_allclose(to_numpy(apply_jit(expr, x)), to_numpy(expr.apply(x))) + + +def test_factories_enforce_same_context_dtype(): + sc = importlib.import_module("spacecore") + ctx32 = sc.Context(sc.NumpyOps(), dtype=np.float32) + ctx64 = sc.Context(sc.NumpyOps(), dtype=np.float64) + X32 = sc.VectorSpace((2,), ctx32) + Y32 = sc.VectorSpace((2,), ctx32) + X64 = sc.VectorSpace((2,), ctx64) + Y64 = sc.VectorSpace((2,), ctx64) + A32 = sc.DenseLinOp(ctx32.asarray([[1.0, 2.0], [3.0, 4.0]]), X32, Y32, ctx32) + A64 = sc.DenseLinOp(ctx64.asarray([[1.0, 2.0], [3.0, 4.0]]), X64, Y64, ctx64) + + with pytest.raises(ValueError, match="same ctx"): + sc.make_sum((A32, A64)) + with pytest.raises(ValueError, match="same ctx"): + sc.make_composed(A32, A64) + + +def test_factories_enforce_domain_and_codomain_compatibility(): + sc = importlib.import_module("spacecore") + ctx = _ctx(dtype=np.float64) + X = sc.VectorSpace((2,), ctx) + Y = sc.VectorSpace((3,), ctx) + Z = sc.VectorSpace((4,), ctx) + A = sc.DenseLinOp(ctx.asarray(np.ones((3, 2))), X, Y, ctx) + B = sc.DenseLinOp(ctx.asarray(np.ones((4, 2))), X, Z, ctx) + + with pytest.raises(ValueError, match="same domain and codomain"): + sc.make_sum((A, B)) + with pytest.raises(ValueError, match="right.codomain == left.domain"): + sc.make_composed(A, B) + + +def test_base_linop_equality_protocol_does_not_raise(): + A = _dense_linop(_ctx()) + + assert (A == None) is False # noqa: E711 + assert A in [A] From 5e816579facf561931ca8a64251c09110555d051 Mon Sep 17 00:00:00 2001 From: Pavlo Pelikh Date: Wed, 20 May 2026 01:22:00 -0300 Subject: [PATCH 7/7] Polish LinOp dense materialization --- spacecore/linop/_base.py | 4 ++-- spacecore/linop/_sparse.py | 16 ++++++++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/spacecore/linop/_base.py b/spacecore/linop/_base.py index d1a05be..f006e9a 100644 --- a/spacecore/linop/_base.py +++ b/spacecore/linop/_base.py @@ -163,10 +163,10 @@ def to_dense(self) -> Any: already store the matrix should override this method for efficiency. """ domain_size = prod(self.domain.shape) - zero = self.ops.zeros((domain_size,), dtype=self.dtype) + eye = self.ops.eye(domain_size, dtype=self.dtype) columns = [] for i in range(domain_size): - basis_vector = self.ops.index_set(zero, i, 1, copy=True) + basis_vector = eye[:, i] x = self.domain.unflatten(basis_vector) y = self.apply(x) columns.append(self.codomain.flatten(y)) diff --git a/spacecore/linop/_sparse.py b/spacecore/linop/_sparse.py index 20780ab..7a8162e 100644 --- a/spacecore/linop/_sparse.py +++ b/spacecore/linop/_sparse.py @@ -87,6 +87,22 @@ def _rapply_unchecked(self, y: DenseArray) -> DenseArray: return x1 if self._dom_is_flat else x1.reshape(self.dom.shape) return self.dom.unflatten(x1) + def to_dense(self) -> DenseArray: + """ + Materialize the stored sparse matrix as a dense operator tensor. + + The returned array has shape ``self.codomain.shape + self.domain.shape``. + """ + if hasattr(self.A, "toarray"): + dense = self.A.toarray() + elif hasattr(self.A, "todense"): + dense = self.A.todense() + elif hasattr(self.A, "to_dense"): + dense = self.A.to_dense() + else: + dense = super().to_dense().reshape((self._cod_size, self._dom_size)) + return self.ops.reshape(dense, tuple(self.codomain.shape) + tuple(self.domain.shape)) + def __eq__(self, x: Any) -> bool: if type(x) is type(self): return (self.dom == x.dom