diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 72486da2f..5369e6c6d 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -111,9 +111,9 @@ def forward(self, batch, pe_embed): # if the assert is hit, max_number_tokens_local_per_cell in config needs to be increased max_tokens = self.cf.get("ae_local_max_tokens_per_cell", 64) - assert ( - batch.tokens_lens.flatten(0, 2).sum(0).max() <= max_tokens - ), "max number of tokens per cell for positional encoding exceeded." + assert batch.tokens_lens.flatten(0, 2).sum(0).max() <= max_tokens, ( + "max number of tokens per cell for positional encoding exceeded." + ) " Increase ae_local_max_tokens_per_cell in config." if batch.tokens_lens.shape[2] == 1: diff --git a/src/weathergen/train/loss_modules/loss_functions.py b/src/weathergen/train/loss_modules/loss_functions.py index f9f173fcf..0d741bc28 100644 --- a/src/weathergen/train/loss_modules/loss_functions.py +++ b/src/weathergen/train/loss_modules/loss_functions.py @@ -62,9 +62,39 @@ def stats_normalized_erf(target, ens, mu, stddev): return torch.mean(d * d) # + torch.mean( torch.sqrt( stddev) ) -def mse_ens(target, ens, mu, stddev): - mse_loss = torch.nn.functional.mse_loss - return torch.stack([mse_loss(target, mem) for mem in ens], 0).mean() +def mse_ens( + target: torch.Tensor, + pred: torch.Tensor, + weights_channels: torch.Tensor | None, + weights_points: torch.Tensor | None, + use_ensemble_mean: bool = False, +): + """ + MSE loss for ensemble predictions, with two modes: + + use_ensemble_mean=False (default): + Mean of per-member MSE — equivalent to mean(mse(target, mem) for mem in ens). + Penalises every member independently; each member is pushed toward the target. + + use_ensemble_mean=True: + MSE of the ensemble mean against the target. + Collapses the ensemble to a single prediction before comparing, which + ignores spread and rewards a well-calibrated ensemble mean. + + target : shape (num_data_points, num_channels) + pred : shape (ens_dim, num_data_points, num_channels) + weights_channels : shape (num_channels,) or None + weights_points : shape (num_data_points,) or None + """ + if use_ensemble_mean: + # lp_loss collapses the ensemble via .mean(0) before computing MSE + return mse(target, pred, weights_channels, weights_points) + + losses, losses_chs = zip( + *[mse(target, member.unsqueeze(0), weights_channels, weights_points) for member in pred], + strict=False, + ) + return torch.stack(list(losses)).mean(), torch.stack(list(losses_chs)).mean(0) def kernel_crps(