Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions config/evaluate/eval_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@
# max_workers: 36 # hard cap on parallel workers (I/O, plotting, scoring)

evaluation:
metrics :
- rmse
- mae
- psd:
psd_method: "sht" # "sht" (SHT-based) or "fft" (regrid FFT)
psd_regrid_resolution: 1.0 # degrees, only for fft method
grid_type: "octahedral" # "octahedral" (default), "regular", "reduced".
metrics : ["rmse", "mae"]
# regions: ["global", "nhem"] # Have regions here, if you want for them to apply to all streams (scores calculation)
summary_plots : true
Expand Down
7 changes: 7 additions & 0 deletions packages/evaluate/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,15 @@ dependencies = [
"earthkit-data==0.17.0",
"earthkit-utils==0.1.2",
"imageio[ffmpeg]>=2.37.2",
"scipy>=1.12",
]

[project.optional-dependencies]
fft = [
"scitools-iris>=3.11",
"cf-units",
]

[dependency-groups]
dev = [
"pytest>=8.3",
Expand Down
84 changes: 84 additions & 0 deletions packages/evaluate/src/weathergen/evaluate/plotting/line_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import xarray as xr

Expand Down Expand Up @@ -667,3 +668,86 @@ def heat_map(
parts = ["heat_map", metric, tag]
name = "_".join(filter(None, parts))
plt.savefig(f"{self.out_plot_dir.joinpath(name)}.{self.image_format}")

# ------------------------------------------------------------------
# PSD summary plot
# ------------------------------------------------------------------

def psd_plot(
self,
psd_datasets: list[dict],
labels: list[str],
tag: str = "",
variable: str = "",
forecast_step: str = "",
) -> None:
"""Create a PSD summary plot overlaying multiple runs.

Each entry in *psd_datasets* is a dict with keys
``frequencies``, ``psd_target``, ``psd_prediction``, ``psd_method``.

Parameters
----------
psd_datasets : list[dict]
One dict per run, each containing the PSD arrays stored by
``Scores.calc_psd`` in ``.attrs``.
labels : list[str]
Human-readable label for each run.
tag : str
Filename tag.
"""
out_dir = Path(self.out_plot_dir_lines) / "psd"
out_dir.mkdir(parents=True, exist_ok=True)

# Use the target from the first run as reference
freq = np.asarray(psd_datasets[0]["frequencies"])
tar_psd = np.asarray(psd_datasets[0]["psd_target"])

fig, (ax_spec, ax_ratio) = plt.subplots(
2,
1,
figsize=self.fig_size or (10, 8),
gridspec_kw={"height_ratios": [2, 1], "hspace": 0.08},
)

# Upper panel: log-log spectra
ax_spec.loglog(freq, tar_psd, color="black", lw=1.5, label="Target")
colors = plt.cm.tab10.colors
for i, (ds, label) in enumerate(zip(psd_datasets, labels, strict=False)):
c = colors[i % len(colors)]
ax_spec.loglog(
np.asarray(ds["frequencies"]),
np.asarray(ds["psd_prediction"]),
color=c,
lw=1.5,
label=label,
)
ax_spec.set_ylabel("Power")
psd_method = psd_datasets[0].get("psd_method", "sht")
title_parts = [f"PSD ({psd_method})"]
if variable:
title_parts.append(variable)
if forecast_step:
title_parts.append(f"step {forecast_step}")
ax_spec.set_title(" – ".join(title_parts))
ax_spec.legend(frameon=False, fontsize=7)
ax_spec.grid(True, which="both", ls="--", alpha=0.4)

# Lower panel: ratio (pred / target)
for i, (ds, label) in enumerate(zip(psd_datasets, labels, strict=False)):
c = colors[i % len(colors)]
pred = np.asarray(ds["psd_prediction"])
with np.errstate(divide="ignore", invalid="ignore"):
ratio = np.where(tar_psd > 0, pred / tar_psd, np.nan)
ax_ratio.semilogx(freq, ratio, color=c, lw=1.2, label=label)
ax_ratio.axhline(1.0, ls="--", color="gray", lw=0.8)
ax_ratio.set_ylabel("Pred / Target")
ax_ratio.set_xlabel("Frequency (1/deg)")
ax_ratio.set_ylim(0, 2)
ax_ratio.grid(True, which="both", ls="--", alpha=0.4)

name = tag or "psd"
fname = out_dir / f"{name}.{self.image_format}"
_logger.info(f"Saving PSD summary plot to {fname}")
fig.savefig(str(fname), bbox_inches="tight", dpi=self.dpi_val)
plt.close(fig)
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
bar_plot_metric_region,
heat_maps_metric_region,
plot_metric_region,
psd_plot_metric_region,
quantile_plot_metric_region,
ratio_plot_metric_region,
score_card_metric_region,
Expand Down Expand Up @@ -788,7 +789,12 @@ def plot_summary(cfg: dict, scores_dict: dict, summary_dir: Path):
for metric in metrics:
for region in scores_dict[metric].keys():
if eval_opt.get("summary_plots", False):
plot_metric_region(metric, region, runs, scores_dict, plotter, print_summary)
if metric == "psd":
psd_plot_metric_region(metric, region, runs, scores_dict, plotter)
elif metric == "qq_analysis":
quantile_plot_metric_region(metric, region, runs, scores_dict, quantile_plotter)
else:
plot_metric_region(metric, region, runs, scores_dict, plotter, print_summary)
if eval_opt.get("ratio_plots", False):
ratio_plot_metric_region(metric, region, runs, scores_dict, plotter, print_summary)
if eval_opt.get("heat_maps", False):
Expand All @@ -797,5 +803,3 @@ def plot_summary(cfg: dict, scores_dict: dict, summary_dir: Path):
score_card_metric_region(metric, region, runs, scores_dict, sc_plotter)
if eval_opt.get("bar_plots", False):
bar_plot_metric_region(metric, region, runs, scores_dict, br_plotter)
if metric == "qq_analysis":
quantile_plot_metric_region(metric, region, runs, scores_dict, quantile_plotter)
70 changes: 70 additions & 0 deletions packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,6 +768,76 @@ def quantile_plot_metric_region(
)


def _extract_psd_attrs(data_ch: xr.DataArray, fstep: int, ch: str) -> list[dict] | None:
"""Extract PSD curve data from DataArray attrs for a given fstep/channel.

Returns a single-element list of dicts ready for the plotter, or None if keys are missing.
"""
attrs = data_ch.attrs
fp = f"fstep_{fstep}/"

for prefix in (f"{fp}{ch}/", fp):
if f"{prefix}frequencies" in attrs and f"{prefix}psd_target" in attrs:
return [
{
"frequencies": np.array(attrs[f"{prefix}frequencies"]),
"psd_target": np.array(attrs[f"{prefix}psd_target"]),
"psd_prediction": np.array(attrs[f"{prefix}psd_prediction"]),
"psd_method": attrs.get(f"{fp}psd_method", attrs.get("psd_method", "sht")),
}
]
return None


def psd_plot_metric_region(
metric: str,
region: str,
runs: dict,
scores_dict: dict,
plotter: object,
) -> None:
"""Create PSD plots for all streams and channels for a given metric and region.

PSD curves (frequencies, target PSD, prediction PSD) are stored in
``score.attrs`` by ``Scores.calc_psd`` and read back here.
"""
streams_set = collect_streams(runs)
channels_set = collect_channels(scores_dict, metric, region, runs)

for stream in streams_set:
for ch in channels_set:
for run_id, data in scores_dict[metric][region].get(stream, {}).items():
if ch not in np.atleast_1d(data.channel.values):
continue

data_ch = data.sel(channel=ch) if "channel" in data.dims else data
if data_ch.isnull().all():
continue

attr_fsteps = data_ch.attrs.get("attr_fsteps", [])
if not attr_fsteps:
_logger.warning(f"PSD attrs missing for {run_id}/{stream}/{ch}. Skipping.")
continue

label = runs[run_id].get("label", run_id)

for fstep in attr_fsteps:
psd_datasets = _extract_psd_attrs(data_ch, fstep, ch)
if psd_datasets is None:
continue

method_tag = psd_datasets[0].get("psd_method", "sht")
name = create_filename(
prefix=[metric, method_tag, region],
middle=[run_id],
suffix=[stream, ch, f"fstep{fstep}"],
)
plotter.psd_plot(
psd_datasets, [label], tag=name,
variable=ch, forecast_step=str(fstep),
)


def create_filename(
*,
prefix: Sequence[str] = (),
Expand Down
Loading
Loading