diff --git a/config/evaluate/config_zarr2verif.yaml b/config/evaluate/config_zarr2verif.yaml new file mode 100644 index 000000000..8cc573f40 --- /dev/null +++ b/config/evaluate/config_zarr2verif.yaml @@ -0,0 +1,114 @@ +# List of variables to compare between WeatherGenerator output and MetNor observation files +# To add more varaibles use the same formate for the variable using the `var` field to put the +# name of the variable in the WeatherGenerator dataset and units for each stream in `wg_uni` +# Additionally, add the chosen variable to required channels in `verif_parser.py` L92 + +variables: + 2t: + var: 2t + long: 2 meter temperature + wg_unit: {CERRA: K, + MEPS: K, + NORA3: K, + ERA5: K, + DEFAULT: K} + verif_unit: K + obs_name: air_temperature + obs_units: K + level_type: sfc + + sp: + var: sp + long: Surface pressure + wg_unit: {CERRA: Pa, + MEPS: Pa, + NORA3: Pa, + ERA5: Pa, + DEFAULT: Pa} + verif_unit: Pa + obs_name: surface_air_pressure + obs_units: Pa + level_type: sfc + + tp: + var: tp + long: Total precipitation amount + wg_unit: {CERRA: kg/m^2, + MEPS: kg/m^2, + NORA3: kg/m^2, + ERA5: m, + DEFAULT: kg/m^2} + verif_unit: kg/m^2 + obs_name: precipitation_amount_1h + obs_units: kg/m^2 + level_type: sfc + + + msl: + var: mslp + long: Mean sea level pressure + wg_unit: {CERRA: Pa, + MEPS: Pa, + NORA3: Pa, + ERA5: Pa, + DEFAULT: Pa} + verif_unit: Pa + obs_name: surface_air_pressure #check with Rolf + obs_units: kg/m^2 + level_type: sfc + + + 10si: + var: 10si # derived channel + long: wind speed + wg_unit: {CERRA: m/s, + MEPS: m/s, + NORA3: m/s, + ERA5: m/s, + DEFAULT: m/s} + verif_unit: m/s + obs_name: wind_speed + obs_units: m/s + level_type: sfc + + +coordinates: + sfc: + lat: latitude + lon: longitude + forecast_step: leadtime + forecast_reference_time: time + ncells: ncells + pl: + #not needed + pressure_level: pressure + lat: latitude + lon: longitude + forecast_step: leadtime + forecast_reference_time: time + ncells: ncells + +dimensions: + lat: + verif: latitude + std: latitude + verif_unit: degrees_north + lon: + verif: longitude + std: longitude + verif_unit: degrees_east + pressure_level: + verif: pressure + std: pressure + verif_unit: hPa + forecast_reference_time: + verif: time + std: forecast_reference_time + forecast_step: + verif: leadtime + std: forecast_period + long: time since forecast_reference_time + verif_unit: hour + ncells: + verif: ncells + std: ncells \ No newline at end of file diff --git a/packages/evaluate/pyproject.toml b/packages/evaluate/pyproject.toml index 9674bedfc..dae7f4009 100644 --- a/packages/evaluate/pyproject.toml +++ b/packages/evaluate/pyproject.toml @@ -20,7 +20,7 @@ dependencies = [ "eckitlib==1.32.3.7", "earthkit-data==0.17.0", "earthkit-utils==0.1.2" -] + ] [dependency-groups] dev = [ diff --git a/packages/evaluate/src/weathergen/evaluate/export/cf_utils.py b/packages/evaluate/src/weathergen/evaluate/export/cf_utils.py index 1367e4b3c..3a9e7b2e5 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/cf_utils.py +++ b/packages/evaluate/src/weathergen/evaluate/export/cf_utils.py @@ -25,10 +25,8 @@ def __init__(self, config, **kwargs): grid_type : str Type of grid ('regular' or 'gaussian'). """ - for k, v in kwargs.items(): setattr(self, k, v) - self.config = config self.file_extension = _get_file_extension(self.output_format) self.fstep_hours = np.timedelta64(self.fstep_hours, "h") @@ -96,10 +94,12 @@ def _get_file_extension(output_format: str) -> str: """ if output_format == "netcdf": return "nc" + if output_format == "verif": + return "nc" elif output_format == "quaver": return "grib" else: raise ValueError( f"Unsupported output format: {output_format}," - "supported formats are ['netcdf', 'DWD', 'quaver']" + "supported formats are ['netcdf', 'verif', 'quaver']" ) diff --git a/packages/evaluate/src/weathergen/evaluate/export/export_core.py b/packages/evaluate/src/weathergen/evaluate/export/export_core.py index 2bd6de36a..0679cdd75 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/export_core.py +++ b/packages/evaluate/src/weathergen/evaluate/export/export_core.py @@ -54,6 +54,12 @@ def get_data_worker(args: tuple) -> tuple[int, int, xr.DataArray]: coords_arr = np.asarray(ds_group["coords"]) # (npoints, 2) times_arr = np.asarray(ds_group["times"]).astype("datetime64[ns]") # (npoints,) channels = list(ds_group.attrs["channels"]) + source_interval_start = np.asarray(ds_group.attrs["source_interval"]["start"]).astype( + "datetime64[ns]" + ) + source_interval_end = np.asarray(ds_group.attrs["source_interval"]["end"]).astype( + "datetime64[ns]" + ) # Build a lightweight xarray DataArray with the same structure # that process_sample / assign_coords expects: @@ -75,6 +81,8 @@ def get_data_worker(args: tuple) -> tuple[int, int, xr.DataArray]: "valid_time": ("ipoint", times_arr), "lat": ("ipoint", coords_arr[:, 0]), "lon": ("ipoint", coords_arr[:, 1]), + "source_interval_start": source_interval_start, + "source_interval_end": source_interval_end, }, ) @@ -268,6 +276,13 @@ def get_ref_times(fname_zarr, stream, samples, fstep_hours, n_processes) -> list return ref_times +def get_streams(stream, fname_zarr): + with zarrio_reader(fname_zarr) as zio: + zio_streams = zio.streams + streams = zio_streams if stream is None else [stream] + return streams + + def export_model_outputs(data_type: str, config: OmegaConf, **kwargs) -> None: """ Retrieve data from Zarr store and export to the requested format. @@ -303,96 +318,105 @@ def export_model_outputs(data_type: str, config: OmegaConf, **kwargs) -> None: fname_zarr = get_model_results(run_id, epoch, rank) fsteps = get_fsteps(fsteps, fname_zarr) samples = get_samples(samples, fname_zarr) - grid_type = get_grid_type(data_type, stream, fname_zarr) - channels = get_channels(channels, stream, fname_zarr) - ref_times = get_ref_times(fname_zarr, stream, samples, fstep_hours, n_processes) - - kwargs["grid_type"] = grid_type - kwargs["channels"] = channels - kwargs["data_type"] = data_type - - parser = CfParserFactory.get_parser(config=config, **kwargs) - - n_fsteps = len(fsteps) - total_tasks = len(samples) * n_fsteps + streams = get_streams(stream, fname_zarr) + for stream in streams: + grid_type = get_grid_type(data_type, stream, fname_zarr) + channels = get_channels(channels, stream, fname_zarr) + ref_times = get_ref_times(fname_zarr, stream, samples, fstep_hours, n_processes) + kwargs["grid_type"] = grid_type + kwargs["channels"] = channels + kwargs["data_type"] = data_type + + parser = CfParserFactory.get_parser(config=config, **kwargs) + + n_fsteps = len(fsteps) + total_tasks = len(samples) * n_fsteps + + # Batch size in *samples*. Limits how many samples can be in-flight at once, + # bounding peak memory while still allowing read/write overlap within each batch. + batch_size = max(1, n_processes * 2) + n_batches = (len(samples) + batch_size - 1) // batch_size + + _logger.info( + f"Exporting {len(samples)} samples × {n_fsteps} fsteps " + f"({total_tasks} total tasks) in {n_batches} batch(es) of up to " + f"{batch_size} samples, using {n_processes} workers. " + f"Reading and writing are interleaved within each batch." + ) - # Batch size in *samples*. Limits how many samples can be in-flight at once, - # bounding peak memory while still allowing read/write overlap within each batch. - batch_size = max(1, n_processes * 2) - n_batches = (len(samples) + batch_size - 1) // batch_size + # Initialise each worker with the zarr path so it is resolved only once. + with Pool( + processes=n_processes, + initializer=_init_worker, + initargs=(fname_zarr,), + ) as pool: + samples_written = 0 + + for batch_idx in range(n_batches): + batch_start = batch_idx * batch_size + batch_end = min(batch_start + batch_size, len(samples)) + batch_samples = samples[batch_start:batch_end] + batch_ref_times = ref_times[batch_start:batch_end] + + # Map sample -> index within this batch for ref_times lookup. + sample_to_batch_idx = {s: i for i, s in enumerate(batch_samples)} + + batch_tasks = [ + (sample, fstep, stream, data_type) + for sample in batch_samples + for fstep in fsteps + ] + + _logger.info( + f"Batch {batch_idx + 1}/{n_batches}: " + f"samples {batch_start}–{batch_end - 1} " + f"({len(batch_samples)} samples, {len(batch_tasks)} tasks)" + ) - _logger.info( - f"Exporting {len(samples)} samples × {n_fsteps} fsteps " - f"({total_tasks} total tasks) in {n_batches} batch(es) of up to " - f"{batch_size} samples, using {n_processes} workers. " - f"Reading and writing are interleaved within each batch." - ) + # Interleaved read/write: as soon as all fsteps for a sample + # arrive, write it immediately while workers continue reading. + sample_results: dict[int, list] = defaultdict(list) + batch_written = 0 - # Initialise each worker with the zarr path so it is resolved only once. - with Pool( - processes=n_processes, - initializer=_init_worker, - initargs=(fname_zarr,), - ) as pool: - samples_written = 0 - - for batch_idx in range(n_batches): - batch_start = batch_idx * batch_size - batch_end = min(batch_start + batch_size, len(samples)) - batch_samples = samples[batch_start:batch_end] - batch_ref_times = ref_times[batch_start:batch_end] - - # Map sample -> index within this batch for ref_times lookup. - sample_to_batch_idx = {s: i for i, s in enumerate(batch_samples)} - - batch_tasks = [ - (sample, fstep, stream, data_type) for sample in batch_samples for fstep in fsteps - ] - - _logger.info( - f"Batch {batch_idx + 1}/{n_batches}: " - f"samples {batch_start}–{batch_end - 1} " - f"({len(batch_samples)} samples, {len(batch_tasks)} tasks)" - ) - - # Interleaved read/write: as soon as all fsteps for a sample - # arrive, write it immediately while workers continue reading. - sample_results: dict[int, list] = defaultdict(list) - batch_written = 0 - - pbar = tqdm( - total=len(batch_tasks), - desc=f" Batch {batch_idx + 1}/{n_batches}", - ) - - for sample, _fstep, data in pool.imap_unordered( - get_data_worker, batch_tasks, chunksize=max(1, n_fsteps) - ): - sample_results[sample].append(data) - pbar.update(1) - - # Check if this sample is complete (all fsteps received). - if len(sample_results[sample]) == n_fsteps: - b_idx = sample_to_batch_idx[sample] - ref_time = batch_ref_times[b_idx] - results_iter = iter(sample_results[sample]) - parser.process_sample(results_iter, ref_time=ref_time) - - # Free memory immediately. - del sample_results[sample] - batch_written += 1 - - pbar.close() - - samples_written += batch_written - if batch_written != len(batch_samples): - _logger.error( - f"Batch {batch_idx + 1}: expected {len(batch_samples)} " - f"samples but only wrote {batch_written}. " - f"Incomplete: {list(sample_results.keys())}" + pbar = tqdm( + total=len(batch_tasks), + desc=f" Batch {batch_idx + 1}/{n_batches}", ) - # Free any remaining refs before next batch. - del sample_results - - _logger.info(f"Export complete. Wrote {samples_written}/{len(samples)} samples.") + processed_samples = [] + + for sample, _fstep, data in pool.imap_unordered( + get_data_worker, batch_tasks, chunksize=max(1, n_fsteps) + ): + sample_results[sample].append(data) + pbar.update(1) + + # Check if this sample is complete (all fsteps received). + if len(sample_results[sample]) == n_fsteps: + b_idx = sample_to_batch_idx[sample] + ref_time = batch_ref_times[b_idx] + results_iter = iter(sample_results[sample]) + processed = parser.process_sample(results_iter, ref_time=ref_time) + processed_samples.append(processed) + + # Free memory immediately. + del sample_results[sample] + batch_written += 1 + + # Only save here if need to merge samples, otherwise saved in process_sample + if processed_samples[0] is not None: + parser.save(processed_samples) + pbar.close() + + samples_written += batch_written + if batch_written != len(batch_samples): + _logger.error( + f"Batch {batch_idx + 1}: expected {len(batch_samples)} " + f"samples but only wrote {batch_written}. " + f"Incomplete: {list(sample_results.keys())}" + ) + + # Free any remaining refs before next batch. + del sample_results + + _logger.info(f"Export complete. Wrote {samples_written}/{len(samples)} samples.") diff --git a/packages/evaluate/src/weathergen/evaluate/export/export_inference.py b/packages/evaluate/src/weathergen/evaluate/export/export_inference.py index 12711ee9e..d748b0068 100755 --- a/packages/evaluate/src/weathergen/evaluate/export/export_inference.py +++ b/packages/evaluate/src/weathergen/evaluate/export/export_inference.py @@ -95,7 +95,7 @@ def parse_args(args: list) -> argparse.Namespace: "--format", dest="output_format", type=str, - choices=["netcdf", "grib", "quaver"], + choices=["netcdf", "verif", "quaver"], help="Output file format (currently only netcdf supported)", required=True, ) @@ -103,9 +103,8 @@ def parse_args(args: list) -> argparse.Namespace: parser.add_argument( "--stream", type=str, - choices=["ERA5", "IMERG_ANEMOI"], + choices=["ERA5", "CERRA", "MEPS", "NORA3", "IMERG_ANEMOI"], help="Stream name to retrieve data for", - required=True, ) parser.add_argument( @@ -203,6 +202,24 @@ def parse_args(args: list) -> argparse.Namespace: help="Type of grid to regrid to (only used if --regrid-degree is specified)", ) + parser.add_argument("-b", "--obs", help="observation file for creating verif files") + + parser.add_argument( + "-m", + "--method", + default="2d", + choices=["2d", "lat_lon", "nearest"], + help="Interpolation method used for verif. Default: 2d_interpolation", + ) + + parser.add_argument( + "--verif-template", + default="verif/%S/%V/verif_%S_%V_%M_%D.nc", + help="Template for the output nc filenames, default will be to create output/verif/%S/%V \ + repertories where %S, %V, %M, %D are replaced by the " + "streams, variable, method and date", + ) + args, unknown_args = parser.parse_known_args(args) if unknown_args: _logger.warning(f"Unknown arguments: {unknown_args}") @@ -256,7 +273,10 @@ def export_from_args(args: list) -> None: args = parse_args(args) # Load configuration - config_file = Path(_REPO_ROOT, "config/evaluate/config_zarr2cf.yaml") + if args.output_format == "verif": + config_file = Path(_REPO_ROOT, "config/evaluate/config_zarr2verif.yaml") + else: + config_file = Path(_REPO_ROOT, "config/evaluate/config_zarr2cf.yaml") config = OmegaConf.load(config_file) # check config loaded correctly assert len(config["variables"].keys()) > 0, "Config file not loaded correctly" diff --git a/packages/evaluate/src/weathergen/evaluate/export/io_utils.py b/packages/evaluate/src/weathergen/evaluate/export/io_utils.py index 06f0cf25f..e9dd3ea52 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/io_utils.py +++ b/packages/evaluate/src/weathergen/evaluate/export/io_utils.py @@ -1,7 +1,5 @@ import logging -from pathlib import Path -import numpy as np import xarray as xr from weathergen.common.config import get_model_results @@ -11,42 +9,6 @@ _logger.setLevel(logging.INFO) -def output_filename( - prefix: str, - run_id: str, - output_dir: str, - output_format: str, - forecast_ref_time: np.datetime64, - regrid_degree: float, -) -> Path: - """ - Generate output filename based on prefix (should refer to type e.g. pred/targ), run_id, sample - index, output directory, format and forecast_ref_time. - - Parameters - ---------- - prefix : Prefix for file name (e.g., 'pred' or 'targ'). - run_id :Run ID to include in the filename. - output_dir : Directory to save the output file. - output_format : Output file format (currently only 'netcdf' supported). - forecast_ref_time : Forecast reference time to include in the filename. - - Returns - ------- - Full path to the output file. - """ - if output_format not in ["netcdf"]: - raise ValueError( - f"Unsupported output format: {output_format}, supported formates are ['netcdf']" - ) - file_extension = "nc" - frt = np.datetime_as_string(forecast_ref_time, unit="h") - if regrid_degree is not None: - run_id += f"_regular{regrid_degree, regrid_degree}" - out_fname = Path(output_dir) / f"{prefix}_{frt}_{run_id}.{file_extension}" - return out_fname - - def get_data_worker(args: tuple) -> xr.DataArray: """ Worker function to retrieve data for a single sample and forecast step. diff --git a/packages/evaluate/src/weathergen/evaluate/export/parser_factory.py b/packages/evaluate/src/weathergen/evaluate/export/parser_factory.py index d248b0c78..1d3f7e26b 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/parser_factory.py +++ b/packages/evaluate/src/weathergen/evaluate/export/parser_factory.py @@ -3,6 +3,7 @@ from weathergen.evaluate.export.cf_utils import CfParser from weathergen.evaluate.export.parsers.netcdf_parser import NetcdfParser from weathergen.evaluate.export.parsers.quaver_parser import QuaverParser +from weathergen.evaluate.export.parsers.verif_parser import VerifParser class CfParserFactory: @@ -30,17 +31,16 @@ def get_parser(config: OmegaConf, **kwargs) -> CfParser: _parser_map = { "netcdf": (NetcdfParser, ["grid_type"]), "quaver": (QuaverParser, ["grid_type", "channels", "template"]), + "verif": (VerifParser, ["obs", "method", "verif_template"]), } fmt = kwargs.get("output_format") parser_class = _parser_map.get(fmt) parser = parser_class[0] - # allowed_keys = parser_class[1] # filtered_kwargs = {k: v for k, v in kwargs.items() if k in allowed_keys} if parser_class is None: raise ValueError(f"Unsupported format: {fmt}") - return parser(config, **kwargs) diff --git a/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py b/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py new file mode 100644 index 000000000..958ca1a1f --- /dev/null +++ b/packages/evaluate/src/weathergen/evaluate/export/parsers/verif_parser.py @@ -0,0 +1,652 @@ +# pylint: disable=bad-builtin + +import contextlib +import logging +from pathlib import Path +from typing import Any + +import numpy as np +import xarray as xr +from omegaconf import OmegaConf + +from weathergen.evaluate.export.cf_utils import CfParser +from weathergen.evaluate.export.preprocess import compute_mslp, compute_precip +from weathergen.evaluate.export.reshape import ( + InterpolatorFactory, + find_pl, + get_grid_points, + get_obs_coordinates, +) + +_logger = logging.getLogger(__name__) +_logger.setLevel(logging.INFO) + +""" +Usage: + +uv run export --run-id wgp6fowx --stream ERA5 \ +--output-dir ../test_output1 \ +--format verif --samples 1 2 --fsteps 1 2 3 \ +--obs /p/project1/weatherai/myhre1/metno_observations_v3.nc \ +--method 2d +""" + + +class VerifParser(CfParser): + """ + Child class for handling NetCDF output format for MetNor Verif software. + """ + + def __init__(self, config: OmegaConf, **kwargs): + """ + CF-compliant parser that handles both regular and Gaussian grids. + + Parameters + ---------- + config : OmegaConf + Configuration defining variable mappings and dimension metadata. + ds : xr.Dataset + Input dataset. + + Returns + ------- + xr.Dataset + CF-compliant dataset with consistent naming and attributes. + """ + for k, v in kwargs.items(): + setattr(self, k, v) + + super().__init__(config=config) + + if not hasattr(self, "obs"): + raise ValueError("Observation data required for creating verif compliant NetCDFs") + + self.mapping = config.get("variables", {}) + + # add extra attributes + self.obs = xr.open_dataset(self.obs) + lat, lon, _ = get_obs_coordinates(self.obs) + self.obs_coords = np.column_stack((lat.values, lon.values)) + self.zarr_coords = None + + required_channels = ["10u", "10v", "sp", "2t", "msl"] + self.channels = list(set(self.channels) & set(required_channels)) + self.zarr_dt: np.timedelta64 | None = None + + def process_sample( + self, + fstep_iterator_results: iter, + ref_time: np.datetime64, + ): + """ + Process results from get_data_worker: reshape, concatenate, add metadata, and save. + Parameters + ---------- + fstep_iterator_results : Iterator over results from get_data_worker. + ref_time : Forecast reference time for the sample. + Returns + ------- + None + """ + # check ref_time exists in the obs data + if ref_time not in self.obs.time.values: + _logger.warning( + f"Reference time {ref_time} not found in observation data. Skipping sample." + ) + return + + da_fs = [] + for result in fstep_iterator_results: + if result is None: + continue + # result is already a materialized xarray DataArray (built in the worker). + if not isinstance(result, xr.DataArray): + result = result.as_xarray().squeeze() + result = result.sel(channel=self.channels) + result = self.preprocess(result) + result = self.reshape(result) + da_fs.append(result) + + _logger.info(f"Retrieved {len(da_fs)} forecast steps for type {self.data_type}.") + + if da_fs: + if self.zarr_coords is None: + self.zarr_coords = get_grid_points(da_fs[0]) + self.zarr_dt = self.get_zarr_dt(da_fs[0]) + # check consistency of grid points across forecast steps + if not np.array_equal(get_grid_points(da_fs[1]), self.zarr_coords): + raise ValueError( + "Grid points between forecast steps are not consistent." + "Check that inference was not performed with masking" + ) + da_fs = self.concatenate(da_fs) + da_fs = self.assign_frt(da_fs, ref_time) + da_fs = self.add_attrs(da_fs) + vars_to_merge = {verif_var: None for verif_var in self.mapping.keys()} + + for verif_var in self.mapping.keys(): + da_var = self.regrid(da_fs, verif_var) + if da_var is None: + continue + da_var = self.add_encoding(da_var) + obs_result = self.obs_preprocess(da_var, verif_var) + obs_result = self.add_encoding(obs_result) + merged = self.merge(da_var, obs_result) + merged = self.add_metadata(merged, verif_var) + vars_to_merge[verif_var] = merged + return vars_to_merge + + def get_zarr_dt(self, ds: xr.Dataset) -> np.timedelta64: + """ + Compute the time difference between forecast steps in hours from the WG output dataset. + Parameters + ---------- + ds : xr.Dataset + Input dataset from which to compute the time difference. + Returns + ------- + np.timedelta64 + Time difference between forecast steps in hours. + """ + zarr_dt = ds.source_interval_end.values - ds.source_interval_start.values + zarr_dt = zarr_dt.astype("timedelta64[h]") + + return zarr_dt + + def get_output_filename(self, variable: str) -> Path: + """ + Create output directories for the verif files + and return path to output file + Args: + variables (list[string]) + outfiles (string): template for the output files + Outputs: + None + """ + outfile = Path( + self.verif_template.replace("%S", self.stream) + .replace("%V", variable) + .replace("%M", self.method) + .replace("%D", self.data_type) + ) + outfile = Path(self.output_dir) / outfile + pathdir = outfile.parent + _logger.info(f"Output directory: {pathdir}") + pathdir.mkdir(exist_ok=True, parents=True) + return outfile + + def reshape(self, data: xr.DataArray) -> xr.Dataset: + """ + Reshape dataset while preserving grid structure (regular or Gaussian). + + Parameters + ---------- + data : xr.DataArray + Input data with dimensions (ipoint, channel) + + Returns + ------- + xr.Dataset + Reshaped dataset appropriate for the grid type + """ + grid_type = self.grid_type + + # Original logic + var_dict = find_pl(data.channel.values) + data_vars = {} + + for new_var, pls in var_dict.items(): + if pls[0] is not None: + old_vars = [f"{new_var}_{p}" for p in pls] + data_vars[new_var] = xr.DataArray( + data.sel(channel=old_vars).values, + dims=["ipoint", "pressure_level"], + coords={"pressure_level": pls}, + ) + else: + data_vars[new_var] = xr.DataArray( + data.sel(channel=new_var).values, + dims=["ipoint"], + ) + + reshaped_dataset = xr.Dataset(data_vars) + reshaped_dataset = reshaped_dataset.assign_coords( + ipoint=data.coords["ipoint"], + ) + + # order using pressure_level coord + if "pressure_level" in reshaped_dataset.coords: + reshaped_dataset = reshaped_dataset.sortby("pressure_level") + + if grid_type == "regular": + # Use original reshape logic for regular grids + # This is safe for regular grids + reshaped_dataset = reshaped_dataset.set_index( + ipoint=("valid_time", "lat", "lon") + ).unstack("ipoint") + else: + # Use new logic for Gaussian/unstructured grids + reshaped_dataset = reshaped_dataset.set_index(ipoint2=("ipoint", "valid_time")).unstack( + "ipoint2" + ) + # rename ipoint to ncells + reshaped_dataset = reshaped_dataset.rename_dims({"ipoint": "ncells"}) + reshaped_dataset = reshaped_dataset.rename_vars({"ipoint": "ncells"}) + + return reshaped_dataset + + def obs_preprocess(self, ds_var, verif_var: str) -> xr.DataArray: + """ + Preprocess the observation data for the given variable and valid times. + This includes computing derived variables like MSLP and total precipitation if needed. + + Parameters + ---------- + obs_data : xr.Dataset + The original observation dataset. + ds_var : xr.DataArray + The forecast data array to which the observation data should be regridded. + + Returns + ------- + xr.DataArray + Regridded observation data matching the forecast grid. + """ + obs_data = self.obs + mapped_info = self.mapping.get(verif_var, {}) + obs_name = mapped_info.get("obs_name", {}) + + original_shape = ds_var.shape + new_shape = list(original_shape) + + obs_dataarray = np.empty(new_shape, dtype=np.float32) + + for i, leadtime in enumerate(ds_var.coords["leadtime"].values): + valid_time = ds_var.coords["time"] + np.timedelta64(int(leadtime), "h") + if verif_var == "mslp": + obs_dataarray[:, i, :] = compute_mslp(obs_data, valid_time) + if verif_var == "tp": + obs_dataarray[:, i, :] = compute_precip(obs_data, self.zarr_dt, valid_time) + else: + obs_dataarray[:, i, :] = obs_data.data_vars[obs_name].sel(time=valid_time) + + obs_dataarray = ds_var.copy(data=obs_dataarray) + obs_dataarray.name = "obs" + + return obs_dataarray + + def preprocess(self, ds: xr.Dataset) -> xr.Dataset: + """ + Preprocess variables and only keep relevant ones for WG output + Parameters + ---------- + ds : xr.Dataset + + + Returns + ------- + xr.Dataset + """ + if set(["10u", "10v"]).issubset(self.channels): + u = ds.sel(channel="10u") + v = ds.sel(channel="10v") + # hypotenuese + wind_speed = xr.apply_ufunc( + np.hypot, u, v, dask="parallelized", output_dtypes=[ds.dtype] + ).astype("float32") + wind_speed = wind_speed.expand_dims(channel=["10si"]) + if ds.chunks: + wind_speed = wind_speed.chunk( + {"ipoint": ds.chunks[ds.get_axis_num("ipoint")][0], "channel": 1} + ) + new_ds = xr.concat([ds, wind_speed], dim="channel") + new_ds.attrs = ds.attrs + + # remove unnecessary + new_ds = new_ds.drop_sel(channel=["10u", "10v"]) + return new_ds + else: + return ds + + def regrid(self, ds: xr.Dataset, verif_var: str) -> xr.Dataset: + """ + Regrid a single xarray Dataset using specific method. + Parameters + ---------- + ds: native xarray Dataset + Returns + ------- + Regridded xarray Dataset. + """ + mapped_info = self.mapping.get(verif_var, {}) + wg_var = mapped_info.get("var", None) + try: + ds_var = ds[wg_var] + except KeyError as e: + _logger.info(f"{wg_var} not available in WeatherGenerator output: {e}") + return + # set coords + # TODO: tidy this up + new_coords = { + "time": (["time"], np.atleast_1d(ds_var.coords["time"].values), ds_var["time"].attrs), + "location": ( + ["location"], + self.obs.location.values, + {"long_name": "Norwegian station ID"}, + ), + "leadtime": ( + ["leadtime"], + np.atleast_1d(ds_var.coords["leadtime"].values.astype("float32")), + ds_var["leadtime"].attrs, + ), + } + # set variable attrs + attrs = ds_var.attrs.copy() + with contextlib.suppress(KeyError): + del attrs["ncells"] # + + original_shape = ds_var.shape + new_shape = list(original_shape) + pos = ds_var.dims.index("ncells") + new_shape[pos] = self.obs.location.shape[0] + # rearrange to be time,location + order = [1, 0] + new_shape = [new_shape[x] for x in order] + + fcstdata = np.empty(new_shape, dtype=np.float32) + + # set interpolation method + method_factory = InterpolatorFactory(self.method) + interpolator = method_factory.get_interpolator(self.zarr_coords, self.obs_coords) + + num_leadtimes = np.atleast_1d(ds_var.coords["leadtime"].values).shape[0] + + for idx in range(num_leadtimes): + regrid_values = interpolator.interpolate(ds_var.values[:, idx]) + fcstdata[idx, :] = regrid_values + + regridded_var = xr.DataArray( + np.array([fcstdata]), + dims=["time", "leadtime", "location"], + coords={**new_coords}, + name="fcst", + attrs=attrs, + ) + return regridded_var + + def concatenate( + self, + array_list, + dim="valid_time", + data_vars="minimal", + coords="different", + compat="equals", + combine_attrs="drop", + sortby_dim="valid_time", + ) -> xr.Dataset: + """ + Uses list of pred/target xarray DataArrays to save one sample to a NetCDF file. + + Parameters + ---------- + type_str : str + Type of data ('pred' or 'targ') to include in the filename. + array_list : list of xr.DataArray + List of DataArrays to concatenate. + dim : str, optional + Dimension along which to concatenate. Default is 'valid_time'. + data_vars : str, optional + How to handle data variables during concatenation. Default is 'minimal'. + coords : str, optional + How to handle coordinates during concatenation. Default is 'different'. + compat : str, optional + Compatibility check for variables. Default is 'equals'. + combine_attrs : str, optional + How to combine attributes. Default is 'drop'. + sortby_dim : str, optional + Dimension to sort the final dataset by. Default is 'valid_time'. + + Returns + ------- + xr.Dataset + Concatenated xarray Dataset. + """ + + data = xr.concat( + array_list, + dim=dim, + data_vars=data_vars, + coords=coords, + compat=compat, + combine_attrs=combine_attrs, + ).sortby(sortby_dim) + + return data + + def assign_frt(self, ds: xr.Dataset, reference_time: np.datetime64) -> xr.Dataset: + """ + Assign forecast reference time coordinate to the dataset. + + Parameters + ---------- + ds : xarray Dataset to assign coordinates to. + reference_time : Forecast reference time to assign. + + Returns + ------- + xarray Dataset with assigned forecast reference time coordinate. + """ + ds = ds.assign_coords(forecast_reference_time=reference_time) + + if "sample" in ds.coords: + ds = ds.drop_vars("sample") + n_hours = self.fstep_hours.astype("int64") + ds["forecast_step"] = ds["forecast_step"] * n_hours + return ds + + def add_attrs(self, ds: xr.Dataset) -> xr.Dataset: + """ + Add CF-compliant attributes to the dataset variables. + + Parameters + ---------- + ds : xarray Dataset to add attributes to. + Returns + ------- + xarray Dataset with CF-compliant variable attributes. + """ + variables = self._attrs_gaussian_grid(ds) + dataset = xr.merge(variables.values(), compat="no_conflicts") + return dataset + + def add_encoding(self, ds: xr.Dataset) -> xr.Dataset: + """ + Add time encoding to the dataset variables. + Add aux coordinates to leadtime + + Parameters + ---------- + ds : xarray Dataset to add time encoding to. + Returns + ------- + xarray Dataset with time encoding added. + """ + time_encoding = { + "units": "seconds since 1970-01-01 00:00:00", + "calendar": "proleptic_gregorian", + } + + if "time" in ds.coords: + ds["time"].encoding.update(time_encoding) + + if "forecast_reference_time" in ds.coords: + ds["forecast_reference_time"].encoding.update(time_encoding) + + if "leadtime" in ds.coords: + ds["leadtime"].encoding.update({"coordinates": "forecast_reference_time"}) + + return ds + + def add_metadata(self, ds: xr.Dataset, verif_var) -> xr.Dataset: + """ + Add CF conventions to the dataset attributes. + + Parameters + ---------- + ds : Input xarray Dataset to add conventions to. + Returns + ------- + xarray Dataset with CF conventions added to attributes. + """ + ds.attrs["title"] = ( + f"WeatherGenerator Output for {self.run_id}, variable {verif_var} " + f"using stream {self.stream}" + ) + ds.attrs["institution"] = "WeatherGenerator Collaboration" + ds.attrs["source"] = "WeatherGenerator v0.0" + ds.attrs["history"] = "Created using the verif_parser on " + np.datetime_as_string( + np.datetime64("now"), unit="s" + ) + ds.attrs["conventions"] = "verif_1.0.0" + + return ds + + def _attrs_gaussian_grid(self, ds: xr.Dataset) -> xr.Dataset: + """ + Assign CF-compliant attributes to variables in a gaussian grid dataset. + Parameters + ---------- + ds : xr.Dataset + Input dataset. + Returns + ------- + xr.Dataset + Dataset with CF-compliant variable attributes. + """ + unit_conversion = {"kg/m^2": 1.0, "Pa": 1.0, "K": 1.0, "m/s": 1.0, "m": 1000.0} + + variables = {} + dims_cfg = self.config.get("dimensions", {}) + ds, ds_attrs = self._assign_dim_attrs(ds, dims_cfg) + for var_name, da in ds.data_vars.items(): + mapped_info = self.mapping.get(var_name, {}) + mapped_name = mapped_info.get("var", var_name) + mapped_units = mapped_info.get("wg_unit", {}) + + coords = self._build_coordinate_mapping(ds, mapped_info, ds_attrs) + + wg_unit = mapped_units.get(self.stream, "DEFAULT") + verif_unit = mapped_info.get("verif_unit", None) + if wg_unit != verif_unit: + # perform unit conversion + da.values = da.values * unit_conversion[wg_unit] + + attributes = { + "units": verif_unit, + } + + if "long" in mapped_info: + attributes["long_name"] = mapped_info["long"] + variables[mapped_name] = xr.DataArray( + data=da.values, + dims=da.dims, + coords=coords, + attrs=attributes, + name=mapped_name, + ) + + return variables + + def _assign_dim_attrs( + self, ds: xr.Dataset, dim_cfg: dict[str, Any] + ) -> tuple[xr.Dataset, dict[str, dict[str, str]]]: + """ + Assign CF attributes from given config file. + Parameters + ---------- + ds : xr.Dataset + Input dataset. + dim_cfg : Dict[str, Any] + Dimension configuration from mapping. + Returns + ------- + Dict[str, Dict[str, str]]: + Attributes for each dimension. + xr.Dataset: + Dataset with renamed dimensions. + """ + ds_attrs = {} + + for dim_name, meta in dim_cfg.items(): + verif_name = meta.get("verif", dim_name) + if dim_name in ds.dims and dim_name != verif_name: + ds = ds.rename_dims({dim_name: verif_name}) + + dim_attrs = {"standard_name": meta.get("std", verif_name)} + if meta.get("verif_unit"): + dim_attrs["units"] = meta["verif_unit"] + if meta.get("long"): + dim_attrs["long_name"] = meta["long"] + ds_attrs[verif_name] = dim_attrs + return ds, ds_attrs + + def _build_coordinate_mapping( + self, ds: xr.Dataset, var_cfg: dict[str, Any], attrs: dict[str, dict[str, str]] + ) -> dict[str, Any]: + """Create coordinate mapping for a given variable. + Parameters + ---------- + ds : xr.Dataset + Input dataset. + var_cfg : Dict[str, Any] + Variable configuration from mapping. + attrs : Dict[str, Dict[str, str]] + Attributes for dimensions. + Returns + ------- + Dict[str, Any]: + Coordinate mapping for the variable. + """ + coords = {} + coord_map = self.config.get("coordinates", {}).get(var_cfg.get("level_type"), {}) + + for coord, new_name in coord_map.items(): + coords[new_name] = ( + ds.coords[coord].dims, + ds.coords[coord].values, + attrs[new_name], + ) + + return coords + + def merge(self, ds, obs_ds): + lat, lon, alt = get_obs_coordinates(self.obs) + merged = xr.merge([ds, obs_ds, lat, lon, alt], compat="minimal") + return merged + + def save(self, list_samples: list) -> None: + """ + Save the dataset to a NetCDF file. + + Parameters + ---------- + list_samples : list of dictionary containing variables to merge and save. + Each dictionary corresponds to a sample and contains variables for that sample. + + Returns + ------- + None + """ + for verif_var in self.mapping.keys(): + var_list = [sample[verif_var] for sample in list_samples if verif_var in sample] + if all(v is None for v in var_list): + _logger.warning(f"No data to save for variable {verif_var}. Skipping.") + continue + ds = xr.concat( + var_list, dim="time", data_vars="minimal", coords="minimal", join="exact" + ) + out_fname = self.get_output_filename(verif_var) + _logger.info(f"Saving to {out_fname}.") + ds.to_netcdf(out_fname) + _logger.info(f"Saved NetCDF file to {out_fname}.") + _logger.info(f"Saved {verif_var} data to {self.output_format} in {self.output_dir}.") diff --git a/packages/evaluate/src/weathergen/evaluate/export/preprocess.py b/packages/evaluate/src/weathergen/evaluate/export/preprocess.py new file mode 100644 index 000000000..1f650cd21 --- /dev/null +++ b/packages/evaluate/src/weathergen/evaluate/export/preprocess.py @@ -0,0 +1,86 @@ +import numpy as np +import xarray as xr + +""" +Extra helper functions to preprocess data +e.g. for verif applications +""" + + +def compute_mslp(obs: xr.DataArray, time: np.datetime64) -> np.typing.NDArray: + """ + Compute mean sea level pressure (MSLP) from surface air pressure, + air temperature, and relative humidity. + Parameters + ---------- + obs : xarray DataArray + Input data containing surface air pressure, air temperature, and relative humidity. + time : np.datetime64 + Time over which to compute mean for the MSLP. + Returns + ------- + np.ndarray + Computed mean sea level pressure values. + """ + # g = 9.80665 # Gravitational acceleration (m/s**2) + # R = 8.31447 # Universal gas constant (J/mol*K) + # a = 0.0065 # Temperature lapse rate (K/m) + # Ch = 0.0012 # (K/Pa) + + a = 17.625 + b = 243.03 + c = 6.1094 + + p = obs.data_vars["surface_air_pressure"].sel(time=time) + t = obs.data_vars["air_temperature"].sel(time=time) + rh = obs.data_vars["relative_humidity"].sel(time=time) + + altitude = obs.altitude + + e = rh * 6.11 * np.power(10.0, ((7.5 * (t - 273.15)) / (t - 38.85))) + + dewpoint = np.where(~np.isnan(e), b * np.log(e / c) / (a - np.log(e / c)), t - 276.15) + + e = np.where(np.isnan(e), 0, e) + + tv = t / (1.0 - 0.379 * (6.11 * np.power(10.0, ((7.5 * dewpoint) / (237.7 + dewpoint))) / p)) + + # mslp = np.where(altitude >= 50., + # p * np.exp((g * altitude / R) / (t + 0.5 * a * altitude + e * Ch)), + # p + p * altitude / (29.27 * tv)) + + mslp = p + p * altitude / (29.27 * tv) + + return mslp + + +def compute_precip( + obs_data: xr.Dataset, zarr_dt: np.timedelta64, frt: np.datetime64 +) -> np.typing.NDArray: + """ + Compute accumulated precipitation over the forecast time step. + Parameters + ---------- + obs_data : xarray Dataset + Input data containing precipitation observations. + zarr_dt : np.timedelta64 + Time difference between forecast steps in hours. + frt : np.datetime64 + Forecast reference time for which to compute accumulated precipitation. + Returns + ------- + np.ndarray + Accumulated precipitation values for the forecast time step.""" + obs_dt = obs_data.time.values[1] - obs_data.time.values[0] + obs_dt = obs_dt.astype("timedelta64[h]") + + if obs_dt >= zarr_dt: + return obs_data["precipitation_amount_1h"].values + else: + accumulate = np.zeros(obs_data.location.shape[0]) + int_factor = int(zarr_dt / obs_dt) + + for i in range(int_factor): + back_time = frt - zarr_dt + (i + 1) * obs_dt + accumulate += obs_data.data_vars["precipitation_amount_1h"].sel(time=back_time) + return accumulate diff --git a/packages/evaluate/src/weathergen/evaluate/export/reshape.py b/packages/evaluate/src/weathergen/evaluate/export/reshape.py index 74339bdad..d71931915 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/reshape.py +++ b/packages/evaluate/src/weathergen/evaluate/export/reshape.py @@ -6,6 +6,8 @@ import numpy as np import xarray as xr from earthkit.regrid import interpolate +from scipy.interpolate import LinearNDInterpolator +from scipy.spatial import Delaunay, KDTree _logger = logging.getLogger(__name__) _logger.setLevel(logging.INFO) @@ -15,6 +17,33 @@ """ +def get_obs_coordinates(obs: xr.Dataset): + """ + Extract latitude, longitude and altitude + from observation dataset + Args: + obs: Dataset + Outputs: + lat: DataArray + lon: DataArray + alt: DataArray + """ + + lat = obs.latitude.astype("float32") + lat.name = "lat" + + lon = obs.longitude.astype("float32") + lon.name = "lon" + + alt = obs.altitude.astype("float32") + + return lat, lon, alt + + +def get_grid_points(data: xr.DataArray): + return np.column_stack((data.lat.values, data.lon.values)) + + def detect_grid_type(data: xr.DataArray) -> str: """ Detect whether data is on a regular lat/lon grid or Gaussian grid. @@ -539,3 +568,216 @@ def regrid_da(self, da: xr.DataArray) -> xr.DataArray: is not implemented yet.""" ) return regrid_da + + +## functions for verif + + +def convert_coordinates(coords: np.typing.NDArray) -> np.typing.NDArray: + """ + Convert lat-lon coordinates to cartesian coordinates in a unit box + """ + + xyz_coords = np.empty((coords.shape[0], 3), dtype="float32") + + xyz_coords[:, 0] = np.cos(np.pi * coords[:, 0] / 180.0) * np.cos(np.pi * coords[:, 1] / 180.0) + xyz_coords[:, 1] = np.cos(np.pi * coords[:, 0] / 180.0) * np.sin(np.pi * coords[:, 1] / 180.0) + xyz_coords[:, 2] = np.sin(np.pi * coords[:, 0] / 180.0) + + return xyz_coords + + +def normalise(x: np.typing.NDArray) -> np.typing.NDArray: + """ + Normalise an array by dividing by the sum of its elements. + """ + return x[:] / np.sum(x[:]) + + +class VerifInterpolator: + """ + Interpolator class that's either a wrapper for scipys LinearNDInterpolator + or uses the handmade approximate 2D linear interpolator + """ + + +class Verif2DInterpolator(VerifInterpolator): + """ + Class that does approximate 2D interpolation + """ + + def __init__(self, grid_points: np.typing.NDArray, obs_points: np.typing.NDArray): + """ + Initialise the class and store gridpoints + """ + + grid_xyz = convert_coordinates(grid_points) + obs_xyz = convert_coordinates(obs_points) + + self.indices = np.empty((obs_points.shape[0], 5), dtype="float32") + tree = KDTree(grid_xyz) + _, self.indices = tree.query(obs_xyz, k=5) + + self.weights = np.empty((obs_points.shape[0], 3), dtype="float32") + self.compute_weights(grid_xyz, obs_xyz) + + def compute_weights(self, grid_xyz: np.typing.NDArray, obs_xyz: np.typing.NDArray): + """ + Compute the weights of the three nearest grid points + by computing the barycentric coordinates, + assuming that the observations are close enough to the plane through the grid points. + """ + + eps = 0.01 + + for i, (obs, indix) in enumerate(zip(obs_xyz, self.indices, strict=True)): + ab = grid_xyz[indix[1]] - grid_xyz[indix[0]] + ac = grid_xyz[indix[2]] - grid_xyz[indix[0]] + bc = grid_xyz[indix[2]] - grid_xyz[indix[1]] + ap = obs - grid_xyz[indix[0]] + bp = obs - grid_xyz[indix[1]] + + area_tot = np.linalg.norm(np.cross(ab, ac)) + self.weights[i, 0] = np.linalg.norm(np.cross(bc, bp)) + self.weights[i, 1] = np.linalg.norm(np.cross(ac, ap)) + self.weights[i, 2] = np.linalg.norm(np.cross(ab, ap)) + + if 1 - area_tot / np.sum(self.weights[i, :]) < eps: + continue + + indix[2] = indix[3] + + ac = grid_xyz[indix[2]] - grid_xyz[indix[0]] + bc = grid_xyz[indix[2]] - grid_xyz[indix[1]] + + area_tot = np.linalg.norm(np.cross(ab, ac)) + self.weights[i, 0] = np.linalg.norm(np.cross(bc, bp)) + self.weights[i, 1] = np.linalg.norm(np.cross(ac, ap)) + + if 1 - area_tot / np.sum(self.weights[i, :]) < eps: + continue + + indix[2] = indix[4] + + ac = grid_xyz[indix[2]] - grid_xyz[indix[0]] + bc = grid_xyz[indix[2]] - grid_xyz[indix[1]] + + self.weights[i, 0] = np.linalg.norm(np.cross(bc, bp)) + self.weights[i, 1] = np.linalg.norm(np.cross(ac, ap)) + + self.weights = self.weights / self.weights.sum(axis=1)[:, np.newaxis] + + def interpolate( + self, values: np.typing.NDArray, intmap: np.typing.NDArray = None + ) -> np.typing.NDArray: + """ + Interpolate values to points + """ + + wvalues = np.empty((self.weights.shape[0]), dtype="float32") + + if intmap is None: + wvalues[:] = ( + self.weights[:, 0] * values[self.indices[:, 0]] + + self.weights[:, 1] * values[self.indices[:, 1]] + + self.weights[:, 2] * values[self.indices[:, 2]] + ) + else: + wvalues[:] = ( + self.weights[:, 0] * values[intmap[self.indices[:, 0]]] + + self.weights[:, 1] * values[intmap[self.indices[:, 1]]] + + self.weights[:, 2] * values[intmap[self.indices[:, 2]]] + ) + + return wvalues + + +class VerifLatLonInterpolator(VerifInterpolator): + """ + Class that does approximate 2D interpolation + """ + + def __init__(self, grid_points, obs_points): + """ + Initialise the class and store gridpoints + """ + + self.obs_points = obs_points + self.triangulation = Delaunay(grid_points) + + def interpolate( + self, values: np.typing.NDArray, intmap: np.typing.NDArray = None + ) -> np.typing.NDArray: + """ + Interpolate values to points + """ + + newvalues = np.empty_like(values) + + if intmap is None: + newvalues = values + else: + for i in range(len(values)): + newvalues[i] = values[intmap[i]] + + interpolator = LinearNDInterpolator(self.triangulation, newvalues) + + return interpolator(self.obs_points).astype(np.float32) + + +class VerifNearestInterpolator(VerifInterpolator): + """ + Class that does approximate 2D interpolation + """ + + def __init__(self, grid_points: np.typing.NDArray, obs_points: np.typing.NDArray): + """ + Initialise the class and store gridpoints + """ + + grid_xyz = convert_coordinates(grid_points) + obs_xyz = convert_coordinates(obs_points) + + tree = KDTree(grid_xyz) + _, self.indices = tree.query(obs_xyz, k=1) + + def interpolate( + self, values: np.typing.NDArray, intmap: np.typing.NDArray = None + ) -> np.typing.NDArray: + """ + Interpolate values to points + """ + + wvalues = np.empty((self.indices.shape[0]), dtype="float32") + + if intmap is None: + wvalues[:] = values[self.indices[:]] + else: + wvalues[:] = values[intmap[self.indices[:]]] + + return wvalues + + +class InterpolatorFactory: + def __init__(self, method: str): + valid_methods = ("2d", "lat_lon", "nearest") + + if method not in valid_methods: + raise Exception(f"{method} is not a valid method.") + + self.method = method + + def get_interpolator( + self, zarr_coords: np.typing.NDArray, obs_coords: np.typing.NDArray + ) -> VerifInterpolator: + if self.method == "2d": + _logger.info("2D interpolation") + return Verif2DInterpolator(zarr_coords, obs_coords) + + elif self.method == "lat_lon": + _logger.info("lat-lon interpolation") + return VerifLatLonInterpolator(zarr_coords, obs_coords) + + elif self.method == "nearest": + _logger.info("nearest neighbour interpolation") + return VerifNearestInterpolator(zarr_coords, obs_coords) diff --git a/packages/evaluate/src/weathergen/evaluate/io/data/dataarray_postprocessing.py b/packages/evaluate/src/weathergen/evaluate/io/data/dataarray_postprocessing.py index 854013948..82f22e744 100644 --- a/packages/evaluate/src/weathergen/evaluate/io/data/dataarray_postprocessing.py +++ b/packages/evaluate/src/weathergen/evaluate/io/data/dataarray_postprocessing.py @@ -8,7 +8,7 @@ # nor does it submit to any jurisdiction. """ -Post-processing helpers for evaluation DataArrays +Post-processing helpers for evaluation DataArrays (channel selection, derived channels, lead-time). """ diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index ff8e3b510..6870bf413 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -118,10 +118,10 @@ def forward(self, batch, pe_embed): # per cell indices into positional encoding tok_counts = batch.tokens_lens.permute([2, 0, 1, 3]).sum(0).flatten() - rows = torch.arange( tok_counts.max(), device=tok_counts.device).unsqueeze(0) + rows = torch.arange(tok_counts.max(), device=tok_counts.device).unsqueeze(0) rows = rows.expand(tok_counts.shape[0], -1) pe_idxs = rows[rows < tok_counts.unsqueeze(1)] - + # actual scatter operation tokens_all.scatter_(0, scatter_idxs, torch.cat(x_embeds) + pe_embed[pe_idxs]) diff --git a/src/weathergen/train/target_and_aux_ssl_teacher.py b/src/weathergen/train/target_and_aux_ssl_teacher.py index 6f41a5e24..edd8e53b6 100644 --- a/src/weathergen/train/target_and_aux_ssl_teacher.py +++ b/src/weathergen/train/target_and_aux_ssl_teacher.py @@ -30,8 +30,6 @@ logger = logging.getLogger(__name__) -logger = logging.getLogger(__name__) - class EncoderTeacher(TargetAndAuxModuleBase): """Base class for SSL teacher models.