diff --git a/src/weathergen/train/loss_modules/loss_functions.py b/src/weathergen/train/loss_modules/loss_functions.py index 192101278..d5b4b34bf 100644 --- a/src/weathergen/train/loss_modules/loss_functions.py +++ b/src/weathergen/train/loss_modules/loss_functions.py @@ -114,7 +114,12 @@ def kernel_crps( # apply point weighting if weights_points is not None: - kcrps_locs_chs = kcrps_locs_chs * weights_points + if weights_points.dim() == 1: + # uniform location weight across channels + kcrps_locs_chs = kcrps_locs_chs * weights_points + else: + # per-channel location weight + kcrps_locs_chs = kcrps_locs_chs * weights_points.T.unsqueeze(0) # apply channel weighting kcrps_chs = torch.mean(torch.mean(kcrps_locs_chs, 0), -1) if weights_channels is not None: @@ -195,7 +200,12 @@ def lp_loss( torch.abs(torch.where(mask_nan, target, 0) - torch.where(mask_nan, pred, 0)), p_norm ) if weights_points is not None: - diff_p = (diff_p.transpose(1, 0) * weights_points).transpose(1, 0) + if weights_points.dim() == 1: + # uniform location weight across channels + diff_p = (diff_p.transpose(1, 0) * weights_points).transpose(1, 0) + else: + # per-channel location weight + diff_p = diff_p * weights_points loss_chs = diff_p.mean(0) if with_mean else diff_p.sum(0) loss_chs = torch.pow(loss_chs, 1.0 / p_norm) if with_p_root else loss_chs loss = torch.mean(loss_chs * weights_channels if weights_channels is not None else loss_chs) diff --git a/src/weathergen/train/loss_modules/loss_module_physical.py b/src/weathergen/train/loss_modules/loss_module_physical.py index d913b70e4..8ef712bf4 100644 --- a/src/weathergen/train/loss_modules/loss_module_physical.py +++ b/src/weathergen/train/loss_modules/loss_module_physical.py @@ -57,6 +57,9 @@ def __init__( self.device = device self.name = "LossPhysical" + # cache for per-stream location-weight + self._location_weight_fractions: dict[str, torch.Tensor | None] = {} + # dynamically load loss functions based on configuration and stage self.loss_fcts = [ [ @@ -100,7 +103,7 @@ def _get_output_step_weights(self, len_forecast_steps): decay_factor = list(timestep_weight_config.values())[0]["decay_factor"] return weights_timestep_fct(len_forecast_steps, decay_factor) - def _get_location_weights(self, stream_info, target_coords): + def _get_location_weights(self, stream_info, target_coords, target_channels): location_weight_type = stream_info.get("location_weight", None) if location_weight_type is None: return None @@ -108,6 +111,25 @@ def _get_location_weights(self, stream_info, target_coords): weights_locations = weights_locations_fct(target_coords) weights_locations = weights_locations.to(device=self.device, non_blocking=True) + # Cache; reuse on every subsequent call. + stream_name = stream_info["name"] + if stream_name not in self._location_weight_fractions: + location_weight_fraction = stream_info.get("location_weight_fraction", None) + if location_weight_fraction is not None: + self._location_weight_fractions[stream_name] = torch.tensor( + [location_weight_fraction.get(ch, 1.0) for ch in target_channels], + device=self.device, + dtype=weights_locations.dtype, + ) + else: + self._location_weight_fractions[stream_name] = None + + fractions = self._location_weight_fractions[stream_name] + if fractions is not None: + weights_locations = 1.0 + fractions.unsqueeze(0) * ( + weights_locations.unsqueeze(1) - 1.0 + ) + return weights_locations def _get_substep_masks(self, stream_info, output_step, target_times): @@ -263,7 +285,7 @@ def compute_loss(self, preds: dict, targets: dict, metadata) -> LossValues: # get weights for locations weights_locations = self._get_location_weights( - stream_info, targets_coords_batch[target_idx] + stream_info, targets_coords_batch[target_idx], target_channels ) # loss_st_corr: loss for give source-target correspondence