From 390a014f237f2f4404a9a3d4a9790a0ff116fa2b Mon Sep 17 00:00:00 2001 From: Kerem Tezcan Date: Sat, 16 May 2026 10:42:23 +0200 Subject: [PATCH 1/6] bug fix, non existing latent state when forecast offset=1 --- .../train/loss_modules/loss_module_latent_diffusion.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py b/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py index a89c9d6b8..0918c1732 100644 --- a/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py +++ b/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py @@ -94,8 +94,9 @@ def compute_loss(self, preds: dict, targets: dict, **kwargs) -> LossValues: for _, _, loss_fct_name in self.loss_fcts } - pred_tokens_all = [pl["latent_state"].z_pre_norm for pl in preds.latent if pl] + pred_tokens_all = [pl["latent_state"].z_pre_norm for pl in preds.latent if ("latent_state" in pl)] target_tokens_all = [latent["diffusion_latent"] for latent in targets.latent if latent] + assert len(pred_tokens_all) == len(target_tokens_all), "Mismatch between predicted and target token lengths" eta = torch.tensor( [targets.aux_outputs["noise_level_rn"]], device=self.device, dtype=torch.float32 From b65e5b0b9076cdfab158e3f526a6da40875065d3 Mon Sep 17 00:00:00 2001 From: Kerem Tezcan Date: Sun, 17 May 2026 15:29:12 +0200 Subject: [PATCH 2/6] merged conf resolved --- .../evaluate/plotting/plot_orchestration.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py b/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py index b28ad4adc..9b8ef47ae 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py @@ -508,6 +508,21 @@ def _plot_all_samples( plotter.clean_data_selection() +def _plot_distribution_analysis( + plotter_cfg: dict, + output_basedir: str, + stream: str, + fstep: int | str, + tars: xr.DataArray, + preds: xr.DataArray, + channels: list[str], +) -> None: + """Aggregate-over-samples distribution histogram for one fstep (loky worker).""" + matplotlib.use("Agg") + plotter = Plotter(plotter_cfg, Path(output_basedir), stream) + plotter.create_distribution_histograms(tars, preds, channels, fstep, stream) + + def plot_data( reader: Reader, stream: str, From 551474b245c3f9f73f50311cea593ad998c1afc8 Mon Sep 17 00:00:00 2001 From: Kerem Tezcan Date: Sun, 17 May 2026 17:14:34 +0200 Subject: [PATCH 3/6] hotfix for infernce with forecastoffset>0 --- src/weathergen/train/trainer.py | 3 +++ src/weathergen/utils/validation_io.py | 2 ++ 2 files changed, 5 insertions(+) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 25f62309f..26ed6af59 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -72,6 +72,9 @@ def _expand_targets_to_match_preds(preds, targets_and_auxs: dict) -> None: """ n_pred = len(preds.physical) for t_aux in targets_and_auxs.values(): + # if the first entry is None (e.g. when forecast_offset > 0), then remove it + if not t_aux.physical[0]: + t_aux.physical = t_aux.physical[1:] n_tgt = len(t_aux.physical) if n_tgt == n_pred or n_tgt == 0: continue diff --git a/src/weathergen/utils/validation_io.py b/src/weathergen/utils/validation_io.py index e6bbebe8d..45ea832b7 100644 --- a/src/weathergen/utils/validation_io.py +++ b/src/weathergen/utils/validation_io.py @@ -63,6 +63,8 @@ def write_output( # TODO Maybe stopping at forecast_steps explained #1657 for t_idx in timestep_idxs: + if cf.training_config.forecast.offset == 1: + t_idx = t_idx - 1 preds_all += [[]] targets_all += [[]] targets_coords_all += [[]] From 09652ad976027e184b03d6f6726edca94c7f9b14 Mon Sep 17 00:00:00 2001 From: Kerem Tezcan Date: Wed, 20 May 2026 09:33:18 +0200 Subject: [PATCH 4/6] removed geoinfos --- config/streams/era5_1deg_forecasting/era5.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config/streams/era5_1deg_forecasting/era5.yml b/config/streams/era5_1deg_forecasting/era5.yml index ae8e2da53..fce350a59 100644 --- a/config/streams/era5_1deg_forecasting/era5.yml +++ b/config/streams/era5_1deg_forecasting/era5.yml @@ -13,7 +13,7 @@ ERA5 : stream_id : 0 source_exclude : ['w_', 'skt', 'tcw', 'cp', 'tp'] target_exclude : ['w_', 'slor', 'sdor', 'tcw', 'cp', 'tp'] - geoinfo_channels : ['lsm', 'slor', 'sdor', 'insolation', 'cos_local_time', 'sin_local_time', 'cos_julian_day', 'sin_julian_day'] + # geoinfo_channels : ['lsm', 'slor', 'sdor', 'insolation', 'cos_local_time', 'sin_local_time', 'cos_julian_day', 'sin_julian_day'] loss_weight : 1. location_weight : cosine_latitude masking_rate : 0.6 From 6efa0503920ae3cf16844aea2633546e4bd6c683 Mon Sep 17 00:00:00 2001 From: Kerem Tezcan Date: Wed, 20 May 2026 14:28:58 +0200 Subject: [PATCH 5/6] cleanup plotting --- .../evaluate/plotting/plot_orchestration.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py b/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py index 9b8ef47ae..8523a049a 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py @@ -507,22 +507,6 @@ def _plot_all_samples( plotter.clean_data_selection() - -def _plot_distribution_analysis( - plotter_cfg: dict, - output_basedir: str, - stream: str, - fstep: int | str, - tars: xr.DataArray, - preds: xr.DataArray, - channels: list[str], -) -> None: - """Aggregate-over-samples distribution histogram for one fstep (loky worker).""" - matplotlib.use("Agg") - plotter = Plotter(plotter_cfg, Path(output_basedir), stream) - plotter.create_distribution_histograms(tars, preds, channels, fstep, stream) - - def plot_data( reader: Reader, stream: str, From 5f3b976416c57502fad065138a121c386e449262 Mon Sep 17 00:00:00 2001 From: Kerem Tezcan Date: Wed, 20 May 2026 14:30:53 +0200 Subject: [PATCH 6/6] cleanup plotting again --- .../src/weathergen/evaluate/plotting/plot_orchestration.py | 1 + 1 file changed, 1 insertion(+) diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py b/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py index 8523a049a..b28ad4adc 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py @@ -507,6 +507,7 @@ def _plot_all_samples( plotter.clean_data_selection() + def plot_data( reader: Reader, stream: str,