From c69dfeece2b8c6c477228b62487f078e8b88ec19 Mon Sep 17 00:00:00 2001 From: Kerem Tezcan Date: Sat, 16 May 2026 10:47:17 +0200 Subject: [PATCH] bugfix non existing latent state when fo=1 --- .../train/loss_modules/loss_module_latent_diffusion.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py b/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py index a89c9d6b8..0918c1732 100644 --- a/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py +++ b/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py @@ -94,8 +94,9 @@ def compute_loss(self, preds: dict, targets: dict, **kwargs) -> LossValues: for _, _, loss_fct_name in self.loss_fcts } - pred_tokens_all = [pl["latent_state"].z_pre_norm for pl in preds.latent if pl] + pred_tokens_all = [pl["latent_state"].z_pre_norm for pl in preds.latent if ("latent_state" in pl)] target_tokens_all = [latent["diffusion_latent"] for latent in targets.latent if latent] + assert len(pred_tokens_all) == len(target_tokens_all), "Mismatch between predicted and target token lengths" eta = torch.tensor( [targets.aux_outputs["noise_level_rn"]], device=self.device, dtype=torch.float32