Skip to content
Merged

[2037] #2050

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion config/evaluate/config_zarr2cf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ variables:
wg_unit: Pa
std_unit: Pa
level_type: sfc
tp_imerg:
tp_imerg_0:
var: tp_imerg_0
long: imerg_total_precipitation
std: precipitation_amount
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def process_sample(
continue

result = result.as_xarray().squeeze()
if "channel" not in result.indexes:
if "channel" not in result.indexes:
result = result.expand_dims("channel")
result = result.sel(channel=self.channels)
result = self.reshape(result)
Expand Down Expand Up @@ -128,26 +128,30 @@ def reshape(self, data: xr.DataArray) -> xr.Dataset:
grid_type = self.grid_type

# Original logic
var_dict, pl = find_pl(data.channel.values)
var_dict = find_pl(data.channel.values)
data_vars = {}

for new_var, old_vars in var_dict.items():
if len(old_vars) > 1:
for new_var, pls in var_dict.items():
if pls[0] is not None:
old_vars = [f"{new_var}_{p}" for p in pls]
data_vars[new_var] = xr.DataArray(
data.sel(channel=old_vars).values,
dims=["ipoint", "pressure_level"],
coords={"pressure_level": pls},
)
else:
data_vars[new_var] = xr.DataArray(
data.sel(channel=old_vars[0]).values,
data.sel(channel=new_var).values,
dims=["ipoint"],
)

reshaped_dataset = xr.Dataset(data_vars)
reshaped_dataset = reshaped_dataset.assign_coords(
ipoint=data.coords["ipoint"],
pressure_level=pl,
)
# order using pressure_level coord
if "pressure_level" in reshaped_dataset.coords:
reshaped_dataset = reshaped_dataset.sortby("pressure_level")

if grid_type == "regular":
# Use original reshape logic for regular grids
Expand Down Expand Up @@ -274,7 +278,7 @@ def add_attrs(self, ds: xr.Dataset) -> xr.Dataset:
else:
variables = self._attrs_regular_grid(ds)

dataset = xr.merge(variables.values())
dataset = xr.merge(variables.values(), compat="no_conflicts")
dataset.attrs = ds.attrs
return dataset

Expand Down
12 changes: 6 additions & 6 deletions packages/evaluate/src/weathergen/evaluate/export/reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,18 +68,18 @@ def find_pl(vars: list) -> tuple[dict[str, list[str]], list[int]]:
List of unique pressure levels found in the variable names.
"""
var_dict = {}
pl = []
for var in vars:
match = re.search(r"^([a-zA-Z0-9_]+)_(\d+)$", var)
if match:
var_name = match.group(1)
pressure_level = int(match.group(2))
pl.append(pressure_level)
var_dict.setdefault(var_name, []).append(var)
if pressure_level == 0:
var_dict.setdefault(var, []).append(None)
return var_dict
var_dict.setdefault(var_name, []).append(pressure_level)
else:
var_dict.setdefault(var, []).append(var)
pl = sorted(set(pl))
return var_dict, pl
var_dict.setdefault(var, []).append(None)
return var_dict


class Regridder:
Expand Down
Loading