From e6f72f24de2ed63e16836edfef479fe1056a9cca Mon Sep 17 00:00:00 2001 From: "julius.polz" Date: Thu, 30 Apr 2026 10:33:51 +0200 Subject: [PATCH 1/2] enable per-channel location weight scaling --- .../train/loss_modules/loss_functions.py | 14 ++++++++++++-- .../train/loss_modules/loss_module_physical.py | 16 ++++++++++++++-- 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/src/weathergen/train/loss_modules/loss_functions.py b/src/weathergen/train/loss_modules/loss_functions.py index 192101278..d5b4b34bf 100644 --- a/src/weathergen/train/loss_modules/loss_functions.py +++ b/src/weathergen/train/loss_modules/loss_functions.py @@ -114,7 +114,12 @@ def kernel_crps( # apply point weighting if weights_points is not None: - kcrps_locs_chs = kcrps_locs_chs * weights_points + if weights_points.dim() == 1: + # uniform location weight across channels + kcrps_locs_chs = kcrps_locs_chs * weights_points + else: + # per-channel location weight + kcrps_locs_chs = kcrps_locs_chs * weights_points.T.unsqueeze(0) # apply channel weighting kcrps_chs = torch.mean(torch.mean(kcrps_locs_chs, 0), -1) if weights_channels is not None: @@ -195,7 +200,12 @@ def lp_loss( torch.abs(torch.where(mask_nan, target, 0) - torch.where(mask_nan, pred, 0)), p_norm ) if weights_points is not None: - diff_p = (diff_p.transpose(1, 0) * weights_points).transpose(1, 0) + if weights_points.dim() == 1: + # uniform location weight across channels + diff_p = (diff_p.transpose(1, 0) * weights_points).transpose(1, 0) + else: + # per-channel location weight + diff_p = diff_p * weights_points loss_chs = diff_p.mean(0) if with_mean else diff_p.sum(0) loss_chs = torch.pow(loss_chs, 1.0 / p_norm) if with_p_root else loss_chs loss = torch.mean(loss_chs * weights_channels if weights_channels is not None else loss_chs) diff --git a/src/weathergen/train/loss_modules/loss_module_physical.py b/src/weathergen/train/loss_modules/loss_module_physical.py index d913b70e4..7aec379c2 100644 --- a/src/weathergen/train/loss_modules/loss_module_physical.py +++ b/src/weathergen/train/loss_modules/loss_module_physical.py @@ -100,7 +100,7 @@ def _get_output_step_weights(self, len_forecast_steps): decay_factor = list(timestep_weight_config.values())[0]["decay_factor"] return weights_timestep_fct(len_forecast_steps, decay_factor) - def _get_location_weights(self, stream_info, target_coords): + def _get_location_weights(self, stream_info, target_coords, target_channels): location_weight_type = stream_info.get("location_weight", None) if location_weight_type is None: return None @@ -108,6 +108,18 @@ def _get_location_weights(self, stream_info, target_coords): weights_locations = weights_locations_fct(target_coords) weights_locations = weights_locations.to(device=self.device, non_blocking=True) + # Channels not listed default to 1.0 (full weighting). + location_weight_fraction = stream_info.get("location_weight_fraction", None) + if location_weight_fraction is not None: + fractions = torch.tensor( + [location_weight_fraction.get(ch, 1.0) for ch in target_channels], + device=self.device, + dtype=weights_locations.dtype, + ) + weights_locations = 1.0 + fractions.unsqueeze(0) * ( + weights_locations.unsqueeze(1) - 1.0 + ) + return weights_locations def _get_substep_masks(self, stream_info, output_step, target_times): @@ -263,7 +275,7 @@ def compute_loss(self, preds: dict, targets: dict, metadata) -> LossValues: # get weights for locations weights_locations = self._get_location_weights( - stream_info, targets_coords_batch[target_idx] + stream_info, targets_coords_batch[target_idx], target_channels ) # loss_st_corr: loss for give source-target correspondence From a89ff69c1341749839a4d84d84ee97f1f173082a Mon Sep 17 00:00:00 2001 From: "julius.polz" Date: Thu, 30 Apr 2026 11:17:49 +0200 Subject: [PATCH 2/2] cache location weights --- .../loss_modules/loss_module_physical.py | 26 +++++++++++++------ 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/src/weathergen/train/loss_modules/loss_module_physical.py b/src/weathergen/train/loss_modules/loss_module_physical.py index 7aec379c2..8ef712bf4 100644 --- a/src/weathergen/train/loss_modules/loss_module_physical.py +++ b/src/weathergen/train/loss_modules/loss_module_physical.py @@ -57,6 +57,9 @@ def __init__( self.device = device self.name = "LossPhysical" + # cache for per-stream location-weight + self._location_weight_fractions: dict[str, torch.Tensor | None] = {} + # dynamically load loss functions based on configuration and stage self.loss_fcts = [ [ @@ -108,14 +111,21 @@ def _get_location_weights(self, stream_info, target_coords, target_channels): weights_locations = weights_locations_fct(target_coords) weights_locations = weights_locations.to(device=self.device, non_blocking=True) - # Channels not listed default to 1.0 (full weighting). - location_weight_fraction = stream_info.get("location_weight_fraction", None) - if location_weight_fraction is not None: - fractions = torch.tensor( - [location_weight_fraction.get(ch, 1.0) for ch in target_channels], - device=self.device, - dtype=weights_locations.dtype, - ) + # Cache; reuse on every subsequent call. + stream_name = stream_info["name"] + if stream_name not in self._location_weight_fractions: + location_weight_fraction = stream_info.get("location_weight_fraction", None) + if location_weight_fraction is not None: + self._location_weight_fractions[stream_name] = torch.tensor( + [location_weight_fraction.get(ch, 1.0) for ch in target_channels], + device=self.device, + dtype=weights_locations.dtype, + ) + else: + self._location_weight_fractions[stream_name] = None + + fractions = self._location_weight_fractions[stream_name] + if fractions is not None: weights_locations = 1.0 + fractions.unsqueeze(0) * ( weights_locations.unsqueeze(1) - 1.0 )