diff --git a/config/streams/era5_1deg_forecasting/era5.yml b/config/streams/era5_1deg_forecasting/era5.yml index ae8e2da53..fce350a59 100644 --- a/config/streams/era5_1deg_forecasting/era5.yml +++ b/config/streams/era5_1deg_forecasting/era5.yml @@ -13,7 +13,7 @@ ERA5 : stream_id : 0 source_exclude : ['w_', 'skt', 'tcw', 'cp', 'tp'] target_exclude : ['w_', 'slor', 'sdor', 'tcw', 'cp', 'tp'] - geoinfo_channels : ['lsm', 'slor', 'sdor', 'insolation', 'cos_local_time', 'sin_local_time', 'cos_julian_day', 'sin_julian_day'] + # geoinfo_channels : ['lsm', 'slor', 'sdor', 'insolation', 'cos_local_time', 'sin_local_time', 'cos_julian_day', 'sin_julian_day'] loss_weight : 1. location_weight : cosine_latitude masking_rate : 0.6 diff --git a/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py b/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py index a89c9d6b8..0918c1732 100644 --- a/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py +++ b/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py @@ -94,8 +94,9 @@ def compute_loss(self, preds: dict, targets: dict, **kwargs) -> LossValues: for _, _, loss_fct_name in self.loss_fcts } - pred_tokens_all = [pl["latent_state"].z_pre_norm for pl in preds.latent if pl] + pred_tokens_all = [pl["latent_state"].z_pre_norm for pl in preds.latent if ("latent_state" in pl)] target_tokens_all = [latent["diffusion_latent"] for latent in targets.latent if latent] + assert len(pred_tokens_all) == len(target_tokens_all), "Mismatch between predicted and target token lengths" eta = torch.tensor( [targets.aux_outputs["noise_level_rn"]], device=self.device, dtype=torch.float32 diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 25f62309f..26ed6af59 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -72,6 +72,9 @@ def _expand_targets_to_match_preds(preds, targets_and_auxs: dict) -> None: """ n_pred = len(preds.physical) for t_aux in targets_and_auxs.values(): + # if the first entry is None (e.g. when forecast_offset > 0), then remove it + if not t_aux.physical[0]: + t_aux.physical = t_aux.physical[1:] n_tgt = len(t_aux.physical) if n_tgt == n_pred or n_tgt == 0: continue diff --git a/src/weathergen/utils/validation_io.py b/src/weathergen/utils/validation_io.py index e6bbebe8d..45ea832b7 100644 --- a/src/weathergen/utils/validation_io.py +++ b/src/weathergen/utils/validation_io.py @@ -63,6 +63,8 @@ def write_output( # TODO Maybe stopping at forecast_steps explained #1657 for t_idx in timestep_idxs: + if cf.training_config.forecast.offset == 1: + t_idx = t_idx - 1 preds_all += [[]] targets_all += [[]] targets_coords_all += [[]]