From cc4df7375842a9f2b6dc7deacc63f4109ae43da0 Mon Sep 17 00:00:00 2001 From: Matteo Broccoli Date: Fri, 22 May 2026 18:05:56 +0200 Subject: [PATCH 1/2] added arctic region for plotting in stereographic projection and modified projection logic --- .../src/weathergen/evaluate/plotting/plotter.py | 6 +++++- .../src/weathergen/evaluate/utils/regions.py | 17 ++++++++++------- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py index 97e0840a6..adb6d4f7e 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py @@ -888,7 +888,11 @@ def scatter_plot( if figsize is None and data.size >= 200_000: figsize = (15, 7) - proj = ccrs.Robinson() if regionname == "global" else ccrs.PlateCarree() + if regionname and regionname.lower() in RegionLibrary.REGIONS: + proj = RegionBoundingBox.from_region_name(regionname).projection + else: + # Fallback if regionname is None or not in library + proj = ccrs.PlateCarree() fig = plt.figure(figsize=figsize, dpi=self.dpi_val) ax = fig.add_subplot(1, 1, 1, projection=proj) try: diff --git a/packages/evaluate/src/weathergen/evaluate/utils/regions.py b/packages/evaluate/src/weathergen/evaluate/utils/regions.py index b2893c314..401402f3c 100644 --- a/packages/evaluate/src/weathergen/evaluate/utils/regions.py +++ b/packages/evaluate/src/weathergen/evaluate/utils/regions.py @@ -11,6 +11,7 @@ from dataclasses import dataclass from typing import ClassVar +import cartopy.crs as ccrs import xarray as xr _logger = logging.getLogger(__name__) @@ -22,13 +23,14 @@ class RegionLibrary: Predefined bounding boxes for known regions. """ - REGIONS: ClassVar[dict[str, tuple[float, float, float, float]]] = { - "global": (-90.0, 90.0, -180.0, 180.0), - "nhem": (0.0, 90.0, -180.0, 180.0), - "shem": (-90.0, 0.0, -180.0, 180.0), - "tropics": (-30.0, 30.0, -180.0, 180.0), - "belgium": (49, 52, 2, 7), - "europe": (35, 70, -10, 40), + REGIONS: ClassVar[dict[str, tuple[float, float, float, float, ccrs.Projection]]] = { + "global": (-90.0, 90.0, -180.0, 180.0, ccrs.Robinson()), + "nhem": (0.0, 90.0, -180.0, 180.0, ccrs.PlateCarree()), + "shem": (-90.0, 0.0, -180.0, 180.0, ccrs.PlateCarree()), + "tropics": (-30.0, 30.0, -180.0, 180.0, ccrs.PlateCarree()), + "belgium": (49, 52, 2, 7, ccrs.PlateCarree()), + "europe": (35, 70, -10, 40, ccrs.PlateCarree()), + "arctic": (50.0, 90.0, -180.0, 180.0, ccrs.Stereographic(central_longitude=0, central_latitude=90)) } @@ -38,6 +40,7 @@ class RegionBoundingBox: lat_max: float lon_min: float lon_max: float + projection: ccrs.Projection def __post_init__(self): """Validate the bounding box coordinates.""" From 6251f721a69d2d91c6af3b8ad7d02831be76f2a9 Mon Sep 17 00:00:00 2001 From: Matteo Broccoli Date: Fri, 22 May 2026 18:51:27 +0200 Subject: [PATCH 2/2] new projection logic in plotter --- config/cglors_era5_finetuning_eval.yml | 40 +++++++++++++++++++ .../weathergen/evaluate/plotting/plotter.py | 15 ++++--- .../src/weathergen/evaluate/utils/regions.py | 8 +++- src/weathergen/model/attention.py | 12 ++---- src/weathergen/model/positional_encoding.py | 2 +- 5 files changed, 61 insertions(+), 16 deletions(-) create mode 100644 config/cglors_era5_finetuning_eval.yml diff --git a/config/cglors_era5_finetuning_eval.yml b/config/cglors_era5_finetuning_eval.yml new file mode 100644 index 000000000..832afdbc7 --- /dev/null +++ b/config/cglors_era5_finetuning_eval.yml @@ -0,0 +1,40 @@ +global_plotting_options: + image_format : "png" #options: "png", "pdf", "svg", "eps", "jpg" .. + regions: ["global", "arctic"] + dpi_val : 300 + C-GLORS: + marker_size: 4 + +evaluation: + metrics : ["froct", "rmse"] + regions: ["global"] + summary_plots : true + summary_dir: "./plots/" + print_summary: false #print out score values on screen. it can be verbose + plot_score_maps: true + log_scale: false + add_grid: true + +run_ids : + y0mltkag: + label: "atmofs03-finetune-ycnfjoph" + epoch: 0 + rank: 0 + streams: + C-GLORS: + channels: [ + "icethic", + "iicevelu", + "iicevelv", + "isnowthi", + "soicecov", + "sosstsst"] + evaluation: + sample: "all" + forecast_step: "all" + plotting: + sample: [0] + forecast_step: "all" + plot_maps: true + plot_animations: true + plot_histograms: true diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py index adb6d4f7e..63488cef1 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py @@ -888,11 +888,16 @@ def scatter_plot( if figsize is None and data.size >= 200_000: figsize = (15, 7) - if regionname and regionname.lower() in RegionLibrary.REGIONS: - proj = RegionBoundingBox.from_region_name(regionname).projection - else: - # Fallback if regionname is None or not in library - proj = ccrs.PlateCarree() + proj = ccrs.PlateCarree() + if regionname: + try: + # This uses the method already available in RegionBoundingBox + bbox = RegionBoundingBox.from_region_name(regionname) + proj = bbox.projection + except ValueError: + # If regionname isn't in the library, fall back to PlateCarree + _logger.warning(f"Region '{regionname}' not found in library, using PlateCarree.") + proj = ccrs.PlateCarree() fig = plt.figure(figsize=figsize, dpi=self.dpi_val) ax = fig.add_subplot(1, 1, 1, projection=proj) try: diff --git a/packages/evaluate/src/weathergen/evaluate/utils/regions.py b/packages/evaluate/src/weathergen/evaluate/utils/regions.py index 401402f3c..75f46a31a 100644 --- a/packages/evaluate/src/weathergen/evaluate/utils/regions.py +++ b/packages/evaluate/src/weathergen/evaluate/utils/regions.py @@ -30,7 +30,13 @@ class RegionLibrary: "tropics": (-30.0, 30.0, -180.0, 180.0, ccrs.PlateCarree()), "belgium": (49, 52, 2, 7, ccrs.PlateCarree()), "europe": (35, 70, -10, 40, ccrs.PlateCarree()), - "arctic": (50.0, 90.0, -180.0, 180.0, ccrs.Stereographic(central_longitude=0, central_latitude=90)) + "arctic": ( + 50.0, + 90.0, + -180.0, + 180.0, + ccrs.Stereographic(central_longitude=0, central_latitude=90), + ), } diff --git a/src/weathergen/model/attention.py b/src/weathergen/model/attention.py index b18791aa5..fb4250d98 100644 --- a/src/weathergen/model/attention.py +++ b/src/weathergen/model/attention.py @@ -102,9 +102,7 @@ def forward(self, x, x_lens, ada_ln_aux=None, coords=None): ks = self.lnorm_k(self.proj_heads_k(x).reshape(s)).to(self.dtype) vs = self.proj_heads_v(x).reshape(s) - qs, ks = apply_rope( - qs, ks, coords, self.rope_mode, 1 - ) + qs, ks = apply_rope(qs, ks, coords, self.rope_mode, 1) if self.rope_post_mod_qk_lnorm: qs = self.post_rope_lnorm_q(qs).to(self.dtype) ks = self.post_rope_lnorm_k(ks).to(self.dtype) @@ -302,9 +300,7 @@ def forward(self, x, coords=None, ada_ln_aux=None): ks = self.lnorm_k(self.proj_heads_k(x).reshape(s)).to(self.dtype).permute([0, 2, 1, 3]) vs = self.proj_heads_v(x).reshape(s).permute([0, 2, 1, 3]) - qs, ks = apply_rope( - qs, ks, coords, self.rope_mode, 1 - ) + qs, ks = apply_rope(qs, ks, coords, self.rope_mode, 1) if self.rope_post_mod_qk_lnorm: qs = self.post_rope_lnorm_q(qs).to(self.dtype) ks = self.post_rope_lnorm_k(ks).to(self.dtype) @@ -621,9 +617,7 @@ def forward(self, x, coords=None, ada_ln_aux=None): ks = self.lnorm_k(self.proj_heads_k(x).reshape(s)).to(self.dtype) vs = self.proj_heads_v(x).reshape(s).to(self.dtype) - qs, ks = apply_rope( - qs, ks, coords, self.rope_mode, 2 - ) + qs, ks = apply_rope(qs, ks, coords, self.rope_mode, 2) if self.rope_post_mod_qk_lnorm: qs = self.post_rope_lnorm_q(qs).to(self.dtype) ks = self.post_rope_lnorm_k(ks).to(self.dtype) diff --git a/src/weathergen/model/positional_encoding.py b/src/weathergen/model/positional_encoding.py index ad1c54bee..6aa364e4c 100644 --- a/src/weathergen/model/positional_encoding.py +++ b/src/weathergen/model/positional_encoding.py @@ -181,6 +181,7 @@ def rotary_pos_emb_2d(q, k, coords, base=10000.0, unsqueeze_dim=1): cos, sin = rotary_embedding_2d(coords, q.shape[-1], base=base) return apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=unsqueeze_dim) + # Spherical RoPE def _max_supported_spherical_band(dim_embed: int, num_heads: int) -> int: head_dim = dim_embed // num_heads @@ -324,7 +325,6 @@ def build_spherical_rope_coeff_tensors( ) - @lru_cache(maxsize=32) def _healpy_band_maps( nside: int, band: int