diff --git a/CHANGELOG.md b/CHANGELOG.md index ebeeb67d5..e2bd218fd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,21 @@ changelog does not include internal changes that do not affect the user. ### Changed +- **BREAKING**: Removed `norm_eps`, `rep_eps` and `solver` parameters from the `__init__` of + `UPGrad`, `UPGradWeighting`, `DualProj` and `DualProjWeighting` in favor of a `projector` + parameter of type `DualConeProjector`. To update: + ```python + # Before + from torchjd.aggregation import UPGrad + aggregator = UPGrad(norm_eps=1e-6, reg_eps=1e-6, solver="quadprog") + + # After + from torchjd.aggregation import UPGrad + from torchjd.linalg import QuadprogProjector + aggregator = UPGrad(projector=QuadprogProjector(norm_eps=1e-6, reg_eps=1e-6)) + ``` + If you used the default `norm_eps`, `reg_eps` and `solver`, you don't have to change anything and + you will get the same results. - `CAGrad`, `CAGradWeighting`, and `NashMTL` are now always importable from `torchjd.aggregation`, even when their optional dependencies are not installed. Attempting to instantiate them without the required dependencies now raises an `ImportError` with installation instructions, instead of @@ -21,6 +36,9 @@ changelog does not include internal changes that do not affect the user. ### Added +- Added a new abstraction: the `DualConeProjector` abstract base class and its concrete + `QuadprogProjector` implementation, to do the projection of the gradients onto the dual cone, as + required in `UPGrad`, and `DualProj`. These classes can be found in `torchjd.linalg`. - Made `WeightedAggregator` and `GramianWeightedAggregator` public. These abstract base classes are now importable from `torchjd.aggregation` and documented. They can be extended to easily implement custom `Aggregator`s. diff --git a/docs/source/docs/linalg/dual_cone.rst b/docs/source/docs/linalg/dual_cone.rst new file mode 100644 index 000000000..f7db87ad0 --- /dev/null +++ b/docs/source/docs/linalg/dual_cone.rst @@ -0,0 +1,9 @@ +:hide-toc: + +Dual Cone Projectors +==================== + +.. autoclass:: torchjd.linalg.DualConeProjector + :members: __call__ + +.. autoclass:: torchjd.linalg.QuadprogProjector diff --git a/docs/source/docs/linalg/index.rst b/docs/source/docs/linalg/index.rst index 4446ccea7..94fcce205 100644 --- a/docs/source/docs/linalg/index.rst +++ b/docs/source/docs/linalg/index.rst @@ -10,3 +10,4 @@ linalg matrix.rst psd_matrix.rst + dual_cone.rst diff --git a/src/torchjd/_linalg/__init__.py b/src/torchjd/_linalg/__init__.py index 29b8cd0b3..fce72e8e6 100644 --- a/src/torchjd/_linalg/__init__.py +++ b/src/torchjd/_linalg/__init__.py @@ -1,3 +1,4 @@ +from ._dual_cone import DualConeProjector, QuadprogProjector, projector_or_default from ._generalized_gramian import flatten, movedim, reshape from ._gramian import compute_gramian, normalize, regularize from ._matrix import Matrix, PSDMatrix, PSDTensor, is_matrix, is_psd_matrix, is_psd_tensor @@ -15,4 +16,7 @@ "flatten", "reshape", "movedim", + "DualConeProjector", + "QuadprogProjector", + "projector_or_default", ] diff --git a/src/torchjd/_linalg/_dual_cone.py b/src/torchjd/_linalg/_dual_cone.py new file mode 100644 index 000000000..750e97e99 --- /dev/null +++ b/src/torchjd/_linalg/_dual_cone.py @@ -0,0 +1,101 @@ +from abc import ABC, abstractmethod + +import numpy as np +import torch +from qpsolvers import solve_qp +from torch import Tensor + +from ._gramian import normalize, regularize +from ._matrix import PSDMatrix + + +class DualConeProjector(ABC): + """ + Abstract class whose instances are responsible for projecting vectors onto the dual cone of the + rows of a matrix, or rather the dual form of this problem. The current default + :class:`~torchjd.linalg.DualConeProjector` is :class:`~torchjd.linalg.QuadprogProjector`. + """ + + @abstractmethod + def __call__(self, U: Tensor, G: PSDMatrix) -> Tensor: + r""" + Computes the weights :math:`w` of the projection of :math:`J^\top u` onto the dual cone of + the rows of :math:`J`, provided :math:`G = J J^\top` and :math:`u`. In other words, this + computes the :math:`w` that satisfies :math:`\pi_J(J^\top u) = J^\top w`, with + :math:`\pi_J` defined in Equation 3 of [1]. + + By Proposition 1 of [1], this is equivalent to solving for :math:`v` the following + quadratic program: + + .. math:: + + \min_{v} \quad & v^\top G v \\ + \text{subject to} \quad & u \preceq v + + Reference: + [1] `Jacobian Descent For Multi-Objective Optimization `_. + + :param U: The tensor of weights corresponding to the vectors to project, of shape + ``[..., m]``. + :param G: The Gramian matrix of shape ``[m, m]``. It must be symmetric and positive + definite. + :return: A tensor of projection weights with the same shape as ``U``. + """ + + +def projector_or_default(projector: DualConeProjector | None) -> DualConeProjector: + if projector is None: + return QuadprogProjector() + return projector + + +class QuadprogProjector(DualConeProjector): + r""" + Solves the quadratic program defined in :meth:`DualConeProjector.__call__` using the + ``quadprog`` QP solver. + + :param norm_eps: A small value to avoid division by zero when normalizing. + :param reg_eps: A small value to add to the diagonal of the gramian of the matrix. Due to + numerical errors when computing the gramian, it might not exactly be positive definite. + This issue can make the optimization fail. Adding ``reg_eps`` to the diagonal of the gramian + ensures that it is positive definite. + """ + + def __init__( + self, + *, + norm_eps: float = 0.0001, + reg_eps: float = 0.0001, + ) -> None: + self.norm_eps = norm_eps + self.reg_eps = reg_eps + + def __repr__(self) -> str: + return f"QuadprogProjector(norm_eps={self.norm_eps}, reg_eps={self.reg_eps})" + + def __call__(self, U: Tensor, G: PSDMatrix) -> Tensor: + + G = regularize(normalize(G, self.norm_eps), self.reg_eps) + + G_ = _to_array(G) + U_ = _to_array(U) + + W = np.apply_along_axis(lambda u: self._project_weight_vector(u, G_), axis=-1, arr=U_) + + return torch.as_tensor(W, device=G.device, dtype=G.dtype) + + def _project_weight_vector(self, u: np.ndarray, G: np.ndarray) -> np.ndarray: + + m = G.shape[0] + w = solve_qp(G, np.zeros(m), -np.eye(m), -u, solver="quadprog") + + if w is None: # This may happen when G has large values. + raise ValueError("Failed to solve the quadratic programming problem.") + + return w + + +def _to_array(tensor: Tensor) -> np.ndarray: + """Transforms a tensor into a numpy array with float64 dtype.""" + + return tensor.cpu().detach().numpy().astype(np.float64) diff --git a/src/torchjd/aggregation/_dualproj.py b/src/torchjd/aggregation/_dualproj.py index e379f1276..603d34581 100644 --- a/src/torchjd/aggregation/_dualproj.py +++ b/src/torchjd/aggregation/_dualproj.py @@ -1,12 +1,11 @@ from torch import Tensor -from torchjd._linalg import normalize, regularize +from torchjd._linalg import DualConeProjector, projector_or_default from torchjd.linalg import PSDMatrix from ._aggregator_bases import GramianWeightedAggregator from ._mean import MeanWeighting from ._mixins import _NonDifferentiable -from ._utils.dual_cone import SUPPORTED_SOLVER, project_weights from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting from ._weighting_bases import _GramianWeighting @@ -19,31 +18,21 @@ class DualProjWeighting(_NonDifferentiable, _GramianWeighting): :param pref_vector: The preference vector to use. If not provided, defaults to :math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`. - :param norm_eps: A small value to avoid division by zero when normalizing. - :param reg_eps: A small value to add to the diagonal of the gramian of the matrix. Due to - numerical errors when computing the gramian, it might not exactly be positive definite. - This issue can make the optimization fail. Adding ``reg_eps`` to the diagonal of the gramian - ensures that it is positive definite. - :param solver: The solver used to optimize the underlying optimization problem. + :param projector: The :class:`~torchjd.linalg.DualConeProjector` used tocompute the projection. """ def __init__( self, pref_vector: Tensor | None = None, - norm_eps: float = 0.0001, - reg_eps: float = 0.0001, - solver: SUPPORTED_SOLVER = "quadprog", + projector: DualConeProjector | None = None, ) -> None: super().__init__() self.pref_vector = pref_vector - self.norm_eps = norm_eps - self.reg_eps = reg_eps - self.solver: SUPPORTED_SOLVER = solver + self.projector = projector_or_default(projector) def forward(self, gramian: PSDMatrix, /) -> Tensor: u = self.weighting(gramian) - G = regularize(normalize(gramian, self.norm_eps), self.reg_eps) - w = project_weights(u, G, self.solver) + w = self.projector(u, gramian) return w @property @@ -56,26 +45,12 @@ def pref_vector(self, value: Tensor | None) -> None: self._pref_vector = value @property - def norm_eps(self) -> float: - return self._norm_eps + def projector(self) -> DualConeProjector: + return self._projector - @norm_eps.setter - def norm_eps(self, value: float) -> None: - if value < 0: - raise ValueError(f"norm_eps must be non-negative, but got {value}.") - - self._norm_eps = value - - @property - def reg_eps(self) -> float: - return self._reg_eps - - @reg_eps.setter - def reg_eps(self, value: float) -> None: - if value < 0: - raise ValueError(f"reg_eps must be non-negative, but got {value}.") - - self._reg_eps = value + @projector.setter + def projector(self, value: DualConeProjector | None) -> None: + self._projector = projector_or_default(value) class DualProj(_NonDifferentiable, GramianWeightedAggregator): @@ -87,12 +62,7 @@ class DualProj(_NonDifferentiable, GramianWeightedAggregator): :param pref_vector: The preference vector used to combine the rows. If not provided, defaults to :math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`. - :param norm_eps: A small value to avoid division by zero when normalizing. - :param reg_eps: A small value to add to the diagonal of the gramian of the matrix. Due to - numerical errors when computing the gramian, it might not exactly be positive definite. - This issue can make the optimization fail. Adding ``reg_eps`` to the diagonal of the gramian - ensures that it is positive definite. - :param solver: The solver used to optimize the underlying optimization problem. + :param projector: The :class:`~torchjd.linalg.DualConeProjector` used tocompute the projection. """ gramian_weighting: DualProjWeighting @@ -100,14 +70,10 @@ class DualProj(_NonDifferentiable, GramianWeightedAggregator): def __init__( self, pref_vector: Tensor | None = None, - norm_eps: float = 0.0001, - reg_eps: float = 0.0001, - solver: SUPPORTED_SOLVER = "quadprog", + projector: DualConeProjector | None = None, ) -> None: - self._solver: SUPPORTED_SOLVER = solver - super().__init__( - DualProjWeighting(pref_vector, norm_eps=norm_eps, reg_eps=reg_eps, solver=solver), + DualProjWeighting(pref_vector, projector=projector), ) @property @@ -119,25 +85,17 @@ def pref_vector(self, value: Tensor | None) -> None: self.gramian_weighting.pref_vector = value @property - def norm_eps(self) -> float: - return self.gramian_weighting.norm_eps - - @norm_eps.setter - def norm_eps(self, value: float) -> None: - self.gramian_weighting.norm_eps = value - - @property - def reg_eps(self) -> float: - return self.gramian_weighting.reg_eps + def projector(self) -> DualConeProjector: + return self.gramian_weighting.projector - @reg_eps.setter - def reg_eps(self, value: float) -> None: - self.gramian_weighting.reg_eps = value + @projector.setter + def projector(self, value: DualConeProjector | None) -> None: + self.gramian_weighting.projector = value def __repr__(self) -> str: return ( - f"{self.__class__.__name__}(pref_vector={repr(self.pref_vector)}, norm_eps=" - f"{self.norm_eps}, reg_eps={self.reg_eps}, solver={repr(self._solver)})" + f"{self.__class__.__name__}(pref_vector={repr(self.pref_vector)}, projector=" + f"{repr(self.projector)})" ) def __str__(self) -> str: diff --git a/src/torchjd/aggregation/_upgrad.py b/src/torchjd/aggregation/_upgrad.py index c1e4807e3..001f0db59 100644 --- a/src/torchjd/aggregation/_upgrad.py +++ b/src/torchjd/aggregation/_upgrad.py @@ -1,13 +1,12 @@ import torch from torch import Tensor -from torchjd._linalg import normalize, regularize +from torchjd._linalg import DualConeProjector, projector_or_default from torchjd.linalg import PSDMatrix from ._aggregator_bases import GramianWeightedAggregator from ._mean import MeanWeighting from ._mixins import _NonDifferentiable -from ._utils.dual_cone import SUPPORTED_SOLVER, project_weights from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting from ._weighting_bases import _GramianWeighting @@ -20,31 +19,21 @@ class UPGradWeighting(_NonDifferentiable, _GramianWeighting): :param pref_vector: The preference vector to use. If not provided, defaults to :math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`. - :param norm_eps: A small value to avoid division by zero when normalizing. - :param reg_eps: A small value to add to the diagonal of the gramian of the matrix. Due to - numerical errors when computing the gramian, it might not exactly be positive definite. - This issue can make the optimization fail. Adding ``reg_eps`` to the diagonal of the gramian - ensures that it is positive definite. - :param solver: The solver used to optimize the underlying optimization problem. + :param projector: The :class:`~torchjd.linalg.DualConeProjector` used tocompute the projection. """ def __init__( self, pref_vector: Tensor | None = None, - norm_eps: float = 0.0001, - reg_eps: float = 0.0001, - solver: SUPPORTED_SOLVER = "quadprog", + projector: DualConeProjector | None = None, ) -> None: super().__init__() self.pref_vector = pref_vector - self.norm_eps = norm_eps - self.reg_eps = reg_eps - self.solver: SUPPORTED_SOLVER = solver + self.projector = projector_or_default(projector) def forward(self, gramian: PSDMatrix, /) -> Tensor: U = torch.diag(self.weighting(gramian)) - G = regularize(normalize(gramian, self.norm_eps), self.reg_eps) - W = project_weights(U, G, self.solver) + W = self.projector(U, gramian) return torch.sum(W, dim=0) @property @@ -57,28 +46,12 @@ def pref_vector(self, value: Tensor | None) -> None: self._pref_vector = value @property - def norm_eps(self) -> float: - return self._norm_eps + def projector(self) -> DualConeProjector: + return self._projector - @norm_eps.setter - def norm_eps(self, value: float) -> None: - - if value < 0: - raise ValueError(f"norm_eps must be non-negative, but got {value}.") - - self._norm_eps = value - - @property - def reg_eps(self) -> float: - return self._reg_eps - - @reg_eps.setter - def reg_eps(self, value: float) -> None: - - if value < 0: - raise ValueError(f"reg_eps must be non-negative, but got {value}.") - - self._reg_eps = value + @projector.setter + def projector(self, value: DualConeProjector | None) -> None: + self._projector = projector_or_default(value) class UPGrad(_NonDifferentiable, GramianWeightedAggregator): @@ -90,12 +63,7 @@ class UPGrad(_NonDifferentiable, GramianWeightedAggregator): :param pref_vector: The preference vector used to combine the projected rows. If not provided, defaults to :math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`. - :param norm_eps: A small value to avoid division by zero when normalizing. - :param reg_eps: A small value to add to the diagonal of the gramian of the matrix. Due to - numerical errors when computing the gramian, it might not exactly be positive definite. - This issue can make the optimization fail. Adding ``reg_eps`` to the diagonal of the gramian - ensures that it is positive definite. - :param solver: The solver used to optimize the underlying optimization problem. + :param projector: The :class:`~torchjd.linalg.DualConeProjector` used tocompute the projection. """ gramian_weighting: UPGradWeighting @@ -103,14 +71,10 @@ class UPGrad(_NonDifferentiable, GramianWeightedAggregator): def __init__( self, pref_vector: Tensor | None = None, - norm_eps: float = 0.0001, - reg_eps: float = 0.0001, - solver: SUPPORTED_SOLVER = "quadprog", + projector: DualConeProjector | None = None, ) -> None: - self._solver: SUPPORTED_SOLVER = solver - super().__init__( - UPGradWeighting(pref_vector, norm_eps=norm_eps, reg_eps=reg_eps, solver=solver), + UPGradWeighting(pref_vector, projector=projector), ) @property @@ -122,25 +86,17 @@ def pref_vector(self, value: Tensor | None) -> None: self.gramian_weighting.pref_vector = value @property - def norm_eps(self) -> float: - return self.gramian_weighting.norm_eps - - @norm_eps.setter - def norm_eps(self, value: float) -> None: - self.gramian_weighting.norm_eps = value - - @property - def reg_eps(self) -> float: - return self.gramian_weighting.reg_eps + def projector(self) -> DualConeProjector: + return self.gramian_weighting.projector - @reg_eps.setter - def reg_eps(self, value: float) -> None: - self.gramian_weighting.reg_eps = value + @projector.setter + def projector(self, value: DualConeProjector | None) -> None: + self.gramian_weighting.projector = value def __repr__(self) -> str: return ( - f"{self.__class__.__name__}(pref_vector={repr(self.pref_vector)}, norm_eps=" - f"{self.norm_eps}, reg_eps={self.reg_eps}, solver={repr(self._solver)})" + f"{self.__class__.__name__}(pref_vector={repr(self.pref_vector)}, projector=" + f"{repr(self.projector)})" ) def __str__(self) -> str: diff --git a/src/torchjd/aggregation/_utils/dual_cone.py b/src/torchjd/aggregation/_utils/dual_cone.py deleted file mode 100644 index b076366be..000000000 --- a/src/torchjd/aggregation/_utils/dual_cone.py +++ /dev/null @@ -1,62 +0,0 @@ -from typing import Literal, TypeAlias - -import numpy as np -import torch -from qpsolvers import solve_qp -from torch import Tensor - -SUPPORTED_SOLVER: TypeAlias = Literal["quadprog"] - - -def project_weights(U: Tensor, G: Tensor, solver: SUPPORTED_SOLVER) -> Tensor: - """ - Computes the tensor of weights corresponding to the projection of the vectors in `U` onto the - rows of a matrix whose Gramian is provided. - - :param U: The tensor of weights corresponding to the vectors to project, of shape `[..., m]`. - :param G: The Gramian matrix of shape `[m, m]`. It must be symmetric and positive definite. - :param solver: The quadratic programming solver to use. - :return: A tensor of projection weights with the same shape as `U`. - """ - - G_ = _to_array(G) - U_ = _to_array(U) - - W = np.apply_along_axis(lambda u: _project_weight_vector(u, G_, solver), axis=-1, arr=U_) - - return torch.as_tensor(W, device=G.device, dtype=G.dtype) - - -def _project_weight_vector(u: np.ndarray, G: np.ndarray, solver: SUPPORTED_SOLVER) -> np.ndarray: - r""" - Computes the weights `w` of the projection of `J^T u` onto the dual cone of the rows of `J`, - given `G = J J^T` and `u`. In other words, this computes the `w` that satisfies - `\pi_J(J^T u) = J^T w`, with `\pi_J` defined in Equation 3 of [1]. - - By Proposition 1 of [1], this is equivalent to solving for `v` the following quadratic program: - minimize v^T G v - subject to u \preceq v - - Reference: - [1] `Jacobian Descent For Multi-Objective Optimization `_. - - :param u: The vector of weights `u` of shape `[m]` corresponding to the vector `J^T u` to - project. - :param G: The Gramian matrix of `J`, equal to `J J^T`, and of shape `[m, m]`. It must be - symmetric and positive definite. - :param solver: The quadratic programming solver to use. - """ - - m = G.shape[0] - w = solve_qp(G, np.zeros(m), -np.eye(m), -u, solver=solver) - - if w is None: # This may happen when G has large values. - raise ValueError("Failed to solve the quadratic programming problem.") - - return w - - -def _to_array(tensor: Tensor) -> np.ndarray: - """Transforms a tensor into a numpy array with float64 dtype.""" - - return tensor.cpu().detach().numpy().astype(np.float64) diff --git a/src/torchjd/linalg/__init__.py b/src/torchjd/linalg/__init__.py index f8238104e..15476b738 100644 --- a/src/torchjd/linalg/__init__.py +++ b/src/torchjd/linalg/__init__.py @@ -1,8 +1,18 @@ """ -This module provides type annotation classes representing tensors with specific structural -properties. +This module provides utilitary linear algebra methods as well as types to represent specific +structural properties. """ -from torchjd._linalg._matrix import Matrix, PSDMatrix +from torchjd._linalg import ( + DualConeProjector, + Matrix, + PSDMatrix, + QuadprogProjector, +) -__all__ = ["Matrix", "PSDMatrix"] +__all__ = [ + "DualConeProjector", + "Matrix", + "PSDMatrix", + "QuadprogProjector", +] diff --git a/tests/plots/interactive_plotter.py b/tests/plots/interactive_plotter.py index 1c2b62400..c8b3871a9 100644 --- a/tests/plots/interactive_plotter.py +++ b/tests/plots/interactive_plotter.py @@ -11,6 +11,7 @@ from typing_extensions import Unpack from plots._utils import Plotter, angle_to_coord, coord_to_angle +from torchjd._linalg import QuadprogProjector from torchjd.aggregation import ( IMTLG, MGDA, @@ -61,7 +62,7 @@ def main() -> None: "AlignedMTL-RMSE": lambda: AlignedMTL(scale_mode="rmse"), str(CAGrad(c=0.5)): lambda: CAGrad(c=0.5), str(ConFIG()): lambda: ConFIG(), - str(DualProj()): lambda: DualProj(reg_eps=1e-7), + str(DualProj()): lambda: DualProj(projector=QuadprogProjector(reg_eps=1e-7)), str(GradDrop()): lambda: GradDrop(), str(GradVac()): lambda: GradVac(), str(IMTLG()): lambda: IMTLG(), @@ -72,7 +73,7 @@ def main() -> None: str(Random()): lambda: Random(), str(Sum()): lambda: Sum(), str(TrimmedMean(trim_number=1)): lambda: TrimmedMean(trim_number=1), - str(UPGrad()): lambda: UPGrad(reg_eps=1e-7), + str(UPGrad()): lambda: UPGrad(projector=QuadprogProjector(reg_eps=1e-7)), } aggregator_strings = list(aggregator_factories.keys()) diff --git a/tests/unit/aggregation/test_dualproj.py b/tests/unit/aggregation/test_dualproj.py index 34fe8d462..190eaaeb4 100644 --- a/tests/unit/aggregation/test_dualproj.py +++ b/tests/unit/aggregation/test_dualproj.py @@ -1,10 +1,10 @@ import torch -from pytest import mark, raises +from pytest import mark from torch import Tensor from utils.tensors import ones_ +from torchjd._linalg import QuadprogProjector from torchjd.aggregation import ConstantWeighting, DualProj -from torchjd.aggregation._dualproj import DualProjWeighting from ._asserts import ( assert_expected_structure, @@ -47,21 +47,20 @@ def test_non_differentiable(aggregator: DualProj, matrix: Tensor) -> None: def test_representations() -> None: - A = DualProj(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver="quadprog") + A = DualProj(pref_vector=None, projector=QuadprogProjector(norm_eps=0.001, reg_eps=0.01)) assert ( - repr(A) == "DualProj(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver='quadprog')" + repr(A) == "DualProj(pref_vector=None, projector=QuadprogProjector(norm_eps=0.001, " + "reg_eps=0.01))" ) assert str(A) == "DualProj" A = DualProj( pref_vector=torch.tensor([1.0, 2.0, 3.0], device="cpu"), - norm_eps=0.0001, - reg_eps=0.0001, - solver="quadprog", + projector=QuadprogProjector(norm_eps=0.001, reg_eps=0.01), ) assert ( - repr(A) == "DualProj(pref_vector=tensor([1., 2., 3.]), norm_eps=0.0001, reg_eps=0.0001, " - "solver='quadprog')" + repr(A) == "DualProj(pref_vector=tensor([1., 2., 3.]), projector=QuadprogProjector(" + "norm_eps=0.001, reg_eps=0.01))" ) assert str(A) == "DualProj([1., 2., 3.])" @@ -75,39 +74,14 @@ def test_pref_vector_setter_updates_value() -> None: assert A.gramian_weighting.weighting.weights is new_pref -def test_norm_eps_setter_updates_value() -> None: +def test_projector_getter_returns_default() -> None: A = DualProj() - A.norm_eps = 0.25 - assert A.norm_eps == 0.25 - assert A.gramian_weighting.norm_eps == 0.25 + assert isinstance(A.projector, QuadprogProjector) -def test_reg_eps_setter_updates_value() -> None: +def test_projector_setter_updates_value() -> None: A = DualProj() - A.reg_eps = 0.25 - assert A.reg_eps == 0.25 - assert A.gramian_weighting.reg_eps == 0.25 - - -def test_norm_eps_setter_rejects_negative() -> None: - A = DualProj() - with raises(ValueError, match="norm_eps"): - A.norm_eps = -1e-9 - - -def test_reg_eps_setter_rejects_negative() -> None: - A = DualProj() - with raises(ValueError, match="reg_eps"): - A.reg_eps = -1e-9 - - -def test_weighting_norm_eps_setter_rejects_negative() -> None: - W = DualProjWeighting() - with raises(ValueError, match="norm_eps"): - W.norm_eps = -1e-9 - - -def test_weighting_reg_eps_setter_rejects_negative() -> None: - W = DualProjWeighting() - with raises(ValueError, match="reg_eps"): - W.reg_eps = -1e-9 + new_projector = QuadprogProjector(norm_eps=0.001, reg_eps=0.01) + A.projector = new_projector + assert A.projector is new_projector + assert A.gramian_weighting.projector is new_projector diff --git a/tests/unit/aggregation/test_pcgrad.py b/tests/unit/aggregation/test_pcgrad.py index b776071d3..6d22359f3 100644 --- a/tests/unit/aggregation/test_pcgrad.py +++ b/tests/unit/aggregation/test_pcgrad.py @@ -3,7 +3,7 @@ from torch.testing import assert_close from utils.tensors import ones_, randn_ -from torchjd._linalg import compute_gramian +from torchjd._linalg import QuadprogProjector, compute_gramian from torchjd.aggregation import PCGrad from torchjd.aggregation._pcgrad import PCGradWeighting from torchjd.aggregation._upgrad import UPGradWeighting @@ -53,9 +53,7 @@ def test_equivalence_upgrad_sum_two_rows(shape: tuple[int, int]) -> None: pc_grad_weighting = PCGradWeighting() upgrad_sum_weighting = UPGradWeighting( ones_((2,)), - norm_eps=0.0, - reg_eps=0.0, - solver="quadprog", + projector=QuadprogProjector(norm_eps=0.0, reg_eps=0.0), ) result = pc_grad_weighting(gramian) diff --git a/tests/unit/aggregation/test_upgrad.py b/tests/unit/aggregation/test_upgrad.py index 075680a02..b639763c6 100644 --- a/tests/unit/aggregation/test_upgrad.py +++ b/tests/unit/aggregation/test_upgrad.py @@ -1,10 +1,10 @@ import torch -from pytest import mark, raises +from pytest import mark from torch import Tensor from utils.tensors import ones_ +from torchjd._linalg import QuadprogProjector from torchjd.aggregation import ConstantWeighting, UPGrad -from torchjd.aggregation._upgrad import UPGradWeighting from ._asserts import ( assert_expected_structure, @@ -53,19 +53,20 @@ def test_non_differentiable(aggregator: UPGrad, matrix: Tensor) -> None: def test_representations() -> None: - A = UPGrad(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver="quadprog") - assert repr(A) == "UPGrad(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver='quadprog')" + A = UPGrad(pref_vector=None, projector=QuadprogProjector(norm_eps=0.001, reg_eps=0.01)) + assert ( + repr(A) == "UPGrad(pref_vector=None, projector=QuadprogProjector(norm_eps=0.001, " + "reg_eps=0.01))" + ) assert str(A) == "UPGrad" A = UPGrad( pref_vector=torch.tensor([1.0, 2.0, 3.0], device="cpu"), - norm_eps=0.0001, - reg_eps=0.0001, - solver="quadprog", + projector=QuadprogProjector(norm_eps=0.001, reg_eps=0.01), ) assert ( - repr(A) == "UPGrad(pref_vector=tensor([1., 2., 3.]), norm_eps=0.0001, reg_eps=0.0001, " - "solver='quadprog')" + repr(A) == "UPGrad(pref_vector=tensor([1., 2., 3.]), projector=QuadprogProjector(" + "norm_eps=0.001, reg_eps=0.01))" ) assert str(A) == "UPGrad([1., 2., 3.])" @@ -79,37 +80,14 @@ def test_pref_vector_setter_updates_value() -> None: assert A.gramian_weighting.weighting.weights is new_pref -def test_norm_eps_setter_updates_value() -> None: - A = UPGrad() - A.norm_eps = 0.25 - assert A.norm_eps == 0.25 - - -def test_reg_eps_setter_updates_value() -> None: - A = UPGrad() - A.reg_eps = 0.25 - assert A.reg_eps == 0.25 - - -def test_norm_eps_setter_rejects_negative() -> None: +def test_projector_getter_returns_default() -> None: A = UPGrad() - with raises(ValueError, match="norm_eps"): - A.norm_eps = -1e-9 + assert isinstance(A.projector, QuadprogProjector) -def test_reg_eps_setter_rejects_negative() -> None: +def test_projector_setter_updates_value() -> None: A = UPGrad() - with raises(ValueError, match="reg_eps"): - A.reg_eps = -1e-9 - - -def test_weighting_norm_eps_setter_rejects_negative() -> None: - W = UPGradWeighting() - with raises(ValueError, match="norm_eps"): - W.norm_eps = -1e-9 - - -def test_weighting_reg_eps_setter_rejects_negative() -> None: - W = UPGradWeighting() - with raises(ValueError, match="reg_eps"): - W.reg_eps = -1e-9 + new_projector = QuadprogProjector(norm_eps=0.001, reg_eps=0.01) + A.projector = new_projector + assert A.projector is new_projector + assert A.gramian_weighting.projector is new_projector diff --git a/tests/unit/aggregation/_utils/test_dual_cone.py b/tests/unit/linalg/test_dual_cone.py similarity index 68% rename from tests/unit/aggregation/_utils/test_dual_cone.py rename to tests/unit/linalg/test_dual_cone.py index 68a8a75d7..e42956c1f 100644 --- a/tests/unit/aggregation/_utils/test_dual_cone.py +++ b/tests/unit/linalg/test_dual_cone.py @@ -1,14 +1,17 @@ +from typing import cast + import numpy as np import torch from pytest import mark, raises from torch.testing import assert_close from utils.tensors import rand_, randn_ -from torchjd.aggregation._utils.dual_cone import _project_weight_vector, project_weights +from torchjd._linalg import DualConeProjector, PSDMatrix, QuadprogProjector, compute_gramian +@mark.parametrize("projector", [QuadprogProjector(reg_eps=0.0, norm_eps=0.0)]) @mark.parametrize("shape", [(5, 7), (9, 37), (2, 14), (32, 114), (50, 100)]) -def test_solution_weights(shape: tuple[int, int]) -> None: +def test_solution_weights(projector: DualConeProjector, shape: tuple[int, int]) -> None: r""" Tests that `_project_weights` returns valid weights corresponding to the projection onto the dual cone of a matrix with the specified shape. @@ -31,10 +34,10 @@ def test_solution_weights(shape: tuple[int, int]) -> None: """ J = randn_(shape) - G = J @ J.T + G = compute_gramian(J) u = rand_(shape[0]) - w = project_weights(u, G, "quadprog") + w = projector(u, G) dual_gap = w - u # Dual feasibility @@ -52,25 +55,30 @@ def test_solution_weights(shape: tuple[int, int]) -> None: assert_close(slackness, torch.zeros_like(slackness), atol=3e-03, rtol=0) +@mark.parametrize("projector", [QuadprogProjector(reg_eps=0.0, norm_eps=0.0)]) @mark.parametrize("shape", [(5, 7), (9, 37), (32, 114)]) @mark.parametrize("scaling", [2 ** (-4), 2 ** (-2), 2**2, 2**4]) -def test_scale_invariant(shape: tuple[int, int], scaling: float) -> None: +def test_scale_invariant( + projector: DualConeProjector, shape: tuple[int, int], scaling: float +) -> None: """ Tests that `_project_weights` is invariant under scaling. """ J = randn_(shape) - G = J @ J.T + G = compute_gramian(J) + scaled_G = cast(PSDMatrix, scaling * G) u = rand_(shape[0]) - w = project_weights(u, G, "quadprog") - w_scaled = project_weights(u, scaling * G, "quadprog") + w = projector(u, G) + w_scaled = projector(u, scaled_G) assert_close(w_scaled, w) +@mark.parametrize("projector", [QuadprogProjector(reg_eps=0.0, norm_eps=0.0)]) @mark.parametrize("shape", [(5, 2, 3), (1, 3, 6, 9), (2, 1, 1, 5, 8), (3, 1)]) -def test_tensorization_shape(shape: tuple[int, ...]) -> None: +def test_tensorization_shape(projector: DualConeProjector, shape: tuple[int, ...]) -> None: """ Tests that applying `_project_weights` on a tensor is equivalent to applying it on the tensor reshaped as matrix and to reshape the result back to the original tensor's shape. @@ -80,18 +88,23 @@ def test_tensorization_shape(shape: tuple[int, ...]) -> None: U_tensor = randn_(shape) U_matrix = U_tensor.reshape([-1, shape[-1]]) - G = matrix @ matrix.T + G = compute_gramian(matrix) - W_tensor = project_weights(U_tensor, G, "quadprog") - W_matrix = project_weights(U_matrix, G, "quadprog") + W_tensor = projector(U_tensor, G) + W_matrix = projector(U_matrix, G) assert_close(W_matrix.reshape(shape), W_tensor) -def test_project_weight_vector_failure() -> None: - """Tests that `_project_weight_vector` raises an error when the input G has too large values.""" +def test_qp_solver_based_failure() -> None: + """ + Tests that `QPSolverBased._project_weight_vector` raises an error when the input G has too large + values. + """ + + projector = QuadprogProjector() large_J = np.random.randn(10, 100) * 1e5 large_G = large_J @ large_J.T with raises(ValueError): - _project_weight_vector(np.ones(10), large_G, "quadprog") + projector._project_weight_vector(np.ones(10), large_G)