From 512709e9513ff895367af9553942afa19bc2af13 Mon Sep 17 00:00:00 2001 From: Rolf Heilemann Myhre Date: Tue, 27 Jan 2026 16:39:17 +0100 Subject: [PATCH 01/28] add verif to develop --- packages/verif/pyproject.toml | 75 ++++ .../verif/src/weathergen/verif/__init__.py | 0 .../src/weathergen/verif/create_verif.py | 398 ++++++++++++++++++ .../src/weathergen/verif/verif_config.py | 44 ++ .../src/weathergen/verif/verif_config.yaml | 59 +++ .../weathergen/verif/verif_interpolator.py | 204 +++++++++ .../src/weathergen/verif/verif_processers.py | 190 +++++++++ pyproject.toml | 6 +- uv.lock | 36 ++ 9 files changed, 1011 insertions(+), 1 deletion(-) create mode 100644 packages/verif/pyproject.toml create mode 100644 packages/verif/src/weathergen/verif/__init__.py create mode 100644 packages/verif/src/weathergen/verif/create_verif.py create mode 100644 packages/verif/src/weathergen/verif/verif_config.py create mode 100644 packages/verif/src/weathergen/verif/verif_config.yaml create mode 100644 packages/verif/src/weathergen/verif/verif_interpolator.py create mode 100644 packages/verif/src/weathergen/verif/verif_processers.py diff --git a/packages/verif/pyproject.toml b/packages/verif/pyproject.toml new file mode 100644 index 000000000..1fef65772 --- /dev/null +++ b/packages/verif/pyproject.toml @@ -0,0 +1,75 @@ +[project] +name = "weathergen-verif" +version = "0.1.0" +description = "The WeatherGenerator Machine Learning Earth System Model" +readme = "../../README.md" +requires-python = ">=3.11,<3.13" +dependencies = [ + "xarray>=2025.6.1", + "dask>=2024.9.1", + "zarr~=3.1.3", + "numcodecs<0.16.0", +] + +[dependency-groups] +dev = [ + "pytest~=8.3.5", + "pytest-mock>=3.14.1", + "ruff==0.9.7", +] + + + +[tool.black] + +# Wide rows +line-length = 100 + + +# The linting configuration +[tool.ruff] + +# Wide rows +line-length = 100 + +[tool.ruff.lint] +# All disabled until the code is formatted. +select = [ + # pycodestyle + "E", + # Pyflakes + "F", + # pyupgrade + "UP", + # flake8-bugbear + "B", + # flake8-simplify + "SIM", + # isort + "I", + # Banned imports + "TID" +] + +# These rules are sensible and should be enabled at a later stage. +ignore = [ + # "B006", + "B011", + "UP008", + "SIM117", + "SIM118", + "SIM102", + "SIM401", + "UP040", # TODO: enable later + # To ignore, not relevant for us + "SIM108" # in case additional norm layer supports are added in future +] + + + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["src/weathergen"] diff --git a/packages/verif/src/weathergen/verif/__init__.py b/packages/verif/src/weathergen/verif/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/packages/verif/src/weathergen/verif/create_verif.py b/packages/verif/src/weathergen/verif/create_verif.py new file mode 100644 index 000000000..e841da4b4 --- /dev/null +++ b/packages/verif/src/weathergen/verif/create_verif.py @@ -0,0 +1,398 @@ +from argparse import ArgumentParser, Namespace +from pathlib import Path +from time import time + +import numpy as np +import xarray as xr + +from weathergen.common.io import ZarrIO +from weathergen.verif.verif_config import Variables +from weathergen.verif.verif_interpolator import Interpolator_factory +from weathergen.verif.verif_processers import Processer_factory + + +def readarg() -> Namespace: + parser = ArgumentParser(description="Create verif files from a zarr file and observation file") + + parser.add_argument( + "-z", + "--zarr", + dest="zarrfile", + required=True, + help="Zarr file (.zarr)", + ) + + parser.add_argument( + "-b", + "--obs", + dest="obsfile", + required=True, + help="Observation file (.nc)", + ) + + parser.add_argument( + "-o", + "--output", + dest="outfiles", + default="output/verif/%S/%V/verif_%S_%V_%M.nc", + required=False, + help="Template for the output nc filenames, default will be to create output/verif/%S/%V \ + repertories where %S, %V, %d are replaced by the streams, variable and date", + ) + + parser.add_argument( + "-d", + "--date", + type=str, + dest="datefromto", + required=False, + default=None, + help="From to date in format %Y%m%d%H:%Y%m%d%H or %Y%m%d:%Y%m%d, \ + excluding the second date for instance 2024010100:2024020200", + ) + + parser.add_argument( + "-v", + "--variables", + default=None, + dest="variables", + nargs="*", + help="Do verif for these variables. Default: 2t", + ) + + parser.add_argument( + "-s", + "--streams", + default=None, + dest="streams", + nargs="*", + help="Do verif for this streams. Default: Infer from .zarr file", + ) + + parser.add_argument( + "-m", + "--method", + default="2d", + dest="method", + choices=["2d", "lat_lon", "nearest"], + help="Interpolation method. Default: 2d_interpolation", + ) + + parser.add_argument( + "-ds", + "--dataset", + default="prediction", + dest="dataset", + choices=["prediction", "target"], + help="Prediction or target dataset.", + ) + + parser.add_argument( + "-c", + "--config_file", + dest="config_file", + default=None, + type=str, + help="Config file used for generating verif file.", + ) + + args = parser.parse_args() + + return args + + +def create_output_paths( + stream: str, variable: str, outfiles: str, method: str, dataset: str +) -> Path: + """ + Create output directories for the verif files + and return path to output file + Args: + stream (string) + variables (list[string]) + outfiles (string): template for the output files + Outputs: + None + """ + outfile = Path( + outfiles.replace("%S", stream) + .replace("%V", variable) + .replace("%M", method) + .replace("%D", dataset) + ) + pathdir = outfile.parent + print(f"Output directory: {pathdir}") + pathdir.mkdir(exist_ok=True, parents=True) + return outfile + + +def generate_time_coordinates( + xdata: xr.DataArray, zarrio: ZarrIO, stream: str, dataset: str +) -> tuple[xr.DataArray, xr.DataArray]: + """ + Read samples and steps from ZarrIO object + and convert to xarray data objects + to be used as coordinates in verrif dataset + """ + + # Initial times are stored as numpy.datetime64 objects in verif + # Get the valid time of the first step for each sample + verif_times = [np.datetime64("nat", "h")] * len(zarrio.samples) + for sample in zarrio.samples: + item = zarrio.get_data(sample=sample, stream=stream, forecast_step=1) + if dataset == "prediction": + verif_times[int(sample)] = item.prediction.as_xarray().source_interval_start.values[0] + else: + verif_times[int(sample)] = item.target.as_xarray().source_interval_start.values[0] + + xrtime = xr.DataArray( + verif_times, + name="time", + dims=["time"], + coords={"time": verif_times}, + attrs={"standard_name": "forecast_reference_time"}, + ) + + dt = xdata.source_interval_end.values[0] - xdata.source_interval_start.values[0] + dt = dt.astype("timedelta64[h]") + + # Lead times are stored as float32 in verif + # Assume all time steps are the same, + # so loop over steps and multiply the time step size by index + leadtimes = np.ndarray(len(zarrio.forecast_steps), dtype=np.float32) + for i in range(len(zarrio.forecast_steps)): + leadtimes[i] = (i + 1) * dt + + xrleadtime = xr.DataArray( + leadtimes, + name="leadtime", + dims=["leadtime"], + coords={"leadtime": leadtimes}, + attrs={"units": "hour"}, + ) + + return xrtime, xrleadtime + + +def get_streams(zarrio: ZarrIO, arg_streams: list) -> list: + """ + Determine the stream, + either by getting streams from argument and check if they are in the zarr file + or just use all the streams in zarrio + Args: + zarrio: ZarrIO object + arg_streams: (list[string]) + Outputs: + streams: (list[string]) + """ + if arg_streams: + for stream in arg_streams: + if stream not in zarrio.streams: + raise Exception( + f"Stream {stream} is not present in .zarr file. zarrio.streams: \ + {zarrio.streams}" + ) + return arg_streams + else: + return zarrio.streams + + +def get_variables(xdata: xr.DataArray, config_file: Path, arg_variables: list, stream: str) -> list: + """ + Go through argument variables, + check if they are in the config_file and return + a list ov variables. + If no arguments are given, + return list of variables found in file. + """ + + config_variables = Variables(config_file) + + config_names = (cv.name for cv in config_variables) + + variables = [] + if arg_variables: + # Check if there's a config for requested variables + for av in arg_variables: + if av not in config_names: + raise Exception(f"Variable {av} does not have an entry in the config file") + + # Add requested variables to list of variables + for cv in config_variables: + if cv.name in arg_variables: + variables += [cv] + + else: + variables = [v for v in config_variables] + + # Check what variables exist in zarr file + vvars = [] + for v in variables: + w = v + + stringnames = (n for n in v.zarr_names if isinstance(n, str)) + listnames = (n for n in v.zarr_names if isinstance(n, list)) + + for n in stringnames: + if n in xdata.channel.values: + w.zarr_names = n + vvars += [w] + for n in listnames: + if len(set(n).intersection(xdata.channel.values)) == len(n): + w.zarr_names = tuple(n) + vvars += [w] + + variables = vvars + + if not (len(variables) == len(set(variables))): + raise Exception("Same variable appears multiple times in zarr file.") + + if not variables: + raise Exception("No variables with configuration found in zarr file.") + + for v in variables: + v.unit = v.zarr_units = v.zarr_units[stream] + + return variables + + +def get_obs_coordinates(obs: xr.Dataset): + """ + Extract latitude, longitude and altitude + from observation dataset + Args: + obs: Dataset + Outputs: + lat: DataArray + lon: DataArray + alt: DataArray + """ + + lat = obs.latitude.astype("float32") + lat.name = "lat" + + lon = obs.longitude.astype("float32") + lon.name = "lon" + + alt = obs.altitude.astype("float32") + + return lat, lon, alt + + +def process_config(config_file: str) -> Path: + """ + Convert input config_file argument to absolute Path object + """ + + if not config_file: + config_path = Path(__file__).parent / "verif_config.yaml" + else: + config_path = Path(config_file).resolve() + + if not config_path.is_file(): + raise Exception(f"{config_file} is not a file.") + + return config_path + + +def main(): + print("Start creating verif files") + + args = readarg() + + print("zarrfile:", args.zarrfile) + print("obsfile:", args.obsfile) + print("outputfile template:", args.outfiles) + print("dataset: ", args.dataset) + + obs = xr.open_dataset(args.obsfile) + lat, lon, alt = get_obs_coordinates(obs) + obs_coords = np.column_stack((lat.values, lon.values)) + + print() + print(obs) + + method_factory = Interpolator_factory(args.method) + + with ZarrIO(args.zarrfile) as zarrio: + streams = get_streams(zarrio, args.streams) + + t_start = time() + + for stream in streams: + print() + print("stream: ", stream) + + item = zarrio.get_data(sample=0, stream=stream, forecast_step=1) + if args.dataset == "prediction": + xdata = item.prediction.as_xarray() + else: + xdata = item.target.as_xarray() + + xrtime, xrleadtime = generate_time_coordinates(xdata, zarrio, stream, args.dataset) + + config_path = process_config(args.config_file) + + variables = get_variables(xdata, config_path, args.variables, stream) + + zarr_coords = np.column_stack((xdata.ipoint.lat.values, xdata.ipoint.lon.values)) + + interpolator = method_factory.get_interpolator(zarr_coords, obs_coords) + + data_shape = (len(zarrio.samples), len(zarrio.forecast_steps), obs.location.shape[0]) + + processers = Processer_factory(zarrio, obs, stream, interpolator, args.dataset) + + for v in variables: + vt_start = time() + + print() + print("variable: ", v.name) + + fcstdata = np.ndarray(data_shape, dtype=np.float32) + obsdata = np.ndarray(data_shape, dtype=np.float32) + + p = processers.get_processer(v.name) + + p.get_data(v, fcstdata, obsdata) + + xrobsdata = xr.DataArray( + obsdata, + dims=["time", "leadtime", "location"], + coords={"time": xrtime, "leadtime": xrleadtime, "location": obs.location}, + name="obs", + attrs=v.attributes, + ) + + xrfcstdata = xr.DataArray( + fcstdata, + dims=["time", "leadtime", "location"], + coords={"time": xrtime, "leadtime": xrleadtime, "location": obs.location}, + name="fcst", + attrs=v.attributes, + ) + + merged = xr.merge([xrfcstdata, xrobsdata, lat, lon, alt]) + + outfile = create_output_paths( + stream, v.name, args.outfiles, args.method, args.dataset + ) + + merged.to_netcdf( + outfile, encoding={"time": {"units": "seconds since 1970-01-01 00:00:00"}} + ) + + vt_end = time() + + print(v.name, "time: ", vt_end - vt_start) + print("merged: ") + print(merged) + + t_end = time() + + print() + print("all the time: ", t_end - t_start) + + +if __name__ == "__main__": + main() diff --git a/packages/verif/src/weathergen/verif/verif_config.py b/packages/verif/src/weathergen/verif/verif_config.py new file mode 100644 index 000000000..490899ebb --- /dev/null +++ b/packages/verif/src/weathergen/verif/verif_config.py @@ -0,0 +1,44 @@ +from pathlib import Path + +from yaml import safe_load + +__all__ = ["Variable"] + + +class Variable: + """ + Object representing a variable + """ + + def __init__(self, **kwargs): + self.name = kwargs.get("name") + self.attributes = kwargs.get("attributes") + self.zarr_names = kwargs.get("zarr_names") + self.zarr_units = kwargs.get("zarr_units") + self.obs_name = kwargs.get("obs_name") + self.obs_units = kwargs.get("obs_units") + + def __repr__(self): + return self.name + + +class Variables: + """ + Utility class to read a verif configuration from .yaml file + """ + + def __init__(self, filename: Path): + print(f"Reading configuration from file: {filename}") + + with open(filename) as stream: + self.schema = safe_load(stream) + + def __iter__(self): + return self.variables.__iter__() + + @property + def variables(self): + """ + Get a list of variables from the locally stored configuration + """ + return [Variable(**var) for var in self.schema] diff --git a/packages/verif/src/weathergen/verif/verif_config.yaml b/packages/verif/src/weathergen/verif/verif_config.yaml new file mode 100644 index 000000000..d6d073efa --- /dev/null +++ b/packages/verif/src/weathergen/verif/verif_config.yaml @@ -0,0 +1,59 @@ +# Default config file +# mapping variables from zarr and observation files +# to verif files + +- name: "2t" + attributes: { + units: "K", + long_name: "2 meter temperature", + conventions: "verif_1.0.0"} + zarr_names: ["2t"] + zarr_units: {"CERRA": "K", + "ERA5": "K"} + obs_name: "air_temperature" + obs_units: "K" + +- name: "sp" + attributes: { + units: "Pa", + long_name: "Surface pressure", + conventions: "verif_1.0.0"} + zarr_names: ["sp"] + zarr_units: {"CERRA": "Pa", + "ERA5": "Pa"} + obs_name: "surface_air_pressure" + obs_units: "Pa" + +- name: "tp" + attributes: { + units: "kg/m^2", + long_name: "Total precipitation amount", + conventions: "verif_1.0.0"} + zarr_names: ["tp"] + zarr_units: {"CERRA": "kg/m^2", + "ERA5": "m"} + obs_name: "precipitation_amount_1h" + obs_units: "kg/m^2" + +- name: "mslp" + attributes: { + units: "Pa", + long_name: "Mean sea level pressure", + conventions: "verif_1.0.0"} + zarr_names: ["msl"] + zarr_units: {"CERRA": "Pa", + "ERA5": "Pa"} + obs_name: "surface_air_pressure" + obs_units: "Pa" + +- name: "wind" + attributes: { + units: "m/s", + long_name: "wind speed", + conventions: "verif_1.0.0"} + zarr_names: ["10si", ["10u", "10v"]] + zarr_units: {"CERRA": "m/s", + "ERA5": "m/s"} + obs_name: "wind_speed" + obs_units: "m/s" + diff --git a/packages/verif/src/weathergen/verif/verif_interpolator.py b/packages/verif/src/weathergen/verif/verif_interpolator.py new file mode 100644 index 000000000..4bbb29d6b --- /dev/null +++ b/packages/verif/src/weathergen/verif/verif_interpolator.py @@ -0,0 +1,204 @@ +import numpy as np +from scipy.interpolate import LinearNDInterpolator +from scipy.spatial import Delaunay, KDTree + + +def convert_coordinates(coords): + """ + Convert lat-lon coordinates to cartesian coordinates in a unit box + """ + + xyz_coords = np.ndarray((coords.shape[0], 3), dtype="float32") + + xyz_coords[:, 0] = np.cos(np.pi * coords[:, 0] / 180.0) * np.cos(np.pi * coords[:, 1] / 180.0) + xyz_coords[:, 1] = np.cos(np.pi * coords[:, 0] / 180.0) * np.sin(np.pi * coords[:, 1] / 180.0) + xyz_coords[:, 2] = np.sin(np.pi * coords[:, 0] / 180.0) + + return xyz_coords + + +def normalise(x): + return x[:] / np.sum(x[:]) + + +class Verif_interpolator: + """ + Interpolator class that's either a wrapper for scipys LinearNDInterpolator + or uses the handmade approximate 2D linear interpolator + """ + + +class Verif_2D_interpolator(Verif_interpolator): + """ + Class that does approximate 2D interpolation + """ + + def __init__(self, grid_points, obs_points): + """ + Initialise the class and store gridpoints + """ + + grid_xyz = convert_coordinates(grid_points) + obs_xyz = convert_coordinates(obs_points) + + self.indices = np.ndarray((obs_points.shape[0], 5), dtype="float32") + tree = KDTree(grid_xyz) + _, self.indices = tree.query(obs_xyz, k=5) + + self.weights = np.ndarray((obs_points.shape[0], 3), dtype="float32") + self.compute_weights(grid_xyz, obs_xyz) + + def compute_weights(self, grid_xyz, obs_xyz): + """ + Compute the weights of the three nearest grid points + by computing the barycentric coordinates, + assuming that the observations are close enough to the plane through the grid points. + """ + + eps = 0.01 + + for i, (obs, indix) in enumerate(zip(obs_xyz, self.indices, strict=True)): + AB = grid_xyz[indix[1]] - grid_xyz[indix[0]] + AC = grid_xyz[indix[2]] - grid_xyz[indix[0]] + BC = grid_xyz[indix[2]] - grid_xyz[indix[1]] + AP = obs - grid_xyz[indix[0]] + BP = obs - grid_xyz[indix[1]] + + area_tot = np.linalg.norm(np.cross(AB, AC)) + self.weights[i, 0] = np.linalg.norm(np.cross(BC, BP)) + self.weights[i, 1] = np.linalg.norm(np.cross(AC, AP)) + self.weights[i, 2] = np.linalg.norm(np.cross(AB, AP)) + + if 1 - area_tot / np.sum(self.weights[i, :]) < eps: + continue + + indix[2] = indix[3] + + AC = grid_xyz[indix[2]] - grid_xyz[indix[0]] + BC = grid_xyz[indix[2]] - grid_xyz[indix[1]] + + area_tot = np.linalg.norm(np.cross(AB, AC)) + self.weights[i, 0] = np.linalg.norm(np.cross(BC, BP)) + self.weights[i, 1] = np.linalg.norm(np.cross(AC, AP)) + + if 1 - area_tot / np.sum(self.weights[i, :]) < eps: + continue + + indix[2] = indix[4] + + AC = grid_xyz[indix[2]] - grid_xyz[indix[0]] + BC = grid_xyz[indix[2]] - grid_xyz[indix[1]] + + self.weights[i, 0] = np.linalg.norm(np.cross(BC, BP)) + self.weights[i, 1] = np.linalg.norm(np.cross(AC, AP)) + + self.weights = self.weights / self.weights.sum(axis=1)[:, np.newaxis] + + def interpolate(self, values, intmap=None): + """ + Interpolate values to points + """ + + wvalues = np.ndarray((self.weights.shape[0]), dtype="float32") + + if intmap is None: + wvalues[:] = ( + self.weights[:, 0] * values[self.indices[:, 0]] + + self.weights[:, 1] * values[self.indices[:, 1]] + + self.weights[:, 2] * values[self.indices[:, 2]] + ) + else: + wvalues[:] = ( + self.weights[:, 0] * values[intmap[self.indices[:, 0]]] + + self.weights[:, 1] * values[intmap[self.indices[:, 1]]] + + self.weights[:, 2] * values[intmap[self.indices[:, 2]]] + ) + + return wvalues + + +class Verif_lat_lon_interpolator(Verif_interpolator): + """ + Class that does approximate 2D interpolation + """ + + def __init__(self, grid_points, obs_points): + """ + Initialise the class and store gridpoints + """ + + self.obs_points = obs_points + self.triangulation = Delaunay(grid_points) + + def interpolate(self, values, intmap=None): + """ + Interpolate values to points + """ + + newvalues = np.empty_like(values) + + if intmap is None: + newvalues = values + else: + for i in range(len(values)): + newvalues[i] = values[intmap[i]] + + interpolator = LinearNDInterpolator(self.triangulation, newvalues) + + return interpolator(self.obs_points).astype(np.float32) + + +class Verif_nearest_interpolator(Verif_interpolator): + """ + Class that does approximate 2D interpolation + """ + + def __init__(self, grid_points, obs_points): + """ + Initialise the class and store gridpoints + """ + + grid_xyz = convert_coordinates(grid_points) + obs_xyz = convert_coordinates(obs_points) + + tree = KDTree(grid_xyz) + _, self.indices = tree.query(obs_xyz, k=1) + + def interpolate(self, values, intmap=None): + """ + Interpolate values to points + """ + + wvalues = np.ndarray((self.indices.shape), dtype="float32") + + if intmap is None: + wvalues[:] = values[self.indices[:]] + else: + wvalues[:] = values[intmap[self.indices[:]]] + + return wvalues + + +class Interpolator_factory: + def __init__(self, method: str): + valid_methods = ("2d", "lat_lon", "nearest") + + if method not in valid_methods: + raise Exception(f"{method} is not a valid method.") + + self.method = method + + def get_interpolator( + self, zarr_coords: np.ndarray, obs_coords: np.ndarray + ) -> Verif_interpolator: + if self.method == "2d": + print("2D interpolation") + return Verif_2D_interpolator(zarr_coords, obs_coords) + + elif self.method == "lat_lon": + print("lat-lon interpolation") + return Verif_lat_lon_interpolator(zarr_coords, obs_coords) + + elif self.method == "nearest": + print("nearest neighbour interpolation") + return Verif_nearest_interpolator(zarr_coords, obs_coords) diff --git a/packages/verif/src/weathergen/verif/verif_processers.py b/packages/verif/src/weathergen/verif/verif_processers.py new file mode 100644 index 000000000..c56d23212 --- /dev/null +++ b/packages/verif/src/weathergen/verif/verif_processers.py @@ -0,0 +1,190 @@ +import numpy as np +import xarray as xr + +from weathergen.common.io import ZarrIO +from weathergen.evaluate.score import Scores +from weathergen.verif.verif_config import Variable +from weathergen.verif.verif_interpolator import Verif_interpolator + + +class Processer: + unit_conversion = {"kg/m^2": 1.0, "Pa": 1.0, "K": 1.0, "m/s": 1.0, "m": 1000.0} + + def __init__( + self, + zarrio: ZarrIO, + obs: xr.DataArray, + stream: str, + interpolator: Verif_interpolator, + dataset: str, + ): + self.zarrio = zarrio + self.obs = obs + self.stream = stream + self.interpolator = interpolator + self.dataset = dataset + + item = zarrio.get_data(sample=0, stream=stream, forecast_step=1) + if self.dataset == "prediction": + self.xdata = item.prediction.as_xarray() + else: + self.xdata = item.target.as_xarray() + + self.obs_dt = self.obs.time.values[1] - self.obs.time.values[0] + self.obs_dt = self.obs_dt.astype("timedelta64[h]") + + self.zarr_dt = ( + self.xdata.source_interval_end.values[0] - self.xdata.source_interval_start.values[0] + ) + self.zarr_dt = self.zarr_dt.astype("timedelta64[h]") + + def get_data(self, v: Variable, fcstdata, obsdata): + for sample in range(len(self.zarrio.samples)): + for step in range(len(self.zarrio.forecast_steps)): + item = self.zarrio.get_data( + sample=sample, stream=self.stream, forecast_step=step + 1 + ) + + if self.dataset == "prediction": + ydata = Scores.sort_by_coords(item.prediction.as_xarray(), self.xdata) + else: + ydata = Scores.sort_by_coords(item.target.as_xarray(), self.xdata) + + obsdata[sample, step, :] = self.get_obsdata( + self.obs, v.obs_name, ydata.valid_time.values[0] + ) + + fcstdata[sample, step, :] = self.get_fcstdata(ydata, v, sample, step + 1) + + def get_obsdata(self, obs: xr.DataArray, name: str, time: np.datetime64) -> np.ndarray: + return obs.data_vars[name].sel(time=time) + + def get_fcstdata(self, ydata: xr.DataArray, v: Variable, sample: int, step: int) -> np.ndarray: + return ( + self.interpolator.interpolate( + ydata.sel( + sample=sample, + stream=self.stream, + forecast_step=step, + channel=v.zarr_names, + ens=0, + ).values + ) + * self.unit_conversion[v.zarr_units] + ) + + +class MSLP_processer(Processer): + def get_obsdata(self, obs: xr.DataArray, name: str, time: np.datetime64) -> np.ndarray: + return self.compute_mslp(obs, time) + + def compute_mslp(self, obs: xr.DataArray, time: np.datetime64) -> np.ndarray: + # g = 9.80665 # Gravitational acceleration (m/s**2) + # R = 8.31447 # Universal gas constant (J/mol*K) + + # a = 0.0065 # Temperature lapse rate (K/m) + # Ch = 0.0012 # (K/Pa) + + A = 17.625 + B = 243.03 + C = 6.1094 + + P = obs.data_vars["surface_air_pressure"].sel(time=time) + T = obs.data_vars["air_temperature"].sel(time=time) + rh = obs.data_vars["relative_humidity"].sel(time=time) + + altitude = obs.altitude + + e = rh * 6.11 * np.power(10.0, ((7.5 * (T - 273.15)) / (T - 38.85))) + + dewpoint = np.where(~np.isnan(e), B * np.log(e / C) / (A - np.log(e / C)), T - 276.15) + + e = np.where(np.isnan(e), 0, e) + + Tv = T / ( + 1.0 - 0.379 * (6.11 * np.power(10.0, ((7.5 * dewpoint) / (237.7 + dewpoint))) / P) + ) + + # mslp = np.where(altitude >= 50., + # P * np.exp((g * altitude / R) / (T + 0.5 * a * altitude + e * Ch)), + # P + P * altitude / (29.27 * Tv)) + + mslp = P + P * altitude / (29.27 * Tv) + + return mslp + + +class Wind_processer(Processer): + def get_fcstdata(self, ydata: xr.DataArray, v: Variable, sample: int, step: int) -> np.ndarray: + if isinstance(v.zarr_names, str): + return super().get_fcstdata(ydata, v, sample, step) + + else: + u = self.interpolator.interpolate( + ydata.sel( + sample=sample, + stream=self.stream, + forecast_step=step, + channel=v.zarr_names[0], + ens=0, + ).values + ) + + v = self.interpolator.interpolate( + ydata.sel( + sample=sample, + stream=self.stream, + forecast_step=step, + channel=v.zarr_names[1], + ens=0, + ).values + ) + + return np.sqrt(np.square(u) + np.square(v)) + + +class Precipitation_processer(Processer): + def get_obsdata(self, obs: xr.DataArray, name: str, time: np.datetime64) -> np.ndarray: + if self.obs_dt >= self.zarr_dt: + return super().get_obsdata(obs, name, time) + else: + accumulate = np.zeros(self.obs.location.shape[0]) + int_factor = int(self.zarr_dt / self.obs_dt) + + for i in range(int_factor): + back_time = time - self.zarr_dt + (i + 1) * self.obs_dt + accumulate += super().get_obsdata(obs, name, back_time) + + return accumulate + + +class Processer_factory: + def __init__( + self, + zarrio: ZarrIO, + obs: xr.DataArray, + stream: str, + interpolator: Verif_interpolator, + dataset: str, + ): + self.zarrio = zarrio + self.obs = obs + self.stream = stream + self.interpolator = interpolator + self.dataset = dataset + + def get_processer(self, name: str) -> Processer: + if name == "mslp": + return MSLP_processer( + self.zarrio, self.obs, self.stream, self.interpolator, self.dataset + ) + elif name == "wind": + return Wind_processer( + self.zarrio, self.obs, self.stream, self.interpolator, self.dataset + ) + elif name == "tp": + return Precipitation_processer( + self.zarrio, self.obs, self.stream, self.interpolator, self.dataset + ) + else: + return Processer(self.zarrio, self.obs, self.stream, self.interpolator, self.dataset) diff --git a/pyproject.toml b/pyproject.toml index e4dfc14d4..b744452fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,8 @@ dependencies = [ "anemoi-datasets", "weathergen-common", "weathergen-evaluate", - "weathergen-readers-extra" + "weathergen-readers-extra", + "weathergen-verif" ] @@ -45,6 +46,7 @@ inference = "weathergen.run_train:inference" evaluate = "weathergen.evaluate.run_evaluation:evaluate" plot_train = "weathergen.utils.plot_training:plot_train" export = "weathergen.evaluate.export.export_inference:export" +create_verif = "weathergen.verif.create_verif:main" [build-system] requires = ["hatchling"] @@ -225,6 +227,7 @@ weathergen-common = { workspace = true } weathergen-evaluate = { workspace = true } weathergen-metrics = { workspace = true } weathergen-readers-extra = { workspace = true } +weathergen-verif = { workspace = true } flash-attn = [ @@ -265,6 +268,7 @@ members = [ "packages/evaluate", "packages/metrics", "packages/readers_extra", + "packages/verif", # Explicitly not depending on 'packages/dashboard' : this causes issues when deploying # the streamlit dashboard. ] diff --git a/uv.lock b/uv.lock index cbeab0cb6..67b58ecac 100644 --- a/uv.lock +++ b/uv.lock @@ -23,6 +23,7 @@ members = [ "weathergen-evaluate", "weathergen-metrics", "weathergen-readers-extra", + "weathergen-verif", ] [[package]] @@ -2984,6 +2985,7 @@ dependencies = [ { name = "weathergen-common", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'darwin' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, { name = "weathergen-evaluate", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'darwin' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, { name = "weathergen-readers-extra", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'darwin' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "weathergen-verif", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'darwin' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, { name = "wheel", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'darwin' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, { name = "zarr", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'darwin' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, ] @@ -3044,6 +3046,7 @@ requires-dist = [ { name = "weathergen-common", editable = "packages/common" }, { name = "weathergen-evaluate", editable = "packages/evaluate" }, { name = "weathergen-readers-extra", editable = "packages/readers_extra" }, + { name = "weathergen-verif", editable = "packages/verif" }, { name = "wheel" }, { name = "zarr", specifier = "~=3.1.3" }, ] @@ -3222,6 +3225,39 @@ dev = [ { name = "ruff", specifier = "==0.9.7" }, ] +[[package]] +name = "weathergen-verif" +version = "0.1.0" +source = { editable = "packages/verif" } +dependencies = [ + { name = "dask", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'darwin' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "numcodecs", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'darwin' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "xarray", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'darwin' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "zarr", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'darwin' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, +] + +[package.dev-dependencies] +dev = [ + { name = "pytest", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'darwin' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "pytest-mock", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'darwin' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "ruff", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'darwin' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform == 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, +] + +[package.metadata] +requires-dist = [ + { name = "dask", specifier = ">=2024.9.1" }, + { name = "numcodecs", specifier = "<0.16.0" }, + { name = "xarray", specifier = ">=2025.6.1" }, + { name = "zarr", specifier = "~=3.1.3" }, +] + +[package.metadata.requires-dev] +dev = [ + { name = "pytest", specifier = "~=8.3.5" }, + { name = "pytest-mock", specifier = ">=3.14.1" }, + { name = "ruff", specifier = "==0.9.7" }, +] + [[package]] name = "webencodings" version = "0.5.1" From e7fb3ee3f0f0a28e1096d293bf07c3691461eb33 Mon Sep 17 00:00:00 2001 From: Rolf Heilemann Myhre Date: Mon, 9 Feb 2026 16:58:50 +0100 Subject: [PATCH 02/28] update to respect core developments and default units --- .../src/weathergen/verif/create_verif.py | 7 ++-- .../src/weathergen/verif/verif_config.yaml | 35 +++++++++++++------ .../src/weathergen/verif/verif_processers.py | 2 +- 3 files changed, 31 insertions(+), 13 deletions(-) diff --git a/packages/verif/src/weathergen/verif/create_verif.py b/packages/verif/src/weathergen/verif/create_verif.py index e841da4b4..e3f9ff402 100644 --- a/packages/verif/src/weathergen/verif/create_verif.py +++ b/packages/verif/src/weathergen/verif/create_verif.py @@ -251,7 +251,10 @@ def get_variables(xdata: xr.DataArray, config_file: Path, arg_variables: list, s raise Exception("No variables with configuration found in zarr file.") for v in variables: - v.unit = v.zarr_units = v.zarr_units[stream] + try: + v.zarr_units = v.zarr_units[stream] + except KeyError: + v.zarr_units = v.zarr_units["DEFAULT"] return variables @@ -314,7 +317,7 @@ def main(): method_factory = Interpolator_factory(args.method) - with ZarrIO(args.zarrfile) as zarrio: + with ZarrIO(args.zarrfile, read_only=True) as zarrio: streams = get_streams(zarrio, args.streams) t_start = time() diff --git a/packages/verif/src/weathergen/verif/verif_config.yaml b/packages/verif/src/weathergen/verif/verif_config.yaml index d6d073efa..0662007d6 100644 --- a/packages/verif/src/weathergen/verif/verif_config.yaml +++ b/packages/verif/src/weathergen/verif/verif_config.yaml @@ -8,8 +8,11 @@ long_name: "2 meter temperature", conventions: "verif_1.0.0"} zarr_names: ["2t"] - zarr_units: {"CERRA": "K", - "ERA5": "K"} + zarr_units: {"CERRA": "K", + "MEPS": "K", + "NORA3": "K", + "ERA5": "K", + "DEFAULT": "K"} obs_name: "air_temperature" obs_units: "K" @@ -19,8 +22,11 @@ long_name: "Surface pressure", conventions: "verif_1.0.0"} zarr_names: ["sp"] - zarr_units: {"CERRA": "Pa", - "ERA5": "Pa"} + zarr_units: {"CERRA": "Pa", + "MEPS": "Pa", + "NORA3": "Pa", + "ERA5": "Pa", + "DEFAULT": "Pa"} obs_name: "surface_air_pressure" obs_units: "Pa" @@ -30,8 +36,11 @@ long_name: "Total precipitation amount", conventions: "verif_1.0.0"} zarr_names: ["tp"] - zarr_units: {"CERRA": "kg/m^2", - "ERA5": "m"} + zarr_units: {"CERRA": "kg/m^2", + "MEPS": "kg/m^2", + "NORA3": "kg/m^2", + "ERA5": "m", + "DEFAULT": "kg/m^2"} obs_name: "precipitation_amount_1h" obs_units: "kg/m^2" @@ -41,8 +50,11 @@ long_name: "Mean sea level pressure", conventions: "verif_1.0.0"} zarr_names: ["msl"] - zarr_units: {"CERRA": "Pa", - "ERA5": "Pa"} + zarr_units: {"CERRA": "Pa", + "MEPS": "Pa", + "NORA3": "Pa", + "ERA5": "Pa", + "DEFAULT": "Pa"} obs_name: "surface_air_pressure" obs_units: "Pa" @@ -52,8 +64,11 @@ long_name: "wind speed", conventions: "verif_1.0.0"} zarr_names: ["10si", ["10u", "10v"]] - zarr_units: {"CERRA": "m/s", - "ERA5": "m/s"} + zarr_units: {"CERRA": "m/s", + "MEPS": "m/s", + "NORA3": "m/s", + "ERA5": "m/s", + "DEFAULT": "m/s"} obs_name: "wind_speed" obs_units: "m/s" diff --git a/packages/verif/src/weathergen/verif/verif_processers.py b/packages/verif/src/weathergen/verif/verif_processers.py index c56d23212..eb9a86e34 100644 --- a/packages/verif/src/weathergen/verif/verif_processers.py +++ b/packages/verif/src/weathergen/verif/verif_processers.py @@ -2,7 +2,7 @@ import xarray as xr from weathergen.common.io import ZarrIO -from weathergen.evaluate.score import Scores +from weathergen.evaluate.scores.score import Scores from weathergen.verif.verif_config import Variable from weathergen.verif.verif_interpolator import Verif_interpolator From 4dca44128b3f3fbdef40935dffc5047f77c528d3 Mon Sep 17 00:00:00 2001 From: Sorcha Date: Tue, 17 Feb 2026 17:00:37 +0100 Subject: [PATCH 03/28] initial reworking --- config/evaluate/config_zarr2verif.yaml | 111 ++++ .../weathergen/evaluate/export/cf_utils.py | 4 +- .../evaluate/export/export_inference.py | 7 +- .../evaluate/export/parsers/verif_parser.py | 530 ++++++++++++++++++ 4 files changed, 649 insertions(+), 3 deletions(-) create mode 100644 config/evaluate/config_zarr2verif.yaml create mode 100644 packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py diff --git a/config/evaluate/config_zarr2verif.yaml b/config/evaluate/config_zarr2verif.yaml new file mode 100644 index 000000000..da842cfed --- /dev/null +++ b/config/evaluate/config_zarr2verif.yaml @@ -0,0 +1,111 @@ +variables: + 2t: + var: 2t + long: 2 meter temperature + wg_unit: {CERRA: K, + MEPS: K, + NORA3: K, + ERA5: K, + DEFAULT: K} + verif_unit: K + obs_name: air_temperature + obs_units: K + level_type: sfc + + sp: + var: sp + long: Surface pressure + wg_unit: {CERRA: Pa, + MEPS: Pa, + NORA3: Pa, + ERA5: Pa, + DEFAULT: Pa} + verif_unit: Pa + obs_name: surface_air_pressure + obs_units: Pa + level_type: sfc + + tp: + var: tp + long: Total precipitation amount + wg_unit: {CERRA: kg/m^2, + MEPS: kg/m^2, + NORA3: kg/m^2, + ERA5: m, + DEFAULT: kg/m^2} + verif_unit: kg/m^2 + obs_name: precipitation_amount_1h + obs_units: kg/m^2 + level_type: sfc + + + mlsp: + var: msl + long: Mean sea level pressure + wg_unit: {CERRA: Pa, + MEPS: Pa, + NORA3: Pa, + ERA5: Pa, + DEFAULT: Pa} + verif_unit: Pa + obs_name: surface_air_pressure #check with Rolf + obs_units: kg/m^2 + level_type: sfc + + + wind: + var: 10si # derived channel + long: wind speed + wg_unit: {CERRA: m/s, + MEPS: m/s, + NORA3: m/s, + ERA5: m/s, + DEFAULT: m/s} + verif_unit: m/s + obs_name: wind_speed + obs_units: m/s + level_type: sfc + + +coordinates: + sfc: + valid_time: valid_time + lat: latitude + lon: longitude + stream: stream + forecast_step: leadtime + forecast_reference_time: forecast_reference_time + ncells: ncells + #pl: + #not needed + +dimensions: + valid_time: + wg: valid_time + verif: time + lat: + wg: latitude + verif: latitude + verif_unit: degrees_north + lon: + wg: longitude + verif: longitude + verif_unit: degrees_east + pressure_level: + wg: pressure + verif: pressure + verif_unit: hPa + forecast_reference_time: + wg: forecast_reference_time + verif: forecast_reference_time + forecast_step: + wg: leadtime + verif: forecast_period + long: time since forecast_reference_time + verif_unit: hours + stream: + wg: stream + verif: stream + ncells: + wg: ncells + verif: ncells \ No newline at end of file diff --git a/packages/evaluate/src/weathergen/evaluate/export/cf_utils.py b/packages/evaluate/src/weathergen/evaluate/export/cf_utils.py index 201ffa168..740d41bf2 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/cf_utils.py +++ b/packages/evaluate/src/weathergen/evaluate/export/cf_utils.py @@ -94,10 +94,12 @@ def _get_file_extension(output_format: str) -> str: """ if output_format == "netcdf": return "nc" + if output_format == "verif": + return "nc" elif output_format == "quaver": return "grib" else: raise ValueError( f"Unsupported output format: {output_format}," - "supported formats are ['netcdf', 'DWD', 'quaver']" + "supported formats are ['netcdf', 'verif', 'quaver']" ) diff --git a/packages/evaluate/src/weathergen/evaluate/export/export_inference.py b/packages/evaluate/src/weathergen/evaluate/export/export_inference.py index 0bf4be398..4b97c623c 100755 --- a/packages/evaluate/src/weathergen/evaluate/export/export_inference.py +++ b/packages/evaluate/src/weathergen/evaluate/export/export_inference.py @@ -75,7 +75,7 @@ def parse_args(args: list) -> argparse.Namespace: "--format", dest="output_format", type=str, - choices=["netcdf", "grib", "quaver"], + choices=["netcdf", "verif", "quaver"], help="Output file format (currently only netcdf supported)", required=True, ) @@ -207,7 +207,10 @@ def export_from_args(args: list) -> None: args = parse_args(sys.argv[1:]) # Load configuration - config_file = Path(_REPO_ROOT, "config/evaluate/config_zarr2cf.yaml") + if args.format == "verif": + config_file = Path(_REPO_ROOT, "config/evaluate/config_zarr2verif.yaml") + else: + config_file = Path(_REPO_ROOT, "config/evaluate/config_zarr2cf.yaml") config = OmegaConf.load(config_file) # check config loaded correctly assert len(config["variables"].keys()) > 0, "Config file not loaded correctly" diff --git a/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py b/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py new file mode 100644 index 000000000..337c4a598 --- /dev/null +++ b/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py @@ -0,0 +1,530 @@ +import logging +from pathlib import Path +from typing import Any + +import numpy as np +import xarray as xr +from omegaconf import OmegaConf + +from weathergen.evaluate.export.cf_utils import CfParser +from weathergen.evaluate.export.reshape import Regridder, find_pl + +_logger = logging.getLogger(__name__) +_logger.setLevel(logging.INFO) + +""" +Usage: + +uv run export --run-id ciga1p9c --stream ERA5 +--output-dir ./test_output1 +--format verif --samples 1 2 --fsteps 1 2 3 +""" + + +class VerifParser(CfParser): + """ + Child class for handling NetCDF output format for MetNor Verif software. + """ + + def __init__(self, config: OmegaConf, **kwargs): + """ + CF-compliant parser that handles both regular and Gaussian grids. + + Parameters + ---------- + config : OmegaConf + Configuration defining variable mappings and dimension metadata. + ds : xr.Dataset + Input dataset. + + Returns + ------- + xr.Dataset + CF-compliant dataset with consistent naming and attributes. + """ + for k, v in kwargs.items(): + setattr(self, k, v) + + super().__init__(config=config, grid_type=self.grid_type) + + self.mapping = config.get("variables", {}) + + def process_sample( + self, + fstep_iterator_results: iter, + ref_time: np.datetime64, + ): + """ + Process results from get_data_worker: reshape, concatenate, add metadata, and save. + Parameters + ---------- + fstep_iterator_results : Iterator over results from get_data_worker. + ref_time : Forecast reference time for the sample. + Returns + ------- + None + """ + for var in self.channels: + da_var = [] + for result in fstep_iterator_results: + if result is None: + continue + + result = result.as_xarray().squeeze() + result = result.sel(channel=var) + result = self.reshape(result) + da_fs.append(result) + + _logger.info(f"Retrieved {len(da_fs)} forecast steps for type {self.data_type}.") + _logger.info(f"Saved sample data to {self.output_format} in {self.output_dir}.") + + if da_fs: + da_fs = self.concatenate(da_fs) + da_fs = self.assign_frt(da_fs, ref_time) + da_fs = self.add_attrs(da_fs) + da_fs = self.add_metadata(da_fs) + da_fs = self.add_encoding(da_fs) + self.save(da_fs, ref_time) + + def get_output_filename(self, forecast_ref_time: np.datetime64, stream:str, variable:str) -> Path: + """ + Generate output filename based on date + + Parameters + ---------- + forecast_ref_time : Forecast reference time to include in the filename. + + Returns + ------- + Full path to the output file. + """ + + frt = np.datetime_as_string(forecast_ref_time, unit="h") + out_fname = ( + Path(self.output_dir) / "verif" / stream / variable / frt + ) + # create nested output directories + pathdir = out_fname.parent + pathdir.mkdir(exist_ok=True, parents=True) + + return out_fname + + def reshape(self, data: xr.DataArray) -> xr.Dataset: + """ + Reshape dataset while preserving grid structure (regular or Gaussian). + + Parameters + ---------- + data : xr.DataArray + Input data with dimensions (ipoint, channel) + + Returns + ------- + xr.Dataset + Reshaped dataset appropriate for the grid type + """ + grid_type = self.grid_type + + # Original logic + var_dict, pl = find_pl(data.channel.values) + data_vars = {} + + for new_var, old_vars in var_dict.items(): + if len(old_vars) > 1: + data_vars[new_var] = xr.DataArray( + data.sel(channel=old_vars).values, + dims=["ipoint", "pressure_level"], + ) + else: + data_vars[new_var] = xr.DataArray( + data.sel(channel=old_vars[0]).values, + dims=["ipoint"], + ) + + reshaped_dataset = xr.Dataset(data_vars) + reshaped_dataset = reshaped_dataset.assign_coords( + ipoint=data.coords["ipoint"], + pressure_level=pl, + ) + + if grid_type == "regular": + # Use original reshape logic for regular grids + # This is safe for regular grids + reshaped_dataset = reshaped_dataset.set_index( + ipoint=("valid_time", "lat", "lon") + ).unstack("ipoint") + else: + # Use new logic for Gaussian/unstructured grids + reshaped_dataset = reshaped_dataset.set_index(ipoint2=("ipoint", "valid_time")).unstack( + "ipoint2" + ) + # rename ipoint to ncells + reshaped_dataset = reshaped_dataset.rename_dims({"ipoint": "ncells"}) + reshaped_dataset = reshaped_dataset.rename_vars({"ipoint": "ncells"}) + + return reshaped_dataset + + def regrid(self, ds: xr.Dataset) -> xr.Dataset: + """ + Regrid a single xarray Dataset to specified grid type and degree. + Parameters + ---------- + output_grid_type : Type of grid to regrid to (e.g., 'regular_ll'). + degree : Degree of the grid; for regular grids, this is the lat/lon degree spacing; + for Gaussian grids, this is the N number (e.g., 63 for N63). + Returns + ------- + Regridded xarray Dataset. + """ + if self.regrid_degree is None or self.regrid_type is None: + _logger.info("No regridding specified, skipping regridding step.") + return ds + nc_regridder = Regridder(ds, output_grid_type=self.regrid_type, degree=self.regrid_degree) + + regrid_ds = nc_regridder.regrid_ds() + return regrid_ds + + def concatenate( + self, + array_list, + dim="valid_time", + data_vars="minimal", + coords="different", + compat="equals", + combine_attrs="drop", + sortby_dim="valid_time", + ) -> xr.Dataset: + """ + Uses list of pred/target xarray DataArrays to save one sample to a NetCDF file. + + Parameters + ---------- + type_str : str + Type of data ('pred' or 'targ') to include in the filename. + array_list : list of xr.DataArray + List of DataArrays to concatenate. + dim : str, optional + Dimension along which to concatenate. Default is 'valid_time'. + data_vars : str, optional + How to handle data variables during concatenation. Default is 'minimal'. + coords : str, optional + How to handle coordinates during concatenation. Default is 'different'. + compat : str, optional + Compatibility check for variables. Default is 'equals'. + combine_attrs : str, optional + How to combine attributes. Default is 'drop'. + sortby_dim : str, optional + Dimension to sort the final dataset by. Default is 'valid_time'. + + Returns + ------- + xr.Dataset + Concatenated xarray Dataset. + """ + + data = xr.concat( + array_list, + dim=dim, + data_vars=data_vars, + coords=coords, + compat=compat, + combine_attrs=combine_attrs, + ).sortby(sortby_dim) + + return data + + def assign_frt(self, ds: xr.Dataset, reference_time: np.datetime64) -> xr.Dataset: + """ + Assign forecast reference time coordinate to the dataset. + + Parameters + ---------- + ds : xarray Dataset to assign coordinates to. + reference_time : Forecast reference time to assign. + + Returns + ------- + xarray Dataset with assigned forecast reference time coordinate. + """ + ds = ds.assign_coords(forecast_reference_time=reference_time) + + if "sample" in ds.coords: + ds = ds.drop_vars("sample") + + n_hours = self.fstep_hours.astype("int64") + ds["forecast_step"] = ds["forecast_step"] * n_hours + return ds + + def add_attrs(self, ds: xr.Dataset) -> xr.Dataset: + """ + Add CF-compliant attributes to the dataset variables. + + Parameters + ---------- + ds : xarray Dataset to add attributes to. + Returns + ------- + xarray Dataset with CF-compliant variable attributes. + """ + + if self.grid_type == "gaussian": + variables = self._attrs_gaussian_grid(ds) + else: + variables = self._attrs_regular_grid(ds) + + dataset = xr.merge(variables.values()) + dataset.attrs = ds.attrs + return dataset + + def add_encoding(self, ds: xr.Dataset) -> xr.Dataset: + """ + Add time encoding to the dataset variables. + Add aux coordinates to forecast_period + + Parameters + ---------- + ds : xarray Dataset to add time encoding to. + Returns + ------- + xarray Dataset with time encoding added. + """ + time_encoding = { + "units": "hours since 1970-01-01 00:00:00", + "calendar": "gregorian", + } + + if "valid_time" in ds.coords: + ds["valid_time"].encoding.update(time_encoding) + + if "forecast_reference_time" in ds.coords: + ds["forecast_reference_time"].encoding.update(time_encoding) + + if "forecast_period" in ds.coords: + ds["forecast_period"].encoding.update({"coordinates": "forecast_reference_time"}) + + return ds + + def _attrs_gaussian_grid(self, ds: xr.Dataset) -> xr.Dataset: + """ + Assign CF-compliant attributes to variables in a Gaussian grid dataset. + Parameters + ---------- + ds : xr.Dataset + Input dataset. + Returns + ------- + xr.Dataset + Dataset with CF-compliant variable attributes. + """ + variables = {} + dims_cfg = self.config.get("dimensions", {}) + ds, ds_attrs = self._assign_dim_attrs(ds, dims_cfg) + for var_name, da in ds.data_vars.items(): + mapped_info = self.mapping.get(var_name, {}) + mapped_name = mapped_info.get("var", var_name) + + coords = self._build_coordinate_mapping(ds, mapped_info, ds_attrs) + + attributes = { + "standard_name": mapped_info.get("std", var_name), + "units": mapped_info.get("std_unit", "unknown"), + } + if "long" in mapped_info: + attributes["long_name"] = mapped_info["long"] + variables[mapped_name] = xr.DataArray( + data=da.values, + dims=da.dims, + coords=coords, + attrs=attributes, + name=mapped_name, + ) + + return variables + + def _attrs_regular_grid(self, ds: xr.Dataset) -> xr.Dataset: + """ + Assign CF-compliant attributes to variables in a regular grid dataset. + Parameters + ---------- + ds : xr.Dataset + Input dataset. + Returns + ------- + xr.Dataset + Dataset with CF-compliant variable attributes. + """ + variables = {} + dims_cfg = self.config.get("dimensions", {}) + ds, ds_attrs = self._assign_dim_attrs(ds, dims_cfg) + dims_list = ["pressure", "latitude", "longitude", "valid_time"] + for var_name, da in ds.data_vars.items(): + mapped_info = self.mapping.get(var_name, {}) + mapped_name = mapped_info.get("var", var_name) + dims = dims_list.copy() + if mapped_info.get("level_type") == "sfc": + dims.remove("pressure") + + coords = self._build_coordinate_mapping(ds, mapped_info, ds_attrs) + + attributes = { + "standard_name": mapped_info.get("std", var_name), + "units": mapped_info.get("std_unit", "unknown"), + } + if "long" in mapped_info: + attributes["long_name"] = mapped_info["long"] + variables[mapped_name] = xr.DataArray( + data=da.values, + dims=dims, + coords={**coords, "valid_time": ds["valid_time"].values}, + attrs=attributes, + name=mapped_name, + ) + if da.encoding.get("coordinates"): + variables[mapped_name].encoding["coordinates"] = ( + da.encoding["coordinates"] + .replace(" lat ", " latitude ") + .replace(" lon ", " longitude "), + ) + + return variables + + def _assign_dim_attrs( + self, ds: xr.Dataset, dim_cfg: dict[str, Any] + ) -> tuple[xr.Dataset, dict[str, dict[str, str]]]: + """ + Assign CF attributes from given config file. + Parameters + ---------- + ds : xr.Dataset + Input dataset. + dim_cfg : Dict[str, Any] + Dimension configuration from mapping. + Returns + ------- + Dict[str, Dict[str, str]]: + Attributes for each dimension. + xr.Dataset: + Dataset with renamed dimensions. + """ + ds_attrs = {} + + for dim_name, meta in dim_cfg.items(): + wg_name = meta.get("wg", dim_name) + if dim_name in ds.dims and dim_name != wg_name: + ds = ds.rename_dims({dim_name: wg_name}) + + dim_attrs = {"standard_name": meta.get("std", wg_name)} + if meta.get("std_unit"): + dim_attrs["units"] = meta["std_unit"] + if meta.get("long"): + dim_attrs["long_name"] = meta["long"] + ds_attrs[wg_name] = dim_attrs + + return ds, ds_attrs + + def _build_coordinate_mapping( + self, ds: xr.Dataset, var_cfg: dict[str, Any], attrs: dict[str, dict[str, str]] + ) -> dict[str, Any]: + """Create coordinate mapping for a given variable. + Parameters + ---------- + ds : xr.Dataset + Input dataset. + var_cfg : Dict[str, Any] + Variable configuration from mapping. + attrs : Dict[str, Dict[str, str]] + Attributes for dimensions. + Returns + ------- + Dict[str, Any]: + Coordinate mapping for the variable. + """ + coords = {} + coord_map = self.config.get("coordinates", {}).get(var_cfg.get("level_type"), {}) + + for coord, new_name in coord_map.items(): + coords[new_name] = ( + ds.coords[coord].dims, + ds.coords[coord].values, + attrs[new_name], + ) + + return coords + + def _add_grid_attrs(self, ds: xr.Dataset, grid_info: dict | None = None) -> xr.Dataset: + """ + Add Gaussian grid metadata following CF conventions. + + Parameters + ---------- + ds : xr.Dataset + Dataset to add metadata to + grid_info : dict, optional + Dictionary with grid information: + - 'N': Gaussian grid number (e.g., N320) + - 'reduced': Whether it's a reduced Gaussian grid + + Returns + ------- + xr.Dataset + Dataset with added grid metadata + """ + + if self.grid_type != "gaussian": + return ds + + # ds = ds.copy() + # Add grid mapping information + ds.attrs["grid_type"] = "gaussian" + + # If grid info provided, add it + if grid_info: + ds.attrs["gaussian_grid_number"] = grid_info.get("N", "unknown") + ds.attrs["gaussian_grid_type"] = ( + "reduced" if grid_info.get("reduced", False) else "regular" + ) + + return ds + + def add_metadata(self, ds: xr.Dataset) -> xr.Dataset: + """ + Add CF conventions to the dataset attributes. + + Parameters + ---------- + ds : Input xarray Dataset to add conventions to. + Returns + ------- + xarray Dataset with CF conventions added to attributes. + """ + # ds = ds.copy() + ds.attrs["title"] = f"WeatherGenerator Output for {self.run_id} using stream {self.stream} and {self.variable}" + ds.attrs["institution"] = "WeatherGenerator Project" + ds.attrs["source"] = "WeatherGenerator v0.0" + ds.attrs["history"] = ( + "Created using the export_inference.py script on " + + np.datetime_as_string(np.datetime64("now"), unit="s") + ) + ds.attrs["Conventions"] = "verif_1.0.0" + # drop stream now it's in title + ds = ds.drop_vars("stream") + return ds + + def save(self, ds: xr.Dataset, forecast_ref_time: np.datetime64) -> None: + """ + Save the dataset to a NetCDF file. + + Parameters + ---------- + ds : xarray Dataset to save. + data_type : Type of data ('pred' or 'targ') to include in the filename. + forecast_ref_time : Forecast reference time to include in the filename. + + Returns + ------- + None + """ + out_fname = self.get_output_filename(forecast_ref_time) + _logger.info(f"Saving to {out_fname}.") + ds.to_netcdf(out_fname) + _logger.info(f"Saved NetCDF file to {out_fname}.") From 786bc5671ac73b89dc8a1b22b5bb66aa3a6c42bf Mon Sep 17 00:00:00 2001 From: Sorcha Date: Thu, 19 Feb 2026 13:44:16 +0100 Subject: [PATCH 04/28] working - TODO: delte uncesseary --- config/evaluate/config_zarr2verif.yaml | 40 ++- .../weathergen/evaluate/export/cf_utils.py | 2 - .../evaluate/export/export_inference.py | 26 +- .../weathergen/evaluate/export/io_utils.py | 64 ++-- .../evaluate/export/parser_factory.py | 4 +- .../evaluate/export/parsers/verif_parser.py | 321 +++++++++--------- .../weathergen/evaluate/export/preprocess.py | 63 ++++ .../src/weathergen/evaluate/export/reshape.py | 232 +++++++++++++ .../src/weathergen/verif/create_verif.py | 2 +- 9 files changed, 537 insertions(+), 217 deletions(-) create mode 100644 packages/evaluate/src/weathergen/evaluate/export/preprocess.py diff --git a/config/evaluate/config_zarr2verif.yaml b/config/evaluate/config_zarr2verif.yaml index da842cfed..b76bff633 100644 --- a/config/evaluate/config_zarr2verif.yaml +++ b/config/evaluate/config_zarr2verif.yaml @@ -39,8 +39,8 @@ variables: level_type: sfc - mlsp: - var: msl + msl: + var: mslp long: Mean sea level pressure wg_unit: {CERRA: Pa, MEPS: Pa, @@ -53,8 +53,8 @@ variables: level_type: sfc - wind: - var: 10si # derived channel + 10si: + var: wind # derived channel long: wind speed wg_unit: {CERRA: m/s, MEPS: m/s, @@ -69,43 +69,51 @@ variables: coordinates: sfc: - valid_time: valid_time + valid_time: time lat: latitude lon: longitude stream: stream forecast_step: leadtime forecast_reference_time: forecast_reference_time ncells: ncells - #pl: + pl: #not needed + pressure_level: pressure + valid_time: time + lat: latitude + lon: longitude + stream: stream + forecast_step: leadtime + forecast_reference_time: forecast_reference_time + ncells: ncells dimensions: valid_time: - wg: valid_time verif: time + std: time lat: - wg: latitude verif: latitude + std: latitude verif_unit: degrees_north lon: - wg: longitude verif: longitude + std: longitude verif_unit: degrees_east pressure_level: - wg: pressure verif: pressure + std: pressure verif_unit: hPa forecast_reference_time: - wg: forecast_reference_time verif: forecast_reference_time + std: forecast_reference_time forecast_step: - wg: leadtime - verif: forecast_period + verif: leadtime + std: forecast_period long: time since forecast_reference_time verif_unit: hours stream: - wg: stream verif: stream + std: stream ncells: - wg: ncells - verif: ncells \ No newline at end of file + verif: ncells + std: ncells \ No newline at end of file diff --git a/packages/evaluate/src/weathergen/evaluate/export/cf_utils.py b/packages/evaluate/src/weathergen/evaluate/export/cf_utils.py index 740d41bf2..0b5cd84fe 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/cf_utils.py +++ b/packages/evaluate/src/weathergen/evaluate/export/cf_utils.py @@ -23,10 +23,8 @@ def __init__(self, config, **kwargs): grid_type : str Type of grid ('regular' or 'gaussian'). """ - for k, v in kwargs.items(): setattr(self, k, v) - self.config = config self.file_extension = _get_file_extension(self.output_format) self.fstep_hours = np.timedelta64(self.fstep_hours, "h") diff --git a/packages/evaluate/src/weathergen/evaluate/export/export_inference.py b/packages/evaluate/src/weathergen/evaluate/export/export_inference.py index 4b97c623c..a296417f4 100755 --- a/packages/evaluate/src/weathergen/evaluate/export/export_inference.py +++ b/packages/evaluate/src/weathergen/evaluate/export/export_inference.py @@ -182,9 +182,30 @@ def parse_args(args: list) -> argparse.Namespace: help="Type of grid to regrid to (only used if --regrid-degree is specified)", ) + parser.add_argument( + "-b", + "--obs", + help= "observation file for creating verif files" + ) + + parser.add_argument( + "-m", + "--method", + default="2d", + choices=["2d", "lat_lon", "nearest"], + help="Interpolation method used for verif. Default: 2d_interpolation", + ) + + parser.add_argument( + "--verif-template", + default="verif/%S/%V/verif_%S_%V_%M.nc", + help="Template for the output nc filenames, default will be to create output/verif/%S/%V \ + repertories where %S, %V, %d are replaced by the streams, variable and date", + ) + args, unknown_args = parser.parse_known_args(args) if unknown_args: - _logger.warning(f"Unknown arguments: {unknown_args}") + _logger.warning(f"Unknown arguments: {unknown_args}") return args @@ -205,9 +226,8 @@ def export_from_args(args: list) -> None: args : List of command line arguments. """ args = parse_args(sys.argv[1:]) - # Load configuration - if args.format == "verif": + if args.output_format == "verif": config_file = Path(_REPO_ROOT, "config/evaluate/config_zarr2verif.yaml") else: config_file = Path(_REPO_ROOT, "config/evaluate/config_zarr2cf.yaml") diff --git a/packages/evaluate/src/weathergen/evaluate/export/io_utils.py b/packages/evaluate/src/weathergen/evaluate/export/io_utils.py index 06f0cf25f..710cc19b9 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/io_utils.py +++ b/packages/evaluate/src/weathergen/evaluate/export/io_utils.py @@ -11,40 +11,40 @@ _logger.setLevel(logging.INFO) -def output_filename( - prefix: str, - run_id: str, - output_dir: str, - output_format: str, - forecast_ref_time: np.datetime64, - regrid_degree: float, -) -> Path: - """ - Generate output filename based on prefix (should refer to type e.g. pred/targ), run_id, sample - index, output directory, format and forecast_ref_time. +# def output_filename( +# prefix: str, +# run_id: str, +# output_dir: str, +# output_format: str, +# forecast_ref_time: np.datetime64, +# regrid_degree: float, +# ) -> Path: +# """ +# Generate output filename based on prefix (should refer to type e.g. pred/targ), run_id, sample +# index, output directory, format and forecast_ref_time. - Parameters - ---------- - prefix : Prefix for file name (e.g., 'pred' or 'targ'). - run_id :Run ID to include in the filename. - output_dir : Directory to save the output file. - output_format : Output file format (currently only 'netcdf' supported). - forecast_ref_time : Forecast reference time to include in the filename. +# Parameters +# ---------- +# prefix : Prefix for file name (e.g., 'pred' or 'targ'). +# run_id :Run ID to include in the filename. +# output_dir : Directory to save the output file. +# output_format : Output file format (currently only 'netcdf' supported). +# forecast_ref_time : Forecast reference time to include in the filename. - Returns - ------- - Full path to the output file. - """ - if output_format not in ["netcdf"]: - raise ValueError( - f"Unsupported output format: {output_format}, supported formates are ['netcdf']" - ) - file_extension = "nc" - frt = np.datetime_as_string(forecast_ref_time, unit="h") - if regrid_degree is not None: - run_id += f"_regular{regrid_degree, regrid_degree}" - out_fname = Path(output_dir) / f"{prefix}_{frt}_{run_id}.{file_extension}" - return out_fname +# Returns +# ------- +# Full path to the output file. +# """ +# if output_format not in ["netcdf", "verif"]: +# raise ValueError( +# f"Unsupported output format: {output_format}, supported formates are ['netcdf']" +# ) +# file_extension = "nc" +# frt = np.datetime_as_string(forecast_ref_time, unit="h") +# if regrid_degree is not None: +# run_id += f"_regular{regrid_degree, regrid_degree}" +# out_fname = Path(output_dir) / f"{prefix}_{frt}_{run_id}.{file_extension}" +# return out_fname def get_data_worker(args: tuple) -> xr.DataArray: diff --git a/packages/evaluate/src/weathergen/evaluate/export/parser_factory.py b/packages/evaluate/src/weathergen/evaluate/export/parser_factory.py index d248b0c78..9d5e1c621 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/parser_factory.py +++ b/packages/evaluate/src/weathergen/evaluate/export/parser_factory.py @@ -3,6 +3,7 @@ from weathergen.evaluate.export.cf_utils import CfParser from weathergen.evaluate.export.parsers.netcdf_parser import NetcdfParser from weathergen.evaluate.export.parsers.quaver_parser import QuaverParser +from weathergen.evaluate.export.parsers.verif_parser import VerifParser class CfParserFactory: @@ -30,17 +31,16 @@ def get_parser(config: OmegaConf, **kwargs) -> CfParser: _parser_map = { "netcdf": (NetcdfParser, ["grid_type"]), "quaver": (QuaverParser, ["grid_type", "channels", "template"]), + "verif": (VerifParser, ["obs", "method", "verif_template"]) } fmt = kwargs.get("output_format") parser_class = _parser_map.get(fmt) parser = parser_class[0] - # allowed_keys = parser_class[1] # filtered_kwargs = {k: v for k, v in kwargs.items() if k in allowed_keys} if parser_class is None: raise ValueError(f"Unsupported format: {fmt}") - return parser(config, **kwargs) diff --git a/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py b/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py index 337c4a598..2085bc307 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py +++ b/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py @@ -7,7 +7,9 @@ from omegaconf import OmegaConf from weathergen.evaluate.export.cf_utils import CfParser -from weathergen.evaluate.export.reshape import Regridder, find_pl +from weathergen.evaluate.export.reshape import find_pl, get_grid_points, get_obs_coordinates, Interpolator_factory +from weathergen.evaluate.export.preprocess import compute_mslp, compute_precip + _logger = logging.getLogger(__name__) _logger.setLevel(logging.INFO) @@ -15,9 +17,11 @@ """ Usage: -uv run export --run-id ciga1p9c --stream ERA5 ---output-dir ./test_output1 ---format verif --samples 1 2 --fsteps 1 2 3 +uv run export --run-id wgp6fowx --stream ERA5 \ +--output-dir ../test_output1 \ +--format verif --samples 1 2 --fsteps 1 2 3 \ +--obs /p/project1/weatherai/myhre1/metno_observations_v3.nc \ +--method 2d """ @@ -45,10 +49,19 @@ def __init__(self, config: OmegaConf, **kwargs): for k, v in kwargs.items(): setattr(self, k, v) - super().__init__(config=config, grid_type=self.grid_type) + super().__init__(config=config) + + if not hasattr(self, "obs"): + raise ValueError("Observation data required for creating verif compliant NetCDFs") self.mapping = config.get("variables", {}) + # add extra attributes + self.obs = xr.open_dataset(self.obs) + lat, lon, _ = get_obs_coordinates(self.obs) + self.obs_coords = np.column_stack((lat.values, lon.values)) + self.zarr_coords = None + def process_sample( self, fstep_iterator_results: iter, @@ -64,51 +77,69 @@ def process_sample( ------- None """ - for var in self.channels: - da_var = [] - for result in fstep_iterator_results: - if result is None: - continue - - result = result.as_xarray().squeeze() - result = result.sel(channel=var) - result = self.reshape(result) - da_fs.append(result) - - _logger.info(f"Retrieved {len(da_fs)} forecast steps for type {self.data_type}.") - _logger.info(f"Saved sample data to {self.output_format} in {self.output_dir}.") - - if da_fs: - da_fs = self.concatenate(da_fs) - da_fs = self.assign_frt(da_fs, ref_time) - da_fs = self.add_attrs(da_fs) - da_fs = self.add_metadata(da_fs) - da_fs = self.add_encoding(da_fs) - self.save(da_fs, ref_time) - + required_channels = ["10u", "10v", "sp", "2t", "msl"] + self.channels = list(set(self.channels) & set(required_channels)) + da_fs = [] + for result in fstep_iterator_results: + if result is None: + continue + + result = result.as_xarray().squeeze() + result = result.sel(channel=self.channels) + result = self.preprocess(result) + result = self.reshape(result) + da_fs.append(result) + + _logger.info(f"Retrieved {len(da_fs)} forecast steps for type {self.data_type}.") + + if da_fs: + if self.zarr_coords is None: + self.zarr_coords = get_grid_points(da_fs[0]) + self.zarr_dt = self.get_zarr_dt(da_fs[0]) + + + da_fs = self.concatenate(da_fs) + da_fs = self.assign_frt(da_fs, ref_time) + da_fs = self.add_attrs(da_fs) + da_fs = self.add_metadata(da_fs) + da_fs = self.add_encoding(da_fs) + for verif_var in self.mapping.keys(): + da_var = self.regrid(da_fs, verif_var) + obs_result = self.obs_preprocess(ref_time, verif_var) + self.save(da_var, obs_result, ref_time, verif_var) + _logger.info(f"Saved {verif_var} data for {ref_time} to \ + {self.output_format} in {self.output_dir}.") + + def get_zarr_dt(self, ds): + zarr_dt = ( + ds.source_interval_end.values - ds.source_interval_start.values + ) + zarr_dt = zarr_dt.astype("timedelta64[h]") + return zarr_dt + def get_output_filename(self, forecast_ref_time: np.datetime64, stream:str, variable:str) -> Path: """ - Generate output filename based on date - - Parameters - ---------- - forecast_ref_time : Forecast reference time to include in the filename. - - Returns - ------- - Full path to the output file. + Create output directories for the verif files + and return path to output file + Args: + stream (string) + variables (list[string]) + outfiles (string): template for the output files + Outputs: + None """ - - frt = np.datetime_as_string(forecast_ref_time, unit="h") - out_fname = ( - Path(self.output_dir) / "verif" / stream / variable / frt + outfile = Path( + self.verif_template.replace("%S", self.stream) + .replace("%V", variable) + .replace("%M", self.method) + .replace("%D", self.data_type) ) - # create nested output directories - pathdir = out_fname.parent + outfile = Path(self.output_dir) / outfile + pathdir = outfile.parent + _logger.info(f"Output directory: {pathdir}") pathdir.mkdir(exist_ok=True, parents=True) - - return out_fname - + return outfile + def reshape(self, data: xr.DataArray) -> xr.Dataset: """ Reshape dataset while preserving grid structure (regular or Gaussian). @@ -164,25 +195,67 @@ def reshape(self, data: xr.DataArray) -> xr.Dataset: return reshaped_dataset - def regrid(self, ds: xr.Dataset) -> xr.Dataset: + + def obs_preprocess(self, ref_time, verif_var): + obs_data = self.obs + mapped_info = self.mapping.get(verif_var, {}) + obs_name = mapped_info.get("obs_name", {}) + if verif_var == "mslp": + obs_data[obs_name].values = compute_mslp(obs_data, ref_time) + if verif_var == "tp": + obs_data[obs_name].sel(time = ref_time).values = compute_precip(obs_data, self.zarr_dt, ref_time) + else: + pass + + return obs_data[obs_name] + + def preprocess(self, ds: xr.Dataset) -> xr.Dataset: + """ + Preprocess variables and only keep relevant ones + """ + if all(["10u", "10v"]) in self.channels: + u = ds.sel(channel='10u') + v = ds.sel(channel='10v') + # hypotenuese + wind_speed = xr.apply_ufunc( + np.hypot, u, v, dask="parallelized", output_dtypes=[ds.dtype] + ).astype("float32") + wind_speed = wind_speed.expand_dims(channel=['10si']) + if ds.chunks: + wind_speed = wind_speed.chunk({"ipoint": ds.chunks[ds.get_axis_num("ipoint")][0], "channel": 1}) + new_ds = xr.concat([ds, wind_speed], dim="channel") + new_ds.attrs = ds.attrs + + # remove unnecessary + new_ds = new_ds.drop_sel(channel = ['10u', '10v']) + return new_ds + else: + return ds + + def regrid(self, ds: xr.Dataset, verif_var) -> xr.Dataset: """ - Regrid a single xarray Dataset to specified grid type and degree. + Regrid a single xarray Datas v vvbet using specific method. Parameters ---------- - output_grid_type : Type of grid to regrid to (e.g., 'regular_ll'). - degree : Degree of the grid; for regular grids, this is the lat/lon degree spacing; - for Gaussian grids, this is the N number (e.g., 63 for N63). + ds: native xarray Dataset Returns ------- Regridded xarray Dataset. """ - if self.regrid_degree is None or self.regrid_type is None: - _logger.info("No regridding specified, skipping regridding step.") - return ds - nc_regridder = Regridder(ds, output_grid_type=self.regrid_type, degree=self.regrid_degree) + try: + ds_var = ds[verif_var] + except KeyError as e: + _logger.info(f"{verif_var} not available in WeatherGenerator output: {e}") + return + method_factory = Interpolator_factory(self.method) + interpolator = method_factory.get_interpolator(self.zarr_coords, self.obs_coords) + + for idx in range(len(ds_var.time.values)): + interpolator.interpolate( + ds_var.values[:,idx] + ) + return ds_var - regrid_ds = nc_regridder.regrid_ds() - return regrid_ds def concatenate( self, @@ -250,8 +323,7 @@ def assign_frt(self, ds: xr.Dataset, reference_time: np.datetime64) -> xr.Datase if "sample" in ds.coords: ds = ds.drop_vars("sample") - - n_hours = self.fstep_hours.astype("int64") + n_hours = self.fstep_hours.astype('int64') ds["forecast_step"] = ds["forecast_step"] * n_hours return ds @@ -266,12 +338,7 @@ def add_attrs(self, ds: xr.Dataset) -> xr.Dataset: ------- xarray Dataset with CF-compliant variable attributes. """ - - if self.grid_type == "gaussian": - variables = self._attrs_gaussian_grid(ds) - else: - variables = self._attrs_regular_grid(ds) - + variables = self._attrs_gaussian_grid(ds) dataset = xr.merge(variables.values()) dataset.attrs = ds.attrs return dataset @@ -279,7 +346,7 @@ def add_attrs(self, ds: xr.Dataset) -> xr.Dataset: def add_encoding(self, ds: xr.Dataset) -> xr.Dataset: """ Add time encoding to the dataset variables. - Add aux coordinates to forecast_period + Add aux coordinates to leadtime Parameters ---------- @@ -293,20 +360,20 @@ def add_encoding(self, ds: xr.Dataset) -> xr.Dataset: "calendar": "gregorian", } - if "valid_time" in ds.coords: - ds["valid_time"].encoding.update(time_encoding) + if "time" in ds.coords: + ds["time"].encoding.update(time_encoding) if "forecast_reference_time" in ds.coords: ds["forecast_reference_time"].encoding.update(time_encoding) - if "forecast_period" in ds.coords: - ds["forecast_period"].encoding.update({"coordinates": "forecast_reference_time"}) + if "leadtime" in ds.coords: + ds["leadtime"].encoding.update({"coordinates": "forecast_reference_time"}) return ds def _attrs_gaussian_grid(self, ds: xr.Dataset) -> xr.Dataset: """ - Assign CF-compliant attributes to variables in a Gaussian grid dataset. + Assign CF-compliant attributes to variables in a gaussian grid dataset. Parameters ---------- ds : xr.Dataset @@ -316,19 +383,28 @@ def _attrs_gaussian_grid(self, ds: xr.Dataset) -> xr.Dataset: xr.Dataset Dataset with CF-compliant variable attributes. """ + unit_conversion = {"kg/m^2": 1.0, "Pa": 1.0, "K": 1.0, "m/s": 1.0, "m": 1000.0} + variables = {} dims_cfg = self.config.get("dimensions", {}) ds, ds_attrs = self._assign_dim_attrs(ds, dims_cfg) for var_name, da in ds.data_vars.items(): mapped_info = self.mapping.get(var_name, {}) mapped_name = mapped_info.get("var", var_name) + mapped_units = mapped_info.get("wg_unit", {}) coords = self._build_coordinate_mapping(ds, mapped_info, ds_attrs) + wg_unit = mapped_units.get(self.stream, "DEFAULT") + verif_unit = mapped_info.get("verif_unit", None) + if wg_unit != verif_unit: + #perform unit conversion + da.values = da.values * unit_conversion[wg_unit] + attributes = { - "standard_name": mapped_info.get("std", var_name), - "units": mapped_info.get("std_unit", "unknown"), + "units": verif_unit, } + if "long" in mapped_info: attributes["long_name"] = mapped_info["long"] variables[mapped_name] = xr.DataArray( @@ -341,53 +417,6 @@ def _attrs_gaussian_grid(self, ds: xr.Dataset) -> xr.Dataset: return variables - def _attrs_regular_grid(self, ds: xr.Dataset) -> xr.Dataset: - """ - Assign CF-compliant attributes to variables in a regular grid dataset. - Parameters - ---------- - ds : xr.Dataset - Input dataset. - Returns - ------- - xr.Dataset - Dataset with CF-compliant variable attributes. - """ - variables = {} - dims_cfg = self.config.get("dimensions", {}) - ds, ds_attrs = self._assign_dim_attrs(ds, dims_cfg) - dims_list = ["pressure", "latitude", "longitude", "valid_time"] - for var_name, da in ds.data_vars.items(): - mapped_info = self.mapping.get(var_name, {}) - mapped_name = mapped_info.get("var", var_name) - dims = dims_list.copy() - if mapped_info.get("level_type") == "sfc": - dims.remove("pressure") - - coords = self._build_coordinate_mapping(ds, mapped_info, ds_attrs) - - attributes = { - "standard_name": mapped_info.get("std", var_name), - "units": mapped_info.get("std_unit", "unknown"), - } - if "long" in mapped_info: - attributes["long_name"] = mapped_info["long"] - variables[mapped_name] = xr.DataArray( - data=da.values, - dims=dims, - coords={**coords, "valid_time": ds["valid_time"].values}, - attrs=attributes, - name=mapped_name, - ) - if da.encoding.get("coordinates"): - variables[mapped_name].encoding["coordinates"] = ( - da.encoding["coordinates"] - .replace(" lat ", " latitude ") - .replace(" lon ", " longitude "), - ) - - return variables - def _assign_dim_attrs( self, ds: xr.Dataset, dim_cfg: dict[str, Any] ) -> tuple[xr.Dataset, dict[str, dict[str, str]]]: @@ -409,13 +438,13 @@ def _assign_dim_attrs( ds_attrs = {} for dim_name, meta in dim_cfg.items(): - wg_name = meta.get("wg", dim_name) + wg_name = meta.get("verif", dim_name) if dim_name in ds.dims and dim_name != wg_name: ds = ds.rename_dims({dim_name: wg_name}) dim_attrs = {"standard_name": meta.get("std", wg_name)} - if meta.get("std_unit"): - dim_attrs["units"] = meta["std_unit"] + if meta.get("verif_unit"): + dim_attrs["units"] = meta["verif_unit"] if meta.get("long"): dim_attrs["long_name"] = meta["long"] ds_attrs[wg_name] = dim_attrs @@ -451,41 +480,6 @@ def _build_coordinate_mapping( return coords - def _add_grid_attrs(self, ds: xr.Dataset, grid_info: dict | None = None) -> xr.Dataset: - """ - Add Gaussian grid metadata following CF conventions. - - Parameters - ---------- - ds : xr.Dataset - Dataset to add metadata to - grid_info : dict, optional - Dictionary with grid information: - - 'N': Gaussian grid number (e.g., N320) - - 'reduced': Whether it's a reduced Gaussian grid - - Returns - ------- - xr.Dataset - Dataset with added grid metadata - """ - - if self.grid_type != "gaussian": - return ds - - # ds = ds.copy() - # Add grid mapping information - ds.attrs["grid_type"] = "gaussian" - - # If grid info provided, add it - if grid_info: - ds.attrs["gaussian_grid_number"] = grid_info.get("N", "unknown") - ds.attrs["gaussian_grid_type"] = ( - "reduced" if grid_info.get("reduced", False) else "regular" - ) - - return ds - def add_metadata(self, ds: xr.Dataset) -> xr.Dataset: """ Add CF conventions to the dataset attributes. @@ -498,7 +492,7 @@ def add_metadata(self, ds: xr.Dataset) -> xr.Dataset: xarray Dataset with CF conventions added to attributes. """ # ds = ds.copy() - ds.attrs["title"] = f"WeatherGenerator Output for {self.run_id} using stream {self.stream} and {self.variable}" + ds.attrs["title"] = f"WeatherGenerator Output for {self.run_id} using stream {self.stream}" ds.attrs["institution"] = "WeatherGenerator Project" ds.attrs["source"] = "WeatherGenerator v0.0" ds.attrs["history"] = ( @@ -510,7 +504,7 @@ def add_metadata(self, ds: xr.Dataset) -> xr.Dataset: ds = ds.drop_vars("stream") return ds - def save(self, ds: xr.Dataset, forecast_ref_time: np.datetime64) -> None: + def save(self, ds: xr.Dataset, obs_result, forecast_ref_time: np.datetime64, verif_var) -> None: """ Save the dataset to a NetCDF file. @@ -524,7 +518,12 @@ def save(self, ds: xr.Dataset, forecast_ref_time: np.datetime64) -> None: ------- None """ - out_fname = self.get_output_filename(forecast_ref_time) + lat, lon, alt = get_obs_coordinates(self.obs) + if ds is None: + merged = xr.merge([obs_result, lat, lon, alt]) + else: + merged = xr.merge([ds, obs_result, lat, lon, alt]) + out_fname = self.get_output_filename(forecast_ref_time, self.stream, verif_var) _logger.info(f"Saving to {out_fname}.") - ds.to_netcdf(out_fname) + merged.to_netcdf(out_fname) _logger.info(f"Saved NetCDF file to {out_fname}.") diff --git a/packages/evaluate/src/weathergen/evaluate/export/preprocess.py b/packages/evaluate/src/weathergen/evaluate/export/preprocess.py new file mode 100644 index 000000000..9e0dab513 --- /dev/null +++ b/packages/evaluate/src/weathergen/evaluate/export/preprocess.py @@ -0,0 +1,63 @@ +import logging +import numpy as np +import xarray as xr + + +_logger = logging.getLogger(__name__) +_logger.setLevel(logging.INFO) + +""" +Extra helper functions to preprocess data +e.g. for verif applications +""" + + +def compute_mslp(obs: xr.DataArray, time: np.datetime64) -> np.ndarray: + # g = 9.80665 # Gravitational acceleration (m/s**2) + # R = 8.31447 # Universal gas constant (J/mol*K) + # a = 0.0065 # Temperature lapse rate (K/m) + # Ch = 0.0012 # (K/Pa) + + A = 17.625 + B = 243.03 + C = 6.1094 + + P = obs.data_vars["surface_air_pressure"].sel(time=time) + T = obs.data_vars["air_temperature"].sel(time=time) + rh = obs.data_vars["relative_humidity"].sel(time=time) + + altitude = obs.altitude + + e = rh * 6.11 * np.power(10.0, ((7.5 * (T - 273.15)) / (T - 38.85))) + + dewpoint = np.where(~np.isnan(e), B * np.log(e / C) / (A - np.log(e / C)), T - 276.15) + + e = np.where(np.isnan(e), 0, e) + + Tv = T / ( + 1.0 - 0.379 * (6.11 * np.power(10.0, ((7.5 * dewpoint) / (237.7 + dewpoint))) / P) + ) + + # mslp = np.where(altitude >= 50., + # P * np.exp((g * altitude / R) / (T + 0.5 * a * altitude + e * Ch)), + # P + P * altitude / (29.27 * Tv)) + + mslp = P + P * altitude / (29.27 * Tv) + + return mslp + +def compute_precip(obs_data, zarr_dt, frt): + + obs_dt = obs_data.time.values[1] - obs_data.time.values[0] + obs_dt = obs_dt.astype("timedelta64[h]") + + if obs_dt >= zarr_dt: + return obs_data["precipitation_amount_1h"].values + else: + accumulate = np.zeros(obs_data.location.shape[0]) + int_factor = int(zarr_dt / obs_dt) + + for i in range(int_factor): + back_time = frt - zarr_dt + (i + 1) * obs_dt + accumulate += obs_data.data_vars["precipitation_amount_1h"].sel(time=back_time) + return accumulate diff --git a/packages/evaluate/src/weathergen/evaluate/export/reshape.py b/packages/evaluate/src/weathergen/evaluate/export/reshape.py index e3e2f11d9..b1e3adb2d 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/reshape.py +++ b/packages/evaluate/src/weathergen/evaluate/export/reshape.py @@ -14,6 +14,31 @@ Enhanced functions to handle Gaussian grids when converting from Zarr to NetCDF. """ +def get_obs_coordinates(obs: xr.Dataset): + """ + Extract latitude, longitude and altitude + from observation dataset + Args: + obs: Dataset + Outputs: + lat: DataArray + lon: DataArray + alt: DataArray + """ + + lat = obs.latitude.astype("float32") + lat.name = "lat" + + lon = obs.longitude.astype("float32") + lon.name = "lon" + + alt = obs.altitude.astype("float32") + + return lat, lon, alt + +def get_grid_points(data: xr.DataArray): + return np.column_stack((data.lat.values, data.lon.values)) + def detect_grid_type(data: xr.DataArray) -> str: """ @@ -541,3 +566,210 @@ def regrid_da(self, da: xr.DataArray) -> xr.DataArray: is not implemented yet.""" ) return regrid_da + +## classes for verif + +import numpy as np +from scipy.interpolate import LinearNDInterpolator +from scipy.spatial import Delaunay, KDTree + + +def convert_coordinates(coords): + """ + Convert lat-lon coordinates to cartesian coordinates in a unit box + """ + + xyz_coords = np.ndarray((coords.shape[0], 3), dtype="float32") + + xyz_coords[:, 0] = np.cos(np.pi * coords[:, 0] / 180.0) * np.cos(np.pi * coords[:, 1] / 180.0) + xyz_coords[:, 1] = np.cos(np.pi * coords[:, 0] / 180.0) * np.sin(np.pi * coords[:, 1] / 180.0) + xyz_coords[:, 2] = np.sin(np.pi * coords[:, 0] / 180.0) + + return xyz_coords + + +def normalise(x): + return x[:] / np.sum(x[:]) + + +class Verif_interpolator: + """ + Interpolator class that's either a wrapper for scipys LinearNDInterpolator + or uses the handmade approximate 2D linear interpolator + """ + + +class Verif_2D_interpolator(Verif_interpolator): + """ + Class that does approximate 2D interpolation + """ + + def __init__(self, grid_points, obs_points): + """ + Initialise the class and store gridpoints + """ + + grid_xyz = convert_coordinates(grid_points) + obs_xyz = convert_coordinates(obs_points) + + self.indices = np.ndarray((obs_points.shape[0], 5), dtype="float32") + tree = KDTree(grid_xyz) + _, self.indices = tree.query(obs_xyz, k=5) + + self.weights = np.ndarray((obs_points.shape[0], 3), dtype="float32") + self.compute_weights(grid_xyz, obs_xyz) + + def compute_weights(self, grid_xyz, obs_xyz): + """ + Compute the weights of the three nearest grid points + by computing the barycentric coordinates, + assuming that the observations are close enough to the plane through the grid points. + """ + + eps = 0.01 + + for i, (obs, indix) in enumerate(zip(obs_xyz, self.indices, strict=True)): + AB = grid_xyz[indix[1]] - grid_xyz[indix[0]] + AC = grid_xyz[indix[2]] - grid_xyz[indix[0]] + BC = grid_xyz[indix[2]] - grid_xyz[indix[1]] + AP = obs - grid_xyz[indix[0]] + BP = obs - grid_xyz[indix[1]] + + area_tot = np.linalg.norm(np.cross(AB, AC)) + self.weights[i, 0] = np.linalg.norm(np.cross(BC, BP)) + self.weights[i, 1] = np.linalg.norm(np.cross(AC, AP)) + self.weights[i, 2] = np.linalg.norm(np.cross(AB, AP)) + + if 1 - area_tot / np.sum(self.weights[i, :]) < eps: + continue + + indix[2] = indix[3] + + AC = grid_xyz[indix[2]] - grid_xyz[indix[0]] + BC = grid_xyz[indix[2]] - grid_xyz[indix[1]] + + area_tot = np.linalg.norm(np.cross(AB, AC)) + self.weights[i, 0] = np.linalg.norm(np.cross(BC, BP)) + self.weights[i, 1] = np.linalg.norm(np.cross(AC, AP)) + + if 1 - area_tot / np.sum(self.weights[i, :]) < eps: + continue + + indix[2] = indix[4] + + AC = grid_xyz[indix[2]] - grid_xyz[indix[0]] + BC = grid_xyz[indix[2]] - grid_xyz[indix[1]] + + self.weights[i, 0] = np.linalg.norm(np.cross(BC, BP)) + self.weights[i, 1] = np.linalg.norm(np.cross(AC, AP)) + + self.weights = self.weights / self.weights.sum(axis=1)[:, np.newaxis] + + def interpolate(self, values, intmap=None): + """ + Interpolate values to points + """ + + wvalues = np.ndarray((self.weights.shape[0]), dtype="float32") + + if intmap is None: + wvalues[:] = ( + self.weights[:, 0] * values[self.indices[:, 0]] + + self.weights[:, 1] * values[self.indices[:, 1]] + + self.weights[:, 2] * values[self.indices[:, 2]] + ) + else: + wvalues[:] = ( + self.weights[:, 0] * values[intmap[self.indices[:, 0]]] + + self.weights[:, 1] * values[intmap[self.indices[:, 1]]] + + self.weights[:, 2] * values[intmap[self.indices[:, 2]]] + ) + + return wvalues + + +class Verif_lat_lon_interpolator(Verif_interpolator): + """ + Class that does approximate 2D interpolation + """ + + def __init__(self, grid_points, obs_points): + """ + Initialise the class and store gridpoints + """ + + self.obs_points = obs_points + self.triangulation = Delaunay(grid_points) + + def interpolate(self, values, intmap=None): + """ + Interpolate values to points + """ + + newvalues = np.empty_like(values) + + if intmap is None: + newvalues = values + else: + for i in range(len(values)): + newvalues[i] = values[intmap[i]] + + interpolator = LinearNDInterpolator(self.triangulation, newvalues) + + return interpolator(self.obs_points).astype(np.float32) + + +class Verif_nearest_interpolator(Verif_interpolator): + """ + Class that does approximate 2D interpolation + """ + + def __init__(self, grid_points, obs_points): + """ + Initialise the class and store gridpoints + """ + + grid_xyz = convert_coordinates(grid_points) + obs_xyz = convert_coordinates(obs_points) + + tree = KDTree(grid_xyz) + _, self.indices = tree.query(obs_xyz, k=1) + + def interpolate(self, values, intmap=None): + """ + Interpolate values to points + """ + + wvalues = np.ndarray((self.indices.shape), dtype="float32") + + if intmap is None: + wvalues[:] = values[self.indices[:]] + else: + wvalues[:] = values[intmap[self.indices[:]]] + + return wvalues + + +class Interpolator_factory: + def __init__(self, method: str): + valid_methods = ("2d", "lat_lon", "nearest") + + if method not in valid_methods: + raise Exception(f"{method} is not a valid method.") + + self.method = method + + def get_interpolator( + self, zarr_coords: np.ndarray, obs_coords: np.ndarray + ) -> Verif_interpolator: + if self.method == "2d": + _logger.info("2D interpolation") + return Verif_2D_interpolator(zarr_coords, obs_coords) + + elif self.method == "lat_lon": + _logger.info("lat-lon interpolation") + return Verif_lat_lon_interpolator(zarr_coords, obs_coords) + + elif self.method == "nearest": + _logger.info("nearest neighbour interpolation") + return Verif_nearest_interpolator(zarr_coords, obs_coords) diff --git a/packages/verif/src/weathergen/verif/create_verif.py b/packages/verif/src/weathergen/verif/create_verif.py index e3f9ff402..a43453aca 100644 --- a/packages/verif/src/weathergen/verif/create_verif.py +++ b/packages/verif/src/weathergen/verif/create_verif.py @@ -318,7 +318,7 @@ def main(): method_factory = Interpolator_factory(args.method) with ZarrIO(args.zarrfile, read_only=True) as zarrio: - streams = get_streams(zarrio, args.streams) + streams = get_streams(zarrio, "ERA5") t_start = time() From 19da58c50dac42416b84d83df510081c0348458e Mon Sep 17 00:00:00 2001 From: Sorcha Date: Thu, 19 Feb 2026 17:05:43 +0100 Subject: [PATCH 05/28] fixes to windpseed --- .../src/weathergen/evaluate/export/parsers/verif_parser.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py b/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py index 2085bc307..c3e5a5bdb 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py +++ b/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py @@ -132,9 +132,9 @@ def get_output_filename(self, forecast_ref_time: np.datetime64, stream:str, vari self.verif_template.replace("%S", self.stream) .replace("%V", variable) .replace("%M", self.method) - .replace("%D", self.data_type) + .replace("%D", self.data_type + str(forecast_ref_time)) ) - outfile = Path(self.output_dir) / outfile + outfile = Path(self.output_dir) / outfile pathdir = outfile.parent _logger.info(f"Output directory: {pathdir}") pathdir.mkdir(exist_ok=True, parents=True) @@ -213,7 +213,7 @@ def preprocess(self, ds: xr.Dataset) -> xr.Dataset: """ Preprocess variables and only keep relevant ones """ - if all(["10u", "10v"]) in self.channels: + if set(["10u", "10v"]).issubset(self.channels): u = ds.sel(channel='10u') v = ds.sel(channel='10v') # hypotenuese From 34bd4031c1b3b594ae98c92a8138b1b4a9c68e1f Mon Sep 17 00:00:00 2001 From: Sorcha Date: Thu, 19 Feb 2026 19:05:54 +0100 Subject: [PATCH 06/28] adjusting template --- .../src/weathergen/evaluate/export/export_inference.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/export/export_inference.py b/packages/evaluate/src/weathergen/evaluate/export/export_inference.py index a296417f4..1a4e0ea85 100755 --- a/packages/evaluate/src/weathergen/evaluate/export/export_inference.py +++ b/packages/evaluate/src/weathergen/evaluate/export/export_inference.py @@ -198,9 +198,9 @@ def parse_args(args: list) -> argparse.Namespace: parser.add_argument( "--verif-template", - default="verif/%S/%V/verif_%S_%V_%M.nc", + default="verif/%S/%V/verif_%S_%V_%M_%D.nc", help="Template for the output nc filenames, default will be to create output/verif/%S/%V \ - repertories where %S, %V, %d are replaced by the streams, variable and date", + repertories where %S, %V, %M, %D are replaced by the streams, variable, method and date", ) args, unknown_args = parser.parse_known_args(args) From ba9ca9fe54e0a1f5c399e3cf42fda6d11e682006 Mon Sep 17 00:00:00 2001 From: Sorcha Date: Fri, 20 Feb 2026 18:21:44 +0100 Subject: [PATCH 07/28] adding multi stream handling --- .../weathergen/evaluate/export/export_core.py | 63 +++++--- .../evaluate/export/export_inference.py | 3 +- .../evaluate/export/parsers/verif_parser.py | 153 ++++++++++++------ 3 files changed, 146 insertions(+), 73 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/export/export_core.py b/packages/evaluate/src/weathergen/evaluate/export/export_core.py index 4f7f2a0c8..3ebcb6740 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/export_core.py +++ b/packages/evaluate/src/weathergen/evaluate/export/export_core.py @@ -169,6 +169,16 @@ def get_ref_times(fname_zarr, stream, samples, fstep_hours) -> list[np.datetime6 ref_times.append(ref_time) return ref_times +def get_streams(stream, fname_zarr): + with zarrio_reader(fname_zarr) as zio: + zio_streams = zio.streams + streams = ( + zio_streams + if stream is None + else [stream] + ) + return streams + def export_model_outputs(data_type: str, config: OmegaConf, **kwargs) -> None: """ @@ -230,30 +240,31 @@ def export_model_outputs(data_type: str, config: OmegaConf, **kwargs) -> None: fname_zarr = get_model_results(run_id, epoch, rank) fsteps = get_fsteps(fsteps, fname_zarr) samples = get_samples(samples, fname_zarr) - grid_type = get_grid_type(data_type, stream, fname_zarr) - channels = get_channels(channels, stream, fname_zarr) - ref_times = get_ref_times(fname_zarr, stream, samples, fstep_hours) - - kwargs["grid_type"] = grid_type - kwargs["channels"] = channels - kwargs["data_type"] = data_type - - with Pool(processes=n_processes, maxtasksperchild=5) as pool: - parser = CfParserFactory.get_parser(config=config, **kwargs) - - for s_idx, sample in enumerate(tqdm(samples)): - ref_time = ref_times[s_idx] - - step_tasks = [ - (sample, fstep, run_id, stream, data_type, epoch, rank) for fstep in fsteps - ] - - results_iterator = pool.imap_unordered(get_data_worker, step_tasks, chunksize=1) - - parser.process_sample( - results_iterator, - ref_time=ref_time, - ) + streams = get_streams(stream, fname_zarr) + for stream in streams: + grid_type = get_grid_type(data_type, stream, fname_zarr) + channels = get_channels(channels, stream, fname_zarr) + ref_times = get_ref_times(fname_zarr, stream, samples, fstep_hours) + kwargs["stream"] = stream + kwargs["grid_type"] = grid_type + kwargs["channels"] = channels + kwargs["data_type"] = data_type + + with Pool(processes=n_processes, maxtasksperchild=5) as pool: + parser = CfParserFactory.get_parser(config=config, **kwargs) + + for s_idx, sample in enumerate(tqdm(samples)): + ref_time = ref_times[s_idx] + step_tasks = [ + (sample, fstep, run_id, stream, data_type, epoch, rank) for fstep in fsteps + ] + + results_iterator = pool.imap_unordered(get_data_worker, step_tasks, chunksize=1) + + parser.process_sample( + results_iterator, + ref_time=ref_time, + ) - pool.terminate() - pool.join() + pool.terminate() + pool.join() \ No newline at end of file diff --git a/packages/evaluate/src/weathergen/evaluate/export/export_inference.py b/packages/evaluate/src/weathergen/evaluate/export/export_inference.py index 1a4e0ea85..8399a9bbe 100755 --- a/packages/evaluate/src/weathergen/evaluate/export/export_inference.py +++ b/packages/evaluate/src/weathergen/evaluate/export/export_inference.py @@ -83,9 +83,8 @@ def parse_args(args: list) -> argparse.Namespace: parser.add_argument( "--stream", type=str, - choices=["ERA5"], + choices=["ERA5", "CERRA", "MEPS", "NORA3"], help="Stream name to retrieve data for", - required=True, ) parser.add_argument( diff --git a/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py b/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py index c3e5a5bdb..7c042a6ba 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py +++ b/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py @@ -1,3 +1,4 @@ +import contextlib import logging from pathlib import Path from typing import Any @@ -80,6 +81,7 @@ def process_sample( required_channels = ["10u", "10v", "sp", "2t", "msl"] self.channels = list(set(self.channels) & set(required_channels)) da_fs = [] + valid_times = [] for result in fstep_iterator_results: if result is None: continue @@ -89,6 +91,7 @@ def process_sample( result = self.preprocess(result) result = self.reshape(result) da_fs.append(result) + valid_times.append(result.valid_time.values[0]) _logger.info(f"Retrieved {len(da_fs)} forecast steps for type {self.data_type}.") @@ -101,12 +104,18 @@ def process_sample( da_fs = self.concatenate(da_fs) da_fs = self.assign_frt(da_fs, ref_time) da_fs = self.add_attrs(da_fs) - da_fs = self.add_metadata(da_fs) - da_fs = self.add_encoding(da_fs) for verif_var in self.mapping.keys(): - da_var = self.regrid(da_fs, verif_var) - obs_result = self.obs_preprocess(ref_time, verif_var) - self.save(da_var, obs_result, ref_time, verif_var) + mapped_info = self.mapping.get(verif_var, {}) + wg_var = mapped_info.get("var", None) + da_var = self.regrid(da_fs, wg_var, valid_times) + if da_var is None: + continue + da_var = self.add_encoding(da_var) + obs_result = self.obs_preprocess(valid_times, verif_var) + obs_result = self.add_encoding(obs_result) + merged = self.merge(da_var, obs_result) + merged = self.add_metadata(merged,verif_var) + self.save(merged, ref_time, verif_var) _logger.info(f"Saved {verif_var} data for {ref_time} to \ {self.output_format} in {self.output_dir}.") @@ -117,7 +126,7 @@ def get_zarr_dt(self, ds): zarr_dt = zarr_dt.astype("timedelta64[h]") return zarr_dt - def get_output_filename(self, forecast_ref_time: np.datetime64, stream:str, variable:str) -> Path: + def get_output_filename(self, forecast_ref_time: np.datetime64, variable:str) -> Path: """ Create output directories for the verif files and return path to output file @@ -196,18 +205,23 @@ def reshape(self, data: xr.DataArray) -> xr.Dataset: return reshaped_dataset - def obs_preprocess(self, ref_time, verif_var): + def obs_preprocess(self, valid_times, verif_var): obs_data = self.obs mapped_info = self.mapping.get(verif_var, {}) obs_name = mapped_info.get("obs_name", {}) if verif_var == "mslp": - obs_data[obs_name].values = compute_mslp(obs_data, ref_time) + for vtime in valid_times: + obs_data[obs_name].sel(time = valid_times).values = compute_mslp(obs_data, vtime) if verif_var == "tp": - obs_data[obs_name].sel(time = ref_time).values = compute_precip(obs_data, self.zarr_dt, ref_time) + for vtime in valid_times: + obs_data[obs_name].sel(time = vtime).values = compute_precip(obs_data, self.zarr_dt, vtime) else: pass + + new_xarray = obs_data[obs_name].sel(time = valid_times) + new_xarray.name = "obs" - return obs_data[obs_name] + return new_xarray def preprocess(self, ds: xr.Dataset) -> xr.Dataset: """ @@ -232,14 +246,14 @@ def preprocess(self, ds: xr.Dataset) -> xr.Dataset: else: return ds - def regrid(self, ds: xr.Dataset, verif_var) -> xr.Dataset: + def regrid(self, ds: xr.Dataset, verif_var, valid_times) -> xr.Dataset: """ Regrid a single xarray Datas v vvbet using specific method. Parameters ---------- ds: native xarray Dataset Returns - ------- + -------s Regridded xarray Dataset. """ try: @@ -247,14 +261,63 @@ def regrid(self, ds: xr.Dataset, verif_var) -> xr.Dataset: except KeyError as e: _logger.info(f"{verif_var} not available in WeatherGenerator output: {e}") return + # set coords + new_coords = ds_var.coords.copy() + new_coords.update( + {"location": self.obs.location.values} + ) + new_coords._drop_coords(["ncells"]) + + # set attrs + attrs = ds_var.attrs.copy() + with contextlib.suppress(KeyError): + del attrs["ncells"]# + + original_shape = ds_var.shape + new_shape = list(original_shape) + pos = ds_var.dims.index("ncells") + new_shape[pos] = self.obs.location.shape[0] + #rearrange to be time,location + order = [1,0] + new_shape = [new_shape[x] for x in order] + + fcstdata = np.ndarray(new_shape, dtype=np.float32) + + #set interpolation method method_factory = Interpolator_factory(self.method) interpolator = method_factory.get_interpolator(self.zarr_coords, self.obs_coords) + + #fix lat, lon + latitude_da = xr.DataArray( + self.obs.latitude.values, + dims=["location"], + attrs=ds_var.latitude.attrs + ) + longitude_da = xr.DataArray( + self.obs.longitude.values, + dims=["location"], + attrs=ds_var.longitude.attrs + ) - for idx in range(len(ds_var.time.values)): - interpolator.interpolate( + for idx in range(len(valid_times)): + + regrid_values = interpolator.interpolate( ds_var.values[:,idx] ) - return ds_var + fcstdata[idx,:] = regrid_values + + regridded_var = xr.DataArray( + fcstdata, + dims=["time", "location"], + coords={ + **new_coords, + "latitude": latitude_da, + "longitude": longitude_da + }, + name="fcst", + attrs = attrs + ) + return regridded_var def concatenate( @@ -371,6 +434,29 @@ def add_encoding(self, ds: xr.Dataset) -> xr.Dataset: return ds + def add_metadata(self, ds: xr.Dataset, verif_var) -> xr.Dataset: + """ + Add CF conventions to the dataset attributes. + + Parameters + ---------- + ds : Input xarray Dataset to add conventions to. + Returns + ------- + xarray Dataset with CF conventions added to attributes. + """ + ds.attrs["title"] = f"WeatherGenerator Output for {self.run_id}, variable {verif_var} using stream {self.stream}" + ds.attrs["institution"] = "WeatherGenerator Collaboration" + ds.attrs["source"] = "WeatherGenerator v0.0" + ds.attrs["history"] = ( + "Created using the verif_parser on " + + np.datetime_as_string(np.datetime64("now"), unit="s") + ) + ds.attrs["Conventions"] = "verif_1.0.0" + # drop stream now it's in folder layout + ds = ds.drop_vars("stream") + return ds + def _attrs_gaussian_grid(self, ds: xr.Dataset) -> xr.Dataset: """ Assign CF-compliant attributes to variables in a gaussian grid dataset. @@ -480,31 +566,13 @@ def _build_coordinate_mapping( return coords - def add_metadata(self, ds: xr.Dataset) -> xr.Dataset: - """ - Add CF conventions to the dataset attributes. + def merge(self, ds, obs_ds): + lat, lon, alt = get_obs_coordinates(self.obs) + merged = xr.merge([ds, obs_ds, lat, lon, alt]) + return merged - Parameters - ---------- - ds : Input xarray Dataset to add conventions to. - Returns - ------- - xarray Dataset with CF conventions added to attributes. - """ - # ds = ds.copy() - ds.attrs["title"] = f"WeatherGenerator Output for {self.run_id} using stream {self.stream}" - ds.attrs["institution"] = "WeatherGenerator Project" - ds.attrs["source"] = "WeatherGenerator v0.0" - ds.attrs["history"] = ( - "Created using the export_inference.py script on " - + np.datetime_as_string(np.datetime64("now"), unit="s") - ) - ds.attrs["Conventions"] = "verif_1.0.0" - # drop stream now it's in title - ds = ds.drop_vars("stream") - return ds - def save(self, ds: xr.Dataset, obs_result, forecast_ref_time: np.datetime64, verif_var) -> None: + def save(self, ds: xr.Dataset, forecast_ref_time: np.datetime64, verif_var) -> None: """ Save the dataset to a NetCDF file. @@ -518,12 +586,7 @@ def save(self, ds: xr.Dataset, obs_result, forecast_ref_time: np.datetime64, ver ------- None """ - lat, lon, alt = get_obs_coordinates(self.obs) - if ds is None: - merged = xr.merge([obs_result, lat, lon, alt]) - else: - merged = xr.merge([ds, obs_result, lat, lon, alt]) - out_fname = self.get_output_filename(forecast_ref_time, self.stream, verif_var) + out_fname = self.get_output_filename(forecast_ref_time, verif_var) _logger.info(f"Saving to {out_fname}.") - merged.to_netcdf(out_fname) + ds.to_netcdf(out_fname) _logger.info(f"Saved NetCDF file to {out_fname}.") From b88b037ff0ba9328218fc991f2de181bb5d50ce0 Mon Sep 17 00:00:00 2001 From: Sorcha Date: Mon, 23 Feb 2026 13:14:22 +0100 Subject: [PATCH 08/28] first commit --- .../src/weathergen/evaluate/export/export_core.py | 11 ++++------- .../weathergen/evaluate/export/export_inference.py | 11 ++++------- pyproject.toml | 2 -- 3 files changed, 8 insertions(+), 16 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/export/export_core.py b/packages/evaluate/src/weathergen/evaluate/export/export_core.py index 3ebcb6740..d0d9837f8 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/export_core.py +++ b/packages/evaluate/src/weathergen/evaluate/export/export_core.py @@ -169,14 +169,11 @@ def get_ref_times(fname_zarr, stream, samples, fstep_hours) -> list[np.datetime6 ref_times.append(ref_time) return ref_times + def get_streams(stream, fname_zarr): with zarrio_reader(fname_zarr) as zio: - zio_streams = zio.streams - streams = ( - zio_streams - if stream is None - else [stream] - ) + zio_streams = zio.streams + streams = zio_streams if stream is None else [stream] return streams @@ -267,4 +264,4 @@ def export_model_outputs(data_type: str, config: OmegaConf, **kwargs) -> None: ) pool.terminate() - pool.join() \ No newline at end of file + pool.join() diff --git a/packages/evaluate/src/weathergen/evaluate/export/export_inference.py b/packages/evaluate/src/weathergen/evaluate/export/export_inference.py index 8399a9bbe..0cae19144 100755 --- a/packages/evaluate/src/weathergen/evaluate/export/export_inference.py +++ b/packages/evaluate/src/weathergen/evaluate/export/export_inference.py @@ -181,11 +181,7 @@ def parse_args(args: list) -> argparse.Namespace: help="Type of grid to regrid to (only used if --regrid-degree is specified)", ) - parser.add_argument( - "-b", - "--obs", - help= "observation file for creating verif files" - ) + parser.add_argument("-b", "--obs", help="observation file for creating verif files") parser.add_argument( "-m", @@ -199,12 +195,13 @@ def parse_args(args: list) -> argparse.Namespace: "--verif-template", default="verif/%S/%V/verif_%S_%V_%M_%D.nc", help="Template for the output nc filenames, default will be to create output/verif/%S/%V \ - repertories where %S, %V, %M, %D are replaced by the streams, variable, method and date", + repertories where %S, %V, %M, %D are replaced by the " + "streams, variable, method and date", ) args, unknown_args = parser.parse_known_args(args) if unknown_args: - _logger.warning(f"Unknown arguments: {unknown_args}") + _logger.warning(f"Unknown arguments: {unknown_args}") return args diff --git a/pyproject.toml b/pyproject.toml index 7c9d1b725..a5f6f5dbb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,6 @@ dependencies = [ "weathergen-common", "weathergen-evaluate", "weathergen-readers-extra", - "weathergen-verif" ] @@ -226,7 +225,6 @@ weathergen-common = { workspace = true } weathergen-evaluate = { workspace = true } weathergen-metrics = { workspace = true } weathergen-readers-extra = { workspace = true } -weathergen-verif = { workspace = true } flash-attn = [ From 703b723db79a02febc1a5f4a55fefda6f86f776f Mon Sep 17 00:00:00 2001 From: Sorcha Date: Mon, 23 Feb 2026 13:15:27 +0100 Subject: [PATCH 09/28] removing verif as a package --- .../weathergen/evaluate/export/io_utils.py | 38 -- .../evaluate/export/parser_factory.py | 2 +- .../evaluate/export/parsers/verif_parser.py | 202 +++++---- .../weathergen/evaluate/export/preprocess.py | 63 ++- .../src/weathergen/evaluate/export/reshape.py | 95 +++-- packages/verif/pyproject.toml | 75 ---- .../verif/src/weathergen/verif/__init__.py | 0 .../src/weathergen/verif/create_verif.py | 401 ------------------ .../src/weathergen/verif/verif_config.py | 44 -- .../src/weathergen/verif/verif_config.yaml | 74 ---- .../weathergen/verif/verif_interpolator.py | 204 --------- .../src/weathergen/verif/verif_processers.py | 190 --------- 12 files changed, 213 insertions(+), 1175 deletions(-) delete mode 100644 packages/verif/pyproject.toml delete mode 100644 packages/verif/src/weathergen/verif/__init__.py delete mode 100644 packages/verif/src/weathergen/verif/create_verif.py delete mode 100644 packages/verif/src/weathergen/verif/verif_config.py delete mode 100644 packages/verif/src/weathergen/verif/verif_config.yaml delete mode 100644 packages/verif/src/weathergen/verif/verif_interpolator.py delete mode 100644 packages/verif/src/weathergen/verif/verif_processers.py diff --git a/packages/evaluate/src/weathergen/evaluate/export/io_utils.py b/packages/evaluate/src/weathergen/evaluate/export/io_utils.py index 710cc19b9..e9dd3ea52 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/io_utils.py +++ b/packages/evaluate/src/weathergen/evaluate/export/io_utils.py @@ -1,7 +1,5 @@ import logging -from pathlib import Path -import numpy as np import xarray as xr from weathergen.common.config import get_model_results @@ -11,42 +9,6 @@ _logger.setLevel(logging.INFO) -# def output_filename( -# prefix: str, -# run_id: str, -# output_dir: str, -# output_format: str, -# forecast_ref_time: np.datetime64, -# regrid_degree: float, -# ) -> Path: -# """ -# Generate output filename based on prefix (should refer to type e.g. pred/targ), run_id, sample -# index, output directory, format and forecast_ref_time. - -# Parameters -# ---------- -# prefix : Prefix for file name (e.g., 'pred' or 'targ'). -# run_id :Run ID to include in the filename. -# output_dir : Directory to save the output file. -# output_format : Output file format (currently only 'netcdf' supported). -# forecast_ref_time : Forecast reference time to include in the filename. - -# Returns -# ------- -# Full path to the output file. -# """ -# if output_format not in ["netcdf", "verif"]: -# raise ValueError( -# f"Unsupported output format: {output_format}, supported formates are ['netcdf']" -# ) -# file_extension = "nc" -# frt = np.datetime_as_string(forecast_ref_time, unit="h") -# if regrid_degree is not None: -# run_id += f"_regular{regrid_degree, regrid_degree}" -# out_fname = Path(output_dir) / f"{prefix}_{frt}_{run_id}.{file_extension}" -# return out_fname - - def get_data_worker(args: tuple) -> xr.DataArray: """ Worker function to retrieve data for a single sample and forecast step. diff --git a/packages/evaluate/src/weathergen/evaluate/export/parser_factory.py b/packages/evaluate/src/weathergen/evaluate/export/parser_factory.py index 9d5e1c621..1d3f7e26b 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/parser_factory.py +++ b/packages/evaluate/src/weathergen/evaluate/export/parser_factory.py @@ -31,7 +31,7 @@ def get_parser(config: OmegaConf, **kwargs) -> CfParser: _parser_map = { "netcdf": (NetcdfParser, ["grid_type"]), "quaver": (QuaverParser, ["grid_type", "channels", "template"]), - "verif": (VerifParser, ["obs", "method", "verif_template"]) + "verif": (VerifParser, ["obs", "method", "verif_template"]), } fmt = kwargs.get("output_format") diff --git a/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py b/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py index 7c042a6ba..ccb7bc312 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py +++ b/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py @@ -8,9 +8,13 @@ from omegaconf import OmegaConf from weathergen.evaluate.export.cf_utils import CfParser -from weathergen.evaluate.export.reshape import find_pl, get_grid_points, get_obs_coordinates, Interpolator_factory from weathergen.evaluate.export.preprocess import compute_mslp, compute_precip - +from weathergen.evaluate.export.reshape import ( + InterpolatorFactory, + find_pl, + get_grid_points, + get_obs_coordinates, +) _logger = logging.getLogger(__name__) _logger.setLevel(logging.INFO) @@ -53,7 +57,7 @@ def __init__(self, config: OmegaConf, **kwargs): super().__init__(config=config) if not hasattr(self, "obs"): - raise ValueError("Observation data required for creating verif compliant NetCDFs") + raise ValueError("Observation data required for creating verif compliant NetCDFs") self.mapping = config.get("variables", {}) @@ -78,6 +82,13 @@ def process_sample( ------- None """ + # check ref_time exists in the obs data + if ref_time not in self.obs.time.values: + _logger.warning( + f"Reference time {ref_time} not found in observation data. Skipping sample." + ) + return + required_channels = ["10u", "10v", "sp", "2t", "msl"] self.channels = list(set(self.channels) & set(required_channels)) da_fs = [] @@ -100,33 +111,39 @@ def process_sample( self.zarr_coords = get_grid_points(da_fs[0]) self.zarr_dt = self.get_zarr_dt(da_fs[0]) - da_fs = self.concatenate(da_fs) da_fs = self.assign_frt(da_fs, ref_time) da_fs = self.add_attrs(da_fs) + for verif_var in self.mapping.keys(): - mapped_info = self.mapping.get(verif_var, {}) - wg_var = mapped_info.get("var", None) - da_var = self.regrid(da_fs, wg_var, valid_times) + da_var = self.regrid(da_fs, verif_var, valid_times) if da_var is None: continue da_var = self.add_encoding(da_var) obs_result = self.obs_preprocess(valid_times, verif_var) obs_result = self.add_encoding(obs_result) merged = self.merge(da_var, obs_result) - merged = self.add_metadata(merged,verif_var) + merged = self.add_metadata(merged, verif_var) self.save(merged, ref_time, verif_var) - _logger.info(f"Saved {verif_var} data for {ref_time} to \ - {self.output_format} in {self.output_dir}.") - def get_zarr_dt(self, ds): - zarr_dt = ( - ds.source_interval_end.values - ds.source_interval_start.values - ) + def get_zarr_dt(self, ds: xr.Dataset) -> np.timedelta64: + """ + Compute the time difference between forecast steps in hours from the WG output dataset. + Parameters + ---------- + ds : xr.Dataset + Input dataset from which to compute the time difference. + Returns + ------- + np.timedelta64 + Time difference between forecast steps in hours. + """ + zarr_dt = ds.source_interval_end.values - ds.source_interval_start.values zarr_dt = zarr_dt.astype("timedelta64[h]") + return zarr_dt - - def get_output_filename(self, forecast_ref_time: np.datetime64, variable:str) -> Path: + + def get_output_filename(self, forecast_ref_time: np.datetime64, variable: str) -> Path: """ Create output directories for the verif files and return path to output file @@ -141,14 +158,14 @@ def get_output_filename(self, forecast_ref_time: np.datetime64, variable:str) -> self.verif_template.replace("%S", self.stream) .replace("%V", variable) .replace("%M", self.method) - .replace("%D", self.data_type + str(forecast_ref_time)) + .replace("%D", self.data_type + np.datetime_as_string(forecast_ref_time, unit="h")) ) - outfile = Path(self.output_dir) / outfile + outfile = Path(self.output_dir) / outfile pathdir = outfile.parent _logger.info(f"Output directory: {pathdir}") pathdir.mkdir(exist_ok=True, parents=True) return outfile - + def reshape(self, data: xr.DataArray) -> xr.Dataset: """ Reshape dataset while preserving grid structure (regular or Gaussian). @@ -204,51 +221,78 @@ def reshape(self, data: xr.DataArray) -> xr.Dataset: return reshaped_dataset - - def obs_preprocess(self, valid_times, verif_var): + def obs_preprocess(self, valid_times: list[np.datetime64], verif_var: str) -> xr.DataArray: + """ + Preprocess the observation data for the given variable and valid times. + This includes computing derived variables like MSLP and total precipitation if needed. + Parameters + ---------- + valid_times : list of np.datetime64 + List of valid times for which to preprocess the observation data. + verif_var : str + The variable for which to preprocess the observation data (e.g., 'mslp', 'tp'). + Returns + ------- + xr.DataArray + Preprocessed observation data for the specified variable and valid times. + """ obs_data = self.obs - mapped_info = self.mapping.get(verif_var, {}) + mapped_info = self.mapping.get(verif_var, {}) obs_name = mapped_info.get("obs_name", {}) if verif_var == "mslp": for vtime in valid_times: - obs_data[obs_name].sel(time = valid_times).values = compute_mslp(obs_data, vtime) + obs_data[obs_name].sel(time=valid_times).values = compute_mslp(obs_data, vtime) if verif_var == "tp": for vtime in valid_times: - obs_data[obs_name].sel(time = vtime).values = compute_precip(obs_data, self.zarr_dt, vtime) + obs_data[obs_name].sel(time=vtime).values = compute_precip( + obs_data, self.zarr_dt, vtime + ) else: pass - - new_xarray = obs_data[obs_name].sel(time = valid_times) + + new_xarray = obs_data[obs_name].sel(time=valid_times) new_xarray.name = "obs" return new_xarray - + def preprocess(self, ds: xr.Dataset) -> xr.Dataset: """ - Preprocess variables and only keep relevant ones + Preprocess variables and only keep relevant ones for WG output + Parameters + ---------- + ds : xr.Dataset + + + Returns + ------- + xr.Dataset """ if set(["10u", "10v"]).issubset(self.channels): - u = ds.sel(channel='10u') - v = ds.sel(channel='10v') + u = ds.sel(channel="10u") + v = ds.sel(channel="10v") # hypotenuese - wind_speed = xr.apply_ufunc( + wind_speed = xr.apply_ufunc( np.hypot, u, v, dask="parallelized", output_dtypes=[ds.dtype] ).astype("float32") - wind_speed = wind_speed.expand_dims(channel=['10si']) + wind_speed = wind_speed.expand_dims(channel=["10si"]) if ds.chunks: - wind_speed = wind_speed.chunk({"ipoint": ds.chunks[ds.get_axis_num("ipoint")][0], "channel": 1}) + wind_speed = wind_speed.chunk( + {"ipoint": ds.chunks[ds.get_axis_num("ipoint")][0], "channel": 1} + ) new_ds = xr.concat([ds, wind_speed], dim="channel") new_ds.attrs = ds.attrs # remove unnecessary - new_ds = new_ds.drop_sel(channel = ['10u', '10v']) + new_ds = new_ds.drop_sel(channel=["10u", "10v"]) return new_ds else: return ds - def regrid(self, ds: xr.Dataset, verif_var, valid_times) -> xr.Dataset: + def regrid( + self, ds: xr.Dataset, verif_var: str, valid_times: list[np.datetime64] + ) -> xr.Dataset: """ - Regrid a single xarray Datas v vvbet using specific method. + Regrid a single xarray Dataset using specific method. Parameters ---------- ds: native xarray Dataset @@ -256,70 +300,59 @@ def regrid(self, ds: xr.Dataset, verif_var, valid_times) -> xr.Dataset: -------s Regridded xarray Dataset. """ + mapped_info = self.mapping.get(verif_var, {}) + wg_var = mapped_info.get("var", None) + try: - ds_var = ds[verif_var] + ds_var = ds[wg_var] except KeyError as e: - _logger.info(f"{verif_var} not available in WeatherGenerator output: {e}") + _logger.info(f"{wg_var} not available in WeatherGenerator output: {e}") return # set coords new_coords = ds_var.coords.copy() - new_coords.update( - {"location": self.obs.location.values} - ) + new_coords.update({"location": self.obs.location.values}) new_coords._drop_coords(["ncells"]) # set attrs attrs = ds_var.attrs.copy() with contextlib.suppress(KeyError): - del attrs["ncells"]# + del attrs["ncells"] # original_shape = ds_var.shape new_shape = list(original_shape) pos = ds_var.dims.index("ncells") new_shape[pos] = self.obs.location.shape[0] - #rearrange to be time,location - order = [1,0] + # rearrange to be time,location + order = [1, 0] new_shape = [new_shape[x] for x in order] - fcstdata = np.ndarray(new_shape, dtype=np.float32) + fcstdata = np.empty(new_shape, dtype=np.float32) - #set interpolation method - method_factory = Interpolator_factory(self.method) + # set interpolation method + method_factory = InterpolatorFactory(self.method) interpolator = method_factory.get_interpolator(self.zarr_coords, self.obs_coords) - #fix lat, lon + # fix lat, lon latitude_da = xr.DataArray( - self.obs.latitude.values, - dims=["location"], - attrs=ds_var.latitude.attrs - ) + self.obs.latitude.values, dims=["location"], attrs=ds_var.latitude.attrs + ) longitude_da = xr.DataArray( - self.obs.longitude.values, - dims=["location"], - attrs=ds_var.longitude.attrs - ) - + self.obs.longitude.values, dims=["location"], attrs=ds_var.longitude.attrs + ) + for idx in range(len(valid_times)): - - regrid_values = interpolator.interpolate( - ds_var.values[:,idx] - ) - fcstdata[idx,:] = regrid_values + regrid_values = interpolator.interpolate(ds_var.values[:, idx]) + fcstdata[idx, :] = regrid_values regridded_var = xr.DataArray( - fcstdata, - dims=["time", "location"], - coords={ - **new_coords, - "latitude": latitude_da, - "longitude": longitude_da - }, - name="fcst", - attrs = attrs - ) + fcstdata, + dims=["time", "location"], + coords={**new_coords, "latitude": latitude_da, "longitude": longitude_da}, + name="fcst", + attrs=attrs, + ) return regridded_var - def concatenate( self, array_list, @@ -386,7 +419,7 @@ def assign_frt(self, ds: xr.Dataset, reference_time: np.datetime64) -> xr.Datase if "sample" in ds.coords: ds = ds.drop_vars("sample") - n_hours = self.fstep_hours.astype('int64') + n_hours = self.fstep_hours.astype("int64") ds["forecast_step"] = ds["forecast_step"] * n_hours return ds @@ -445,12 +478,14 @@ def add_metadata(self, ds: xr.Dataset, verif_var) -> xr.Dataset: ------- xarray Dataset with CF conventions added to attributes. """ - ds.attrs["title"] = f"WeatherGenerator Output for {self.run_id}, variable {verif_var} using stream {self.stream}" + ds.attrs["title"] = ( + f"WeatherGenerator Output for {self.run_id}, variable {verif_var} " + f"using stream {self.stream}" + ) ds.attrs["institution"] = "WeatherGenerator Collaboration" ds.attrs["source"] = "WeatherGenerator v0.0" - ds.attrs["history"] = ( - "Created using the verif_parser on " - + np.datetime_as_string(np.datetime64("now"), unit="s") + ds.attrs["history"] = "Created using the verif_parser on " + np.datetime_as_string( + np.datetime64("now"), unit="s" ) ds.attrs["Conventions"] = "verif_1.0.0" # drop stream now it's in folder layout @@ -484,9 +519,9 @@ def _attrs_gaussian_grid(self, ds: xr.Dataset) -> xr.Dataset: wg_unit = mapped_units.get(self.stream, "DEFAULT") verif_unit = mapped_info.get("verif_unit", None) if wg_unit != verif_unit: - #perform unit conversion + # perform unit conversion da.values = da.values * unit_conversion[wg_unit] - + attributes = { "units": verif_unit, } @@ -571,8 +606,7 @@ def merge(self, ds, obs_ds): merged = xr.merge([ds, obs_ds, lat, lon, alt]) return merged - - def save(self, ds: xr.Dataset, forecast_ref_time: np.datetime64, verif_var) -> None: + def save(self, ds: xr.Dataset, forecast_ref_time: np.datetime64, verif_var: str) -> None: """ Save the dataset to a NetCDF file. @@ -590,3 +624,7 @@ def save(self, ds: xr.Dataset, forecast_ref_time: np.datetime64, verif_var) -> N _logger.info(f"Saving to {out_fname}.") ds.to_netcdf(out_fname) _logger.info(f"Saved NetCDF file to {out_fname}.") + _logger.info( + f"Saved {verif_var} data for {forecast_ref_time} to" + f" {self.output_format} in {self.output_dir}." + ) diff --git a/packages/evaluate/src/weathergen/evaluate/export/preprocess.py b/packages/evaluate/src/weathergen/evaluate/export/preprocess.py index 9e0dab513..1f650cd21 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/preprocess.py +++ b/packages/evaluate/src/weathergen/evaluate/export/preprocess.py @@ -1,53 +1,76 @@ -import logging import numpy as np import xarray as xr - -_logger = logging.getLogger(__name__) -_logger.setLevel(logging.INFO) - """ Extra helper functions to preprocess data e.g. for verif applications """ -def compute_mslp(obs: xr.DataArray, time: np.datetime64) -> np.ndarray: +def compute_mslp(obs: xr.DataArray, time: np.datetime64) -> np.typing.NDArray: + """ + Compute mean sea level pressure (MSLP) from surface air pressure, + air temperature, and relative humidity. + Parameters + ---------- + obs : xarray DataArray + Input data containing surface air pressure, air temperature, and relative humidity. + time : np.datetime64 + Time over which to compute mean for the MSLP. + Returns + ------- + np.ndarray + Computed mean sea level pressure values. + """ # g = 9.80665 # Gravitational acceleration (m/s**2) # R = 8.31447 # Universal gas constant (J/mol*K) # a = 0.0065 # Temperature lapse rate (K/m) # Ch = 0.0012 # (K/Pa) - A = 17.625 - B = 243.03 - C = 6.1094 + a = 17.625 + b = 243.03 + c = 6.1094 - P = obs.data_vars["surface_air_pressure"].sel(time=time) - T = obs.data_vars["air_temperature"].sel(time=time) + p = obs.data_vars["surface_air_pressure"].sel(time=time) + t = obs.data_vars["air_temperature"].sel(time=time) rh = obs.data_vars["relative_humidity"].sel(time=time) altitude = obs.altitude - e = rh * 6.11 * np.power(10.0, ((7.5 * (T - 273.15)) / (T - 38.85))) + e = rh * 6.11 * np.power(10.0, ((7.5 * (t - 273.15)) / (t - 38.85))) - dewpoint = np.where(~np.isnan(e), B * np.log(e / C) / (A - np.log(e / C)), T - 276.15) + dewpoint = np.where(~np.isnan(e), b * np.log(e / c) / (a - np.log(e / c)), t - 276.15) e = np.where(np.isnan(e), 0, e) - Tv = T / ( - 1.0 - 0.379 * (6.11 * np.power(10.0, ((7.5 * dewpoint) / (237.7 + dewpoint))) / P) - ) + tv = t / (1.0 - 0.379 * (6.11 * np.power(10.0, ((7.5 * dewpoint) / (237.7 + dewpoint))) / p)) # mslp = np.where(altitude >= 50., - # P * np.exp((g * altitude / R) / (T + 0.5 * a * altitude + e * Ch)), - # P + P * altitude / (29.27 * Tv)) + # p * np.exp((g * altitude / R) / (t + 0.5 * a * altitude + e * Ch)), + # p + p * altitude / (29.27 * tv)) - mslp = P + P * altitude / (29.27 * Tv) + mslp = p + p * altitude / (29.27 * tv) return mslp -def compute_precip(obs_data, zarr_dt, frt): +def compute_precip( + obs_data: xr.Dataset, zarr_dt: np.timedelta64, frt: np.datetime64 +) -> np.typing.NDArray: + """ + Compute accumulated precipitation over the forecast time step. + Parameters + ---------- + obs_data : xarray Dataset + Input data containing precipitation observations. + zarr_dt : np.timedelta64 + Time difference between forecast steps in hours. + frt : np.datetime64 + Forecast reference time for which to compute accumulated precipitation. + Returns + ------- + np.ndarray + Accumulated precipitation values for the forecast time step.""" obs_dt = obs_data.time.values[1] - obs_data.time.values[0] obs_dt = obs_dt.astype("timedelta64[h]") diff --git a/packages/evaluate/src/weathergen/evaluate/export/reshape.py b/packages/evaluate/src/weathergen/evaluate/export/reshape.py index b1e3adb2d..b6f627c27 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/reshape.py +++ b/packages/evaluate/src/weathergen/evaluate/export/reshape.py @@ -6,6 +6,8 @@ import numpy as np import xarray as xr from earthkit.regrid import interpolate +from scipy.interpolate import LinearNDInterpolator +from scipy.spatial import Delaunay, KDTree _logger = logging.getLogger(__name__) _logger.setLevel(logging.INFO) @@ -14,6 +16,7 @@ Enhanced functions to handle Gaussian grids when converting from Zarr to NetCDF. """ + def get_obs_coordinates(obs: xr.Dataset): """ Extract latitude, longitude and altitude @@ -36,6 +39,7 @@ def get_obs_coordinates(obs: xr.Dataset): return lat, lon, alt + def get_grid_points(data: xr.DataArray): return np.column_stack((data.lat.values, data.lon.values)) @@ -567,19 +571,15 @@ def regrid_da(self, da: xr.DataArray) -> xr.DataArray: ) return regrid_da -## classes for verif - -import numpy as np -from scipy.interpolate import LinearNDInterpolator -from scipy.spatial import Delaunay, KDTree +## functions for verif -def convert_coordinates(coords): +def convert_coordinates(coords: np.typing.NDArray) -> np.typing.NDArray: """ Convert lat-lon coordinates to cartesian coordinates in a unit box """ - xyz_coords = np.ndarray((coords.shape[0], 3), dtype="float32") + xyz_coords = np.empty((coords.shape[0], 3), dtype="float32") xyz_coords[:, 0] = np.cos(np.pi * coords[:, 0] / 180.0) * np.cos(np.pi * coords[:, 1] / 180.0) xyz_coords[:, 1] = np.cos(np.pi * coords[:, 0] / 180.0) * np.sin(np.pi * coords[:, 1] / 180.0) @@ -588,23 +588,26 @@ def convert_coordinates(coords): return xyz_coords -def normalise(x): +def normalise(x: np.typing.NDArray) -> np.typing.NDArray: + """ + Normalise an array by dividing by the sum of its elements. + """ return x[:] / np.sum(x[:]) -class Verif_interpolator: +class VerifInterpolator: """ Interpolator class that's either a wrapper for scipys LinearNDInterpolator or uses the handmade approximate 2D linear interpolator """ -class Verif_2D_interpolator(Verif_interpolator): +class Verif2DInterpolator(VerifInterpolator): """ Class that does approximate 2D interpolation """ - def __init__(self, grid_points, obs_points): + def __init__(self, grid_points: np.typing.NDArray, obs_points: np.typing.NDArray): """ Initialise the class and store gridpoints """ @@ -612,14 +615,14 @@ def __init__(self, grid_points, obs_points): grid_xyz = convert_coordinates(grid_points) obs_xyz = convert_coordinates(obs_points) - self.indices = np.ndarray((obs_points.shape[0], 5), dtype="float32") + self.indices = np.empty((obs_points.shape[0], 5), dtype="float32") tree = KDTree(grid_xyz) _, self.indices = tree.query(obs_xyz, k=5) - self.weights = np.ndarray((obs_points.shape[0], 3), dtype="float32") + self.weights = np.empty((obs_points.shape[0], 3), dtype="float32") self.compute_weights(grid_xyz, obs_xyz) - def compute_weights(self, grid_xyz, obs_xyz): + def compute_weights(self, grid_xyz: np.typing.NDArray, obs_xyz: np.typing.NDArray): """ Compute the weights of the three nearest grid points by computing the barycentric coordinates, @@ -629,48 +632,48 @@ def compute_weights(self, grid_xyz, obs_xyz): eps = 0.01 for i, (obs, indix) in enumerate(zip(obs_xyz, self.indices, strict=True)): - AB = grid_xyz[indix[1]] - grid_xyz[indix[0]] - AC = grid_xyz[indix[2]] - grid_xyz[indix[0]] - BC = grid_xyz[indix[2]] - grid_xyz[indix[1]] - AP = obs - grid_xyz[indix[0]] - BP = obs - grid_xyz[indix[1]] + ab = grid_xyz[indix[1]] - grid_xyz[indix[0]] + ac = grid_xyz[indix[2]] - grid_xyz[indix[0]] + bc = grid_xyz[indix[2]] - grid_xyz[indix[1]] + ap = obs - grid_xyz[indix[0]] + bp = obs - grid_xyz[indix[1]] - area_tot = np.linalg.norm(np.cross(AB, AC)) - self.weights[i, 0] = np.linalg.norm(np.cross(BC, BP)) - self.weights[i, 1] = np.linalg.norm(np.cross(AC, AP)) - self.weights[i, 2] = np.linalg.norm(np.cross(AB, AP)) + area_tot = np.linalg.norm(np.cross(ab, ac)) + self.weights[i, 0] = np.linalg.norm(np.cross(bc, bp)) + self.weights[i, 1] = np.linalg.norm(np.cross(ac, ap)) + self.weights[i, 2] = np.linalg.norm(np.cross(ab, ap)) if 1 - area_tot / np.sum(self.weights[i, :]) < eps: continue indix[2] = indix[3] - AC = grid_xyz[indix[2]] - grid_xyz[indix[0]] - BC = grid_xyz[indix[2]] - grid_xyz[indix[1]] + ac = grid_xyz[indix[2]] - grid_xyz[indix[0]] + bc = grid_xyz[indix[2]] - grid_xyz[indix[1]] - area_tot = np.linalg.norm(np.cross(AB, AC)) - self.weights[i, 0] = np.linalg.norm(np.cross(BC, BP)) - self.weights[i, 1] = np.linalg.norm(np.cross(AC, AP)) + area_tot = np.linalg.norm(np.cross(ab, ac)) + self.weights[i, 0] = np.linalg.norm(np.cross(bc, bp)) + self.weights[i, 1] = np.linalg.norm(np.cross(ac, ap)) if 1 - area_tot / np.sum(self.weights[i, :]) < eps: continue indix[2] = indix[4] - AC = grid_xyz[indix[2]] - grid_xyz[indix[0]] - BC = grid_xyz[indix[2]] - grid_xyz[indix[1]] + ac = grid_xyz[indix[2]] - grid_xyz[indix[0]] + bc = grid_xyz[indix[2]] - grid_xyz[indix[1]] - self.weights[i, 0] = np.linalg.norm(np.cross(BC, BP)) - self.weights[i, 1] = np.linalg.norm(np.cross(AC, AP)) + self.weights[i, 0] = np.linalg.norm(np.cross(bc, bp)) + self.weights[i, 1] = np.linalg.norm(np.cross(ac, ap)) self.weights = self.weights / self.weights.sum(axis=1)[:, np.newaxis] - def interpolate(self, values, intmap=None): + def interpolate(self, values: np.typing.NDArray, intmap: np.typing.NDArray = None) -> np.typing.NDArray: """ Interpolate values to points """ - wvalues = np.ndarray((self.weights.shape[0]), dtype="float32") + wvalues = np.empty((self.weights.shape[0]), dtype="float32") if intmap is None: wvalues[:] = ( @@ -688,7 +691,7 @@ def interpolate(self, values, intmap=None): return wvalues -class Verif_lat_lon_interpolator(Verif_interpolator): +class VerifLatLonInterpolator(VerifInterpolator): """ Class that does approximate 2D interpolation """ @@ -701,7 +704,7 @@ def __init__(self, grid_points, obs_points): self.obs_points = obs_points self.triangulation = Delaunay(grid_points) - def interpolate(self, values, intmap=None): + def interpolate(self, values: np.typing.NDArray, intmap: np.typing.NDArray = None) -> np.typing.NDArray: """ Interpolate values to points """ @@ -719,12 +722,12 @@ def interpolate(self, values, intmap=None): return interpolator(self.obs_points).astype(np.float32) -class Verif_nearest_interpolator(Verif_interpolator): +class VerifNearestInterpolator(VerifInterpolator): """ Class that does approximate 2D interpolation """ - def __init__(self, grid_points, obs_points): + def __init__(self, grid_points: np.typing.NDArray, obs_points: np.typing.NDArray): """ Initialise the class and store gridpoints """ @@ -735,12 +738,12 @@ def __init__(self, grid_points, obs_points): tree = KDTree(grid_xyz) _, self.indices = tree.query(obs_xyz, k=1) - def interpolate(self, values, intmap=None): + def interpolate(self, values: np.typing.NDArray, intmap: np.typing.NDArray = None) -> np.typing.NDArray: """ Interpolate values to points """ - wvalues = np.ndarray((self.indices.shape), dtype="float32") + wvalues = np.empty((self.indices.shape[0]), dtype="float32") if intmap is None: wvalues[:] = values[self.indices[:]] @@ -750,7 +753,7 @@ def interpolate(self, values, intmap=None): return wvalues -class Interpolator_factory: +class InterpolatorFactory: def __init__(self, method: str): valid_methods = ("2d", "lat_lon", "nearest") @@ -760,16 +763,16 @@ def __init__(self, method: str): self.method = method def get_interpolator( - self, zarr_coords: np.ndarray, obs_coords: np.ndarray - ) -> Verif_interpolator: + self, zarr_coords: np.typing.NDArray, obs_coords: np.typing.NDArray + ) -> VerifInterpolator: if self.method == "2d": _logger.info("2D interpolation") - return Verif_2D_interpolator(zarr_coords, obs_coords) + return Verif2DInterpolator(zarr_coords, obs_coords) elif self.method == "lat_lon": _logger.info("lat-lon interpolation") - return Verif_lat_lon_interpolator(zarr_coords, obs_coords) + return VerifLatLonInterpolator(zarr_coords, obs_coords) elif self.method == "nearest": _logger.info("nearest neighbour interpolation") - return Verif_nearest_interpolator(zarr_coords, obs_coords) + return VerifNearestInterpolator(zarr_coords, obs_coords) diff --git a/packages/verif/pyproject.toml b/packages/verif/pyproject.toml deleted file mode 100644 index 1fef65772..000000000 --- a/packages/verif/pyproject.toml +++ /dev/null @@ -1,75 +0,0 @@ -[project] -name = "weathergen-verif" -version = "0.1.0" -description = "The WeatherGenerator Machine Learning Earth System Model" -readme = "../../README.md" -requires-python = ">=3.11,<3.13" -dependencies = [ - "xarray>=2025.6.1", - "dask>=2024.9.1", - "zarr~=3.1.3", - "numcodecs<0.16.0", -] - -[dependency-groups] -dev = [ - "pytest~=8.3.5", - "pytest-mock>=3.14.1", - "ruff==0.9.7", -] - - - -[tool.black] - -# Wide rows -line-length = 100 - - -# The linting configuration -[tool.ruff] - -# Wide rows -line-length = 100 - -[tool.ruff.lint] -# All disabled until the code is formatted. -select = [ - # pycodestyle - "E", - # Pyflakes - "F", - # pyupgrade - "UP", - # flake8-bugbear - "B", - # flake8-simplify - "SIM", - # isort - "I", - # Banned imports - "TID" -] - -# These rules are sensible and should be enabled at a later stage. -ignore = [ - # "B006", - "B011", - "UP008", - "SIM117", - "SIM118", - "SIM102", - "SIM401", - "UP040", # TODO: enable later - # To ignore, not relevant for us - "SIM108" # in case additional norm layer supports are added in future -] - - - -[build-system] -requires = ["hatchling"] -build-backend = "hatchling.build" - -[tool.hatch.build.targets.wheel] -packages = ["src/weathergen"] diff --git a/packages/verif/src/weathergen/verif/__init__.py b/packages/verif/src/weathergen/verif/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/packages/verif/src/weathergen/verif/create_verif.py b/packages/verif/src/weathergen/verif/create_verif.py deleted file mode 100644 index a43453aca..000000000 --- a/packages/verif/src/weathergen/verif/create_verif.py +++ /dev/null @@ -1,401 +0,0 @@ -from argparse import ArgumentParser, Namespace -from pathlib import Path -from time import time - -import numpy as np -import xarray as xr - -from weathergen.common.io import ZarrIO -from weathergen.verif.verif_config import Variables -from weathergen.verif.verif_interpolator import Interpolator_factory -from weathergen.verif.verif_processers import Processer_factory - - -def readarg() -> Namespace: - parser = ArgumentParser(description="Create verif files from a zarr file and observation file") - - parser.add_argument( - "-z", - "--zarr", - dest="zarrfile", - required=True, - help="Zarr file (.zarr)", - ) - - parser.add_argument( - "-b", - "--obs", - dest="obsfile", - required=True, - help="Observation file (.nc)", - ) - - parser.add_argument( - "-o", - "--output", - dest="outfiles", - default="output/verif/%S/%V/verif_%S_%V_%M.nc", - required=False, - help="Template for the output nc filenames, default will be to create output/verif/%S/%V \ - repertories where %S, %V, %d are replaced by the streams, variable and date", - ) - - parser.add_argument( - "-d", - "--date", - type=str, - dest="datefromto", - required=False, - default=None, - help="From to date in format %Y%m%d%H:%Y%m%d%H or %Y%m%d:%Y%m%d, \ - excluding the second date for instance 2024010100:2024020200", - ) - - parser.add_argument( - "-v", - "--variables", - default=None, - dest="variables", - nargs="*", - help="Do verif for these variables. Default: 2t", - ) - - parser.add_argument( - "-s", - "--streams", - default=None, - dest="streams", - nargs="*", - help="Do verif for this streams. Default: Infer from .zarr file", - ) - - parser.add_argument( - "-m", - "--method", - default="2d", - dest="method", - choices=["2d", "lat_lon", "nearest"], - help="Interpolation method. Default: 2d_interpolation", - ) - - parser.add_argument( - "-ds", - "--dataset", - default="prediction", - dest="dataset", - choices=["prediction", "target"], - help="Prediction or target dataset.", - ) - - parser.add_argument( - "-c", - "--config_file", - dest="config_file", - default=None, - type=str, - help="Config file used for generating verif file.", - ) - - args = parser.parse_args() - - return args - - -def create_output_paths( - stream: str, variable: str, outfiles: str, method: str, dataset: str -) -> Path: - """ - Create output directories for the verif files - and return path to output file - Args: - stream (string) - variables (list[string]) - outfiles (string): template for the output files - Outputs: - None - """ - outfile = Path( - outfiles.replace("%S", stream) - .replace("%V", variable) - .replace("%M", method) - .replace("%D", dataset) - ) - pathdir = outfile.parent - print(f"Output directory: {pathdir}") - pathdir.mkdir(exist_ok=True, parents=True) - return outfile - - -def generate_time_coordinates( - xdata: xr.DataArray, zarrio: ZarrIO, stream: str, dataset: str -) -> tuple[xr.DataArray, xr.DataArray]: - """ - Read samples and steps from ZarrIO object - and convert to xarray data objects - to be used as coordinates in verrif dataset - """ - - # Initial times are stored as numpy.datetime64 objects in verif - # Get the valid time of the first step for each sample - verif_times = [np.datetime64("nat", "h")] * len(zarrio.samples) - for sample in zarrio.samples: - item = zarrio.get_data(sample=sample, stream=stream, forecast_step=1) - if dataset == "prediction": - verif_times[int(sample)] = item.prediction.as_xarray().source_interval_start.values[0] - else: - verif_times[int(sample)] = item.target.as_xarray().source_interval_start.values[0] - - xrtime = xr.DataArray( - verif_times, - name="time", - dims=["time"], - coords={"time": verif_times}, - attrs={"standard_name": "forecast_reference_time"}, - ) - - dt = xdata.source_interval_end.values[0] - xdata.source_interval_start.values[0] - dt = dt.astype("timedelta64[h]") - - # Lead times are stored as float32 in verif - # Assume all time steps are the same, - # so loop over steps and multiply the time step size by index - leadtimes = np.ndarray(len(zarrio.forecast_steps), dtype=np.float32) - for i in range(len(zarrio.forecast_steps)): - leadtimes[i] = (i + 1) * dt - - xrleadtime = xr.DataArray( - leadtimes, - name="leadtime", - dims=["leadtime"], - coords={"leadtime": leadtimes}, - attrs={"units": "hour"}, - ) - - return xrtime, xrleadtime - - -def get_streams(zarrio: ZarrIO, arg_streams: list) -> list: - """ - Determine the stream, - either by getting streams from argument and check if they are in the zarr file - or just use all the streams in zarrio - Args: - zarrio: ZarrIO object - arg_streams: (list[string]) - Outputs: - streams: (list[string]) - """ - if arg_streams: - for stream in arg_streams: - if stream not in zarrio.streams: - raise Exception( - f"Stream {stream} is not present in .zarr file. zarrio.streams: \ - {zarrio.streams}" - ) - return arg_streams - else: - return zarrio.streams - - -def get_variables(xdata: xr.DataArray, config_file: Path, arg_variables: list, stream: str) -> list: - """ - Go through argument variables, - check if they are in the config_file and return - a list ov variables. - If no arguments are given, - return list of variables found in file. - """ - - config_variables = Variables(config_file) - - config_names = (cv.name for cv in config_variables) - - variables = [] - if arg_variables: - # Check if there's a config for requested variables - for av in arg_variables: - if av not in config_names: - raise Exception(f"Variable {av} does not have an entry in the config file") - - # Add requested variables to list of variables - for cv in config_variables: - if cv.name in arg_variables: - variables += [cv] - - else: - variables = [v for v in config_variables] - - # Check what variables exist in zarr file - vvars = [] - for v in variables: - w = v - - stringnames = (n for n in v.zarr_names if isinstance(n, str)) - listnames = (n for n in v.zarr_names if isinstance(n, list)) - - for n in stringnames: - if n in xdata.channel.values: - w.zarr_names = n - vvars += [w] - for n in listnames: - if len(set(n).intersection(xdata.channel.values)) == len(n): - w.zarr_names = tuple(n) - vvars += [w] - - variables = vvars - - if not (len(variables) == len(set(variables))): - raise Exception("Same variable appears multiple times in zarr file.") - - if not variables: - raise Exception("No variables with configuration found in zarr file.") - - for v in variables: - try: - v.zarr_units = v.zarr_units[stream] - except KeyError: - v.zarr_units = v.zarr_units["DEFAULT"] - - return variables - - -def get_obs_coordinates(obs: xr.Dataset): - """ - Extract latitude, longitude and altitude - from observation dataset - Args: - obs: Dataset - Outputs: - lat: DataArray - lon: DataArray - alt: DataArray - """ - - lat = obs.latitude.astype("float32") - lat.name = "lat" - - lon = obs.longitude.astype("float32") - lon.name = "lon" - - alt = obs.altitude.astype("float32") - - return lat, lon, alt - - -def process_config(config_file: str) -> Path: - """ - Convert input config_file argument to absolute Path object - """ - - if not config_file: - config_path = Path(__file__).parent / "verif_config.yaml" - else: - config_path = Path(config_file).resolve() - - if not config_path.is_file(): - raise Exception(f"{config_file} is not a file.") - - return config_path - - -def main(): - print("Start creating verif files") - - args = readarg() - - print("zarrfile:", args.zarrfile) - print("obsfile:", args.obsfile) - print("outputfile template:", args.outfiles) - print("dataset: ", args.dataset) - - obs = xr.open_dataset(args.obsfile) - lat, lon, alt = get_obs_coordinates(obs) - obs_coords = np.column_stack((lat.values, lon.values)) - - print() - print(obs) - - method_factory = Interpolator_factory(args.method) - - with ZarrIO(args.zarrfile, read_only=True) as zarrio: - streams = get_streams(zarrio, "ERA5") - - t_start = time() - - for stream in streams: - print() - print("stream: ", stream) - - item = zarrio.get_data(sample=0, stream=stream, forecast_step=1) - if args.dataset == "prediction": - xdata = item.prediction.as_xarray() - else: - xdata = item.target.as_xarray() - - xrtime, xrleadtime = generate_time_coordinates(xdata, zarrio, stream, args.dataset) - - config_path = process_config(args.config_file) - - variables = get_variables(xdata, config_path, args.variables, stream) - - zarr_coords = np.column_stack((xdata.ipoint.lat.values, xdata.ipoint.lon.values)) - - interpolator = method_factory.get_interpolator(zarr_coords, obs_coords) - - data_shape = (len(zarrio.samples), len(zarrio.forecast_steps), obs.location.shape[0]) - - processers = Processer_factory(zarrio, obs, stream, interpolator, args.dataset) - - for v in variables: - vt_start = time() - - print() - print("variable: ", v.name) - - fcstdata = np.ndarray(data_shape, dtype=np.float32) - obsdata = np.ndarray(data_shape, dtype=np.float32) - - p = processers.get_processer(v.name) - - p.get_data(v, fcstdata, obsdata) - - xrobsdata = xr.DataArray( - obsdata, - dims=["time", "leadtime", "location"], - coords={"time": xrtime, "leadtime": xrleadtime, "location": obs.location}, - name="obs", - attrs=v.attributes, - ) - - xrfcstdata = xr.DataArray( - fcstdata, - dims=["time", "leadtime", "location"], - coords={"time": xrtime, "leadtime": xrleadtime, "location": obs.location}, - name="fcst", - attrs=v.attributes, - ) - - merged = xr.merge([xrfcstdata, xrobsdata, lat, lon, alt]) - - outfile = create_output_paths( - stream, v.name, args.outfiles, args.method, args.dataset - ) - - merged.to_netcdf( - outfile, encoding={"time": {"units": "seconds since 1970-01-01 00:00:00"}} - ) - - vt_end = time() - - print(v.name, "time: ", vt_end - vt_start) - print("merged: ") - print(merged) - - t_end = time() - - print() - print("all the time: ", t_end - t_start) - - -if __name__ == "__main__": - main() diff --git a/packages/verif/src/weathergen/verif/verif_config.py b/packages/verif/src/weathergen/verif/verif_config.py deleted file mode 100644 index 490899ebb..000000000 --- a/packages/verif/src/weathergen/verif/verif_config.py +++ /dev/null @@ -1,44 +0,0 @@ -from pathlib import Path - -from yaml import safe_load - -__all__ = ["Variable"] - - -class Variable: - """ - Object representing a variable - """ - - def __init__(self, **kwargs): - self.name = kwargs.get("name") - self.attributes = kwargs.get("attributes") - self.zarr_names = kwargs.get("zarr_names") - self.zarr_units = kwargs.get("zarr_units") - self.obs_name = kwargs.get("obs_name") - self.obs_units = kwargs.get("obs_units") - - def __repr__(self): - return self.name - - -class Variables: - """ - Utility class to read a verif configuration from .yaml file - """ - - def __init__(self, filename: Path): - print(f"Reading configuration from file: {filename}") - - with open(filename) as stream: - self.schema = safe_load(stream) - - def __iter__(self): - return self.variables.__iter__() - - @property - def variables(self): - """ - Get a list of variables from the locally stored configuration - """ - return [Variable(**var) for var in self.schema] diff --git a/packages/verif/src/weathergen/verif/verif_config.yaml b/packages/verif/src/weathergen/verif/verif_config.yaml deleted file mode 100644 index 0662007d6..000000000 --- a/packages/verif/src/weathergen/verif/verif_config.yaml +++ /dev/null @@ -1,74 +0,0 @@ -# Default config file -# mapping variables from zarr and observation files -# to verif files - -- name: "2t" - attributes: { - units: "K", - long_name: "2 meter temperature", - conventions: "verif_1.0.0"} - zarr_names: ["2t"] - zarr_units: {"CERRA": "K", - "MEPS": "K", - "NORA3": "K", - "ERA5": "K", - "DEFAULT": "K"} - obs_name: "air_temperature" - obs_units: "K" - -- name: "sp" - attributes: { - units: "Pa", - long_name: "Surface pressure", - conventions: "verif_1.0.0"} - zarr_names: ["sp"] - zarr_units: {"CERRA": "Pa", - "MEPS": "Pa", - "NORA3": "Pa", - "ERA5": "Pa", - "DEFAULT": "Pa"} - obs_name: "surface_air_pressure" - obs_units: "Pa" - -- name: "tp" - attributes: { - units: "kg/m^2", - long_name: "Total precipitation amount", - conventions: "verif_1.0.0"} - zarr_names: ["tp"] - zarr_units: {"CERRA": "kg/m^2", - "MEPS": "kg/m^2", - "NORA3": "kg/m^2", - "ERA5": "m", - "DEFAULT": "kg/m^2"} - obs_name: "precipitation_amount_1h" - obs_units: "kg/m^2" - -- name: "mslp" - attributes: { - units: "Pa", - long_name: "Mean sea level pressure", - conventions: "verif_1.0.0"} - zarr_names: ["msl"] - zarr_units: {"CERRA": "Pa", - "MEPS": "Pa", - "NORA3": "Pa", - "ERA5": "Pa", - "DEFAULT": "Pa"} - obs_name: "surface_air_pressure" - obs_units: "Pa" - -- name: "wind" - attributes: { - units: "m/s", - long_name: "wind speed", - conventions: "verif_1.0.0"} - zarr_names: ["10si", ["10u", "10v"]] - zarr_units: {"CERRA": "m/s", - "MEPS": "m/s", - "NORA3": "m/s", - "ERA5": "m/s", - "DEFAULT": "m/s"} - obs_name: "wind_speed" - obs_units: "m/s" - diff --git a/packages/verif/src/weathergen/verif/verif_interpolator.py b/packages/verif/src/weathergen/verif/verif_interpolator.py deleted file mode 100644 index 4bbb29d6b..000000000 --- a/packages/verif/src/weathergen/verif/verif_interpolator.py +++ /dev/null @@ -1,204 +0,0 @@ -import numpy as np -from scipy.interpolate import LinearNDInterpolator -from scipy.spatial import Delaunay, KDTree - - -def convert_coordinates(coords): - """ - Convert lat-lon coordinates to cartesian coordinates in a unit box - """ - - xyz_coords = np.ndarray((coords.shape[0], 3), dtype="float32") - - xyz_coords[:, 0] = np.cos(np.pi * coords[:, 0] / 180.0) * np.cos(np.pi * coords[:, 1] / 180.0) - xyz_coords[:, 1] = np.cos(np.pi * coords[:, 0] / 180.0) * np.sin(np.pi * coords[:, 1] / 180.0) - xyz_coords[:, 2] = np.sin(np.pi * coords[:, 0] / 180.0) - - return xyz_coords - - -def normalise(x): - return x[:] / np.sum(x[:]) - - -class Verif_interpolator: - """ - Interpolator class that's either a wrapper for scipys LinearNDInterpolator - or uses the handmade approximate 2D linear interpolator - """ - - -class Verif_2D_interpolator(Verif_interpolator): - """ - Class that does approximate 2D interpolation - """ - - def __init__(self, grid_points, obs_points): - """ - Initialise the class and store gridpoints - """ - - grid_xyz = convert_coordinates(grid_points) - obs_xyz = convert_coordinates(obs_points) - - self.indices = np.ndarray((obs_points.shape[0], 5), dtype="float32") - tree = KDTree(grid_xyz) - _, self.indices = tree.query(obs_xyz, k=5) - - self.weights = np.ndarray((obs_points.shape[0], 3), dtype="float32") - self.compute_weights(grid_xyz, obs_xyz) - - def compute_weights(self, grid_xyz, obs_xyz): - """ - Compute the weights of the three nearest grid points - by computing the barycentric coordinates, - assuming that the observations are close enough to the plane through the grid points. - """ - - eps = 0.01 - - for i, (obs, indix) in enumerate(zip(obs_xyz, self.indices, strict=True)): - AB = grid_xyz[indix[1]] - grid_xyz[indix[0]] - AC = grid_xyz[indix[2]] - grid_xyz[indix[0]] - BC = grid_xyz[indix[2]] - grid_xyz[indix[1]] - AP = obs - grid_xyz[indix[0]] - BP = obs - grid_xyz[indix[1]] - - area_tot = np.linalg.norm(np.cross(AB, AC)) - self.weights[i, 0] = np.linalg.norm(np.cross(BC, BP)) - self.weights[i, 1] = np.linalg.norm(np.cross(AC, AP)) - self.weights[i, 2] = np.linalg.norm(np.cross(AB, AP)) - - if 1 - area_tot / np.sum(self.weights[i, :]) < eps: - continue - - indix[2] = indix[3] - - AC = grid_xyz[indix[2]] - grid_xyz[indix[0]] - BC = grid_xyz[indix[2]] - grid_xyz[indix[1]] - - area_tot = np.linalg.norm(np.cross(AB, AC)) - self.weights[i, 0] = np.linalg.norm(np.cross(BC, BP)) - self.weights[i, 1] = np.linalg.norm(np.cross(AC, AP)) - - if 1 - area_tot / np.sum(self.weights[i, :]) < eps: - continue - - indix[2] = indix[4] - - AC = grid_xyz[indix[2]] - grid_xyz[indix[0]] - BC = grid_xyz[indix[2]] - grid_xyz[indix[1]] - - self.weights[i, 0] = np.linalg.norm(np.cross(BC, BP)) - self.weights[i, 1] = np.linalg.norm(np.cross(AC, AP)) - - self.weights = self.weights / self.weights.sum(axis=1)[:, np.newaxis] - - def interpolate(self, values, intmap=None): - """ - Interpolate values to points - """ - - wvalues = np.ndarray((self.weights.shape[0]), dtype="float32") - - if intmap is None: - wvalues[:] = ( - self.weights[:, 0] * values[self.indices[:, 0]] - + self.weights[:, 1] * values[self.indices[:, 1]] - + self.weights[:, 2] * values[self.indices[:, 2]] - ) - else: - wvalues[:] = ( - self.weights[:, 0] * values[intmap[self.indices[:, 0]]] - + self.weights[:, 1] * values[intmap[self.indices[:, 1]]] - + self.weights[:, 2] * values[intmap[self.indices[:, 2]]] - ) - - return wvalues - - -class Verif_lat_lon_interpolator(Verif_interpolator): - """ - Class that does approximate 2D interpolation - """ - - def __init__(self, grid_points, obs_points): - """ - Initialise the class and store gridpoints - """ - - self.obs_points = obs_points - self.triangulation = Delaunay(grid_points) - - def interpolate(self, values, intmap=None): - """ - Interpolate values to points - """ - - newvalues = np.empty_like(values) - - if intmap is None: - newvalues = values - else: - for i in range(len(values)): - newvalues[i] = values[intmap[i]] - - interpolator = LinearNDInterpolator(self.triangulation, newvalues) - - return interpolator(self.obs_points).astype(np.float32) - - -class Verif_nearest_interpolator(Verif_interpolator): - """ - Class that does approximate 2D interpolation - """ - - def __init__(self, grid_points, obs_points): - """ - Initialise the class and store gridpoints - """ - - grid_xyz = convert_coordinates(grid_points) - obs_xyz = convert_coordinates(obs_points) - - tree = KDTree(grid_xyz) - _, self.indices = tree.query(obs_xyz, k=1) - - def interpolate(self, values, intmap=None): - """ - Interpolate values to points - """ - - wvalues = np.ndarray((self.indices.shape), dtype="float32") - - if intmap is None: - wvalues[:] = values[self.indices[:]] - else: - wvalues[:] = values[intmap[self.indices[:]]] - - return wvalues - - -class Interpolator_factory: - def __init__(self, method: str): - valid_methods = ("2d", "lat_lon", "nearest") - - if method not in valid_methods: - raise Exception(f"{method} is not a valid method.") - - self.method = method - - def get_interpolator( - self, zarr_coords: np.ndarray, obs_coords: np.ndarray - ) -> Verif_interpolator: - if self.method == "2d": - print("2D interpolation") - return Verif_2D_interpolator(zarr_coords, obs_coords) - - elif self.method == "lat_lon": - print("lat-lon interpolation") - return Verif_lat_lon_interpolator(zarr_coords, obs_coords) - - elif self.method == "nearest": - print("nearest neighbour interpolation") - return Verif_nearest_interpolator(zarr_coords, obs_coords) diff --git a/packages/verif/src/weathergen/verif/verif_processers.py b/packages/verif/src/weathergen/verif/verif_processers.py deleted file mode 100644 index eb9a86e34..000000000 --- a/packages/verif/src/weathergen/verif/verif_processers.py +++ /dev/null @@ -1,190 +0,0 @@ -import numpy as np -import xarray as xr - -from weathergen.common.io import ZarrIO -from weathergen.evaluate.scores.score import Scores -from weathergen.verif.verif_config import Variable -from weathergen.verif.verif_interpolator import Verif_interpolator - - -class Processer: - unit_conversion = {"kg/m^2": 1.0, "Pa": 1.0, "K": 1.0, "m/s": 1.0, "m": 1000.0} - - def __init__( - self, - zarrio: ZarrIO, - obs: xr.DataArray, - stream: str, - interpolator: Verif_interpolator, - dataset: str, - ): - self.zarrio = zarrio - self.obs = obs - self.stream = stream - self.interpolator = interpolator - self.dataset = dataset - - item = zarrio.get_data(sample=0, stream=stream, forecast_step=1) - if self.dataset == "prediction": - self.xdata = item.prediction.as_xarray() - else: - self.xdata = item.target.as_xarray() - - self.obs_dt = self.obs.time.values[1] - self.obs.time.values[0] - self.obs_dt = self.obs_dt.astype("timedelta64[h]") - - self.zarr_dt = ( - self.xdata.source_interval_end.values[0] - self.xdata.source_interval_start.values[0] - ) - self.zarr_dt = self.zarr_dt.astype("timedelta64[h]") - - def get_data(self, v: Variable, fcstdata, obsdata): - for sample in range(len(self.zarrio.samples)): - for step in range(len(self.zarrio.forecast_steps)): - item = self.zarrio.get_data( - sample=sample, stream=self.stream, forecast_step=step + 1 - ) - - if self.dataset == "prediction": - ydata = Scores.sort_by_coords(item.prediction.as_xarray(), self.xdata) - else: - ydata = Scores.sort_by_coords(item.target.as_xarray(), self.xdata) - - obsdata[sample, step, :] = self.get_obsdata( - self.obs, v.obs_name, ydata.valid_time.values[0] - ) - - fcstdata[sample, step, :] = self.get_fcstdata(ydata, v, sample, step + 1) - - def get_obsdata(self, obs: xr.DataArray, name: str, time: np.datetime64) -> np.ndarray: - return obs.data_vars[name].sel(time=time) - - def get_fcstdata(self, ydata: xr.DataArray, v: Variable, sample: int, step: int) -> np.ndarray: - return ( - self.interpolator.interpolate( - ydata.sel( - sample=sample, - stream=self.stream, - forecast_step=step, - channel=v.zarr_names, - ens=0, - ).values - ) - * self.unit_conversion[v.zarr_units] - ) - - -class MSLP_processer(Processer): - def get_obsdata(self, obs: xr.DataArray, name: str, time: np.datetime64) -> np.ndarray: - return self.compute_mslp(obs, time) - - def compute_mslp(self, obs: xr.DataArray, time: np.datetime64) -> np.ndarray: - # g = 9.80665 # Gravitational acceleration (m/s**2) - # R = 8.31447 # Universal gas constant (J/mol*K) - - # a = 0.0065 # Temperature lapse rate (K/m) - # Ch = 0.0012 # (K/Pa) - - A = 17.625 - B = 243.03 - C = 6.1094 - - P = obs.data_vars["surface_air_pressure"].sel(time=time) - T = obs.data_vars["air_temperature"].sel(time=time) - rh = obs.data_vars["relative_humidity"].sel(time=time) - - altitude = obs.altitude - - e = rh * 6.11 * np.power(10.0, ((7.5 * (T - 273.15)) / (T - 38.85))) - - dewpoint = np.where(~np.isnan(e), B * np.log(e / C) / (A - np.log(e / C)), T - 276.15) - - e = np.where(np.isnan(e), 0, e) - - Tv = T / ( - 1.0 - 0.379 * (6.11 * np.power(10.0, ((7.5 * dewpoint) / (237.7 + dewpoint))) / P) - ) - - # mslp = np.where(altitude >= 50., - # P * np.exp((g * altitude / R) / (T + 0.5 * a * altitude + e * Ch)), - # P + P * altitude / (29.27 * Tv)) - - mslp = P + P * altitude / (29.27 * Tv) - - return mslp - - -class Wind_processer(Processer): - def get_fcstdata(self, ydata: xr.DataArray, v: Variable, sample: int, step: int) -> np.ndarray: - if isinstance(v.zarr_names, str): - return super().get_fcstdata(ydata, v, sample, step) - - else: - u = self.interpolator.interpolate( - ydata.sel( - sample=sample, - stream=self.stream, - forecast_step=step, - channel=v.zarr_names[0], - ens=0, - ).values - ) - - v = self.interpolator.interpolate( - ydata.sel( - sample=sample, - stream=self.stream, - forecast_step=step, - channel=v.zarr_names[1], - ens=0, - ).values - ) - - return np.sqrt(np.square(u) + np.square(v)) - - -class Precipitation_processer(Processer): - def get_obsdata(self, obs: xr.DataArray, name: str, time: np.datetime64) -> np.ndarray: - if self.obs_dt >= self.zarr_dt: - return super().get_obsdata(obs, name, time) - else: - accumulate = np.zeros(self.obs.location.shape[0]) - int_factor = int(self.zarr_dt / self.obs_dt) - - for i in range(int_factor): - back_time = time - self.zarr_dt + (i + 1) * self.obs_dt - accumulate += super().get_obsdata(obs, name, back_time) - - return accumulate - - -class Processer_factory: - def __init__( - self, - zarrio: ZarrIO, - obs: xr.DataArray, - stream: str, - interpolator: Verif_interpolator, - dataset: str, - ): - self.zarrio = zarrio - self.obs = obs - self.stream = stream - self.interpolator = interpolator - self.dataset = dataset - - def get_processer(self, name: str) -> Processer: - if name == "mslp": - return MSLP_processer( - self.zarrio, self.obs, self.stream, self.interpolator, self.dataset - ) - elif name == "wind": - return Wind_processer( - self.zarrio, self.obs, self.stream, self.interpolator, self.dataset - ) - elif name == "tp": - return Precipitation_processer( - self.zarrio, self.obs, self.stream, self.interpolator, self.dataset - ) - else: - return Processer(self.zarrio, self.obs, self.stream, self.interpolator, self.dataset) From f3c3a9adc4b373b9bb64e922a868c9164fba68fb Mon Sep 17 00:00:00 2001 From: Sorcha Date: Mon, 23 Feb 2026 15:14:24 +0100 Subject: [PATCH 10/28] fixing pyporject --- pyproject.toml | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 345a7559b..2f48fd884 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ dependencies = [ "anemoi-datasets", "weathergen-common", "weathergen-evaluate", - "weathergen-readers-extra", + "weathergen-readers-extra" ] @@ -44,7 +44,6 @@ inference = "weathergen.run_train:inference" evaluate = "weathergen.evaluate.run_evaluation:evaluate" plot_train = "weathergen.utils.plot_training:plot_train" export = "weathergen.evaluate.export.export_inference:export" -create_verif = "weathergen.verif.create_verif:main" [build-system] requires = ["hatchling"] @@ -266,7 +265,6 @@ members = [ "packages/evaluate", "packages/metrics", "packages/readers_extra", - "packages/verif", # Explicitly not depending on 'packages/dashboard' : this causes issues when deploying # the streamlit dashboard. ] From c040701352fe2f5c8f274352b1b69bdb635cf4a2 Mon Sep 17 00:00:00 2001 From: Sorcha Owens Date: Fri, 27 Feb 2026 11:56:22 +0100 Subject: [PATCH 11/28] pinning eathkit --- config/evaluate/config_zarr2verif.yaml | 7 ++++++- packages/evaluate/pyproject.toml | 5 +++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/config/evaluate/config_zarr2verif.yaml b/config/evaluate/config_zarr2verif.yaml index b76bff633..6ed22b2a3 100644 --- a/config/evaluate/config_zarr2verif.yaml +++ b/config/evaluate/config_zarr2verif.yaml @@ -1,3 +1,8 @@ +# List of variables to compare between WeatherGenerator output and MetNor observation files +# To add more varaibles use the same formate for the variable using the `var` field to put the +# name of the variable in the WeatherGenerator dataset and units for each stream in `wg_uni` +# Additionally, add the chosen variable to required channels in `verif_parser.py` L92 + variables: 2t: var: 2t @@ -54,7 +59,7 @@ variables: 10si: - var: wind # derived channel + var: 10si # derived channel long: wind speed wg_unit: {CERRA: m/s, MEPS: m/s, diff --git a/packages/evaluate/pyproject.toml b/packages/evaluate/pyproject.toml index 0a2991001..6a178833b 100644 --- a/packages/evaluate/pyproject.toml +++ b/packages/evaluate/pyproject.toml @@ -17,8 +17,9 @@ dependencies = [ "eccodes==2.44.0", "eccodeslib==2.44.0.7", "eckitlib==1.32.3.7", - "earthkit-data==0.18.2" -] + "earthkit-data==0.18.2", + "earthkit-utils==0.1.2" + ] [dependency-groups] dev = [ From 991c5d2d2279c01a8128cb7ae1d0ffeeb8b41ad9 Mon Sep 17 00:00:00 2001 From: Sorcha Date: Fri, 27 Feb 2026 12:12:12 +0100 Subject: [PATCH 12/28] add compat/join arguments --- .../src/weathergen/evaluate/export/parsers/verif_parser.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py b/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py index ccb7bc312..61d479c2b 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py +++ b/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py @@ -302,7 +302,6 @@ def regrid( """ mapped_info = self.mapping.get(verif_var, {}) wg_var = mapped_info.get("var", None) - try: ds_var = ds[wg_var] except KeyError as e: @@ -435,7 +434,7 @@ def add_attrs(self, ds: xr.Dataset) -> xr.Dataset: xarray Dataset with CF-compliant variable attributes. """ variables = self._attrs_gaussian_grid(ds) - dataset = xr.merge(variables.values()) + dataset = xr.merge(variables.values(), compat="no_conflicts") dataset.attrs = ds.attrs return dataset @@ -603,7 +602,7 @@ def _build_coordinate_mapping( def merge(self, ds, obs_ds): lat, lon, alt = get_obs_coordinates(self.obs) - merged = xr.merge([ds, obs_ds, lat, lon, alt]) + merged = xr.merge([ds, obs_ds, lat, lon, alt], join="outer") return merged def save(self, ds: xr.Dataset, forecast_ref_time: np.datetime64, verif_var: str) -> None: From 8bebaa1c614beb9b2d1848a3a0ffef0d3d85a128 Mon Sep 17 00:00:00 2001 From: Sorcha Date: Fri, 27 Feb 2026 19:51:06 +0100 Subject: [PATCH 13/28] adjjsting to save all samples --- config/evaluate/config_zarr2verif.yaml | 14 +- .../weathergen/evaluate/export/export_core.py | 11 +- .../evaluate/export/parsers/netcdf_parser.py | 2 +- .../evaluate/export/parsers/verif_parser.py | 128 ++++++++++-------- 4 files changed, 86 insertions(+), 69 deletions(-) diff --git a/config/evaluate/config_zarr2verif.yaml b/config/evaluate/config_zarr2verif.yaml index 6ed22b2a3..589626d40 100644 --- a/config/evaluate/config_zarr2verif.yaml +++ b/config/evaluate/config_zarr2verif.yaml @@ -74,28 +74,28 @@ variables: coordinates: sfc: - valid_time: time + valid_time: valid_time lat: latitude lon: longitude stream: stream forecast_step: leadtime - forecast_reference_time: forecast_reference_time + forecast_reference_time: time ncells: ncells pl: #not needed pressure_level: pressure - valid_time: time + valid_time: valid_time lat: latitude lon: longitude stream: stream forecast_step: leadtime - forecast_reference_time: forecast_reference_time + forecast_reference_time: time ncells: ncells dimensions: valid_time: - verif: time - std: time + verif: valid_time + std: valid_time lat: verif: latitude std: latitude @@ -109,7 +109,7 @@ dimensions: std: pressure verif_unit: hPa forecast_reference_time: - verif: forecast_reference_time + verif: time std: forecast_reference_time forecast_step: verif: leadtime diff --git a/packages/evaluate/src/weathergen/evaluate/export/export_core.py b/packages/evaluate/src/weathergen/evaluate/export/export_core.py index d0d9837f8..eb5ab09e3 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/export_core.py +++ b/packages/evaluate/src/weathergen/evaluate/export/export_core.py @@ -250,6 +250,8 @@ def export_model_outputs(data_type: str, config: OmegaConf, **kwargs) -> None: with Pool(processes=n_processes, maxtasksperchild=5) as pool: parser = CfParserFactory.get_parser(config=config, **kwargs) + processed_samples = [] + for s_idx, sample in enumerate(tqdm(samples)): ref_time = ref_times[s_idx] step_tasks = [ @@ -258,10 +260,17 @@ def export_model_outputs(data_type: str, config: OmegaConf, **kwargs) -> None: results_iterator = pool.imap_unordered(get_data_worker, step_tasks, chunksize=1) - parser.process_sample( + + processed = parser.process_sample( results_iterator, ref_time=ref_time, ) + processed_samples.append(processed) + + # Only save here if need to merge samples, otherwise save in process_sample + if processed_samples[0] is not None: + parser.save(processed_samples) + pool.terminate() pool.join() diff --git a/packages/evaluate/src/weathergen/evaluate/export/parsers/netcdf_parser.py b/packages/evaluate/src/weathergen/evaluate/export/parsers/netcdf_parser.py index fe7655fbe..9b19c97d0 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/parsers/netcdf_parser.py +++ b/packages/evaluate/src/weathergen/evaluate/export/parsers/netcdf_parser.py @@ -270,7 +270,7 @@ def add_attrs(self, ds: xr.Dataset) -> xr.Dataset: else: variables = self._attrs_regular_grid(ds) - dataset = xr.merge(variables.values()) + dataset = xr.merge(variables.values(), compat="no_conflicts") dataset.attrs = ds.attrs return dataset diff --git a/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py b/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py index 61d479c2b..5d3951577 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py +++ b/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py @@ -92,7 +92,6 @@ def process_sample( required_channels = ["10u", "10v", "sp", "2t", "msl"] self.channels = list(set(self.channels) & set(required_channels)) da_fs = [] - valid_times = [] for result in fstep_iterator_results: if result is None: continue @@ -102,7 +101,7 @@ def process_sample( result = self.preprocess(result) result = self.reshape(result) da_fs.append(result) - valid_times.append(result.valid_time.values[0]) + _logger.info(f"Retrieved {len(da_fs)} forecast steps for type {self.data_type}.") @@ -115,17 +114,21 @@ def process_sample( da_fs = self.assign_frt(da_fs, ref_time) da_fs = self.add_attrs(da_fs) + vars_to_merge = {verif_var: None for verif_var in self.mapping.keys()} + for verif_var in self.mapping.keys(): - da_var = self.regrid(da_fs, verif_var, valid_times) + da_var = self.regrid(da_fs, verif_var) if da_var is None: continue da_var = self.add_encoding(da_var) - obs_result = self.obs_preprocess(valid_times, verif_var) + obs_result = self.obs_preprocess(da_var, verif_var) obs_result = self.add_encoding(obs_result) merged = self.merge(da_var, obs_result) merged = self.add_metadata(merged, verif_var) - self.save(merged, ref_time, verif_var) + vars_to_merge[verif_var] = merged + return vars_to_merge + def get_zarr_dt(self, ds: xr.Dataset) -> np.timedelta64: """ Compute the time difference between forecast steps in hours from the WG output dataset. @@ -143,12 +146,11 @@ def get_zarr_dt(self, ds: xr.Dataset) -> np.timedelta64: return zarr_dt - def get_output_filename(self, forecast_ref_time: np.datetime64, variable: str) -> Path: + def get_output_filename(self, variable: str) -> Path: """ Create output directories for the verif files and return path to output file Args: - stream (string) variables (list[string]) outfiles (string): template for the output files Outputs: @@ -158,7 +160,7 @@ def get_output_filename(self, forecast_ref_time: np.datetime64, variable: str) - self.verif_template.replace("%S", self.stream) .replace("%V", variable) .replace("%M", self.method) - .replace("%D", self.data_type + np.datetime_as_string(forecast_ref_time, unit="h")) + .replace("%D", self.data_type) ) outfile = Path(self.output_dir) / outfile pathdir = outfile.parent @@ -221,39 +223,47 @@ def reshape(self, data: xr.DataArray) -> xr.Dataset: return reshaped_dataset - def obs_preprocess(self, valid_times: list[np.datetime64], verif_var: str) -> xr.DataArray: + def obs_preprocess(self, ds_var, verif_var: str) -> xr.DataArray: """ Preprocess the observation data for the given variable and valid times. This includes computing derived variables like MSLP and total precipitation if needed. + Parameters ---------- - valid_times : list of np.datetime64 - List of valid times for which to preprocess the observation data. - verif_var : str - The variable for which to preprocess the observation data (e.g., 'mslp', 'tp'). + obs_data : xr.Dataset + The original observation dataset. + ds_var : xr.DataArray + The forecast data array to which the observation data should be regridded. + Returns ------- xr.DataArray - Preprocessed observation data for the specified variable and valid times. + Regridded observation data matching the forecast grid. """ obs_data = self.obs mapped_info = self.mapping.get(verif_var, {}) obs_name = mapped_info.get("obs_name", {}) - if verif_var == "mslp": - for vtime in valid_times: - obs_data[obs_name].sel(time=valid_times).values = compute_mslp(obs_data, vtime) - if verif_var == "tp": - for vtime in valid_times: - obs_data[obs_name].sel(time=vtime).values = compute_precip( - obs_data, self.zarr_dt, vtime - ) - else: - pass - new_xarray = obs_data[obs_name].sel(time=valid_times) - new_xarray.name = "obs" + original_shape = ds_var.shape + new_shape = list(original_shape) + + obs_dataarray = np.empty(new_shape, dtype=np.float32) + + for i, leadtime in enumerate(ds_var.coords["leadtime"].values): + valid_time = ds_var.coords["time"] + np.timedelta64(int(leadtime), "h") + if verif_var == "mslp": + obs_dataarray[:, i, :] = compute_mslp(obs_data, valid_time) + if verif_var == "tp": + obs_dataarray[:, i, :] = compute_precip( + obs_data, self.zarr_dt, valid_time) + else: + obs_dataarray[:, i, :] = obs_data.data_vars[obs_name].sel(time=valid_time) + + obs_dataarray = ds_var.copy(data=obs_dataarray) + obs_dataarray.name = "obs" - return new_xarray + return obs_dataarray + def preprocess(self, ds: xr.Dataset) -> xr.Dataset: """ @@ -289,7 +299,7 @@ def preprocess(self, ds: xr.Dataset) -> xr.Dataset: return ds def regrid( - self, ds: xr.Dataset, verif_var: str, valid_times: list[np.datetime64] + self, ds: xr.Dataset, verif_var: str ) -> xr.Dataset: """ Regrid a single xarray Dataset using specific method. @@ -297,7 +307,7 @@ def regrid( ---------- ds: native xarray Dataset Returns - -------s + ------- Regridded xarray Dataset. """ mapped_info = self.mapping.get(verif_var, {}) @@ -308,9 +318,9 @@ def regrid( _logger.info(f"{wg_var} not available in WeatherGenerator output: {e}") return # set coords - new_coords = ds_var.coords.copy() - new_coords.update({"location": self.obs.location.values}) - new_coords._drop_coords(["ncells"]) + new_coords = {"time": np.atleast_1d(ds_var.coords["time"].values), + "location": self.obs.location.values, + "leadtime": ds_var.coords["leadtime"].values.astype("float32")} # set attrs attrs = ds_var.attrs.copy() @@ -331,22 +341,14 @@ def regrid( method_factory = InterpolatorFactory(self.method) interpolator = method_factory.get_interpolator(self.zarr_coords, self.obs_coords) - # fix lat, lon - latitude_da = xr.DataArray( - self.obs.latitude.values, dims=["location"], attrs=ds_var.latitude.attrs - ) - longitude_da = xr.DataArray( - self.obs.longitude.values, dims=["location"], attrs=ds_var.longitude.attrs - ) - - for idx in range(len(valid_times)): + for idx in range(len(ds_var.coords["leadtime"].values)): regrid_values = interpolator.interpolate(ds_var.values[:, idx]) fcstdata[idx, :] = regrid_values - + regridded_var = xr.DataArray( - fcstdata, - dims=["time", "location"], - coords={**new_coords, "latitude": latitude_da, "longitude": longitude_da}, + np.array([fcstdata]), + dims=["time", "leadtime", "location"], + coords={**new_coords}, name="fcst", attrs=attrs, ) @@ -487,8 +489,7 @@ def add_metadata(self, ds: xr.Dataset, verif_var) -> xr.Dataset: np.datetime64("now"), unit="s" ) ds.attrs["Conventions"] = "verif_1.0.0" - # drop stream now it's in folder layout - ds = ds.drop_vars("stream") + return ds def _attrs_gaussian_grid(self, ds: xr.Dataset) -> xr.Dataset: @@ -602,28 +603,35 @@ def _build_coordinate_mapping( def merge(self, ds, obs_ds): lat, lon, alt = get_obs_coordinates(self.obs) - merged = xr.merge([ds, obs_ds, lat, lon, alt], join="outer") + merged = xr.merge([ds, obs_ds, lat, lon, alt]) return merged - def save(self, ds: xr.Dataset, forecast_ref_time: np.datetime64, verif_var: str) -> None: + def save(self, list_samples: list) -> None: """ Save the dataset to a NetCDF file. Parameters ---------- - ds : xarray Dataset to save. - data_type : Type of data ('pred' or 'targ') to include in the filename. - forecast_ref_time : Forecast reference time to include in the filename. + list_samples : list of dictionary containing variables to merge and save. + Each dictionary corresponds to a sample and contains variables for that sample. Returns ------- None """ - out_fname = self.get_output_filename(forecast_ref_time, verif_var) - _logger.info(f"Saving to {out_fname}.") - ds.to_netcdf(out_fname) - _logger.info(f"Saved NetCDF file to {out_fname}.") - _logger.info( - f"Saved {verif_var} data for {forecast_ref_time} to" - f" {self.output_format} in {self.output_dir}." - ) + for verif_var in self.mapping.keys(): + var_list = [sample[verif_var] for sample in list_samples if verif_var in sample] + if all(v is None for v in var_list): + _logger.warning(f"No data to save for variable {verif_var}. Skipping.") + continue + ds = xr.concat(var_list, + data_vars="all", + dim="time") + out_fname = self.get_output_filename(verif_var) + _logger.info(f"Saving to {out_fname}.") + ds.to_netcdf(out_fname) + _logger.info(f"Saved NetCDF file to {out_fname}.") + _logger.info( + f"Saved {verif_var} data to" + f" {self.output_format} in {self.output_dir}." + ) \ No newline at end of file From 6a48b8a7b9b1ee54751bb026a3a2f559f0cf209d Mon Sep 17 00:00:00 2001 From: Sorcha Date: Fri, 27 Feb 2026 19:52:00 +0100 Subject: [PATCH 14/28] linting --- .../evaluate/export/parsers/verif_parser.py | 30 +++++++------------ .../src/weathergen/evaluate/export/reshape.py | 13 ++++++-- 2 files changed, 21 insertions(+), 22 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py b/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py index 5d3951577..74d00ed40 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py +++ b/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py @@ -102,7 +102,6 @@ def process_sample( result = self.reshape(result) da_fs.append(result) - _logger.info(f"Retrieved {len(da_fs)} forecast steps for type {self.data_type}.") if da_fs: @@ -128,7 +127,7 @@ def process_sample( vars_to_merge[verif_var] = merged return vars_to_merge - + def get_zarr_dt(self, ds: xr.Dataset) -> np.timedelta64: """ Compute the time difference between forecast steps in hours from the WG output dataset. @@ -254,8 +253,7 @@ def obs_preprocess(self, ds_var, verif_var: str) -> xr.DataArray: if verif_var == "mslp": obs_dataarray[:, i, :] = compute_mslp(obs_data, valid_time) if verif_var == "tp": - obs_dataarray[:, i, :] = compute_precip( - obs_data, self.zarr_dt, valid_time) + obs_dataarray[:, i, :] = compute_precip(obs_data, self.zarr_dt, valid_time) else: obs_dataarray[:, i, :] = obs_data.data_vars[obs_name].sel(time=valid_time) @@ -263,7 +261,6 @@ def obs_preprocess(self, ds_var, verif_var: str) -> xr.DataArray: obs_dataarray.name = "obs" return obs_dataarray - def preprocess(self, ds: xr.Dataset) -> xr.Dataset: """ @@ -298,9 +295,7 @@ def preprocess(self, ds: xr.Dataset) -> xr.Dataset: else: return ds - def regrid( - self, ds: xr.Dataset, verif_var: str - ) -> xr.Dataset: + def regrid(self, ds: xr.Dataset, verif_var: str) -> xr.Dataset: """ Regrid a single xarray Dataset using specific method. Parameters @@ -318,9 +313,11 @@ def regrid( _logger.info(f"{wg_var} not available in WeatherGenerator output: {e}") return # set coords - new_coords = {"time": np.atleast_1d(ds_var.coords["time"].values), - "location": self.obs.location.values, - "leadtime": ds_var.coords["leadtime"].values.astype("float32")} + new_coords = { + "time": np.atleast_1d(ds_var.coords["time"].values), + "location": self.obs.location.values, + "leadtime": ds_var.coords["leadtime"].values.astype("float32"), + } # set attrs attrs = ds_var.attrs.copy() @@ -344,7 +341,7 @@ def regrid( for idx in range(len(ds_var.coords["leadtime"].values)): regrid_values = interpolator.interpolate(ds_var.values[:, idx]) fcstdata[idx, :] = regrid_values - + regridded_var = xr.DataArray( np.array([fcstdata]), dims=["time", "leadtime", "location"], @@ -624,14 +621,9 @@ def save(self, list_samples: list) -> None: if all(v is None for v in var_list): _logger.warning(f"No data to save for variable {verif_var}. Skipping.") continue - ds = xr.concat(var_list, - data_vars="all", - dim="time") + ds = xr.concat(var_list, data_vars="all", dim="time") out_fname = self.get_output_filename(verif_var) _logger.info(f"Saving to {out_fname}.") ds.to_netcdf(out_fname) _logger.info(f"Saved NetCDF file to {out_fname}.") - _logger.info( - f"Saved {verif_var} data to" - f" {self.output_format} in {self.output_dir}." - ) \ No newline at end of file + _logger.info(f"Saved {verif_var} data to {self.output_format} in {self.output_dir}.") diff --git a/packages/evaluate/src/weathergen/evaluate/export/reshape.py b/packages/evaluate/src/weathergen/evaluate/export/reshape.py index b6f627c27..3516ea359 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/reshape.py +++ b/packages/evaluate/src/weathergen/evaluate/export/reshape.py @@ -574,6 +574,7 @@ def regrid_da(self, da: xr.DataArray) -> xr.DataArray: ## functions for verif + def convert_coordinates(coords: np.typing.NDArray) -> np.typing.NDArray: """ Convert lat-lon coordinates to cartesian coordinates in a unit box @@ -668,7 +669,9 @@ def compute_weights(self, grid_xyz: np.typing.NDArray, obs_xyz: np.typing.NDArra self.weights = self.weights / self.weights.sum(axis=1)[:, np.newaxis] - def interpolate(self, values: np.typing.NDArray, intmap: np.typing.NDArray = None) -> np.typing.NDArray: + def interpolate( + self, values: np.typing.NDArray, intmap: np.typing.NDArray = None + ) -> np.typing.NDArray: """ Interpolate values to points """ @@ -704,7 +707,9 @@ def __init__(self, grid_points, obs_points): self.obs_points = obs_points self.triangulation = Delaunay(grid_points) - def interpolate(self, values: np.typing.NDArray, intmap: np.typing.NDArray = None) -> np.typing.NDArray: + def interpolate( + self, values: np.typing.NDArray, intmap: np.typing.NDArray = None + ) -> np.typing.NDArray: """ Interpolate values to points """ @@ -738,7 +743,9 @@ def __init__(self, grid_points: np.typing.NDArray, obs_points: np.typing.NDArray tree = KDTree(grid_xyz) _, self.indices = tree.query(obs_xyz, k=1) - def interpolate(self, values: np.typing.NDArray, intmap: np.typing.NDArray = None) -> np.typing.NDArray: + def interpolate( + self, values: np.typing.NDArray, intmap: np.typing.NDArray = None + ) -> np.typing.NDArray: """ Interpolate values to points """ From a80892c3e3ab791ca0eba3d4e87624d918ab6647 Mon Sep 17 00:00:00 2001 From: Sorcha Owens Date: Mon, 9 Mar 2026 09:28:14 +0100 Subject: [PATCH 15/28] fxing attributes --- config/evaluate/config_zarr2verif.yaml | 7 +--- .../weathergen/evaluate/export/export_core.py | 3 +- .../evaluate/export/parsers/verif_parser.py | 41 +++++++++++-------- 3 files changed, 25 insertions(+), 26 deletions(-) diff --git a/config/evaluate/config_zarr2verif.yaml b/config/evaluate/config_zarr2verif.yaml index 589626d40..2ac3ba19c 100644 --- a/config/evaluate/config_zarr2verif.yaml +++ b/config/evaluate/config_zarr2verif.yaml @@ -74,7 +74,6 @@ variables: coordinates: sfc: - valid_time: valid_time lat: latitude lon: longitude stream: stream @@ -84,7 +83,6 @@ coordinates: pl: #not needed pressure_level: pressure - valid_time: valid_time lat: latitude lon: longitude stream: stream @@ -93,9 +91,6 @@ coordinates: ncells: ncells dimensions: - valid_time: - verif: valid_time - std: valid_time lat: verif: latitude std: latitude @@ -115,7 +110,7 @@ dimensions: verif: leadtime std: forecast_period long: time since forecast_reference_time - verif_unit: hours + verif_unit: hour stream: verif: stream std: stream diff --git a/packages/evaluate/src/weathergen/evaluate/export/export_core.py b/packages/evaluate/src/weathergen/evaluate/export/export_core.py index eb5ab09e3..74f4aeb07 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/export_core.py +++ b/packages/evaluate/src/weathergen/evaluate/export/export_core.py @@ -260,14 +260,13 @@ def export_model_outputs(data_type: str, config: OmegaConf, **kwargs) -> None: results_iterator = pool.imap_unordered(get_data_worker, step_tasks, chunksize=1) - processed = parser.process_sample( results_iterator, ref_time=ref_time, ) processed_samples.append(processed) - + # Only save here if need to merge samples, otherwise save in process_sample if processed_samples[0] is not None: parser.save(processed_samples) diff --git a/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py b/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py index 74d00ed40..c57c22134 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py +++ b/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py @@ -112,7 +112,6 @@ def process_sample( da_fs = self.concatenate(da_fs) da_fs = self.assign_frt(da_fs, ref_time) da_fs = self.add_attrs(da_fs) - vars_to_merge = {verif_var: None for verif_var in self.mapping.keys()} for verif_var in self.mapping.keys(): @@ -313,13 +312,21 @@ def regrid(self, ds: xr.Dataset, verif_var: str) -> xr.Dataset: _logger.info(f"{wg_var} not available in WeatherGenerator output: {e}") return # set coords + # TODO: tidy this up new_coords = { - "time": np.atleast_1d(ds_var.coords["time"].values), - "location": self.obs.location.values, - "leadtime": ds_var.coords["leadtime"].values.astype("float32"), + "time": (["time"], np.atleast_1d(ds_var.coords["time"].values), ds_var["time"].attrs), + "location": ( + ["location"], + self.obs.location.values, + {"long_name": "Norwegian station ID"}, + ), + "leadtime": ( + ["leadtime"], + ds_var.coords["leadtime"].values.astype("float32"), + ds_var["leadtime"].attrs, + ), } - - # set attrs + # set variable attrs attrs = ds_var.attrs.copy() with contextlib.suppress(KeyError): del attrs["ncells"] # @@ -434,7 +441,6 @@ def add_attrs(self, ds: xr.Dataset) -> xr.Dataset: """ variables = self._attrs_gaussian_grid(ds) dataset = xr.merge(variables.values(), compat="no_conflicts") - dataset.attrs = ds.attrs return dataset def add_encoding(self, ds: xr.Dataset) -> xr.Dataset: @@ -450,8 +456,8 @@ def add_encoding(self, ds: xr.Dataset) -> xr.Dataset: xarray Dataset with time encoding added. """ time_encoding = { - "units": "hours since 1970-01-01 00:00:00", - "calendar": "gregorian", + "units": "seconds since 1970-01-01 00:00:00", + "calendar": "proleptic_gregorian", } if "time" in ds.coords: @@ -485,7 +491,7 @@ def add_metadata(self, ds: xr.Dataset, verif_var) -> xr.Dataset: ds.attrs["history"] = "Created using the verif_parser on " + np.datetime_as_string( np.datetime64("now"), unit="s" ) - ds.attrs["Conventions"] = "verif_1.0.0" + ds.attrs["conventions"] = "verif_1.0.0" return ds @@ -556,17 +562,16 @@ def _assign_dim_attrs( ds_attrs = {} for dim_name, meta in dim_cfg.items(): - wg_name = meta.get("verif", dim_name) - if dim_name in ds.dims and dim_name != wg_name: - ds = ds.rename_dims({dim_name: wg_name}) + verif_name = meta.get("verif", dim_name) + if dim_name in ds.dims and dim_name != verif_name: + ds = ds.rename_dims({dim_name: verif_name}) - dim_attrs = {"standard_name": meta.get("std", wg_name)} + dim_attrs = {"standard_name": meta.get("std", verif_name)} if meta.get("verif_unit"): dim_attrs["units"] = meta["verif_unit"] if meta.get("long"): dim_attrs["long_name"] = meta["long"] - ds_attrs[wg_name] = dim_attrs - + ds_attrs[verif_name] = dim_attrs return ds, ds_attrs def _build_coordinate_mapping( @@ -600,7 +605,7 @@ def _build_coordinate_mapping( def merge(self, ds, obs_ds): lat, lon, alt = get_obs_coordinates(self.obs) - merged = xr.merge([ds, obs_ds, lat, lon, alt]) + merged = xr.merge([ds, obs_ds, lat, lon, alt], compat="minimal") return merged def save(self, list_samples: list) -> None: @@ -621,7 +626,7 @@ def save(self, list_samples: list) -> None: if all(v is None for v in var_list): _logger.warning(f"No data to save for variable {verif_var}. Skipping.") continue - ds = xr.concat(var_list, data_vars="all", dim="time") + ds = xr.concat(var_list, data_vars="all", dim="location") out_fname = self.get_output_filename(verif_var) _logger.info(f"Saving to {out_fname}.") ds.to_netcdf(out_fname) From 90719b5b388ac26d2365a62da33b775cda4dace0 Mon Sep 17 00:00:00 2001 From: Sorcha Owens Date: Thu, 12 Mar 2026 15:57:41 +0100 Subject: [PATCH 16/28] setting at least 1d for single fsteps --- .../src/weathergen/evaluate/export/parsers/verif_parser.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py b/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py index c57c22134..a01df0f51 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py +++ b/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py @@ -322,7 +322,7 @@ def regrid(self, ds: xr.Dataset, verif_var: str) -> xr.Dataset: ), "leadtime": ( ["leadtime"], - ds_var.coords["leadtime"].values.astype("float32"), + np.atleast_1d(ds_var.coords["leadtime"].values.astype("float32")), ds_var["leadtime"].attrs, ), } @@ -345,7 +345,9 @@ def regrid(self, ds: xr.Dataset, verif_var: str) -> xr.Dataset: method_factory = InterpolatorFactory(self.method) interpolator = method_factory.get_interpolator(self.zarr_coords, self.obs_coords) - for idx in range(len(ds_var.coords["leadtime"].values)): + num_leadtimes = np.atleast_1d(ds_var.coords["leadtime"].values).shape[0] + + for idx in range(num_leadtimes): regrid_values = interpolator.interpolate(ds_var.values[:, idx]) fcstdata[idx, :] = regrid_values From ab6df3a2f87b8012bc6ab00bf130236b9f36b9c4 Mon Sep 17 00:00:00 2001 From: Sorcha Owens Date: Fri, 13 Mar 2026 11:09:33 +0100 Subject: [PATCH 17/28] fixing duplication of location --- .../src/weathergen/evaluate/export/parsers/verif_parser.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py b/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py index a01df0f51..ce7b5309a 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py +++ b/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py @@ -628,7 +628,7 @@ def save(self, list_samples: list) -> None: if all(v is None for v in var_list): _logger.warning(f"No data to save for variable {verif_var}. Skipping.") continue - ds = xr.concat(var_list, data_vars="all", dim="location") + ds = xr.concat(var_list, dim="time", data_vars="minimal", coords="minimal", join="exact") out_fname = self.get_output_filename(verif_var) _logger.info(f"Saving to {out_fname}.") ds.to_netcdf(out_fname) From ed02cb47a0181beb089392375480336015be38dc Mon Sep 17 00:00:00 2001 From: Sorcha Owens Date: Fri, 13 Mar 2026 14:38:59 +0100 Subject: [PATCH 18/28] change to match rest of export package --- .../src/weathergen/evaluate/export/parsers/verif_parser.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py b/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py index ce7b5309a..b03d3b351 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py +++ b/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py @@ -1,3 +1,5 @@ +# pylint: disable=bad-builtin + import contextlib import logging from pathlib import Path From 2df28289dff41765a7568cb98857b84154507c70 Mon Sep 17 00:00:00 2001 From: Sorcha Owens Date: Fri, 13 Mar 2026 14:40:11 +0100 Subject: [PATCH 19/28] linting --- .../src/weathergen/evaluate/export/parsers/verif_parser.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py b/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py index b03d3b351..556b15fa8 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py +++ b/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py @@ -630,7 +630,9 @@ def save(self, list_samples: list) -> None: if all(v is None for v in var_list): _logger.warning(f"No data to save for variable {verif_var}. Skipping.") continue - ds = xr.concat(var_list, dim="time", data_vars="minimal", coords="minimal", join="exact") + ds = xr.concat( + var_list, dim="time", data_vars="minimal", coords="minimal", join="exact" + ) out_fname = self.get_output_filename(verif_var) _logger.info(f"Saving to {out_fname}.") ds.to_netcdf(out_fname) From 518d716110fad1e614cd8bbf1f09ebd4c1e6a117 Mon Sep 17 00:00:00 2001 From: Rolf Heilemann Myhre Date: Fri, 13 Mar 2026 15:26:27 +0100 Subject: [PATCH 20/28] declare some variables --- .../src/weathergen/evaluate/export/parsers/verif_parser.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py b/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py index 556b15fa8..bff561526 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py +++ b/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py @@ -69,6 +69,9 @@ def __init__(self, config: OmegaConf, **kwargs): self.obs_coords = np.column_stack((lat.values, lon.values)) self.zarr_coords = None + self.channels: list(str()) | None = None + self.zarr_dt: np.timedelta64 | None = None + def process_sample( self, fstep_iterator_results: iter, From 2e2e0f1145b6474c4c273b0be70bee333198c6bf Mon Sep 17 00:00:00 2001 From: Rolf Heilemann Myhre Date: Fri, 13 Mar 2026 15:28:40 +0100 Subject: [PATCH 21/28] ruff did not like --- .../src/weathergen/evaluate/export/parsers/verif_parser.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py b/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py index bff561526..2b8a5afb1 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py +++ b/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py @@ -69,7 +69,7 @@ def __init__(self, config: OmegaConf, **kwargs): self.obs_coords = np.column_stack((lat.values, lon.values)) self.zarr_coords = None - self.channels: list(str()) | None = None + self.channels: list("foo") | None = None self.zarr_dt: np.timedelta64 | None = None def process_sample( From f21dc0b251f6ce128ed4f04128981b7412980c00 Mon Sep 17 00:00:00 2001 From: Rolf Heilemann Myhre Date: Fri, 13 Mar 2026 15:49:40 +0100 Subject: [PATCH 22/28] declare self.channels correctly --- .../src/weathergen/evaluate/export/parsers/verif_parser.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py b/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py index 2b8a5afb1..c58128ee0 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py +++ b/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py @@ -69,7 +69,8 @@ def __init__(self, config: OmegaConf, **kwargs): self.obs_coords = np.column_stack((lat.values, lon.values)) self.zarr_coords = None - self.channels: list("foo") | None = None + required_channels = ["10u", "10v", "sp", "2t", "msl"] + self.channels = list(set(self.channels) & set(required_channels)) self.zarr_dt: np.timedelta64 | None = None def process_sample( @@ -94,8 +95,6 @@ def process_sample( ) return - required_channels = ["10u", "10v", "sp", "2t", "msl"] - self.channels = list(set(self.channels) & set(required_channels)) da_fs = [] for result in fstep_iterator_results: if result is None: From 59878141be6c9df3688fef3a590cd5ee2dab9738 Mon Sep 17 00:00:00 2001 From: Rolf Heilemann Myhre Date: Tue, 17 Mar 2026 14:30:59 +0100 Subject: [PATCH 23/28] update some code missed in merge --- .../evaluate/export/parsers/verif_parser.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py b/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py index c58128ee0..e46332d1a 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py +++ b/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py @@ -187,27 +187,32 @@ def reshape(self, data: xr.DataArray) -> xr.Dataset: grid_type = self.grid_type # Original logic - var_dict, pl = find_pl(data.channel.values) + var_dict = find_pl(data.channel.values) data_vars = {} - for new_var, old_vars in var_dict.items(): - if len(old_vars) > 1: + for new_var, pls in var_dict.items(): + if pls[0] is not None: + old_vars = [f"{new_var}_{p}" for p in pls] data_vars[new_var] = xr.DataArray( data.sel(channel=old_vars).values, dims=["ipoint", "pressure_level"], + coords={"pressure_level":pls}, ) else: data_vars[new_var] = xr.DataArray( - data.sel(channel=old_vars[0]).values, + data.sel(channel=new_var).values, dims=["ipoint"], ) reshaped_dataset = xr.Dataset(data_vars) reshaped_dataset = reshaped_dataset.assign_coords( ipoint=data.coords["ipoint"], - pressure_level=pl, ) + # order using pressure_level coord + if "pressure_level" in reshaped_dataset.coords: + reshaped_dataset = reshaped_dataset.sortby("pressure_level") + if grid_type == "regular": # Use original reshape logic for regular grids # This is safe for regular grids From 30de27f42b2572490665772099ca5b7d19c9a84e Mon Sep 17 00:00:00 2001 From: Rolf Heilemann Myhre Date: Thu, 19 Mar 2026 14:27:09 +0100 Subject: [PATCH 24/28] revert to earthkit-data 0.17.0 --- packages/evaluate/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/evaluate/pyproject.toml b/packages/evaluate/pyproject.toml index 6a178833b..9215070f7 100644 --- a/packages/evaluate/pyproject.toml +++ b/packages/evaluate/pyproject.toml @@ -17,7 +17,7 @@ dependencies = [ "eccodes==2.44.0", "eccodeslib==2.44.0.7", "eckitlib==1.32.3.7", - "earthkit-data==0.18.2", + "earthkit-data==0.17.0", "earthkit-utils==0.1.2" ] From c89c58d40d2ffe2d3f73bdce235ae58f8c325a39 Mon Sep 17 00:00:00 2001 From: Sorcha Date: Thu, 9 Apr 2026 20:15:32 +0200 Subject: [PATCH 25/28] _interval_start andddddddddddddd end --- .github/ISSUE_TEMPLATE/initiative.yml | 1 + .github/ISSUE_TEMPLATE/task.yml | 1 + .github/workflows/issue_set_label.yml | 3 +- ci/cscs.yaml | 9 +- config/config_forecasting_eerie.yml | 250 +++++++ config/config_jepa_forecasting_finetuning.yml | 17 +- config/default_config.yml | 11 +- config/evaluate/config_zarr2cf.yaml | 7 +- config/evaluate/config_zarr2verif.yaml | 5 - config/evaluate/eval_config.yml | 3 + config/evaluate/eval_config_default.yml | 64 ++ .../streams/eerie_downscaling/eerie_atmo.yml | 41 ++ .../eerie_downscaling/eerie_ocean_elem.yml | 37 + .../eerie_downscaling/eerie_ocean_node.yml | 39 ++ config/streams/eerie_gridded/eerie_atmo.yml | 39 ++ .../eerie_gridded/eerie_ocean_elem.yml | 34 + .../eerie_gridded/eerie_ocean_node.yml | 36 + config/streams/eerie_native/eerie_atmo.yml | 40 ++ .../streams/eerie_native/eerie_ocean_elem.yml | 37 + .../streams/eerie_native/eerie_ocean_node.yml | 39 ++ integration_tests/small1.yml | 28 +- integration_tests/small1_test.py | 6 - integration_tests/small_multi_stream.yaml | 29 +- integration_tests/small_multi_stream_test.py | 12 +- .../common/src/weathergen/common/paths.py | 4 +- .../weathergen/evaluate/export/export_core.py | 249 +++++-- .../evaluate/export/export_inference.py | 4 +- .../evaluate/export/parsers/netcdf_parser.py | 9 +- .../evaluate/export/parsers/quaver_parser.py | 3 +- .../evaluate/export/parsers/verif_parser.py | 6 +- .../weathergen/evaluate/plotting/plotter.py | 75 +- .../src/weathergen/evaluate/run_evaluation.py | 44 +- .../readers_extra/data_reader_mesh.py | 185 +++-- src/weathergen/datasets/data_reader_fesom.py | 49 +- src/weathergen/model/engines.py | 6 +- src/weathergen/model/model_interface.py | 61 -- .../train/target_and_aux_ssl_teacher.py | 164 ++++- src/weathergen/train/target_and_aux_utils.py | 83 +++ src/weathergen/train/teacher_utils.py | 135 ++++ src/weathergen/train/trainer.py | 4 +- src/weathergen/utils/train_logger.py | 6 - tests/test_encoder_teacher.py | 647 ++++++++++++++++++ 42 files changed, 2159 insertions(+), 363 deletions(-) create mode 100644 config/config_forecasting_eerie.yml create mode 100644 config/evaluate/eval_config_default.yml create mode 100644 config/streams/eerie_downscaling/eerie_atmo.yml create mode 100644 config/streams/eerie_downscaling/eerie_ocean_elem.yml create mode 100644 config/streams/eerie_downscaling/eerie_ocean_node.yml create mode 100644 config/streams/eerie_gridded/eerie_atmo.yml create mode 100644 config/streams/eerie_gridded/eerie_ocean_elem.yml create mode 100644 config/streams/eerie_gridded/eerie_ocean_node.yml create mode 100644 config/streams/eerie_native/eerie_atmo.yml create mode 100644 config/streams/eerie_native/eerie_ocean_elem.yml create mode 100644 config/streams/eerie_native/eerie_ocean_node.yml create mode 100644 src/weathergen/train/target_and_aux_utils.py create mode 100644 src/weathergen/train/teacher_utils.py create mode 100644 tests/test_encoder_teacher.py diff --git a/.github/ISSUE_TEMPLATE/initiative.yml b/.github/ISSUE_TEMPLATE/initiative.yml index 83bb8db58..5b9c42128 100644 --- a/.github/ISSUE_TEMPLATE/initiative.yml +++ b/.github/ISSUE_TEMPLATE/initiative.yml @@ -50,6 +50,7 @@ body: - label: infrastructure and engineering - label: evaluation, export and visualization - label: documentation + - label: performance validations: required: true diff --git a/.github/ISSUE_TEMPLATE/task.yml b/.github/ISSUE_TEMPLATE/task.yml index 2008dd81a..68cd0fefa 100644 --- a/.github/ISSUE_TEMPLATE/task.yml +++ b/.github/ISSUE_TEMPLATE/task.yml @@ -31,6 +31,7 @@ body: - label: infrastructure and engineering - label: evaluation, export and visualization - label: documentation + - label: performance validations: required: true diff --git a/.github/workflows/issue_set_label.yml b/.github/workflows/issue_set_label.yml index de1ff4ac5..419e68ad5 100644 --- a/.github/workflows/issue_set_label.yml +++ b/.github/workflows/issue_set_label.yml @@ -21,7 +21,8 @@ jobs: "science": "science", "infrastructure and engineering": "infra", "evaluation, export and visualization": "eval", - "documentation": "documentation" + "documentation": "documentation", + "performance": "performance" }; const issue = context.payload.issue; diff --git a/ci/cscs.yaml b/ci/cscs.yaml index 7655028ae..eb2d444b4 100644 --- a/ci/cscs.yaml +++ b/ci/cscs.yaml @@ -2,10 +2,10 @@ include: - remote: 'https://gitlab.com/cscs-ci/recipes/-/raw/master/templates/v2/.ci-ext.yml' stages: - - test-single + - test test_job: - stage: test-single + stage: test extends: .uenv-runner-santis-gh200 image: prgenv-gnu/25.6:v2 script: | @@ -21,9 +21,12 @@ test_job: --branch "$PRIVATE_REPO_BRANCH" \ https://oauth2:${PRIVATE_REPO_TOKEN}@gitlab.jsc.fz-juelich.de/esde/WeatherGenerator-private.git + echo "Sync" ./scripts/actions.sh sync + echo "Create links and run tests" ./scripts/actions.sh create-links - ./scripts/actions.sh integration-test-single + echo "Run tests" + ./scripts/actions.sh integration-test${STAGE_TYPE:-} variables: SLURM_JOB_NUM_NODES: 1 WITH_UENV_VIEW: 'modules' diff --git a/config/config_forecasting_eerie.yml b/config/config_forecasting_eerie.yml new file mode 100644 index 000000000..53466ad7d --- /dev/null +++ b/config/config_forecasting_eerie.yml @@ -0,0 +1,250 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +embed_orientation: "channels" +embed_unembed_mode: "block" +embed_dropout_rate: 0.1 + +ae_local_dim_embed: 2048 +ae_local_num_blocks: 4 +ae_local_num_heads: 16 +ae_local_dropout_rate: 0.1 +ae_local_with_qk_lnorm: True + +ae_local_num_queries: 1 +ae_local_queries_per_cell: False +ae_adapter_num_heads: 16 +ae_adapter_embed: 128 +ae_adapter_with_qk_lnorm: True +ae_adapter_with_residual: True +ae_adapter_dropout_rate: 0.1 + +ae_global_dim_embed: 2048 +ae_global_num_blocks: 4 +ae_global_num_heads: 32 +ae_global_dropout_rate: 0.1 +ae_global_with_qk_lnorm: True +# TODO: switching to < 1 triggers triton-related issues. +# See https://github.com/ecmwf/WeatherGenerator/issues/1050 +ae_global_att_dense_rate: 1.0 +ae_global_block_factor: 64 +ae_global_mlp_hidden_factor: 2 +ae_global_trailing_layer_norm: False + +ae_aggregation_num_blocks: 8 +ae_aggregation_num_heads: 32 +ae_aggregation_dropout_rate: 0.1 +ae_aggregation_with_qk_lnorm: True +ae_aggregation_att_dense_rate: 1.0 +ae_aggregation_block_factor: 64 +ae_aggregation_mlp_hidden_factor: 2 + +decoder_type: PerceiverIOCoordConditioning # Main options PerceiverIOCoordConditioning or Linear +pred_adapter_kv: False +pred_self_attention: True +pred_dyadic_dims: False +pred_mlp_adaln: True +num_class_tokens: 0 +num_register_tokens: 0 + +# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then +# one is training an auto-encoder +fe_num_blocks: 16 +fe_num_heads: 16 +fe_dropout_rate: 0.1 +fe_with_qk_lnorm: True +fe_layer_norm_after_blocks: [7] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer +fe_impute_latent_noise_std: 1e-4 +# currently fixed to 1.0 (due to limitations with flex_attention and triton) +forecast_att_dense_rate: 1.0 + +healpix_level: 5 + +rope_2D: False + +with_mixed_precision: True +with_flash_attention: True +compile_model: False +with_fsdp: True +attention_dtype: bf16 +mixed_precision_dtype: bf16 +mlp_norm_eps: 1e-5 +norm_eps: 1e-4 + +latent_noise_kl_weight: 0.0 # 1e-5 +latent_noise_gamma: 2.0 +latent_noise_saturate_encodings: 5 +latent_noise_use_additive_noise: False +latent_noise_deterministic_latents: True + +freeze_modules: "" +load_chkpt: {} + +norm_type: "LayerNorm" + +##################################### + +streams_directory: "./config/streams/eerie_native/" +streams: ??? + +# type of zarr_store +zarr_store: "zip" # "zarr" for LocalStore, "zip" for ZipStore + +general: + + # mutable parameters + istep: 0 + rank: ??? + world_size: ??? + + # local_rank, + # with_ddp, + # data_path_*, + # model_path, + # run_path, + # path_shared_ + + multiprocessing_method: "fork" + + desc: "" + run_id: ??? + run_history: [] + +# logging frequency in the training loop (in number of batches) +train_logging: + terminal: 10 + metrics: 20 + checkpoint: 250 + +# parameters for data loading +data_loading : + + num_workers: 12 + rng_seed: ??? + repeat_data_in_mini_epoch : False + + +# config for training +training_config: + + # training_mode: "masking", "student_teacher", "latent_loss" + training_mode: ["masking"] + + num_mini_epochs: 64 + samples_per_mini_epoch: 4096 + shuffle: True + + start_date: 1950-01-01T00:00 + end_date: 2004-12-31T00:00 + + time_window_step: 24:00:00 + time_window_len: 24:00:00 + + learning_rate_scheduling : + lr_start: 1e-6 + lr_max: 5e-5 + lr_final_decay: 2e-6 + lr_final: 0.0 + num_steps_warmup: 256 + num_steps_cooldown: 512 + policy_warmup: "cosine" + policy_decay: "constant" + policy_cooldown: "linear" + parallel_scaling_policy: "sqrt" + + optimizer: + grad_clip: 1.0 + weight_decay: 0.1 + log_grad_norms: False + adamw : + # parameters are scaled by number of DDP workers + beta1 : 0.98125 # == 0.85 on 2 nodes x 4 gpus + beta2 : 0.9875 # == 0.90 on 2 nodes x 4 gpus + eps : 2e-08 + + losses : { + "physical": { + type: LossPhysical, + loss_fcts: { "mse": { }, }, + }, + } + + model_input: { + "forecasting" : { + # masking strategy: "random", "healpix", "forecast" + masking_strategy: "forecast", + }, + } + + forecast : + time_step: 24:00:00 + offset: 1 + num_steps: 3 + policy: "fixed" + + +# validation config; full validation config is merge of training and validation config +validation_config: + + samples_per_mini_epoch: 256 + shuffle: False + + start_date: 2004-01-01T00:00 + end_date: 2009-12-31T00:00 + + # whether to track the exponential moving average of weights for validation + validate_with_ema: + enabled : True + ema_ramp_up_ratio: 0.09 + ema_halflife_in_thousands: 1e-3 + + # parameters for validation samples that are written to disk + output : { + # number of samples that are written + num_samples: 0, + # write samples in normalized model space + normalized_samples: False, + # output streams to write; default all + streams: null, + } + + # run validation before training starts (mainly for model development) + validate_before_training: False + + +# test config; full test config is merge of validation and test config +# test config is used by default when running inference + +# Tags for experiment tracking +# These tags will be logged in MLFlow along with completed runs for train, eval, val +# The tags are free-form, with the following rules: +# - tags should be primitive types (strings, numbers, booleans). NO lists or dictionaries +# - tags should not duplicate existing config entries. +# - try to reuse existing tags where possible. MLFlow does not like having too many unique tags +# - do not use long strings in values (less than 20 characters is a good rule of thumb, we may enforce this in the future) +wgtags: + # The name of the organization of the person running the experiment. + # This may be autofilled in the future. Expected values are lowercase strings + # e.g. "ecmwf", "cmcc", "metnor", "jsc", "escience" + org: null + # The Github issue corresponding to this run (number such as 1234) + # Github issues are the central point when running experiment and contain + # links to hedgedocs, code branches, pull requests etc. + # It is recommended to associate a run with a Github issue. + issue: null + # The name of the experiment. This is a distinctive codename for the experiment campaign being run. + # This is expected to be the primary tag for comparing experiments in MLFlow, along with the + # issue number. + # Expected values are lowercase strings with no spaces, just underscores: + # Examples: "rollout_ablation_grid" + exp: null + # *** Experiment-specific tags *** + # All extra tags (including lists, dictionaries, etc.) are treated + # as strings by mlflow, so treat all extra tags as simple string key: value pairs. + grid: null diff --git a/config/config_jepa_forecasting_finetuning.yml b/config/config_jepa_forecasting_finetuning.yml index 2192b8f26..36de995bf 100644 --- a/config/config_jepa_forecasting_finetuning.yml +++ b/config/config_jepa_forecasting_finetuning.yml @@ -9,15 +9,14 @@ # number of steps offset applied to first target window; if set to zero and forecast_steps=0 then # one is training an auto-encoder -fe_num_blocks: 6 +fe_num_blocks: 16 fe_num_heads: 16 fe_dropout_rate: 0.1 fe_with_qk_lnorm: True -fe_layer_norm_after_blocks: [] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer +fe_layer_norm_after_blocks: [7] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer fe_impute_latent_noise_std: 0.0 # 1e-4 # currently fixed to 1.0 (due to limitations with flex_attention and triton) forecast_att_dense_rate: 1.0 -with_step_conditioning: True # False healpix_level: 5 @@ -111,9 +110,9 @@ training_config: learning_rate_scheduling : lr_start: 1e-6 lr_max: 5e-5 - lr_final_decay: 1e-6 + lr_final_decay: 2e-6 lr_final: 0.0 - num_steps_warmup: 512 + num_steps_warmup: 256 num_steps_cooldown: 512 policy_warmup: "cosine" policy_decay: "constant" @@ -126,7 +125,7 @@ training_config: log_grad_norms: False adamw : # parameters are scaled by number of DDP workers - beta1 : 0.975 + beta1 : 0.98125 beta2 : 0.9875 eps : 2e-08 @@ -160,7 +159,7 @@ training_config: forecast : time_step: 06:00:00 - num_steps: 2 + num_steps: 3 offset: 1 policy: "fixed" @@ -176,14 +175,14 @@ validation_config: # whether to track the exponential moving average of weights for validation validate_with_ema: - enabled : False + enabled : True ema_ramp_up_ratio: 0.09 ema_halflife_in_thousands: 1e-3 # parameters for validation samples that are written to disk output : { # number of samples that are written - num_samples: 8, + num_samples: 0, # write samples in normalized model space normalized_samples: False, # output streams to write; default all diff --git a/config/default_config.yml b/config/default_config.yml index 06c0b8ffa..d62b575ce 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -11,7 +11,7 @@ embed_orientation: "channels" embed_unembed_mode: "block" embed_dropout_rate: 0.1 -ae_local_dim_embed: 512 #1024 +ae_local_dim_embed: 1024 ae_local_num_blocks: 2 ae_local_num_heads: 16 ae_local_dropout_rate: 0.1 @@ -25,7 +25,7 @@ ae_adapter_with_qk_lnorm: True ae_adapter_with_residual: True ae_adapter_dropout_rate: 0.1 -ae_global_dim_embed: 512 #1024 #2048 +ae_global_dim_embed: 2048 ae_global_num_blocks: 2 ae_global_num_heads: 32 ae_global_dropout_rate: 0.1 @@ -37,7 +37,7 @@ ae_global_block_factor: 64 ae_global_mlp_hidden_factor: 2 ae_global_trailing_layer_norm: False -ae_aggregation_num_blocks: 2 +ae_aggregation_num_blocks: 0 ae_aggregation_num_heads: 32 ae_aggregation_dropout_rate: 0.1 ae_aggregation_with_qk_lnorm: True @@ -50,8 +50,6 @@ pred_adapter_kv: False pred_self_attention: True pred_dyadic_dims: False pred_mlp_adaln: True -num_class_tokens: 1 -num_register_tokens: 7 # number of steps offset applied to first target window; if set to zero and forecast_steps=0 then # one is training an auto-encoder @@ -64,6 +62,9 @@ fe_impute_latent_noise_std: 0.0 # 1e-4 # currently fixed to 1.0 (due to limitations with flex_attention and triton) forecast_att_dense_rate: 1.0 +num_class_tokens: 0 +num_register_tokens: 0 + healpix_level: 5 # Use 2D RoPE instead of traditional global positional encoding diff --git a/config/evaluate/config_zarr2cf.yaml b/config/evaluate/config_zarr2cf.yaml index 431896336..75677858b 100644 --- a/config/evaluate/config_zarr2cf.yaml +++ b/config/evaluate/config_zarr2cf.yaml @@ -36,7 +36,7 @@ variables: wg_unit: m**2 s**-2 std_unit: m2 s-2 level_type: pl - scale_factor: 1/9.80665 + scale_factor: 1.0 #0.10197 #1/9.80665 10u: var: u10 long: u_wind_at_10m @@ -103,7 +103,6 @@ coordinates: valid_time: valid_time lat: latitude lon: longitude - stream: stream forecast_step: forecast_period forecast_reference_time: forecast_reference_time ncells: ncells @@ -112,7 +111,6 @@ coordinates: valid_time: valid_time lat: latitude lon: longitude - stream: stream forecast_step: forecast_period forecast_reference_time: forecast_reference_time ncells: ncells @@ -141,9 +139,6 @@ dimensions: std: forecast_period long: time since forecast_reference_time std_unit: hours - stream: - wg: stream - std: stream ncells: wg: ncells std: ncells \ No newline at end of file diff --git a/config/evaluate/config_zarr2verif.yaml b/config/evaluate/config_zarr2verif.yaml index 2ac3ba19c..8cc573f40 100644 --- a/config/evaluate/config_zarr2verif.yaml +++ b/config/evaluate/config_zarr2verif.yaml @@ -76,7 +76,6 @@ coordinates: sfc: lat: latitude lon: longitude - stream: stream forecast_step: leadtime forecast_reference_time: time ncells: ncells @@ -85,7 +84,6 @@ coordinates: pressure_level: pressure lat: latitude lon: longitude - stream: stream forecast_step: leadtime forecast_reference_time: time ncells: ncells @@ -111,9 +109,6 @@ dimensions: std: forecast_period long: time since forecast_reference_time verif_unit: hour - stream: - verif: stream - std: stream ncells: verif: ncells std: ncells \ No newline at end of file diff --git a/config/evaluate/eval_config.yml b/config/evaluate/eval_config.yml index bfa309a45..b24a1dd90 100644 --- a/config/evaluate/eval_config.yml +++ b/config/evaluate/eval_config.yml @@ -1,5 +1,6 @@ #optional: if commented out all is taken care of by the default settings # NB. global options apply to all run_ids + #global_plotting_options: # regions: ["europe", "global"] # image_format : "png" #options: "png", "pdf", "svg", "eps", "jpg" .. @@ -9,6 +10,8 @@ # marker_size: 2 # scale_marker_size: 1 # marker: "o" +# add_healpix_grid: false +# healpix_nside: 4 # # alpha: 0.5 # 2t: # vmin: 250 diff --git a/config/evaluate/eval_config_default.yml b/config/evaluate/eval_config_default.yml new file mode 100644 index 000000000..a5defae02 --- /dev/null +++ b/config/evaluate/eval_config_default.yml @@ -0,0 +1,64 @@ +#optional: if commented out all is taken care of by the default settings +# NB. global options apply to all run_ids +#global_plotting_options: +# regions: ["europe", "global"] +# image_format : "png" #options: "png", "pdf", "svg", "eps", "jpg" .. +# dpi_val : 300 +# fps: 2 +# ERA5: +# marker_size: 2 +# scale_marker_size: 1 +# marker: "o" +# alpha: 0.5 +# add_healpix_grid: false +# healpix_nside: 4 +# 2t: +# vmin: 250 +# vmax: 300 +# 10u: +# vmin: -40 +# vmax: 40 + +evaluation: + metrics: ["rmse", "mae"] + regions: ["global", "nhem"] + summary_plots : true + ratio_plots : false + heat_maps : false + summary_dir: "./plots/" + plot_ensemble: "members" #supported: false, "std", "minmax", "members" + plot_score_maps: false #plot scores on a 2D maps. it slows down score computation + print_summary: false #print out score values on screen. it can be verbose + log_scale: false + add_grid: false + score_cards: false + bar_plots: false + + +default_streams: + ERA5: + channels: ["2t", "10u", "10v", "z_500", "t_850", "u_850", "v_850", "q_850"] + evaluation: + forecast_step: "all" + sample: "all" + ensemble: "all" #supported: "all", "mean", [0,1,2] + plotting: + sample: [0] + forecast_step: "all" #supported: "all", [1,2,3,...], "1-50" (equivalent of [1,2,3,...50]) + plot_maps: true + plot_histograms: false + plot_animations: true + CERRA: + channels: ["z_500", "t_850", "u_850"] #, "blah"] + evaluation: + forecast_step: "all" + sample: "all" + plotting: + sample: [0] + forecast_step: "all" + plot_maps: true + plot_bias: false + plot_target: false + plot_histograms: true + plot_animations: true + diff --git a/config/streams/eerie_downscaling/eerie_atmo.yml b/config/streams/eerie_downscaling/eerie_atmo.yml new file mode 100644 index 000000000..c8d3a3947 --- /dev/null +++ b/config/streams/eerie_downscaling/eerie_atmo.yml @@ -0,0 +1,41 @@ +EERIE_ATMO: + type: mesh + stream_id: 2136 + loss_weight: 1.0 + filenames: # Gridded data is used as source and native high-res data is used as target. + - /work/ab0995/a270225/data_transformation/eerie_atmo_gr025_2d_atm_mesh_daily.parq + target_file: /work/ab0995/a270225/data_transformation/eerie_atmo_2d_reduced_gaussian_daily.parq + source: + - m10u # 10m u wind component 24h mean + - m10v # 10m v wind component 24h mean + - mean2t # 2m temperature 24h mean + - mmsl # mean sea level pressure 24h mean + target: + - m10u + - m10v + - mean2t + - mmsl + sampling_mode: patch + patch_size_deg: 20.0 + location_weight : cosine_latitude + masking_rate : 0.6 + masking_rate_none : 0.05 + token_size : 32 + tokenize_spacetime : True + max_num_targets: -1 + embed : + net : transformer + num_tokens : 1 + num_heads : 8 + dim_embed : 256 + num_blocks : 2 + embed_target_coords : + net : linear + dim_embed : 256 + target_readout : + num_layers : 2 + num_heads : 4 + # sampling_rate : 0.2 + pred_head : + ens_size : 1 + num_layers : 1 \ No newline at end of file diff --git a/config/streams/eerie_downscaling/eerie_ocean_elem.yml b/config/streams/eerie_downscaling/eerie_ocean_elem.yml new file mode 100644 index 000000000..dbd1e0073 --- /dev/null +++ b/config/streams/eerie_downscaling/eerie_ocean_elem.yml @@ -0,0 +1,37 @@ +EERIE_OCEAN_ELEM: + type: mesh + stream_id: 2137 + loss_weight: 1.0 + filenames: # Gridded data is used as source and native high-res data is used as target. + - /work/ab0995/a270225/data_transformation/eerie_ocean_gr025_3d_gridded_daily.parq + target_file: /work/ab0995/a270225/data_transformation/eerie_ocean_3d_elements_daily.parq + source: # Not every vertical level is present in gridded data. + - avg_uoe_2.5m # ocean u velocity at 2.5m depth + - avg_voe_2.5m # ocean v velocity at 2.5m depth + target: + - avg_uoe_2.5m + - avg_voe_2.5m + sampling_mode: patch + patch_size_deg: 20 + location_weight : cosine_latitude + masking_rate : 0.6 + masking_rate_none : 0.05 + token_size : 32 + tokenize_spacetime : True + max_num_targets: -1 + embed : + net : transformer + num_tokens : 1 + num_heads : 8 + dim_embed : 256 + num_blocks : 2 + embed_target_coords : + net : linear + dim_embed : 256 + target_readout : + num_layers : 2 + num_heads : 4 + # sampling_rate : 0.2 + pred_head : + ens_size : 1 + num_layers : 1 \ No newline at end of file diff --git a/config/streams/eerie_downscaling/eerie_ocean_node.yml b/config/streams/eerie_downscaling/eerie_ocean_node.yml new file mode 100644 index 000000000..6b2326803 --- /dev/null +++ b/config/streams/eerie_downscaling/eerie_ocean_node.yml @@ -0,0 +1,39 @@ +EERIE_OCEAN_NODE: + type: mesh + stream_id: 2138 + loss_weight: 1.0 + filenames: # Gridded data is used as source and native high-res data is used as target. + - /work/ab0995/a270225/data_transformation/eerie_ocean_gr025_2d_gridded_daily.parq + target_file: /work/ab0995/a270225/data_transformation/eerie_ocean_2d_nodes_daily.parq + source: + - avg_tos # sea surface temperature + - avg_sos # sea surface salinity + - avg_zos # sea surface height + target: + - avg_tos + - avg_sos + - avg_zos + sampling_mode: patch + patch_size_deg: 20 + location_weight : cosine_latitude + masking_rate : 0.6 + masking_rate_none : 0.05 + token_size : 32 + tokenize_spacetime : True + max_num_targets: -1 + embed : + net : transformer + num_tokens : 1 + num_heads : 8 + dim_embed : 256 + num_blocks : 2 + embed_target_coords : + net : linear + dim_embed : 256 + target_readout : + num_layers : 2 + num_heads : 4 + # sampling_rate : 0.2 + pred_head : + ens_size : 1 + num_layers : 1 diff --git a/config/streams/eerie_gridded/eerie_atmo.yml b/config/streams/eerie_gridded/eerie_atmo.yml new file mode 100644 index 000000000..4250bb2e2 --- /dev/null +++ b/config/streams/eerie_gridded/eerie_atmo.yml @@ -0,0 +1,39 @@ +EERIE_ATMO: + type: mesh + stream_id: 2136 + loss_weight: 1.0 + filenames: + - /work/ab0995/a270225/data_transformation/eerie_atmo_gr025_2d_atm_mesh_daily.parq + target_file: /work/ab0995/a270225/data_transformation/eerie_atmo_gr025_2d_atm_mesh_daily.parq + source: + - m10u # 10m u wind component 24h mean + - m10v # 10m v wind component 24h mean + - mean2t # 2m temperature 24h mean + - mmsl # mean sea level pressure 24h mean + target: + - m10u + - m10v + - mean2t + - mmsl + location_weight : cosine_latitude + masking_rate : 0.6 + masking_rate_none : 0.05 + token_size : 32 + tokenize_spacetime : True + max_num_targets: -1 + embed : + net : transformer + num_tokens : 1 + num_heads : 8 + dim_embed : 256 + num_blocks : 2 + embed_target_coords : + net : linear + dim_embed : 256 + target_readout : + num_layers : 2 + num_heads : 4 + # sampling_rate : 0.2 + pred_head : + ens_size : 1 + num_layers : 1 \ No newline at end of file diff --git a/config/streams/eerie_gridded/eerie_ocean_elem.yml b/config/streams/eerie_gridded/eerie_ocean_elem.yml new file mode 100644 index 000000000..2880c1cd0 --- /dev/null +++ b/config/streams/eerie_gridded/eerie_ocean_elem.yml @@ -0,0 +1,34 @@ +EERIE_OCEAN_ELEM: + type: mesh + stream_id: 2137 + loss_weight: 1.0 + filenames: # In case of gridded data there's no separate node and element files. Not every vertial level is used for training, so the 3D gridded file is used instead of the 3D element file. + - /work/ab0995/a270225/data_transformation/eerie_ocean_gr025_3d_gridded_daily.parq + source: # Not every vertical level is present in gridded data. + - avg_uoe_2.5m # ocean u velocity at 2.5m depth + - avg_voe_2.5m # ocean v velocity at 2.5m depth + target: + - avg_uoe_2.5m + - avg_voe_2.5m + location_weight : cosine_latitude + masking_rate : 0.6 + masking_rate_none : 0.05 + token_size : 32 + tokenize_spacetime : True + max_num_targets: -1 + embed : + net : transformer + num_tokens : 1 + num_heads : 8 + dim_embed : 256 + num_blocks : 2 + embed_target_coords : + net : linear + dim_embed : 256 + target_readout : + num_layers : 2 + num_heads : 4 + # sampling_rate : 0.2 + pred_head : + ens_size : 1 + num_layers : 1 \ No newline at end of file diff --git a/config/streams/eerie_gridded/eerie_ocean_node.yml b/config/streams/eerie_gridded/eerie_ocean_node.yml new file mode 100644 index 000000000..8d270dfda --- /dev/null +++ b/config/streams/eerie_gridded/eerie_ocean_node.yml @@ -0,0 +1,36 @@ +EERIE_OCEAN_NODE: + type: mesh + stream_id: 2138 + loss_weight: 1.0 + filenames: + - /work/ab0995/a270225/data_transformation/eerie_ocean_gr025_2d_gridded_daily.parq + source: + - avg_tos # sea surface temperature + - avg_sos # sea surface salinity + - avg_zos # sea surface height + target: + - avg_tos + - avg_sos + - avg_zos + location_weight : cosine_latitude + masking_rate : 0.6 + masking_rate_none : 0.05 + token_size : 32 + tokenize_spacetime : True + max_num_targets: -1 + embed : + net : transformer + num_tokens : 1 + num_heads : 8 + dim_embed : 256 + num_blocks : 2 + embed_target_coords : + net : linear + dim_embed : 256 + target_readout : + num_layers : 2 + num_heads : 4 + # sampling_rate : 0.2 + pred_head : + ens_size : 1 + num_layers : 1 diff --git a/config/streams/eerie_native/eerie_atmo.yml b/config/streams/eerie_native/eerie_atmo.yml new file mode 100644 index 000000000..2abafb9d4 --- /dev/null +++ b/config/streams/eerie_native/eerie_atmo.yml @@ -0,0 +1,40 @@ +EERIE_ATMO: + type: mesh + stream_id: 2136 + loss_weight: 1.0 + filenames: + - /work/ab0995/a270225/data_transformation/eerie_atmo_2d_reduced_gaussian_daily.parq + source: + - m10u # 10m u wind component 24h mean + - m10v # 10m v wind component 24h mean + - mean2t # 2m temperature 24h mean + - mmsl # mean sea level pressure 24h mean + target: + - m10u + - m10v + - mean2t + - mmsl + patch_size_deg: null + sampling_mode: 'global_sparse' # loads given number of random points globally for each time step. + sample_points: 262144 + location_weight : cosine_latitude + masking_rate : 0.6 + masking_rate_none : 0.05 + token_size : 32 + tokenize_spacetime : True + max_num_targets: -1 + embed : + net : transformer + num_tokens : 1 + num_heads : 8 + dim_embed : 256 + num_blocks : 2 + embed_target_coords : + net : linear + dim_embed : 256 + target_readout : + num_layers : 2 + num_heads : 4 + pred_head : + ens_size : 1 + num_layers : 1 \ No newline at end of file diff --git a/config/streams/eerie_native/eerie_ocean_elem.yml b/config/streams/eerie_native/eerie_ocean_elem.yml new file mode 100644 index 000000000..3f28b210e --- /dev/null +++ b/config/streams/eerie_native/eerie_ocean_elem.yml @@ -0,0 +1,37 @@ +EERIE_OCEAN_ELEM: + type: mesh + stream_id: 2137 + loss_weight: 1.0 + filenames: + - /work/ab0995/a270225/data_transformation/eerie_ocean_3d_elements_daily.parq + source: + - avg_uoe_2.5m # ocean u velocity at 2.5m depth + - avg_voe_2.5m # ocean v velocity at 2.5m depth + target: + - avg_uoe_2.5m + - avg_voe_2.5m + patch_size_deg: null + sampling_mode: 'global_sparse' # loads given number of random points globally for each time step. + sample_points: 262144 + location_weight : cosine_latitude + masking_rate : 0.6 + masking_rate_none : 0.05 + token_size : 32 + tokenize_spacetime : True + max_num_targets: -1 + embed : + net : transformer + num_tokens : 1 + num_heads : 8 + dim_embed : 256 + num_blocks : 2 + embed_target_coords : + net : linear + dim_embed : 256 + target_readout : + num_layers : 2 + num_heads : 4 + # sampling_rate : 0.2 + pred_head : + ens_size : 1 + num_layers : 1 \ No newline at end of file diff --git a/config/streams/eerie_native/eerie_ocean_node.yml b/config/streams/eerie_native/eerie_ocean_node.yml new file mode 100644 index 000000000..515853444 --- /dev/null +++ b/config/streams/eerie_native/eerie_ocean_node.yml @@ -0,0 +1,39 @@ +EERIE_OCEAN_NODE: + type: mesh + stream_id: 2138 + loss_weight: 1.0 + filenames: # For deeper layers use the eerie_ocean_3d_nodes_daily.parq file + - /work/ab0995/a270225/data_transformation/eerie_ocean_2d_nodes_daily.parq + source: + - avg_tos # sea surface temperature + - avg_sos # sea surface salinity + - avg_zos # sea surface height + target: + - avg_tos + - avg_sos + - avg_zos + patch_size_deg: null + sampling_mode: 'global_sparse' # loads given number of random points globally for each time step. + sample_points: 262144 + location_weight : cosine_latitude + masking_rate : 0.6 + masking_rate_none : 0.05 + token_size : 32 + tokenize_spacetime : True + max_num_targets: -1 + embed : + net : transformer + num_tokens : 1 + num_heads : 8 + dim_embed : 256 + num_blocks : 2 + embed_target_coords : + net : linear + dim_embed : 256 + target_readout : + num_layers : 2 + num_heads : 4 + # sampling_rate : 0.2 + pred_head : + ens_size : 1 + num_layers : 1 diff --git a/integration_tests/small1.yml b/integration_tests/small1.yml index 3fbed8381..12268bfef 100644 --- a/integration_tests/small1.yml +++ b/integration_tests/small1.yml @@ -77,7 +77,7 @@ norm_eps: 1e-4 latent_noise_kl_weight: 0.0 # 1e-5 latent_noise_gamma: 2.0 -latent_noise_saturate_encodings: 5 +latent_noise_saturate_encodings: 5 latent_noise_use_additive_noise: False latent_noise_deterministic_latents: True @@ -102,13 +102,13 @@ general: rank: ??? world_size: ??? - # local_rank, + # local_rank, # with_ddp, - # data_path_*, - # model_path, - # run_path, + # data_path_*, + # model_path, + # run_path, # path_shared_ - + multiprocessing_method: "fork" desc: "" @@ -127,14 +127,14 @@ data_loading : rng_seed: ??? repeat_data_in_mini_epoch : False - # pin GPU memory for faster transfer; it is possible that enabling memory_pinning with + # pin GPU memory for faster transfer; it is possible that enabling memory_pinning with # FSDP2 + DINOv2 can cause the job to hang and trigger a PyTorch timeout error. # If this happens, you can disable the flag, but performance will drop on GH200. memory_pinning: True # config for training training_config: - + training_mode: ["masking"] num_mini_epochs: 1 @@ -148,7 +148,7 @@ training_config: time_window_len: 06:00:00 window_offset_prediction : 1 - + learning_rate_scheduling : lr_start: 1e-6 lr_max: 0.00005 @@ -187,12 +187,12 @@ training_config: forecast : time_step: 06:00:00 num_steps: 2 - policy: "fixed" + policy: "fixed" offset: 1 # validation config; full validation config is merge of training and validation config -validation_config: +validation_config: samples_per_mini_epoch: 32 shuffle: False @@ -203,14 +203,16 @@ validation_config: output: streams: ["ERA5"] - validate_with_ema: + validate_with_ema: enabled : True ema_ramp_up_ratio: 0.09 ema_halflife_in_thousands: 1e-3 test_config: + start_date: 2021-10-10T00:00 + end_date: 2022-10-11T00:00 output: - num_samples: 2 + num_samples: 10 # TODO: read latent from here diff --git a/integration_tests/small1_test.py b/integration_tests/small1_test.py index 65ffec429..d35380478 100644 --- a/integration_tests/small1_test.py +++ b/integration_tests/small1_test.py @@ -71,12 +71,6 @@ def infer(run_id): main( [ "inference", - "-start", - "2021-10-10", - "-end", - "2022-10-11", - "--samples", - "10", "--mini-epoch", "0", "--from-run-id", diff --git a/integration_tests/small_multi_stream.yaml b/integration_tests/small_multi_stream.yaml index 683a1d941..7537c8ee4 100644 --- a/integration_tests/small_multi_stream.yaml +++ b/integration_tests/small_multi_stream.yaml @@ -77,7 +77,7 @@ norm_eps: 1e-4 latent_noise_kl_weight: 0.0 # 1e-5 latent_noise_gamma: 2.0 -latent_noise_saturate_encodings: 5 +latent_noise_saturate_encodings: 5 latent_noise_use_additive_noise: False latent_noise_deterministic_latents: True @@ -100,13 +100,13 @@ general: rank: ??? world_size: ??? - # local_rank, + # local_rank, # with_ddp, - # data_path_*, - # model_path, - # run_path, + # data_path_*, + # model_path, + # run_path, # path_shared_ - + multiprocessing_method: "fork" desc: "" @@ -125,14 +125,14 @@ data_loading : rng_seed: ??? repeat_data_in_mini_epoch : False - # pin GPU memory for faster transfer; it is possible that enabling memory_pinning with + # pin GPU memory for faster transfer; it is possible that enabling memory_pinning with # FSDP2 + DINOv2 can cause the job to hang and trigger a PyTorch timeout error. # If this happens, you can disable the flag, but performance will drop on GH200. memory_pinning: True # config for training training_config: - + training_mode: ["masking"] num_mini_epochs: 1 @@ -146,7 +146,7 @@ training_config: time_window_len: 06:00:00 window_offset_prediction : 1 - + learning_rate_scheduling : lr_start: 1e-6 lr_max: 0.00005 @@ -185,12 +185,12 @@ training_config: forecast : time_step: 06:00:00 num_steps: 2 - policy: "fixed" + policy: "fixed" offset: 1 # validation config; full validation config is merge of training and validation config -validation_config: +validation_config: samples_per_mini_epoch: 32 shuffle: False @@ -198,11 +198,14 @@ validation_config: start_date: 2021-10-10T00:00 end_date: 2022-10-11T00:00 - validate_with_ema: + validate_with_ema: enabled : True ema_ramp_up_ratio: 0.09 ema_halflife_in_thousands: 1e-3 test_config: + start_date: 2021-10-10T00:00 + end_date: 2022-10-11T00:00 output: - num_samples: 2 + num_samples: 10 + streams_output: ["ERA5", "SurfaceCombined", "NPPATMS"] diff --git a/integration_tests/small_multi_stream_test.py b/integration_tests/small_multi_stream_test.py index 92141f100..42d0b37b8 100644 --- a/integration_tests/small_multi_stream_test.py +++ b/integration_tests/small_multi_stream_test.py @@ -66,7 +66,7 @@ def test_train_multi_stream(setup, test_run_id): test_run_id, ] ) - + infer_multi_stream(test_run_id) evaluate_multi_stream_results(test_run_id) assert_metrics_file_exists(test_run_id) @@ -81,22 +81,12 @@ def infer_multi_stream(run_id): main( [ "inference", - "-start", - "2021-10-10", - "-end", - "2022-10-11", - "--samples", - "10", "--mini-epoch", "0", "--from-run-id", run_id, "--run-id", run_id, - "--streams-output", - "ERA5", - "SurfaceCombined", - "NPPATMS", "--config", f"{WEATHERGEN_HOME}/integration_tests/small_multi_stream.yaml", ] diff --git a/packages/common/src/weathergen/common/paths.py b/packages/common/src/weathergen/common/paths.py index 82c31ef49..532588da2 100644 --- a/packages/common/src/weathergen/common/paths.py +++ b/packages/common/src/weathergen/common/paths.py @@ -23,5 +23,7 @@ def get_wg_private_path() -> Path: path = _REPO_ROOT.parent / "WeatherGenerator-private" path = path.resolve() - assert path.is_dir(), f"WeatherGenerator private repo path does not exist or is not a directory: {path}" + assert path.is_dir(), ( + f"WeatherGenerator private repo path does not exist or is not a directory: {path}" + ) return path diff --git a/packages/evaluate/src/weathergen/evaluate/export/export_core.py b/packages/evaluate/src/weathergen/evaluate/export/export_core.py index 74f4aeb07..bacb93212 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/export_core.py +++ b/packages/evaluate/src/weathergen/evaluate/export/export_core.py @@ -1,4 +1,5 @@ import logging +from collections import defaultdict from multiprocessing import Pool import numpy as np @@ -14,28 +15,71 @@ _logger = logging.getLogger(__name__) _logger.setLevel(logging.INFO) +# Module-level cache for the zarr path and open store — resolved once per worker. +_CACHED_FNAME_ZARR: str | None = None +_CACHED_ZIO = None -def get_data_worker(args: tuple) -> xr.DataArray: + +def _init_worker(fname_zarr: str) -> None: + """Pool initializer: open the zarr store once and keep it for the worker's lifetime.""" + global _CACHED_FNAME_ZARR, _CACHED_ZIO + _CACHED_FNAME_ZARR = fname_zarr + _CACHED_ZIO = zarrio_reader(fname_zarr) + _CACHED_ZIO.__enter__() + + +def get_data_worker(args: tuple) -> tuple[int, int, xr.DataArray]: """ - Worker function to retrieve data for a single sample and forecast step. + Worker function to retrieve data for a single (sample, fstep) pair. - Parameters - ---------- - args : Tuple containing (sample, fstep, run_id, stream). + Reads the raw zarr arrays as numpy (bypassing dask) and builds a + lightweight xarray DataArray that can be pickled back to the main + process with all data already in memory. Returns ------- - xarray DataArray for the specified sample and forecast step. + Tuple of (sample, fstep, xarray.DataArray) with data fully in memory. """ - sample, fstep, run_id, stream, dtype, epoch, rank = args - fname_zarr = get_model_results(run_id, epoch, rank) - with zarrio_reader(fname_zarr) as zio: - out = zio.get_data(sample, stream, fstep) - if dtype == "target": - data = out.target - elif dtype == "prediction": - data = out.prediction - return data + sample, fstep, stream, dtype = args + + # Navigate directly to the zarr group for this (sample, stream, fstep, dtype). + group_path = f"{sample}/{stream}/{fstep}/{dtype}" + ds_group = _CACHED_ZIO.data_root.get(group_path) + + # Read raw arrays as numpy — no dask, no chunking overhead. + data_arr = np.asarray(ds_group["data"]) # (npoints, nchannels) or (npoints, nchannels, nens) + coords_arr = np.asarray(ds_group["coords"]) # (npoints, 2) + times_arr = np.asarray(ds_group["times"]).astype("datetime64[ns]") # (npoints,) + channels = list(ds_group.attrs["channels"]) + source_interval_start = np.asarray(ds_group.attrs["source_interval"]["start"]).astype("datetime64[ns]") + source_interval_end = np.asarray(ds_group.attrs["source_interval"]["end"]).astype("datetime64[ns]") + + # Build a lightweight xarray DataArray with the same structure + # that process_sample / assign_coords expects: + # dims = [ipoint, channel] + # coords: forecast_step, channel, valid_time, lat, lon + npoints = data_arr.shape[0] + + # Handle optional ensemble dimension: squeeze it out if present. + if data_arr.ndim == 3 and data_arr.shape[2] == 1: + data_arr = data_arr[:, :, 0] + + da_result = xr.DataArray( + data_arr, + dims=["ipoint", "channel"], + coords={ + "ipoint": np.arange(npoints), + "channel": channels, + "forecast_step": fstep, + "valid_time": ("ipoint", times_arr), + "lat": ("ipoint", coords_arr[:, 0]), + "lon": ("ipoint", coords_arr[:, 1]), + "source_interval_start": source_interval_start, + "source_interval_end": source_interval_end + }, + ) + + return (sample, fstep, da_result) def get_fsteps(fsteps, fname_zarr: str): @@ -141,9 +185,13 @@ def get_grid_type(data_type, stream: str, fname_zarr: str) -> str: # TODO: this will change after restructuring the lead time. -def get_ref_times(fname_zarr, stream, samples, fstep_hours) -> list[np.datetime64]: +def get_ref_times(fname_zarr, stream, samples, fstep_hours, n_processes) -> list[np.datetime64]: """ Retrieve reference times for the specified samples from the Zarr store. + + Reads only the lightweight 'times' array from the zarr hierarchy + instead of loading the full data arrays. + Parameters ---------- fname_zarr : str @@ -154,19 +202,33 @@ def get_ref_times(fname_zarr, stream, samples, fstep_hours) -> list[np.datetime6 List of samples to process. fstep_hours : np.timedelta64 Time difference between forecast steps in hours. + n_processes : int + Number of parallel processes to use (unused, kept for API compat). Returns ------- list[np.datetime64] List of reference times corresponding to the samples. """ + _logger.info(f"Retrieving reference times for {len(samples)} samples...") + ref_times = [] with zarrio_reader(fname_zarr) as zio: - zio_forecast_steps = sorted([int(step) for step in zio.forecast_steps]) - for sample in samples: - data = zio.get_data(sample, stream, zio_forecast_steps[0]) - data = data.target.as_xarray().squeeze() - ref_time = data.valid_time.values[0] - fstep_hours * int(data.forecast_step.values) + first_fstep = sorted([int(step) for step in zio.forecast_steps])[0] + + for sample in tqdm(samples, desc="Getting ref times"): + # Navigate directly to the target group and read only the 'times' array, + # avoiding the expensive full-data load via get_data() / as_xarray(). + group_path = f"{sample}/{stream}/{first_fstep}/target" + target_group = zio.data_root.get(group_path) + + if target_group is None: + raise FileNotFoundError(f"Zarr group '{group_path}' not found in {fname_zarr}") + + times_arr = np.array(target_group["times"]).astype("datetime64[ns]") + valid_time = times_arr[0] + ref_time = valid_time - fstep_hours * first_fstep ref_times.append(ref_time) + return ref_times @@ -179,8 +241,11 @@ def get_streams(stream, fname_zarr): def export_model_outputs(data_type: str, config: OmegaConf, **kwargs) -> None: """ - Retrieve data from Zarr store and save one sample to each NetCDF file. - Using multiprocessing to speed up data retrieval. + Retrieve data from Zarr store and export to the requested format. + + All (sample, fstep) pairs are submitted to the pool at once so that + every worker stays busy. Results are grouped by sample and handed to + the parser in sample order. Parameters ---------- @@ -188,36 +253,8 @@ def export_model_outputs(data_type: str, config: OmegaConf, **kwargs) -> None: Type of data to retrieve ('target' or 'prediction'). config : OmegaConf Loaded config for cf_parser function. - kwargs: Additional keyword arguments for the parser. - - NOTE: it contains the following parameters: - run_id : str - Run ID to identify the Zarr store. - samples : list - Sample to process - stream : str - Stream name to retrieve data for (e.g., 'ERA5'). - data_type : str - Type of data to retrieve ('target' or 'prediction'). - fsteps : list - List of forecast steps to retrieve. If None, retrieves all available forecast steps. - channels : list - List of channels to retrieve. If None, retrieves all available channels. - n_processes : list - Number of parallel processes to use for data retrieval. - ecpoch : int - Epoch number to identify the Zarr store. - rank : int - Rank number to identify the Zarr store. - regrid_degree : float - If specified, regrid the data to a regular lat/lon grid with the given degree - output_dir : str - Directory to save the NetCDF files. - output_format : str - Output file format (currently only 'netcdf' supported). - """ kwargs = OmegaConf.create(kwargs) @@ -241,35 +278,99 @@ def export_model_outputs(data_type: str, config: OmegaConf, **kwargs) -> None: for stream in streams: grid_type = get_grid_type(data_type, stream, fname_zarr) channels = get_channels(channels, stream, fname_zarr) - ref_times = get_ref_times(fname_zarr, stream, samples, fstep_hours) - kwargs["stream"] = stream + ref_times = get_ref_times(fname_zarr, stream, samples, fstep_hours, n_processes) kwargs["grid_type"] = grid_type kwargs["channels"] = channels kwargs["data_type"] = data_type - with Pool(processes=n_processes, maxtasksperchild=5) as pool: - parser = CfParserFactory.get_parser(config=config, **kwargs) - - processed_samples = [] - - for s_idx, sample in enumerate(tqdm(samples)): - ref_time = ref_times[s_idx] - step_tasks = [ - (sample, fstep, run_id, stream, data_type, epoch, rank) for fstep in fsteps + parser = CfParserFactory.get_parser(config=config, **kwargs) + + n_fsteps = len(fsteps) + total_tasks = len(samples) * n_fsteps + + # Batch size in *samples*. Limits how many samples can be in-flight at once, + # bounding peak memory while still allowing read/write overlap within each batch. + batch_size = max(1, n_processes * 2) + n_batches = (len(samples) + batch_size - 1) // batch_size + + _logger.info( + f"Exporting {len(samples)} samples × {n_fsteps} fsteps " + f"({total_tasks} total tasks) in {n_batches} batch(es) of up to " + f"{batch_size} samples, using {n_processes} workers. " + f"Reading and writing are interleaved within each batch." + ) + + # Initialise each worker with the zarr path so it is resolved only once. + with Pool( + processes=n_processes, + initializer=_init_worker, + initargs=(fname_zarr,), + ) as pool: + samples_written = 0 + + for batch_idx in range(n_batches): + batch_start = batch_idx * batch_size + batch_end = min(batch_start + batch_size, len(samples)) + batch_samples = samples[batch_start:batch_end] + batch_ref_times = ref_times[batch_start:batch_end] + + # Map sample -> index within this batch for ref_times lookup. + sample_to_batch_idx = {s: i for i, s in enumerate(batch_samples)} + + batch_tasks = [ + (sample, fstep, stream, data_type) for sample in batch_samples for fstep in fsteps ] - results_iterator = pool.imap_unordered(get_data_worker, step_tasks, chunksize=1) - - processed = parser.process_sample( - results_iterator, - ref_time=ref_time, + _logger.info( + f"Batch {batch_idx + 1}/{n_batches}: " + f"samples {batch_start}–{batch_end - 1} " + f"({len(batch_samples)} samples, {len(batch_tasks)} tasks)" ) - processed_samples.append(processed) + # Interleaved read/write: as soon as all fsteps for a sample + # arrive, write it immediately while workers continue reading. + sample_results: dict[int, list] = defaultdict(list) + batch_written = 0 - # Only save here if need to merge samples, otherwise save in process_sample - if processed_samples[0] is not None: - parser.save(processed_samples) - - pool.terminate() - pool.join() + pbar = tqdm( + total=len(batch_tasks), + desc=f" Batch {batch_idx + 1}/{n_batches}", + ) + + processed_samples = [] + + for sample, _fstep, data in pool.imap_unordered( + get_data_worker, batch_tasks, chunksize=max(1, n_fsteps) + ): + sample_results[sample].append(data) + pbar.update(1) + + # Check if this sample is complete (all fsteps received). + if len(sample_results[sample]) == n_fsteps: + b_idx = sample_to_batch_idx[sample] + ref_time = batch_ref_times[b_idx] + results_iter = iter(sample_results[sample]) + processed = parser.process_sample(results_iter, ref_time=ref_time) + processed_samples.append(processed) + + # Free memory immediately. + del sample_results[sample] + batch_written += 1 + + # Only save here if need to merge samples, otherwise saved in process_sample + if processed_samples[0] is not None: + parser.save(processed_samples) + pbar.close() + + samples_written += batch_written + if batch_written != len(batch_samples): + _logger.error( + f"Batch {batch_idx + 1}: expected {len(batch_samples)} " + f"samples but only wrote {batch_written}. " + f"Incomplete: {list(sample_results.keys())}" + ) + + # Free any remaining refs before next batch. + del sample_results + + _logger.info(f"Export complete. Wrote {samples_written}/{len(samples)} samples.") \ No newline at end of file diff --git a/packages/evaluate/src/weathergen/evaluate/export/export_inference.py b/packages/evaluate/src/weathergen/evaluate/export/export_inference.py index 61ca0ef3e..921c4a446 100755 --- a/packages/evaluate/src/weathergen/evaluate/export/export_inference.py +++ b/packages/evaluate/src/weathergen/evaluate/export/export_inference.py @@ -270,7 +270,8 @@ def export_from_args(args: list) -> None: ---------- args : List of command line arguments. """ - args = parse_args(sys.argv[1:]) + args = parse_args(args) + # Load configuration if args.output_format == "verif": config_file = Path(_REPO_ROOT, "config/evaluate/config_zarr2verif.yaml") @@ -285,7 +286,6 @@ def export_from_args(args: list) -> None: if kwargs.get("expver") == "NEW": kwargs["expver"] = generate_new_expver() - _logger.info(kwargs) # Ensure output directory exists diff --git a/packages/evaluate/src/weathergen/evaluate/export/parsers/netcdf_parser.py b/packages/evaluate/src/weathergen/evaluate/export/parsers/netcdf_parser.py index a0373b520..85771f21b 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/parsers/netcdf_parser.py +++ b/packages/evaluate/src/weathergen/evaluate/export/parsers/netcdf_parser.py @@ -72,7 +72,9 @@ def process_sample( if result is None: continue - result = result.as_xarray().squeeze() + # result is already a materialized xarray DataArray (built in the worker). + if not isinstance(result, xr.DataArray): + result = result.as_xarray().squeeze() if "channel" not in result.indexes: result = result.expand_dims("channel") result = result.sel(channel=self.channels) @@ -504,7 +506,7 @@ def add_metadata(self, ds: xr.Dataset) -> xr.Dataset: xarray Dataset with CF conventions added to attributes. """ # ds = ds.copy() - ds.attrs["title"] = f"WeatherGenerator Output for {self.run_id} using stream {self.stream}" + ds.attrs["title"] = f"WeatherGenerator Output for {self.run_id}" ds.attrs["institution"] = "WeatherGenerator Project" ds.attrs["source"] = "WeatherGenerator v0.0" ds.attrs["history"] = ( @@ -512,8 +514,7 @@ def add_metadata(self, ds: xr.Dataset) -> xr.Dataset: + np.datetime_as_string(np.datetime64("now"), unit="s") ) ds.attrs["Conventions"] = "CF-1.12" - # drop stream now it's in title - ds = ds.drop_vars("stream") + return ds def save(self, ds: xr.Dataset, forecast_ref_time: np.datetime64) -> None: diff --git a/packages/evaluate/src/weathergen/evaluate/export/parsers/quaver_parser.py b/packages/evaluate/src/weathergen/evaluate/export/parsers/quaver_parser.py index d698a55f9..9fa7a1b71 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/parsers/quaver_parser.py +++ b/packages/evaluate/src/weathergen/evaluate/export/parsers/quaver_parser.py @@ -87,7 +87,8 @@ def process_sample( if result is None: continue - result = result.as_xarray().squeeze() + if not isinstance(result, xr.DataArray): + result = result.as_xarray().squeeze() result = result.sel(channel=self.channels) da_fs = self.assign_coords(result) diff --git a/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py b/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py index e46332d1a..70f1d708e 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py +++ b/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py @@ -99,8 +99,9 @@ def process_sample( for result in fstep_iterator_results: if result is None: continue - - result = result.as_xarray().squeeze() + # result is already a materialized xarray DataArray (built in the worker). + if not isinstance(result, xr.DataArray): + result = result.as_xarray().squeeze() result = result.sel(channel=self.channels) result = self.preprocess(result) result = self.reshape(result) @@ -128,7 +129,6 @@ def process_sample( merged = self.merge(da_var, obs_result) merged = self.add_metadata(merged, verif_var) vars_to_merge[verif_var] = merged - return vars_to_merge def get_zarr_dt(self, ds: xr.Dataset) -> np.timedelta64: diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py index 2adba17bf..ff4768b9b 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py @@ -3,6 +3,7 @@ import logging import os import re +import warnings from pathlib import Path import cartopy @@ -13,29 +14,45 @@ import omegaconf as oc import seaborn as sns import xarray as xr +from astropy_healpix import HEALPix as HEALPixGrid +from cartopy.io import DownloadWarning +from matplotlib.collections import LineCollection from matplotlib.lines import Line2D from PIL import Image from scipy.stats import wilcoxon from weathergen.common.config import _load_private_conf -from weathergen.evaluate.plotting.plot_utils import ( - DefaultMarkerSize, -) +from weathergen.evaluate.plotting.plot_utils import DefaultMarkerSize from weathergen.evaluate.utils.regions import RegionBoundingBox +_logger = logging.getLogger(__name__) +_logger.setLevel(logging.INFO) + work_dir = Path(_load_private_conf(None)["path_shared_working_dir"]) / "assets/cartopy" cartopy.config["data_dir"] = str(work_dir) cartopy.config["pre_existing_data_dir"] = str(work_dir) os.environ["CARTOPY_DATA_DIR"] = str(work_dir) +# Route Cartopy DownloadWarnings through the logging system so they are visible in logs. +logging.captureWarnings(True) +warnings.filterwarnings("always", category=DownloadWarning) + + +def _download_cartopy_off(enabled: bool) -> None: + """Enable/disable blocking Cartopy downloads by elevating DownloadWarning to error.""" + if enabled: + warnings.filterwarnings("error", category=DownloadWarning) + _logger.info( + "Auto-downloads are blocked for cartopy; only local cartopy data will be used." + ) + else: + warnings.filterwarnings("default", category=DownloadWarning) + np.seterr(divide="ignore", invalid="ignore") logging.getLogger("matplotlib.category").setLevel(logging.ERROR) -_logger = logging.getLogger(__name__) -_logger.setLevel(logging.INFO) - _logger.debug(f"Taking cartopy paths from {work_dir}") @@ -72,6 +89,7 @@ def __init__(self, plotter_cfg: dict, output_basedir: str | Path, stream: str | self.fig_size = plotter_cfg.get("fig_size") self.fps = plotter_cfg.get("fps") self.regions = plotter_cfg.get("regions") + _download_cartopy_off(enabled=True) self.plot_subtimesteps = plotter_cfg.get( "plot_subtimesteps", False ) # True if plots are created for each valid time separately @@ -454,6 +472,14 @@ def scatter_plot( vmax = map_kwargs_save.pop("vmax", None) cmap = plt.get_cmap(map_kwargs_save.pop("colormap", "coolwarm")) + # Healpix grid configuration + add_healpix_grid = map_kwargs_save.pop("add_healpix_grid", False) + healpix_nside = map_kwargs_save.pop("healpix_nside", 4) + healpix_color = map_kwargs_save.pop("healpix_color", "black") + healpix_linewidth = map_kwargs_save.pop("healpix_linewidth", 0.2) + healpix_step = map_kwargs_save.pop("healpix_step", 64) + healpix_linestyle = map_kwargs_save.pop("healpix_linestyle", "-") + if isinstance(map_kwargs_save.get("levels", False), oc.listconfig.ListConfig): norm = mpl.colors.BoundaryNorm( map_kwargs_save.pop("levels", None), cmap.N, extend="both" @@ -482,7 +508,10 @@ def scatter_plot( proj = ccrs.Robinson() ax = fig.add_subplot(1, 1, 1, projection=proj) - ax.coastlines() + try: + ax.coastlines() + except Exception: + _logger.warning("Could not add coastlines to plot; continuing without them.") assert data["lon"].shape == data["lat"].shape == data.shape, ( f"Scatter plot:: Data shape do not match. Shapes: " @@ -502,6 +531,15 @@ def scatter_plot( **map_kwargs_save, ) + # Add Healpix grid (optimized with LineCollection) + if add_healpix_grid: + lc = self.healpixlines( + healpix_nside, healpix_color, healpix_linewidth, healpix_step, healpix_linestyle + ) + ax.add_collection(lc) + else: + ax.gridlines(draw_labels=False, linestyle="--", color="black", linewidth=0.2) + plt.colorbar(scatter_plt, ax=ax, orientation="horizontal", label=f"Variable: {varname}") plt.title(title, fontsize=9.5) if regionname == "global": @@ -514,7 +552,6 @@ def scatter_plot( data["lat"].max().item(), ] ax.set_extent(region_extent, crs=ccrs.PlateCarree()) - ax.gridlines(draw_labels=False, linestyle="--", color="black", linewidth=1) # TODO: make this nicer parts = ["map", self.run_id, tag] @@ -551,6 +588,28 @@ def scatter_plot( return name + def healpixlines( + self, healpix_nside, healpix_color, healpix_linewidth, healpix_step, healpix_linestyle + ): + hp_grid = HEALPixGrid(nside=healpix_nside, order="ring") + lon_all, lat_all = hp_grid.boundaries_lonlat(np.arange(hp_grid.npix), step=healpix_step) + # Ensure closure of polygons + lon_closed = np.concatenate([lon_all.deg, lon_all.deg[:, 0:1]], axis=1) + lat_closed = np.concatenate([lat_all.deg, lat_all.deg[:, 0:1]], axis=1) + # Stack as (N_polys, N_points, 2) + segments = np.stack([lon_closed, lat_closed], axis=-1) + # (cartopy handles transform for LineCollection via set_transform) + lc = LineCollection( + segments, + colors=healpix_color, + linewidths=healpix_linewidth, + linestyles=healpix_linestyle, + alpha=0.5, + zorder=10, + ) + lc.set_transform(ccrs.PlateCarree()) + return lc + def animation(self, samples, fsteps, variables, select, tag) -> list[str]: """ Plot 2D animations for a dataset diff --git a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py index b2d5e87eb..b901fb895 100755 --- a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py +++ b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py @@ -134,6 +134,25 @@ def evaluate_from_args(argl: list[str], log_queue: mp.Queue) -> None: action="store_true", help="(optional) Upload scores to MLFlow.", ) + parser.add_argument( + "--options", + nargs="+", + default=[], + help=( + "Overwrite individual config options." + " Individual items should be of the form: parent_obj.nested_obj=value." + " NOTE: cannot be used for run_ids (use --run-ids instead)." + ), + ) + parser.add_argument( + "--run-ids", + nargs="+", + default=None, + help=( + "Filter run_ids from the config to only these." + " E.g. --run-ids wu4wy9os fy6fgscn so67dku1" + ), + ) args = parser.parse_args(argl) if args.config: @@ -155,6 +174,28 @@ def evaluate_from_args(argl: list[str], log_queue: mp.Queue) -> None: cf = OmegaConf.load(config) assert isinstance(cf, DictConfig) + + # Disable struct flag so that --options and --run-ids can freely modify keys. + OmegaConf.set_struct(cf, False) + + if args.options: + # Filter out any run_ids= items — those must use --run-ids instead. + cli_items = [item for item in args.options if not item.startswith("run_ids=")] + if len(cli_items) != len(args.options): + _logger.warning( + "run_ids= in --options is not supported (it's a dict, not a list). " + "Use --run-ids instead. Ignoring run_ids= items." + ) + if cli_items: + cli_overwrite = OmegaConf.from_cli(cli_items) + cf = OmegaConf.merge(cf, cli_overwrite) + _logger.info(f"Applied --options overwrites: {cli_items}") + + if args.run_ids: + existing = cf.get("run_ids", {}) + cf.run_ids = {k: existing.get(k, {}) for k in args.run_ids} + _logger.info(f"Overwritten run_ids to: {args.run_ids}") + evaluate_from_config(cf, mlflow_client, log_queue) @@ -254,8 +295,7 @@ def _process_stream( ) metric_list_to_json(reader, stream, stream_computed_scores, regions_to_compute) scores_dict = merge(stream_loaded_scores, stream_computed_scores) - return run_id, stream, scores_dict - + return run_id, stream, scores_dict # except Exception as e: diff --git a/packages/readers_extra/src/weathergen/readers_extra/data_reader_mesh.py b/packages/readers_extra/src/weathergen/readers_extra/data_reader_mesh.py index e3fcb6860..5d480f2f1 100644 --- a/packages/readers_extra/src/weathergen/readers_extra/data_reader_mesh.py +++ b/packages/readers_extra/src/weathergen/readers_extra/data_reader_mesh.py @@ -36,6 +36,7 @@ class DataReaderMesh(DataReaderTimestep): - Robust Multi-Node/Worker support (Fork-safe, Dask-safe). - Dynamic Patching (local) OR Global Sparse Sampling. """ + def __init__( self, tw_handler: TimeWindowHandler, @@ -58,6 +59,11 @@ def __init__( self._dask_arrays_trg = {} self.sampling_mode = stream_info.get("sampling_mode", "patch") + self.patch_stability_window = stream_info.get("patch_stability_window", 1) + + # Auto-enable staircase mode if window is defined and we are in patch mode + auto_use_counter = self.sampling_mode == "patch" and "patch_stability_window" in stream_info + self.patch_use_counter = stream_info.get("patch_use_counter", auto_use_counter) if self.filename_source != self.filename_target and self.sampling_mode != "patch": _logger.error( @@ -81,6 +87,7 @@ def __init__( self.col_map = {} self.stats_means = {} self.stats_vars = {} + self.patch_counter = 0 # 1. Probe Source meta_src = self._probe_file(self.filename_source, is_source=True) @@ -123,7 +130,10 @@ def __init__( self.roi_min_lon, self.roi_min_lat, self.roi_max_lon, self.roi_max_lat = self.roi else: self.roi_min_lon, self.roi_min_lat, self.roi_max_lon, self.roi_max_lat = ( - -180.0, -90.0, 180.0, 90.0 + -180.0, + -90.0, + 180.0, + 90.0, ) self.available_channels = list(self.col_map.keys()) @@ -133,7 +143,7 @@ def __init__( self.source_idx = self._select_channels("source") self.target_idx = self._select_channels("target") self.geoinfo_idx = [] - self.geoinfo_channels =[] + self.geoinfo_channels = [] self.source_channels = [self.available_channels[i] for i in self.source_idx] self.target_channels = [self.available_channels[i] for i in self.target_idx] @@ -146,7 +156,7 @@ def _probe_file(self, filepath, is_source=True): with xr.open_dataset(mapper, engine="zarr", chunks={}, consolidated=False) as ds: if "time" not in ds.coords: all_vars = list(ds.coords) + list(ds.data_vars) - time_candidates =[v for v in all_vars if "time" in v.lower()] + time_candidates = [v for v in all_vars if "time" in v.lower()] if time_candidates: target = time_candidates[0] if target in ds.data_vars: @@ -208,35 +218,25 @@ def _lazy_init(self): if self._initialized: return - self.mapper_src = fsspec.get_mapper("reference://", - fo=str(self.filename_source), - remote_protocol="file" - ) + self.mapper_src = fsspec.get_mapper( + "reference://", fo=str(self.filename_source), remote_protocol="file" + ) import warnings + with warnings.catch_warnings(): warnings.filterwarnings("ignore", message=".*separate the stored chunks.*") self.ds_source = xr.open_dataset( - self.mapper_src, - engine="zarr", - chunks={}, - decode_times=True, - consolidated=False + self.mapper_src, engine="zarr", chunks={}, decode_times=True, consolidated=False ) if self.filename_target != self.filename_source: self.mapper_trg = fsspec.get_mapper( - "reference://", - fo=str(self.filename_target), - remote_protocol="file" + "reference://", fo=str(self.filename_target), remote_protocol="file" ) with warnings.catch_warnings(): warnings.filterwarnings("ignore", message=".*separate the stored chunks.*") self.ds_target = xr.open_dataset( - self.mapper_trg, - engine="zarr", - chunks={}, - decode_times=True, - consolidated=False + self.mapper_trg, engine="zarr", chunks={}, decode_times=True, consolidated=False ) else: self.ds_target = self.ds_source @@ -284,19 +284,30 @@ def _fetch_data(self, idx: TIndex, channels: list[str], is_source: bool) -> Read if len(t_idxs) == 0 or not channels: return ReaderData.empty(len(channels), 0) - channel_indices =[self.available_channels.index(c) for c in channels] + channel_indices = [self.available_channels.index(c) for c in channels] start_t, end_t = t_idxs[0], t_idxs[-1] + 1 n_steps = len(t_idxs) - - lats_ref = self.lats_src if is_source else self.lats_trg + spatial_indices_ref = self.spatial_indices_src if is_source else self.spatial_indices_trg coords_ref = self.coords_src if is_source else self.coords_trg ds_ref = self.ds_source if is_source else self.ds_target arr_cache = self._dask_arrays_src if is_source else self._dask_arrays_trg - local_seed = int(idx) + 12345 + # Patching Seed Logic: + # Use internal counter for 'staircase' stability OR sample index for variety. + if self.patch_use_counter: + patch_idx = self.patch_counter // self.patch_stability_window + local_seed = patch_idx + 12345 + else: + # Fallback to time-based index (Warning: sampler often seeds this per-rank!) + local_seed = int(idx) + 12345 + patch_idx = int(idx) + patch_rng = np.random.default_rng(local_seed) + # Increment counter for next fetch + self.patch_counter += 1 + if self.sampling_mode == "global_sparse": total_points = len(spatial_indices_ref) target_n = self.sample_points if self.sample_points else 4096 @@ -312,43 +323,34 @@ def _fetch_data(self, idx: TIndex, channels: list[str], is_source: bool) -> Read lon_range = max(0.0, (self.roi_max_lon - self.roi_min_lon) - self.patch_size_deg) patch_indices_local = np.array([]) - attempts = 0 - - lat_0_candidates = self.roi_min_lat + patch_rng.random(100) * lat_range - lon_0_candidates = self.roi_min_lon + patch_rng.random(100) * lon_range - while attempts < 100: - lat_0 = lat_0_candidates[attempts] - lon_0 = lon_0_candidates[attempts] + lat_0 = self.roi_min_lat + patch_rng.random() * lat_range + lon_0 = self.roi_min_lon + patch_rng.random() * lon_range - mask_src = ( - (self.lats_src >= lat_0) & (self.lats_src < lat_0 + self.patch_size_deg) & - (self.lons_src >= lon_0) & (self.lons_src < lon_0 + self.patch_size_deg) - ) - mask_trg = ( - (self.lats_trg >= lat_0) & (self.lats_trg < lat_0 + self.patch_size_deg) & - (self.lons_trg >= lon_0) & (self.lons_trg < lon_0 + self.patch_size_deg) - ) - - pts_src = np.count_nonzero(mask_src) - pts_trg = np.count_nonzero(mask_trg) - - if pts_src >= MIN_PATCH_POINTS and pts_trg >= MIN_PATCH_POINTS: - patch_indices_local = np.where(mask_src if is_source else mask_trg)[0] - break - attempts += 1 + mask_src = ( + (self.lats_src >= lat_0) + & (self.lats_src < lat_0 + self.patch_size_deg) + & (self.lons_src >= lon_0) + & (self.lons_src < lon_0 + self.patch_size_deg) + ) + mask_trg = ( + (self.lats_trg >= lat_0) + & (self.lats_trg < lat_0 + self.patch_size_deg) + & (self.lons_trg >= lon_0) + & (self.lons_trg < lon_0 + self.patch_size_deg) + ) - if len(patch_indices_local) < MIN_PATCH_POINTS: - req_points = min(MIN_PATCH_POINTS, len(lats_ref)) - patch_indices_local = patch_rng.choice( - len(lats_ref), size=req_points, replace=False - ) + patch_indices_local = np.where(mask_src if is_source else mask_trg)[0] - patch_coords_base = self.coords_src[patch_indices_local] if is_source else ( - self.coords_trg[patch_indices_local] + patch_coords_base = ( + self.coords_src[patch_indices_local] + if is_source + else (self.coords_trg[patch_indices_local]) ) - final_disk_indices = self.spatial_indices_src[patch_indices_local] if is_source else ( - self.spatial_indices_trg[patch_indices_local] + final_disk_indices = ( + self.spatial_indices_src[patch_indices_local] + if is_source + else (self.spatial_indices_trg[patch_indices_local]) ) use_contiguous_read = True @@ -357,28 +359,35 @@ def _fetch_data(self, idx: TIndex, channels: list[str], is_source: bool) -> Read patch_coords_base = self.coords_src if is_source else self.coords_trg use_contiguous_read = True + if len(final_disk_indices) == 0: + _logger.warning( + f"[Stream {self._stream_info.get('name')}] NO POINTS FOUND for patch! Skipping." + ) + return ReaderData.empty(len(channels), n_steps) + if use_contiguous_read: disk_start, disk_stop = np.min(final_disk_indices), np.max(final_disk_indices) + 1 rel_indices = final_disk_indices - disk_start data_block = self._load_block_from_ds( - ds_ref, - arr_cache, - channel_indices, - start_t, - end_t, - n_steps, - slice(disk_start, disk_stop), - rel_indices + ds_ref, + arr_cache, + channel_indices, + start_t, + end_t, + n_steps, + slice(disk_start, disk_stop), + rel_indices, ) else: data_block = self._load_block_from_ds( - ds_ref, - arr_cache, - channel_indices, - start_t, end_t, - n_steps, - final_disk_indices, - None + ds_ref, + arr_cache, + channel_indices, + start_t, + end_t, + n_steps, + final_disk_indices, + None, ) if data_block.size > 0: @@ -399,16 +408,8 @@ def _fetch_data(self, idx: TIndex, channels: list[str], is_source: bool) -> Read return rdata def _load_block_from_ds( - self, - ds, - arr_cache, - indices, - start_t, - end_t, - n_steps, - disk_indices, - rel_indices - ) -> np.typing.NDArray: + self, ds, arr_cache, indices, start_t, end_t, n_steps, disk_indices, rel_indices + ) -> np.typing.NDArray: if rel_indices is not None: num_points = len(rel_indices) else: @@ -462,7 +463,7 @@ def _load_block_from_ds( if "time" in dims: # Contiguous read: Apply raw disk bounds, then rel_indices chunk = chunk[:, disk_indices] - + # Safety check: if chunk is completely empty, fill with NaNs if chunk.shape[1] == 0: assert False, "Empty chunk after disk indexing with time dimension" @@ -510,8 +511,8 @@ def _parse_attr(self, attrs, key): def _select_channels(self, type_key: str) -> list[int]: select = self._stream_info.get(type_key) - exclude = self._stream_info.get(f"{type_key}_exclude",[]) - return[ + exclude = self._stream_info.get(f"{type_key}_exclude", []) + return [ i for i, ch in enumerate(self.available_channels) if (not select or any(s in ch for s in select)) and not any(e in ch for e in exclude) @@ -546,21 +547,17 @@ def normalize_target_channels(self, target: np.typing.NDArray) -> np.typing.NDAr def denormalize_source_channels(self, source): if isinstance(source, torch.Tensor): stdev = torch.tensor( - self.stdev[self.source_idx], - dtype=source.dtype, - device=source.device + self.stdev[self.source_idx], dtype=source.dtype, device=source.device ) mean = torch.tensor( - self.mean[self.source_idx], - dtype=source.dtype, - device=source.device + self.mean[self.source_idx], dtype=source.dtype, device=source.device ) - land_mask = (source == 0.0) + land_mask = source == 0.0 denorm = (source * stdev) + mean denorm[land_mask] = torch.nan return denorm - - land_mask = (source == 0.0) + + land_mask = source == 0.0 denorm = (source * self.stdev[self.source_idx]) + self.mean[self.source_idx] denorm[land_mask] = np.nan return denorm @@ -576,4 +573,4 @@ def denormalize_target_channels(self, data): @override def normalize_geoinfos(self, geoinfos: np.typing.NDArray) -> np.typing.NDArray: norm = (geoinfos - self.mean_geoinfo) / self.stdev_geoinfo - return np.nan_to_num(norm, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32) \ No newline at end of file + return np.nan_to_num(norm, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32) diff --git a/src/weathergen/datasets/data_reader_fesom.py b/src/weathergen/datasets/data_reader_fesom.py index db46bdf73..b37352a7e 100644 --- a/src/weathergen/datasets/data_reader_fesom.py +++ b/src/weathergen/datasets/data_reader_fesom.py @@ -94,7 +94,54 @@ def __init__( # This flag ensures initialization happens only once per worker self._initialized = False - # print(f"checking stream info {list(stream_info.keys())}") + if len(self.filenames) > 0 and len(self.target_files) > 0: + # We need to initialize the channels in __init__ so that the + # MultiStreamDataSampler can correctly identify if a stream is forcing or not. + s_group = zarr.open_group(self.filenames[0], mode="r") + t_group = zarr.open_group(self.target_files[0], mode="r") + + self.source_mesh_size = self._get_mesh_size(s_group) + self.target_mesh_size = self._get_mesh_size(t_group) + + source_colnames: list[str] = list(s_group["data"].attrs["colnames"]) + target_colnames: list[str] = list(t_group["data"].attrs["colnames"]) + + source_cols_idx = list(np.arange(len(source_colnames), dtype=int)) + target_cols_idx = list(np.arange(len(target_colnames), dtype=int)) + + src_lat_index: int = source_colnames.index("lat") + src_lon_index: int = source_colnames.index("lon") + trg_lat_index: int = target_colnames.index("lat") + trg_lon_index: int = target_colnames.index("lon") + + source_colnames = self._remove_lonlat(source_colnames) + target_colnames = self._remove_lonlat(target_colnames) + + source_cols_idx.remove(src_lat_index) + source_cols_idx.remove(src_lon_index) + source_cols_idx = np.array(source_cols_idx) + + target_cols_idx.remove(trg_lat_index) + target_cols_idx.remove(trg_lon_index) + target_cols_idx = np.array(target_cols_idx) + + source_channels = self._stream_info.get("source") + source_excl = self._stream_info.get("source_exclude") + self.source_channels, self.source_idx = ( + self.select(source_colnames, source_cols_idx, source_channels, source_excl) + if source_channels or source_excl + else (source_colnames, source_cols_idx) + ) + + target_channels = self._stream_info.get("target") + target_excl = self._stream_info.get("target_exclude") + self.target_channels, self.target_idx = ( + self.select(target_colnames, target_cols_idx, target_channels, target_excl) + if target_channels or target_excl + else (target_colnames, target_cols_idx) + ) + + self.target_channel_weights = self.parse_target_channel_weights() def _get_mesh_size(self, group: zarr.Group) -> int: if "n_points" in group["data"].attrs: diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 4419c2955..4d54b2bf1 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -117,8 +117,10 @@ def forward(self, batch, pe_embed): # per cell indices into positional encoding tok_counts = batch.tokens_lens.permute([2, 0, 1, 3]).sum(0).flatten() - pe_idxs = torch.cat([torch.arange(c) for c in tok_counts]) - + rows = torch.arange( tok_counts.max(), device=tok_counts.device).unsqueeze(0) + rows = rows.expand(tok_counts.shape[0], -1) + pe_idxs = rows[rows < tok_counts.unsqueeze(1)] + # actual scatter operation tokens_all.scatter_(0, scatter_idxs, torch.cat(x_embeds) + pe_embed[pe_idxs]) diff --git a/src/weathergen/model/model_interface.py b/src/weathergen/model/model_interface.py index 1303f06b7..24a2d82b4 100644 --- a/src/weathergen/model/model_interface.py +++ b/src/weathergen/model/model_interface.py @@ -12,7 +12,6 @@ import itertools import logging -import omegaconf import torch from torch.distributed.fsdp import ( MixedPrecisionPolicy, @@ -28,12 +27,9 @@ MultiSelfAttentionHeadLocal, MultiSelfAttentionHeadVarlen, ) -from weathergen.model.ema import EMAModel from weathergen.model.layers import MLP from weathergen.model.model import Model, ModelParams from weathergen.model.utils import apply_fct_to_blocks, freeze_weights -from weathergen.train.target_and_aux_module_base import PhysicalTargetAndAux -from weathergen.train.target_and_aux_ssl_teacher import EMATeacher from weathergen.utils.distributed import is_root from weathergen.utils.utils import get_dtype @@ -273,60 +269,3 @@ def get_model(cf: Config, training_mode: TrainingMode, dataset, overrides): return Model( cf_with_overrides, sources_size, targets_num_channels, targets_coords_size ).create() - - -def get_target_aux_calculator( - cf: Config, loss_cfg: omegaconf.OmegaConf, dataset, model, device, batch_size_per_gpu, **kwargs -): - """ - Create target aux calculator - """ - - target_and_aux_calc_cfg = loss_cfg.get("target_and_aux_calc", "Physical") - - # parse target_and_aux_calc_cfg specification which can either be a string or config dict - if type(target_and_aux_calc_cfg) is str: - target_and_aux_calc = target_and_aux_calc_cfg - target_and_aux_calc_params = {} - elif type(target_and_aux_calc_cfg) is omegaconf.dictconfig.DictConfig: - # single key is the target_and_aux_calc type - target_and_aux_calc = list(target_and_aux_calc_cfg.keys())[0] - # value is dict with the target_and_aux_calc parameters - target_and_aux_calc_params = list(target_and_aux_calc_cfg.values())[0] - else: - assert False, "target_and_aux_calc needs either be name or config dict." - - # create target_and_aux_calc - if target_and_aux_calc == "Physical": - target_aux = PhysicalTargetAndAux(loss_cfg, model) - - elif target_and_aux_calc == "EMATeacher": - # work around for problems with FSDP2 - assert not cf.with_fsdp, "EMATeacher not supported with FSDP(2) at the moment" - - meta_ema_model, _ = init_model_and_shard( - cf, - dataset, - None, - None, - "student", - device, - with_ddp=False, - with_fsdp=False, - overrides=target_and_aux_calc_params.get("model_param_overrides", {}), - ) - ema_model = EMAModel( - model, - meta_ema_model, - halflife_steps=target_and_aux_calc_params.get("ema_halflife_in_thousands", 1e-3), - rampup_ratio=target_and_aux_calc_params.get("ema_ramp_up_ratio", 0.09), - is_model_sharded=(cf.with_ddp and cf.with_fsdp), - ) - - batch_size = cf.get("world_size_original", cf.get("world_size")) * batch_size_per_gpu - target_aux = EMATeacher(model, ema_model, batch_size, cf.training_config) - - else: - raise NotImplementedError(f"{target_and_aux_calc} is not implemented") - - return target_aux diff --git a/src/weathergen/train/target_and_aux_ssl_teacher.py b/src/weathergen/train/target_and_aux_ssl_teacher.py index 76994221c..edd8e53b6 100644 --- a/src/weathergen/train/target_and_aux_ssl_teacher.py +++ b/src/weathergen/train/target_and_aux_ssl_teacher.py @@ -9,30 +9,39 @@ from __future__ import annotations +import logging from typing import Any import torch +from weathergen.common.config import Config, load_run_config, merge_configs +from weathergen.model.model import ModelParams +from weathergen.model.model_interface import get_model from weathergen.model.ssl_target_processing import ( DINOTargetProcessing, JEPATargetProcessing, iBOTPatchTargetProcessing, ) from weathergen.train.target_and_aux_module_base import TargetAndAuxModuleBase, TargetAuxOutput +from weathergen.train.teacher_utils import ( + load_encoder_from_checkpoint, + prepare_encoder_teacher, +) +logger = logging.getLogger(__name__) -class EMATeacher(TargetAndAuxModuleBase): - def __init__(self, model, ema_model, batch_size, training_cfg, **kwargs): - # One of the issues is that the teacher model may have a different architecture - # to the student, e.g. JEPA. So we need quite a flexible way to instantiate the - # the teacher. Because of the device sharding etc that requires quite a bit of - # massaging we assume that the teacher creates the EMA model correctly. However, - # note that you cannot assume that model.state_dict equals ema_model.state_dict - self.ema_model = ema_model - self.batch_size = batch_size - # is a dict of TargetProcessing classes as we may use several in parallel +class EncoderTeacher(TargetAndAuxModuleBase): + """Base class for SSL teacher models. + Handles shared logic: SSL loss extraction, target postprocessing, compute loop. + Subclasses must implement forward_teacher(). + """ + + def __init__(self, teacher_model, training_cfg, **kwargs): + self.teacher_model = teacher_model + + # Extract SSL loss configs losses_cfg = [ v.loss_fcts for k, v in training_cfg.losses.items() @@ -41,24 +50,12 @@ def __init__(self, model, ema_model, batch_size, training_cfg, **kwargs): # TODO: support multiple LossLatentSSLStudentTeacher loss terms self.postprocess_targets = get_target_postprocessing(losses_cfg[0], training_cfg, **kwargs) - self.reset() - - def reset(self, batch_size=None): - self.ema_model.reset() - if batch_size is not None: - self.batch_size = batch_size - - def update_state_pre_backward(self, istep, batch, model, **kwargs) -> None: - return - - def update_state_post_opt_step(self, istep, batch, model, **kwargs) -> None: - if self.ema_model.is_model_sharded: - self.ema_model.ema_model.reshard() - self.ema_model.update(istep, self.batch_size) + def forward_teacher(self, model_params, batch) -> Any: + raise NotImplementedError("Subclasses must implement forward_teacher()") - def compute(self, bidx, batch, model_params, model) -> tuple[Any, Any]: + def compute(self, bidx, batch, model_params, model) -> TargetAuxOutput: with torch.no_grad(): - outputs = self.ema_model.forward_eval(model_params, batch).get_latent_prediction(0) + outputs = self.forward_teacher(model_params, batch).get_latent_prediction(0) targets = {} for loss_name, target_module in self.postprocess_targets.items(): targets[loss_name] = target_module(outputs[loss_name]) @@ -72,7 +69,10 @@ def compute(self, bidx, batch, model_params, model) -> tuple[Any, Any]: return targets_out - def to_device(self, device) -> EMATeacher: + def update_state_pre_backward(self, istep, batch, model, **kwargs) -> None: + return + + def to_device(self, device) -> EncoderTeacher: for _, module in self.postprocess_targets.items(): module.to(device) return self @@ -82,10 +82,112 @@ def get_current_beta(self, cur_step: int) -> float: return beta -def get_target_postprocessing(target_losses: list[str], training_cfg, **kwargs): +class EMATeacher(EncoderTeacher): + """SSL teacher using exponential moving average of student weights.""" + + def __init__(self, model, ema_model, batch_size, training_cfg, **kwargs): + super().__init__(model, training_cfg, **kwargs) + self.ema_model = ema_model + self.batch_size = batch_size + self.reset() + + def forward_teacher(self, model_params, batch): + return self.ema_model.forward_eval(model_params, batch) + + def reset(self, batch_size=None): + self.ema_model.reset() + if batch_size is not None: + self.batch_size = batch_size + + def update_state_post_opt_step(self, istep, batch, model, **kwargs) -> None: + if self.ema_model.is_model_sharded: + self.ema_model.ema_model.reshard() + self.ema_model.update(istep, self.batch_size) + + def get_current_beta(self, cur_step: int) -> float: + """Return the current EMA interpolation beta for monitoring.""" + return self.ema_model.get_current_beta(cur_step) + + +class FrozenTeacher(EncoderTeacher): + """SSL teacher using a frozen pre-trained encoder. + + The encoder is loaded from a checkpoint and never updated. Non-encoder + parts are discarded; latent heads are created fresh based on the student's + SSL loss config. + """ + + def __init__(self, teacher_model, training_cfg, teacher_model_params=None): + super().__init__(teacher_model, training_cfg) + self.teacher_model_params = teacher_model_params + + # Freeze all parameters + for param in self.teacher_model.parameters(): + param.requires_grad = False + self.teacher_model.eval() + + @classmethod + def from_pretrained(cls, cf: Config, dataset, device, params: dict) -> FrozenTeacher: + """Create a FrozenTeacher from a pre-trained checkpoint. + + Args: + cf: Full training config + dataset: Dataset for model creation + device: Target device + params: Dict with 'teacher_run_id' and optional 'teacher_mini_epoch' + """ + + teacher_run_id = params["teacher_run_id"] + teacher_mini_epoch = params.get("teacher_mini_epoch", -1) + + # Load teacher's config, create model with teacher's architecture + teacher_config = load_run_config(teacher_run_id, teacher_mini_epoch, model_path=None) + teacher_config = merge_configs(teacher_config, {"with_ddp": False, "with_fsdp": False}) + + teacher_model = get_model(teacher_config, "student", dataset, {}) + + # Load only encoder weights + load_encoder_from_checkpoint(teacher_model, cf, teacher_run_id, teacher_mini_epoch, device) + + # Strip to encoder + create fresh heads + prepare_encoder_teacher(teacher_model, cf.training_config, teacher_config) + + # Create model params matching teacher's architecture + teacher_model_params = ModelParams(teacher_config).create(teacher_config).to(device) + + return cls(teacher_model, cf.training_config, teacher_model_params) + + def forward_teacher(self, model_params, batch): + params = ( + self.teacher_model_params if self.teacher_model_params is not None else model_params + ) + return self.teacher_model(params, batch) + + def reset(self, batch_size=None): + pass + + def update_state_post_opt_step(self, istep, batch, model, **kwargs) -> None: + pass + + +def get_target_postprocessing( + target_losses: dict[str, Any], training_cfg, **kwargs +) -> dict[str, torch.nn.Module]: + """Create target postprocessing modules for each SSL loss type. + + Args: + target_losses: Dict mapping loss name → loss config + training_cfg: Training configuration + + Returns: + Dict mapping loss name → target processing module + """ return_dict = {} for loss_name, conf in target_losses.items(): if loss_name == "iBOT": + for key in ("out_dim", "center_momentum", "teacher_temp", "teacher_style"): + if key not in conf: + raise KeyError(f"iBOT config missing required key {key!r}") return_dict[loss_name] = iBOTPatchTargetProcessing( patch_out_dim=conf["out_dim"], center_momentum=conf["center_momentum"], @@ -94,6 +196,9 @@ def get_target_postprocessing(target_losses: list[str], training_cfg, **kwargs): teacher_style=conf["teacher_style"], ) elif loss_name == "DINO": + for key in ("out_dim", "center_momentum", "teacher_style"): + if key not in conf: + raise KeyError(f"DINO config missing required key {key!r}") return_dict[loss_name] = DINOTargetProcessing( out_dim=conf["out_dim"], center_momentum=conf["center_momentum"], @@ -103,6 +208,7 @@ def get_target_postprocessing(target_losses: list[str], training_cfg, **kwargs): elif loss_name == "JEPA": return_dict[loss_name] = JEPATargetProcessing() else: - # We skip losses that are not handled by the EMATeacher + # We skip losses that are not handled by the teacher + logger.debug(f"Skipping unknown loss type {loss_name!r} in target postprocessing") continue return return_dict diff --git a/src/weathergen/train/target_and_aux_utils.py b/src/weathergen/train/target_and_aux_utils.py new file mode 100644 index 000000000..efaff18bc --- /dev/null +++ b/src/weathergen/train/target_and_aux_utils.py @@ -0,0 +1,83 @@ +import omegaconf + +from weathergen.common.config import Config, merge_configs +from weathergen.model.ema import EMAModel +from weathergen.model.model_interface import init_model_and_shard +from weathergen.train.target_and_aux_module_base import PhysicalTargetAndAux +from weathergen.train.target_and_aux_ssl_teacher import EMATeacher, FrozenTeacher +from weathergen.train.teacher_utils import load_encoder_from_checkpoint, prepare_encoder_teacher + + +def get_target_aux_calculator( + cf: Config, loss_cfg: omegaconf.OmegaConf, dataset, model, device, batch_size_per_gpu, **kwargs +): + """ + Create target aux calculator + """ + + target_and_aux_calc_cfg = loss_cfg.get("target_and_aux_calc", "Physical") + + # parse target_and_aux_calc_cfg specification which can either be a string or config dict + if type(target_and_aux_calc_cfg) is str: + target_and_aux_calc = target_and_aux_calc_cfg + target_and_aux_calc_params = {} + elif type(target_and_aux_calc_cfg) is omegaconf.dictconfig.DictConfig: + # single key is the target_and_aux_calc type + target_and_aux_calc = list(target_and_aux_calc_cfg.keys())[0] + # value is dict with the target_and_aux_calc parameters + target_and_aux_calc_params = list(target_and_aux_calc_cfg.values())[0] + else: + assert False, "target_and_aux_calc needs either be name or config dict." + + # create target_and_aux_calc + if target_and_aux_calc == "Physical": + target_aux = PhysicalTargetAndAux(loss_cfg, model) + + elif target_and_aux_calc == "EMATeacher": + # work around for problems with FSDP2 + assert not cf.with_fsdp, "EMATeacher not supported with FSDP(2) at the moment" + + meta_ema_model, _ = init_model_and_shard( + cf, + dataset, + None, + None, + "student", + device, + with_ddp=False, + with_fsdp=False, + overrides=target_and_aux_calc_params.get("model_param_overrides", {}), + ) + + # Strip to encoder + create fresh heads + cf_overridden = merge_configs( + cf, target_and_aux_calc_params.get("model_param_overrides", {}) + ) + prepare_encoder_teacher(meta_ema_model, cf.training_config, cf_overridden) + + ema_model = EMAModel( + model, + meta_ema_model, + halflife_steps=target_and_aux_calc_params.get("ema_halflife_in_thousands", 1e-3), + rampup_ratio=target_and_aux_calc_params.get("ema_ramp_up_ratio", 0.09), + is_model_sharded=(cf.with_ddp and cf.with_fsdp), + ) + + batch_size = cf.get("world_size_original", cf.get("world_size")) * batch_size_per_gpu + target_aux = EMATeacher(model, ema_model, batch_size, cf.training_config) + + # Optional: warm start encoder from checkpoint + teacher_run_id = target_and_aux_calc_params.get("teacher_run_id") + if teacher_run_id is not None: + teacher_mini_epoch = target_and_aux_calc_params.get("teacher_mini_epoch", -1) + load_encoder_from_checkpoint( + ema_model.ema_model, cf, teacher_run_id, teacher_mini_epoch, device + ) + + elif target_and_aux_calc == "FrozenTeacher": + target_aux = FrozenTeacher.from_pretrained(cf, dataset, device, target_and_aux_calc_params) + + else: + raise NotImplementedError(f"{target_and_aux_calc} is not implemented") + + return target_aux diff --git a/src/weathergen/train/teacher_utils.py b/src/weathergen/train/teacher_utils.py new file mode 100644 index 000000000..c026960b5 --- /dev/null +++ b/src/weathergen/train/teacher_utils.py @@ -0,0 +1,135 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from __future__ import annotations + +import logging +from pathlib import Path + +import torch +import torch.nn as nn + +from weathergen.common.config import get_path_model +from weathergen.model.engines import ( + LatentPredictionHeadIdentity, + LatentPredictionHeadMLP, + LatentPredictionHeadTransformer, +) + +logger = logging.getLogger(__name__) + + +def _create_teacher_heads( + name: str, head_type: str, dim_embed: int, loss_conf, cf=None +) -> nn.Module: + """Create a latent prediction head for a given SSL loss type. + + Mirrors Model._create_latent_pred_head() logic with per-loss-type token settings: + iBOT: use_class_token=True, use_patch_token=True + DINO: use_class_token=True, use_patch_token=False + """ + if name == "iBOT": + use_class_token, use_patch_token = True, True + elif name == "DINO": + use_class_token, use_patch_token = True, False + else: + raise ValueError(f"_create_teacher_heads does not support loss type {name!r}") + + if head_type == "mlp": + return LatentPredictionHeadMLP( + f"{name}-head", dim_embed, loss_conf, use_class_token, use_patch_token + ) + elif head_type == "transformer": + if cf is None: + raise ValueError("LatentPredictionHeadTransformer requires a global config (cf)") + return LatentPredictionHeadTransformer( + cf, f"{name}-head", dim_embed, loss_conf, use_class_token, use_patch_token + ) + elif head_type == "identity": + return LatentPredictionHeadIdentity() + else: + raise ValueError(f"Unknown latent prediction head type {head_type!r}") + + +def prepare_encoder_teacher(model: nn.Module, training_cfg, override_cfg) -> None: + """Strip a model to encoder-only and create fresh SSL latent heads. + + Modifies model in-place: + 1. Removes forecast_engine, decoders, pred_heads, embed_target_coords + 2. Ensures latent_pre_norm exists + 3. Creates fresh latent_heads based on the student's SSL loss config + """ + # Strip non-encoder components + teacher_dim_embed = override_cfg.ae_global_dim_embed + model.forecast_engine = None + model.embed_target_coords = nn.ModuleDict() + model.target_token_engines = nn.ModuleDict() + model.pred_heads = nn.ModuleDict() + + # Ensure latent_pre_norm exists (teacher may not have had SSL training) + if model.latent_pre_norm is None: + model.latent_pre_norm = nn.LayerNorm(teacher_dim_embed) + + # Create fresh latent heads from student's SSL config + model.latent_heads = nn.ModuleDict() + ssl_losses = [ + v for v in training_cfg.losses.values() if v.type == "LossLatentSSLStudentTeacher" + ] + for ssl_loss in ssl_losses: + for name, conf in ssl_loss.loss_fcts.items(): + if name == "JEPA": + model.latent_heads[name] = LatentPredictionHeadIdentity() + elif name in ("iBOT", "DINO"): + head_type = conf.get("head", "mlp").lower() + model.latent_heads[name] = _create_teacher_heads( + name, head_type, teacher_dim_embed, conf + ) + else: + logger.warning(f"Unknown SSL loss type {name!r} in teacher setup, skipping.") + + +def load_encoder_from_checkpoint( + model: nn.Module, + cf, + teacher_run_id: str, + teacher_mini_epoch: int | None, + device: torch.device | str, +) -> None: + """Load only encoder weights from a checkpoint into a model. + + Filters checkpoint to encoder.* and latent_pre_norm* keys only, then loads with + strict=False. Moves the model to the given device afterwards. + """ + path_run = Path(cf.get("model_path", get_path_model(run_id=teacher_run_id))) / teacher_run_id + mini_epoch_id = ( + f"chkpt{teacher_mini_epoch:05d}" + if teacher_mini_epoch is not None and teacher_mini_epoch != -1 + else "latest" + ) + filename = f"{teacher_run_id}_{mini_epoch_id}.chkpt" + + params = torch.load(path_run / filename, map_location="cpu", mmap=True, weights_only=True) + + # Filter to encoder + latent_pre_norm only + encoder_params = { + k: v for k, v in params.items() if k.startswith(("encoder.", "latent_pre_norm")) + } + + mkeys, ukeys = model.load_state_dict(encoder_params, strict=False) + model.to(device) + + logging.info(f"Teacher: Loaded encoder weights from checkpoint {filename}") + if mkeys is not None: + logger.info(f"Number of missing keys: {len(mkeys)}") + logger.debug(f"Missing keys: {mkeys}") + if ukeys is not None: + logger.info(f"Number of unused keys: {len(ukeys)}") + logger.debug(f"Unused keys: {ukeys}") + if mkeys is None and ukeys is None: + logger.info("All keys in checkpoint matched successfully.") diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 431b9d1e0..64d7de86d 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -25,13 +25,13 @@ from weathergen.datasets.multi_stream_data_sampler import MultiStreamDataSampler from weathergen.model.ema import EMAModel from weathergen.model.model_interface import ( - get_target_aux_calculator, init_model_and_shard, ) from weathergen.model.utils import apply_fct_to_blocks, set_to_eval from weathergen.train.collapse_monitor import CollapseMonitor from weathergen.train.loss_calculator import LossCalculator from weathergen.train.lr_scheduler import LearningRateScheduler +from weathergen.train.target_and_aux_utils import get_target_aux_calculator from weathergen.train.trainer_base import TrainerBase from weathergen.train.utils import ( TRAIN, @@ -166,8 +166,6 @@ def get_target_aux_calculators(self, mode_cfg): batch_size = get_batch_size_from_config(mode_cfg) # get target_aux calculators for different loss terms - # del self.cf.training_config.losses["student-teacher"]["loss_fcts"]["JEPA"] - # del mode_cfg.losses["student-teacher"]["loss_fcts"]["JEPA"] target_and_aux_calculators = {} for loss_name, loss_cfg in mode_cfg.losses.items(): target_and_aux_calculators[loss_name] = get_target_aux_calculator( diff --git a/src/weathergen/utils/train_logger.py b/src/weathergen/utils/train_logger.py index 0e34a1363..cbdc82acb 100644 --- a/src/weathergen/utils/train_logger.py +++ b/src/weathergen/utils/train_logger.py @@ -31,8 +31,6 @@ _weathergen_timestamp = "weathergen.timestamp" _weathergen_reltime = "weathergen.reltime" _weathergen_time = "weathergen.time" -_performance_gpu = "perf.gpu" -_performance_memory = "perf.memory" _logger = logging.getLogger(__name__) @@ -102,8 +100,6 @@ def add_logs( stddev_all: dict, avg_loss: list[float] = None, lr: float = None, - perf_gpu: float = 0.0, - perf_mem: float = 0.0, ) -> None: """ Log training or validation data @@ -114,8 +110,6 @@ def add_logs( metrics["loss_avg_mean"] = np.nanmean(avg_loss) metrics["learning_rate"] = lr metrics["num_samples"] = int(samples) - metrics[_performance_gpu] = perf_gpu - metrics[_performance_memory] = perf_mem for key, value in losses_all.items(): metrics[key] = np.nanmean(value) diff --git a/tests/test_encoder_teacher.py b/tests/test_encoder_teacher.py new file mode 100644 index 000000000..a319a4bfd --- /dev/null +++ b/tests/test_encoder_teacher.py @@ -0,0 +1,647 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest +import torch +import torch.nn as nn +from omegaconf import OmegaConf + +# flash_attn is GPU-only; skip entire module if not available (e.g. macOS) +pytest.importorskip("flash_attn", reason="flash_attn required (GPU-only)") + +from weathergen.model.engines import ( # noqa: E402 + LatentPredictionHeadIdentity, + LatentPredictionHeadMLP, +) +from weathergen.model.ssl_target_processing import ( # noqa: E402 + DINOTargetProcessing, + JEPATargetProcessing, + iBOTPatchTargetProcessing, +) +from weathergen.train.target_and_aux_ssl_teacher import ( # noqa: E402 + EMATeacher, + EncoderTeacher, + FrozenTeacher, + get_target_postprocessing, +) +from weathergen.train.teacher_utils import ( # noqa: E402 + _create_teacher_heads, + load_encoder_from_checkpoint, + prepare_encoder_teacher, +) + +# --------------------------------------------------------------------------- +# Fixtures and helpers +# --------------------------------------------------------------------------- + + +def _make_training_cfg(loss_types: dict[str, dict] | None = None) -> OmegaConf: + """Create a minimal training config with SSL losses. + + loss_types: dict mapping loss name -> loss config dict. + Defaults to a single JEPA loss with identity head. + """ + if loss_types is None: + loss_types = {"JEPA": {"head": "identity"}} + + return OmegaConf.create( + { + "losses": { + "ssl_loss": { + "type": "LossLatentSSLStudentTeacher", + "loss_fcts": loss_types, + } + } + } + ) + + +def _make_mock_model(dim_embed: int = 64) -> nn.Module: + """Create a mock model with the attributes that prepare_encoder_teacher expects.""" + model = nn.Module() + model.forecast_engine = nn.Linear(10, 10) + model.embed_target_coords = nn.ModuleDict({"stream1": nn.Linear(3, 3)}) + model.target_token_engines = nn.ModuleDict({"stream1": nn.Linear(5, 5)}) + model.pred_heads = nn.ModuleDict({"stream1": nn.Linear(5, 5)}) + model.latent_pre_norm = None + model.latent_heads = nn.ModuleDict({"existing": nn.Linear(dim_embed, dim_embed)}) + # Add a minimal encoder + model.encoder = nn.Linear(10, dim_embed) + return model + + +def _make_mock_ema_model(): + """Create a mock EMA model for EMATeacher tests.""" + ema = MagicMock() + ema.is_model_sharded = False + ema.reset = MagicMock() + ema.update = MagicMock() + ema.forward_eval = MagicMock() + ema.get_current_beta = MagicMock(return_value=0.99) + return ema + + +# --------------------------------------------------------------------------- +# Tests for prepare_encoder_teacher +# --------------------------------------------------------------------------- + + +class TestPrepareEncoderTeacher: + def test_strips_non_encoder_components(self): + model = _make_mock_model() + training_cfg = _make_training_cfg() + + prepare_encoder_teacher(model, training_cfg, teacher_dim_embed=64) + + assert model.forecast_engine is None + assert len(model.embed_target_coords) == 0 + assert len(model.target_token_engines) == 0 + assert len(model.pred_heads) == 0 + + def test_creates_latent_pre_norm_if_missing(self): + model = _make_mock_model() + assert model.latent_pre_norm is None + + training_cfg = _make_training_cfg() + prepare_encoder_teacher(model, training_cfg, teacher_dim_embed=64) + + assert isinstance(model.latent_pre_norm, nn.LayerNorm) + assert model.latent_pre_norm.normalized_shape == (64,) + + def test_preserves_existing_latent_pre_norm(self): + model = _make_mock_model() + existing_norm = nn.LayerNorm(64) + model.latent_pre_norm = existing_norm + + training_cfg = _make_training_cfg() + prepare_encoder_teacher(model, training_cfg, teacher_dim_embed=64) + + # Should keep the existing norm, not replace it + assert model.latent_pre_norm is existing_norm + + def test_jepa_creates_identity_head(self): + model = _make_mock_model() + training_cfg = _make_training_cfg({"JEPA": {"head": "identity"}}) + + prepare_encoder_teacher(model, training_cfg, teacher_dim_embed=64) + + assert "JEPA" in model.latent_heads + assert isinstance(model.latent_heads["JEPA"], LatentPredictionHeadIdentity) + + def test_ibot_creates_mlp_head_by_default(self): + model = _make_mock_model() + training_cfg = _make_training_cfg( + { + "iBOT": { + "head": "mlp", + "out_dim": 32, + "num_layers": 2, + "hidden_factor": 2, + "center_momentum": 0.9, + "teacher_temp": 0.04, + "teacher_style": "softmax_center", + "loss_extra_args": {"student_temp": 0.1}, + } + } + ) + + prepare_encoder_teacher(model, training_cfg, teacher_dim_embed=64) + + assert "iBOT" in model.latent_heads + assert isinstance(model.latent_heads["iBOT"], LatentPredictionHeadMLP) + + def test_dino_creates_mlp_head_by_default(self): + model = _make_mock_model() + training_cfg = _make_training_cfg( + { + "DINO": { + "head": "mlp", + "out_dim": 32, + "num_layers": 2, + "hidden_factor": 2, + "center_momentum": 0.9, + "teacher_style": "softmax_center", + "loss_extra_args": {"student_temp": 0.1}, + } + } + ) + + prepare_encoder_teacher(model, training_cfg, teacher_dim_embed=64) + + assert "DINO" in model.latent_heads + assert isinstance(model.latent_heads["DINO"], LatentPredictionHeadMLP) + + def test_replaces_existing_heads(self): + model = _make_mock_model() + assert "existing" in model.latent_heads + + training_cfg = _make_training_cfg({"JEPA": {"head": "identity"}}) + prepare_encoder_teacher(model, training_cfg, teacher_dim_embed=64) + + # Old heads should be gone, only new ones + assert "existing" not in model.latent_heads + assert "JEPA" in model.latent_heads + + def test_multiple_ssl_losses(self): + model = _make_mock_model() + training_cfg = _make_training_cfg( + { + "JEPA": {"head": "identity"}, + "iBOT": { + "head": "mlp", + "out_dim": 32, + "num_layers": 2, + "hidden_factor": 2, + "center_momentum": 0.9, + "teacher_temp": 0.04, + "teacher_style": "softmax_center", + "loss_extra_args": {"student_temp": 0.1}, + }, + } + ) + + prepare_encoder_teacher(model, training_cfg, teacher_dim_embed=64) + + assert "JEPA" in model.latent_heads + assert "iBOT" in model.latent_heads + assert isinstance(model.latent_heads["JEPA"], LatentPredictionHeadIdentity) + assert isinstance(model.latent_heads["iBOT"], LatentPredictionHeadMLP) + + def test_no_ssl_losses(self): + model = _make_mock_model() + training_cfg = OmegaConf.create( + { + "losses": { + "phys_loss": { + "type": "LossPhysical", + } + } + } + ) + + prepare_encoder_teacher(model, training_cfg, teacher_dim_embed=64) + + assert len(model.latent_heads) == 0 + + def test_encoder_preserved(self): + model = _make_mock_model() + original_encoder = model.encoder + training_cfg = _make_training_cfg() + + prepare_encoder_teacher(model, training_cfg, teacher_dim_embed=64) + + assert model.encoder is original_encoder + + +# --------------------------------------------------------------------------- +# Tests for load_encoder_from_checkpoint +# --------------------------------------------------------------------------- + + +class TestLoadEncoderFromCheckpoint: + def test_loads_only_encoder_keys(self, tmp_path): + """Verify that only encoder.* and latent_pre_norm* keys are loaded.""" + # Create a mock model + model = nn.Module() + model.encoder = nn.Linear(10, 20) + model.latent_pre_norm = nn.LayerNorm(20) + model.other_module = nn.Linear(20, 5) + + # Create checkpoint with encoder + non-encoder params + checkpoint = {} + for name, param in model.state_dict().items(): + checkpoint[name] = torch.randn_like(param) + # Add some extra non-encoder params that should be ignored + checkpoint["forecast_engine.weight"] = torch.randn(10, 10) + checkpoint["pred_heads.stream1.weight"] = torch.randn(5, 5) + + # Save checkpoint + run_id = "test1234" + run_dir = tmp_path / run_id + run_dir.mkdir() + torch.save(checkpoint, run_dir / f"{run_id}_latest.chkpt") + + cf = OmegaConf.create({"model_path": str(tmp_path)}) + + # Load - should not raise despite extra keys + load_encoder_from_checkpoint(model, cf, run_id, -1, "cpu") + + def test_mini_epoch_filename(self, tmp_path): + """Test that specific mini_epoch generates correct filename.""" + model = nn.Module() + model.encoder = nn.Linear(10, 20) + + run_id = "test1234" + run_dir = tmp_path / run_id + run_dir.mkdir() + torch.save( + {"encoder.weight": torch.randn(20, 10), "encoder.bias": torch.randn(20)}, + run_dir / f"{run_id}_chkpt00042.chkpt", + ) + + cf = OmegaConf.create({"model_path": str(tmp_path)}) + load_encoder_from_checkpoint(model, cf, run_id, 42, "cpu") + + def test_latest_filename(self, tmp_path): + """Test that mini_epoch=-1 generates 'latest' filename.""" + model = nn.Module() + model.encoder = nn.Linear(10, 20) + + run_id = "test1234" + run_dir = tmp_path / run_id + run_dir.mkdir() + torch.save( + {"encoder.weight": torch.randn(20, 10), "encoder.bias": torch.randn(20)}, + run_dir / f"{run_id}_latest.chkpt", + ) + + cf = OmegaConf.create({"model_path": str(tmp_path)}) + load_encoder_from_checkpoint(model, cf, run_id, -1, "cpu") + + def test_none_mini_epoch_uses_latest(self, tmp_path): + """Test that mini_epoch=None generates 'latest' filename.""" + model = nn.Module() + model.encoder = nn.Linear(10, 20) + + run_id = "test1234" + run_dir = tmp_path / run_id + run_dir.mkdir() + torch.save( + {"encoder.weight": torch.randn(20, 10), "encoder.bias": torch.randn(20)}, + run_dir / f"{run_id}_latest.chkpt", + ) + + cf = OmegaConf.create({"model_path": str(tmp_path)}) + load_encoder_from_checkpoint(model, cf, run_id, None, "cpu") + + +# --------------------------------------------------------------------------- +# Tests for _create_head +# --------------------------------------------------------------------------- + + +class TestCreateHead: + def test_ibot_mlp(self): + conf = OmegaConf.create({"out_dim": 32, "num_layers": 2, "hidden_factor": 2}) + head = _create_teacher_heads("iBOT", "mlp", 64, conf) + assert isinstance(head, LatentPredictionHeadMLP) + assert head.use_class_token is True + assert head.use_patch_token is True + + def test_dino_mlp(self): + conf = OmegaConf.create({"out_dim": 32, "num_layers": 2, "hidden_factor": 2}) + head = _create_teacher_heads("DINO", "mlp", 64, conf) + assert isinstance(head, LatentPredictionHeadMLP) + assert head.use_class_token is True + assert head.use_patch_token is False + + def test_identity_head(self): + head = _create_teacher_heads("iBOT", "identity", 64, {}) + assert isinstance(head, LatentPredictionHeadIdentity) + + def test_unknown_loss_type(self): + with pytest.raises(ValueError, match="does not support loss type"): + _create_teacher_heads("UnknownLoss", "mlp", 64, {}) + + def test_unknown_head_type(self): + with pytest.raises(ValueError, match="Unknown latent prediction head type"): + _create_teacher_heads("iBOT", "nonexistent", 64, {}) + + def test_transformer_requires_cf(self): + conf = OmegaConf.create( + { + "out_dim": 32, + "num_blocks": 1, + "num_heads": 2, + "with_qk_lnorm": True, + "intermediate_dim": 32, + "dropout_rate": 0.0, + } + ) + with pytest.raises(ValueError, match="requires a global config"): + _create_teacher_heads("iBOT", "transformer", 64, conf, cf=None) + + +# --------------------------------------------------------------------------- +# Tests for get_target_postprocessing +# --------------------------------------------------------------------------- + + +class TestGetTargetPostprocessing: + def test_jepa(self): + losses = OmegaConf.create({"JEPA": {"head": "identity"}}) + training_cfg = OmegaConf.create({}) + result = get_target_postprocessing(losses, training_cfg) + assert "JEPA" in result + assert isinstance(result["JEPA"], JEPATargetProcessing) + + def test_ibot(self): + losses = OmegaConf.create( + { + "iBOT": { + "out_dim": 32, + "center_momentum": 0.9, + "teacher_temp": 0.04, + "teacher_style": "softmax_center", + "loss_extra_args": {"student_temp": 0.1}, + } + } + ) + training_cfg = OmegaConf.create({}) + result = get_target_postprocessing(losses, training_cfg) + assert "iBOT" in result + assert isinstance(result["iBOT"], iBOTPatchTargetProcessing) + + def test_dino(self): + losses = OmegaConf.create( + { + "DINO": { + "out_dim": 32, + "center_momentum": 0.9, + "teacher_style": "softmax_center", + "loss_extra_args": {"student_temp": 0.1}, + } + } + ) + training_cfg = OmegaConf.create({}) + result = get_target_postprocessing(losses, training_cfg) + assert "DINO" in result + assert isinstance(result["DINO"], DINOTargetProcessing) + + def test_unknown_loss_skipped(self): + losses = OmegaConf.create({"UnknownLoss": {"foo": "bar"}}) + training_cfg = OmegaConf.create({}) + result = get_target_postprocessing(losses, training_cfg) + assert len(result) == 0 + + def test_ibot_missing_config_key(self): + losses = OmegaConf.create({"iBOT": {"out_dim": 32}}) # missing required keys + training_cfg = OmegaConf.create({}) + with pytest.raises(KeyError, match="center_momentum"): + get_target_postprocessing(losses, training_cfg) + + def test_dino_missing_config_key(self): + losses = OmegaConf.create({"DINO": {"out_dim": 32}}) # missing required keys + training_cfg = OmegaConf.create({}) + with pytest.raises(KeyError, match="center_momentum"): + get_target_postprocessing(losses, training_cfg) + + +# --------------------------------------------------------------------------- +# Tests for EncoderTeacher interface +# --------------------------------------------------------------------------- + + +class TestEncoderTeacher: + def test_forward_teacher_not_implemented(self): + teacher = EncoderTeacher.__new__(EncoderTeacher) + teacher.teacher_model = nn.Module() + teacher.postprocess_targets = {} + + with pytest.raises(NotImplementedError): + teacher.forward_teacher(None, None) + + def test_update_state_pre_backward_is_noop(self): + teacher = EncoderTeacher.__new__(EncoderTeacher) + teacher.postprocess_targets = {} + # Should not raise + teacher.update_state_pre_backward(0, None, None) + + +# --------------------------------------------------------------------------- +# Tests for EMATeacher +# --------------------------------------------------------------------------- + + +class TestEMATeacher: + def test_init_calls_reset(self): + ema = _make_mock_ema_model() + training_cfg = _make_training_cfg() + EMATeacher(nn.Module(), ema, batch_size=32, training_cfg=training_cfg) + ema.reset.assert_called_once() + + def test_reset_updates_batch_size(self): + ema = _make_mock_ema_model() + training_cfg = _make_training_cfg() + teacher = EMATeacher(nn.Module(), ema, batch_size=32, training_cfg=training_cfg) + + teacher.reset(batch_size=64) + assert teacher.batch_size == 64 + + def test_reset_without_batch_size(self): + ema = _make_mock_ema_model() + training_cfg = _make_training_cfg() + teacher = EMATeacher(nn.Module(), ema, batch_size=32, training_cfg=training_cfg) + + teacher.reset() + assert teacher.batch_size == 32 + + def test_update_state_post_opt_step_calls_ema_update(self): + ema = _make_mock_ema_model() + training_cfg = _make_training_cfg() + teacher = EMATeacher(nn.Module(), ema, batch_size=32, training_cfg=training_cfg) + + teacher.update_state_post_opt_step(istep=10, batch=None, model=None) + ema.update.assert_called_once_with(10, 32) + + def test_get_current_beta(self): + ema = _make_mock_ema_model() + training_cfg = _make_training_cfg() + teacher = EMATeacher(nn.Module(), ema, batch_size=32, training_cfg=training_cfg) + + teacher.get_current_beta(100) + ema.get_current_beta.assert_called_once_with(100, 32) + + def test_has_required_methods(self): + """EMATeacher has all required TargetAndAuxModuleBase methods.""" + assert hasattr(EMATeacher, "reset") + assert hasattr(EMATeacher, "compute") + assert hasattr(EMATeacher, "update_state_pre_backward") + assert hasattr(EMATeacher, "update_state_post_opt_step") + assert hasattr(EMATeacher, "to_device") + + def test_to_device_moves_postprocessors(self): + ema = _make_mock_ema_model() + training_cfg = _make_training_cfg() + teacher = EMATeacher(nn.Module(), ema, batch_size=32, training_cfg=training_cfg) + + # Should not raise + result = teacher.to_device("cpu") + assert result is teacher + + +# --------------------------------------------------------------------------- +# Tests for FrozenTeacher +# --------------------------------------------------------------------------- + + +class TestFrozenTeacher: + def test_freezes_all_params(self): + model = nn.Module() + model.encoder = nn.Linear(10, 20) + model.latent_heads = nn.ModuleDict() + model.latent_pre_norm = nn.LayerNorm(20) + training_cfg = _make_training_cfg() + + teacher = FrozenTeacher(model, training_cfg) + + for param in teacher.teacher_model.parameters(): + assert not param.requires_grad + + def test_eval_mode(self): + model = nn.Module() + model.encoder = nn.Linear(10, 20) + model.latent_heads = nn.ModuleDict() + model.latent_pre_norm = nn.LayerNorm(20) + training_cfg = _make_training_cfg() + + teacher = FrozenTeacher(model, training_cfg) + assert not teacher.teacher_model.training + + def test_reset_is_noop(self): + model = nn.Module() + model.encoder = nn.Linear(10, 20) + model.latent_heads = nn.ModuleDict() + model.latent_pre_norm = nn.LayerNorm(20) + training_cfg = _make_training_cfg() + + teacher = FrozenTeacher(model, training_cfg) + teacher.reset() # should not raise + teacher.reset(batch_size=64) # should not raise + + def test_update_state_post_opt_step_is_noop(self): + model = nn.Module() + model.encoder = nn.Linear(10, 20) + model.latent_heads = nn.ModuleDict() + model.latent_pre_norm = nn.LayerNorm(20) + training_cfg = _make_training_cfg() + + teacher = FrozenTeacher(model, training_cfg) + teacher.update_state_post_opt_step(istep=0, batch=None, model=None) # should not raise + + def test_forward_teacher_uses_own_params(self): + model = MagicMock() + model.parameters = MagicMock(return_value=iter([])) + model.eval = MagicMock(return_value=model) + training_cfg = _make_training_cfg() + teacher_params = MagicMock() + + teacher = FrozenTeacher(model, training_cfg, teacher_model_params=teacher_params) + + batch = MagicMock() + teacher.forward_teacher(MagicMock(), batch) + model.assert_called_once_with(teacher_params, batch) + + def test_forward_teacher_falls_back_to_student_params(self): + model = MagicMock() + model.parameters = MagicMock(return_value=iter([])) + model.eval = MagicMock(return_value=model) + training_cfg = _make_training_cfg() + + teacher = FrozenTeacher(model, training_cfg, teacher_model_params=None) + + student_params = MagicMock() + batch = MagicMock() + teacher.forward_teacher(student_params, batch) + model.assert_called_once_with(student_params, batch) + + def test_has_required_methods(self): + """FrozenTeacher has all required TargetAndAuxModuleBase methods.""" + assert hasattr(FrozenTeacher, "reset") + assert hasattr(FrozenTeacher, "compute") + assert hasattr(FrozenTeacher, "update_state_pre_backward") + assert hasattr(FrozenTeacher, "update_state_post_opt_step") + assert hasattr(FrozenTeacher, "to_device") + assert hasattr(FrozenTeacher, "from_pretrained") + + def test_from_pretrained_requires_teacher_run_id(self): + cf = OmegaConf.create({"model_path": "/tmp/claude/models"}) + with pytest.raises(KeyError, match="teacher_run_id"): + FrozenTeacher.from_pretrained(cf, None, "cpu", {}) + + +# --------------------------------------------------------------------------- +# Tests for EMAModel.get_current_beta +# --------------------------------------------------------------------------- + + +class TestEMAModelBeta: + def test_get_current_beta(self): + from weathergen.model.ema import EMAModel + + model = nn.Module() + model.p = nn.Parameter(torch.randn(3)) + empty = nn.Module() + empty.p = nn.Parameter(torch.randn(3)) + + ema = EMAModel.__new__(EMAModel) + ema.halflife_steps = 1e-3 + ema.rampup_ratio = 0.09 + + beta = ema.get_current_beta(100, 32) + assert 0.0 < beta < 1.0 + + def test_batch_size_stored_on_update(self): + """Verify that update() stores batch_size.""" + from weathergen.model.ema import EMAModel + + model = nn.Module() + model.p = nn.Parameter(torch.randn(3)) + empty = nn.Module() + empty.p = nn.Parameter(torch.randn(3)) + + ema = EMAModel(model, empty) + assert ema.batch_size == 1 + + ema.update(cur_step=10, batch_size=64) + assert ema.batch_size == 64 From f4934e14eaed905b8a0b29a73e08e4ec780a0a08 Mon Sep 17 00:00:00 2001 From: Sorcha Date: Thu, 9 Apr 2026 20:17:28 +0200 Subject: [PATCH 26/28] linting --- .../weathergen/evaluate/export/export_core.py | 20 ++++++++++++------- .../evaluate/export/parsers/verif_parser.py | 2 +- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/export/export_core.py b/packages/evaluate/src/weathergen/evaluate/export/export_core.py index bacb93212..c25b8f429 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/export_core.py +++ b/packages/evaluate/src/weathergen/evaluate/export/export_core.py @@ -51,8 +51,12 @@ def get_data_worker(args: tuple) -> tuple[int, int, xr.DataArray]: coords_arr = np.asarray(ds_group["coords"]) # (npoints, 2) times_arr = np.asarray(ds_group["times"]).astype("datetime64[ns]") # (npoints,) channels = list(ds_group.attrs["channels"]) - source_interval_start = np.asarray(ds_group.attrs["source_interval"]["start"]).astype("datetime64[ns]") - source_interval_end = np.asarray(ds_group.attrs["source_interval"]["end"]).astype("datetime64[ns]") + source_interval_start = np.asarray(ds_group.attrs["source_interval"]["start"]).astype( + "datetime64[ns]" + ) + source_interval_end = np.asarray(ds_group.attrs["source_interval"]["end"]).astype( + "datetime64[ns]" + ) # Build a lightweight xarray DataArray with the same structure # that process_sample / assign_coords expects: @@ -75,7 +79,7 @@ def get_data_worker(args: tuple) -> tuple[int, int, xr.DataArray]: "lat": ("ipoint", coords_arr[:, 0]), "lon": ("ipoint", coords_arr[:, 1]), "source_interval_start": source_interval_start, - "source_interval_end": source_interval_end + "source_interval_end": source_interval_end, }, ) @@ -318,7 +322,9 @@ def export_model_outputs(data_type: str, config: OmegaConf, **kwargs) -> None: sample_to_batch_idx = {s: i for i, s in enumerate(batch_samples)} batch_tasks = [ - (sample, fstep, stream, data_type) for sample in batch_samples for fstep in fsteps + (sample, fstep, stream, data_type) + for sample in batch_samples + for fstep in fsteps ] _logger.info( @@ -336,7 +342,7 @@ def export_model_outputs(data_type: str, config: OmegaConf, **kwargs) -> None: total=len(batch_tasks), desc=f" Batch {batch_idx + 1}/{n_batches}", ) - + processed_samples = [] for sample, _fstep, data in pool.imap_unordered( @@ -356,7 +362,7 @@ def export_model_outputs(data_type: str, config: OmegaConf, **kwargs) -> None: # Free memory immediately. del sample_results[sample] batch_written += 1 - + # Only save here if need to merge samples, otherwise saved in process_sample if processed_samples[0] is not None: parser.save(processed_samples) @@ -373,4 +379,4 @@ def export_model_outputs(data_type: str, config: OmegaConf, **kwargs) -> None: # Free any remaining refs before next batch. del sample_results - _logger.info(f"Export complete. Wrote {samples_written}/{len(samples)} samples.") \ No newline at end of file + _logger.info(f"Export complete. Wrote {samples_written}/{len(samples)} samples.") diff --git a/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py b/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py index 70f1d708e..19166644f 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py +++ b/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py @@ -196,7 +196,7 @@ def reshape(self, data: xr.DataArray) -> xr.Dataset: data_vars[new_var] = xr.DataArray( data.sel(channel=old_vars).values, dims=["ipoint", "pressure_level"], - coords={"pressure_level":pls}, + coords={"pressure_level": pls}, ) else: data_vars[new_var] = xr.DataArray( From 7b4f87a80ffb66a936c73e5562427ca662fce53c Mon Sep 17 00:00:00 2001 From: Rolf Heilemann Myhre <90390629+rolfhm@users.noreply.github.com> Date: Fri, 17 Apr 2026 15:01:12 +0200 Subject: [PATCH 27/28] Apply suggestion from @enssow check consistent grid Co-authored-by: Sorcha Owens <73587207+enssow@users.noreply.github.com> --- .../src/weathergen/evaluate/export/parsers/verif_parser.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py b/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py index 19166644f..e99ab5afb 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py +++ b/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py @@ -113,7 +113,12 @@ def process_sample( if self.zarr_coords is None: self.zarr_coords = get_grid_points(da_fs[0]) self.zarr_dt = self.get_zarr_dt(da_fs[0]) - + # check consistency of grid points across forecast steps + if not np.array_equal(get_grid_points(da_fs[1]), self.zarr_coords): + raise ValueError( + "Grid points between forecast steps are not consistent."\ + "Check that inference was not performed with masking" + ) da_fs = self.concatenate(da_fs) da_fs = self.assign_frt(da_fs, ref_time) da_fs = self.add_attrs(da_fs) From c7a58d3c5730a9fbabc580d22abe7acf1e4144b4 Mon Sep 17 00:00:00 2001 From: Rolf Heilemann Myhre Date: Fri, 17 Apr 2026 16:29:11 +0200 Subject: [PATCH 28/28] linting --- .../src/weathergen/evaluate/export/parsers/verif_parser.py | 2 +- .../weathergen/evaluate/io/data/dataarray_postprocessing.py | 2 +- src/weathergen/model/engines.py | 4 ++-- src/weathergen/train/target_and_aux_ssl_teacher.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py b/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py index e99ab5afb..958ca1a1f 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py +++ b/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py @@ -116,7 +116,7 @@ def process_sample( # check consistency of grid points across forecast steps if not np.array_equal(get_grid_points(da_fs[1]), self.zarr_coords): raise ValueError( - "Grid points between forecast steps are not consistent."\ + "Grid points between forecast steps are not consistent." "Check that inference was not performed with masking" ) da_fs = self.concatenate(da_fs) diff --git a/packages/evaluate/src/weathergen/evaluate/io/data/dataarray_postprocessing.py b/packages/evaluate/src/weathergen/evaluate/io/data/dataarray_postprocessing.py index 854013948..82f22e744 100644 --- a/packages/evaluate/src/weathergen/evaluate/io/data/dataarray_postprocessing.py +++ b/packages/evaluate/src/weathergen/evaluate/io/data/dataarray_postprocessing.py @@ -8,7 +8,7 @@ # nor does it submit to any jurisdiction. """ -Post-processing helpers for evaluation DataArrays +Post-processing helpers for evaluation DataArrays (channel selection, derived channels, lead-time). """ diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index ff8e3b510..6870bf413 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -118,10 +118,10 @@ def forward(self, batch, pe_embed): # per cell indices into positional encoding tok_counts = batch.tokens_lens.permute([2, 0, 1, 3]).sum(0).flatten() - rows = torch.arange( tok_counts.max(), device=tok_counts.device).unsqueeze(0) + rows = torch.arange(tok_counts.max(), device=tok_counts.device).unsqueeze(0) rows = rows.expand(tok_counts.shape[0], -1) pe_idxs = rows[rows < tok_counts.unsqueeze(1)] - + # actual scatter operation tokens_all.scatter_(0, scatter_idxs, torch.cat(x_embeds) + pe_embed[pe_idxs]) diff --git a/src/weathergen/train/target_and_aux_ssl_teacher.py b/src/weathergen/train/target_and_aux_ssl_teacher.py index bc89d7488..edd8e53b6 100644 --- a/src/weathergen/train/target_and_aux_ssl_teacher.py +++ b/src/weathergen/train/target_and_aux_ssl_teacher.py @@ -28,9 +28,9 @@ prepare_encoder_teacher, ) - logger = logging.getLogger(__name__) + class EncoderTeacher(TargetAndAuxModuleBase): """Base class for SSL teacher models.