diff --git a/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py b/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py index a89c9d6b8..0918c1732 100644 --- a/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py +++ b/src/weathergen/train/loss_modules/loss_module_latent_diffusion.py @@ -94,8 +94,9 @@ def compute_loss(self, preds: dict, targets: dict, **kwargs) -> LossValues: for _, _, loss_fct_name in self.loss_fcts } - pred_tokens_all = [pl["latent_state"].z_pre_norm for pl in preds.latent if pl] + pred_tokens_all = [pl["latent_state"].z_pre_norm for pl in preds.latent if ("latent_state" in pl)] target_tokens_all = [latent["diffusion_latent"] for latent in targets.latent if latent] + assert len(pred_tokens_all) == len(target_tokens_all), "Mismatch between predicted and target token lengths" eta = torch.tensor( [targets.aux_outputs["noise_level_rn"]], device=self.device, dtype=torch.float32