From 336b47cc11d0383ae2986a7a7bece20673d6eeeb Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Mon, 11 May 2026 10:05:51 +0200 Subject: [PATCH 01/20] Add DualConeProjector --- src/torchjd/_linalg/_dual_cone.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 src/torchjd/_linalg/_dual_cone.py diff --git a/src/torchjd/_linalg/_dual_cone.py b/src/torchjd/_linalg/_dual_cone.py new file mode 100644 index 00000000..4ff1167e --- /dev/null +++ b/src/torchjd/_linalg/_dual_cone.py @@ -0,0 +1,27 @@ +from abc import ABC, abstractmethod + +from torch import Tensor + +from ._matrix import PSDMatrix + + +class DualConeProjector(ABC): + @abstractmethod + def project_weights(U: Tensor, G: PSDMatrix) -> Tensor: + r""" + Computes the weights `w` of the projection of `J^T u` onto the dual cone of + the rows of `J`, provided `G = J J^T` and `u`. In other words, this computes the `w` that + satisfies `\pi_J(J^T u) = J^T w`, with `\pi_J` defined in Equation 3 of [1]. + + By Proposition 1 of [1], this is equivalent to solving for `v` the following quadratic + program: + minimize v^T G v + subject to u \preceq v + + Reference: + [1] `Jacobian Descent For Multi-Objective Optimization `_. + + :param U: The tensor of weights corresponding to the vectors to project, of shape `[..., m]`. + :param G: The Gramian matrix of shape `[m, m]`. It must be symmetric and positive definite. + :return: A tensor of projection weights with the same shape as `U`. + """ From f1076a73f22318b89bbe71fe811a0dc5ea245e03 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Mon, 11 May 2026 10:28:41 +0200 Subject: [PATCH 02/20] Implement and use QPSolverBased. Set as default. --- src/torchjd/_linalg/__init__.py | 4 ++ src/torchjd/_linalg/_dual_cone.py | 44 ++++++++++++- src/torchjd/aggregation/_dualproj.py | 17 +++-- src/torchjd/aggregation/_upgrad.py | 17 +++-- src/torchjd/aggregation/_utils/dual_cone.py | 62 ------------------- tests/unit/aggregation/test_dualproj.py | 12 ++-- tests/unit/aggregation/test_pcgrad.py | 2 +- tests/unit/aggregation/test_upgrad.py | 14 +++-- .../_utils => linalg}/test_dual_cone.py | 34 ++++++---- 9 files changed, 102 insertions(+), 104 deletions(-) delete mode 100644 src/torchjd/aggregation/_utils/dual_cone.py rename tests/unit/{aggregation/_utils => linalg}/test_dual_cone.py (72%) diff --git a/src/torchjd/_linalg/__init__.py b/src/torchjd/_linalg/__init__.py index 29b8cd0b..2035b7b2 100644 --- a/src/torchjd/_linalg/__init__.py +++ b/src/torchjd/_linalg/__init__.py @@ -1,3 +1,4 @@ +from ._dual_cone import DualConeProjector, QPSolverBased, projector_or_default from ._generalized_gramian import flatten, movedim, reshape from ._gramian import compute_gramian, normalize, regularize from ._matrix import Matrix, PSDMatrix, PSDTensor, is_matrix, is_psd_matrix, is_psd_tensor @@ -15,4 +16,7 @@ "flatten", "reshape", "movedim", + "DualConeProjector", + "QPSolverBased", + "projector_or_default", ] diff --git a/src/torchjd/_linalg/_dual_cone.py b/src/torchjd/_linalg/_dual_cone.py index 4ff1167e..62aabbf3 100644 --- a/src/torchjd/_linalg/_dual_cone.py +++ b/src/torchjd/_linalg/_dual_cone.py @@ -1,5 +1,9 @@ from abc import ABC, abstractmethod +from typing import Literal, TypeAlias +import numpy as np +import torch +from qpsolvers import solve_qp from torch import Tensor from ._matrix import PSDMatrix @@ -7,7 +11,7 @@ class DualConeProjector(ABC): @abstractmethod - def project_weights(U: Tensor, G: PSDMatrix) -> Tensor: + def project_weights(self, U: Tensor, G: PSDMatrix) -> Tensor: r""" Computes the weights `w` of the projection of `J^T u` onto the dual cone of the rows of `J`, provided `G = J J^T` and `u`. In other words, this computes the `w` that @@ -25,3 +29,41 @@ def project_weights(U: Tensor, G: PSDMatrix) -> Tensor: :param G: The Gramian matrix of shape `[m, m]`. It must be symmetric and positive definite. :return: A tensor of projection weights with the same shape as `U`. """ + + +def projector_or_default(projector: DualConeProjector | None) -> DualConeProjector: + if projector is None: + return QPSolverBased("quadprog") + return projector + + +class QPSolverBased(DualConeProjector): + SUPPORTED_SOLVER: TypeAlias = Literal["quadprog"] + + def __init__(self, solver: SUPPORTED_SOLVER) -> None: + self.solver = solver + + def project_weights(self, U: Tensor, G: Tensor) -> Tensor: + + G_ = _to_array(G) + U_ = _to_array(U) + + W = np.apply_along_axis(lambda u: self._project_weight_vector(u, G_), axis=-1, arr=U_) + + return torch.as_tensor(W, device=G.device, dtype=G.dtype) + + def _project_weight_vector(self, u: np.ndarray, G: np.ndarray) -> np.ndarray: + + m = G.shape[0] + w = solve_qp(G, np.zeros(m), -np.eye(m), -u, solver=self.solver) + + if w is None: # This may happen when G has large values. + raise ValueError("Failed to solve the quadratic programming problem.") + + return w + + +def _to_array(tensor: Tensor) -> np.ndarray: + """Transforms a tensor into a numpy array with float64 dtype.""" + + return tensor.cpu().detach().numpy().astype(np.float64) diff --git a/src/torchjd/aggregation/_dualproj.py b/src/torchjd/aggregation/_dualproj.py index e379f127..3dc33d05 100644 --- a/src/torchjd/aggregation/_dualproj.py +++ b/src/torchjd/aggregation/_dualproj.py @@ -1,12 +1,11 @@ from torch import Tensor -from torchjd._linalg import normalize, regularize +from torchjd._linalg import DualConeProjector, normalize, projector_or_default, regularize from torchjd.linalg import PSDMatrix from ._aggregator_bases import GramianWeightedAggregator from ._mean import MeanWeighting from ._mixins import _NonDifferentiable -from ._utils.dual_cone import SUPPORTED_SOLVER, project_weights from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting from ._weighting_bases import _GramianWeighting @@ -32,18 +31,18 @@ def __init__( pref_vector: Tensor | None = None, norm_eps: float = 0.0001, reg_eps: float = 0.0001, - solver: SUPPORTED_SOLVER = "quadprog", + projector: DualConeProjector | None = None, ) -> None: super().__init__() self.pref_vector = pref_vector self.norm_eps = norm_eps self.reg_eps = reg_eps - self.solver: SUPPORTED_SOLVER = solver + self.projector = projector_or_default(projector) def forward(self, gramian: PSDMatrix, /) -> Tensor: u = self.weighting(gramian) G = regularize(normalize(gramian, self.norm_eps), self.reg_eps) - w = project_weights(u, G, self.solver) + w = self.projector.project_weights(u, G) return w @property @@ -102,12 +101,10 @@ def __init__( pref_vector: Tensor | None = None, norm_eps: float = 0.0001, reg_eps: float = 0.0001, - solver: SUPPORTED_SOLVER = "quadprog", + projector: DualConeProjector | None = None, ) -> None: - self._solver: SUPPORTED_SOLVER = solver - super().__init__( - DualProjWeighting(pref_vector, norm_eps=norm_eps, reg_eps=reg_eps, solver=solver), + DualProjWeighting(pref_vector, norm_eps=norm_eps, reg_eps=reg_eps, projector=projector), ) @property @@ -137,7 +134,7 @@ def reg_eps(self, value: float) -> None: def __repr__(self) -> str: return ( f"{self.__class__.__name__}(pref_vector={repr(self.pref_vector)}, norm_eps=" - f"{self.norm_eps}, reg_eps={self.reg_eps}, solver={repr(self._solver)})" + f"{self.norm_eps}, reg_eps={self.reg_eps}, solver={repr(self.projector)})" ) def __str__(self) -> str: diff --git a/src/torchjd/aggregation/_upgrad.py b/src/torchjd/aggregation/_upgrad.py index c1e4807e..29a4e654 100644 --- a/src/torchjd/aggregation/_upgrad.py +++ b/src/torchjd/aggregation/_upgrad.py @@ -1,13 +1,12 @@ import torch from torch import Tensor -from torchjd._linalg import normalize, regularize +from torchjd._linalg import DualConeProjector, normalize, projector_or_default, regularize from torchjd.linalg import PSDMatrix from ._aggregator_bases import GramianWeightedAggregator from ._mean import MeanWeighting from ._mixins import _NonDifferentiable -from ._utils.dual_cone import SUPPORTED_SOLVER, project_weights from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting from ._weighting_bases import _GramianWeighting @@ -33,18 +32,18 @@ def __init__( pref_vector: Tensor | None = None, norm_eps: float = 0.0001, reg_eps: float = 0.0001, - solver: SUPPORTED_SOLVER = "quadprog", + projector: DualConeProjector | None = None, ) -> None: super().__init__() self.pref_vector = pref_vector self.norm_eps = norm_eps self.reg_eps = reg_eps - self.solver: SUPPORTED_SOLVER = solver + self.projector = projector_or_default(projector) def forward(self, gramian: PSDMatrix, /) -> Tensor: U = torch.diag(self.weighting(gramian)) G = regularize(normalize(gramian, self.norm_eps), self.reg_eps) - W = project_weights(U, G, self.solver) + W = self.projector.project_weights(U, G) return torch.sum(W, dim=0) @property @@ -105,12 +104,10 @@ def __init__( pref_vector: Tensor | None = None, norm_eps: float = 0.0001, reg_eps: float = 0.0001, - solver: SUPPORTED_SOLVER = "quadprog", + projector: DualConeProjector | None = None, ) -> None: - self._solver: SUPPORTED_SOLVER = solver - super().__init__( - UPGradWeighting(pref_vector, norm_eps=norm_eps, reg_eps=reg_eps, solver=solver), + UPGradWeighting(pref_vector, norm_eps=norm_eps, reg_eps=reg_eps, projector=projector), ) @property @@ -140,7 +137,7 @@ def reg_eps(self, value: float) -> None: def __repr__(self) -> str: return ( f"{self.__class__.__name__}(pref_vector={repr(self.pref_vector)}, norm_eps=" - f"{self.norm_eps}, reg_eps={self.reg_eps}, solver={repr(self._solver)})" + f"{self.norm_eps}, reg_eps={self.reg_eps}, solver={repr(self.projector)})" ) def __str__(self) -> str: diff --git a/src/torchjd/aggregation/_utils/dual_cone.py b/src/torchjd/aggregation/_utils/dual_cone.py deleted file mode 100644 index b076366b..00000000 --- a/src/torchjd/aggregation/_utils/dual_cone.py +++ /dev/null @@ -1,62 +0,0 @@ -from typing import Literal, TypeAlias - -import numpy as np -import torch -from qpsolvers import solve_qp -from torch import Tensor - -SUPPORTED_SOLVER: TypeAlias = Literal["quadprog"] - - -def project_weights(U: Tensor, G: Tensor, solver: SUPPORTED_SOLVER) -> Tensor: - """ - Computes the tensor of weights corresponding to the projection of the vectors in `U` onto the - rows of a matrix whose Gramian is provided. - - :param U: The tensor of weights corresponding to the vectors to project, of shape `[..., m]`. - :param G: The Gramian matrix of shape `[m, m]`. It must be symmetric and positive definite. - :param solver: The quadratic programming solver to use. - :return: A tensor of projection weights with the same shape as `U`. - """ - - G_ = _to_array(G) - U_ = _to_array(U) - - W = np.apply_along_axis(lambda u: _project_weight_vector(u, G_, solver), axis=-1, arr=U_) - - return torch.as_tensor(W, device=G.device, dtype=G.dtype) - - -def _project_weight_vector(u: np.ndarray, G: np.ndarray, solver: SUPPORTED_SOLVER) -> np.ndarray: - r""" - Computes the weights `w` of the projection of `J^T u` onto the dual cone of the rows of `J`, - given `G = J J^T` and `u`. In other words, this computes the `w` that satisfies - `\pi_J(J^T u) = J^T w`, with `\pi_J` defined in Equation 3 of [1]. - - By Proposition 1 of [1], this is equivalent to solving for `v` the following quadratic program: - minimize v^T G v - subject to u \preceq v - - Reference: - [1] `Jacobian Descent For Multi-Objective Optimization `_. - - :param u: The vector of weights `u` of shape `[m]` corresponding to the vector `J^T u` to - project. - :param G: The Gramian matrix of `J`, equal to `J J^T`, and of shape `[m, m]`. It must be - symmetric and positive definite. - :param solver: The quadratic programming solver to use. - """ - - m = G.shape[0] - w = solve_qp(G, np.zeros(m), -np.eye(m), -u, solver=solver) - - if w is None: # This may happen when G has large values. - raise ValueError("Failed to solve the quadratic programming problem.") - - return w - - -def _to_array(tensor: Tensor) -> np.ndarray: - """Transforms a tensor into a numpy array with float64 dtype.""" - - return tensor.cpu().detach().numpy().astype(np.float64) diff --git a/tests/unit/aggregation/test_dualproj.py b/tests/unit/aggregation/test_dualproj.py index 34fe8d46..7852fa59 100644 --- a/tests/unit/aggregation/test_dualproj.py +++ b/tests/unit/aggregation/test_dualproj.py @@ -3,6 +3,7 @@ from torch import Tensor from utils.tensors import ones_ +from torchjd._linalg import QPSolverBased from torchjd.aggregation import ConstantWeighting, DualProj from torchjd.aggregation._dualproj import DualProjWeighting @@ -47,9 +48,12 @@ def test_non_differentiable(aggregator: DualProj, matrix: Tensor) -> None: def test_representations() -> None: - A = DualProj(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver="quadprog") + A = DualProj( + pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, projector=QPSolverBased("quadprog") + ) assert ( - repr(A) == "DualProj(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver='quadprog')" + repr(A) == "DualProj(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, projector=" + "QPSolverBased('quadprog'))" ) assert str(A) == "DualProj" @@ -57,11 +61,11 @@ def test_representations() -> None: pref_vector=torch.tensor([1.0, 2.0, 3.0], device="cpu"), norm_eps=0.0001, reg_eps=0.0001, - solver="quadprog", + projector=QPSolverBased("quadprog"), ) assert ( repr(A) == "DualProj(pref_vector=tensor([1., 2., 3.]), norm_eps=0.0001, reg_eps=0.0001, " - "solver='quadprog')" + "projector=QPSolverBased('quadprog'))" ) assert str(A) == "DualProj([1., 2., 3.])" diff --git a/tests/unit/aggregation/test_pcgrad.py b/tests/unit/aggregation/test_pcgrad.py index b776071d..f7961e8c 100644 --- a/tests/unit/aggregation/test_pcgrad.py +++ b/tests/unit/aggregation/test_pcgrad.py @@ -55,7 +55,7 @@ def test_equivalence_upgrad_sum_two_rows(shape: tuple[int, int]) -> None: ones_((2,)), norm_eps=0.0, reg_eps=0.0, - solver="quadprog", + projector="quadprog", ) result = pc_grad_weighting(gramian) diff --git a/tests/unit/aggregation/test_upgrad.py b/tests/unit/aggregation/test_upgrad.py index 075680a0..c04d32f1 100644 --- a/tests/unit/aggregation/test_upgrad.py +++ b/tests/unit/aggregation/test_upgrad.py @@ -3,6 +3,7 @@ from torch import Tensor from utils.tensors import ones_ +from torchjd._linalg import QPSolverBased from torchjd.aggregation import ConstantWeighting, UPGrad from torchjd.aggregation._upgrad import UPGradWeighting @@ -53,19 +54,24 @@ def test_non_differentiable(aggregator: UPGrad, matrix: Tensor) -> None: def test_representations() -> None: - A = UPGrad(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver="quadprog") - assert repr(A) == "UPGrad(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver='quadprog')" + A = UPGrad( + pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, projector=QPSolverBased("quadprog") + ) + assert ( + repr(A) + == "UPGrad(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, projector=QPSolverBased('quadprog'))" + ) assert str(A) == "UPGrad" A = UPGrad( pref_vector=torch.tensor([1.0, 2.0, 3.0], device="cpu"), norm_eps=0.0001, reg_eps=0.0001, - solver="quadprog", + projector=QPSolverBased("quadprog"), ) assert ( repr(A) == "UPGrad(pref_vector=tensor([1., 2., 3.]), norm_eps=0.0001, reg_eps=0.0001, " - "solver='quadprog')" + "projector=QPSolverBased('quadprog'))" ) assert str(A) == "UPGrad([1., 2., 3.])" diff --git a/tests/unit/aggregation/_utils/test_dual_cone.py b/tests/unit/linalg/test_dual_cone.py similarity index 72% rename from tests/unit/aggregation/_utils/test_dual_cone.py rename to tests/unit/linalg/test_dual_cone.py index 68a8a75d..c39029fa 100644 --- a/tests/unit/aggregation/_utils/test_dual_cone.py +++ b/tests/unit/linalg/test_dual_cone.py @@ -4,11 +4,12 @@ from torch.testing import assert_close from utils.tensors import rand_, randn_ -from torchjd.aggregation._utils.dual_cone import _project_weight_vector, project_weights +from torchjd._linalg import DualConeProjector, QPSolverBased +@mark.parametrize("projector", [QPSolverBased("quadprog")]) @mark.parametrize("shape", [(5, 7), (9, 37), (2, 14), (32, 114), (50, 100)]) -def test_solution_weights(shape: tuple[int, int]) -> None: +def test_solution_weights(projector: DualConeProjector, shape: tuple[int, int]) -> None: r""" Tests that `_project_weights` returns valid weights corresponding to the projection onto the dual cone of a matrix with the specified shape. @@ -34,7 +35,7 @@ def test_solution_weights(shape: tuple[int, int]) -> None: G = J @ J.T u = rand_(shape[0]) - w = project_weights(u, G, "quadprog") + w = projector.project_weights(u, G) dual_gap = w - u # Dual feasibility @@ -52,9 +53,12 @@ def test_solution_weights(shape: tuple[int, int]) -> None: assert_close(slackness, torch.zeros_like(slackness), atol=3e-03, rtol=0) +@mark.parametrize("projector", [QPSolverBased("quadprog")]) @mark.parametrize("shape", [(5, 7), (9, 37), (32, 114)]) @mark.parametrize("scaling", [2 ** (-4), 2 ** (-2), 2**2, 2**4]) -def test_scale_invariant(shape: tuple[int, int], scaling: float) -> None: +def test_scale_invariant( + projector: DualConeProjector, shape: tuple[int, int], scaling: float +) -> None: """ Tests that `_project_weights` is invariant under scaling. """ @@ -63,14 +67,15 @@ def test_scale_invariant(shape: tuple[int, int], scaling: float) -> None: G = J @ J.T u = rand_(shape[0]) - w = project_weights(u, G, "quadprog") - w_scaled = project_weights(u, scaling * G, "quadprog") + w = projector.project_weights(u, G) + w_scaled = projector.project_weights(u, scaling * G) assert_close(w_scaled, w) +@mark.parametrize("projector", [QPSolverBased("quadprog")]) @mark.parametrize("shape", [(5, 2, 3), (1, 3, 6, 9), (2, 1, 1, 5, 8), (3, 1)]) -def test_tensorization_shape(shape: tuple[int, ...]) -> None: +def test_tensorization_shape(projector: DualConeProjector, shape: tuple[int, ...]) -> None: """ Tests that applying `_project_weights` on a tensor is equivalent to applying it on the tensor reshaped as matrix and to reshape the result back to the original tensor's shape. @@ -82,16 +87,21 @@ def test_tensorization_shape(shape: tuple[int, ...]) -> None: G = matrix @ matrix.T - W_tensor = project_weights(U_tensor, G, "quadprog") - W_matrix = project_weights(U_matrix, G, "quadprog") + W_tensor = projector.project_weights(U_tensor, G) + W_matrix = projector.project_weights(U_matrix, G) assert_close(W_matrix.reshape(shape), W_tensor) -def test_project_weight_vector_failure() -> None: - """Tests that `_project_weight_vector` raises an error when the input G has too large values.""" +def test_qp_solver_based_failure() -> None: + """ + Tests that `QPSolverBased._project_weight_vector` raises an error when the input G has too large + values. + """ + + projector = QPSolverBased("quadprog") large_J = np.random.randn(10, 100) * 1e5 large_G = large_J @ large_J.T with raises(ValueError): - _project_weight_vector(np.ones(10), large_G, "quadprog") + projector._project_weight_vector(np.ones(10), large_G) From 3343ebb8dec5386b9cf5190dd7dce56fa9c3531d Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Mon, 11 May 2026 11:16:23 +0200 Subject: [PATCH 03/20] add getters and setters for projectors in UPGrad and DualProj --- src/torchjd/_linalg/_dual_cone.py | 3 +++ src/torchjd/aggregation/_dualproj.py | 18 +++++++++++++++++- src/torchjd/aggregation/_upgrad.py | 18 +++++++++++++++++- 3 files changed, 37 insertions(+), 2 deletions(-) diff --git a/src/torchjd/_linalg/_dual_cone.py b/src/torchjd/_linalg/_dual_cone.py index 62aabbf3..7cbb673a 100644 --- a/src/torchjd/_linalg/_dual_cone.py +++ b/src/torchjd/_linalg/_dual_cone.py @@ -43,6 +43,9 @@ class QPSolverBased(DualConeProjector): def __init__(self, solver: SUPPORTED_SOLVER) -> None: self.solver = solver + def __repr__(self) -> str: + return f"QPSolverBased({repr(self.solver)})" + def project_weights(self, U: Tensor, G: Tensor) -> Tensor: G_ = _to_array(G) diff --git a/src/torchjd/aggregation/_dualproj.py b/src/torchjd/aggregation/_dualproj.py index 3dc33d05..e6117a77 100644 --- a/src/torchjd/aggregation/_dualproj.py +++ b/src/torchjd/aggregation/_dualproj.py @@ -76,6 +76,14 @@ def reg_eps(self, value: float) -> None: self._reg_eps = value + @property + def projector(self) -> DualConeProjector: + return self._projector + + @projector.setter + def projector(self, value: DualConeProjector | None) -> None: + self._projector = projector_or_default(value) + class DualProj(_NonDifferentiable, GramianWeightedAggregator): r""" @@ -131,10 +139,18 @@ def reg_eps(self) -> float: def reg_eps(self, value: float) -> None: self.gramian_weighting.reg_eps = value + @property + def projector(self) -> DualConeProjector: + return self.gramian_weighting.projector + + @projector.setter + def projector(self, value: DualConeProjector | None) -> None: + self.gramian_weighting.projector = value + def __repr__(self) -> str: return ( f"{self.__class__.__name__}(pref_vector={repr(self.pref_vector)}, norm_eps=" - f"{self.norm_eps}, reg_eps={self.reg_eps}, solver={repr(self.projector)})" + f"{self.norm_eps}, reg_eps={self.reg_eps}, projector={repr(self.projector)})" ) def __str__(self) -> str: diff --git a/src/torchjd/aggregation/_upgrad.py b/src/torchjd/aggregation/_upgrad.py index 29a4e654..a2d28515 100644 --- a/src/torchjd/aggregation/_upgrad.py +++ b/src/torchjd/aggregation/_upgrad.py @@ -79,6 +79,14 @@ def reg_eps(self, value: float) -> None: self._reg_eps = value + @property + def projector(self) -> DualConeProjector: + return self._projector + + @projector.setter + def projector(self, value: DualConeProjector | None) -> None: + self._projector = projector_or_default(value) + class UPGrad(_NonDifferentiable, GramianWeightedAggregator): r""" @@ -134,10 +142,18 @@ def reg_eps(self) -> float: def reg_eps(self, value: float) -> None: self.gramian_weighting.reg_eps = value + @property + def projector(self) -> DualConeProjector: + return self.gramian_weighting.projector + + @projector.setter + def projector(self, value: DualConeProjector | None) -> None: + self.gramian_weighting.projector = value + def __repr__(self) -> str: return ( f"{self.__class__.__name__}(pref_vector={repr(self.pref_vector)}, norm_eps=" - f"{self.norm_eps}, reg_eps={self.reg_eps}, solver={repr(self.projector)})" + f"{self.norm_eps}, reg_eps={self.reg_eps}, projector={repr(self.projector)})" ) def __str__(self) -> str: From 25d0c916e468539b28e38bc06c3943745c2be4cc Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Mon, 11 May 2026 11:43:45 +0200 Subject: [PATCH 04/20] fix typing --- tests/unit/aggregation/test_pcgrad.py | 4 ++-- tests/unit/linalg/test_dual_cone.py | 13 ++++++++----- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/tests/unit/aggregation/test_pcgrad.py b/tests/unit/aggregation/test_pcgrad.py index f7961e8c..819f2be0 100644 --- a/tests/unit/aggregation/test_pcgrad.py +++ b/tests/unit/aggregation/test_pcgrad.py @@ -3,7 +3,7 @@ from torch.testing import assert_close from utils.tensors import ones_, randn_ -from torchjd._linalg import compute_gramian +from torchjd._linalg import QPSolverBased, compute_gramian from torchjd.aggregation import PCGrad from torchjd.aggregation._pcgrad import PCGradWeighting from torchjd.aggregation._upgrad import UPGradWeighting @@ -55,7 +55,7 @@ def test_equivalence_upgrad_sum_two_rows(shape: tuple[int, int]) -> None: ones_((2,)), norm_eps=0.0, reg_eps=0.0, - projector="quadprog", + projector=QPSolverBased("quadprog"), ) result = pc_grad_weighting(gramian) diff --git a/tests/unit/linalg/test_dual_cone.py b/tests/unit/linalg/test_dual_cone.py index c39029fa..6faa25af 100644 --- a/tests/unit/linalg/test_dual_cone.py +++ b/tests/unit/linalg/test_dual_cone.py @@ -1,10 +1,12 @@ +from typing import cast + import numpy as np import torch from pytest import mark, raises from torch.testing import assert_close from utils.tensors import rand_, randn_ -from torchjd._linalg import DualConeProjector, QPSolverBased +from torchjd._linalg import DualConeProjector, PSDMatrix, QPSolverBased, compute_gramian @mark.parametrize("projector", [QPSolverBased("quadprog")]) @@ -32,7 +34,7 @@ def test_solution_weights(projector: DualConeProjector, shape: tuple[int, int]) """ J = randn_(shape) - G = J @ J.T + G = compute_gramian(J) u = rand_(shape[0]) w = projector.project_weights(u, G) @@ -64,11 +66,12 @@ def test_scale_invariant( """ J = randn_(shape) - G = J @ J.T + G = compute_gramian(J) + scaled_G = cast(PSDMatrix, scaling * G) u = rand_(shape[0]) w = projector.project_weights(u, G) - w_scaled = projector.project_weights(u, scaling * G) + w_scaled = projector.project_weights(u, scaled_G) assert_close(w_scaled, w) @@ -85,7 +88,7 @@ def test_tensorization_shape(projector: DualConeProjector, shape: tuple[int, ... U_tensor = randn_(shape) U_matrix = U_tensor.reshape([-1, shape[-1]]) - G = matrix @ matrix.T + G = compute_gramian(matrix) W_tensor = projector.project_weights(U_tensor, G) W_matrix = projector.project_weights(U_matrix, G) From c63571d82abe3dd6c963495a1899b9ea0221b3d0 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Mon, 11 May 2026 11:46:58 +0200 Subject: [PATCH 05/20] Make PCGrad test use default Projector --- tests/unit/aggregation/test_pcgrad.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/unit/aggregation/test_pcgrad.py b/tests/unit/aggregation/test_pcgrad.py index 819f2be0..ca939116 100644 --- a/tests/unit/aggregation/test_pcgrad.py +++ b/tests/unit/aggregation/test_pcgrad.py @@ -3,7 +3,7 @@ from torch.testing import assert_close from utils.tensors import ones_, randn_ -from torchjd._linalg import QPSolverBased, compute_gramian +from torchjd._linalg import compute_gramian from torchjd.aggregation import PCGrad from torchjd.aggregation._pcgrad import PCGradWeighting from torchjd.aggregation._upgrad import UPGradWeighting @@ -55,7 +55,6 @@ def test_equivalence_upgrad_sum_two_rows(shape: tuple[int, int]) -> None: ones_((2,)), norm_eps=0.0, reg_eps=0.0, - projector=QPSolverBased("quadprog"), ) result = pc_grad_weighting(gramian) From b2e6d2a6034ee81fb016f9054f83bbeeda4a14f3 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Tue, 12 May 2026 10:30:14 +0200 Subject: [PATCH 06/20] rename to __call__ --- src/torchjd/_linalg/_dual_cone.py | 4 ++-- src/torchjd/aggregation/_dualproj.py | 2 +- src/torchjd/aggregation/_upgrad.py | 2 +- tests/unit/linalg/test_dual_cone.py | 10 +++++----- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/torchjd/_linalg/_dual_cone.py b/src/torchjd/_linalg/_dual_cone.py index 7cbb673a..9f2db7c8 100644 --- a/src/torchjd/_linalg/_dual_cone.py +++ b/src/torchjd/_linalg/_dual_cone.py @@ -11,7 +11,7 @@ class DualConeProjector(ABC): @abstractmethod - def project_weights(self, U: Tensor, G: PSDMatrix) -> Tensor: + def __call__(self, U: Tensor, G: PSDMatrix) -> Tensor: r""" Computes the weights `w` of the projection of `J^T u` onto the dual cone of the rows of `J`, provided `G = J J^T` and `u`. In other words, this computes the `w` that @@ -46,7 +46,7 @@ def __init__(self, solver: SUPPORTED_SOLVER) -> None: def __repr__(self) -> str: return f"QPSolverBased({repr(self.solver)})" - def project_weights(self, U: Tensor, G: Tensor) -> Tensor: + def __call__(self, U: Tensor, G: Tensor) -> Tensor: G_ = _to_array(G) U_ = _to_array(U) diff --git a/src/torchjd/aggregation/_dualproj.py b/src/torchjd/aggregation/_dualproj.py index e6117a77..b25a5874 100644 --- a/src/torchjd/aggregation/_dualproj.py +++ b/src/torchjd/aggregation/_dualproj.py @@ -42,7 +42,7 @@ def __init__( def forward(self, gramian: PSDMatrix, /) -> Tensor: u = self.weighting(gramian) G = regularize(normalize(gramian, self.norm_eps), self.reg_eps) - w = self.projector.project_weights(u, G) + w = self.projector(u, G) return w @property diff --git a/src/torchjd/aggregation/_upgrad.py b/src/torchjd/aggregation/_upgrad.py index a2d28515..b08c3802 100644 --- a/src/torchjd/aggregation/_upgrad.py +++ b/src/torchjd/aggregation/_upgrad.py @@ -43,7 +43,7 @@ def __init__( def forward(self, gramian: PSDMatrix, /) -> Tensor: U = torch.diag(self.weighting(gramian)) G = regularize(normalize(gramian, self.norm_eps), self.reg_eps) - W = self.projector.project_weights(U, G) + W = self.projector(U, G) return torch.sum(W, dim=0) @property diff --git a/tests/unit/linalg/test_dual_cone.py b/tests/unit/linalg/test_dual_cone.py index 6faa25af..93d0f439 100644 --- a/tests/unit/linalg/test_dual_cone.py +++ b/tests/unit/linalg/test_dual_cone.py @@ -37,7 +37,7 @@ def test_solution_weights(projector: DualConeProjector, shape: tuple[int, int]) G = compute_gramian(J) u = rand_(shape[0]) - w = projector.project_weights(u, G) + w = projector(u, G) dual_gap = w - u # Dual feasibility @@ -70,8 +70,8 @@ def test_scale_invariant( scaled_G = cast(PSDMatrix, scaling * G) u = rand_(shape[0]) - w = projector.project_weights(u, G) - w_scaled = projector.project_weights(u, scaled_G) + w = projector(u, G) + w_scaled = projector(u, scaled_G) assert_close(w_scaled, w) @@ -90,8 +90,8 @@ def test_tensorization_shape(projector: DualConeProjector, shape: tuple[int, ... G = compute_gramian(matrix) - W_tensor = projector.project_weights(U_tensor, G) - W_matrix = projector.project_weights(U_matrix, G) + W_tensor = projector(U_tensor, G) + W_matrix = projector(U_matrix, G) assert_close(W_matrix.reshape(shape), W_tensor) From 2f05e890a6dcf6dd27d60dc52b8a5cab0ade832b Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Tue, 12 May 2026 10:55:20 +0200 Subject: [PATCH 07/20] Move noramlization and regularization to QPBasedProjector. --- src/torchjd/_linalg/_dual_cone.py | 26 ++++++++++++++++++++++--- src/torchjd/aggregation/_dualproj.py | 23 +++------------------- src/torchjd/aggregation/_upgrad.py | 23 +++------------------- tests/unit/aggregation/test_dualproj.py | 16 ++++----------- tests/unit/aggregation/test_pcgrad.py | 5 ++--- tests/unit/aggregation/test_upgrad.py | 16 ++++----------- tests/unit/linalg/test_dual_cone.py | 8 ++++---- 7 files changed, 43 insertions(+), 74 deletions(-) diff --git a/src/torchjd/_linalg/_dual_cone.py b/src/torchjd/_linalg/_dual_cone.py index 9f2db7c8..f69e504d 100644 --- a/src/torchjd/_linalg/_dual_cone.py +++ b/src/torchjd/_linalg/_dual_cone.py @@ -6,6 +6,7 @@ from qpsolvers import solve_qp from torch import Tensor +from ._gramian import normalize, regularize from ._matrix import PSDMatrix @@ -33,20 +34,39 @@ def __call__(self, U: Tensor, G: PSDMatrix) -> Tensor: def projector_or_default(projector: DualConeProjector | None) -> DualConeProjector: if projector is None: - return QPSolverBased("quadprog") + return QPSolverBased(solver="quadprog") return projector class QPSolverBased(DualConeProjector): + """ + + :param norm_eps: A small value to avoid division by zero when normalizing. + :param reg_eps: A small value to add to the diagonal of the gramian of the matrix. Due to + numerical errors when computing the gramian, it might not exactly be positive definite. + This issue can make the optimization fail. Adding ``reg_eps`` to the diagonal of the gramian + ensures that it is positive definite. + """ + SUPPORTED_SOLVER: TypeAlias = Literal["quadprog"] - def __init__(self, solver: SUPPORTED_SOLVER) -> None: + def __init__( + self, + *, + norm_eps: float = 0.0001, + reg_eps: float = 0.0001, + solver: SUPPORTED_SOLVER = "quadprog", + ) -> None: + self.norm_eps = norm_eps + self.reg_eps = reg_eps self.solver = solver def __repr__(self) -> str: return f"QPSolverBased({repr(self.solver)})" - def __call__(self, U: Tensor, G: Tensor) -> Tensor: + def __call__(self, U: Tensor, G: PSDMatrix) -> Tensor: + + G = regularize(normalize(G, self.norm_eps), self.reg_eps) G_ = _to_array(G) U_ = _to_array(U) diff --git a/src/torchjd/aggregation/_dualproj.py b/src/torchjd/aggregation/_dualproj.py index b25a5874..f1e9aeec 100644 --- a/src/torchjd/aggregation/_dualproj.py +++ b/src/torchjd/aggregation/_dualproj.py @@ -1,6 +1,6 @@ from torch import Tensor -from torchjd._linalg import DualConeProjector, normalize, projector_or_default, regularize +from torchjd._linalg import DualConeProjector, projector_or_default from torchjd.linalg import PSDMatrix from ._aggregator_bases import GramianWeightedAggregator @@ -18,31 +18,21 @@ class DualProjWeighting(_NonDifferentiable, _GramianWeighting): :param pref_vector: The preference vector to use. If not provided, defaults to :math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`. - :param norm_eps: A small value to avoid division by zero when normalizing. - :param reg_eps: A small value to add to the diagonal of the gramian of the matrix. Due to - numerical errors when computing the gramian, it might not exactly be positive definite. - This issue can make the optimization fail. Adding ``reg_eps`` to the diagonal of the gramian - ensures that it is positive definite. :param solver: The solver used to optimize the underlying optimization problem. """ def __init__( self, pref_vector: Tensor | None = None, - norm_eps: float = 0.0001, - reg_eps: float = 0.0001, projector: DualConeProjector | None = None, ) -> None: super().__init__() self.pref_vector = pref_vector - self.norm_eps = norm_eps - self.reg_eps = reg_eps self.projector = projector_or_default(projector) def forward(self, gramian: PSDMatrix, /) -> Tensor: u = self.weighting(gramian) - G = regularize(normalize(gramian, self.norm_eps), self.reg_eps) - w = self.projector(u, G) + w = self.projector(u, gramian) return w @property @@ -94,11 +84,6 @@ class DualProj(_NonDifferentiable, GramianWeightedAggregator): :param pref_vector: The preference vector used to combine the rows. If not provided, defaults to :math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`. - :param norm_eps: A small value to avoid division by zero when normalizing. - :param reg_eps: A small value to add to the diagonal of the gramian of the matrix. Due to - numerical errors when computing the gramian, it might not exactly be positive definite. - This issue can make the optimization fail. Adding ``reg_eps`` to the diagonal of the gramian - ensures that it is positive definite. :param solver: The solver used to optimize the underlying optimization problem. """ @@ -107,12 +92,10 @@ class DualProj(_NonDifferentiable, GramianWeightedAggregator): def __init__( self, pref_vector: Tensor | None = None, - norm_eps: float = 0.0001, - reg_eps: float = 0.0001, projector: DualConeProjector | None = None, ) -> None: super().__init__( - DualProjWeighting(pref_vector, norm_eps=norm_eps, reg_eps=reg_eps, projector=projector), + DualProjWeighting(pref_vector, projector=projector), ) @property diff --git a/src/torchjd/aggregation/_upgrad.py b/src/torchjd/aggregation/_upgrad.py index b08c3802..47eef4d5 100644 --- a/src/torchjd/aggregation/_upgrad.py +++ b/src/torchjd/aggregation/_upgrad.py @@ -1,7 +1,7 @@ import torch from torch import Tensor -from torchjd._linalg import DualConeProjector, normalize, projector_or_default, regularize +from torchjd._linalg import DualConeProjector, projector_or_default from torchjd.linalg import PSDMatrix from ._aggregator_bases import GramianWeightedAggregator @@ -19,31 +19,21 @@ class UPGradWeighting(_NonDifferentiable, _GramianWeighting): :param pref_vector: The preference vector to use. If not provided, defaults to :math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`. - :param norm_eps: A small value to avoid division by zero when normalizing. - :param reg_eps: A small value to add to the diagonal of the gramian of the matrix. Due to - numerical errors when computing the gramian, it might not exactly be positive definite. - This issue can make the optimization fail. Adding ``reg_eps`` to the diagonal of the gramian - ensures that it is positive definite. :param solver: The solver used to optimize the underlying optimization problem. """ def __init__( self, pref_vector: Tensor | None = None, - norm_eps: float = 0.0001, - reg_eps: float = 0.0001, projector: DualConeProjector | None = None, ) -> None: super().__init__() self.pref_vector = pref_vector - self.norm_eps = norm_eps - self.reg_eps = reg_eps self.projector = projector_or_default(projector) def forward(self, gramian: PSDMatrix, /) -> Tensor: U = torch.diag(self.weighting(gramian)) - G = regularize(normalize(gramian, self.norm_eps), self.reg_eps) - W = self.projector(U, G) + W = self.projector(U, gramian) return torch.sum(W, dim=0) @property @@ -97,11 +87,6 @@ class UPGrad(_NonDifferentiable, GramianWeightedAggregator): :param pref_vector: The preference vector used to combine the projected rows. If not provided, defaults to :math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`. - :param norm_eps: A small value to avoid division by zero when normalizing. - :param reg_eps: A small value to add to the diagonal of the gramian of the matrix. Due to - numerical errors when computing the gramian, it might not exactly be positive definite. - This issue can make the optimization fail. Adding ``reg_eps`` to the diagonal of the gramian - ensures that it is positive definite. :param solver: The solver used to optimize the underlying optimization problem. """ @@ -110,12 +95,10 @@ class UPGrad(_NonDifferentiable, GramianWeightedAggregator): def __init__( self, pref_vector: Tensor | None = None, - norm_eps: float = 0.0001, - reg_eps: float = 0.0001, projector: DualConeProjector | None = None, ) -> None: super().__init__( - UPGradWeighting(pref_vector, norm_eps=norm_eps, reg_eps=reg_eps, projector=projector), + UPGradWeighting(pref_vector, projector=projector), ) @property diff --git a/tests/unit/aggregation/test_dualproj.py b/tests/unit/aggregation/test_dualproj.py index 7852fa59..0e9d0ba5 100644 --- a/tests/unit/aggregation/test_dualproj.py +++ b/tests/unit/aggregation/test_dualproj.py @@ -48,24 +48,16 @@ def test_non_differentiable(aggregator: DualProj, matrix: Tensor) -> None: def test_representations() -> None: - A = DualProj( - pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, projector=QPSolverBased("quadprog") - ) - assert ( - repr(A) == "DualProj(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, projector=" - "QPSolverBased('quadprog'))" - ) + A = DualProj(pref_vector=None, projector=QPSolverBased()) + assert repr(A) == "DualProj(pref_vector=None, projector=QPSolverBased('quadprog'))" assert str(A) == "DualProj" A = DualProj( pref_vector=torch.tensor([1.0, 2.0, 3.0], device="cpu"), - norm_eps=0.0001, - reg_eps=0.0001, - projector=QPSolverBased("quadprog"), + projector=QPSolverBased(), ) assert ( - repr(A) == "DualProj(pref_vector=tensor([1., 2., 3.]), norm_eps=0.0001, reg_eps=0.0001, " - "projector=QPSolverBased('quadprog'))" + repr(A) == "DualProj(pref_vector=tensor([1., 2., 3.]), projector=QPSolverBased('quadprog'))" ) assert str(A) == "DualProj([1., 2., 3.])" diff --git a/tests/unit/aggregation/test_pcgrad.py b/tests/unit/aggregation/test_pcgrad.py index ca939116..89565664 100644 --- a/tests/unit/aggregation/test_pcgrad.py +++ b/tests/unit/aggregation/test_pcgrad.py @@ -3,7 +3,7 @@ from torch.testing import assert_close from utils.tensors import ones_, randn_ -from torchjd._linalg import compute_gramian +from torchjd._linalg import QPSolverBased, compute_gramian from torchjd.aggregation import PCGrad from torchjd.aggregation._pcgrad import PCGradWeighting from torchjd.aggregation._upgrad import UPGradWeighting @@ -53,8 +53,7 @@ def test_equivalence_upgrad_sum_two_rows(shape: tuple[int, int]) -> None: pc_grad_weighting = PCGradWeighting() upgrad_sum_weighting = UPGradWeighting( ones_((2,)), - norm_eps=0.0, - reg_eps=0.0, + projector=QPSolverBased(norm_eps=0.0, reg_eps=0.0), ) result = pc_grad_weighting(gramian) diff --git a/tests/unit/aggregation/test_upgrad.py b/tests/unit/aggregation/test_upgrad.py index c04d32f1..e79cf1f8 100644 --- a/tests/unit/aggregation/test_upgrad.py +++ b/tests/unit/aggregation/test_upgrad.py @@ -54,24 +54,16 @@ def test_non_differentiable(aggregator: UPGrad, matrix: Tensor) -> None: def test_representations() -> None: - A = UPGrad( - pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, projector=QPSolverBased("quadprog") - ) - assert ( - repr(A) - == "UPGrad(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, projector=QPSolverBased('quadprog'))" - ) + A = UPGrad(pref_vector=None, projector=QPSolverBased()) + assert repr(A) == "UPGrad(pref_vector=None, projector=QPSolverBased('quadprog'))" assert str(A) == "UPGrad" A = UPGrad( pref_vector=torch.tensor([1.0, 2.0, 3.0], device="cpu"), - norm_eps=0.0001, - reg_eps=0.0001, - projector=QPSolverBased("quadprog"), + projector=QPSolverBased(), ) assert ( - repr(A) == "UPGrad(pref_vector=tensor([1., 2., 3.]), norm_eps=0.0001, reg_eps=0.0001, " - "projector=QPSolverBased('quadprog'))" + repr(A) == "UPGrad(pref_vector=tensor([1., 2., 3.]), projector=QPSolverBased('quadprog'))" ) assert str(A) == "UPGrad([1., 2., 3.])" diff --git a/tests/unit/linalg/test_dual_cone.py b/tests/unit/linalg/test_dual_cone.py index 93d0f439..c79d5266 100644 --- a/tests/unit/linalg/test_dual_cone.py +++ b/tests/unit/linalg/test_dual_cone.py @@ -9,7 +9,7 @@ from torchjd._linalg import DualConeProjector, PSDMatrix, QPSolverBased, compute_gramian -@mark.parametrize("projector", [QPSolverBased("quadprog")]) +@mark.parametrize("projector", [QPSolverBased(reg_eps=0.0, norm_eps=0.0)]) @mark.parametrize("shape", [(5, 7), (9, 37), (2, 14), (32, 114), (50, 100)]) def test_solution_weights(projector: DualConeProjector, shape: tuple[int, int]) -> None: r""" @@ -55,7 +55,7 @@ def test_solution_weights(projector: DualConeProjector, shape: tuple[int, int]) assert_close(slackness, torch.zeros_like(slackness), atol=3e-03, rtol=0) -@mark.parametrize("projector", [QPSolverBased("quadprog")]) +@mark.parametrize("projector", [QPSolverBased(reg_eps=0.0, norm_eps=0.0)]) @mark.parametrize("shape", [(5, 7), (9, 37), (32, 114)]) @mark.parametrize("scaling", [2 ** (-4), 2 ** (-2), 2**2, 2**4]) def test_scale_invariant( @@ -76,7 +76,7 @@ def test_scale_invariant( assert_close(w_scaled, w) -@mark.parametrize("projector", [QPSolverBased("quadprog")]) +@mark.parametrize("projector", [QPSolverBased(reg_eps=0.0, norm_eps=0.0)]) @mark.parametrize("shape", [(5, 2, 3), (1, 3, 6, 9), (2, 1, 1, 5, 8), (3, 1)]) def test_tensorization_shape(projector: DualConeProjector, shape: tuple[int, ...]) -> None: """ @@ -102,7 +102,7 @@ def test_qp_solver_based_failure() -> None: values. """ - projector = QPSolverBased("quadprog") + projector = QPSolverBased() large_J = np.random.randn(10, 100) * 1e5 large_G = large_J @ large_J.T From 470ccd05577fe4e3eda19dc0de7a9464c818d665 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Tue, 12 May 2026 11:02:35 +0200 Subject: [PATCH 08/20] Rename QPBased to QuadprogProjector. Remove parameter solver. Will complexify class architecture later if needed. --- src/torchjd/_linalg/__init__.py | 4 ++-- src/torchjd/_linalg/_dual_cone.py | 13 ++++--------- tests/unit/aggregation/test_dualproj.py | 6 +++--- tests/unit/aggregation/test_pcgrad.py | 4 ++-- tests/unit/aggregation/test_upgrad.py | 6 +++--- tests/unit/linalg/test_dual_cone.py | 10 +++++----- 6 files changed, 19 insertions(+), 24 deletions(-) diff --git a/src/torchjd/_linalg/__init__.py b/src/torchjd/_linalg/__init__.py index 2035b7b2..fce72e8e 100644 --- a/src/torchjd/_linalg/__init__.py +++ b/src/torchjd/_linalg/__init__.py @@ -1,4 +1,4 @@ -from ._dual_cone import DualConeProjector, QPSolverBased, projector_or_default +from ._dual_cone import DualConeProjector, QuadprogProjector, projector_or_default from ._generalized_gramian import flatten, movedim, reshape from ._gramian import compute_gramian, normalize, regularize from ._matrix import Matrix, PSDMatrix, PSDTensor, is_matrix, is_psd_matrix, is_psd_tensor @@ -17,6 +17,6 @@ "reshape", "movedim", "DualConeProjector", - "QPSolverBased", + "QuadprogProjector", "projector_or_default", ] diff --git a/src/torchjd/_linalg/_dual_cone.py b/src/torchjd/_linalg/_dual_cone.py index f69e504d..b0927fe4 100644 --- a/src/torchjd/_linalg/_dual_cone.py +++ b/src/torchjd/_linalg/_dual_cone.py @@ -1,5 +1,4 @@ from abc import ABC, abstractmethod -from typing import Literal, TypeAlias import numpy as np import torch @@ -34,11 +33,11 @@ def __call__(self, U: Tensor, G: PSDMatrix) -> Tensor: def projector_or_default(projector: DualConeProjector | None) -> DualConeProjector: if projector is None: - return QPSolverBased(solver="quadprog") + return QuadprogProjector() return projector -class QPSolverBased(DualConeProjector): +class QuadprogProjector(DualConeProjector): """ :param norm_eps: A small value to avoid division by zero when normalizing. @@ -48,21 +47,17 @@ class QPSolverBased(DualConeProjector): ensures that it is positive definite. """ - SUPPORTED_SOLVER: TypeAlias = Literal["quadprog"] - def __init__( self, *, norm_eps: float = 0.0001, reg_eps: float = 0.0001, - solver: SUPPORTED_SOLVER = "quadprog", ) -> None: self.norm_eps = norm_eps self.reg_eps = reg_eps - self.solver = solver def __repr__(self) -> str: - return f"QPSolverBased({repr(self.solver)})" + return "QuadprogProjector()" def __call__(self, U: Tensor, G: PSDMatrix) -> Tensor: @@ -78,7 +73,7 @@ def __call__(self, U: Tensor, G: PSDMatrix) -> Tensor: def _project_weight_vector(self, u: np.ndarray, G: np.ndarray) -> np.ndarray: m = G.shape[0] - w = solve_qp(G, np.zeros(m), -np.eye(m), -u, solver=self.solver) + w = solve_qp(G, np.zeros(m), -np.eye(m), -u, solver="quadprog") if w is None: # This may happen when G has large values. raise ValueError("Failed to solve the quadratic programming problem.") diff --git a/tests/unit/aggregation/test_dualproj.py b/tests/unit/aggregation/test_dualproj.py index 0e9d0ba5..da6eb974 100644 --- a/tests/unit/aggregation/test_dualproj.py +++ b/tests/unit/aggregation/test_dualproj.py @@ -3,7 +3,7 @@ from torch import Tensor from utils.tensors import ones_ -from torchjd._linalg import QPSolverBased +from torchjd._linalg import QuadprogProjector from torchjd.aggregation import ConstantWeighting, DualProj from torchjd.aggregation._dualproj import DualProjWeighting @@ -48,13 +48,13 @@ def test_non_differentiable(aggregator: DualProj, matrix: Tensor) -> None: def test_representations() -> None: - A = DualProj(pref_vector=None, projector=QPSolverBased()) + A = DualProj(pref_vector=None, projector=QuadprogProjector()) assert repr(A) == "DualProj(pref_vector=None, projector=QPSolverBased('quadprog'))" assert str(A) == "DualProj" A = DualProj( pref_vector=torch.tensor([1.0, 2.0, 3.0], device="cpu"), - projector=QPSolverBased(), + projector=QuadprogProjector(), ) assert ( repr(A) == "DualProj(pref_vector=tensor([1., 2., 3.]), projector=QPSolverBased('quadprog'))" diff --git a/tests/unit/aggregation/test_pcgrad.py b/tests/unit/aggregation/test_pcgrad.py index 89565664..6d22359f 100644 --- a/tests/unit/aggregation/test_pcgrad.py +++ b/tests/unit/aggregation/test_pcgrad.py @@ -3,7 +3,7 @@ from torch.testing import assert_close from utils.tensors import ones_, randn_ -from torchjd._linalg import QPSolverBased, compute_gramian +from torchjd._linalg import QuadprogProjector, compute_gramian from torchjd.aggregation import PCGrad from torchjd.aggregation._pcgrad import PCGradWeighting from torchjd.aggregation._upgrad import UPGradWeighting @@ -53,7 +53,7 @@ def test_equivalence_upgrad_sum_two_rows(shape: tuple[int, int]) -> None: pc_grad_weighting = PCGradWeighting() upgrad_sum_weighting = UPGradWeighting( ones_((2,)), - projector=QPSolverBased(norm_eps=0.0, reg_eps=0.0), + projector=QuadprogProjector(norm_eps=0.0, reg_eps=0.0), ) result = pc_grad_weighting(gramian) diff --git a/tests/unit/aggregation/test_upgrad.py b/tests/unit/aggregation/test_upgrad.py index e79cf1f8..0ba71877 100644 --- a/tests/unit/aggregation/test_upgrad.py +++ b/tests/unit/aggregation/test_upgrad.py @@ -3,7 +3,7 @@ from torch import Tensor from utils.tensors import ones_ -from torchjd._linalg import QPSolverBased +from torchjd._linalg import QuadprogProjector from torchjd.aggregation import ConstantWeighting, UPGrad from torchjd.aggregation._upgrad import UPGradWeighting @@ -54,13 +54,13 @@ def test_non_differentiable(aggregator: UPGrad, matrix: Tensor) -> None: def test_representations() -> None: - A = UPGrad(pref_vector=None, projector=QPSolverBased()) + A = UPGrad(pref_vector=None, projector=QuadprogProjector()) assert repr(A) == "UPGrad(pref_vector=None, projector=QPSolverBased('quadprog'))" assert str(A) == "UPGrad" A = UPGrad( pref_vector=torch.tensor([1.0, 2.0, 3.0], device="cpu"), - projector=QPSolverBased(), + projector=QuadprogProjector(), ) assert ( repr(A) == "UPGrad(pref_vector=tensor([1., 2., 3.]), projector=QPSolverBased('quadprog'))" diff --git a/tests/unit/linalg/test_dual_cone.py b/tests/unit/linalg/test_dual_cone.py index c79d5266..e42956c1 100644 --- a/tests/unit/linalg/test_dual_cone.py +++ b/tests/unit/linalg/test_dual_cone.py @@ -6,10 +6,10 @@ from torch.testing import assert_close from utils.tensors import rand_, randn_ -from torchjd._linalg import DualConeProjector, PSDMatrix, QPSolverBased, compute_gramian +from torchjd._linalg import DualConeProjector, PSDMatrix, QuadprogProjector, compute_gramian -@mark.parametrize("projector", [QPSolverBased(reg_eps=0.0, norm_eps=0.0)]) +@mark.parametrize("projector", [QuadprogProjector(reg_eps=0.0, norm_eps=0.0)]) @mark.parametrize("shape", [(5, 7), (9, 37), (2, 14), (32, 114), (50, 100)]) def test_solution_weights(projector: DualConeProjector, shape: tuple[int, int]) -> None: r""" @@ -55,7 +55,7 @@ def test_solution_weights(projector: DualConeProjector, shape: tuple[int, int]) assert_close(slackness, torch.zeros_like(slackness), atol=3e-03, rtol=0) -@mark.parametrize("projector", [QPSolverBased(reg_eps=0.0, norm_eps=0.0)]) +@mark.parametrize("projector", [QuadprogProjector(reg_eps=0.0, norm_eps=0.0)]) @mark.parametrize("shape", [(5, 7), (9, 37), (32, 114)]) @mark.parametrize("scaling", [2 ** (-4), 2 ** (-2), 2**2, 2**4]) def test_scale_invariant( @@ -76,7 +76,7 @@ def test_scale_invariant( assert_close(w_scaled, w) -@mark.parametrize("projector", [QPSolverBased(reg_eps=0.0, norm_eps=0.0)]) +@mark.parametrize("projector", [QuadprogProjector(reg_eps=0.0, norm_eps=0.0)]) @mark.parametrize("shape", [(5, 2, 3), (1, 3, 6, 9), (2, 1, 1, 5, 8), (3, 1)]) def test_tensorization_shape(projector: DualConeProjector, shape: tuple[int, ...]) -> None: """ @@ -102,7 +102,7 @@ def test_qp_solver_based_failure() -> None: values. """ - projector = QPSolverBased() + projector = QuadprogProjector() large_J = np.random.randn(10, 100) * 1e5 large_G = large_J @ large_J.T From a8ccd16f77b7e22a2a72d91ffd7953ee682144d9 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Tue, 12 May 2026 11:10:17 +0200 Subject: [PATCH 09/20] Fix representations --- src/torchjd/_linalg/_dual_cone.py | 2 +- src/torchjd/aggregation/_dualproj.py | 4 ++-- src/torchjd/aggregation/_upgrad.py | 4 ++-- tests/unit/aggregation/test_dualproj.py | 12 ++++++++---- tests/unit/aggregation/test_upgrad.py | 12 ++++++++---- 5 files changed, 21 insertions(+), 13 deletions(-) diff --git a/src/torchjd/_linalg/_dual_cone.py b/src/torchjd/_linalg/_dual_cone.py index b0927fe4..2af85789 100644 --- a/src/torchjd/_linalg/_dual_cone.py +++ b/src/torchjd/_linalg/_dual_cone.py @@ -57,7 +57,7 @@ def __init__( self.reg_eps = reg_eps def __repr__(self) -> str: - return "QuadprogProjector()" + return f"QuadprogProjector(norm_eps={self.norm_eps}, reg_eps={self.reg_eps})" def __call__(self, U: Tensor, G: PSDMatrix) -> Tensor: diff --git a/src/torchjd/aggregation/_dualproj.py b/src/torchjd/aggregation/_dualproj.py index f1e9aeec..c69f99eb 100644 --- a/src/torchjd/aggregation/_dualproj.py +++ b/src/torchjd/aggregation/_dualproj.py @@ -132,8 +132,8 @@ def projector(self, value: DualConeProjector | None) -> None: def __repr__(self) -> str: return ( - f"{self.__class__.__name__}(pref_vector={repr(self.pref_vector)}, norm_eps=" - f"{self.norm_eps}, reg_eps={self.reg_eps}, projector={repr(self.projector)})" + f"{self.__class__.__name__}(pref_vector={repr(self.pref_vector)}, projector=" + f"{repr(self.projector)})" ) def __str__(self) -> str: diff --git a/src/torchjd/aggregation/_upgrad.py b/src/torchjd/aggregation/_upgrad.py index 47eef4d5..6daa25ed 100644 --- a/src/torchjd/aggregation/_upgrad.py +++ b/src/torchjd/aggregation/_upgrad.py @@ -135,8 +135,8 @@ def projector(self, value: DualConeProjector | None) -> None: def __repr__(self) -> str: return ( - f"{self.__class__.__name__}(pref_vector={repr(self.pref_vector)}, norm_eps=" - f"{self.norm_eps}, reg_eps={self.reg_eps}, projector={repr(self.projector)})" + f"{self.__class__.__name__}(pref_vector={repr(self.pref_vector)}, projector=" + f"{repr(self.projector)})" ) def __str__(self) -> str: diff --git a/tests/unit/aggregation/test_dualproj.py b/tests/unit/aggregation/test_dualproj.py index da6eb974..f225aae6 100644 --- a/tests/unit/aggregation/test_dualproj.py +++ b/tests/unit/aggregation/test_dualproj.py @@ -48,16 +48,20 @@ def test_non_differentiable(aggregator: DualProj, matrix: Tensor) -> None: def test_representations() -> None: - A = DualProj(pref_vector=None, projector=QuadprogProjector()) - assert repr(A) == "DualProj(pref_vector=None, projector=QPSolverBased('quadprog'))" + A = DualProj(pref_vector=None, projector=QuadprogProjector(norm_eps=0.001, reg_eps=0.01)) + assert ( + repr(A) == "DualProj(pref_vector=None, projector=QuadprogProjector(norm_eps=0.001, " + "reg_eps=0.01))" + ) assert str(A) == "DualProj" A = DualProj( pref_vector=torch.tensor([1.0, 2.0, 3.0], device="cpu"), - projector=QuadprogProjector(), + projector=QuadprogProjector(norm_eps=0.001, reg_eps=0.01), ) assert ( - repr(A) == "DualProj(pref_vector=tensor([1., 2., 3.]), projector=QPSolverBased('quadprog'))" + repr(A) == "DualProj(pref_vector=tensor([1., 2., 3.]), projector=QuadprogProjector(" + "norm_eps=0.001, reg_eps=0.01))" ) assert str(A) == "DualProj([1., 2., 3.])" diff --git a/tests/unit/aggregation/test_upgrad.py b/tests/unit/aggregation/test_upgrad.py index 0ba71877..88f54d82 100644 --- a/tests/unit/aggregation/test_upgrad.py +++ b/tests/unit/aggregation/test_upgrad.py @@ -54,16 +54,20 @@ def test_non_differentiable(aggregator: UPGrad, matrix: Tensor) -> None: def test_representations() -> None: - A = UPGrad(pref_vector=None, projector=QuadprogProjector()) - assert repr(A) == "UPGrad(pref_vector=None, projector=QPSolverBased('quadprog'))" + A = UPGrad(pref_vector=None, projector=QuadprogProjector(norm_eps=0.001, reg_eps=0.01)) + assert ( + repr(A) == "UPGrad(pref_vector=None, projector=QuadprogProjector(norm_eps=0.001, " + "reg_eps=0.01))" + ) assert str(A) == "UPGrad" A = UPGrad( pref_vector=torch.tensor([1.0, 2.0, 3.0], device="cpu"), - projector=QuadprogProjector(), + projector=QuadprogProjector(norm_eps=0.001, reg_eps=0.01), ) assert ( - repr(A) == "UPGrad(pref_vector=tensor([1., 2., 3.]), projector=QPSolverBased('quadprog'))" + repr(A) == "UPGrad(pref_vector=tensor([1., 2., 3.]), projector=QuadprogProjector(" + "norm_eps=0.001, reg_eps=0.01))" ) assert str(A) == "UPGrad([1., 2., 3.])" From 5780f6bf961029bd2e1378570439e85977ab7b3c Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Tue, 12 May 2026 11:33:17 +0200 Subject: [PATCH 10/20] Improve docstring of `QuadprogProjector` --- src/torchjd/_linalg/_dual_cone.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/torchjd/_linalg/_dual_cone.py b/src/torchjd/_linalg/_dual_cone.py index 2af85789..b785d1ee 100644 --- a/src/torchjd/_linalg/_dual_cone.py +++ b/src/torchjd/_linalg/_dual_cone.py @@ -38,7 +38,9 @@ def projector_or_default(projector: DualConeProjector | None) -> DualConeProject class QuadprogProjector(DualConeProjector): - """ + r""" + Solves the quadratic program defined in :meth:`DualConeProjector.__call__` using the + ``quadprog`` QP solver. :param norm_eps: A small value to avoid division by zero when normalizing. :param reg_eps: A small value to add to the diagonal of the gramian of the matrix. Due to From 70f814f22e118dc80ec8770677d2599ee7efa3c3 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Tue, 12 May 2026 14:34:52 +0200 Subject: [PATCH 11/20] Expose projectors --- docs/source/docs/linalg/dual_cone.rst | 9 +++++++++ docs/source/docs/linalg/index.rst | 1 + src/torchjd/linalg/__init__.py | 14 ++++++++++++-- 3 files changed, 22 insertions(+), 2 deletions(-) create mode 100644 docs/source/docs/linalg/dual_cone.rst diff --git a/docs/source/docs/linalg/dual_cone.rst b/docs/source/docs/linalg/dual_cone.rst new file mode 100644 index 00000000..f7db87ad --- /dev/null +++ b/docs/source/docs/linalg/dual_cone.rst @@ -0,0 +1,9 @@ +:hide-toc: + +Dual Cone Projectors +==================== + +.. autoclass:: torchjd.linalg.DualConeProjector + :members: __call__ + +.. autoclass:: torchjd.linalg.QuadprogProjector diff --git a/docs/source/docs/linalg/index.rst b/docs/source/docs/linalg/index.rst index 4446ccea..94fcce20 100644 --- a/docs/source/docs/linalg/index.rst +++ b/docs/source/docs/linalg/index.rst @@ -10,3 +10,4 @@ linalg matrix.rst psd_matrix.rst + dual_cone.rst diff --git a/src/torchjd/linalg/__init__.py b/src/torchjd/linalg/__init__.py index f8238104..4e83639c 100644 --- a/src/torchjd/linalg/__init__.py +++ b/src/torchjd/linalg/__init__.py @@ -3,6 +3,16 @@ properties. """ -from torchjd._linalg._matrix import Matrix, PSDMatrix +from torchjd._linalg import ( + DualConeProjector, + Matrix, + PSDMatrix, + QuadprogProjector, +) -__all__ = ["Matrix", "PSDMatrix"] +__all__ = [ + "DualConeProjector", + "Matrix", + "PSDMatrix", + "QuadprogProjector", +] From 83b528efe6c2eefcda851c5abd7659547394a957 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Tue, 12 May 2026 14:41:21 +0200 Subject: [PATCH 12/20] Add default specification in docs. --- src/torchjd/_linalg/_dual_cone.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/torchjd/_linalg/_dual_cone.py b/src/torchjd/_linalg/_dual_cone.py index b785d1ee..0ffe5338 100644 --- a/src/torchjd/_linalg/_dual_cone.py +++ b/src/torchjd/_linalg/_dual_cone.py @@ -10,6 +10,12 @@ class DualConeProjector(ABC): + """ + Abstract class whose instances are responsible for projecting vectors onto the dual cone of the + rows of a matrix, or rather the dual form of this problem. The current default + :class:`~torchjd.linalg.DualConeProjector` is :class:`~torchjd.linalg.QuadprogProjector`. + """ + @abstractmethod def __call__(self, U: Tensor, G: PSDMatrix) -> Tensor: r""" From 563faff82736bc41907efc7f6270881b08fe55b5 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Tue, 12 May 2026 14:49:12 +0200 Subject: [PATCH 13/20] Update docstring parameter solver to projector, update docstring of package linalg. --- src/torchjd/aggregation/_dualproj.py | 4 ++-- src/torchjd/aggregation/_upgrad.py | 4 ++-- src/torchjd/linalg/__init__.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/torchjd/aggregation/_dualproj.py b/src/torchjd/aggregation/_dualproj.py index c69f99eb..9c540de9 100644 --- a/src/torchjd/aggregation/_dualproj.py +++ b/src/torchjd/aggregation/_dualproj.py @@ -18,7 +18,7 @@ class DualProjWeighting(_NonDifferentiable, _GramianWeighting): :param pref_vector: The preference vector to use. If not provided, defaults to :math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`. - :param solver: The solver used to optimize the underlying optimization problem. + :param projector: The :class:`~torchjd.linalg.DualConeProjector` used tocompute the projection. """ def __init__( @@ -84,7 +84,7 @@ class DualProj(_NonDifferentiable, GramianWeightedAggregator): :param pref_vector: The preference vector used to combine the rows. If not provided, defaults to :math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`. - :param solver: The solver used to optimize the underlying optimization problem. + :param projector: The :class:`~torchjd.linalg.DualConeProjector` used tocompute the projection. """ gramian_weighting: DualProjWeighting diff --git a/src/torchjd/aggregation/_upgrad.py b/src/torchjd/aggregation/_upgrad.py index 6daa25ed..78bac958 100644 --- a/src/torchjd/aggregation/_upgrad.py +++ b/src/torchjd/aggregation/_upgrad.py @@ -19,7 +19,7 @@ class UPGradWeighting(_NonDifferentiable, _GramianWeighting): :param pref_vector: The preference vector to use. If not provided, defaults to :math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`. - :param solver: The solver used to optimize the underlying optimization problem. + :param projector: The :class:`~torchjd.linalg.DualConeProjector` used tocompute the projection. """ def __init__( @@ -87,7 +87,7 @@ class UPGrad(_NonDifferentiable, GramianWeightedAggregator): :param pref_vector: The preference vector used to combine the projected rows. If not provided, defaults to :math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`. - :param solver: The solver used to optimize the underlying optimization problem. + :param projector: The :class:`~torchjd.linalg.DualConeProjector` used tocompute the projection. """ gramian_weighting: UPGradWeighting diff --git a/src/torchjd/linalg/__init__.py b/src/torchjd/linalg/__init__.py index 4e83639c..15476b73 100644 --- a/src/torchjd/linalg/__init__.py +++ b/src/torchjd/linalg/__init__.py @@ -1,6 +1,6 @@ """ -This module provides type annotation classes representing tensors with specific structural -properties. +This module provides utilitary linear algebra methods as well as types to represent specific +structural properties. """ from torchjd._linalg import ( From e849d89c89aae4886ca161d5a77dcafe0702087f Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Tue, 12 May 2026 14:56:34 +0200 Subject: [PATCH 14/20] remove norm_eps and reg_eps setters and getters from aggregators/weighting. They could be added to Projectors if needed. --- src/torchjd/aggregation/_dualproj.py | 38 -------------------------- src/torchjd/aggregation/_upgrad.py | 40 ---------------------------- 2 files changed, 78 deletions(-) diff --git a/src/torchjd/aggregation/_dualproj.py b/src/torchjd/aggregation/_dualproj.py index 9c540de9..603d3458 100644 --- a/src/torchjd/aggregation/_dualproj.py +++ b/src/torchjd/aggregation/_dualproj.py @@ -44,28 +44,6 @@ def pref_vector(self, value: Tensor | None) -> None: self.weighting = pref_vector_to_weighting(value, default=MeanWeighting()) self._pref_vector = value - @property - def norm_eps(self) -> float: - return self._norm_eps - - @norm_eps.setter - def norm_eps(self, value: float) -> None: - if value < 0: - raise ValueError(f"norm_eps must be non-negative, but got {value}.") - - self._norm_eps = value - - @property - def reg_eps(self) -> float: - return self._reg_eps - - @reg_eps.setter - def reg_eps(self, value: float) -> None: - if value < 0: - raise ValueError(f"reg_eps must be non-negative, but got {value}.") - - self._reg_eps = value - @property def projector(self) -> DualConeProjector: return self._projector @@ -106,22 +84,6 @@ def pref_vector(self) -> Tensor | None: def pref_vector(self, value: Tensor | None) -> None: self.gramian_weighting.pref_vector = value - @property - def norm_eps(self) -> float: - return self.gramian_weighting.norm_eps - - @norm_eps.setter - def norm_eps(self, value: float) -> None: - self.gramian_weighting.norm_eps = value - - @property - def reg_eps(self) -> float: - return self.gramian_weighting.reg_eps - - @reg_eps.setter - def reg_eps(self, value: float) -> None: - self.gramian_weighting.reg_eps = value - @property def projector(self) -> DualConeProjector: return self.gramian_weighting.projector diff --git a/src/torchjd/aggregation/_upgrad.py b/src/torchjd/aggregation/_upgrad.py index 78bac958..001f0db5 100644 --- a/src/torchjd/aggregation/_upgrad.py +++ b/src/torchjd/aggregation/_upgrad.py @@ -45,30 +45,6 @@ def pref_vector(self, value: Tensor | None) -> None: self.weighting = pref_vector_to_weighting(value, default=MeanWeighting()) self._pref_vector = value - @property - def norm_eps(self) -> float: - return self._norm_eps - - @norm_eps.setter - def norm_eps(self, value: float) -> None: - - if value < 0: - raise ValueError(f"norm_eps must be non-negative, but got {value}.") - - self._norm_eps = value - - @property - def reg_eps(self) -> float: - return self._reg_eps - - @reg_eps.setter - def reg_eps(self, value: float) -> None: - - if value < 0: - raise ValueError(f"reg_eps must be non-negative, but got {value}.") - - self._reg_eps = value - @property def projector(self) -> DualConeProjector: return self._projector @@ -109,22 +85,6 @@ def pref_vector(self) -> Tensor | None: def pref_vector(self, value: Tensor | None) -> None: self.gramian_weighting.pref_vector = value - @property - def norm_eps(self) -> float: - return self.gramian_weighting.norm_eps - - @norm_eps.setter - def norm_eps(self, value: float) -> None: - self.gramian_weighting.norm_eps = value - - @property - def reg_eps(self) -> float: - return self.gramian_weighting.reg_eps - - @reg_eps.setter - def reg_eps(self, value: float) -> None: - self.gramian_weighting.reg_eps = value - @property def projector(self) -> DualConeProjector: return self.gramian_weighting.projector From 01ffc2b22854f85e692af7fd34f4285dd6c79891 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Tue, 12 May 2026 14:59:02 +0200 Subject: [PATCH 15/20] Add changelog entry. --- CHANGELOG.md | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index ebeeb67d..675cce12 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,22 @@ changelog does not include internal changes that do not affect the user. ### Changed +- **BREAKING**: Moved normalization, regularization, and QP solver configuration from `UPGrad`, + `UPGradWeighting`, `DualProj`, and `DualProjWeighting` to a new `projector` parameter accepting a + `DualConeProjector`. The `norm_eps`, `reg_eps`, and `solver` constructor parameters of these + classes have been removed. The default projector is `QuadprogProjector`, which accepts `norm_eps` + and `reg_eps` as keyword-only arguments. To update: + ```python + # Before + from torchjd.aggregation import UPGrad + aggregator = UPGrad(pref_vector=torch.tensor([0.7, 0.3]), norm_eps=0.001, reg_eps=0.001, solver="quadprog") + + # After + from torchjd.aggregation import UPGrad + from torchjd.linalg import QuadprogProjector + aggregator = UPGrad(pref_vector=torch.tensor([0.7, 0.3]), projector=QuadprogProjector(norm_eps=0.001, reg_eps=0.001)) + ``` + The `solver` parameter has been removed; the default projector uses `quadprog` internally. - `CAGrad`, `CAGradWeighting`, and `NashMTL` are now always importable from `torchjd.aggregation`, even when their optional dependencies are not installed. Attempting to instantiate them without the required dependencies now raises an `ImportError` with installation instructions, instead of From 42f227baf7cadea2caad70a7a87650b966664766 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Tue, 12 May 2026 17:24:18 +0200 Subject: [PATCH 16/20] remove eps setters/getters tests. --- tests/unit/aggregation/test_dualproj.py | 41 +------------------------ tests/unit/aggregation/test_upgrad.py | 39 +---------------------- 2 files changed, 2 insertions(+), 78 deletions(-) diff --git a/tests/unit/aggregation/test_dualproj.py b/tests/unit/aggregation/test_dualproj.py index f225aae6..c0259e18 100644 --- a/tests/unit/aggregation/test_dualproj.py +++ b/tests/unit/aggregation/test_dualproj.py @@ -1,11 +1,10 @@ import torch -from pytest import mark, raises +from pytest import mark from torch import Tensor from utils.tensors import ones_ from torchjd._linalg import QuadprogProjector from torchjd.aggregation import ConstantWeighting, DualProj -from torchjd.aggregation._dualproj import DualProjWeighting from ._asserts import ( assert_expected_structure, @@ -73,41 +72,3 @@ def test_pref_vector_setter_updates_value() -> None: assert A.pref_vector is new_pref assert isinstance(A.gramian_weighting.weighting, ConstantWeighting) assert A.gramian_weighting.weighting.weights is new_pref - - -def test_norm_eps_setter_updates_value() -> None: - A = DualProj() - A.norm_eps = 0.25 - assert A.norm_eps == 0.25 - assert A.gramian_weighting.norm_eps == 0.25 - - -def test_reg_eps_setter_updates_value() -> None: - A = DualProj() - A.reg_eps = 0.25 - assert A.reg_eps == 0.25 - assert A.gramian_weighting.reg_eps == 0.25 - - -def test_norm_eps_setter_rejects_negative() -> None: - A = DualProj() - with raises(ValueError, match="norm_eps"): - A.norm_eps = -1e-9 - - -def test_reg_eps_setter_rejects_negative() -> None: - A = DualProj() - with raises(ValueError, match="reg_eps"): - A.reg_eps = -1e-9 - - -def test_weighting_norm_eps_setter_rejects_negative() -> None: - W = DualProjWeighting() - with raises(ValueError, match="norm_eps"): - W.norm_eps = -1e-9 - - -def test_weighting_reg_eps_setter_rejects_negative() -> None: - W = DualProjWeighting() - with raises(ValueError, match="reg_eps"): - W.reg_eps = -1e-9 diff --git a/tests/unit/aggregation/test_upgrad.py b/tests/unit/aggregation/test_upgrad.py index 88f54d82..579d99a7 100644 --- a/tests/unit/aggregation/test_upgrad.py +++ b/tests/unit/aggregation/test_upgrad.py @@ -1,11 +1,10 @@ import torch -from pytest import mark, raises +from pytest import mark from torch import Tensor from utils.tensors import ones_ from torchjd._linalg import QuadprogProjector from torchjd.aggregation import ConstantWeighting, UPGrad -from torchjd.aggregation._upgrad import UPGradWeighting from ._asserts import ( assert_expected_structure, @@ -79,39 +78,3 @@ def test_pref_vector_setter_updates_value() -> None: assert A.pref_vector is new_pref assert isinstance(A.gramian_weighting.weighting, ConstantWeighting) assert A.gramian_weighting.weighting.weights is new_pref - - -def test_norm_eps_setter_updates_value() -> None: - A = UPGrad() - A.norm_eps = 0.25 - assert A.norm_eps == 0.25 - - -def test_reg_eps_setter_updates_value() -> None: - A = UPGrad() - A.reg_eps = 0.25 - assert A.reg_eps == 0.25 - - -def test_norm_eps_setter_rejects_negative() -> None: - A = UPGrad() - with raises(ValueError, match="norm_eps"): - A.norm_eps = -1e-9 - - -def test_reg_eps_setter_rejects_negative() -> None: - A = UPGrad() - with raises(ValueError, match="reg_eps"): - A.reg_eps = -1e-9 - - -def test_weighting_norm_eps_setter_rejects_negative() -> None: - W = UPGradWeighting() - with raises(ValueError, match="norm_eps"): - W.norm_eps = -1e-9 - - -def test_weighting_reg_eps_setter_rejects_negative() -> None: - W = UPGradWeighting() - with raises(ValueError, match="reg_eps"): - W.reg_eps = -1e-9 From 89bc6d1ba0c4da0ecbe8156769dac885b0849be5 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Tue, 12 May 2026 17:26:03 +0200 Subject: [PATCH 17/20] fix interactive plotter --- tests/plots/interactive_plotter.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/plots/interactive_plotter.py b/tests/plots/interactive_plotter.py index 1c2b6240..c8b3871a 100644 --- a/tests/plots/interactive_plotter.py +++ b/tests/plots/interactive_plotter.py @@ -11,6 +11,7 @@ from typing_extensions import Unpack from plots._utils import Plotter, angle_to_coord, coord_to_angle +from torchjd._linalg import QuadprogProjector from torchjd.aggregation import ( IMTLG, MGDA, @@ -61,7 +62,7 @@ def main() -> None: "AlignedMTL-RMSE": lambda: AlignedMTL(scale_mode="rmse"), str(CAGrad(c=0.5)): lambda: CAGrad(c=0.5), str(ConFIG()): lambda: ConFIG(), - str(DualProj()): lambda: DualProj(reg_eps=1e-7), + str(DualProj()): lambda: DualProj(projector=QuadprogProjector(reg_eps=1e-7)), str(GradDrop()): lambda: GradDrop(), str(GradVac()): lambda: GradVac(), str(IMTLG()): lambda: IMTLG(), @@ -72,7 +73,7 @@ def main() -> None: str(Random()): lambda: Random(), str(Sum()): lambda: Sum(), str(TrimmedMean(trim_number=1)): lambda: TrimmedMean(trim_number=1), - str(UPGrad()): lambda: UPGrad(reg_eps=1e-7), + str(UPGrad()): lambda: UPGrad(projector=QuadprogProjector(reg_eps=1e-7)), } aggregator_strings = list(aggregator_factories.keys()) From 33d171ac96963a88d4b8e217f3fe09c084d6bf14 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Tue, 12 May 2026 21:42:30 +0200 Subject: [PATCH 18/20] add getters/setters tests for projectors. --- tests/unit/aggregation/test_dualproj.py | 13 +++++++++++++ tests/unit/aggregation/test_upgrad.py | 13 +++++++++++++ 2 files changed, 26 insertions(+) diff --git a/tests/unit/aggregation/test_dualproj.py b/tests/unit/aggregation/test_dualproj.py index c0259e18..190eaaeb 100644 --- a/tests/unit/aggregation/test_dualproj.py +++ b/tests/unit/aggregation/test_dualproj.py @@ -72,3 +72,16 @@ def test_pref_vector_setter_updates_value() -> None: assert A.pref_vector is new_pref assert isinstance(A.gramian_weighting.weighting, ConstantWeighting) assert A.gramian_weighting.weighting.weights is new_pref + + +def test_projector_getter_returns_default() -> None: + A = DualProj() + assert isinstance(A.projector, QuadprogProjector) + + +def test_projector_setter_updates_value() -> None: + A = DualProj() + new_projector = QuadprogProjector(norm_eps=0.001, reg_eps=0.01) + A.projector = new_projector + assert A.projector is new_projector + assert A.gramian_weighting.projector is new_projector diff --git a/tests/unit/aggregation/test_upgrad.py b/tests/unit/aggregation/test_upgrad.py index 579d99a7..b639763c 100644 --- a/tests/unit/aggregation/test_upgrad.py +++ b/tests/unit/aggregation/test_upgrad.py @@ -78,3 +78,16 @@ def test_pref_vector_setter_updates_value() -> None: assert A.pref_vector is new_pref assert isinstance(A.gramian_weighting.weighting, ConstantWeighting) assert A.gramian_weighting.weighting.weights is new_pref + + +def test_projector_getter_returns_default() -> None: + A = UPGrad() + assert isinstance(A.projector, QuadprogProjector) + + +def test_projector_setter_updates_value() -> None: + A = UPGrad() + new_projector = QuadprogProjector(norm_eps=0.001, reg_eps=0.01) + A.projector = new_projector + assert A.projector is new_projector + assert A.gramian_weighting.projector is new_projector From e8291aea88190d02d65d139cbc44b04ea077d7dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 13 May 2026 19:42:57 +0200 Subject: [PATCH 19/20] Improve changelog --- CHANGELOG.md | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 675cce12..e2bd218f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,22 +10,21 @@ changelog does not include internal changes that do not affect the user. ### Changed -- **BREAKING**: Moved normalization, regularization, and QP solver configuration from `UPGrad`, - `UPGradWeighting`, `DualProj`, and `DualProjWeighting` to a new `projector` parameter accepting a - `DualConeProjector`. The `norm_eps`, `reg_eps`, and `solver` constructor parameters of these - classes have been removed. The default projector is `QuadprogProjector`, which accepts `norm_eps` - and `reg_eps` as keyword-only arguments. To update: +- **BREAKING**: Removed `norm_eps`, `rep_eps` and `solver` parameters from the `__init__` of + `UPGrad`, `UPGradWeighting`, `DualProj` and `DualProjWeighting` in favor of a `projector` + parameter of type `DualConeProjector`. To update: ```python # Before from torchjd.aggregation import UPGrad - aggregator = UPGrad(pref_vector=torch.tensor([0.7, 0.3]), norm_eps=0.001, reg_eps=0.001, solver="quadprog") + aggregator = UPGrad(norm_eps=1e-6, reg_eps=1e-6, solver="quadprog") # After from torchjd.aggregation import UPGrad from torchjd.linalg import QuadprogProjector - aggregator = UPGrad(pref_vector=torch.tensor([0.7, 0.3]), projector=QuadprogProjector(norm_eps=0.001, reg_eps=0.001)) + aggregator = UPGrad(projector=QuadprogProjector(norm_eps=1e-6, reg_eps=1e-6)) ``` - The `solver` parameter has been removed; the default projector uses `quadprog` internally. + If you used the default `norm_eps`, `reg_eps` and `solver`, you don't have to change anything and + you will get the same results. - `CAGrad`, `CAGradWeighting`, and `NashMTL` are now always importable from `torchjd.aggregation`, even when their optional dependencies are not installed. Attempting to instantiate them without the required dependencies now raises an `ImportError` with installation instructions, instead of @@ -37,6 +36,9 @@ changelog does not include internal changes that do not affect the user. ### Added +- Added a new abstraction: the `DualConeProjector` abstract base class and its concrete + `QuadprogProjector` implementation, to do the projection of the gradients onto the dual cone, as + required in `UPGrad`, and `DualProj`. These classes can be found in `torchjd.linalg`. - Made `WeightedAggregator` and `GramianWeightedAggregator` public. These abstract base classes are now importable from `torchjd.aggregation` and documented. They can be extended to easily implement custom `Aggregator`s. From cf0439cf40ddcfb01950d9669888af747dab50ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 13 May 2026 19:48:45 +0200 Subject: [PATCH 20/20] Fix formatting of maths in docstrings of projectors --- src/torchjd/_linalg/_dual_cone.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/src/torchjd/_linalg/_dual_cone.py b/src/torchjd/_linalg/_dual_cone.py index 0ffe5338..750e97e9 100644 --- a/src/torchjd/_linalg/_dual_cone.py +++ b/src/torchjd/_linalg/_dual_cone.py @@ -19,21 +19,27 @@ class DualConeProjector(ABC): @abstractmethod def __call__(self, U: Tensor, G: PSDMatrix) -> Tensor: r""" - Computes the weights `w` of the projection of `J^T u` onto the dual cone of - the rows of `J`, provided `G = J J^T` and `u`. In other words, this computes the `w` that - satisfies `\pi_J(J^T u) = J^T w`, with `\pi_J` defined in Equation 3 of [1]. + Computes the weights :math:`w` of the projection of :math:`J^\top u` onto the dual cone of + the rows of :math:`J`, provided :math:`G = J J^\top` and :math:`u`. In other words, this + computes the :math:`w` that satisfies :math:`\pi_J(J^\top u) = J^\top w`, with + :math:`\pi_J` defined in Equation 3 of [1]. - By Proposition 1 of [1], this is equivalent to solving for `v` the following quadratic - program: - minimize v^T G v - subject to u \preceq v + By Proposition 1 of [1], this is equivalent to solving for :math:`v` the following + quadratic program: + + .. math:: + + \min_{v} \quad & v^\top G v \\ + \text{subject to} \quad & u \preceq v Reference: [1] `Jacobian Descent For Multi-Objective Optimization `_. - :param U: The tensor of weights corresponding to the vectors to project, of shape `[..., m]`. - :param G: The Gramian matrix of shape `[m, m]`. It must be symmetric and positive definite. - :return: A tensor of projection weights with the same shape as `U`. + :param U: The tensor of weights corresponding to the vectors to project, of shape + ``[..., m]``. + :param G: The Gramian matrix of shape ``[m, m]``. It must be symmetric and positive + definite. + :return: A tensor of projection weights with the same shape as ``U``. """