-
Notifications
You must be signed in to change notification settings - Fork 16
refactor!: Add DualConeProjector
#678
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
336b47c
f1076a7
3343ebb
25d0c91
c63571d
f7f4fd6
b2e6d2a
2f05e89
470ccd0
a8ccd16
5780f6b
70f814f
83b528e
563faff
e849d89
01ffc2b
42f227b
89bc6d1
33d171a
e8291ae
cf0439c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,9 @@ | ||
| :hide-toc: | ||
|
|
||
| Dual Cone Projectors | ||
| ==================== | ||
|
|
||
| .. autoclass:: torchjd.linalg.DualConeProjector | ||
| :members: __call__ | ||
|
|
||
| .. autoclass:: torchjd.linalg.QuadprogProjector |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,3 +10,4 @@ linalg | |
|
|
||
| matrix.rst | ||
| psd_matrix.rst | ||
| dual_cone.rst | ||
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,101 @@ | ||||||||
| from abc import ABC, abstractmethod | ||||||||
|
|
||||||||
| import numpy as np | ||||||||
| import torch | ||||||||
| from qpsolvers import solve_qp | ||||||||
| from torch import Tensor | ||||||||
|
|
||||||||
| from ._gramian import normalize, regularize | ||||||||
| from ._matrix import PSDMatrix | ||||||||
|
|
||||||||
|
|
||||||||
| class DualConeProjector(ABC): | ||||||||
| """ | ||||||||
| Abstract class whose instances are responsible for projecting vectors onto the dual cone of the | ||||||||
| rows of a matrix, or rather the dual form of this problem. The current default | ||||||||
| :class:`~torchjd.linalg.DualConeProjector` is :class:`~torchjd.linalg.QuadprogProjector`. | ||||||||
| """ | ||||||||
|
|
||||||||
| @abstractmethod | ||||||||
| def __call__(self, U: Tensor, G: PSDMatrix) -> Tensor: | ||||||||
| r""" | ||||||||
| Computes the weights :math:`w` of the projection of :math:`J^\top u` onto the dual cone of | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
| the rows of :math:`J`, provided :math:`G = J J^\top` and :math:`u`. In other words, this | ||||||||
| computes the :math:`w` that satisfies :math:`\pi_J(J^\top u) = J^\top w`, with | ||||||||
| :math:`\pi_J` defined in Equation 3 of [1]. | ||||||||
|
|
||||||||
| By Proposition 1 of [1], this is equivalent to solving for :math:`v` the following | ||||||||
| quadratic program: | ||||||||
|
|
||||||||
| .. math:: | ||||||||
|
|
||||||||
| \min_{v} \quad & v^\top G v \\ | ||||||||
| \text{subject to} \quad & u \preceq v | ||||||||
|
|
||||||||
| Reference: | ||||||||
| [1] `Jacobian Descent For Multi-Objective Optimization <https://arxiv.org/pdf/2406.16232>`_. | ||||||||
|
|
||||||||
| :param U: The tensor of weights corresponding to the vectors to project, of shape | ||||||||
| ``[..., m]``. | ||||||||
| :param G: The Gramian matrix of shape ``[m, m]``. It must be symmetric and positive | ||||||||
| definite. | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The gramian is only PSD in general I think, even if some solvers only take PD gramians, it's the role of the projector to regularize from PSD to PD and then to apply the solver. So IMO we should talk about PSD here.
Suggested change
|
||||||||
| :return: A tensor of projection weights with the same shape as ``U``. | ||||||||
| """ | ||||||||
|
|
||||||||
|
|
||||||||
| def projector_or_default(projector: DualConeProjector | None) -> DualConeProjector: | ||||||||
| if projector is None: | ||||||||
| return QuadprogProjector() | ||||||||
| return projector | ||||||||
|
|
||||||||
|
|
||||||||
| class QuadprogProjector(DualConeProjector): | ||||||||
| r""" | ||||||||
| Solves the quadratic program defined in :meth:`DualConeProjector.__call__` using the | ||||||||
| ``quadprog`` QP solver. | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could add a link to https://github.com/quadprog/quadprog Also, we could say in what this solver is good and what are its limitations. E.g. "This solver is very precise and converges in fixed time, but it is slow when m is large (e.g. m >= 64)." This could come in a future PR though, when we add another solver. So feel free to no include that. |
||||||||
|
|
||||||||
| :param norm_eps: A small value to avoid division by zero when normalizing. | ||||||||
| :param reg_eps: A small value to add to the diagonal of the gramian of the matrix. Due to | ||||||||
| numerical errors when computing the gramian, it might not exactly be positive definite. | ||||||||
| This issue can make the optimization fail. Adding ``reg_eps`` to the diagonal of the gramian | ||||||||
| ensures that it is positive definite. | ||||||||
| """ | ||||||||
|
|
||||||||
| def __init__( | ||||||||
| self, | ||||||||
| *, | ||||||||
| norm_eps: float = 0.0001, | ||||||||
| reg_eps: float = 0.0001, | ||||||||
| ) -> None: | ||||||||
| self.norm_eps = norm_eps | ||||||||
| self.reg_eps = reg_eps | ||||||||
|
|
||||||||
| def __repr__(self) -> str: | ||||||||
| return f"QuadprogProjector(norm_eps={self.norm_eps}, reg_eps={self.reg_eps})" | ||||||||
|
|
||||||||
| def __call__(self, U: Tensor, G: PSDMatrix) -> Tensor: | ||||||||
|
|
||||||||
| G = regularize(normalize(G, self.norm_eps), self.reg_eps) | ||||||||
|
|
||||||||
| G_ = _to_array(G) | ||||||||
| U_ = _to_array(U) | ||||||||
|
|
||||||||
| W = np.apply_along_axis(lambda u: self._project_weight_vector(u, G_), axis=-1, arr=U_) | ||||||||
|
|
||||||||
| return torch.as_tensor(W, device=G.device, dtype=G.dtype) | ||||||||
|
|
||||||||
| def _project_weight_vector(self, u: np.ndarray, G: np.ndarray) -> np.ndarray: | ||||||||
|
|
||||||||
| m = G.shape[0] | ||||||||
| w = solve_qp(G, np.zeros(m), -np.eye(m), -u, solver="quadprog") | ||||||||
|
|
||||||||
| if w is None: # This may happen when G has large values. | ||||||||
| raise ValueError("Failed to solve the quadratic programming problem.") | ||||||||
|
|
||||||||
| return w | ||||||||
|
|
||||||||
|
|
||||||||
| def _to_array(tensor: Tensor) -> np.ndarray: | ||||||||
| """Transforms a tensor into a numpy array with float64 dtype.""" | ||||||||
|
|
||||||||
| return tensor.cpu().detach().numpy().astype(np.float64) | ||||||||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -1,12 +1,11 @@ | ||||||
| from torch import Tensor | ||||||
|
|
||||||
| from torchjd._linalg import normalize, regularize | ||||||
| from torchjd._linalg import DualConeProjector, projector_or_default | ||||||
| from torchjd.linalg import PSDMatrix | ||||||
|
|
||||||
| from ._aggregator_bases import GramianWeightedAggregator | ||||||
| from ._mean import MeanWeighting | ||||||
| from ._mixins import _NonDifferentiable | ||||||
| from ._utils.dual_cone import SUPPORTED_SOLVER, project_weights | ||||||
| from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting | ||||||
| from ._weighting_bases import _GramianWeighting | ||||||
|
|
||||||
|
|
@@ -19,31 +18,21 @@ class DualProjWeighting(_NonDifferentiable, _GramianWeighting): | |||||
|
|
||||||
| :param pref_vector: The preference vector to use. If not provided, defaults to | ||||||
| :math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`. | ||||||
| :param norm_eps: A small value to avoid division by zero when normalizing. | ||||||
| :param reg_eps: A small value to add to the diagonal of the gramian of the matrix. Due to | ||||||
| numerical errors when computing the gramian, it might not exactly be positive definite. | ||||||
| This issue can make the optimization fail. Adding ``reg_eps`` to the diagonal of the gramian | ||||||
| ensures that it is positive definite. | ||||||
| :param solver: The solver used to optimize the underlying optimization problem. | ||||||
| :param projector: The :class:`~torchjd.linalg.DualConeProjector` used tocompute the projection. | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| """ | ||||||
|
|
||||||
| def __init__( | ||||||
| self, | ||||||
| pref_vector: Tensor | None = None, | ||||||
| norm_eps: float = 0.0001, | ||||||
| reg_eps: float = 0.0001, | ||||||
| solver: SUPPORTED_SOLVER = "quadprog", | ||||||
| projector: DualConeProjector | None = None, | ||||||
| ) -> None: | ||||||
| super().__init__() | ||||||
| self.pref_vector = pref_vector | ||||||
| self.norm_eps = norm_eps | ||||||
| self.reg_eps = reg_eps | ||||||
| self.solver: SUPPORTED_SOLVER = solver | ||||||
| self.projector = projector_or_default(projector) | ||||||
|
|
||||||
| def forward(self, gramian: PSDMatrix, /) -> Tensor: | ||||||
| u = self.weighting(gramian) | ||||||
| G = regularize(normalize(gramian, self.norm_eps), self.reg_eps) | ||||||
| w = project_weights(u, G, self.solver) | ||||||
| w = self.projector(u, gramian) | ||||||
| return w | ||||||
|
|
||||||
| @property | ||||||
|
|
@@ -56,26 +45,12 @@ def pref_vector(self, value: Tensor | None) -> None: | |||||
| self._pref_vector = value | ||||||
|
|
||||||
| @property | ||||||
| def norm_eps(self) -> float: | ||||||
| return self._norm_eps | ||||||
| def projector(self) -> DualConeProjector: | ||||||
| return self._projector | ||||||
|
|
||||||
| @norm_eps.setter | ||||||
| def norm_eps(self, value: float) -> None: | ||||||
| if value < 0: | ||||||
| raise ValueError(f"norm_eps must be non-negative, but got {value}.") | ||||||
|
|
||||||
| self._norm_eps = value | ||||||
|
|
||||||
| @property | ||||||
| def reg_eps(self) -> float: | ||||||
| return self._reg_eps | ||||||
|
|
||||||
| @reg_eps.setter | ||||||
| def reg_eps(self, value: float) -> None: | ||||||
| if value < 0: | ||||||
| raise ValueError(f"reg_eps must be non-negative, but got {value}.") | ||||||
|
|
||||||
| self._reg_eps = value | ||||||
| @projector.setter | ||||||
| def projector(self, value: DualConeProjector | None) -> None: | ||||||
| self._projector = projector_or_default(value) | ||||||
|
|
||||||
|
|
||||||
| class DualProj(_NonDifferentiable, GramianWeightedAggregator): | ||||||
|
|
@@ -87,27 +62,18 @@ class DualProj(_NonDifferentiable, GramianWeightedAggregator): | |||||
|
|
||||||
| :param pref_vector: The preference vector used to combine the rows. If not provided, defaults to | ||||||
| :math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`. | ||||||
| :param norm_eps: A small value to avoid division by zero when normalizing. | ||||||
| :param reg_eps: A small value to add to the diagonal of the gramian of the matrix. Due to | ||||||
| numerical errors when computing the gramian, it might not exactly be positive definite. | ||||||
| This issue can make the optimization fail. Adding ``reg_eps`` to the diagonal of the gramian | ||||||
| ensures that it is positive definite. | ||||||
| :param solver: The solver used to optimize the underlying optimization problem. | ||||||
| :param projector: The :class:`~torchjd.linalg.DualConeProjector` used tocompute the projection. | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| """ | ||||||
|
|
||||||
| gramian_weighting: DualProjWeighting | ||||||
|
|
||||||
| def __init__( | ||||||
| self, | ||||||
| pref_vector: Tensor | None = None, | ||||||
| norm_eps: float = 0.0001, | ||||||
| reg_eps: float = 0.0001, | ||||||
| solver: SUPPORTED_SOLVER = "quadprog", | ||||||
| projector: DualConeProjector | None = None, | ||||||
| ) -> None: | ||||||
| self._solver: SUPPORTED_SOLVER = solver | ||||||
|
|
||||||
| super().__init__( | ||||||
| DualProjWeighting(pref_vector, norm_eps=norm_eps, reg_eps=reg_eps, solver=solver), | ||||||
| DualProjWeighting(pref_vector, projector=projector), | ||||||
| ) | ||||||
|
|
||||||
| @property | ||||||
|
|
@@ -119,25 +85,17 @@ def pref_vector(self, value: Tensor | None) -> None: | |||||
| self.gramian_weighting.pref_vector = value | ||||||
|
|
||||||
| @property | ||||||
| def norm_eps(self) -> float: | ||||||
| return self.gramian_weighting.norm_eps | ||||||
|
|
||||||
| @norm_eps.setter | ||||||
| def norm_eps(self, value: float) -> None: | ||||||
| self.gramian_weighting.norm_eps = value | ||||||
|
|
||||||
| @property | ||||||
| def reg_eps(self) -> float: | ||||||
| return self.gramian_weighting.reg_eps | ||||||
| def projector(self) -> DualConeProjector: | ||||||
| return self.gramian_weighting.projector | ||||||
|
|
||||||
| @reg_eps.setter | ||||||
| def reg_eps(self, value: float) -> None: | ||||||
| self.gramian_weighting.reg_eps = value | ||||||
| @projector.setter | ||||||
| def projector(self, value: DualConeProjector | None) -> None: | ||||||
| self.gramian_weighting.projector = value | ||||||
|
|
||||||
| def __repr__(self) -> str: | ||||||
| return ( | ||||||
| f"{self.__class__.__name__}(pref_vector={repr(self.pref_vector)}, norm_eps=" | ||||||
| f"{self.norm_eps}, reg_eps={self.reg_eps}, solver={repr(self._solver)})" | ||||||
| f"{self.__class__.__name__}(pref_vector={repr(self.pref_vector)}, projector=" | ||||||
| f"{repr(self.projector)})" | ||||||
| ) | ||||||
|
|
||||||
| def __str__(self) -> str: | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd rather not talk about subclasses in the docstring of the main class unless that's really necessary to make things understandable.