|
1 | 1 | # Partly adapted from https://github.com/AvivNavon/nash-mtl — MIT License, Copyright (c) 2022 Aviv Navon. |
2 | 2 | # See NOTICES for the full license text. |
3 | 3 |
|
4 | | -from torchjd.aggregation._mixins import Stateful, _NonDifferentiable |
| 4 | +from __future__ import annotations |
5 | 5 |
|
6 | | -from ._utils.check_dependencies import check_dependencies_are_installed |
7 | | -from ._weighting_bases import _MatrixWeighting |
8 | | - |
9 | | -check_dependencies_are_installed(["cvxpy", "ecos"]) |
| 6 | +import contextlib |
10 | 7 |
|
11 | | -import cvxpy as cp |
12 | 8 | import numpy as np |
13 | 9 | import torch |
14 | | -from cvxpy import Expression, SolverError |
15 | 10 | from torch import Tensor |
16 | 11 |
|
| 12 | +from torchjd.aggregation._mixins import Stateful, _NonDifferentiable, _WithOptionalDeps |
| 13 | + |
17 | 14 | from ._aggregator_bases import WeightedAggregator |
| 15 | +from ._weighting_bases import _MatrixWeighting |
| 16 | + |
| 17 | +with contextlib.suppress(ImportError): |
| 18 | + import cvxpy as cp |
| 19 | + from cvxpy import Expression, SolverError |
18 | 20 |
|
19 | 21 |
|
20 | 22 | # Non-differentiable: the cvxpy solver operates on numpy arrays, breaking the autograd graph. |
21 | | -class _NashMTLWeighting(_NonDifferentiable, Stateful, _MatrixWeighting): |
| 23 | +class _NashMTLWeighting(_WithOptionalDeps, _NonDifferentiable, Stateful, _MatrixWeighting): |
| 24 | + _REQUIRED_DEPS = ["cvxpy", "ecos"] |
| 25 | + _INSTALL_HINT = 'Install them with: pip install "torchjd[nash_mtl]"' |
22 | 26 | """ |
23 | 27 | :class:`~torchjd.aggregation._mixins.Stateful` |
24 | 28 | :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.Matrix`] that |
@@ -215,10 +219,9 @@ class NashMTL(_NonDifferentiable, Stateful, WeightedAggregator): |
215 | 219 | :param optim_niter: The number of iterations of the underlying optimization process. |
216 | 220 |
|
217 | 221 | .. note:: |
218 | | - This aggregator is not installed by default. When not installed, trying to import it should |
219 | | - result in the following error: |
220 | | - ``ImportError: cannot import name 'NashMTL' from 'torchjd.aggregation'``. |
221 | | - To install it, use ``pip install "torchjd[nash_mtl]"``. |
| 222 | + This aggregator requires optional dependencies. When they are not installed, instantiating |
| 223 | + it raises an :class:`ImportError` with installation instructions. |
| 224 | + To install them, use ``pip install "torchjd[nash_mtl]"``. |
222 | 225 |
|
223 | 226 | .. warning:: |
224 | 227 | This implementation was adapted from the `official implementation |
|
0 commit comments