diff --git a/config/evaluate/config_zarr2cf.yaml b/config/evaluate/config_zarr2cf.yaml index 95bddd21b..431896336 100644 --- a/config/evaluate/config_zarr2cf.yaml +++ b/config/evaluate/config_zarr2cf.yaml @@ -89,7 +89,7 @@ variables: wg_unit: Pa std_unit: Pa level_type: sfc - tp_imerg: + tp_imerg_0: var: tp_imerg_0 long: imerg_total_precipitation std: precipitation_amount diff --git a/packages/evaluate/src/weathergen/evaluate/export/parsers/netcdf_parser.py b/packages/evaluate/src/weathergen/evaluate/export/parsers/netcdf_parser.py index 56a487c2f..a0373b520 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/parsers/netcdf_parser.py +++ b/packages/evaluate/src/weathergen/evaluate/export/parsers/netcdf_parser.py @@ -73,7 +73,7 @@ def process_sample( continue result = result.as_xarray().squeeze() - if "channel" not in result.indexes: + if "channel" not in result.indexes: result = result.expand_dims("channel") result = result.sel(channel=self.channels) result = self.reshape(result) @@ -128,26 +128,30 @@ def reshape(self, data: xr.DataArray) -> xr.Dataset: grid_type = self.grid_type # Original logic - var_dict, pl = find_pl(data.channel.values) + var_dict = find_pl(data.channel.values) data_vars = {} - for new_var, old_vars in var_dict.items(): - if len(old_vars) > 1: + for new_var, pls in var_dict.items(): + if pls[0] is not None: + old_vars = [f"{new_var}_{p}" for p in pls] data_vars[new_var] = xr.DataArray( data.sel(channel=old_vars).values, dims=["ipoint", "pressure_level"], + coords={"pressure_level": pls}, ) else: data_vars[new_var] = xr.DataArray( - data.sel(channel=old_vars[0]).values, + data.sel(channel=new_var).values, dims=["ipoint"], ) reshaped_dataset = xr.Dataset(data_vars) reshaped_dataset = reshaped_dataset.assign_coords( ipoint=data.coords["ipoint"], - pressure_level=pl, ) + # order using pressure_level coord + if "pressure_level" in reshaped_dataset.coords: + reshaped_dataset = reshaped_dataset.sortby("pressure_level") if grid_type == "regular": # Use original reshape logic for regular grids @@ -274,7 +278,7 @@ def add_attrs(self, ds: xr.Dataset) -> xr.Dataset: else: variables = self._attrs_regular_grid(ds) - dataset = xr.merge(variables.values()) + dataset = xr.merge(variables.values(), compat="no_conflicts") dataset.attrs = ds.attrs return dataset diff --git a/packages/evaluate/src/weathergen/evaluate/export/reshape.py b/packages/evaluate/src/weathergen/evaluate/export/reshape.py index f32631ea3..74339bdad 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/reshape.py +++ b/packages/evaluate/src/weathergen/evaluate/export/reshape.py @@ -68,18 +68,18 @@ def find_pl(vars: list) -> tuple[dict[str, list[str]], list[int]]: List of unique pressure levels found in the variable names. """ var_dict = {} - pl = [] for var in vars: match = re.search(r"^([a-zA-Z0-9_]+)_(\d+)$", var) if match: var_name = match.group(1) pressure_level = int(match.group(2)) - pl.append(pressure_level) - var_dict.setdefault(var_name, []).append(var) + if pressure_level == 0: + var_dict.setdefault(var, []).append(None) + return var_dict + var_dict.setdefault(var_name, []).append(pressure_level) else: - var_dict.setdefault(var, []).append(var) - pl = sorted(set(pl)) - return var_dict, pl + var_dict.setdefault(var, []).append(None) + return var_dict class Regridder: