diff --git a/config/evaluate/eval_config.yml b/config/evaluate/eval_config.yml index a7415052b..4d54014d7 100644 --- a/config/evaluate/eval_config.yml +++ b/config/evaluate/eval_config.yml @@ -29,6 +29,13 @@ # max_workers: 36 # hard cap on parallel workers (I/O, plotting, scoring) evaluation: + metrics : + - rmse + - mae + - psd: + psd_method: "sht" # "sht" (SHT-based) or "fft" (regrid FFT) + psd_regrid_resolution: 1.0 # degrees, only for fft method + grid_type: "octahedral" # "octahedral" (default), "regular", "reduced". metrics : ["rmse", "mae"] # regions: ["global", "nhem"] # Have regions here, if you want for them to apply to all streams (scores calculation) summary_plots : true diff --git a/packages/evaluate/pyproject.toml b/packages/evaluate/pyproject.toml index a2d37fcae..5c4afdc9c 100644 --- a/packages/evaluate/pyproject.toml +++ b/packages/evaluate/pyproject.toml @@ -21,8 +21,15 @@ dependencies = [ "earthkit-data==0.17.0", "earthkit-utils==0.1.2", "imageio[ffmpeg]>=2.37.2", + "scipy>=1.12", ] +[project.optional-dependencies] +fft = [ + "scitools-iris>=3.11", + "cf-units", +] + [dependency-groups] dev = [ "pytest>=8.3", diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/line_plots.py b/packages/evaluate/src/weathergen/evaluate/plotting/line_plots.py index 7363c495e..17a117cde 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/line_plots.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/line_plots.py @@ -15,6 +15,7 @@ from pathlib import Path import matplotlib.pyplot as plt +import numpy as np import seaborn as sns import xarray as xr @@ -667,3 +668,86 @@ def heat_map( parts = ["heat_map", metric, tag] name = "_".join(filter(None, parts)) plt.savefig(f"{self.out_plot_dir.joinpath(name)}.{self.image_format}") + + # ------------------------------------------------------------------ + # PSD summary plot + # ------------------------------------------------------------------ + + def psd_plot( + self, + psd_datasets: list[dict], + labels: list[str], + tag: str = "", + variable: str = "", + forecast_step: str = "", + ) -> None: + """Create a PSD summary plot overlaying multiple runs. + + Each entry in *psd_datasets* is a dict with keys + ``frequencies``, ``psd_target``, ``psd_prediction``, ``psd_method``. + + Parameters + ---------- + psd_datasets : list[dict] + One dict per run, each containing the PSD arrays stored by + ``Scores.calc_psd`` in ``.attrs``. + labels : list[str] + Human-readable label for each run. + tag : str + Filename tag. + """ + out_dir = Path(self.out_plot_dir_lines) / "psd" + out_dir.mkdir(parents=True, exist_ok=True) + + # Use the target from the first run as reference + freq = np.asarray(psd_datasets[0]["frequencies"]) + tar_psd = np.asarray(psd_datasets[0]["psd_target"]) + + fig, (ax_spec, ax_ratio) = plt.subplots( + 2, + 1, + figsize=self.fig_size or (10, 8), + gridspec_kw={"height_ratios": [2, 1], "hspace": 0.08}, + ) + + # Upper panel: log-log spectra + ax_spec.loglog(freq, tar_psd, color="black", lw=1.5, label="Target") + colors = plt.cm.tab10.colors + for i, (ds, label) in enumerate(zip(psd_datasets, labels, strict=False)): + c = colors[i % len(colors)] + ax_spec.loglog( + np.asarray(ds["frequencies"]), + np.asarray(ds["psd_prediction"]), + color=c, + lw=1.5, + label=label, + ) + ax_spec.set_ylabel("Power") + psd_method = psd_datasets[0].get("psd_method", "sht") + title_parts = [f"PSD ({psd_method})"] + if variable: + title_parts.append(variable) + if forecast_step: + title_parts.append(f"step {forecast_step}") + ax_spec.set_title(" – ".join(title_parts)) + ax_spec.legend(frameon=False, fontsize=7) + ax_spec.grid(True, which="both", ls="--", alpha=0.4) + + # Lower panel: ratio (pred / target) + for i, (ds, label) in enumerate(zip(psd_datasets, labels, strict=False)): + c = colors[i % len(colors)] + pred = np.asarray(ds["psd_prediction"]) + with np.errstate(divide="ignore", invalid="ignore"): + ratio = np.where(tar_psd > 0, pred / tar_psd, np.nan) + ax_ratio.semilogx(freq, ratio, color=c, lw=1.2, label=label) + ax_ratio.axhline(1.0, ls="--", color="gray", lw=0.8) + ax_ratio.set_ylabel("Pred / Target") + ax_ratio.set_xlabel("Frequency (1/deg)") + ax_ratio.set_ylim(0, 2) + ax_ratio.grid(True, which="both", ls="--", alpha=0.4) + + name = tag or "psd" + fname = out_dir / f"{name}.{self.image_format}" + _logger.info(f"Saving PSD summary plot to {fname}") + fig.savefig(str(fname), bbox_inches="tight", dpi=self.dpi_val) + plt.close(fig) diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py b/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py index f129c34eb..ae680aacf 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py @@ -29,6 +29,7 @@ bar_plot_metric_region, heat_maps_metric_region, plot_metric_region, + psd_plot_metric_region, quantile_plot_metric_region, ratio_plot_metric_region, score_card_metric_region, @@ -788,7 +789,12 @@ def plot_summary(cfg: dict, scores_dict: dict, summary_dir: Path): for metric in metrics: for region in scores_dict[metric].keys(): if eval_opt.get("summary_plots", False): - plot_metric_region(metric, region, runs, scores_dict, plotter, print_summary) + if metric == "psd": + psd_plot_metric_region(metric, region, runs, scores_dict, plotter) + elif metric == "qq_analysis": + quantile_plot_metric_region(metric, region, runs, scores_dict, quantile_plotter) + else: + plot_metric_region(metric, region, runs, scores_dict, plotter, print_summary) if eval_opt.get("ratio_plots", False): ratio_plot_metric_region(metric, region, runs, scores_dict, plotter, print_summary) if eval_opt.get("heat_maps", False): @@ -797,5 +803,3 @@ def plot_summary(cfg: dict, scores_dict: dict, summary_dir: Path): score_card_metric_region(metric, region, runs, scores_dict, sc_plotter) if eval_opt.get("bar_plots", False): bar_plot_metric_region(metric, region, runs, scores_dict, br_plotter) - if metric == "qq_analysis": - quantile_plot_metric_region(metric, region, runs, scores_dict, quantile_plotter) diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py b/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py index 5fd501dd5..61e935831 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py @@ -768,6 +768,76 @@ def quantile_plot_metric_region( ) +def _extract_psd_attrs(data_ch: xr.DataArray, fstep: int, ch: str) -> list[dict] | None: + """Extract PSD curve data from DataArray attrs for a given fstep/channel. + + Returns a single-element list of dicts ready for the plotter, or None if keys are missing. + """ + attrs = data_ch.attrs + fp = f"fstep_{fstep}/" + + for prefix in (f"{fp}{ch}/", fp): + if f"{prefix}frequencies" in attrs and f"{prefix}psd_target" in attrs: + return [ + { + "frequencies": np.array(attrs[f"{prefix}frequencies"]), + "psd_target": np.array(attrs[f"{prefix}psd_target"]), + "psd_prediction": np.array(attrs[f"{prefix}psd_prediction"]), + "psd_method": attrs.get(f"{fp}psd_method", attrs.get("psd_method", "sht")), + } + ] + return None + + +def psd_plot_metric_region( + metric: str, + region: str, + runs: dict, + scores_dict: dict, + plotter: object, +) -> None: + """Create PSD plots for all streams and channels for a given metric and region. + + PSD curves (frequencies, target PSD, prediction PSD) are stored in + ``score.attrs`` by ``Scores.calc_psd`` and read back here. + """ + streams_set = collect_streams(runs) + channels_set = collect_channels(scores_dict, metric, region, runs) + + for stream in streams_set: + for ch in channels_set: + for run_id, data in scores_dict[metric][region].get(stream, {}).items(): + if ch not in np.atleast_1d(data.channel.values): + continue + + data_ch = data.sel(channel=ch) if "channel" in data.dims else data + if data_ch.isnull().all(): + continue + + attr_fsteps = data_ch.attrs.get("attr_fsteps", []) + if not attr_fsteps: + _logger.warning(f"PSD attrs missing for {run_id}/{stream}/{ch}. Skipping.") + continue + + label = runs[run_id].get("label", run_id) + + for fstep in attr_fsteps: + psd_datasets = _extract_psd_attrs(data_ch, fstep, ch) + if psd_datasets is None: + continue + + method_tag = psd_datasets[0].get("psd_method", "sht") + name = create_filename( + prefix=[metric, method_tag, region], + middle=[run_id], + suffix=[stream, ch, f"fstep{fstep}"], + ) + plotter.psd_plot( + psd_datasets, [label], tag=name, + variable=ch, forecast_step=str(fstep), + ) + + def create_filename( *, prefix: Sequence[str] = (), diff --git a/packages/evaluate/src/weathergen/evaluate/scores/psd.py b/packages/evaluate/src/weathergen/evaluate/scores/psd.py new file mode 100644 index 000000000..a063c9347 --- /dev/null +++ b/packages/evaluate/src/weathergen/evaluate/scores/psd.py @@ -0,0 +1,703 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +"""Power Spectral Density (PSD) computation. + +Provides two PSD computation paths: + +- **Path A – SHT-based PSD** (``method="sht"``): + Spherical Harmonic Transform on unstructured grids (octahedral, reduced + Gaussian, regular lat-lon). Ported from anemoi.models ``spectral_transforms.py`` to pure + numpy using Legendre helpers from anemoi.models ``spectral_helpers.py``. + [anemoi.models.spectral_transforms] + https://github.com/ecmwf/anemoi-core/blob/main/models/src/anemoi/models/layers/spectral_transforms.py + [anemoi.models.spectral_helpers] + https://github.com/ecmwf/anemoi-core/blob/main/models/src/anemoi/models/layers/spectral_helpers.py + +- **Path B – FFT PSD** (``method="fft"``): + 1-D zonal FFT along the longitude dimension on a regular lat-lon grid. + Absorbs the functions previously in ``example_extras/power_spectra/psd_calc.py``. +""" + +from __future__ import annotations + +import logging + +import numpy as np +from scipy.interpolate import griddata + +_logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Numpy-based Spherical Harmonic Transform (ported from spectral_helpers.py) +# --------------------------------------------------------------------------- + + +def _legendre_gauss_weights( + n: int, a: float = -1.0, b: float = 1.0 +) -> tuple[np.typing.NDArray, np.typing.NDArray]: + """Return Legendre-Gauss nodes and weights on ``[a, b]``.""" + xlg, wlg = np.polynomial.legendre.leggauss(n) + xlg = (b - a) * 0.5 * xlg + (b + a) * 0.5 + wlg = wlg * (b - a) * 0.5 + return xlg, wlg + + +def _legpoly( + mmax: int, lmax: int, x: np.typing.NDArray, inverse: bool = False +) -> np.typing.NDArray: + """Compute associated Legendre polynomials. + + Returns shape ``(mmax+1, lmax+1, len(x))``. + """ + nmax = max(mmax, lmax) + vdm = np.zeros((nmax + 1, nmax + 1, len(x)), dtype=np.float64) + + norm_factor = np.sqrt(4 * np.pi) + norm_factor = 1.0 / norm_factor if inverse else norm_factor + vdm[0, 0, :] = norm_factor / np.sqrt(4 * np.pi) + + for n in range(1, nmax + 1): + vdm[n - 1, n, :] = np.sqrt(2 * n + 1) * x * vdm[n - 1, n - 1, :] + vdm[n, n, :] = np.sqrt((2 * n + 1) * (1 + x) * (1 - x) / 2 / n) * vdm[n - 1, n - 1, :] + + for n in range(2, nmax + 1): + for m in range(0, n - 1): + vdm[m, n, :] = ( + x * np.sqrt((2 * n - 1) / (n - m) * (2 * n + 1) / (n + m)) * vdm[m, n - 1, :] + - np.sqrt((n + m - 1) / (n - m) * (2 * n + 1) / (2 * n - 3) * (n - m - 1) / (n + m)) + * vdm[m, n - 2, :] + ) + + return vdm[: mmax + 1, : lmax + 1] + + +class SphericalHarmonicTransform: + """Spherical Harmonic Transform in pure numpy. + + Mirrors the ``SphericalHarmonicTransform`` from ``spectral_helpers.py`` in anemoi.models + but operates on numpy arrays rather than torch tensors. + + Parameters + ---------- + lons_per_lat : list[int] + Number of longitude points on each latitude ring (pole to pole). + truncation : int + Maximum total wavenumber to retain. + """ + + def __init__(self, lons_per_lat: list[int], truncation: int) -> None: + self.lons_per_lat = lons_per_lat + self.nlat = len(lons_per_lat) + self.truncation = truncation + assert 0 < truncation <= self.nlat, f"Truncation {truncation} must be in (0, {self.nlat}]" + self.n_grid_points = sum(lons_per_lat) + + # Offsets into the flattened grid for each latitude ring + self.slon = [0] + list(np.cumsum(lons_per_lat))[:-1] + + # Whether all rings have the same number of points (regular grid) + self._is_regular = len(set(lons_per_lat)) == 1 + + # Precompute Gaussian latitudes + quadrature weights + theta, weight = _legendre_gauss_weights(self.nlat) + theta = np.flip(np.arccos(theta)) + + # Associated Legendre polynomials (m, l, lat) + pct = _legpoly(truncation, truncation, np.cos(theta)) + + # Pre-multiply by quadrature weights → shape (m, l, lat) + self.weight = np.einsum("mlk,k->mlk", pct, weight) + + # internal FFT helpers + + def _rfft_regular(self, x: np.typing.NDArray) -> np.typing.NDArray: + """Batched real FFT for a *regular* grid. + + Parameters + ---------- + x : np.typing.NDArray, shape ``(..., grid)`` + + Returns + ------- + np.typing.NDArray, complex, shape ``(..., nlat, nlon//2+1)`` + """ + nlon = self.lons_per_lat[0] + return np.fft.rfft(x.reshape(*x.shape[:-1], self.nlat, nlon), norm="forward") + + def _rfft_reduced(self, x: np.typing.NDArray) -> np.typing.NDArray: + """Per-ring real FFT for a *reduced* (variable-resolution) grid. + + Parameters + ---------- + x : np.typing.NDArray, shape ``(..., grid)`` + + Returns + ------- + np.typing.NDArray, complex, shape ``(..., nlat, max_nlon//2+1)`` + """ + max_nlon = max(self.lons_per_lat) + out_shape = (*x.shape[:-1], self.nlat, max_nlon // 2 + 1) + out = np.zeros(out_shape, dtype=np.complex128) + + for i, (slon, nlon) in enumerate(zip(self.slon, self.lons_per_lat, strict=False)): + out[..., i, : nlon // 2 + 1] = np.fft.rfft(x[..., slon : slon + nlon], norm="forward") + return out + + # transform + + def transform(self, x: np.typing.NDArray) -> np.typing.NDArray: + """Compute the SHT. + + Parameters + ---------- + x : np.typing.NDArray, real, shape ``(..., grid)`` + + Returns + ------- + np.typing.NDArray, complex, shape ``(..., L, M)`` where + ``L = M = truncation + 1``. + """ + if self._is_regular: + x_fft = self._rfft_regular(x) + else: + x_fft = self._rfft_reduced(x) + + x_fft = 2.0 * np.pi * x_fft + + real_part = x_fft[..., : self.truncation + 1].real + imag_part = x_fft[..., : self.truncation + 1].imag + + rl = np.einsum("...km,mlk->...lm", real_part, self.weight) + im = np.einsum("...km,mlk->...lm", imag_part, self.weight) + + return rl + 1j * im + + +class InverseSphericalHarmonicTransform: + """Inverse Spherical Harmonic Transform in pure numpy. + + Reconstructs a spatial field from spectral coefficients (l, m). + Mirrors the ``InverseSphericalHarmonicTransform`` from ``spectral_helpers.py`` + in anemoi.models but operates on numpy arrays. + + Parameters + ---------- + lons_per_lat : list[int] + Number of longitude points on each latitude ring (pole to pole). + truncation : int + Maximum total wavenumber. + """ + + def __init__(self, lons_per_lat: list[int], truncation: int) -> None: + self.lons_per_lat = lons_per_lat + self.nlat = len(lons_per_lat) + self.truncation = truncation + self.n_grid_points = sum(lons_per_lat) + self._is_regular = len(set(lons_per_lat)) == 1 + + # Gaussian latitudes (no quadrature weights needed for inverse) + theta, _ = _legendre_gauss_weights(self.nlat) + theta = np.flip(np.arccos(theta)) + + # Associated Legendre polynomials with inverse=True + self.pct = _legpoly(truncation, truncation, np.cos(theta), inverse=True) + + def _irfft_regular(self, x: np.typing.NDArray) -> np.typing.NDArray: + """Inverse FFT for a regular grid. + + Parameters + ---------- + x : np.typing.NDArray, complex, shape ``(..., nlat, M)`` + + Returns + ------- + np.typing.NDArray, real, shape ``(..., grid)`` + """ + nlon = self.lons_per_lat[0] + spatial = np.fft.irfft(x, n=nlon, norm="forward") # (..., nlat, nlon) + return spatial.reshape(*spatial.shape[:-2], self.n_grid_points) + + def _irfft_reduced(self, x: np.typing.NDArray) -> np.typing.NDArray: + """Per-ring inverse FFT for a reduced grid. + + Parameters + ---------- + x : np.typing.NDArray, complex, shape ``(..., nlat, M)`` + + Returns + ------- + np.typing.NDArray, real, shape ``(..., grid)`` + """ + lead_shape = x.shape[:-2] + out = np.zeros((*lead_shape, self.n_grid_points), dtype=np.float64) + offset = 0 + for i, nlon in enumerate(self.lons_per_lat): + ring = np.fft.irfft(x[..., i, :], n=nlon, norm="forward") + out[..., offset : offset + nlon] = ring + offset += nlon + return out + + def transform(self, coeffs: np.typing.NDArray) -> np.typing.NDArray: + """Compute the inverse SHT. + + Parameters + ---------- + coeffs : np.typing.NDArray, complex, shape ``(..., L, M)`` + + Returns + ------- + np.typing.NDArray, real, shape ``(..., grid)`` + """ + # Inverse Legendre transform: (..., l, m) × (m, l, k) → (..., k, m) + real_part = coeffs.real + imag_part = coeffs.imag + + rl = np.einsum("...lm,mlk->...km", real_part, self.pct) + im = np.einsum("...lm,mlk->...km", imag_part, self.pct) + + x_fourier = rl + 1j * im # (..., nlat, M) + + # Inverse FFT per ring + if self._is_regular: + return self._irfft_regular(x_fourier) + else: + return self._irfft_reduced(x_fourier) + + +# --------------------------------------------------------------------------- +# Grid helpers for building lons_per_lat +# --------------------------------------------------------------------------- + + +def _octahedral_lons_per_lat(nlat: int) -> list[int]: + """Return lons_per_lat for an octahedral reduced Gaussian grid.""" + half = [20 + 4 * i for i in range(nlat // 2)] + return half + list(reversed(half)) + + +def _regular_lons_per_lat(nlat: int) -> list[int]: + """Return lons_per_lat for a regular lat-lon grid (nlon = 2*nlat).""" + return [2 * nlat] * nlat + + +# --------------------------------------------------------------------------- +# High-level SHT PSD +# --------------------------------------------------------------------------- + + +def sht_psd( + data: np.typing.NDArray, + nlat: int, + truncation: int | None = None, + grid_type: str = "octahedral", +) -> tuple[np.typing.NDArray, np.typing.NDArray]: + """Compute PSD via Spherical Harmonic Transform. + + 1. Forward SHT: spatial → spectral coefficients ``(l, m)``. + 2. PSD: L2-norm over ``m`` for each total wavenumber ``l``. + + Parameters + ---------- + data : np.typing.NDArray + Spatial field with shape ``(n_points,)`` or ``(n_samples, n_points)``. + nlat : int + Number of latitudes in the grid. + truncation : int | None + Spectral truncation. Defaults to ``nlat // 2 - 1``. + grid_type : str + One of ``"octahedral"``, ``"regular"``, ``"reduced"``. + + Returns + ------- + wavenumbers : np.typing.NDArray, shape ``(L,)`` + Total wavenumber indices ``0, 1, …, L-1``. + psd : np.typing.NDArray, shape ``(L,)`` + Power spectral density averaged over samples. + """ + if data.ndim == 1: + data = data[np.newaxis, :] + n_samples, n_points = data.shape + + # Build the SHT for the appropriate grid + if grid_type == "octahedral": + lons_per_lat = _octahedral_lons_per_lat(nlat) + elif grid_type == "regular": + lons_per_lat = _regular_lons_per_lat(nlat) + elif grid_type == "reduced": + try: + from anemoi.transform.grids.named import lookup + except ImportError: + raise ImportError( + "anemoi.transform is required for grid_type='reduced'. " + "Install: pip install anemoi-transform" + ) from None + lats = lookup("N320")["latitudes"] + unique_lats = sorted(set(lats)) + lons_per_lat = [int((lats == lat).sum()) for lat in unique_lats] + else: + raise ValueError(f"Unknown grid_type: {grid_type!r}") + + trunc = truncation or nlat // 2 - 1 + sht = SphericalHarmonicTransform(lons_per_lat=lons_per_lat, truncation=trunc) + + assert n_points == sht.n_grid_points, ( + f"Input points={n_points} != expected grid points={sht.n_grid_points} " + f"for grid_type={grid_type!r}, nlat={nlat}" + ) + + # SphericalHarmonicTransform.transform accepts (..., grid) → (..., L, M) + # Pass (n_samples, n_points) directly. + coeffs = sht.transform(data) # (n_samples, L, M) + + # PSD = sum |coeffs|^2 over m for each total wavenumber l, averaged over samples + psd_per_sample = np.sum(np.abs(coeffs) ** 2, axis=-1) # (n_samples, L) + psd = psd_per_sample.mean(axis=0) + + n_wavenumbers = psd.shape[0] + wavenumbers = np.arange(n_wavenumbers, dtype=np.float64) + + return wavenumbers, psd + + +# --------------------------------------------------------------------------- +# FFT PSD (absorbed from psd_calc.py) +# --------------------------------------------------------------------------- + + +def _fft_psd_calc(ht: np.typing.NDArray) -> np.typing.NDArray: + """Return the PSD for positive non-zero frequencies of an even-length signal. + + Assumes *ht* has an even number of points. + + Parameters + ---------- + ht : np.typing.NDArray + 1-D real-valued signal (one latitude ring). + + Returns + ------- + np.typing.NDArray + PSD for positive frequencies, length ``n // 2``. + """ + n = len(ht) + hf = np.fft.rfft(ht, norm="forward") + power = np.abs(hf[1 : round(n / 2 + 1)]) ** 2 + power *= 2.0 # compensate for positive frequencies only + return power + + +def _cubepsd(field_2d: np.typing.NDArray) -> np.typing.NDArray: + """Compute PSD averaged over all latitude rows. + + Parameters + ---------- + field_2d : np.typing.NDArray + 2-D array of shape ``(nlat, nlon)``. + + Returns + ------- + np.typing.NDArray + PSD of shape ``(nlon // 2,)``. + """ + nlat, nlon = field_2d.shape + field_psd = np.zeros(nlon // 2) + for row in field_2d: + field_psd += _fft_psd_calc(row) + field_psd /= nlat + return field_psd + + +def _calcposfreq(npoints: int, spacing_deg: float = 1.0) -> np.typing.NDArray: + """Return the positive frequencies for a signal of *npoints* evenly spaced points. + + Parameters + ---------- + npoints : int + Number of equally-spaced longitude points. + spacing_deg : float + Grid spacing in degrees. + + Returns + ------- + np.typing.NDArray + Positive frequencies, length ``npoints // 2``. + """ + freq = np.fft.fftfreq(npoints, d=spacing_deg) + return np.abs(freq[1 : round(npoints / 2 + 1)]) + + +def fft_psd( + data: np.typing.NDArray, + lats: np.typing.NDArray, + lons: np.typing.NDArray, + lat_range: tuple[float, float] = (-60.0, 60.0), + regrid_resolution: float = 1.0, +) -> tuple[np.typing.NDArray, np.typing.NDArray]: + """Compute PSD using 1-D FFT along the longitude dimension. + + For unstructured grids (where lats/lons are per-point coordinates rather + than regular axis arrays), the data is first regridded to a regular lat-lon + grid using scipy nearest-neighbor interpolation. + + Parameters + ---------- + data : np.typing.NDArray + Field values. Shape ``(n_samples, n_points)`` or ``(n_points,)``. + lats : np.typing.NDArray + Latitude values. Either per-point (length ``n_points``) or axis (length ``nlat``). + lons : np.typing.NDArray + Longitude values. Either per-point (length ``n_points``) or axis (length ``nlon``). + lat_range : tuple[float, float] + Latitude bounds to restrict the computation to. + regrid_resolution : float + Grid spacing in degrees for the regular target grid. + + Returns + ------- + frequencies : np.typing.NDArray + Positive frequencies in cycles per degree, shape ``(nfreq,)``. + psd : np.typing.NDArray + Power spectral density averaged over samples and latitude rows, + shape ``(nfreq,)``. + """ + + # Ensure 2-D: (n_samples, n_points) + if data.ndim == 1: + data = data.reshape(1, -1) + + n_samples, n_points = data.shape + + # Determine if the grid is regular or unstructured + unique_lats = np.unique(lats) + unique_lons = np.unique(lons) + is_regular = len(unique_lats) * len(unique_lons) == n_points + + if is_regular and len(lats) == len(unique_lats): + # lats/lons are axis arrays for a regular grid + lat_axis = unique_lats + lon_axis = unique_lons + nlat, nlon = len(lat_axis), len(lon_axis) + data_3d = data.reshape(n_samples, nlat, nlon) + else: + # Unstructured grid — regrid to regular lat-lon + lat_min = max(lat_range[0], lats.min()) + lat_max = min(lat_range[1], lats.max()) + lon_min, lon_max = lons.min(), lons.max() + + lat_axis = np.arange(lat_min, lat_max + regrid_resolution / 2, regrid_resolution) + lon_axis = np.arange(lon_min, lon_max + regrid_resolution / 2, regrid_resolution) + nlat, nlon = len(lat_axis), len(lon_axis) + + grid_lon, grid_lat = np.meshgrid(lon_axis, lat_axis) + points = np.column_stack((lats, lons)) + + data_3d = np.empty((n_samples, nlat, nlon)) + for s in range(n_samples): + data_3d[s] = griddata(points, data[s], (grid_lat, grid_lon), method="nearest") + + # Apply latitude mask + lat_mask = (lat_axis >= lat_range[0]) & (lat_axis <= lat_range[1]) + data_3d = data_3d[:, lat_mask, :] + nlon_sub = data_3d.shape[2] + + # Compute PSD per sample and average + psds = [] + for s in range(data_3d.shape[0]): + psds.append(_cubepsd(data_3d[s])) + psd_result = np.mean(psds, axis=0) + + spacing = 360.0 / nlon_sub if nlon_sub > 0 else regrid_resolution + frequencies = _calcposfreq(nlon_sub, spacing_deg=spacing) + return frequencies, psd_result + + +# --------------------------------------------------------------------------- +# Dispatch +# --------------------------------------------------------------------------- + + +def compute_psd_for_field( + data: np.typing.NDArray, + method: str = "sht", + nlat: int | None = None, + lats: np.typing.NDArray | None = None, + lons: np.typing.NDArray | None = None, + lat_range: tuple[float, float] = (-60.0, 60.0), + regrid_resolution: float = 1.0, + sht_truncation: int | None = None, + grid_type: str = "octahedral", +) -> tuple[np.typing.NDArray, np.typing.NDArray]: + """Compute PSD using the selected method. + + Parameters + ---------- + data : np.typing.NDArray + Spatial field. Shape depends on the method (see ``sht_psd`` / ``fft_psd``). + method : str + ``"sht"`` for SHT-based PSD, ``"fft"`` for FFT PSD. + nlat : int | None + Number of latitudes (required for SHT method). + lats, lons : np.typing.NDArray | None + Latitude / longitude coordinate arrays (required for fft method). + lat_range : tuple[float, float] + Latitude bounds for the fft method. + regrid_resolution : float + Grid spacing in degrees for the fft method. + sht_truncation : int | None + Spectral truncation for SHT. + grid_type : str + Grid type for SHT (``"octahedral"``, ``"regular"``, ``"reduced"``). + + Returns + ------- + x_values : np.typing.NDArray + Wavenumbers (SHT) or positive frequencies (fft). + psd : np.typing.NDArray + Power spectral density. + """ + if method == "sht": + if nlat is None: + raise ValueError("nlat is required for method='sht'") + return sht_psd( + data=data, + nlat=nlat, + truncation=sht_truncation, + grid_type=grid_type, + ) + elif method == "fft": + if lats is None or lons is None: + raise ValueError("lats and lons are required for method='fft'") + return fft_psd( + data=data, + lats=lats, + lons=lons, + lat_range=lat_range, + regrid_resolution=regrid_resolution, + ) + else: + raise ValueError(f"Unknown PSD method: {method!r}. Use 'sht' or 'fft'.") + + +def compute_psd_score( + gt: np.typing.NDArray, + p: np.typing.NDArray, + lats: np.typing.NDArray | None, + lons: np.typing.NDArray | None, + nlat: int | None, + n_points: int, + psd_method: str = "sht", + psd_regrid_resolution: float = 1.0, + psd_sht_truncation: int | None = None, + lat_range: tuple[float, float] = (-60.0, 60.0), + grid_type: str = "octahedral", +) -> tuple[float, dict]: + """Compute PSD for a pair of 2-D fields and return a scalar score + curves. + + This is the main entry point called from the Scores class. It handles NaN + masking, calls ``compute_psd_for_field`` for both inputs, and computes a + log-spectral MSE summary score. + + Parameters + ---------- + gt, p : np.typing.NDArray + Ground truth and prediction arrays of shape ``(n_samples, n_points)``. + lats, lons : np.typing.NDArray | None + Latitude / longitude arrays of length ``n_points`` (or None). + nlat : int | None + Number of latitudes (for SHT fallback). + n_points : int + Original number of spatial points (before NaN masking). + psd_method : str + ``"sht"`` or ``"fft"``. + psd_regrid_resolution : float + Grid spacing for fft method. + psd_sht_truncation : int | None + Spectral truncation for SHT. + lat_range : tuple[float, float] + Latitude bounds for fft method. + + Returns + ------- + score : float + Log-spectral MSE scalar. + attrs : dict + Dict with keys ``"frequencies"``, ``"psd_target"``, ``"psd_prediction"`` + (lists for JSON serialization). + """ + # Handle NaN grid points (e.g. from regional masking). + valid_mask = ~np.isnan(gt).all(axis=0) + gt = gt[:, valid_mask] + p = p[:, valid_mask] + + # Filter lat/lon to match valid points + lats_valid = lats[valid_mask] if lats is not None and len(lats) == n_points else lats + lons_valid = lons[valid_mask] if lons is not None and len(lons) == n_points else lons + nlat_valid = len(np.unique(lats_valid)) if lats_valid is not None else nlat + + # SHT requires the structurally-complete grid (all points present). + # If the data is a regional subset, skip with a warning. + if psd_method == "sht": + if grid_type == "octahedral": + expected_pts = sum(_octahedral_lons_per_lat(nlat_valid)) + elif grid_type == "regular": + expected_pts = sum(_regular_lons_per_lat(nlat_valid)) + else: + expected_pts = None # cannot validate + actual_pts = gt.shape[-1] + if expected_pts is not None and actual_pts != expected_pts: + _logger.warning( + f"PSD (SHT): grid point mismatch ({actual_pts} vs expected {expected_pts} " + f"for grid_type={grid_type!r}, nlat={nlat_valid}). SHT scores are only " + f"available for the full (global/unmasked) grid. Skipping this region." + ) + return np.nan, {} + + try: + freq_gt, psd_gt = compute_psd_for_field( + data=gt, + method=psd_method, + nlat=nlat_valid, + lats=lats_valid, + lons=lons_valid, + lat_range=lat_range, + regrid_resolution=psd_regrid_resolution, + sht_truncation=psd_sht_truncation, + grid_type=grid_type, + ) + freq_p, psd_p = compute_psd_for_field( + data=p, + method=psd_method, + nlat=nlat_valid, + lats=lats_valid, + lons=lons_valid, + lat_range=lat_range, + regrid_resolution=psd_regrid_resolution, + sht_truncation=psd_sht_truncation, + grid_type=grid_type, + ) + except Exception: + _logger.exception("PSD computation failed, returning NaN.") + return np.nan, {} + + # Scalar summary: mean squared error of log10 PSD + valid = (psd_gt > 0) & (psd_p > 0) + if valid.any(): + log_mse = float(np.mean((np.log10(psd_p[valid]) - np.log10(psd_gt[valid])) ** 2)) + else: + log_mse = np.nan + + attrs = { + "frequencies": freq_gt.tolist(), + "psd_target": psd_gt.tolist(), + "psd_prediction": psd_p.tolist(), + } + + return log_mse, attrs diff --git a/packages/evaluate/src/weathergen/evaluate/scores/score.py b/packages/evaluate/src/weathergen/evaluate/scores/score.py index afee96d9a..28b58bb90 100755 --- a/packages/evaluate/src/weathergen/evaluate/scores/score.py +++ b/packages/evaluate/src/weathergen/evaluate/scores/score.py @@ -16,6 +16,7 @@ import xarray as xr from scipy.spatial import cKDTree +from weathergen.evaluate.scores.psd import compute_psd_score from weathergen.evaluate.scores.score_utils import to_list # from common.io import MockIO @@ -198,6 +199,7 @@ def __init__( "seeps": self.calc_seeps, "qq_analysis": self.calc_quantiles, "nse": self.calc_nse, + "psd": self.calc_psd, } self.prob_metrics_dict = { "ssr": self.calc_ssr, @@ -1807,3 +1809,179 @@ def calc_quantiles( _logger.info(f"Q-Q analysis completed with {len(overall_qq_score.attrs)} attributes") return overall_qq_score + + def calc_psd( + self, + p: xr.DataArray, + gt: xr.DataArray, + psd_method: str = "sht", + psd_regrid_resolution: float = 1.0, + psd_sht_truncation: int | None = None, + lat_range: tuple[float, float] = (-60.0, 60.0), + grid_type: str = "octahedral", + ) -> xr.DataArray: + """Compute power spectral density for prediction and ground truth. + + Returns a scalar summary score (log-spectral MSE) and stores the full + PSD curves in ``.attrs`` for plotting downstream. + + Parameters + ---------- + p: xr.DataArray + Forecast data array + gt: xr.DataArray + Ground truth data array + psd_method: str + Method to compute the PSD. Options: 'sht' (spherical harmonic transform), + 'fft' (2D Fourier transform) + psd_regrid_resolution: float + Resolution in degrees to regrid data for PSD calculation. Default is 1.0 degree + psd_sht_truncation: int | None + Maximum spherical harmonic degree for truncation. If None, no truncation is applied. + lat_range: tuple[float, float] + Latitude range (min, max) to include in PSD calculation. Default is (-60, + 60) degrees. + grid_type: str + Type of grid for PSD calculation. Options: 'octahedral', 'regular'. + Default is 'octahedral'. + + Returns + ------- + xr.DataArray + Power spectral density score (log-spectral MSE) averaged over aggregation dimensions. + + """ + if self._agg_dims is None: + raise ValueError("Cannot calculate PSD without aggregation dimensions.") + if len(self._agg_dims) != 1: + raise ValueError( + f"PSD expects exactly one spatial aggregation dimension, " + f"got agg_dims={self._agg_dims}." + ) + spatial_dim = self._agg_dims[0] + if spatial_dim not in gt.dims: + raise ValueError( + f"Spatial dimension '{spatial_dim}' not found in dims {list(gt.dims)}." + ) + + n_points = gt.sizes[spatial_dim] + nlat, lats, lons = self._get_psd_grid_info(gt, spatial_dim) + + if psd_method == "fft" and (lats is None or lons is None): + raise ValueError(f"PSD method 'fft' requires lat/lon coords on '{spatial_dim}'.") + + psd_kwargs = dict( + lats=lats, + lons=lons, + nlat=nlat, + n_points=n_points, + psd_method=psd_method, + psd_regrid_resolution=psd_regrid_resolution, + psd_sht_truncation=psd_sht_truncation, + lat_range=lat_range, + grid_type=grid_type, + ) + + # Dims to preserve (e.g. channel) vs batch dims (sample, ens) + other_dims = [d for d in gt.dims if d != spatial_dim] + preserve_dims = [d for d in other_dims if d not in ("sample", "ens")] + + if not preserve_dims: + gt_np, p_np = self._stack_for_psd(gt, p, spatial_dim, n_points) + slice_score, slice_attrs = compute_psd_score(gt=gt_np, p=p_np, **psd_kwargs) + score = xr.DataArray(slice_score) + score.attrs.update(slice_attrs) + score.attrs["psd_method"] = psd_method + return score + + # Iterate over preserved dims (typically per channel) + shape = tuple(gt.sizes[d] for d in preserve_dims) + score_values = np.empty(shape) + all_attrs: dict = {} + + for idx in np.ndindex(*shape): + sel = dict(zip(preserve_dims, idx, strict=False)) + gt_slice = gt.isel(**sel) + p_slice = p.isel(**sel) + gt_np, p_np = self._stack_for_psd(gt_slice, p_slice, spatial_dim, n_points) + + slice_score, slice_attrs = compute_psd_score(gt=gt_np, p=p_np, **psd_kwargs) + score_values[idx] = slice_score + + key = "_".join( + str(gt.coords[d].values[i]) if d in gt.coords else str(i) for d, i in sel.items() + ) + for k, v in slice_attrs.items(): + all_attrs[f"{key}/{k}"] = v + + coords = {d: gt.coords[d] for d in preserve_dims if d in gt.coords} + score = xr.DataArray(score_values, dims=preserve_dims, coords=coords) + all_attrs["psd_method"] = psd_method + all_attrs["preserve_dims"] = preserve_dims + score.attrs.update(all_attrs) + return score + + @staticmethod + def _get_psd_grid_info( + gt: xr.DataArray, spatial_dim: str + ) -> tuple[int | None, np.typing.NDArray | None, np.typing.NDArray | None]: + """ + Extract nlat, lats, lons from ground-truth coords. + + Parameters + ---------- + gt: xr.DataArray + Ground truth data array with lat/lon coordinates. + spatial_dim: str + Name of the spatial dimension along which to compute the PSD. + Returns + ------- + nlat: int | None + Number of latitude points, or None if lat/lon coords are not found. + lats: np.typing.NDArray | None + Latitude values, or None if lat/lon coords are not found. + lons: np.typing.NDArray | None + Longitude values, or None if lat/lon coords are not found. + + """ + if "lat" in gt.coords and "lon" in gt.coords: + if gt.coords["lat"].dims == (spatial_dim,) and gt.coords["lon"].dims == (spatial_dim,): + lats = gt.coords["lat"].values + lons = gt.coords["lon"].values + return len(np.unique(lats)), lats, lons + raise ValueError(f"PSD requires lat/lon coords on spatial dimension '{spatial_dim}'.") + + @staticmethod + def _stack_for_psd( + gt: xr.DataArray, p: xr.DataArray, spatial_dim: str, n_points: int + ) -> tuple[np.typing.NDArray, np.typing.NDArray]: + """ + Reshape data to (n_batch, n_points) for PSD computation. + + Parameters + ---------- + gt: xr.DataArray + Ground truth data array. + p: xr.DataArray + Forecast data array. + spatial_dim: str + Name of the spatial dimension along which to compute the PSD. + n_points: int + Number of points along the spatial dimension. + Returns + ------- + gt_np: np.typing.NDArray + Reshaped ground truth data of shape (n_batch, n_points). + p_np: np.typing.NDArray + Reshaped forecast data of shape (n_batch, n_points). + """ + non_spatial = [d for d in gt.dims if d != spatial_dim] + if non_spatial: + gt_np = gt.transpose(*non_spatial, spatial_dim).values.reshape(-1, n_points) + p_np = p.transpose( + *[d for d in p.dims if d != spatial_dim], spatial_dim + ).values.reshape(-1, n_points) + else: + gt_np = gt.values.reshape(1, -1) + p_np = p.values.reshape(1, -1) + return gt_np, p_np diff --git a/packages/evaluate/src/weathergen/evaluate/scores/score_orchestration.py b/packages/evaluate/src/weathergen/evaluate/scores/score_orchestration.py index 82bc8a026..b56dcb7ae 100644 --- a/packages/evaluate/src/weathergen/evaluate/scores/score_orchestration.py +++ b/packages/evaluate/src/weathergen/evaluate/scores/score_orchestration.py @@ -375,10 +375,11 @@ def store_metrics_for_region( for fstep, combined_metrics, _fstep_attrs in fstep_results: criteria = { "forecast_step": int(fstep), - "sample": combined_metrics.sample.values, "channel": combined_metrics.channel.values, "metric": combined_metrics.metric.values, } + if "sample" in combined_metrics.dims: + criteria["sample"] = combined_metrics.sample.values if "ens" in combined_metrics.dims: criteria["ens"] = combined_metrics.ens.values @@ -420,11 +421,20 @@ def store_metrics_for_region( for metric, parameters in metrics.items(): metric_data = metric_stream.sel({"metric": metric}).assign_attrs(parameters) + + # Restore attrs from all fsteps, keyed by fstep so downstream code + # (plotting) can produce one plot per forecast step. for (_stored_fstep, stored_metric), attrs in all_metric_attrs.items(): if stored_metric == metric and attrs: - _logger.debug(f"Restoring {len(attrs)} attributes for {metric}") - metric_data.attrs.update(attrs) - break + for k, v in attrs.items(): + metric_data.attrs[f"fstep_{_stored_fstep}/{k}"] = v + # Also store the list of fsteps that have attrs + attr_fsteps = sorted( + {fs for (fs, m) in all_metric_attrs if m == metric and all_metric_attrs[(fs, m)]} + ) + if attr_fsteps: + metric_data.attrs["attr_fsteps"] = attr_fsteps + _logger.debug(f"Stored per-fstep attributes for {metric}: fsteps={attr_fsteps}") local_scores.setdefault(metric, {}).setdefault(region, {}).setdefault(stream, {})[ run_id diff --git a/tests/test_psd.py b/tests/test_psd.py new file mode 100644 index 000000000..0b853c457 --- /dev/null +++ b/tests/test_psd.py @@ -0,0 +1,212 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. + +"""Tests for weathergen.evaluate.scores.psd.""" + +from __future__ import annotations + +import numpy as np +import pytest + +from weathergen.evaluate.scores.psd import ( + SphericalHarmonicTransform, + ZonalPSD, + compute_psd_for_field, + sht_psd, + zonal_psd, +) + + +# --------------------------------------------------------------------------- +# ZonalPSD (absorbs psd_calc.py) +# --------------------------------------------------------------------------- + + +class TestZonalPSD: + """Unit tests for the ZonalPSD class.""" + + def test_psd_1d_even(self): + """PSD of a pure sine should peak at the correct frequency bin.""" + n = 360 + freq_idx = 5 # wave number 5 + x = np.sin(2 * np.pi * freq_idx * np.arange(n) / n) + power = ZonalPSD.psd_1d(x) + assert power.shape == (n // 2,) + # Peak should be at index freq_idx - 1 (psd starts at freq index 1) + assert np.argmax(power) == freq_idx - 1 + + def test_positive_frequencies(self): + """Positive frequencies length and positivity.""" + npoints = 360 + freq = ZonalPSD.positive_frequencies(npoints, spacing_deg=1.0) + assert len(freq) == npoints // 2 + assert np.all(freq > 0) + + def test_compute_2d(self): + """ZonalPSD.compute averages PSD over latitude rows.""" + nlat, nlon = 10, 360 + field = np.random.default_rng(42).standard_normal((nlat, nlon)) + psd = ZonalPSD.compute(field) + assert psd.shape == (nlon // 2,) + assert np.all(psd >= 0) + + +# --------------------------------------------------------------------------- +# zonal_psd wrapper +# --------------------------------------------------------------------------- + + +class TestZonalPsdWrapper: + """Tests for the zonal_psd() dispatch-level wrapper.""" + + def test_basic_shape(self): + """Output shape matches expected number of positive frequencies.""" + nlat, nlon = 90, 360 + lats = np.linspace(90, -90, nlat) + lons = np.linspace(0, 359, nlon) + data = np.random.default_rng(0).standard_normal((nlat, nlon)) + + freq, psd = zonal_psd(data, lats, lons, lat_range=(-60, 60)) + assert freq.ndim == 1 + assert psd.ndim == 1 + assert len(freq) == len(psd) + assert np.all(psd >= 0) + + def test_multi_sample(self): + """Multi-sample input is averaged correctly.""" + nlat, nlon = 45, 180 + lats = np.linspace(90, -90, nlat) + lons = np.linspace(0, 358, nlon) + rng = np.random.default_rng(1) + data = rng.standard_normal((3, nlat, nlon)) # 3 samples + + freq, psd = zonal_psd(data, lats, lons) + assert freq.shape == psd.shape + + +# --------------------------------------------------------------------------- +# SphericalHarmonicTransform +# --------------------------------------------------------------------------- + + +class TestSphericalHarmonicTransform: + """Tests for the pure-numpy SHT transform.""" + + def test_regular_grid_shape(self): + """SHT on a regular grid produces the correct output shape.""" + nlat = 32 + nlon = 64 + trunc = 15 + sht = SphericalHarmonicTransform(lons_per_lat=[nlon] * nlat, truncation=trunc) + + x = np.random.default_rng(7).standard_normal(sht.n_grid_points) + coeffs = sht.transform(x) + assert coeffs.shape == (trunc + 1, trunc + 1) + assert np.iscomplexobj(coeffs) + + def test_octahedral_grid_shape(self): + """SHT on an octahedral grid produces the correct output shape.""" + nlat = 64 + trunc = 31 + lons = [20 + 4 * i for i in range(nlat // 2)] + lons = lons + list(reversed(lons)) + sht = SphericalHarmonicTransform(lons_per_lat=lons, truncation=trunc) + + x = np.random.default_rng(8).standard_normal(sht.n_grid_points) + coeffs = sht.transform(x) + assert coeffs.shape == (trunc + 1, trunc + 1) + + def test_constant_field(self): + """A constant field should have energy only at wavenumber 0.""" + nlat = 32 + nlon = 64 + trunc = 15 + sht = SphericalHarmonicTransform(lons_per_lat=[nlon] * nlat, truncation=trunc) + + x = np.ones(sht.n_grid_points) * 42.0 + coeffs = sht.transform(x) + # l=0, m=0 should dominate + assert np.abs(coeffs[0, 0]) > 0 + # All other coefficients should be negligible + mask = np.ones_like(coeffs, dtype=bool) + mask[0, 0] = False + assert np.allclose(coeffs[mask], 0, atol=1e-10) + + +# --------------------------------------------------------------------------- +# sht_psd +# --------------------------------------------------------------------------- + + +class TestShtPsd: + """Tests for the sht_psd high-level function.""" + + def test_output_shape(self): + """sht_psd returns wavenumbers and PSD with matching shapes.""" + nlat = 64 + lons = [20 + 4 * i for i in range(nlat // 2)] + lons = lons + list(reversed(lons)) + n_points = sum(lons) + data = np.random.default_rng(9).standard_normal(n_points) + + wn, psd = sht_psd(data, nlat=nlat, grid_type="octahedral") + assert wn.shape == psd.shape + assert len(wn) == nlat // 2 # truncation default = nlat // 2 - 1 → L = nlat // 2 + assert np.all(psd >= 0) + + def test_multi_sample(self): + """Multi-sample input is averaged.""" + nlat = 32 + nlon = 64 + n_points = nlat * nlon + data = np.random.default_rng(10).standard_normal((4, n_points)) + + wn, psd = sht_psd(data, nlat=nlat, grid_type="regular") + assert wn.shape == psd.shape + + +# --------------------------------------------------------------------------- +# compute_psd_for_field dispatch +# --------------------------------------------------------------------------- + + +class TestComputePsdForField: + """Tests for the dispatch function.""" + + def test_sht_dispatch(self): + """method='sht' calls sht_psd correctly.""" + nlat = 32 + nlon = 64 + data = np.random.default_rng(11).standard_normal(nlat * nlon) + wn, psd = compute_psd_for_field(data, method="sht", nlat=nlat, grid_type="regular") + assert len(wn) == len(psd) + assert np.all(psd >= 0) + + def test_zonal_dispatch(self): + """method='zonal' calls zonal_psd correctly.""" + nlat, nlon = 90, 360 + lats = np.linspace(90, -90, nlat) + lons = np.linspace(0, 359, nlon) + data = np.random.default_rng(12).standard_normal((nlat, nlon)) + + freq, psd = compute_psd_for_field( + data, method="zonal", lats=lats, lons=lons + ) + assert len(freq) == len(psd) + + def test_invalid_method(self): + """Unknown method raises ValueError.""" + with pytest.raises(ValueError, match="Unknown PSD method"): + compute_psd_for_field(np.zeros(10), method="invalid") + + def test_missing_nlat_for_sht(self): + """method='sht' without nlat raises ValueError.""" + with pytest.raises(ValueError, match="nlat is required"): + compute_psd_for_field(np.zeros(10), method="sht") + + def test_missing_lats_for_zonal(self): + """method='zonal' without lats raises ValueError.""" + with pytest.raises(ValueError, match="lats and lons are required"): + compute_psd_for_field(np.zeros(10), method="zonal") diff --git a/tests/test_sht_roundtrip.py b/tests/test_sht_roundtrip.py new file mode 100644 index 000000000..3fae345dc --- /dev/null +++ b/tests/test_sht_roundtrip.py @@ -0,0 +1,95 @@ +"""Test that SHT forward → inverse is (approximately) the identity.""" + +import numpy as np +import pytest + +from weathergen.evaluate.scores.psd import ( + InverseSphericalHarmonicTransform, + SphericalHarmonicTransform, + _octahedral_lons_per_lat, + _regular_lons_per_lat, +) + + +@pytest.mark.parametrize("grid_type,nlat", [ + ("regular", 32), + ("regular", 64), + ("octahedral", 32), + ("octahedral", 64), +]) +def test_sht_roundtrip_identity(grid_type: str, nlat: int) -> None: + """Applying SHT then inverse SHT on random noise recovers the original field.""" + rng = np.random.default_rng(42) + + if grid_type == "regular": + lons_per_lat = _regular_lons_per_lat(nlat) + else: + lons_per_lat = _octahedral_lons_per_lat(nlat) + + n_grid_points = sum(lons_per_lat) + truncation = nlat // 2 - 1 + + sht = SphericalHarmonicTransform(lons_per_lat=lons_per_lat, truncation=truncation) + isht = InverseSphericalHarmonicTransform(lons_per_lat=lons_per_lat, truncation=truncation) + + # Random spatial field + x = rng.standard_normal(n_grid_points) + + # Forward → inverse + coeffs = sht.transform(x) + x_reconstructed = isht.transform(coeffs) + + # The reconstruction is approximate due to truncation, but should be close + # for smooth-enough fields. For a bandlimited signal it should be exact. + # Use a generous tolerance since truncation discards high-frequency content. + assert x_reconstructed.shape == x.shape, ( + f"Shape mismatch: {x_reconstructed.shape} vs {x.shape}" + ) + + # Check correlation is positive — truncation discards high-frequency content + # so white noise won't be perfectly recovered, but the low-frequency part should match. + corr = np.corrcoef(x.ravel(), x_reconstructed.ravel())[0, 1] + assert corr > 0.30, f"Correlation too low: {corr:.4f}" + + # More importantly: verify the energy is preserved for the retained modes + # by checking that the relative L2 error is bounded + rel_error = np.linalg.norm(x - x_reconstructed) / np.linalg.norm(x) + assert rel_error < 1.0, f"Relative L2 error too large: {rel_error:.4f}" + + +@pytest.mark.parametrize("grid_type,nlat", [ + ("regular", 32), + ("regular", 64), + ("octahedral", 32), + ("octahedral", 64), +]) +def test_sht_roundtrip_bandlimited(grid_type: str, nlat: int) -> None: + """For a bandlimited signal, SHT → inverse SHT should be near-exact.""" + if grid_type == "regular": + lons_per_lat = _regular_lons_per_lat(nlat) + else: + lons_per_lat = _octahedral_lons_per_lat(nlat) + + n_grid_points = sum(lons_per_lat) + truncation = nlat // 2 - 1 + + sht = SphericalHarmonicTransform(lons_per_lat=lons_per_lat, truncation=truncation) + isht = InverseSphericalHarmonicTransform(lons_per_lat=lons_per_lat, truncation=truncation) + + # Create a bandlimited signal by doing inverse SHT on random coefficients + rng = np.random.default_rng(123) + L = truncation + 1 + random_coeffs = rng.standard_normal((L, L)) + 1j * rng.standard_normal((L, L)) + # Make it physically meaningful: zero out upper triangle (m > l) + for l in range(L): + random_coeffs[l, l + 1:] = 0.0 + + # Inverse → forward → inverse should give back the same spatial field + x_bandlimited = isht.transform(random_coeffs) + coeffs_recovered = sht.transform(x_bandlimited) + x_roundtrip = isht.transform(coeffs_recovered) + + np.testing.assert_allclose( + x_roundtrip, x_bandlimited, rtol=1e-6, atol=1e-10, + err_msg="Roundtrip on bandlimited signal should be near-exact", + )