Skip to content

Commit 57df9b9

Browse files
ValerianReyclaude
andcommitted
refactor(aggregation): Make NashMTL and CAGrad always importable (#678)
Add _WithOptionalDeps mixin that raises ImportError at instantiation time when optional dependencies are missing, replacing the module-level guard that previously prevented import altogether. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent d1973e7 commit 57df9b9

6 files changed

Lines changed: 74 additions & 57 deletions

File tree

src/torchjd/aggregation/__init__.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262

6363
from ._aggregator_bases import Aggregator, GramianWeightedAggregator, WeightedAggregator
6464
from ._aligned_mtl import AlignedMTL, AlignedMTLWeighting
65+
from ._cagrad import CAGrad, CAGradWeighting
6566
from ._config import ConFIG
6667
from ._constant import Constant, ConstantWeighting
6768
from ._dualproj import DualProj, DualProjWeighting
@@ -73,20 +74,20 @@
7374
from ._mean import Mean, MeanWeighting
7475
from ._mgda import MGDA, MGDAWeighting
7576
from ._mixins import Stateful
77+
from ._nash_mtl import NashMTL
7678
from ._pcgrad import PCGrad, PCGradWeighting
7779
from ._random import Random, RandomWeighting
7880
from ._sum import Sum, SumWeighting
7981
from ._trimmed_mean import TrimmedMean
8082
from ._upgrad import UPGrad, UPGradWeighting
81-
from ._utils.check_dependencies import (
82-
OptionalDepsNotInstalledError as _OptionalDepsNotInstalledError,
83-
)
8483
from ._weighting_bases import GeneralizedWeighting, Weighting
8584

8685
__all__ = [
8786
"Aggregator",
8887
"AlignedMTL",
8988
"AlignedMTLWeighting",
89+
"CAGrad",
90+
"CAGradWeighting",
9091
"ConFIG",
9192
"Constant",
9293
"ConstantWeighting",
@@ -106,6 +107,7 @@
106107
"MeanWeighting",
107108
"MGDA",
108109
"MGDAWeighting",
110+
"NashMTL",
109111
"PCGrad",
110112
"PCGradWeighting",
111113
"Random",
@@ -119,17 +121,3 @@
119121
"WeightedAggregator",
120122
"Weighting",
121123
]
122-
123-
try:
124-
from ._cagrad import CAGrad, CAGradWeighting
125-
126-
__all__ += ["CAGrad", "CAGradWeighting"]
127-
except _OptionalDepsNotInstalledError: # The required dependencies are not installed
128-
pass
129-
130-
try:
131-
from ._nash_mtl import NashMTL
132-
133-
__all__ += ["NashMTL"]
134-
except _OptionalDepsNotInstalledError: # The required dependencies are not installed
135-
pass

src/torchjd/aggregation/_cagrad.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,25 @@
1+
import contextlib
12
from typing import cast
23

3-
from torchjd.linalg import PSDMatrix
4-
5-
from ._mixins import _NonDifferentiable
6-
from ._utils.check_dependencies import check_dependencies_are_installed
7-
from ._weighting_bases import _GramianWeighting
8-
9-
check_dependencies_are_installed(["cvxpy", "clarabel"])
10-
11-
import cvxpy as cp
124
import numpy as np
135
import torch
146
from torch import Tensor
157

168
from torchjd._linalg import normalize
9+
from torchjd.linalg import PSDMatrix
1710

1811
from ._aggregator_bases import GramianWeightedAggregator
12+
from ._mixins import _NonDifferentiable, _WithOptionalDeps
13+
from ._weighting_bases import _GramianWeighting
14+
15+
with contextlib.suppress(ImportError):
16+
import cvxpy as cp
1917

2018

2119
# Non-differentiable: the cvxpy solver operates on numpy arrays, breaking the autograd graph.
22-
class CAGradWeighting(_NonDifferentiable, _GramianWeighting):
20+
class CAGradWeighting(_WithOptionalDeps, _NonDifferentiable, _GramianWeighting):
21+
_REQUIRED_DEPS = ["cvxpy", "clarabel"]
22+
_INSTALL_HINT = 'Install them with: pip install "torchjd[cagrad]"'
2323
"""
2424
:class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`]
2525
giving the weights of :class:`~torchjd.aggregation.CAGrad`.
@@ -103,10 +103,9 @@ class CAGrad(_NonDifferentiable, GramianWeightedAggregator):
103103
:param norm_eps: A small value to avoid division by zero when normalizing.
104104
105105
.. note::
106-
This aggregator is not installed by default. When not installed, trying to import it should
107-
result in the following error:
108-
``ImportError: cannot import name 'CAGrad' from 'torchjd.aggregation'``.
109-
To install it, use ``pip install "torchjd[cagrad]"``.
106+
This aggregator requires optional dependencies. When they are not installed, instantiating
107+
it raises an :class:`ImportError` with installation instructions.
108+
To install them, use ``pip install "torchjd[cagrad]"``.
110109
"""
111110

112111
gramian_weighting: CAGradWeighting

src/torchjd/aggregation/_mixins.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,37 @@
11
from abc import ABC, abstractmethod
2+
from importlib.util import find_spec
23
from typing import Any
34

45
import torch
56
from torch import nn
67

78

9+
class _WithOptionalDeps:
10+
"""
11+
Mixin that raises :class:`ImportError` at instantiation time if required optional dependencies
12+
are not installed.
13+
14+
Subclasses must define :attr:`_REQUIRED_DEPS` (list of package names to check via
15+
:func:`importlib.util.find_spec`) and :attr:`_INSTALL_HINT` (appended to the error message).
16+
17+
.. warning::
18+
This mixin must appear **first** in the inheritance list so that its :meth:`__init__`
19+
runs before any base class that uses the optional dependencies.
20+
"""
21+
22+
_REQUIRED_DEPS: list[str]
23+
_INSTALL_HINT: str
24+
25+
def __init__(self, *args: Any, **kwargs: Any) -> None:
26+
missing = [name for name in self._REQUIRED_DEPS if find_spec(name) is None]
27+
if missing:
28+
raise ImportError(
29+
f"{self.__class__.__name__} requires {missing} to be installed. "
30+
f"{self._INSTALL_HINT}"
31+
)
32+
super().__init__(*args, **kwargs)
33+
34+
835
class Stateful(ABC):
936
"""Mixin adding a reset method."""
1037

src/torchjd/aggregation/_nash_mtl.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,28 @@
11
# Partly adapted from https://github.com/AvivNavon/nash-mtl — MIT License, Copyright (c) 2022 Aviv Navon.
22
# See NOTICES for the full license text.
33

4-
from torchjd.aggregation._mixins import Stateful, _NonDifferentiable
4+
from __future__ import annotations
55

6-
from ._utils.check_dependencies import check_dependencies_are_installed
7-
from ._weighting_bases import _MatrixWeighting
8-
9-
check_dependencies_are_installed(["cvxpy", "ecos"])
6+
import contextlib
107

11-
import cvxpy as cp
128
import numpy as np
139
import torch
14-
from cvxpy import Expression, SolverError
1510
from torch import Tensor
1611

12+
from torchjd.aggregation._mixins import Stateful, _NonDifferentiable, _WithOptionalDeps
13+
1714
from ._aggregator_bases import WeightedAggregator
15+
from ._weighting_bases import _MatrixWeighting
16+
17+
with contextlib.suppress(ImportError):
18+
import cvxpy as cp
19+
from cvxpy import Expression, SolverError
1820

1921

2022
# Non-differentiable: the cvxpy solver operates on numpy arrays, breaking the autograd graph.
21-
class _NashMTLWeighting(_NonDifferentiable, Stateful, _MatrixWeighting):
23+
class _NashMTLWeighting(_WithOptionalDeps, _NonDifferentiable, Stateful, _MatrixWeighting):
24+
_REQUIRED_DEPS = ["cvxpy", "ecos"]
25+
_INSTALL_HINT = 'Install them with: pip install "torchjd[nash_mtl]"'
2226
"""
2327
:class:`~torchjd.aggregation._mixins.Stateful`
2428
:class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.Matrix`] that
@@ -215,10 +219,9 @@ class NashMTL(_NonDifferentiable, Stateful, WeightedAggregator):
215219
:param optim_niter: The number of iterations of the underlying optimization process.
216220
217221
.. note::
218-
This aggregator is not installed by default. When not installed, trying to import it should
219-
result in the following error:
220-
``ImportError: cannot import name 'NashMTL' from 'torchjd.aggregation'``.
221-
To install it, use ``pip install "torchjd[nash_mtl]"``.
222+
This aggregator requires optional dependencies. When they are not installed, instantiating
223+
it raises an :class:`ImportError` with installation instructions.
224+
To install them, use ``pip install "torchjd[nash_mtl]"``.
222225
223226
.. warning::
224227
This implementation was adapted from the `official implementation

tests/unit/aggregation/test_cagrad.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
1+
import pytest
2+
3+
pytest.importorskip("cvxpy")
4+
pytest.importorskip("clarabel")
5+
16
from contextlib import nullcontext as does_not_raise
27

38
from pytest import mark, raises
49
from torch import Tensor
510
from utils.contexts import ExceptionContext
611
from utils.tensors import ones_
712

8-
try:
9-
from torchjd.aggregation import CAGrad
10-
from torchjd.aggregation._cagrad import CAGradWeighting
11-
except ImportError:
12-
import pytest
13-
14-
pytest.skip("CAGrad dependencies not installed", allow_module_level=True)
13+
from torchjd.aggregation import CAGrad
14+
from torchjd.aggregation._cagrad import CAGradWeighting
1515

1616
from ._asserts import assert_expected_structure, assert_non_conflicting, assert_non_differentiable
1717
from ._inputs import scaled_matrices, typical_matrices

tests/unit/aggregation/test_nash_mtl.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
1+
import pytest
2+
3+
pytest.importorskip("cvxpy")
4+
pytest.importorskip("ecos")
5+
16
from pytest import mark, raises
27
from torch import Tensor
38
from torch.testing import assert_close
49
from utils.tensors import ones_, randn_, tensor_
510

6-
try:
7-
from torchjd.aggregation import NashMTL
8-
from torchjd.aggregation._nash_mtl import _NashMTLWeighting
9-
except ImportError:
10-
import pytest
11-
12-
pytest.skip("NashMTL dependencies not installed", allow_module_level=True)
11+
from torchjd.aggregation import NashMTL
12+
from torchjd.aggregation._nash_mtl import _NashMTLWeighting
1313

1414
from ._asserts import assert_expected_structure, assert_non_differentiable
1515
from ._inputs import nash_mtl_matrices

0 commit comments

Comments
 (0)