Skip to content

Commit 10573be

Browse files
committed
labelize_forward on ensemble
1 parent 796f42e commit 10573be

11 files changed

Lines changed: 15 additions & 7 deletions

pina/_src/solver/ensemble_simple_solver.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Module for the DeepEnsemble simple solver."""
22

33
from pina._src.solver.multi_model_simple_solver import MultiModelSimpleSolver
4+
from pina._src.core.utils import check_consistency
45

56

67
class DeepEnsembleSimpleSolver(MultiModelSimpleSolver):
@@ -99,5 +100,7 @@ def __init__(
99100
weighting=weighting,
100101
loss=loss,
101102
use_lt=use_lt,
102-
ensemble_dim=ensemble_dim,
103103
)
104+
105+
check_consistency(ensemble_dim, int)
106+
self.num_ensemble = len(models)

pina/_src/solver/multi_model_simple_solver.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ def __init__(
6262
weighting=None,
6363
loss=None,
6464
use_lt=True,
65-
ensemble_dim=0,
6665
):
6766
"""
6867
Initialize the multi-model simple solver.
@@ -90,7 +89,6 @@ def __init__(
9089
loss = torch.nn.MSELoss()
9190

9291
check_consistency(loss, (LossInterface, _Loss), subclass=False)
93-
check_consistency(ensemble_dim, int)
9492

9593
super().__init__(
9694
problem=problem,
@@ -103,7 +101,6 @@ def __init__(
103101

104102
self._loss_fn = loss
105103
self._reduction = getattr(loss, "reduction", "mean")
106-
self._ensemble_dim = ensemble_dim
107104

108105
if hasattr(self._loss_fn, "reduction"):
109106
self._loss_fn.reduction = "none"
@@ -194,9 +191,16 @@ def optimization_cycle(self, batch):
194191
self.forward = ( # noqa: E731
195192
lambda x, _idx=idx: self.models[_idx].forward(x)
196193
)
194+
from pina._src.core.utils import labelize_forward
195+
problem = self.problem
196+
self.forward = labelize_forward(
197+
self.forward,
198+
input_variables=problem.input_variables,
199+
output_variables=problem.output_variables,
200+
)
197201
loss_tensor = condition.evaluate(
198202
condition_data, self, self._loss_fn
199-
)
203+
).tensor
200204
self.forward = original_forward
201205
per_model_losses.append(self._apply_reduction(loss_tensor))
202206

File renamed without changes.
File renamed without changes.
File renamed without changes.

tests/test_solver/test_autoregressive_solver.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from pina.condition import TimeSeriesCondition
1010
from pina.problem import AbstractProblem
1111
from pina.model import FeedForward
12+
from torch._dynamo import OptimizedModule
1213

1314

1415
# Hyperparameters and settings

0 commit comments

Comments
 (0)