Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions src/weathergen/train/loss_modules/loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,12 @@ def kernel_crps(

# apply point weighting
if weights_points is not None:
kcrps_locs_chs = kcrps_locs_chs * weights_points
if weights_points.dim() == 1:
# uniform location weight across channels
kcrps_locs_chs = kcrps_locs_chs * weights_points
else:
# per-channel location weight
kcrps_locs_chs = kcrps_locs_chs * weights_points.T.unsqueeze(0)
# apply channel weighting
kcrps_chs = torch.mean(torch.mean(kcrps_locs_chs, 0), -1)
if weights_channels is not None:
Expand Down Expand Up @@ -195,7 +200,12 @@ def lp_loss(
torch.abs(torch.where(mask_nan, target, 0) - torch.where(mask_nan, pred, 0)), p_norm
)
if weights_points is not None:
diff_p = (diff_p.transpose(1, 0) * weights_points).transpose(1, 0)
if weights_points.dim() == 1:
# uniform location weight across channels
diff_p = (diff_p.transpose(1, 0) * weights_points).transpose(1, 0)
else:
# per-channel location weight
diff_p = diff_p * weights_points
loss_chs = diff_p.mean(0) if with_mean else diff_p.sum(0)
loss_chs = torch.pow(loss_chs, 1.0 / p_norm) if with_p_root else loss_chs
loss = torch.mean(loss_chs * weights_channels if weights_channels is not None else loss_chs)
Expand Down
26 changes: 24 additions & 2 deletions src/weathergen/train/loss_modules/loss_module_physical.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ def __init__(
self.device = device
self.name = "LossPhysical"

# cache for per-stream location-weight
self._location_weight_fractions: dict[str, torch.Tensor | None] = {}

# dynamically load loss functions based on configuration and stage
self.loss_fcts = [
[
Expand Down Expand Up @@ -100,14 +103,33 @@ def _get_output_step_weights(self, len_forecast_steps):
decay_factor = list(timestep_weight_config.values())[0]["decay_factor"]
return weights_timestep_fct(len_forecast_steps, decay_factor)

def _get_location_weights(self, stream_info, target_coords):
def _get_location_weights(self, stream_info, target_coords, target_channels):
location_weight_type = stream_info.get("location_weight", None)
if location_weight_type is None:
return None
weights_locations_fct = getattr(loss_fns, location_weight_type)
weights_locations = weights_locations_fct(target_coords)
weights_locations = weights_locations.to(device=self.device, non_blocking=True)

# Cache; reuse on every subsequent call.
stream_name = stream_info["name"]
if stream_name not in self._location_weight_fractions:
location_weight_fraction = stream_info.get("location_weight_fraction", None)
if location_weight_fraction is not None:
self._location_weight_fractions[stream_name] = torch.tensor(
[location_weight_fraction.get(ch, 1.0) for ch in target_channels],
device=self.device,
dtype=weights_locations.dtype,
)
else:
self._location_weight_fractions[stream_name] = None

fractions = self._location_weight_fractions[stream_name]
if fractions is not None:
weights_locations = 1.0 + fractions.unsqueeze(0) * (
weights_locations.unsqueeze(1) - 1.0
)

return weights_locations

def _get_substep_masks(self, stream_info, output_step, target_times):
Expand Down Expand Up @@ -263,7 +285,7 @@ def compute_loss(self, preds: dict, targets: dict, metadata) -> LossValues:

# get weights for locations
weights_locations = self._get_location_weights(
stream_info, targets_coords_batch[target_idx]
stream_info, targets_coords_batch[target_idx], target_channels
)

# loss_st_corr: loss for give source-target correspondence
Expand Down
Loading