From 4c3f9452a2821ebdc92694dd8fea597ceaa1902f Mon Sep 17 00:00:00 2001 From: iluise Date: Wed, 20 May 2026 13:52:05 +0200 Subject: [PATCH 1/8] first working implementation of spectra --- packages/evaluate/pyproject.toml | 7 + .../evaluate/plotting/line_plots.py | 72 +++ .../evaluate/plotting/plot_orchestration.py | 27 +- .../evaluate/plotting/plot_utils.py | 110 ++++ .../weathergen/evaluate/plotting/plotter.py | 263 ++++++++ .../src/weathergen/evaluate/scores/psd.py | 566 ++++++++++++++++++ .../src/weathergen/evaluate/scores/score.py | 166 +++++ .../evaluate/scores/score_orchestration.py | 18 +- 8 files changed, 1214 insertions(+), 15 deletions(-) create mode 100644 packages/evaluate/src/weathergen/evaluate/scores/psd.py diff --git a/packages/evaluate/pyproject.toml b/packages/evaluate/pyproject.toml index a2d37fcae..e6329a8a6 100644 --- a/packages/evaluate/pyproject.toml +++ b/packages/evaluate/pyproject.toml @@ -21,8 +21,15 @@ dependencies = [ "earthkit-data==0.17.0", "earthkit-utils==0.1.2", "imageio[ffmpeg]>=2.37.2", + "scipy>=1.12", ] +[project.optional-dependencies] +zonal = [ + "scitools-iris>=3.11", + "cf-units", +] + [dependency-groups] dev = [ "pytest>=8.3", diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/line_plots.py b/packages/evaluate/src/weathergen/evaluate/plotting/line_plots.py index 7363c495e..e7307188c 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/line_plots.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/line_plots.py @@ -15,6 +15,7 @@ from pathlib import Path import matplotlib.pyplot as plt +import numpy as np import seaborn as sns import xarray as xr @@ -667,3 +668,74 @@ def heat_map( parts = ["heat_map", metric, tag] name = "_".join(filter(None, parts)) plt.savefig(f"{self.out_plot_dir.joinpath(name)}.{self.image_format}") + + # ------------------------------------------------------------------ + # PSD summary plot + # ------------------------------------------------------------------ + + def psd_plot( + self, + psd_datasets: list[dict], + labels: list[str], + tag: str = "", + ) -> None: + """Create a PSD summary plot overlaying multiple runs. + + Each entry in *psd_datasets* is a dict with keys + ``frequencies``, ``psd_target``, ``psd_prediction``, ``psd_method``. + + Parameters + ---------- + psd_datasets : list[dict] + One dict per run, each containing the PSD arrays stored by + ``Scores.calc_psd`` in ``.attrs``. + labels : list[str] + Human-readable label for each run. + tag : str + Filename tag. + """ + out_dir = Path(self.out_plot_dir_lines) / "psd" + out_dir.mkdir(parents=True, exist_ok=True) + + # Use the target from the first run as reference + freq = np.asarray(psd_datasets[0]["frequencies"]) + tar_psd = np.asarray(psd_datasets[0]["psd_target"]) + + fig, (ax_spec, ax_ratio) = plt.subplots( + 2, 1, figsize=self.fig_size or (10, 8), + gridspec_kw={"height_ratios": [2, 1], "hspace": 0.08}, + ) + + # Upper panel: log-log spectra + ax_spec.loglog(freq, tar_psd, color="black", lw=1.5, label="Target") + colors = plt.cm.tab10.colors + for i, (ds, label) in enumerate(zip(psd_datasets, labels, strict=False)): + c = colors[i % len(colors)] + ax_spec.loglog( + np.asarray(ds["frequencies"]), + np.asarray(ds["psd_prediction"]), + color=c, lw=1.5, label=label, + ) + ax_spec.set_ylabel("Power") + ax_spec.set_title("PSD summary") + ax_spec.legend(frameon=False, fontsize=7) + ax_spec.grid(True, which="both", ls="--", alpha=0.4) + + # Lower panel: ratio (pred / target) + for i, (ds, label) in enumerate(zip(psd_datasets, labels, strict=False)): + c = colors[i % len(colors)] + pred = np.asarray(ds["psd_prediction"]) + with np.errstate(divide="ignore", invalid="ignore"): + ratio = np.where(tar_psd > 0, pred / tar_psd, np.nan) + ax_ratio.semilogx(freq, ratio, color=c, lw=1.2, label=label) + ax_ratio.axhline(1.0, ls="--", color="gray", lw=0.8) + ax_ratio.set_ylabel("Pred / Target") + ax_ratio.set_xlabel("Frequency (1/deg)") + ax_ratio.set_ylim(0, 2) + ax_ratio.grid(True, which="both", ls="--", alpha=0.4) + + name = tag or "psd" + fname = out_dir / f"{name}.{self.image_format}" + _logger.info(f"Saving PSD summary plot to {fname}") + fig.savefig(str(fname), bbox_inches="tight", dpi=self.dpi_val) + plt.close(fig) diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py b/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py index 0df1c59e5..1d45930c3 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py @@ -29,6 +29,7 @@ bar_plot_metric_region, heat_maps_metric_region, plot_metric_region, + psd_plot_metric_region, quantile_plot_metric_region, ratio_plot_metric_region, score_card_metric_region, @@ -479,10 +480,11 @@ def _plot_all_samples( ) -> None: """Plot histograms across all samples for a single fstep. - Unlike per-sample histograms, these aggregate all samples together. + Unlike per-sample plots, these aggregate all samples together. The output filename uses 'global' instead of a sample id and omits the timestep. """ - if not (plot_histograms is True or plot_histograms == "across-samples"): + has_work = (plot_histograms is True or plot_histograms == "across-samples") + if not has_work: return matplotlib.use("Agg") @@ -496,14 +498,15 @@ def _plot_all_samples( preds_tag = "" if "ens" not in preds.dims else f"ens_{ens}" preds_name = "_".join(filter(None, ["preds", preds_tag])) - plotter.create_histograms( - tars, - preds_ens, - plot_chs, - data_selection, - preds_name, - ranges=maps_config, - ) + if plot_histograms is True or plot_histograms == "across-samples": + plotter.create_histograms( + tars, + preds_ens, + plot_chs, + data_selection, + preds_name, + ranges=maps_config, + ) plotter.clean_data_selection() @@ -788,7 +791,7 @@ def plot_summary(cfg: dict, scores_dict: dict, summary_dir: Path): quantile_plotter = QuantilePlots(plot_cfg, summary_dir) for region in regions: for metric in metrics: - if eval_opt.get("summary_plots", False): + if eval_opt.get("summary_plots", False) and metric != "psd": plot_metric_region(metric, region, runs, scores_dict, plotter, print_summary) if eval_opt.get("ratio_plots", False): ratio_plot_metric_region(metric, region, runs, scores_dict, plotter, print_summary) @@ -800,3 +803,5 @@ def plot_summary(cfg: dict, scores_dict: dict, summary_dir: Path): bar_plot_metric_region(metric, region, runs, scores_dict, br_plotter) if metric == "qq_analysis": quantile_plot_metric_region(metric, region, runs, scores_dict, quantile_plotter) + if metric == "psd": + psd_plot_metric_region(metric, region, runs, scores_dict, plotter) diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py b/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py index 5fd501dd5..4851f0484 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py @@ -768,6 +768,116 @@ def quantile_plot_metric_region( ) +def psd_plot_metric_region( + metric: str, + region: str, + runs: dict, + scores_dict: dict, + plotter: object, +) -> None: + """Create PSD plots for all streams and channels for a given metric and region. + + Follows the same pattern as ``quantile_plot_metric_region``: the PSD + curves (frequencies, target PSD, prediction PSD) are stored in + ``score.attrs`` by ``Scores.calc_psd`` and read back here. + + Parameters + ---------- + metric : str + Metric name (should be ``"psd"``). + region : str + Region name. + runs : dict + Run config dict (run_id → config). + scores_dict : dict + Nested score dict ``{metric: {region: {stream: {run_id: DataArray}}}}``. + plotter : object + Plotter that has a ``psd_summary_plot`` method (or a generic line plotter). + """ + streams_set = collect_streams(runs) + channels_set = collect_channels(scores_dict, metric, region, runs) + + for stream in streams_set: + for ch in channels_set: + for run_id, data in scores_dict[metric][region].get(stream, {}).items(): + if ch not in np.atleast_1d(data.channel.values): + continue + + data_ch = data.sel(channel=ch) if "channel" in data.dims else data + + if data_ch.isnull().all(): + continue + + # Get list of fsteps that have PSD attrs + attr_fsteps = data_ch.attrs.get("attr_fsteps", []) + if not attr_fsteps: + _logger.warning( + f"PSD attrs missing for {run_id}/{stream}/{ch}. Skipping." + ) + continue + + label = runs[run_id].get("label", run_id) + + for fstep in attr_fsteps: + # Look up per-fstep, per-channel attrs + fstep_prefix = f"fstep_{fstep}/" + ch_prefix = f"{ch}/" + + # Try fstep+channel prefix first, then fstep-only + freq_key = None + for candidate in [ + f"{fstep_prefix}{ch_prefix}frequencies", + f"{fstep_prefix}frequencies", + ]: + if candidate in data_ch.attrs: + freq_key = candidate + break + + if freq_key is None: + continue + + # Derive the other keys from the same prefix pattern + key_prefix = freq_key.replace("frequencies", "") + psd_t_key = f"{key_prefix}psd_target" + psd_p_key = f"{key_prefix}psd_prediction" + + if psd_t_key not in data_ch.attrs: + continue + + psd_datasets = [ + { + "frequencies": np.array(data_ch.attrs[freq_key]), + "psd_target": np.array(data_ch.attrs[psd_t_key]), + "psd_prediction": np.array(data_ch.attrs[psd_p_key]), + "psd_method": data_ch.attrs.get( + f"{fstep_prefix}psd_method", + data_ch.attrs.get("psd_method", "sht"), + ), + } + ] + + # Resolve lead time for title + lead_time_str = f"fstep {fstep}" + if "lead_time" in data_ch.coords and "forecast_step" in data_ch.dims: + try: + lt = int(data_ch.coords["lead_time"].sel(forecast_step=fstep).values) + if lt > 0: + lead_time_str = f"{lt}h" + except Exception: + pass + + _logger.info( + f"Creating PSD plot for {metric} - {region} - {stream} - " + f"{ch} - {lead_time_str}." + ) + name = create_filename( + prefix=[metric, region], + middle=[run_id], + suffix=[stream, ch, f"fstep{fstep}"], + ) + plotter.psd_plot(psd_datasets, [label], tag=name) + + def create_filename( *, prefix: Sequence[str] = (), diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py index 97e0840a6..37b3a67d8 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py @@ -478,6 +478,269 @@ def plot_histogram( return name + # ------------------------------------------------------------------ + # PSD plots + # ------------------------------------------------------------------ + + # -- PSD plot annotation helpers (from psd_plots.py) ---------------- + + @staticmethod + def _psd_add_wavenumbers(ax: plt.Axes) -> None: + """Add vertical lines at selected total wavenumbers.""" + yscale = ax.yaxis.get_scale() + ylims = ax.get_ylim() + if yscale == "log": + ytxt = 10.0 ** (0.85 * (np.log10(ylims[1] / ylims[0])) + np.log10(ylims[0])) + else: + ytxt = 0.85 * (ylims[1] - ylims[0]) + ylims[0] + + for wvn in (1, 2, 4, 8, 16, 24, 48, 96, 144, 216, 320, 640, 1280, 2560): + ax.plot( + [wvn / 360.0, wvn / 360.0], ylims, + color="black", lw=0.6, scalex=False, scaley=False, + ) + ax.text(wvn / 360.0, ytxt, f"n{wvn}", rotation="vertical", fontsize=6) + + @staticmethod + def _psd_add_lengths(ax: plt.Axes, lat_center: float = 0.0) -> None: + """Add vertical dashed lines at selected physical length scales (km).""" + re = 6.37e6 # earth radius in metres + yscale = ax.yaxis.get_scale() + ylims = ax.get_ylim() + if yscale == "log": + ytxt = 10.0 ** (0.05 * (np.log10(ylims[1] / ylims[0])) + np.log10(ylims[0])) + else: + ytxt = 0.05 * (ylims[1] - ylims[0]) + ylims[0] + + lengths_km = np.array([1.0e4, 3.0e3, 1.0e3, 3.0e2, 1.0e2, 3.0e1, 1.0e1]) + f_lengths = ( + 2.0 * np.pi * re * np.cos(np.radians(lat_center)) / (1000.0 * lengths_km * 360.0) + ) + for fl, lkm in zip(f_lengths, lengths_km): + ax.plot( + [fl, fl], ylims, + color="black", ls="--", lw=0.6, scalex=False, scaley=False, + ) + ax.text(fl, ytxt, f"{lkm:.0f}km", rotation="vertical", fontsize=6) + + @staticmethod + def _psd_add_ideal_slope( + ax: plt.Axes, slope: float = -3.0, + x_range: tuple[float, float] = (0.01, 0.1), y0: float = 10.0, + ) -> None: + """Add an idealised slope line on a log-log axes.""" + xs = np.array(x_range) + ys = y0 * np.array([1.0, (xs[1] / xs[0]) ** slope]) + ax.plot(xs, ys, color="black", lw=2.0, scalex=False, scaley=False) + xt = np.sqrt(np.prod(xs)) + yt = np.sqrt(np.prod(ys)) + ax.text(xt, yt, f"$k^{{{slope:.0f}}}$", fontsize=10, weight="bold") + + def plot_psd( + self, + freq: np.ndarray, + tar_psd: np.ndarray, + pred_psd: np.ndarray, + psd_output_dir: Path, + varname: str, + tag: str = "", + region: str = "", + ) -> str: + """Plot power spectral density: log-log spectra + ratio. + + Parameters + ---------- + freq : np.ndarray + Positive frequencies (or wavenumbers). + tar_psd : np.ndarray + Target PSD values. + pred_psd : np.ndarray + Prediction PSD values. + psd_output_dir : Path + Output directory for PSD plots. + varname : str + Variable / channel name. + tag, region : str + Filename parts (same convention as histograms / maps). + + Returns + ------- + str + Plot name (without extension). + """ + fig, (ax_spec, ax_ratio) = plt.subplots( + 2, 1, figsize=self.fig_size or (8, 8), + gridspec_kw={"height_ratios": [2, 1], "hspace": 0.08}, + ) + + # Upper panel: log-log spectra + ax_spec.loglog(freq, tar_psd, color="black", lw=1.5, label="Target") + ax_spec.loglog(freq, pred_psd, color="#00897B", lw=1.5, label="Prediction") + ax_spec.set_ylabel("Power") + ax_spec.set_title(f"PSD: {self.stream}, {varname}") + ax_spec.legend(frameon=False) + ax_spec.grid(True, which="both", ls="--", alpha=0.4) + ax_spec.set_xlim(1.0e-3, 1.0e1) + self._psd_add_wavenumbers(ax_spec) + self._psd_add_ideal_slope(ax_spec) + + # Lower panel: ratio + with np.errstate(divide="ignore", invalid="ignore"): + ratio = np.where(tar_psd > 0, pred_psd / tar_psd, np.nan) + ax_ratio.semilogx(freq, ratio, color="#00897B", lw=1.2) + ax_ratio.axhline(1.0, ls="--", color="gray", lw=0.8) + ax_ratio.set_ylabel("Pred / Target") + ax_ratio.set_xlabel("Frequency (1/deg)") + ax_ratio.set_ylim(0, 2) + ax_ratio.set_xlim(1.0e-3, 1.0e1) + ax_ratio.grid(True, which="both", ls="--", alpha=0.4) + self._psd_add_wavenumbers(ax_ratio) + + # Build filename (same pattern as maps / histograms) + is_global = str(self.sample) == "all_samples" + valid_time = None # PSD is always aggregated across time + parts = [ + "psd", + str(self.run_id), + str(tag) if tag else "", + str(self.sample), + valid_time, + str(self.stream), + region if region else "", + varname, + f"{self.fstep:03d}", + ] + name = "_".join(filter(None, parts)) + + fname = psd_output_dir / f"{name}.{self.image_format}" + _logger.debug(f"Saving PSD plot to {fname}") + fig.savefig(fname, bbox_inches="tight") + plt.close(fig) + + return name + + def create_psd_plots( + self, + target: xr.DataArray, + preds: xr.DataArray, + variables: list, + select: dict, + tag: str = "", + psd_method: str = "sht", + psd_regions: list[str] | None = None, + psd_regrid_resolution: float = 1.0, + psd_sht_truncation: int | None = None, + ) -> list[str]: + """Compute and plot PSD for target and prediction. + + Parameters + ---------- + target, preds : xr.DataArray + Target / prediction arrays. + variables : list + List of channel names to process. + select : dict + Data selection dict (sample, stream, forecast_step). + tag : str + Filename tag (e.g., ``"preds_ens_0"``). + psd_method : str + ``"sht"`` or ``"zonal"``. + psd_regions : list[str] | None + Region names. ``None`` → use ``self.regions``. + psd_regrid_resolution : float + Grid spacing for the zonal method. + psd_sht_truncation : int | None + Spectral truncation for SHT method. + + Returns + ------- + list[str] + Names of saved PSD plots. + """ + from weathergen.evaluate.scores.psd import compute_psd_for_field + + self.update_data_selection(select) + psd_output_dir = self.get_psd_output_dir() + os.makedirs(psd_output_dir, exist_ok=True) + + regions = psd_regions or self.regions + plot_names: list[str] = [] + + for region in regions: + if region != "global": + bbox = RegionBoundingBox.from_region_name(region) + reg_target = bbox.apply_mask(target) + reg_preds = bbox.apply_mask(preds) + else: + reg_target = target + reg_preds = preds + + for var in variables: + select_var = self.select | {"channel": var} + targ = self.select_from_da(reg_target, select_var).dropna(dim="ipoint") + prd = self.select_from_da(reg_preds, select_var).dropna(dim="ipoint") + + if targ.size == 0 or prd.size == 0: + _logger.warning(f"PSD: empty data for {var} in {region}. Skipping.") + continue + + targ_np = targ.values + prd_np = prd.values + + # Determine nlat from the grid if available + nlat = None + if "lat" in targ.coords: + nlat = len(np.unique(targ.coords["lat"].values)) + elif hasattr(targ, "attrs") and "nlat" in targ.attrs: + nlat = int(targ.attrs["nlat"]) + + lats = targ.coords["lat"].values if "lat" in targ.coords else None + lons = targ.coords["lon"].values if "lon" in targ.coords else None + + try: + freq_tar, psd_tar = compute_psd_for_field( + data=targ_np, + method=psd_method, + nlat=nlat, + lats=lats, + lons=lons, + lat_range=(-60.0, 60.0), + regrid_resolution=psd_regrid_resolution, + sht_truncation=psd_sht_truncation, + ) + freq_prd, psd_prd = compute_psd_for_field( + data=prd_np, + method=psd_method, + nlat=nlat, + lats=lats, + lons=lons, + lat_range=(-60.0, 60.0), + regrid_resolution=psd_regrid_resolution, + sht_truncation=psd_sht_truncation, + ) + except Exception: + _logger.exception(f"PSD computation failed for {var} in {region}") + continue + + name = self.plot_psd( + freq_tar, psd_tar, psd_prd, psd_output_dir, var, + tag=tag, region=region, + ) + plot_names.append(name) + + self.clean_data_selection() + return plot_names + + def get_psd_output_dir(self) -> Path: + """Return the output directory path for PSD plots. + + Returns + ------- + Path + Resolved directory path: ``//psd``. + """ + return self.out_plot_basedir / self.stream / "psd" + def create_maps_per_sample( self, data: xr.DataArray, diff --git a/packages/evaluate/src/weathergen/evaluate/scores/psd.py b/packages/evaluate/src/weathergen/evaluate/scores/psd.py new file mode 100644 index 000000000..d484c3246 --- /dev/null +++ b/packages/evaluate/src/weathergen/evaluate/scores/psd.py @@ -0,0 +1,566 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +"""Power Spectral Density (PSD) computation. + +Provides two PSD computation paths: + +- **Path A – SHT-based PSD** (``method="sht"``): + Spherical Harmonic Transform on unstructured grids (octahedral, reduced + Gaussian, regular lat-lon). Ported from ``spectral_transforms.py`` to pure + numpy using Legendre helpers from ``spectral_helpers.py``. + +- **Path B – Zonal FFT PSD** (``method="zonal"``): + 1-D zonal FFT along the longitude dimension on a regular lat-lon grid. + Absorbs the functions previously in ``example_extras/power_spectra/psd_calc.py``. +""" + +from __future__ import annotations + +import logging +import numpy as np + +_logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Numpy-based Spherical Harmonic Transform (ported from spectral_helpers.py) +# --------------------------------------------------------------------------- + + +def _legendre_gauss_weights(n: int, a: float = -1.0, b: float = 1.0) -> tuple[np.ndarray, np.ndarray]: + """Return Legendre-Gauss nodes and weights on ``[a, b]``.""" + xlg, wlg = np.polynomial.legendre.leggauss(n) + xlg = (b - a) * 0.5 * xlg + (b + a) * 0.5 + wlg = wlg * (b - a) * 0.5 + return xlg, wlg + + +def _legpoly(mmax: int, lmax: int, x: np.ndarray, inverse: bool = False) -> np.ndarray: + """Compute associated Legendre polynomials. + + Returns shape ``(mmax+1, lmax+1, len(x))``. + """ + nmax = max(mmax, lmax) + vdm = np.zeros((nmax + 1, nmax + 1, len(x)), dtype=np.float64) + + norm_factor = np.sqrt(4 * np.pi) + norm_factor = 1.0 / norm_factor if inverse else norm_factor + vdm[0, 0, :] = norm_factor / np.sqrt(4 * np.pi) + + for n in range(1, nmax + 1): + vdm[n - 1, n, :] = np.sqrt(2 * n + 1) * x * vdm[n - 1, n - 1, :] + vdm[n, n, :] = np.sqrt((2 * n + 1) * (1 + x) * (1 - x) / 2 / n) * vdm[n - 1, n - 1, :] + + for n in range(2, nmax + 1): + for m in range(0, n - 1): + vdm[m, n, :] = ( + x * np.sqrt((2 * n - 1) / (n - m) * (2 * n + 1) / (n + m)) * vdm[m, n - 1, :] + - np.sqrt((n + m - 1) / (n - m) * (2 * n + 1) / (2 * n - 3) * (n - m - 1) / (n + m)) + * vdm[m, n - 2, :] + ) + + return vdm[: mmax + 1, : lmax + 1] + + +class SphericalHarmonicTransform: + """Spherical Harmonic Transform in pure numpy. + + Mirrors the ``SphericalHarmonicTransform`` from ``spectral_helpers.py`` in anemoi.models + but operates on numpy arrays rather than torch tensors. + + Parameters + ---------- + lons_per_lat : list[int] + Number of longitude points on each latitude ring (pole to pole). + truncation : int + Maximum total wavenumber to retain. + """ + + def __init__(self, lons_per_lat: list[int], truncation: int) -> None: + self.lons_per_lat = lons_per_lat + self.nlat = len(lons_per_lat) + self.truncation = truncation + assert 0 < truncation <= self.nlat, ( + f"Truncation {truncation} must be in (0, {self.nlat}]" + ) + self.n_grid_points = sum(lons_per_lat) + + # Offsets into the flattened grid for each latitude ring + self.slon = [0] + list(np.cumsum(lons_per_lat))[:-1] + + # Whether all rings have the same number of points (regular grid) + self._is_regular = len(set(lons_per_lat)) == 1 + + # Precompute Gaussian latitudes + quadrature weights + theta, weight = _legendre_gauss_weights(self.nlat) + theta = np.flip(np.arccos(theta)) + + # Associated Legendre polynomials (m, l, lat) + pct = _legpoly(truncation, truncation, np.cos(theta)) + + # Pre-multiply by quadrature weights → shape (m, l, lat) + self.weight = np.einsum("mlk,k->mlk", pct, weight) + + # -- internal FFT helpers ----------------------------------------------- + + def _rfft_regular(self, x: np.ndarray) -> np.ndarray: + """Batched real FFT for a *regular* grid. + + Parameters + ---------- + x : np.ndarray, shape ``(..., grid)`` + + Returns + ------- + np.ndarray, complex, shape ``(..., nlat, nlon//2+1)`` + """ + nlon = self.lons_per_lat[0] + return np.fft.rfft(x.reshape(*x.shape[:-1], self.nlat, nlon), norm="forward") + + def _rfft_reduced(self, x: np.ndarray) -> np.ndarray: + """Per-ring real FFT for a *reduced* (variable-resolution) grid. + + Parameters + ---------- + x : np.ndarray, shape ``(..., grid)`` + + Returns + ------- + np.ndarray, complex, shape ``(..., nlat, max_nlon//2+1)`` + """ + max_nlon = max(self.lons_per_lat) + out_shape = (*x.shape[:-1], self.nlat, max_nlon // 2 + 1) + out = np.zeros(out_shape, dtype=np.complex128) + + for i, (slon, nlon) in enumerate(zip(self.slon, self.lons_per_lat)): + out[..., i, : nlon // 2 + 1] = np.fft.rfft( + x[..., slon : slon + nlon], norm="forward" + ) + return out + + # -- transform --------------------------------------------------- + + def transform(self, x: np.ndarray) -> np.ndarray: + """Compute the SHT. + + Parameters + ---------- + x : np.ndarray, real, shape ``(..., grid)`` + + Returns + ------- + np.ndarray, complex, shape ``(..., L, M)`` where + ``L = M = truncation + 1``. + """ + if self._is_regular: + x_fft = self._rfft_regular(x) + else: + x_fft = self._rfft_reduced(x) + + x_fft = 2.0 * np.pi * x_fft + + real_part = x_fft[..., : self.truncation + 1].real + imag_part = x_fft[..., : self.truncation + 1].imag + + rl = np.einsum("...km,mlk->...lm", real_part, self.weight) + im = np.einsum("...km,mlk->...lm", imag_part, self.weight) + + return rl + 1j * im + + +# --------------------------------------------------------------------------- +# Grid helpers for building lons_per_lat +# --------------------------------------------------------------------------- + + +def _octahedral_lons_per_lat(nlat: int) -> list[int]: + """Return lons_per_lat for an octahedral reduced Gaussian grid.""" + half = [20 + 4 * i for i in range(nlat // 2)] + return half + list(reversed(half)) + + +def _regular_lons_per_lat(nlat: int) -> list[int]: + """Return lons_per_lat for a regular lat-lon grid (nlon = 2*nlat).""" + return [2 * nlat] * nlat + + +# --------------------------------------------------------------------------- +# High-level SHT PSD +# --------------------------------------------------------------------------- + + +def sht_psd( + data: np.ndarray, + nlat: int, + truncation: int | None = None, + grid_type: str = "octahedral", +) -> tuple[np.ndarray, np.ndarray]: + """Compute PSD via Spherical Harmonic Transform. + + 1. Forward SHT: spatial → spectral coefficients ``(l, m)``. + 2. PSD: L2-norm over ``m`` for each total wavenumber ``l``. + + Parameters + ---------- + data : np.ndarray + Spatial field with shape ``(n_points,)`` or ``(n_samples, n_points)``. + nlat : int + Number of latitudes in the grid. + truncation : int | None + Spectral truncation. Defaults to ``nlat // 2 - 1``. + grid_type : str + One of ``"octahedral"``, ``"regular"``, ``"reduced"``. + + Returns + ------- + wavenumbers : np.ndarray, shape ``(L,)`` + Total wavenumber indices ``0, 1, …, L-1``. + psd : np.ndarray, shape ``(L,)`` + Power spectral density averaged over samples. + """ + if data.ndim == 1: + data = data[np.newaxis, :] + n_samples, n_points = data.shape + + # Build the SHT for the appropriate grid + if grid_type == "octahedral": + lons_per_lat = _octahedral_lons_per_lat(nlat) + elif grid_type == "regular": + lons_per_lat = _regular_lons_per_lat(nlat) + elif grid_type == "reduced": + try: + from anemoi.transform.grids.named import lookup + except ImportError: + raise ImportError( + "anemoi.transform is required for grid_type='reduced'. " + "Install: pip install anemoi-transform" + ) from None + lats = lookup("N320")["latitudes"] + unique_lats = sorted(set(lats)) + lons_per_lat = [int((lats == lat).sum()) for lat in unique_lats] + else: + raise ValueError(f"Unknown grid_type: {grid_type!r}") + + trunc = truncation or nlat // 2 - 1 + sht = SphericalHarmonicTransform(lons_per_lat=lons_per_lat, truncation=trunc) + + assert n_points == sht.n_grid_points, ( + f"Input points={n_points} != expected grid points={sht.n_grid_points} " + f"for grid_type={grid_type!r}, nlat={nlat}" + ) + + # SphericalHarmonicTransform.transform accepts (..., grid) → (..., L, M) + # Pass (n_samples, n_points) directly. + coeffs = sht.transform(data) # (n_samples, L, M) + + # PSD = sum |coeffs|^2 over m for each total wavenumber l, averaged over samples + psd_per_sample = np.sum(np.abs(coeffs) ** 2, axis=-1) # (n_samples, L) + psd = psd_per_sample.mean(axis=0) + + L = psd.shape[0] + wavenumbers = np.arange(L, dtype=np.float64) + + return wavenumbers, psd + + +# --------------------------------------------------------------------------- +# Zonal FFT PSD (absorbed from psd_calc.py) +# --------------------------------------------------------------------------- + + +class ZonalPSD: + """Zonal power spectral density via 1-D FFT along the longitude dimension. + + This class absorbs the functionality previously in + ``example_extras/power_spectra/psd_calc.py``. + """ + + @staticmethod + def psd_1d(ht: np.ndarray) -> np.ndarray: + """Return the PSD for positive non-zero frequencies of an even-length signal. + + Parameters + ---------- + ht : np.ndarray + 1-D real-valued signal (one latitude ring). + + Returns + ------- + np.ndarray + PSD for positive frequencies, length ``n // 2``. + """ + n = len(ht) + hf = np.fft.rfft(ht, norm="forward") + power = np.abs(hf[1 : round(n / 2 + 1)]) ** 2 + power *= 2.0 # compensate for positive frequencies only + return power + + @staticmethod + def positive_frequencies(npoints: int, spacing_deg: float = 1.0) -> np.ndarray: + """Return the positive frequencies for a signal of *npoints* evenly spaced points. + + Parameters + ---------- + npoints : int + Number of equally-spaced longitude points. + spacing_deg : float + Grid spacing in degrees. Default is ``360 / npoints``. + + Returns + ------- + np.ndarray + Positive frequencies, length ``npoints // 2``. + """ + freq = np.fft.fftfreq(npoints, d=spacing_deg) + return np.abs(freq[1 : round(npoints / 2 + 1)]) + + @classmethod + def compute( + cls, + field_2d: np.ndarray, + ) -> np.ndarray: + """Compute the zonal PSD averaged over all latitude rows. + + Parameters + ---------- + field_2d : np.ndarray + 2-D array of shape ``(nlat, nlon)``. + + Returns + ------- + np.ndarray + PSD of shape ``(nlon // 2,)``. + """ + nlat, nlon = field_2d.shape + psd_accum = np.zeros(nlon // 2) + for row in field_2d: + psd_accum += cls.psd_1d(row) + psd_accum /= nlat + return psd_accum + + +def zonal_psd( + data: np.ndarray, + lats: np.ndarray, + lons: np.ndarray, + lat_range: tuple[float, float] = (-60.0, 60.0), + regrid_resolution: float = 1.0, +) -> tuple[np.ndarray, np.ndarray]: + """Compute zonal PSD using 1-D FFT along the longitude dimension. + + The input data is expected to be on a regular lat-lon grid already (or + will be interpolated/regridded by the caller). + + Parameters + ---------- + data : np.ndarray + Field values. If 1-D, interpreted as flattened ``(nlat * nlon,)``. + If 2-D, interpreted as ``(nlat, nlon)`` or ``(n_samples, nlat * nlon)``. + If 3-D, interpreted as ``(n_samples, nlat, nlon)``. + lats : np.ndarray + 1-D array of latitude values (descending, length ``nlat``). + lons : np.ndarray + 1-D array of longitude values (ascending, length ``nlon``). + lat_range : tuple[float, float] + Latitude bounds to restrict the computation to. + regrid_resolution : float + Grid spacing in degrees (used only for frequency calculation). + + Returns + ------- + frequencies : np.ndarray + Positive frequencies in cycles per degree, shape ``(nfreq,)``. + psd : np.ndarray + Power spectral density averaged over samples and latitude rows, + shape ``(nfreq,)``. + """ + nlat = len(lats) + nlon = len(lons) + + # Reshape to (n_samples, nlat, nlon) + if data.ndim == 1: + data = data.reshape(1, nlat, nlon) + elif data.ndim == 2: + if data.shape == (nlat, nlon): + data = data[np.newaxis, :, :] + else: + # (n_samples, nlat * nlon) + data = data.reshape(data.shape[0], nlat, nlon) + # data is now (n_samples, nlat, nlon) + + # Apply latitude mask + lat_mask = (lats >= lat_range[0]) & (lats <= lat_range[1]) + data = data[:, lat_mask, :] + nlon_sub = data.shape[2] + + # Compute PSD per sample and average + psds = [] + for s in range(data.shape[0]): + psds.append(ZonalPSD.compute(data[s])) + psd = np.mean(psds, axis=0) + + spacing = 360.0 / nlon_sub if nlon_sub > 0 else regrid_resolution + frequencies = ZonalPSD.positive_frequencies(nlon_sub, spacing_deg=spacing) + + return frequencies, psd + + +# --------------------------------------------------------------------------- +# Dispatch +# --------------------------------------------------------------------------- + + +def compute_psd_for_field( + data: np.ndarray, + method: str = "sht", + nlat: int | None = None, + lats: np.ndarray | None = None, + lons: np.ndarray | None = None, + lat_range: tuple[float, float] = (-60.0, 60.0), + regrid_resolution: float = 1.0, + sht_truncation: int | None = None, + grid_type: str = "octahedral", +) -> tuple[np.ndarray, np.ndarray]: + """Compute PSD using the selected method. + + Parameters + ---------- + data : np.ndarray + Spatial field. Shape depends on the method (see ``sht_psd`` / ``zonal_psd``). + method : str + ``"sht"`` for SHT-based PSD, ``"zonal"`` for zonal FFT PSD. + nlat : int | None + Number of latitudes (required for SHT method). + lats, lons : np.ndarray | None + Latitude / longitude coordinate arrays (required for zonal method). + lat_range : tuple[float, float] + Latitude bounds for the zonal method. + regrid_resolution : float + Grid spacing in degrees for the zonal method. + sht_truncation : int | None + Spectral truncation for SHT. + grid_type : str + Grid type for SHT (``"octahedral"``, ``"regular"``, ``"reduced"``). + + Returns + ------- + x_values : np.ndarray + Wavenumbers (SHT) or positive frequencies (zonal). + psd : np.ndarray + Power spectral density. + """ + if method == "sht": + if nlat is None: + raise ValueError("nlat is required for method='sht'") + return sht_psd( + data=data, + nlat=nlat, + truncation=sht_truncation, + grid_type=grid_type, + ) + elif method == "zonal": + if lats is None or lons is None: + raise ValueError("lats and lons are required for method='zonal'") + return zonal_psd( + data=data, + lats=lats, + lons=lons, + lat_range=lat_range, + regrid_resolution=regrid_resolution, + ) + else: + raise ValueError(f"Unknown PSD method: {method!r}. Use 'sht' or 'zonal'.") + + +def compute_psd_score( + gt: np.ndarray, + p: np.ndarray, + lats: np.ndarray | None, + lons: np.ndarray | None, + nlat: int | None, + n_points: int, + psd_method: str = "sht", + psd_regrid_resolution: float = 1.0, + psd_sht_truncation: int | None = None, + lat_range: tuple[float, float] = (-60.0, 60.0), +) -> tuple[float, dict]: + """Compute PSD for a pair of 2-D fields and return a scalar score + curves. + + This is the main entry point called from the Scores class. It handles NaN + masking, calls ``compute_psd_for_field`` for both inputs, and computes a + log-spectral MSE summary score. + + Parameters + ---------- + gt, p : np.ndarray + Ground truth and prediction arrays of shape ``(n_samples, n_points)``. + lats, lons : np.ndarray | None + Latitude / longitude arrays of length ``n_points`` (or None). + nlat : int | None + Number of latitudes (for SHT fallback). + n_points : int + Original number of spatial points (before NaN masking). + psd_method : str + ``"sht"`` or ``"zonal"``. + psd_regrid_resolution : float + Grid spacing for zonal method. + psd_sht_truncation : int | None + Spectral truncation for SHT. + lat_range : tuple[float, float] + Latitude bounds for zonal method. + + Returns + ------- + score : float + Log-spectral MSE scalar. + attrs : dict + Dict with keys ``"frequencies"``, ``"psd_target"``, ``"psd_prediction"`` + (lists for JSON serialization). + """ + # Drop NaN columns (masked grid points) + valid_mask = ~np.isnan(gt).all(axis=0) + gt = gt[:, valid_mask] + p = p[:, valid_mask] + + # Filter lat/lon to match valid points + lats_valid = lats[valid_mask] if lats is not None and len(lats) == n_points else lats + lons_valid = lons[valid_mask] if lons is not None and len(lons) == n_points else lons + nlat_valid = len(np.unique(lats_valid)) if lats_valid is not None else nlat + + try: + freq_gt, psd_gt = compute_psd_for_field( + data=gt, method=psd_method, nlat=nlat_valid, lats=lats_valid, lons=lons_valid, + lat_range=lat_range, regrid_resolution=psd_regrid_resolution, + sht_truncation=psd_sht_truncation, + ) + freq_p, psd_p = compute_psd_for_field( + data=p, method=psd_method, nlat=nlat_valid, lats=lats_valid, lons=lons_valid, + lat_range=lat_range, regrid_resolution=psd_regrid_resolution, + sht_truncation=psd_sht_truncation, + ) + except Exception: + import logging + logging.getLogger(__name__).exception("PSD computation failed, returning NaN.") + return np.nan, {} + + # Scalar summary: mean squared error of log10 PSD + valid = (psd_gt > 0) & (psd_p > 0) + if valid.any(): + log_mse = float(np.mean((np.log10(psd_p[valid]) - np.log10(psd_gt[valid])) ** 2)) + else: + log_mse = np.nan + + attrs = { + "frequencies": freq_gt.tolist(), + "psd_target": psd_gt.tolist(), + "psd_prediction": psd_p.tolist(), + } + + return log_mse, attrs diff --git a/packages/evaluate/src/weathergen/evaluate/scores/score.py b/packages/evaluate/src/weathergen/evaluate/scores/score.py index afee96d9a..dd951f009 100755 --- a/packages/evaluate/src/weathergen/evaluate/scores/score.py +++ b/packages/evaluate/src/weathergen/evaluate/scores/score.py @@ -16,6 +16,7 @@ import xarray as xr from scipy.spatial import cKDTree +from weathergen.evaluate.scores.psd import compute_psd_score from weathergen.evaluate.scores.score_utils import to_list # from common.io import MockIO @@ -198,6 +199,7 @@ def __init__( "seeps": self.calc_seeps, "qq_analysis": self.calc_quantiles, "nse": self.calc_nse, + "psd": self.calc_psd, } self.prob_metrics_dict = { "ssr": self.calc_ssr, @@ -1807,3 +1809,167 @@ def calc_quantiles( _logger.info(f"Q-Q analysis completed with {len(overall_qq_score.attrs)} attributes") return overall_qq_score + + def calc_psd( + self, + p: xr.DataArray, + gt: xr.DataArray, + psd_method: str = "sht", + psd_regrid_resolution: float = 1.0, + psd_sht_truncation: int | None = None, + lat_range: tuple[float, float] = (-60.0, 60.0), + ) -> xr.DataArray: + """Compute power spectral density for prediction and ground truth. + + Returns a scalar summary score (log-spectral MSE) and stores the full + PSD curves in ``.attrs`` so that ``psd_plot_metric_region`` in the + summary-plot phase can produce spectral plots — following the same + pattern as ``calc_quantiles`` / ``qq_analysis``. + + Parameters + ---------- + p : xr.DataArray + Prediction data. + gt : xr.DataArray + Ground truth data. + psd_method : str + ``"sht"`` (default) or ``"zonal"``. + psd_regrid_resolution : float + Grid spacing for the zonal method (degrees). + psd_sht_truncation : int | None + Spectral truncation for the SHT method. + lat_range : tuple[float, float] + Latitude range for the zonal method. + + Returns + ------- + xr.DataArray + Scalar score (log-spectral MSE) with PSD data stored in ``.attrs``. + """ + + if self._agg_dims is None: + raise ValueError("Cannot calculate PSD without aggregation dimensions.") + + # PSD expects exactly one spatial aggregation dimension (e.g. "ipoint"). + if len(self._agg_dims) != 1: + raise ValueError( + f"PSD expects exactly one spatial aggregation dimension (points), " + f"but got agg_dims={self._agg_dims}. Do not flatten over multiple dims." + ) + spatial_dim = self._agg_dims[0] + if spatial_dim not in gt.dims: + raise ValueError( + f"Spatial dimension '{spatial_dim}' not found in data dims {list(gt.dims)}." + ) + + # The data arriving here has dims like (sample, channel, ipoint). + # PSD must be computed per channel (and averaged over samples). + # We iterate over all non-spatial dims except the "sample" axis, + # which becomes the sample dimension for the SHT. + other_dims = [d for d in gt.dims if d != spatial_dim] + _logger.debug( + f"PSD: spatial_dim={spatial_dim}, other_dims={other_dims}, " + f"shape p={dict(zip(p.dims, p.shape))}, gt={dict(zip(gt.dims, gt.shape))}" + ) + + # Determine grid info from coords (same for all channels/samples) + n_points = gt.sizes[spatial_dim] + nlat = None + lats, lons = None, None + if "lat" in gt.coords and "lon" in gt.coords: + if gt.coords["lat"].dims == (spatial_dim,) and gt.coords["lon"].dims == (spatial_dim,): + lats = gt.coords["lat"].values + lons = gt.coords["lon"].values + nlat = len(np.unique(lats)) + else: + _logger.warning( + f"PSD: lat/lon coords exist but have unexpected dims " + f"(lat: {gt.coords['lat'].dims}, lon: {gt.coords['lon'].dims}). " + f"Expected ({spatial_dim},)." + ) + + if nlat is None and lats is None: + raise ValueError( + "PSD requires grid information. Either provide lat/lon coords on the " + f"spatial dimension '{spatial_dim}', or set 'nlat' in the DataArray attrs." + ) + if psd_method == "zonal" and (lats is None or lons is None): + raise ValueError( + "PSD method 'zonal' requires lat/lon coordinate arrays on the spatial " + f"dimension '{spatial_dim}', but they are not available." + ) + + # Identify batch dimensions (to average PSD over) vs dims to preserve. + # Convention: "sample" and "ens" dims are averaged over (flattened into + # the batch axis); all other non-spatial dims (e.g. "channel") are preserved. + batch_dims = [d for d in other_dims if d in ("sample", "ens")] + preserve_dims = [d for d in other_dims if d not in ("sample", "ens")] + + # Compute PSD per slice of the preserved dims (typically per channel). + # For each slice, samples form the batch axis for the SHT. + if preserve_dims: + # Iterate over the cartesian product of preserved dims + score_values = {} + attrs_per_slice = {} + for idx in np.ndindex(*[gt.sizes[d] for d in preserve_dims]): + sel = dict(zip(preserve_dims, idx)) + gt_slice = gt.isel(**sel) # dims: (sample, ipoint) or (ipoint,) + p_slice = p.isel(**sel) + + # Ensure 2-D: (n_batch, n_points) — flatten batch dims into one axis + non_spatial = [d for d in gt_slice.dims if d != spatial_dim] + if non_spatial: + gt_np = gt_slice.transpose(*non_spatial, spatial_dim).values.reshape(-1, n_points) + p_np = p_slice.transpose(*[d for d in p_slice.dims if d != spatial_dim], spatial_dim).values.reshape(-1, n_points) + else: + gt_np = gt_slice.values.reshape(1, -1) + p_np = p_slice.values.reshape(1, -1) + + slice_score, slice_attrs = compute_psd_score( + gt=gt_np, p=p_np, lats=lats, lons=lons, nlat=nlat, + n_points=n_points, psd_method=psd_method, + psd_regrid_resolution=psd_regrid_resolution, + psd_sht_truncation=psd_sht_truncation, lat_range=lat_range, + ) + score_values[idx] = slice_score + attrs_per_slice[idx] = slice_attrs + + # Build output DataArray with preserved dims + shape = tuple(gt.sizes[d] for d in preserve_dims) + score_np = np.array([score_values[idx] for idx in np.ndindex(*shape)]).reshape(shape) + coords = {d: gt.coords[d] for d in preserve_dims if d in gt.coords} + score = xr.DataArray(score_np, dims=preserve_dims, coords=coords) + + # Store per-channel PSD curves in attrs (keyed by channel coord value) + all_attrs = {} + for idx in np.ndindex(*shape): + sel = dict(zip(preserve_dims, idx)) + key = "_".join(str(gt.coords[d].values[i]) if d in gt.coords else str(i) + for d, i in sel.items()) + for k, v in attrs_per_slice[idx].items(): + all_attrs[f"{key}/{k}"] = v + all_attrs["psd_method"] = psd_method + all_attrs["preserve_dims"] = preserve_dims + score.attrs.update(all_attrs) + else: + # No preserved dims — compute single PSD over all data + non_spatial = [d for d in gt.dims if d != spatial_dim] + if non_spatial: + gt_np = gt.transpose(*non_spatial, spatial_dim).values.reshape(-1, n_points) + p_np = p.transpose(*[d for d in p.dims if d != spatial_dim], spatial_dim).values.reshape(-1, n_points) + else: + gt_np = gt.values.reshape(1, -1) + p_np = p.values.reshape(1, -1) + + slice_score, slice_attrs = compute_psd_score( + gt=gt_np, p=p_np, lats=lats, lons=lons, nlat=nlat, + n_points=n_points, psd_method=psd_method, + psd_regrid_resolution=psd_regrid_resolution, + psd_sht_truncation=psd_sht_truncation, lat_range=lat_range, + ) + score = xr.DataArray(slice_score) + score.attrs.update(slice_attrs) + score.attrs["psd_method"] = psd_method + + _logger.info(f"PSD analysis completed (score shape={score.shape})") + return score diff --git a/packages/evaluate/src/weathergen/evaluate/scores/score_orchestration.py b/packages/evaluate/src/weathergen/evaluate/scores/score_orchestration.py index 82bc8a026..b56dcb7ae 100644 --- a/packages/evaluate/src/weathergen/evaluate/scores/score_orchestration.py +++ b/packages/evaluate/src/weathergen/evaluate/scores/score_orchestration.py @@ -375,10 +375,11 @@ def store_metrics_for_region( for fstep, combined_metrics, _fstep_attrs in fstep_results: criteria = { "forecast_step": int(fstep), - "sample": combined_metrics.sample.values, "channel": combined_metrics.channel.values, "metric": combined_metrics.metric.values, } + if "sample" in combined_metrics.dims: + criteria["sample"] = combined_metrics.sample.values if "ens" in combined_metrics.dims: criteria["ens"] = combined_metrics.ens.values @@ -420,11 +421,20 @@ def store_metrics_for_region( for metric, parameters in metrics.items(): metric_data = metric_stream.sel({"metric": metric}).assign_attrs(parameters) + + # Restore attrs from all fsteps, keyed by fstep so downstream code + # (plotting) can produce one plot per forecast step. for (_stored_fstep, stored_metric), attrs in all_metric_attrs.items(): if stored_metric == metric and attrs: - _logger.debug(f"Restoring {len(attrs)} attributes for {metric}") - metric_data.attrs.update(attrs) - break + for k, v in attrs.items(): + metric_data.attrs[f"fstep_{_stored_fstep}/{k}"] = v + # Also store the list of fsteps that have attrs + attr_fsteps = sorted( + {fs for (fs, m) in all_metric_attrs if m == metric and all_metric_attrs[(fs, m)]} + ) + if attr_fsteps: + metric_data.attrs["attr_fsteps"] = attr_fsteps + _logger.debug(f"Stored per-fstep attributes for {metric}: fsteps={attr_fsteps}") local_scores.setdefault(metric, {}).setdefault(region, {}).setdefault(stream, {})[ run_id From effd35ce41b4fea0a4feb3553503e81609015eff Mon Sep 17 00:00:00 2001 From: iluise Date: Wed, 20 May 2026 13:57:31 +0200 Subject: [PATCH 2/8] fix zonal (ukmet) psd --- .../src/weathergen/evaluate/scores/psd.py | 69 ++++++++++++------- 1 file changed, 45 insertions(+), 24 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/scores/psd.py b/packages/evaluate/src/weathergen/evaluate/scores/psd.py index d484c3246..3fbe49107 100644 --- a/packages/evaluate/src/weathergen/evaluate/scores/psd.py +++ b/packages/evaluate/src/weathergen/evaluate/scores/psd.py @@ -355,23 +355,22 @@ def zonal_psd( ) -> tuple[np.ndarray, np.ndarray]: """Compute zonal PSD using 1-D FFT along the longitude dimension. - The input data is expected to be on a regular lat-lon grid already (or - will be interpolated/regridded by the caller). + For unstructured grids (where lats/lons are per-point coordinates rather + than regular axis arrays), the data is first regridded to a regular lat-lon + grid using scipy nearest-neighbor interpolation. Parameters ---------- data : np.ndarray - Field values. If 1-D, interpreted as flattened ``(nlat * nlon,)``. - If 2-D, interpreted as ``(nlat, nlon)`` or ``(n_samples, nlat * nlon)``. - If 3-D, interpreted as ``(n_samples, nlat, nlon)``. + Field values. Shape ``(n_samples, n_points)`` or ``(n_points,)``. lats : np.ndarray - 1-D array of latitude values (descending, length ``nlat``). + Latitude values. Either per-point (length ``n_points``) or axis (length ``nlat``). lons : np.ndarray - 1-D array of longitude values (ascending, length ``nlon``). + Longitude values. Either per-point (length ``n_points``) or axis (length ``nlon``). lat_range : tuple[float, float] Latitude bounds to restrict the computation to. regrid_resolution : float - Grid spacing in degrees (used only for frequency calculation). + Grid spacing in degrees for the regular target grid. Returns ------- @@ -381,29 +380,51 @@ def zonal_psd( Power spectral density averaged over samples and latitude rows, shape ``(nfreq,)``. """ - nlat = len(lats) - nlon = len(lons) + from scipy.interpolate import griddata - # Reshape to (n_samples, nlat, nlon) + # Ensure 2-D: (n_samples, n_points) if data.ndim == 1: - data = data.reshape(1, nlat, nlon) - elif data.ndim == 2: - if data.shape == (nlat, nlon): - data = data[np.newaxis, :, :] - else: - # (n_samples, nlat * nlon) - data = data.reshape(data.shape[0], nlat, nlon) - # data is now (n_samples, nlat, nlon) + data = data.reshape(1, -1) + + n_samples, n_points = data.shape + + # Determine if the grid is regular or unstructured + unique_lats = np.unique(lats) + unique_lons = np.unique(lons) + is_regular = (len(unique_lats) * len(unique_lons) == n_points) + + if is_regular and len(lats) == len(unique_lats): + # lats/lons are axis arrays for a regular grid + lat_axis = unique_lats + lon_axis = unique_lons + nlat, nlon = len(lat_axis), len(lon_axis) + data_3d = data.reshape(n_samples, nlat, nlon) + else: + # Unstructured grid — regrid to regular lat-lon + lat_min = max(lat_range[0], lats.min()) + lat_max = min(lat_range[1], lats.max()) + lon_min, lon_max = lons.min(), lons.max() + + lat_axis = np.arange(lat_min, lat_max + regrid_resolution / 2, regrid_resolution) + lon_axis = np.arange(lon_min, lon_max + regrid_resolution / 2, regrid_resolution) + nlat, nlon = len(lat_axis), len(lon_axis) + + grid_lon, grid_lat = np.meshgrid(lon_axis, lat_axis) + points = np.column_stack((lats, lons)) + + data_3d = np.empty((n_samples, nlat, nlon)) + for s in range(n_samples): + data_3d[s] = griddata(points, data[s], (grid_lat, grid_lon), method="nearest") # Apply latitude mask - lat_mask = (lats >= lat_range[0]) & (lats <= lat_range[1]) - data = data[:, lat_mask, :] - nlon_sub = data.shape[2] + lat_mask = (lat_axis >= lat_range[0]) & (lat_axis <= lat_range[1]) + data_3d = data_3d[:, lat_mask, :] + nlon_sub = data_3d.shape[2] # Compute PSD per sample and average psds = [] - for s in range(data.shape[0]): - psds.append(ZonalPSD.compute(data[s])) + for s in range(data_3d.shape[0]): + psds.append(ZonalPSD.compute(data_3d[s])) psd = np.mean(psds, axis=0) spacing = 360.0 / nlon_sub if nlon_sub > 0 else regrid_resolution From e208eece72c9d301396b37eed2b136cdf16f48d5 Mon Sep 17 00:00:00 2001 From: iluise Date: Wed, 20 May 2026 14:05:45 +0200 Subject: [PATCH 3/8] remove old psd implementation --- config/evaluate/eval_config.yml | 7 +- .../weathergen/evaluate/plotting/plotter.py | 110 ------------------ 2 files changed, 6 insertions(+), 111 deletions(-) diff --git a/config/evaluate/eval_config.yml b/config/evaluate/eval_config.yml index 1f0ae14b8..729a8f875 100644 --- a/config/evaluate/eval_config.yml +++ b/config/evaluate/eval_config.yml @@ -29,7 +29,12 @@ # max_workers: 36 # hard cap on parallel workers (I/O, plotting, scoring) evaluation: - metrics : ["rmse", "mae"] + metrics : + - rmse + - mae + - psd: + psd_method: "sht" # "sht" (SHT-based) or "zonal" (iris FFT) + psd_regrid_resolution: 1.0 # degrees, only for zonal method regions: ["global", "nhem"] summary_plots : true ratio_plots : false diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py index 37b3a67d8..48ad0cda7 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py @@ -619,117 +619,7 @@ def plot_psd( return name - def create_psd_plots( - self, - target: xr.DataArray, - preds: xr.DataArray, - variables: list, - select: dict, - tag: str = "", - psd_method: str = "sht", - psd_regions: list[str] | None = None, - psd_regrid_resolution: float = 1.0, - psd_sht_truncation: int | None = None, - ) -> list[str]: - """Compute and plot PSD for target and prediction. - - Parameters - ---------- - target, preds : xr.DataArray - Target / prediction arrays. - variables : list - List of channel names to process. - select : dict - Data selection dict (sample, stream, forecast_step). - tag : str - Filename tag (e.g., ``"preds_ens_0"``). - psd_method : str - ``"sht"`` or ``"zonal"``. - psd_regions : list[str] | None - Region names. ``None`` → use ``self.regions``. - psd_regrid_resolution : float - Grid spacing for the zonal method. - psd_sht_truncation : int | None - Spectral truncation for SHT method. - - Returns - ------- - list[str] - Names of saved PSD plots. - """ - from weathergen.evaluate.scores.psd import compute_psd_for_field - - self.update_data_selection(select) - psd_output_dir = self.get_psd_output_dir() - os.makedirs(psd_output_dir, exist_ok=True) - - regions = psd_regions or self.regions - plot_names: list[str] = [] - for region in regions: - if region != "global": - bbox = RegionBoundingBox.from_region_name(region) - reg_target = bbox.apply_mask(target) - reg_preds = bbox.apply_mask(preds) - else: - reg_target = target - reg_preds = preds - - for var in variables: - select_var = self.select | {"channel": var} - targ = self.select_from_da(reg_target, select_var).dropna(dim="ipoint") - prd = self.select_from_da(reg_preds, select_var).dropna(dim="ipoint") - - if targ.size == 0 or prd.size == 0: - _logger.warning(f"PSD: empty data for {var} in {region}. Skipping.") - continue - - targ_np = targ.values - prd_np = prd.values - - # Determine nlat from the grid if available - nlat = None - if "lat" in targ.coords: - nlat = len(np.unique(targ.coords["lat"].values)) - elif hasattr(targ, "attrs") and "nlat" in targ.attrs: - nlat = int(targ.attrs["nlat"]) - - lats = targ.coords["lat"].values if "lat" in targ.coords else None - lons = targ.coords["lon"].values if "lon" in targ.coords else None - - try: - freq_tar, psd_tar = compute_psd_for_field( - data=targ_np, - method=psd_method, - nlat=nlat, - lats=lats, - lons=lons, - lat_range=(-60.0, 60.0), - regrid_resolution=psd_regrid_resolution, - sht_truncation=psd_sht_truncation, - ) - freq_prd, psd_prd = compute_psd_for_field( - data=prd_np, - method=psd_method, - nlat=nlat, - lats=lats, - lons=lons, - lat_range=(-60.0, 60.0), - regrid_resolution=psd_regrid_resolution, - sht_truncation=psd_sht_truncation, - ) - except Exception: - _logger.exception(f"PSD computation failed for {var} in {region}") - continue - - name = self.plot_psd( - freq_tar, psd_tar, psd_prd, psd_output_dir, var, - tag=tag, region=region, - ) - plot_names.append(name) - - self.clean_data_selection() - return plot_names def get_psd_output_dir(self) -> Path: """Return the output directory path for PSD plots. From 64ed5825afa1f6c05c553303fecf1dbe0fc96143 Mon Sep 17 00:00:00 2001 From: iluise Date: Wed, 20 May 2026 16:31:07 +0200 Subject: [PATCH 4/8] remove dead code --- config/evaluate/eval_config.yml | 3 +- .../evaluate/plotting/line_plots.py | 3 +- .../evaluate/plotting/plot_utils.py | 3 +- .../weathergen/evaluate/plotting/plotter.py | 153 ------------------ .../src/weathergen/evaluate/scores/psd.py | 27 +++- .../src/weathergen/evaluate/scores/score.py | 3 + 6 files changed, 33 insertions(+), 159 deletions(-) diff --git a/config/evaluate/eval_config.yml b/config/evaluate/eval_config.yml index 729a8f875..40cee3f90 100644 --- a/config/evaluate/eval_config.yml +++ b/config/evaluate/eval_config.yml @@ -34,7 +34,8 @@ evaluation: - mae - psd: psd_method: "sht" # "sht" (SHT-based) or "zonal" (iris FFT) - psd_regrid_resolution: 1.0 # degrees, only for zonal method + psd_regrid_resolution: 1.0 # degrees, only for zonal method + grid_type: "octahedral" # "octahedral" (default), "regular", "reduced". regions: ["global", "nhem"] summary_plots : true ratio_plots : false diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/line_plots.py b/packages/evaluate/src/weathergen/evaluate/plotting/line_plots.py index e7307188c..422c59a70 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/line_plots.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/line_plots.py @@ -717,7 +717,8 @@ def psd_plot( color=c, lw=1.5, label=label, ) ax_spec.set_ylabel("Power") - ax_spec.set_title("PSD summary") + psd_method = psd_datasets[0].get("psd_method", "sht") + ax_spec.set_title(f"PSD summary (psd_{psd_method})") ax_spec.legend(frameon=False, fontsize=7) ax_spec.grid(True, which="both", ls="--", alpha=0.4) diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py b/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py index 4851f0484..aecaf8fb8 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py @@ -870,8 +870,9 @@ def psd_plot_metric_region( f"Creating PSD plot for {metric} - {region} - {stream} - " f"{ch} - {lead_time_str}." ) + method_tag = psd_datasets[0].get("psd_method", "sht") name = create_filename( - prefix=[metric, region], + prefix=[metric, method_tag, region], middle=[run_id], suffix=[stream, ch, f"fstep{fstep}"], ) diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py index 48ad0cda7..97e0840a6 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py @@ -478,159 +478,6 @@ def plot_histogram( return name - # ------------------------------------------------------------------ - # PSD plots - # ------------------------------------------------------------------ - - # -- PSD plot annotation helpers (from psd_plots.py) ---------------- - - @staticmethod - def _psd_add_wavenumbers(ax: plt.Axes) -> None: - """Add vertical lines at selected total wavenumbers.""" - yscale = ax.yaxis.get_scale() - ylims = ax.get_ylim() - if yscale == "log": - ytxt = 10.0 ** (0.85 * (np.log10(ylims[1] / ylims[0])) + np.log10(ylims[0])) - else: - ytxt = 0.85 * (ylims[1] - ylims[0]) + ylims[0] - - for wvn in (1, 2, 4, 8, 16, 24, 48, 96, 144, 216, 320, 640, 1280, 2560): - ax.plot( - [wvn / 360.0, wvn / 360.0], ylims, - color="black", lw=0.6, scalex=False, scaley=False, - ) - ax.text(wvn / 360.0, ytxt, f"n{wvn}", rotation="vertical", fontsize=6) - - @staticmethod - def _psd_add_lengths(ax: plt.Axes, lat_center: float = 0.0) -> None: - """Add vertical dashed lines at selected physical length scales (km).""" - re = 6.37e6 # earth radius in metres - yscale = ax.yaxis.get_scale() - ylims = ax.get_ylim() - if yscale == "log": - ytxt = 10.0 ** (0.05 * (np.log10(ylims[1] / ylims[0])) + np.log10(ylims[0])) - else: - ytxt = 0.05 * (ylims[1] - ylims[0]) + ylims[0] - - lengths_km = np.array([1.0e4, 3.0e3, 1.0e3, 3.0e2, 1.0e2, 3.0e1, 1.0e1]) - f_lengths = ( - 2.0 * np.pi * re * np.cos(np.radians(lat_center)) / (1000.0 * lengths_km * 360.0) - ) - for fl, lkm in zip(f_lengths, lengths_km): - ax.plot( - [fl, fl], ylims, - color="black", ls="--", lw=0.6, scalex=False, scaley=False, - ) - ax.text(fl, ytxt, f"{lkm:.0f}km", rotation="vertical", fontsize=6) - - @staticmethod - def _psd_add_ideal_slope( - ax: plt.Axes, slope: float = -3.0, - x_range: tuple[float, float] = (0.01, 0.1), y0: float = 10.0, - ) -> None: - """Add an idealised slope line on a log-log axes.""" - xs = np.array(x_range) - ys = y0 * np.array([1.0, (xs[1] / xs[0]) ** slope]) - ax.plot(xs, ys, color="black", lw=2.0, scalex=False, scaley=False) - xt = np.sqrt(np.prod(xs)) - yt = np.sqrt(np.prod(ys)) - ax.text(xt, yt, f"$k^{{{slope:.0f}}}$", fontsize=10, weight="bold") - - def plot_psd( - self, - freq: np.ndarray, - tar_psd: np.ndarray, - pred_psd: np.ndarray, - psd_output_dir: Path, - varname: str, - tag: str = "", - region: str = "", - ) -> str: - """Plot power spectral density: log-log spectra + ratio. - - Parameters - ---------- - freq : np.ndarray - Positive frequencies (or wavenumbers). - tar_psd : np.ndarray - Target PSD values. - pred_psd : np.ndarray - Prediction PSD values. - psd_output_dir : Path - Output directory for PSD plots. - varname : str - Variable / channel name. - tag, region : str - Filename parts (same convention as histograms / maps). - - Returns - ------- - str - Plot name (without extension). - """ - fig, (ax_spec, ax_ratio) = plt.subplots( - 2, 1, figsize=self.fig_size or (8, 8), - gridspec_kw={"height_ratios": [2, 1], "hspace": 0.08}, - ) - - # Upper panel: log-log spectra - ax_spec.loglog(freq, tar_psd, color="black", lw=1.5, label="Target") - ax_spec.loglog(freq, pred_psd, color="#00897B", lw=1.5, label="Prediction") - ax_spec.set_ylabel("Power") - ax_spec.set_title(f"PSD: {self.stream}, {varname}") - ax_spec.legend(frameon=False) - ax_spec.grid(True, which="both", ls="--", alpha=0.4) - ax_spec.set_xlim(1.0e-3, 1.0e1) - self._psd_add_wavenumbers(ax_spec) - self._psd_add_ideal_slope(ax_spec) - - # Lower panel: ratio - with np.errstate(divide="ignore", invalid="ignore"): - ratio = np.where(tar_psd > 0, pred_psd / tar_psd, np.nan) - ax_ratio.semilogx(freq, ratio, color="#00897B", lw=1.2) - ax_ratio.axhline(1.0, ls="--", color="gray", lw=0.8) - ax_ratio.set_ylabel("Pred / Target") - ax_ratio.set_xlabel("Frequency (1/deg)") - ax_ratio.set_ylim(0, 2) - ax_ratio.set_xlim(1.0e-3, 1.0e1) - ax_ratio.grid(True, which="both", ls="--", alpha=0.4) - self._psd_add_wavenumbers(ax_ratio) - - # Build filename (same pattern as maps / histograms) - is_global = str(self.sample) == "all_samples" - valid_time = None # PSD is always aggregated across time - parts = [ - "psd", - str(self.run_id), - str(tag) if tag else "", - str(self.sample), - valid_time, - str(self.stream), - region if region else "", - varname, - f"{self.fstep:03d}", - ] - name = "_".join(filter(None, parts)) - - fname = psd_output_dir / f"{name}.{self.image_format}" - _logger.debug(f"Saving PSD plot to {fname}") - fig.savefig(fname, bbox_inches="tight") - plt.close(fig) - - return name - - - - def get_psd_output_dir(self) -> Path: - """Return the output directory path for PSD plots. - - Returns - ------- - Path - Resolved directory path: ``//psd``. - """ - return self.out_plot_basedir / self.stream / "psd" - def create_maps_per_sample( self, data: xr.DataArray, diff --git a/packages/evaluate/src/weathergen/evaluate/scores/psd.py b/packages/evaluate/src/weathergen/evaluate/scores/psd.py index 3fbe49107..7887632af 100644 --- a/packages/evaluate/src/weathergen/evaluate/scores/psd.py +++ b/packages/evaluate/src/weathergen/evaluate/scores/psd.py @@ -511,6 +511,7 @@ def compute_psd_score( psd_regrid_resolution: float = 1.0, psd_sht_truncation: int | None = None, lat_range: tuple[float, float] = (-60.0, 60.0), + grid_type: str = "octahedral", ) -> tuple[float, dict]: """Compute PSD for a pair of 2-D fields and return a scalar score + curves. @@ -545,8 +546,9 @@ def compute_psd_score( Dict with keys ``"frequencies"``, ``"psd_target"``, ``"psd_prediction"`` (lists for JSON serialization). """ - # Drop NaN columns (masked grid points) + # Handle NaN grid points (e.g. from regional masking). valid_mask = ~np.isnan(gt).all(axis=0) + n_valid = valid_mask.sum() gt = gt[:, valid_mask] p = p[:, valid_mask] @@ -555,16 +557,35 @@ def compute_psd_score( lons_valid = lons[valid_mask] if lons is not None and len(lons) == n_points else lons nlat_valid = len(np.unique(lats_valid)) if lats_valid is not None else nlat + # SHT requires the structurally-complete grid (all points present). + # If the data is a regional subset, skip with a warning. + if psd_method == "sht": + if grid_type == "octahedral": + expected_pts = sum(_octahedral_lons_per_lat(nlat_valid)) + elif grid_type == "regular": + expected_pts = sum(_regular_lons_per_lat(nlat_valid)) + else: + expected_pts = None # cannot validate + actual_pts = gt.shape[-1] + if expected_pts is not None and actual_pts != expected_pts: + import logging + logging.getLogger(__name__).warning( + f"PSD (SHT): grid point mismatch ({actual_pts} vs expected {expected_pts} " + f"for grid_type={grid_type!r}, nlat={nlat_valid}). SHT scores are only " + f"available for the full (global/unmasked) grid. Skipping this region." + ) + return np.nan, {} + try: freq_gt, psd_gt = compute_psd_for_field( data=gt, method=psd_method, nlat=nlat_valid, lats=lats_valid, lons=lons_valid, lat_range=lat_range, regrid_resolution=psd_regrid_resolution, - sht_truncation=psd_sht_truncation, + sht_truncation=psd_sht_truncation, grid_type=grid_type, ) freq_p, psd_p = compute_psd_for_field( data=p, method=psd_method, nlat=nlat_valid, lats=lats_valid, lons=lons_valid, lat_range=lat_range, regrid_resolution=psd_regrid_resolution, - sht_truncation=psd_sht_truncation, + sht_truncation=psd_sht_truncation, grid_type=grid_type, ) except Exception: import logging diff --git a/packages/evaluate/src/weathergen/evaluate/scores/score.py b/packages/evaluate/src/weathergen/evaluate/scores/score.py index dd951f009..3580664de 100755 --- a/packages/evaluate/src/weathergen/evaluate/scores/score.py +++ b/packages/evaluate/src/weathergen/evaluate/scores/score.py @@ -1818,6 +1818,7 @@ def calc_psd( psd_regrid_resolution: float = 1.0, psd_sht_truncation: int | None = None, lat_range: tuple[float, float] = (-60.0, 60.0), + grid_type: str = "octahedral", ) -> xr.DataArray: """Compute power spectral density for prediction and ground truth. @@ -1930,6 +1931,7 @@ def calc_psd( n_points=n_points, psd_method=psd_method, psd_regrid_resolution=psd_regrid_resolution, psd_sht_truncation=psd_sht_truncation, lat_range=lat_range, + grid_type=grid_type, ) score_values[idx] = slice_score attrs_per_slice[idx] = slice_attrs @@ -1966,6 +1968,7 @@ def calc_psd( n_points=n_points, psd_method=psd_method, psd_regrid_resolution=psd_regrid_resolution, psd_sht_truncation=psd_sht_truncation, lat_range=lat_range, + grid_type=grid_type, ) score = xr.DataArray(slice_score) score.attrs.update(slice_attrs) From efe2d5f84bd86d9baf74520b94310e17518c8864 Mon Sep 17 00:00:00 2001 From: iluise Date: Thu, 21 May 2026 14:52:49 +0200 Subject: [PATCH 5/8] shorter functions --- config/evaluate/eval_config.yml | 4 +- packages/evaluate/pyproject.toml | 2 +- .../evaluate/plotting/plot_orchestration.py | 36 +-- .../evaluate/plotting/plot_utils.py | 94 ++----- .../src/weathergen/evaluate/scores/psd.py | 174 ++++++------ .../src/weathergen/evaluate/scores/score.py | 254 +++++++++--------- 6 files changed, 260 insertions(+), 304 deletions(-) diff --git a/config/evaluate/eval_config.yml b/config/evaluate/eval_config.yml index 40cee3f90..e8d606994 100644 --- a/config/evaluate/eval_config.yml +++ b/config/evaluate/eval_config.yml @@ -33,8 +33,8 @@ evaluation: - rmse - mae - psd: - psd_method: "sht" # "sht" (SHT-based) or "zonal" (iris FFT) - psd_regrid_resolution: 1.0 # degrees, only for zonal method + psd_method: "sht" # "sht" (SHT-based) or "fft" (regrid FFT) + psd_regrid_resolution: 1.0 # degrees, only for fft method grid_type: "octahedral" # "octahedral" (default), "regular", "reduced". regions: ["global", "nhem"] summary_plots : true diff --git a/packages/evaluate/pyproject.toml b/packages/evaluate/pyproject.toml index e6329a8a6..5c4afdc9c 100644 --- a/packages/evaluate/pyproject.toml +++ b/packages/evaluate/pyproject.toml @@ -25,7 +25,7 @@ dependencies = [ ] [project.optional-dependencies] -zonal = [ +fft = [ "scitools-iris>=3.11", "cf-units", ] diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py b/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py index 1d45930c3..194b21367 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py @@ -480,11 +480,10 @@ def _plot_all_samples( ) -> None: """Plot histograms across all samples for a single fstep. - Unlike per-sample plots, these aggregate all samples together. + Unlike per-sample histograms, these aggregate all samples together. The output filename uses 'global' instead of a sample id and omits the timestep. """ - has_work = (plot_histograms is True or plot_histograms == "across-samples") - if not has_work: + if not (plot_histograms is True or plot_histograms == "across-samples"): return matplotlib.use("Agg") @@ -498,15 +497,14 @@ def _plot_all_samples( preds_tag = "" if "ens" not in preds.dims else f"ens_{ens}" preds_name = "_".join(filter(None, ["preds", preds_tag])) - if plot_histograms is True or plot_histograms == "across-samples": - plotter.create_histograms( - tars, - preds_ens, - plot_chs, - data_selection, - preds_name, - ranges=maps_config, - ) + plotter.create_histograms( + tars, + preds_ens, + plot_chs, + data_selection, + preds_name, + ranges=maps_config, + ) plotter.clean_data_selection() @@ -791,8 +789,13 @@ def plot_summary(cfg: dict, scores_dict: dict, summary_dir: Path): quantile_plotter = QuantilePlots(plot_cfg, summary_dir) for region in regions: for metric in metrics: - if eval_opt.get("summary_plots", False) and metric != "psd": - plot_metric_region(metric, region, runs, scores_dict, plotter, print_summary) + if eval_opt.get("summary_plots", False): + if metric == "psd": + psd_plot_metric_region(metric, region, runs, scores_dict, plotter) + elif metric == "qq_analysis": + quantile_plot_metric_region(metric, region, runs, scores_dict, quantile_plotter) + else: + plot_metric_region(metric, region, runs, scores_dict, plotter, print_summary) if eval_opt.get("ratio_plots", False): ratio_plot_metric_region(metric, region, runs, scores_dict, plotter, print_summary) if eval_opt.get("heat_maps", False): @@ -801,7 +804,4 @@ def plot_summary(cfg: dict, scores_dict: dict, summary_dir: Path): score_card_metric_region(metric, region, runs, scores_dict, sc_plotter) if eval_opt.get("bar_plots", False): bar_plot_metric_region(metric, region, runs, scores_dict, br_plotter) - if metric == "qq_analysis": - quantile_plot_metric_region(metric, region, runs, scores_dict, quantile_plotter) - if metric == "psd": - psd_plot_metric_region(metric, region, runs, scores_dict, plotter) + \ No newline at end of file diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py b/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py index aecaf8fb8..0b3b64218 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py @@ -768,6 +768,27 @@ def quantile_plot_metric_region( ) +def _extract_psd_attrs(data_ch: xr.DataArray, fstep: int, ch: str) -> list[dict] | None: + """Extract PSD curve data from DataArray attrs for a given fstep/channel. + + Returns a single-element list of dicts ready for the plotter, or None if keys are missing. + """ + attrs = data_ch.attrs + fp = f"fstep_{fstep}/" + + for prefix in (f"{fp}{ch}/", fp): + if f"{prefix}frequencies" in attrs and f"{prefix}psd_target" in attrs: + return [ + { + "frequencies": np.array(attrs[f"{prefix}frequencies"]), + "psd_target": np.array(attrs[f"{prefix}psd_target"]), + "psd_prediction": np.array(attrs[f"{prefix}psd_prediction"]), + "psd_method": attrs.get(f"{fp}psd_method", attrs.get("psd_method", "sht")), + } + ] + return None + + def psd_plot_metric_region( metric: str, region: str, @@ -777,22 +798,8 @@ def psd_plot_metric_region( ) -> None: """Create PSD plots for all streams and channels for a given metric and region. - Follows the same pattern as ``quantile_plot_metric_region``: the PSD - curves (frequencies, target PSD, prediction PSD) are stored in + PSD curves (frequencies, target PSD, prediction PSD) are stored in ``score.attrs`` by ``Scores.calc_psd`` and read back here. - - Parameters - ---------- - metric : str - Metric name (should be ``"psd"``). - region : str - Region name. - runs : dict - Run config dict (run_id → config). - scores_dict : dict - Nested score dict ``{metric: {region: {stream: {run_id: DataArray}}}}``. - plotter : object - Plotter that has a ``psd_summary_plot`` method (or a generic line plotter). """ streams_set = collect_streams(runs) channels_set = collect_channels(scores_dict, metric, region, runs) @@ -804,72 +811,21 @@ def psd_plot_metric_region( continue data_ch = data.sel(channel=ch) if "channel" in data.dims else data - if data_ch.isnull().all(): continue - # Get list of fsteps that have PSD attrs attr_fsteps = data_ch.attrs.get("attr_fsteps", []) if not attr_fsteps: - _logger.warning( - f"PSD attrs missing for {run_id}/{stream}/{ch}. Skipping." - ) + _logger.warning(f"PSD attrs missing for {run_id}/{stream}/{ch}. Skipping.") continue label = runs[run_id].get("label", run_id) for fstep in attr_fsteps: - # Look up per-fstep, per-channel attrs - fstep_prefix = f"fstep_{fstep}/" - ch_prefix = f"{ch}/" - - # Try fstep+channel prefix first, then fstep-only - freq_key = None - for candidate in [ - f"{fstep_prefix}{ch_prefix}frequencies", - f"{fstep_prefix}frequencies", - ]: - if candidate in data_ch.attrs: - freq_key = candidate - break - - if freq_key is None: - continue - - # Derive the other keys from the same prefix pattern - key_prefix = freq_key.replace("frequencies", "") - psd_t_key = f"{key_prefix}psd_target" - psd_p_key = f"{key_prefix}psd_prediction" - - if psd_t_key not in data_ch.attrs: + psd_datasets = _extract_psd_attrs(data_ch, fstep, ch) + if psd_datasets is None: continue - psd_datasets = [ - { - "frequencies": np.array(data_ch.attrs[freq_key]), - "psd_target": np.array(data_ch.attrs[psd_t_key]), - "psd_prediction": np.array(data_ch.attrs[psd_p_key]), - "psd_method": data_ch.attrs.get( - f"{fstep_prefix}psd_method", - data_ch.attrs.get("psd_method", "sht"), - ), - } - ] - - # Resolve lead time for title - lead_time_str = f"fstep {fstep}" - if "lead_time" in data_ch.coords and "forecast_step" in data_ch.dims: - try: - lt = int(data_ch.coords["lead_time"].sel(forecast_step=fstep).values) - if lt > 0: - lead_time_str = f"{lt}h" - except Exception: - pass - - _logger.info( - f"Creating PSD plot for {metric} - {region} - {stream} - " - f"{ch} - {lead_time_str}." - ) method_tag = psd_datasets[0].get("psd_method", "sht") name = create_filename( prefix=[metric, method_tag, region], diff --git a/packages/evaluate/src/weathergen/evaluate/scores/psd.py b/packages/evaluate/src/weathergen/evaluate/scores/psd.py index 7887632af..53eb3c046 100644 --- a/packages/evaluate/src/weathergen/evaluate/scores/psd.py +++ b/packages/evaluate/src/weathergen/evaluate/scores/psd.py @@ -13,10 +13,14 @@ - **Path A – SHT-based PSD** (``method="sht"``): Spherical Harmonic Transform on unstructured grids (octahedral, reduced - Gaussian, regular lat-lon). Ported from ``spectral_transforms.py`` to pure - numpy using Legendre helpers from ``spectral_helpers.py``. - -- **Path B – Zonal FFT PSD** (``method="zonal"``): + Gaussian, regular lat-lon). Ported from anemoi.models ``spectral_transforms.py`` to pure + numpy using Legendre helpers from anemoi.models ``spectral_helpers.py``. + [anemoi.models.spectral_transforms] + https://github.com/ecmwf/anemoi-core/blob/main/models/src/anemoi/models/layers/spectral_transforms.py + [anemoi.models.spectral_helpers] + https://github.com/ecmwf/anemoi-core/blob/main/models/src/anemoi/models/layers/spectral_helpers.py + +- **Path B – FFT PSD** (``method="fft"``): 1-D zonal FFT along the longitude dimension on a regular lat-lon grid. Absorbs the functions previously in ``example_extras/power_spectra/psd_calc.py``. """ @@ -25,6 +29,7 @@ import logging import numpy as np +from scipy.interpolate import griddata _logger = logging.getLogger(__name__) @@ -108,7 +113,7 @@ def __init__(self, lons_per_lat: list[int], truncation: int) -> None: # Pre-multiply by quadrature weights → shape (m, l, lat) self.weight = np.einsum("mlk,k->mlk", pct, weight) - # -- internal FFT helpers ----------------------------------------------- + # internal FFT helpers def _rfft_regular(self, x: np.ndarray) -> np.ndarray: """Batched real FFT for a *regular* grid. @@ -145,7 +150,7 @@ def _rfft_reduced(self, x: np.ndarray) -> np.ndarray: ) return out - # -- transform --------------------------------------------------- + # transform def transform(self, x: np.ndarray) -> np.ndarray: """Compute the SHT. @@ -271,89 +276,80 @@ def sht_psd( # --------------------------------------------------------------------------- -# Zonal FFT PSD (absorbed from psd_calc.py) +# FFT PSD (absorbed from psd_calc.py) # --------------------------------------------------------------------------- -class ZonalPSD: - """Zonal power spectral density via 1-D FFT along the longitude dimension. +def _fft_psd_calc(ht: np.ndarray) -> np.ndarray: + """Return the PSD for positive non-zero frequencies of an even-length signal. - This class absorbs the functionality previously in - ``example_extras/power_spectra/psd_calc.py``. - """ + Assumes *ht* has an even number of points. - @staticmethod - def psd_1d(ht: np.ndarray) -> np.ndarray: - """Return the PSD for positive non-zero frequencies of an even-length signal. + Parameters + ---------- + ht : np.ndarray + 1-D real-valued signal (one latitude ring). - Parameters - ---------- - ht : np.ndarray - 1-D real-valued signal (one latitude ring). + Returns + ------- + np.ndarray + PSD for positive frequencies, length ``n // 2``. + """ + n = len(ht) + hf = np.fft.rfft(ht, norm="forward") + power = np.abs(hf[1 : round(n / 2 + 1)]) ** 2 + power *= 2.0 # compensate for positive frequencies only + return power - Returns - ------- - np.ndarray - PSD for positive frequencies, length ``n // 2``. - """ - n = len(ht) - hf = np.fft.rfft(ht, norm="forward") - power = np.abs(hf[1 : round(n / 2 + 1)]) ** 2 - power *= 2.0 # compensate for positive frequencies only - return power - @staticmethod - def positive_frequencies(npoints: int, spacing_deg: float = 1.0) -> np.ndarray: - """Return the positive frequencies for a signal of *npoints* evenly spaced points. +def _cubepsd(field_2d: np.ndarray) -> np.ndarray: + """Compute PSD averaged over all latitude rows. - Parameters - ---------- - npoints : int - Number of equally-spaced longitude points. - spacing_deg : float - Grid spacing in degrees. Default is ``360 / npoints``. + Parameters + ---------- + field_2d : np.ndarray + 2-D array of shape ``(nlat, nlon)``. - Returns - ------- - np.ndarray - Positive frequencies, length ``npoints // 2``. - """ - freq = np.fft.fftfreq(npoints, d=spacing_deg) - return np.abs(freq[1 : round(npoints / 2 + 1)]) + Returns + ------- + np.ndarray + PSD of shape ``(nlon // 2,)``. + """ + nlat, nlon = field_2d.shape + field_psd = np.zeros(nlon // 2) + for row in field_2d: + field_psd += _fft_psd_calc(row) + field_psd /= nlat + return field_psd - @classmethod - def compute( - cls, - field_2d: np.ndarray, - ) -> np.ndarray: - """Compute the zonal PSD averaged over all latitude rows. - Parameters - ---------- - field_2d : np.ndarray - 2-D array of shape ``(nlat, nlon)``. +def _calcposfreq(npoints: int, spacing_deg: float = 1.0) -> np.ndarray: + """Return the positive frequencies for a signal of *npoints* evenly spaced points. - Returns - ------- - np.ndarray - PSD of shape ``(nlon // 2,)``. - """ - nlat, nlon = field_2d.shape - psd_accum = np.zeros(nlon // 2) - for row in field_2d: - psd_accum += cls.psd_1d(row) - psd_accum /= nlat - return psd_accum + Parameters + ---------- + npoints : int + Number of equally-spaced longitude points. + spacing_deg : float + Grid spacing in degrees. + + Returns + ------- + np.ndarray + Positive frequencies, length ``npoints // 2``. + """ + freq = np.fft.fftfreq(npoints, d=spacing_deg) + return np.abs(freq[1 : round(npoints / 2 + 1)]) -def zonal_psd( +def fft_psd( data: np.ndarray, lats: np.ndarray, lons: np.ndarray, lat_range: tuple[float, float] = (-60.0, 60.0), regrid_resolution: float = 1.0, ) -> tuple[np.ndarray, np.ndarray]: - """Compute zonal PSD using 1-D FFT along the longitude dimension. + """Compute PSD using 1-D FFT along the longitude dimension. For unstructured grids (where lats/lons are per-point coordinates rather than regular axis arrays), the data is first regridded to a regular lat-lon @@ -380,7 +376,6 @@ def zonal_psd( Power spectral density averaged over samples and latitude rows, shape ``(nfreq,)``. """ - from scipy.interpolate import griddata # Ensure 2-D: (n_samples, n_points) if data.ndim == 1: @@ -424,13 +419,12 @@ def zonal_psd( # Compute PSD per sample and average psds = [] for s in range(data_3d.shape[0]): - psds.append(ZonalPSD.compute(data_3d[s])) - psd = np.mean(psds, axis=0) + psds.append(_cubepsd(data_3d[s])) + psd_result = np.mean(psds, axis=0) spacing = 360.0 / nlon_sub if nlon_sub > 0 else regrid_resolution - frequencies = ZonalPSD.positive_frequencies(nlon_sub, spacing_deg=spacing) - - return frequencies, psd + frequencies = _calcposfreq(nlon_sub, spacing_deg=spacing) + return frequencies, psd_result # --------------------------------------------------------------------------- @@ -454,17 +448,17 @@ def compute_psd_for_field( Parameters ---------- data : np.ndarray - Spatial field. Shape depends on the method (see ``sht_psd`` / ``zonal_psd``). + Spatial field. Shape depends on the method (see ``sht_psd`` / ``fft_psd``). method : str - ``"sht"`` for SHT-based PSD, ``"zonal"`` for zonal FFT PSD. + ``"sht"`` for SHT-based PSD, ``"fft"`` for FFT PSD. nlat : int | None Number of latitudes (required for SHT method). lats, lons : np.ndarray | None - Latitude / longitude coordinate arrays (required for zonal method). + Latitude / longitude coordinate arrays (required for fft method). lat_range : tuple[float, float] - Latitude bounds for the zonal method. + Latitude bounds for the fft method. regrid_resolution : float - Grid spacing in degrees for the zonal method. + Grid spacing in degrees for the fft method. sht_truncation : int | None Spectral truncation for SHT. grid_type : str @@ -473,7 +467,7 @@ def compute_psd_for_field( Returns ------- x_values : np.ndarray - Wavenumbers (SHT) or positive frequencies (zonal). + Wavenumbers (SHT) or positive frequencies (fft). psd : np.ndarray Power spectral density. """ @@ -486,10 +480,10 @@ def compute_psd_for_field( truncation=sht_truncation, grid_type=grid_type, ) - elif method == "zonal": + elif method == "fft": if lats is None or lons is None: - raise ValueError("lats and lons are required for method='zonal'") - return zonal_psd( + raise ValueError("lats and lons are required for method='fft'") + return fft_psd( data=data, lats=lats, lons=lons, @@ -497,7 +491,7 @@ def compute_psd_for_field( regrid_resolution=regrid_resolution, ) else: - raise ValueError(f"Unknown PSD method: {method!r}. Use 'sht' or 'zonal'.") + raise ValueError(f"Unknown PSD method: {method!r}. Use 'sht' or 'fft'.") def compute_psd_score( @@ -530,13 +524,13 @@ def compute_psd_score( n_points : int Original number of spatial points (before NaN masking). psd_method : str - ``"sht"`` or ``"zonal"``. + ``"sht"`` or ``"fft"``. psd_regrid_resolution : float - Grid spacing for zonal method. + Grid spacing for fft method. psd_sht_truncation : int | None Spectral truncation for SHT. lat_range : tuple[float, float] - Latitude bounds for zonal method. + Latitude bounds for fft method. Returns ------- @@ -568,8 +562,7 @@ def compute_psd_score( expected_pts = None # cannot validate actual_pts = gt.shape[-1] if expected_pts is not None and actual_pts != expected_pts: - import logging - logging.getLogger(__name__).warning( + _logger.warning( f"PSD (SHT): grid point mismatch ({actual_pts} vs expected {expected_pts} " f"for grid_type={grid_type!r}, nlat={nlat_valid}). SHT scores are only " f"available for the full (global/unmasked) grid. Skipping this region." @@ -588,8 +581,7 @@ def compute_psd_score( sht_truncation=psd_sht_truncation, grid_type=grid_type, ) except Exception: - import logging - logging.getLogger(__name__).exception("PSD computation failed, returning NaN.") + _logger.exception("PSD computation failed, returning NaN.") return np.nan, {} # Scalar summary: mean squared error of log10 PSD diff --git a/packages/evaluate/src/weathergen/evaluate/scores/score.py b/packages/evaluate/src/weathergen/evaluate/scores/score.py index 3580664de..0e4c40163 100755 --- a/packages/evaluate/src/weathergen/evaluate/scores/score.py +++ b/packages/evaluate/src/weathergen/evaluate/scores/score.py @@ -1823,156 +1823,164 @@ def calc_psd( """Compute power spectral density for prediction and ground truth. Returns a scalar summary score (log-spectral MSE) and stores the full - PSD curves in ``.attrs`` so that ``psd_plot_metric_region`` in the - summary-plot phase can produce spectral plots — following the same - pattern as ``calc_quantiles`` / ``qq_analysis``. + PSD curves in ``.attrs`` for plotting downstream. Parameters ---------- - p : xr.DataArray - Prediction data. - gt : xr.DataArray - Ground truth data. - psd_method : str - ``"sht"`` (default) or ``"zonal"``. - psd_regrid_resolution : float - Grid spacing for the zonal method (degrees). - psd_sht_truncation : int | None - Spectral truncation for the SHT method. - lat_range : tuple[float, float] - Latitude range for the zonal method. + p: xr.DataArray + Forecast data array + gt: xr.DataArray + Ground truth data array + psd_method: str + Method to compute the PSD. Options: 'sht' (spherical harmonic transform), + 'fft' (2D Fourier transform) + psd_regrid_resolution: float + Resolution in degrees to regrid data for PSD calculation. Default is 1.0 degree + psd_sht_truncation: int | None + Maximum spherical harmonic degree for truncation. If None, no truncation is applied. + lat_range: tuple[float, float] + Latitude range (min, max) to include in PSD calculation. Default is (-60, + 60) degrees. + grid_type: str + Type of grid for PSD calculation. Options: 'octahedral', 'regular'. Default is 'octahedral'. Returns ------- xr.DataArray - Scalar score (log-spectral MSE) with PSD data stored in ``.attrs``. - """ + Power spectral density score (log-spectral MSE) averaged over aggregation dimensions. + """ if self._agg_dims is None: raise ValueError("Cannot calculate PSD without aggregation dimensions.") - - # PSD expects exactly one spatial aggregation dimension (e.g. "ipoint"). if len(self._agg_dims) != 1: raise ValueError( - f"PSD expects exactly one spatial aggregation dimension (points), " - f"but got agg_dims={self._agg_dims}. Do not flatten over multiple dims." + f"PSD expects exactly one spatial aggregation dimension, " + f"got agg_dims={self._agg_dims}." ) spatial_dim = self._agg_dims[0] if spatial_dim not in gt.dims: raise ValueError( - f"Spatial dimension '{spatial_dim}' not found in data dims {list(gt.dims)}." + f"Spatial dimension '{spatial_dim}' not found in dims {list(gt.dims)}." ) - # The data arriving here has dims like (sample, channel, ipoint). - # PSD must be computed per channel (and averaged over samples). - # We iterate over all non-spatial dims except the "sample" axis, - # which becomes the sample dimension for the SHT. - other_dims = [d for d in gt.dims if d != spatial_dim] - _logger.debug( - f"PSD: spatial_dim={spatial_dim}, other_dims={other_dims}, " - f"shape p={dict(zip(p.dims, p.shape))}, gt={dict(zip(gt.dims, gt.shape))}" - ) - - # Determine grid info from coords (same for all channels/samples) n_points = gt.sizes[spatial_dim] - nlat = None - lats, lons = None, None - if "lat" in gt.coords and "lon" in gt.coords: - if gt.coords["lat"].dims == (spatial_dim,) and gt.coords["lon"].dims == (spatial_dim,): - lats = gt.coords["lat"].values - lons = gt.coords["lon"].values - nlat = len(np.unique(lats)) - else: - _logger.warning( - f"PSD: lat/lon coords exist but have unexpected dims " - f"(lat: {gt.coords['lat'].dims}, lon: {gt.coords['lon'].dims}). " - f"Expected ({spatial_dim},)." - ) + nlat, lats, lons = self._get_psd_grid_info(gt, spatial_dim) - if nlat is None and lats is None: - raise ValueError( - "PSD requires grid information. Either provide lat/lon coords on the " - f"spatial dimension '{spatial_dim}', or set 'nlat' in the DataArray attrs." - ) - if psd_method == "zonal" and (lats is None or lons is None): + if psd_method == "fft" and (lats is None or lons is None): raise ValueError( - "PSD method 'zonal' requires lat/lon coordinate arrays on the spatial " - f"dimension '{spatial_dim}', but they are not available." + f"PSD method 'fft' requires lat/lon coords on '{spatial_dim}'." ) - # Identify batch dimensions (to average PSD over) vs dims to preserve. - # Convention: "sample" and "ens" dims are averaged over (flattened into - # the batch axis); all other non-spatial dims (e.g. "channel") are preserved. - batch_dims = [d for d in other_dims if d in ("sample", "ens")] + psd_kwargs = dict( + lats=lats, lons=lons, nlat=nlat, n_points=n_points, + psd_method=psd_method, psd_regrid_resolution=psd_regrid_resolution, + psd_sht_truncation=psd_sht_truncation, lat_range=lat_range, + grid_type=grid_type, + ) + + # Dims to preserve (e.g. channel) vs batch dims (sample, ens) + other_dims = [d for d in gt.dims if d != spatial_dim] preserve_dims = [d for d in other_dims if d not in ("sample", "ens")] - # Compute PSD per slice of the preserved dims (typically per channel). - # For each slice, samples form the batch axis for the SHT. - if preserve_dims: - # Iterate over the cartesian product of preserved dims - score_values = {} - attrs_per_slice = {} - for idx in np.ndindex(*[gt.sizes[d] for d in preserve_dims]): - sel = dict(zip(preserve_dims, idx)) - gt_slice = gt.isel(**sel) # dims: (sample, ipoint) or (ipoint,) - p_slice = p.isel(**sel) - - # Ensure 2-D: (n_batch, n_points) — flatten batch dims into one axis - non_spatial = [d for d in gt_slice.dims if d != spatial_dim] - if non_spatial: - gt_np = gt_slice.transpose(*non_spatial, spatial_dim).values.reshape(-1, n_points) - p_np = p_slice.transpose(*[d for d in p_slice.dims if d != spatial_dim], spatial_dim).values.reshape(-1, n_points) - else: - gt_np = gt_slice.values.reshape(1, -1) - p_np = p_slice.values.reshape(1, -1) - - slice_score, slice_attrs = compute_psd_score( - gt=gt_np, p=p_np, lats=lats, lons=lons, nlat=nlat, - n_points=n_points, psd_method=psd_method, - psd_regrid_resolution=psd_regrid_resolution, - psd_sht_truncation=psd_sht_truncation, lat_range=lat_range, - grid_type=grid_type, - ) - score_values[idx] = slice_score - attrs_per_slice[idx] = slice_attrs - - # Build output DataArray with preserved dims - shape = tuple(gt.sizes[d] for d in preserve_dims) - score_np = np.array([score_values[idx] for idx in np.ndindex(*shape)]).reshape(shape) - coords = {d: gt.coords[d] for d in preserve_dims if d in gt.coords} - score = xr.DataArray(score_np, dims=preserve_dims, coords=coords) - - # Store per-channel PSD curves in attrs (keyed by channel coord value) - all_attrs = {} - for idx in np.ndindex(*shape): - sel = dict(zip(preserve_dims, idx)) - key = "_".join(str(gt.coords[d].values[i]) if d in gt.coords else str(i) - for d, i in sel.items()) - for k, v in attrs_per_slice[idx].items(): - all_attrs[f"{key}/{k}"] = v - all_attrs["psd_method"] = psd_method - all_attrs["preserve_dims"] = preserve_dims - score.attrs.update(all_attrs) - else: - # No preserved dims — compute single PSD over all data - non_spatial = [d for d in gt.dims if d != spatial_dim] - if non_spatial: - gt_np = gt.transpose(*non_spatial, spatial_dim).values.reshape(-1, n_points) - p_np = p.transpose(*[d for d in p.dims if d != spatial_dim], spatial_dim).values.reshape(-1, n_points) - else: - gt_np = gt.values.reshape(1, -1) - p_np = p.values.reshape(1, -1) - - slice_score, slice_attrs = compute_psd_score( - gt=gt_np, p=p_np, lats=lats, lons=lons, nlat=nlat, - n_points=n_points, psd_method=psd_method, - psd_regrid_resolution=psd_regrid_resolution, - psd_sht_truncation=psd_sht_truncation, lat_range=lat_range, - grid_type=grid_type, - ) + if not preserve_dims: + gt_np, p_np = self._stack_for_psd(gt, p, spatial_dim, n_points) + slice_score, slice_attrs = compute_psd_score(gt=gt_np, p=p_np, **psd_kwargs) score = xr.DataArray(slice_score) score.attrs.update(slice_attrs) score.attrs["psd_method"] = psd_method + return score + + # Iterate over preserved dims (typically per channel) + shape = tuple(gt.sizes[d] for d in preserve_dims) + score_values = np.empty(shape) + all_attrs: dict = {} + + for idx in np.ndindex(*shape): + sel = dict(zip(preserve_dims, idx)) + gt_slice = gt.isel(**sel) + p_slice = p.isel(**sel) + gt_np, p_np = self._stack_for_psd(gt_slice, p_slice, spatial_dim, n_points) - _logger.info(f"PSD analysis completed (score shape={score.shape})") + slice_score, slice_attrs = compute_psd_score(gt=gt_np, p=p_np, **psd_kwargs) + score_values[idx] = slice_score + + key = "_".join( + str(gt.coords[d].values[i]) if d in gt.coords else str(i) + for d, i in sel.items() + ) + for k, v in slice_attrs.items(): + all_attrs[f"{key}/{k}"] = v + + coords = {d: gt.coords[d] for d in preserve_dims if d in gt.coords} + score = xr.DataArray(score_values, dims=preserve_dims, coords=coords) + all_attrs["psd_method"] = psd_method + all_attrs["preserve_dims"] = preserve_dims + score.attrs.update(all_attrs) return score + + @staticmethod + def _get_psd_grid_info( + gt: xr.DataArray, spatial_dim: str + ) -> tuple[int | None, np.ndarray | None, np.ndarray | None]: + """ + Extract nlat, lats, lons from ground-truth coords. + + Parameters + ---------- + gt: xr.DataArray + Ground truth data array with lat/lon coordinates. + spatial_dim: str + Name of the spatial dimension along which to compute the PSD. + Returns + ------- + nlat: int | None + Number of latitude points, or None if lat/lon coords are not found. + lats: np.ndarray | None + Latitude values, or None if lat/lon coords are not found. + lons: np.ndarray | None + Longitude values, or None if lat/lon coords are not found. + + """ + if "lat" in gt.coords and "lon" in gt.coords: + if gt.coords["lat"].dims == (spatial_dim,) and gt.coords["lon"].dims == (spatial_dim,): + lats = gt.coords["lat"].values + lons = gt.coords["lon"].values + return len(np.unique(lats)), lats, lons + raise ValueError( + f"PSD requires lat/lon coords on spatial dimension '{spatial_dim}'." + ) + + @staticmethod + def _stack_for_psd( + gt: xr.DataArray, p: xr.DataArray, spatial_dim: str, n_points: int + ) -> tuple[np.ndarray, np.ndarray]: + """ + Reshape data to (n_batch, n_points) for PSD computation. + + Parameters + ---------- + gt: xr.DataArray + Ground truth data array. + p: xr.DataArray + Forecast data array. + spatial_dim: str + Name of the spatial dimension along which to compute the PSD. + n_points: int + Number of points along the spatial dimension. + Returns + ------- + gt_np: np.ndarray + Reshaped ground truth data of shape (n_batch, n_points). + p_np: np.ndarray + Reshaped forecast data of shape (n_batch, n_points). + """ + non_spatial = [d for d in gt.dims if d != spatial_dim] + if non_spatial: + gt_np = gt.transpose(*non_spatial, spatial_dim).values.reshape(-1, n_points) + p_np = p.transpose( + *[d for d in p.dims if d != spatial_dim], spatial_dim + ).values.reshape(-1, n_points) + else: + gt_np = gt.values.reshape(1, -1) + p_np = p.values.reshape(1, -1) + return gt_np, p_np From fbca45a96bef135faff8fc6d203d93ec7f0ec941 Mon Sep 17 00:00:00 2001 From: iluise Date: Thu, 21 May 2026 17:36:50 +0200 Subject: [PATCH 6/8] implement inverse transform and tests --- .../src/weathergen/evaluate/scores/psd.py | 91 ++++++++ tests/test_psd.py | 212 ++++++++++++++++++ tests/test_sht_roundtrip.py | 95 ++++++++ 3 files changed, 398 insertions(+) create mode 100644 tests/test_psd.py create mode 100644 tests/test_sht_roundtrip.py diff --git a/packages/evaluate/src/weathergen/evaluate/scores/psd.py b/packages/evaluate/src/weathergen/evaluate/scores/psd.py index 53eb3c046..5504df92f 100644 --- a/packages/evaluate/src/weathergen/evaluate/scores/psd.py +++ b/packages/evaluate/src/weathergen/evaluate/scores/psd.py @@ -180,6 +180,97 @@ def transform(self, x: np.ndarray) -> np.ndarray: return rl + 1j * im +class InverseSphericalHarmonicTransform: + """Inverse Spherical Harmonic Transform in pure numpy. + + Reconstructs a spatial field from spectral coefficients (l, m). + Mirrors the ``InverseSphericalHarmonicTransform`` from ``spectral_helpers.py`` + in anemoi.models but operates on numpy arrays. + + Parameters + ---------- + lons_per_lat : list[int] + Number of longitude points on each latitude ring (pole to pole). + truncation : int + Maximum total wavenumber. + """ + + def __init__(self, lons_per_lat: list[int], truncation: int) -> None: + self.lons_per_lat = lons_per_lat + self.nlat = len(lons_per_lat) + self.truncation = truncation + self.n_grid_points = sum(lons_per_lat) + self._is_regular = len(set(lons_per_lat)) == 1 + + # Gaussian latitudes (no quadrature weights needed for inverse) + theta, _ = _legendre_gauss_weights(self.nlat) + theta = np.flip(np.arccos(theta)) + + # Associated Legendre polynomials with inverse=True + self.pct = _legpoly(truncation, truncation, np.cos(theta), inverse=True) + + def _irfft_regular(self, x: np.ndarray) -> np.ndarray: + """Inverse FFT for a regular grid. + + Parameters + ---------- + x : np.ndarray, complex, shape ``(..., nlat, M)`` + + Returns + ------- + np.ndarray, real, shape ``(..., grid)`` + """ + nlon = self.lons_per_lat[0] + spatial = np.fft.irfft(x, n=nlon, norm="forward") # (..., nlat, nlon) + return spatial.reshape(*spatial.shape[:-2], self.n_grid_points) + + def _irfft_reduced(self, x: np.ndarray) -> np.ndarray: + """Per-ring inverse FFT for a reduced grid. + + Parameters + ---------- + x : np.ndarray, complex, shape ``(..., nlat, M)`` + + Returns + ------- + np.ndarray, real, shape ``(..., grid)`` + """ + lead_shape = x.shape[:-2] + out = np.zeros((*lead_shape, self.n_grid_points), dtype=np.float64) + offset = 0 + for i, nlon in enumerate(self.lons_per_lat): + ring = np.fft.irfft(x[..., i, :], n=nlon, norm="forward") + out[..., offset : offset + nlon] = ring + offset += nlon + return out + + def transform(self, coeffs: np.ndarray) -> np.ndarray: + """Compute the inverse SHT. + + Parameters + ---------- + coeffs : np.ndarray, complex, shape ``(..., L, M)`` + + Returns + ------- + np.ndarray, real, shape ``(..., grid)`` + """ + # Inverse Legendre transform: (..., l, m) × (m, l, k) → (..., k, m) + real_part = coeffs.real + imag_part = coeffs.imag + + rl = np.einsum("...lm,mlk->...km", real_part, self.pct) + im = np.einsum("...lm,mlk->...km", imag_part, self.pct) + + x_fourier = rl + 1j * im # (..., nlat, M) + + # Inverse FFT per ring + if self._is_regular: + return self._irfft_regular(x_fourier) + else: + return self._irfft_reduced(x_fourier) + + # --------------------------------------------------------------------------- # Grid helpers for building lons_per_lat # --------------------------------------------------------------------------- diff --git a/tests/test_psd.py b/tests/test_psd.py new file mode 100644 index 000000000..0b853c457 --- /dev/null +++ b/tests/test_psd.py @@ -0,0 +1,212 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. + +"""Tests for weathergen.evaluate.scores.psd.""" + +from __future__ import annotations + +import numpy as np +import pytest + +from weathergen.evaluate.scores.psd import ( + SphericalHarmonicTransform, + ZonalPSD, + compute_psd_for_field, + sht_psd, + zonal_psd, +) + + +# --------------------------------------------------------------------------- +# ZonalPSD (absorbs psd_calc.py) +# --------------------------------------------------------------------------- + + +class TestZonalPSD: + """Unit tests for the ZonalPSD class.""" + + def test_psd_1d_even(self): + """PSD of a pure sine should peak at the correct frequency bin.""" + n = 360 + freq_idx = 5 # wave number 5 + x = np.sin(2 * np.pi * freq_idx * np.arange(n) / n) + power = ZonalPSD.psd_1d(x) + assert power.shape == (n // 2,) + # Peak should be at index freq_idx - 1 (psd starts at freq index 1) + assert np.argmax(power) == freq_idx - 1 + + def test_positive_frequencies(self): + """Positive frequencies length and positivity.""" + npoints = 360 + freq = ZonalPSD.positive_frequencies(npoints, spacing_deg=1.0) + assert len(freq) == npoints // 2 + assert np.all(freq > 0) + + def test_compute_2d(self): + """ZonalPSD.compute averages PSD over latitude rows.""" + nlat, nlon = 10, 360 + field = np.random.default_rng(42).standard_normal((nlat, nlon)) + psd = ZonalPSD.compute(field) + assert psd.shape == (nlon // 2,) + assert np.all(psd >= 0) + + +# --------------------------------------------------------------------------- +# zonal_psd wrapper +# --------------------------------------------------------------------------- + + +class TestZonalPsdWrapper: + """Tests for the zonal_psd() dispatch-level wrapper.""" + + def test_basic_shape(self): + """Output shape matches expected number of positive frequencies.""" + nlat, nlon = 90, 360 + lats = np.linspace(90, -90, nlat) + lons = np.linspace(0, 359, nlon) + data = np.random.default_rng(0).standard_normal((nlat, nlon)) + + freq, psd = zonal_psd(data, lats, lons, lat_range=(-60, 60)) + assert freq.ndim == 1 + assert psd.ndim == 1 + assert len(freq) == len(psd) + assert np.all(psd >= 0) + + def test_multi_sample(self): + """Multi-sample input is averaged correctly.""" + nlat, nlon = 45, 180 + lats = np.linspace(90, -90, nlat) + lons = np.linspace(0, 358, nlon) + rng = np.random.default_rng(1) + data = rng.standard_normal((3, nlat, nlon)) # 3 samples + + freq, psd = zonal_psd(data, lats, lons) + assert freq.shape == psd.shape + + +# --------------------------------------------------------------------------- +# SphericalHarmonicTransform +# --------------------------------------------------------------------------- + + +class TestSphericalHarmonicTransform: + """Tests for the pure-numpy SHT transform.""" + + def test_regular_grid_shape(self): + """SHT on a regular grid produces the correct output shape.""" + nlat = 32 + nlon = 64 + trunc = 15 + sht = SphericalHarmonicTransform(lons_per_lat=[nlon] * nlat, truncation=trunc) + + x = np.random.default_rng(7).standard_normal(sht.n_grid_points) + coeffs = sht.transform(x) + assert coeffs.shape == (trunc + 1, trunc + 1) + assert np.iscomplexobj(coeffs) + + def test_octahedral_grid_shape(self): + """SHT on an octahedral grid produces the correct output shape.""" + nlat = 64 + trunc = 31 + lons = [20 + 4 * i for i in range(nlat // 2)] + lons = lons + list(reversed(lons)) + sht = SphericalHarmonicTransform(lons_per_lat=lons, truncation=trunc) + + x = np.random.default_rng(8).standard_normal(sht.n_grid_points) + coeffs = sht.transform(x) + assert coeffs.shape == (trunc + 1, trunc + 1) + + def test_constant_field(self): + """A constant field should have energy only at wavenumber 0.""" + nlat = 32 + nlon = 64 + trunc = 15 + sht = SphericalHarmonicTransform(lons_per_lat=[nlon] * nlat, truncation=trunc) + + x = np.ones(sht.n_grid_points) * 42.0 + coeffs = sht.transform(x) + # l=0, m=0 should dominate + assert np.abs(coeffs[0, 0]) > 0 + # All other coefficients should be negligible + mask = np.ones_like(coeffs, dtype=bool) + mask[0, 0] = False + assert np.allclose(coeffs[mask], 0, atol=1e-10) + + +# --------------------------------------------------------------------------- +# sht_psd +# --------------------------------------------------------------------------- + + +class TestShtPsd: + """Tests for the sht_psd high-level function.""" + + def test_output_shape(self): + """sht_psd returns wavenumbers and PSD with matching shapes.""" + nlat = 64 + lons = [20 + 4 * i for i in range(nlat // 2)] + lons = lons + list(reversed(lons)) + n_points = sum(lons) + data = np.random.default_rng(9).standard_normal(n_points) + + wn, psd = sht_psd(data, nlat=nlat, grid_type="octahedral") + assert wn.shape == psd.shape + assert len(wn) == nlat // 2 # truncation default = nlat // 2 - 1 → L = nlat // 2 + assert np.all(psd >= 0) + + def test_multi_sample(self): + """Multi-sample input is averaged.""" + nlat = 32 + nlon = 64 + n_points = nlat * nlon + data = np.random.default_rng(10).standard_normal((4, n_points)) + + wn, psd = sht_psd(data, nlat=nlat, grid_type="regular") + assert wn.shape == psd.shape + + +# --------------------------------------------------------------------------- +# compute_psd_for_field dispatch +# --------------------------------------------------------------------------- + + +class TestComputePsdForField: + """Tests for the dispatch function.""" + + def test_sht_dispatch(self): + """method='sht' calls sht_psd correctly.""" + nlat = 32 + nlon = 64 + data = np.random.default_rng(11).standard_normal(nlat * nlon) + wn, psd = compute_psd_for_field(data, method="sht", nlat=nlat, grid_type="regular") + assert len(wn) == len(psd) + assert np.all(psd >= 0) + + def test_zonal_dispatch(self): + """method='zonal' calls zonal_psd correctly.""" + nlat, nlon = 90, 360 + lats = np.linspace(90, -90, nlat) + lons = np.linspace(0, 359, nlon) + data = np.random.default_rng(12).standard_normal((nlat, nlon)) + + freq, psd = compute_psd_for_field( + data, method="zonal", lats=lats, lons=lons + ) + assert len(freq) == len(psd) + + def test_invalid_method(self): + """Unknown method raises ValueError.""" + with pytest.raises(ValueError, match="Unknown PSD method"): + compute_psd_for_field(np.zeros(10), method="invalid") + + def test_missing_nlat_for_sht(self): + """method='sht' without nlat raises ValueError.""" + with pytest.raises(ValueError, match="nlat is required"): + compute_psd_for_field(np.zeros(10), method="sht") + + def test_missing_lats_for_zonal(self): + """method='zonal' without lats raises ValueError.""" + with pytest.raises(ValueError, match="lats and lons are required"): + compute_psd_for_field(np.zeros(10), method="zonal") diff --git a/tests/test_sht_roundtrip.py b/tests/test_sht_roundtrip.py new file mode 100644 index 000000000..3fae345dc --- /dev/null +++ b/tests/test_sht_roundtrip.py @@ -0,0 +1,95 @@ +"""Test that SHT forward → inverse is (approximately) the identity.""" + +import numpy as np +import pytest + +from weathergen.evaluate.scores.psd import ( + InverseSphericalHarmonicTransform, + SphericalHarmonicTransform, + _octahedral_lons_per_lat, + _regular_lons_per_lat, +) + + +@pytest.mark.parametrize("grid_type,nlat", [ + ("regular", 32), + ("regular", 64), + ("octahedral", 32), + ("octahedral", 64), +]) +def test_sht_roundtrip_identity(grid_type: str, nlat: int) -> None: + """Applying SHT then inverse SHT on random noise recovers the original field.""" + rng = np.random.default_rng(42) + + if grid_type == "regular": + lons_per_lat = _regular_lons_per_lat(nlat) + else: + lons_per_lat = _octahedral_lons_per_lat(nlat) + + n_grid_points = sum(lons_per_lat) + truncation = nlat // 2 - 1 + + sht = SphericalHarmonicTransform(lons_per_lat=lons_per_lat, truncation=truncation) + isht = InverseSphericalHarmonicTransform(lons_per_lat=lons_per_lat, truncation=truncation) + + # Random spatial field + x = rng.standard_normal(n_grid_points) + + # Forward → inverse + coeffs = sht.transform(x) + x_reconstructed = isht.transform(coeffs) + + # The reconstruction is approximate due to truncation, but should be close + # for smooth-enough fields. For a bandlimited signal it should be exact. + # Use a generous tolerance since truncation discards high-frequency content. + assert x_reconstructed.shape == x.shape, ( + f"Shape mismatch: {x_reconstructed.shape} vs {x.shape}" + ) + + # Check correlation is positive — truncation discards high-frequency content + # so white noise won't be perfectly recovered, but the low-frequency part should match. + corr = np.corrcoef(x.ravel(), x_reconstructed.ravel())[0, 1] + assert corr > 0.30, f"Correlation too low: {corr:.4f}" + + # More importantly: verify the energy is preserved for the retained modes + # by checking that the relative L2 error is bounded + rel_error = np.linalg.norm(x - x_reconstructed) / np.linalg.norm(x) + assert rel_error < 1.0, f"Relative L2 error too large: {rel_error:.4f}" + + +@pytest.mark.parametrize("grid_type,nlat", [ + ("regular", 32), + ("regular", 64), + ("octahedral", 32), + ("octahedral", 64), +]) +def test_sht_roundtrip_bandlimited(grid_type: str, nlat: int) -> None: + """For a bandlimited signal, SHT → inverse SHT should be near-exact.""" + if grid_type == "regular": + lons_per_lat = _regular_lons_per_lat(nlat) + else: + lons_per_lat = _octahedral_lons_per_lat(nlat) + + n_grid_points = sum(lons_per_lat) + truncation = nlat // 2 - 1 + + sht = SphericalHarmonicTransform(lons_per_lat=lons_per_lat, truncation=truncation) + isht = InverseSphericalHarmonicTransform(lons_per_lat=lons_per_lat, truncation=truncation) + + # Create a bandlimited signal by doing inverse SHT on random coefficients + rng = np.random.default_rng(123) + L = truncation + 1 + random_coeffs = rng.standard_normal((L, L)) + 1j * rng.standard_normal((L, L)) + # Make it physically meaningful: zero out upper triangle (m > l) + for l in range(L): + random_coeffs[l, l + 1:] = 0.0 + + # Inverse → forward → inverse should give back the same spatial field + x_bandlimited = isht.transform(random_coeffs) + coeffs_recovered = sht.transform(x_bandlimited) + x_roundtrip = isht.transform(coeffs_recovered) + + np.testing.assert_allclose( + x_roundtrip, x_bandlimited, rtol=1e-6, atol=1e-10, + err_msg="Roundtrip on bandlimited signal should be near-exact", + ) From d9f079d7932036dfe6c11efd511e33d1a928889e Mon Sep 17 00:00:00 2001 From: iluise Date: Thu, 21 May 2026 18:00:16 +0200 Subject: [PATCH 7/8] lint --- .../evaluate/plotting/line_plots.py | 8 +- .../evaluate/plotting/plot_orchestration.py | 1 - .../src/weathergen/evaluate/scores/psd.py | 164 ++++++++++-------- .../src/weathergen/evaluate/scores/score.py | 47 ++--- 4 files changed, 118 insertions(+), 102 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/line_plots.py b/packages/evaluate/src/weathergen/evaluate/plotting/line_plots.py index 422c59a70..70c2f0448 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/line_plots.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/line_plots.py @@ -702,7 +702,9 @@ def psd_plot( tar_psd = np.asarray(psd_datasets[0]["psd_target"]) fig, (ax_spec, ax_ratio) = plt.subplots( - 2, 1, figsize=self.fig_size or (10, 8), + 2, + 1, + figsize=self.fig_size or (10, 8), gridspec_kw={"height_ratios": [2, 1], "hspace": 0.08}, ) @@ -714,7 +716,9 @@ def psd_plot( ax_spec.loglog( np.asarray(ds["frequencies"]), np.asarray(ds["psd_prediction"]), - color=c, lw=1.5, label=label, + color=c, + lw=1.5, + label=label, ) ax_spec.set_ylabel("Power") psd_method = psd_datasets[0].get("psd_method", "sht") diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py b/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py index 194b21367..6f6d7bc1b 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py @@ -804,4 +804,3 @@ def plot_summary(cfg: dict, scores_dict: dict, summary_dir: Path): score_card_metric_region(metric, region, runs, scores_dict, sc_plotter) if eval_opt.get("bar_plots", False): bar_plot_metric_region(metric, region, runs, scores_dict, br_plotter) - \ No newline at end of file diff --git a/packages/evaluate/src/weathergen/evaluate/scores/psd.py b/packages/evaluate/src/weathergen/evaluate/scores/psd.py index 5504df92f..a063c9347 100644 --- a/packages/evaluate/src/weathergen/evaluate/scores/psd.py +++ b/packages/evaluate/src/weathergen/evaluate/scores/psd.py @@ -15,9 +15,9 @@ Spherical Harmonic Transform on unstructured grids (octahedral, reduced Gaussian, regular lat-lon). Ported from anemoi.models ``spectral_transforms.py`` to pure numpy using Legendre helpers from anemoi.models ``spectral_helpers.py``. - [anemoi.models.spectral_transforms] + [anemoi.models.spectral_transforms] https://github.com/ecmwf/anemoi-core/blob/main/models/src/anemoi/models/layers/spectral_transforms.py - [anemoi.models.spectral_helpers] + [anemoi.models.spectral_helpers] https://github.com/ecmwf/anemoi-core/blob/main/models/src/anemoi/models/layers/spectral_helpers.py - **Path B – FFT PSD** (``method="fft"``): @@ -28,6 +28,7 @@ from __future__ import annotations import logging + import numpy as np from scipy.interpolate import griddata @@ -39,7 +40,9 @@ # --------------------------------------------------------------------------- -def _legendre_gauss_weights(n: int, a: float = -1.0, b: float = 1.0) -> tuple[np.ndarray, np.ndarray]: +def _legendre_gauss_weights( + n: int, a: float = -1.0, b: float = 1.0 +) -> tuple[np.typing.NDArray, np.typing.NDArray]: """Return Legendre-Gauss nodes and weights on ``[a, b]``.""" xlg, wlg = np.polynomial.legendre.leggauss(n) xlg = (b - a) * 0.5 * xlg + (b + a) * 0.5 @@ -47,7 +50,9 @@ def _legendre_gauss_weights(n: int, a: float = -1.0, b: float = 1.0) -> tuple[np return xlg, wlg -def _legpoly(mmax: int, lmax: int, x: np.ndarray, inverse: bool = False) -> np.ndarray: +def _legpoly( + mmax: int, lmax: int, x: np.typing.NDArray, inverse: bool = False +) -> np.typing.NDArray: """Compute associated Legendre polynomials. Returns shape ``(mmax+1, lmax+1, len(x))``. @@ -92,9 +97,7 @@ def __init__(self, lons_per_lat: list[int], truncation: int) -> None: self.lons_per_lat = lons_per_lat self.nlat = len(lons_per_lat) self.truncation = truncation - assert 0 < truncation <= self.nlat, ( - f"Truncation {truncation} must be in (0, {self.nlat}]" - ) + assert 0 < truncation <= self.nlat, f"Truncation {truncation} must be in (0, {self.nlat}]" self.n_grid_points = sum(lons_per_lat) # Offsets into the flattened grid for each latitude ring @@ -115,53 +118,51 @@ def __init__(self, lons_per_lat: list[int], truncation: int) -> None: # internal FFT helpers - def _rfft_regular(self, x: np.ndarray) -> np.ndarray: + def _rfft_regular(self, x: np.typing.NDArray) -> np.typing.NDArray: """Batched real FFT for a *regular* grid. Parameters ---------- - x : np.ndarray, shape ``(..., grid)`` + x : np.typing.NDArray, shape ``(..., grid)`` Returns ------- - np.ndarray, complex, shape ``(..., nlat, nlon//2+1)`` + np.typing.NDArray, complex, shape ``(..., nlat, nlon//2+1)`` """ nlon = self.lons_per_lat[0] return np.fft.rfft(x.reshape(*x.shape[:-1], self.nlat, nlon), norm="forward") - def _rfft_reduced(self, x: np.ndarray) -> np.ndarray: + def _rfft_reduced(self, x: np.typing.NDArray) -> np.typing.NDArray: """Per-ring real FFT for a *reduced* (variable-resolution) grid. Parameters ---------- - x : np.ndarray, shape ``(..., grid)`` + x : np.typing.NDArray, shape ``(..., grid)`` Returns ------- - np.ndarray, complex, shape ``(..., nlat, max_nlon//2+1)`` + np.typing.NDArray, complex, shape ``(..., nlat, max_nlon//2+1)`` """ max_nlon = max(self.lons_per_lat) out_shape = (*x.shape[:-1], self.nlat, max_nlon // 2 + 1) out = np.zeros(out_shape, dtype=np.complex128) - for i, (slon, nlon) in enumerate(zip(self.slon, self.lons_per_lat)): - out[..., i, : nlon // 2 + 1] = np.fft.rfft( - x[..., slon : slon + nlon], norm="forward" - ) + for i, (slon, nlon) in enumerate(zip(self.slon, self.lons_per_lat, strict=False)): + out[..., i, : nlon // 2 + 1] = np.fft.rfft(x[..., slon : slon + nlon], norm="forward") return out - # transform + # transform - def transform(self, x: np.ndarray) -> np.ndarray: + def transform(self, x: np.typing.NDArray) -> np.typing.NDArray: """Compute the SHT. Parameters ---------- - x : np.ndarray, real, shape ``(..., grid)`` + x : np.typing.NDArray, real, shape ``(..., grid)`` Returns ------- - np.ndarray, complex, shape ``(..., L, M)`` where + np.typing.NDArray, complex, shape ``(..., L, M)`` where ``L = M = truncation + 1``. """ if self._is_regular: @@ -209,31 +210,31 @@ def __init__(self, lons_per_lat: list[int], truncation: int) -> None: # Associated Legendre polynomials with inverse=True self.pct = _legpoly(truncation, truncation, np.cos(theta), inverse=True) - def _irfft_regular(self, x: np.ndarray) -> np.ndarray: + def _irfft_regular(self, x: np.typing.NDArray) -> np.typing.NDArray: """Inverse FFT for a regular grid. Parameters ---------- - x : np.ndarray, complex, shape ``(..., nlat, M)`` + x : np.typing.NDArray, complex, shape ``(..., nlat, M)`` Returns ------- - np.ndarray, real, shape ``(..., grid)`` + np.typing.NDArray, real, shape ``(..., grid)`` """ nlon = self.lons_per_lat[0] spatial = np.fft.irfft(x, n=nlon, norm="forward") # (..., nlat, nlon) return spatial.reshape(*spatial.shape[:-2], self.n_grid_points) - def _irfft_reduced(self, x: np.ndarray) -> np.ndarray: + def _irfft_reduced(self, x: np.typing.NDArray) -> np.typing.NDArray: """Per-ring inverse FFT for a reduced grid. Parameters ---------- - x : np.ndarray, complex, shape ``(..., nlat, M)`` + x : np.typing.NDArray, complex, shape ``(..., nlat, M)`` Returns ------- - np.ndarray, real, shape ``(..., grid)`` + np.typing.NDArray, real, shape ``(..., grid)`` """ lead_shape = x.shape[:-2] out = np.zeros((*lead_shape, self.n_grid_points), dtype=np.float64) @@ -244,16 +245,16 @@ def _irfft_reduced(self, x: np.ndarray) -> np.ndarray: offset += nlon return out - def transform(self, coeffs: np.ndarray) -> np.ndarray: + def transform(self, coeffs: np.typing.NDArray) -> np.typing.NDArray: """Compute the inverse SHT. Parameters ---------- - coeffs : np.ndarray, complex, shape ``(..., L, M)`` + coeffs : np.typing.NDArray, complex, shape ``(..., L, M)`` Returns ------- - np.ndarray, real, shape ``(..., grid)`` + np.typing.NDArray, real, shape ``(..., grid)`` """ # Inverse Legendre transform: (..., l, m) × (m, l, k) → (..., k, m) real_part = coeffs.real @@ -293,11 +294,11 @@ def _regular_lons_per_lat(nlat: int) -> list[int]: def sht_psd( - data: np.ndarray, + data: np.typing.NDArray, nlat: int, truncation: int | None = None, grid_type: str = "octahedral", -) -> tuple[np.ndarray, np.ndarray]: +) -> tuple[np.typing.NDArray, np.typing.NDArray]: """Compute PSD via Spherical Harmonic Transform. 1. Forward SHT: spatial → spectral coefficients ``(l, m)``. @@ -305,7 +306,7 @@ def sht_psd( Parameters ---------- - data : np.ndarray + data : np.typing.NDArray Spatial field with shape ``(n_points,)`` or ``(n_samples, n_points)``. nlat : int Number of latitudes in the grid. @@ -316,9 +317,9 @@ def sht_psd( Returns ------- - wavenumbers : np.ndarray, shape ``(L,)`` + wavenumbers : np.typing.NDArray, shape ``(L,)`` Total wavenumber indices ``0, 1, …, L-1``. - psd : np.ndarray, shape ``(L,)`` + psd : np.typing.NDArray, shape ``(L,)`` Power spectral density averaged over samples. """ if data.ndim == 1: @@ -360,8 +361,8 @@ def sht_psd( psd_per_sample = np.sum(np.abs(coeffs) ** 2, axis=-1) # (n_samples, L) psd = psd_per_sample.mean(axis=0) - L = psd.shape[0] - wavenumbers = np.arange(L, dtype=np.float64) + n_wavenumbers = psd.shape[0] + wavenumbers = np.arange(n_wavenumbers, dtype=np.float64) return wavenumbers, psd @@ -371,19 +372,19 @@ def sht_psd( # --------------------------------------------------------------------------- -def _fft_psd_calc(ht: np.ndarray) -> np.ndarray: +def _fft_psd_calc(ht: np.typing.NDArray) -> np.typing.NDArray: """Return the PSD for positive non-zero frequencies of an even-length signal. Assumes *ht* has an even number of points. Parameters ---------- - ht : np.ndarray + ht : np.typing.NDArray 1-D real-valued signal (one latitude ring). Returns ------- - np.ndarray + np.typing.NDArray PSD for positive frequencies, length ``n // 2``. """ n = len(ht) @@ -393,17 +394,17 @@ def _fft_psd_calc(ht: np.ndarray) -> np.ndarray: return power -def _cubepsd(field_2d: np.ndarray) -> np.ndarray: +def _cubepsd(field_2d: np.typing.NDArray) -> np.typing.NDArray: """Compute PSD averaged over all latitude rows. Parameters ---------- - field_2d : np.ndarray + field_2d : np.typing.NDArray 2-D array of shape ``(nlat, nlon)``. Returns ------- - np.ndarray + np.typing.NDArray PSD of shape ``(nlon // 2,)``. """ nlat, nlon = field_2d.shape @@ -414,7 +415,7 @@ def _cubepsd(field_2d: np.ndarray) -> np.ndarray: return field_psd -def _calcposfreq(npoints: int, spacing_deg: float = 1.0) -> np.ndarray: +def _calcposfreq(npoints: int, spacing_deg: float = 1.0) -> np.typing.NDArray: """Return the positive frequencies for a signal of *npoints* evenly spaced points. Parameters @@ -426,7 +427,7 @@ def _calcposfreq(npoints: int, spacing_deg: float = 1.0) -> np.ndarray: Returns ------- - np.ndarray + np.typing.NDArray Positive frequencies, length ``npoints // 2``. """ freq = np.fft.fftfreq(npoints, d=spacing_deg) @@ -434,12 +435,12 @@ def _calcposfreq(npoints: int, spacing_deg: float = 1.0) -> np.ndarray: def fft_psd( - data: np.ndarray, - lats: np.ndarray, - lons: np.ndarray, + data: np.typing.NDArray, + lats: np.typing.NDArray, + lons: np.typing.NDArray, lat_range: tuple[float, float] = (-60.0, 60.0), regrid_resolution: float = 1.0, -) -> tuple[np.ndarray, np.ndarray]: +) -> tuple[np.typing.NDArray, np.typing.NDArray]: """Compute PSD using 1-D FFT along the longitude dimension. For unstructured grids (where lats/lons are per-point coordinates rather @@ -448,11 +449,11 @@ def fft_psd( Parameters ---------- - data : np.ndarray + data : np.typing.NDArray Field values. Shape ``(n_samples, n_points)`` or ``(n_points,)``. - lats : np.ndarray + lats : np.typing.NDArray Latitude values. Either per-point (length ``n_points``) or axis (length ``nlat``). - lons : np.ndarray + lons : np.typing.NDArray Longitude values. Either per-point (length ``n_points``) or axis (length ``nlon``). lat_range : tuple[float, float] Latitude bounds to restrict the computation to. @@ -461,9 +462,9 @@ def fft_psd( Returns ------- - frequencies : np.ndarray + frequencies : np.typing.NDArray Positive frequencies in cycles per degree, shape ``(nfreq,)``. - psd : np.ndarray + psd : np.typing.NDArray Power spectral density averaged over samples and latitude rows, shape ``(nfreq,)``. """ @@ -477,7 +478,7 @@ def fft_psd( # Determine if the grid is regular or unstructured unique_lats = np.unique(lats) unique_lons = np.unique(lons) - is_regular = (len(unique_lats) * len(unique_lons) == n_points) + is_regular = len(unique_lats) * len(unique_lons) == n_points if is_regular and len(lats) == len(unique_lats): # lats/lons are axis arrays for a regular grid @@ -524,27 +525,27 @@ def fft_psd( def compute_psd_for_field( - data: np.ndarray, + data: np.typing.NDArray, method: str = "sht", nlat: int | None = None, - lats: np.ndarray | None = None, - lons: np.ndarray | None = None, + lats: np.typing.NDArray | None = None, + lons: np.typing.NDArray | None = None, lat_range: tuple[float, float] = (-60.0, 60.0), regrid_resolution: float = 1.0, sht_truncation: int | None = None, grid_type: str = "octahedral", -) -> tuple[np.ndarray, np.ndarray]: +) -> tuple[np.typing.NDArray, np.typing.NDArray]: """Compute PSD using the selected method. Parameters ---------- - data : np.ndarray + data : np.typing.NDArray Spatial field. Shape depends on the method (see ``sht_psd`` / ``fft_psd``). method : str ``"sht"`` for SHT-based PSD, ``"fft"`` for FFT PSD. nlat : int | None Number of latitudes (required for SHT method). - lats, lons : np.ndarray | None + lats, lons : np.typing.NDArray | None Latitude / longitude coordinate arrays (required for fft method). lat_range : tuple[float, float] Latitude bounds for the fft method. @@ -557,9 +558,9 @@ def compute_psd_for_field( Returns ------- - x_values : np.ndarray + x_values : np.typing.NDArray Wavenumbers (SHT) or positive frequencies (fft). - psd : np.ndarray + psd : np.typing.NDArray Power spectral density. """ if method == "sht": @@ -586,10 +587,10 @@ def compute_psd_for_field( def compute_psd_score( - gt: np.ndarray, - p: np.ndarray, - lats: np.ndarray | None, - lons: np.ndarray | None, + gt: np.typing.NDArray, + p: np.typing.NDArray, + lats: np.typing.NDArray | None, + lons: np.typing.NDArray | None, nlat: int | None, n_points: int, psd_method: str = "sht", @@ -606,9 +607,9 @@ def compute_psd_score( Parameters ---------- - gt, p : np.ndarray + gt, p : np.typing.NDArray Ground truth and prediction arrays of shape ``(n_samples, n_points)``. - lats, lons : np.ndarray | None + lats, lons : np.typing.NDArray | None Latitude / longitude arrays of length ``n_points`` (or None). nlat : int | None Number of latitudes (for SHT fallback). @@ -633,7 +634,6 @@ def compute_psd_score( """ # Handle NaN grid points (e.g. from regional masking). valid_mask = ~np.isnan(gt).all(axis=0) - n_valid = valid_mask.sum() gt = gt[:, valid_mask] p = p[:, valid_mask] @@ -662,14 +662,26 @@ def compute_psd_score( try: freq_gt, psd_gt = compute_psd_for_field( - data=gt, method=psd_method, nlat=nlat_valid, lats=lats_valid, lons=lons_valid, - lat_range=lat_range, regrid_resolution=psd_regrid_resolution, - sht_truncation=psd_sht_truncation, grid_type=grid_type, + data=gt, + method=psd_method, + nlat=nlat_valid, + lats=lats_valid, + lons=lons_valid, + lat_range=lat_range, + regrid_resolution=psd_regrid_resolution, + sht_truncation=psd_sht_truncation, + grid_type=grid_type, ) freq_p, psd_p = compute_psd_for_field( - data=p, method=psd_method, nlat=nlat_valid, lats=lats_valid, lons=lons_valid, - lat_range=lat_range, regrid_resolution=psd_regrid_resolution, - sht_truncation=psd_sht_truncation, grid_type=grid_type, + data=p, + method=psd_method, + nlat=nlat_valid, + lats=lats_valid, + lons=lons_valid, + lat_range=lat_range, + regrid_resolution=psd_regrid_resolution, + sht_truncation=psd_sht_truncation, + grid_type=grid_type, ) except Exception: _logger.exception("PSD computation failed, returning NaN.") diff --git a/packages/evaluate/src/weathergen/evaluate/scores/score.py b/packages/evaluate/src/weathergen/evaluate/scores/score.py index 0e4c40163..28b58bb90 100755 --- a/packages/evaluate/src/weathergen/evaluate/scores/score.py +++ b/packages/evaluate/src/weathergen/evaluate/scores/score.py @@ -1842,7 +1842,8 @@ def calc_psd( Latitude range (min, max) to include in PSD calculation. Default is (-60, 60) degrees. grid_type: str - Type of grid for PSD calculation. Options: 'octahedral', 'regular'. Default is 'octahedral'. + Type of grid for PSD calculation. Options: 'octahedral', 'regular'. + Default is 'octahedral'. Returns ------- @@ -1867,14 +1868,17 @@ def calc_psd( nlat, lats, lons = self._get_psd_grid_info(gt, spatial_dim) if psd_method == "fft" and (lats is None or lons is None): - raise ValueError( - f"PSD method 'fft' requires lat/lon coords on '{spatial_dim}'." - ) + raise ValueError(f"PSD method 'fft' requires lat/lon coords on '{spatial_dim}'.") psd_kwargs = dict( - lats=lats, lons=lons, nlat=nlat, n_points=n_points, - psd_method=psd_method, psd_regrid_resolution=psd_regrid_resolution, - psd_sht_truncation=psd_sht_truncation, lat_range=lat_range, + lats=lats, + lons=lons, + nlat=nlat, + n_points=n_points, + psd_method=psd_method, + psd_regrid_resolution=psd_regrid_resolution, + psd_sht_truncation=psd_sht_truncation, + lat_range=lat_range, grid_type=grid_type, ) @@ -1896,7 +1900,7 @@ def calc_psd( all_attrs: dict = {} for idx in np.ndindex(*shape): - sel = dict(zip(preserve_dims, idx)) + sel = dict(zip(preserve_dims, idx, strict=False)) gt_slice = gt.isel(**sel) p_slice = p.isel(**sel) gt_np, p_np = self._stack_for_psd(gt_slice, p_slice, spatial_dim, n_points) @@ -1905,8 +1909,7 @@ def calc_psd( score_values[idx] = slice_score key = "_".join( - str(gt.coords[d].values[i]) if d in gt.coords else str(i) - for d, i in sel.items() + str(gt.coords[d].values[i]) if d in gt.coords else str(i) for d, i in sel.items() ) for k, v in slice_attrs.items(): all_attrs[f"{key}/{k}"] = v @@ -1921,10 +1924,10 @@ def calc_psd( @staticmethod def _get_psd_grid_info( gt: xr.DataArray, spatial_dim: str - ) -> tuple[int | None, np.ndarray | None, np.ndarray | None]: + ) -> tuple[int | None, np.typing.NDArray | None, np.typing.NDArray | None]: """ Extract nlat, lats, lons from ground-truth coords. - + Parameters ---------- gt: xr.DataArray @@ -1935,28 +1938,26 @@ def _get_psd_grid_info( ------- nlat: int | None Number of latitude points, or None if lat/lon coords are not found. - lats: np.ndarray | None + lats: np.typing.NDArray | None Latitude values, or None if lat/lon coords are not found. - lons: np.ndarray | None + lons: np.typing.NDArray | None Longitude values, or None if lat/lon coords are not found. - + """ if "lat" in gt.coords and "lon" in gt.coords: if gt.coords["lat"].dims == (spatial_dim,) and gt.coords["lon"].dims == (spatial_dim,): lats = gt.coords["lat"].values lons = gt.coords["lon"].values return len(np.unique(lats)), lats, lons - raise ValueError( - f"PSD requires lat/lon coords on spatial dimension '{spatial_dim}'." - ) + raise ValueError(f"PSD requires lat/lon coords on spatial dimension '{spatial_dim}'.") @staticmethod def _stack_for_psd( gt: xr.DataArray, p: xr.DataArray, spatial_dim: str, n_points: int - ) -> tuple[np.ndarray, np.ndarray]: + ) -> tuple[np.typing.NDArray, np.typing.NDArray]: """ Reshape data to (n_batch, n_points) for PSD computation. - + Parameters ---------- gt: xr.DataArray @@ -1969,10 +1970,10 @@ def _stack_for_psd( Number of points along the spatial dimension. Returns ------- - gt_np: np.ndarray + gt_np: np.typing.NDArray Reshaped ground truth data of shape (n_batch, n_points). - p_np: np.ndarray - Reshaped forecast data of shape (n_batch, n_points). + p_np: np.typing.NDArray + Reshaped forecast data of shape (n_batch, n_points). """ non_spatial = [d for d in gt.dims if d != spatial_dim] if non_spatial: From 10115e100d242d55e89c663da04a237d764c0238 Mon Sep 17 00:00:00 2001 From: iluise Date: Thu, 21 May 2026 18:06:37 +0200 Subject: [PATCH 8/8] add var name to psd plot title --- .../src/weathergen/evaluate/plotting/line_plots.py | 9 ++++++++- .../src/weathergen/evaluate/plotting/plot_utils.py | 5 ++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/line_plots.py b/packages/evaluate/src/weathergen/evaluate/plotting/line_plots.py index 70c2f0448..17a117cde 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/line_plots.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/line_plots.py @@ -678,6 +678,8 @@ def psd_plot( psd_datasets: list[dict], labels: list[str], tag: str = "", + variable: str = "", + forecast_step: str = "", ) -> None: """Create a PSD summary plot overlaying multiple runs. @@ -722,7 +724,12 @@ def psd_plot( ) ax_spec.set_ylabel("Power") psd_method = psd_datasets[0].get("psd_method", "sht") - ax_spec.set_title(f"PSD summary (psd_{psd_method})") + title_parts = [f"PSD ({psd_method})"] + if variable: + title_parts.append(variable) + if forecast_step: + title_parts.append(f"step {forecast_step}") + ax_spec.set_title(" – ".join(title_parts)) ax_spec.legend(frameon=False, fontsize=7) ax_spec.grid(True, which="both", ls="--", alpha=0.4) diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py b/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py index 0b3b64218..61e935831 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py @@ -832,7 +832,10 @@ def psd_plot_metric_region( middle=[run_id], suffix=[stream, ch, f"fstep{fstep}"], ) - plotter.psd_plot(psd_datasets, [label], tag=name) + plotter.psd_plot( + psd_datasets, [label], tag=name, + variable=ch, forecast_step=str(fstep), + ) def create_filename(