Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
336b47c
Add DualConeProjector
PierreQuinton May 11, 2026
f1076a7
Implement and use QPSolverBased. Set as default.
PierreQuinton May 11, 2026
3343ebb
add getters and setters for projectors in UPGrad and DualProj
PierreQuinton May 11, 2026
25d0c91
fix typing
PierreQuinton May 11, 2026
c63571d
Make PCGrad test use default Projector
PierreQuinton May 11, 2026
f7f4fd6
Merge branch 'main' into add-weight-projector
ValerianRey May 12, 2026
b2e6d2a
rename to __call__
PierreQuinton May 12, 2026
2f05e89
Move noramlization and regularization to QPBasedProjector.
PierreQuinton May 12, 2026
470ccd0
Rename QPBased to QuadprogProjector. Remove parameter solver. Will co…
PierreQuinton May 12, 2026
a8ccd16
Fix representations
PierreQuinton May 12, 2026
5780f6b
Improve docstring of `QuadprogProjector`
PierreQuinton May 12, 2026
70f814f
Expose projectors
PierreQuinton May 12, 2026
83b528e
Add default specification in docs.
PierreQuinton May 12, 2026
563faff
Update docstring parameter solver to projector, update docstring of p…
PierreQuinton May 12, 2026
e849d89
remove norm_eps and reg_eps setters and getters from aggregators/weig…
PierreQuinton May 12, 2026
01ffc2b
Add changelog entry.
PierreQuinton May 12, 2026
42f227b
remove eps setters/getters tests.
PierreQuinton May 12, 2026
89bc6d1
fix interactive plotter
PierreQuinton May 12, 2026
33d171a
add getters/setters tests for projectors.
PierreQuinton May 12, 2026
e8291ae
Improve changelog
ValerianRey May 13, 2026
cf0439c
Fix formatting of maths in docstrings of projectors
ValerianRey May 13, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,21 @@ changelog does not include internal changes that do not affect the user.

### Changed

- **BREAKING**: Removed `norm_eps`, `rep_eps` and `solver` parameters from the `__init__` of
`UPGrad`, `UPGradWeighting`, `DualProj` and `DualProjWeighting` in favor of a `projector`
parameter of type `DualConeProjector`. To update:
```python
# Before
from torchjd.aggregation import UPGrad
aggregator = UPGrad(norm_eps=1e-6, reg_eps=1e-6, solver="quadprog")

# After
from torchjd.aggregation import UPGrad
from torchjd.linalg import QuadprogProjector
aggregator = UPGrad(projector=QuadprogProjector(norm_eps=1e-6, reg_eps=1e-6))
```
If you used the default `norm_eps`, `reg_eps` and `solver`, you don't have to change anything and
you will get the same results.
- `CAGrad`, `CAGradWeighting`, and `NashMTL` are now always importable from `torchjd.aggregation`,
even when their optional dependencies are not installed. Attempting to instantiate them without the
required dependencies now raises an `ImportError` with installation instructions, instead of
Expand All @@ -21,6 +36,9 @@ changelog does not include internal changes that do not affect the user.

### Added

- Added a new abstraction: the `DualConeProjector` abstract base class and its concrete
`QuadprogProjector` implementation, to do the projection of the gradients onto the dual cone, as
required in `UPGrad`, and `DualProj`. These classes can be found in `torchjd.linalg`.
- Made `WeightedAggregator` and `GramianWeightedAggregator` public. These abstract base classes are
now importable from `torchjd.aggregation` and documented. They can be extended to easily implement
custom `Aggregator`s.
Expand Down
9 changes: 9 additions & 0 deletions docs/source/docs/linalg/dual_cone.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
:hide-toc:

Dual Cone Projectors
====================

.. autoclass:: torchjd.linalg.DualConeProjector
:members: __call__

.. autoclass:: torchjd.linalg.QuadprogProjector
1 change: 1 addition & 0 deletions docs/source/docs/linalg/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ linalg

matrix.rst
psd_matrix.rst
dual_cone.rst
4 changes: 4 additions & 0 deletions src/torchjd/_linalg/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ._dual_cone import DualConeProjector, QuadprogProjector, projector_or_default
from ._generalized_gramian import flatten, movedim, reshape
from ._gramian import compute_gramian, normalize, regularize
from ._matrix import Matrix, PSDMatrix, PSDTensor, is_matrix, is_psd_matrix, is_psd_tensor
Expand All @@ -15,4 +16,7 @@
"flatten",
"reshape",
"movedim",
"DualConeProjector",
"QuadprogProjector",
"projector_or_default",
]
101 changes: 101 additions & 0 deletions src/torchjd/_linalg/_dual_cone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from abc import ABC, abstractmethod

import numpy as np
import torch
from qpsolvers import solve_qp
from torch import Tensor

from ._gramian import normalize, regularize
from ._matrix import PSDMatrix


class DualConeProjector(ABC):
"""
Abstract class whose instances are responsible for projecting vectors onto the dual cone of the
rows of a matrix, or rather the dual form of this problem. The current default
:class:`~torchjd.linalg.DualConeProjector` is :class:`~torchjd.linalg.QuadprogProjector`.
Comment on lines +15 to +16
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather not talk about subclasses in the docstring of the main class unless that's really necessary to make things understandable.

Suggested change
rows of a matrix, or rather the dual form of this problem. The current default
:class:`~torchjd.linalg.DualConeProjector` is :class:`~torchjd.linalg.QuadprogProjector`.
rows of a matrix, or rather the dual form of this problem.

"""

@abstractmethod
def __call__(self, U: Tensor, G: PSDMatrix) -> Tensor:
r"""
Computes the weights :math:`w` of the projection of :math:`J^\top u` onto the dual cone of
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Computes the weights :math:`w` of the projection of :math:`J^\top u` onto the dual cone of
Computes for each vector :math:`u` in the provided tensor ``U``
the weights :math:`w` of the projection of :math:`J^\top u` onto the dual cone of

the rows of :math:`J`, provided :math:`G = J J^\top` and :math:`u`. In other words, this
computes the :math:`w` that satisfies :math:`\pi_J(J^\top u) = J^\top w`, with
:math:`\pi_J` defined in Equation 3 of [1].

By Proposition 1 of [1], this is equivalent to solving for :math:`v` the following
quadratic program:

.. math::

\min_{v} \quad & v^\top G v \\
\text{subject to} \quad & u \preceq v

Reference:
[1] `Jacobian Descent For Multi-Objective Optimization <https://arxiv.org/pdf/2406.16232>`_.

:param U: The tensor of weights corresponding to the vectors to project, of shape
``[..., m]``.
:param G: The Gramian matrix of shape ``[m, m]``. It must be symmetric and positive
definite.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The gramian is only PSD in general I think, even if some solvers only take PD gramians, it's the role of the projector to regularize from PSD to PD and then to apply the solver. So IMO we should talk about PSD here.

Suggested change
definite.
semi-definite.

:return: A tensor of projection weights with the same shape as ``U``.
"""


def projector_or_default(projector: DualConeProjector | None) -> DualConeProjector:
if projector is None:
return QuadprogProjector()
return projector


class QuadprogProjector(DualConeProjector):
r"""
Solves the quadratic program defined in :meth:`DualConeProjector.__call__` using the
``quadprog`` QP solver.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could add a link to https://github.com/quadprog/quadprog

Also, we could say in what this solver is good and what are its limitations. E.g. "This solver is very precise and converges in fixed time, but it is slow when m is large (e.g. m >= 64)."

This could come in a future PR though, when we add another solver. So feel free to no include that.


:param norm_eps: A small value to avoid division by zero when normalizing.
:param reg_eps: A small value to add to the diagonal of the gramian of the matrix. Due to
numerical errors when computing the gramian, it might not exactly be positive definite.
This issue can make the optimization fail. Adding ``reg_eps`` to the diagonal of the gramian
ensures that it is positive definite.
"""

def __init__(
self,
*,
norm_eps: float = 0.0001,
reg_eps: float = 0.0001,
) -> None:
self.norm_eps = norm_eps
self.reg_eps = reg_eps

def __repr__(self) -> str:
return f"QuadprogProjector(norm_eps={self.norm_eps}, reg_eps={self.reg_eps})"

def __call__(self, U: Tensor, G: PSDMatrix) -> Tensor:

G = regularize(normalize(G, self.norm_eps), self.reg_eps)

G_ = _to_array(G)
U_ = _to_array(U)

W = np.apply_along_axis(lambda u: self._project_weight_vector(u, G_), axis=-1, arr=U_)

return torch.as_tensor(W, device=G.device, dtype=G.dtype)

def _project_weight_vector(self, u: np.ndarray, G: np.ndarray) -> np.ndarray:

m = G.shape[0]
w = solve_qp(G, np.zeros(m), -np.eye(m), -u, solver="quadprog")

if w is None: # This may happen when G has large values.
raise ValueError("Failed to solve the quadratic programming problem.")

return w


def _to_array(tensor: Tensor) -> np.ndarray:
"""Transforms a tensor into a numpy array with float64 dtype."""

return tensor.cpu().detach().numpy().astype(np.float64)
82 changes: 20 additions & 62 deletions src/torchjd/aggregation/_dualproj.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from torch import Tensor

from torchjd._linalg import normalize, regularize
from torchjd._linalg import DualConeProjector, projector_or_default
from torchjd.linalg import PSDMatrix

from ._aggregator_bases import GramianWeightedAggregator
from ._mean import MeanWeighting
from ._mixins import _NonDifferentiable
from ._utils.dual_cone import SUPPORTED_SOLVER, project_weights
from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting
from ._weighting_bases import _GramianWeighting

Expand All @@ -19,31 +18,21 @@ class DualProjWeighting(_NonDifferentiable, _GramianWeighting):

:param pref_vector: The preference vector to use. If not provided, defaults to
:math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`.
:param norm_eps: A small value to avoid division by zero when normalizing.
:param reg_eps: A small value to add to the diagonal of the gramian of the matrix. Due to
numerical errors when computing the gramian, it might not exactly be positive definite.
This issue can make the optimization fail. Adding ``reg_eps`` to the diagonal of the gramian
ensures that it is positive definite.
:param solver: The solver used to optimize the underlying optimization problem.
:param projector: The :class:`~torchjd.linalg.DualConeProjector` used tocompute the projection.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
:param projector: The :class:`~torchjd.linalg.DualConeProjector` used tocompute the projection.
:param projector: The :class:`~torchjd.linalg.DualConeProjector` used to compute the projection.

"""

def __init__(
self,
pref_vector: Tensor | None = None,
norm_eps: float = 0.0001,
reg_eps: float = 0.0001,
solver: SUPPORTED_SOLVER = "quadprog",
projector: DualConeProjector | None = None,
) -> None:
super().__init__()
self.pref_vector = pref_vector
self.norm_eps = norm_eps
self.reg_eps = reg_eps
self.solver: SUPPORTED_SOLVER = solver
self.projector = projector_or_default(projector)

def forward(self, gramian: PSDMatrix, /) -> Tensor:
u = self.weighting(gramian)
G = regularize(normalize(gramian, self.norm_eps), self.reg_eps)
w = project_weights(u, G, self.solver)
w = self.projector(u, gramian)
return w

@property
Expand All @@ -56,26 +45,12 @@ def pref_vector(self, value: Tensor | None) -> None:
self._pref_vector = value

@property
def norm_eps(self) -> float:
return self._norm_eps
def projector(self) -> DualConeProjector:
return self._projector

@norm_eps.setter
def norm_eps(self, value: float) -> None:
if value < 0:
raise ValueError(f"norm_eps must be non-negative, but got {value}.")

self._norm_eps = value

@property
def reg_eps(self) -> float:
return self._reg_eps

@reg_eps.setter
def reg_eps(self, value: float) -> None:
if value < 0:
raise ValueError(f"reg_eps must be non-negative, but got {value}.")

self._reg_eps = value
@projector.setter
def projector(self, value: DualConeProjector | None) -> None:
self._projector = projector_or_default(value)


class DualProj(_NonDifferentiable, GramianWeightedAggregator):
Expand All @@ -87,27 +62,18 @@ class DualProj(_NonDifferentiable, GramianWeightedAggregator):

:param pref_vector: The preference vector used to combine the rows. If not provided, defaults to
:math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`.
:param norm_eps: A small value to avoid division by zero when normalizing.
:param reg_eps: A small value to add to the diagonal of the gramian of the matrix. Due to
numerical errors when computing the gramian, it might not exactly be positive definite.
This issue can make the optimization fail. Adding ``reg_eps`` to the diagonal of the gramian
ensures that it is positive definite.
:param solver: The solver used to optimize the underlying optimization problem.
:param projector: The :class:`~torchjd.linalg.DualConeProjector` used tocompute the projection.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
:param projector: The :class:`~torchjd.linalg.DualConeProjector` used tocompute the projection.
:param projector: The :class:`~torchjd.linalg.DualConeProjector` used to compute the projection.

"""

gramian_weighting: DualProjWeighting

def __init__(
self,
pref_vector: Tensor | None = None,
norm_eps: float = 0.0001,
reg_eps: float = 0.0001,
solver: SUPPORTED_SOLVER = "quadprog",
projector: DualConeProjector | None = None,
) -> None:
self._solver: SUPPORTED_SOLVER = solver

super().__init__(
DualProjWeighting(pref_vector, norm_eps=norm_eps, reg_eps=reg_eps, solver=solver),
DualProjWeighting(pref_vector, projector=projector),
)

@property
Expand All @@ -119,25 +85,17 @@ def pref_vector(self, value: Tensor | None) -> None:
self.gramian_weighting.pref_vector = value

@property
def norm_eps(self) -> float:
return self.gramian_weighting.norm_eps

@norm_eps.setter
def norm_eps(self, value: float) -> None:
self.gramian_weighting.norm_eps = value

@property
def reg_eps(self) -> float:
return self.gramian_weighting.reg_eps
def projector(self) -> DualConeProjector:
return self.gramian_weighting.projector

@reg_eps.setter
def reg_eps(self, value: float) -> None:
self.gramian_weighting.reg_eps = value
@projector.setter
def projector(self, value: DualConeProjector | None) -> None:
self.gramian_weighting.projector = value

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}(pref_vector={repr(self.pref_vector)}, norm_eps="
f"{self.norm_eps}, reg_eps={self.reg_eps}, solver={repr(self._solver)})"
f"{self.__class__.__name__}(pref_vector={repr(self.pref_vector)}, projector="
f"{repr(self.projector)})"
)

def __str__(self) -> str:
Expand Down
Loading
Loading