Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion config/evaluate/eval_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ run_ids :
color: "magenta" #optional: if not specified, the color is automatically assigned by the plotting module
results_base_dir : "./results/"
epoch: 1 #optional: if not specified epoch 0 (in inference it is always 0) is used
rank: 2 #optional: if not specified rank 0 is used
rank: "all" #optional: int, "all", or list of ints. Default: "all". Use "all" for multi-rank inference.
streams:
ERA5:
channels: ["2t", "10u", "10v"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class IOState:
lon: NDArray
n_workers: int
backend: str = "loky"
rank: str = "0000"


# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -247,6 +248,7 @@ def _build_io_state(
ensemble: list[str],
n_io_workers: int,
ens_select: EnsembleSelect,
rank: str = "",
) -> IOState:
"""Resolve all I/O parameters that are shared between the two impl paths."""
zarr_path = str(fname_zarr)
Expand Down Expand Up @@ -283,6 +285,7 @@ def _build_io_state(
lat=lat,
lon=lon,
n_workers=n_io_workers,
rank=rank,
)


Expand Down Expand Up @@ -459,7 +462,8 @@ def get_data_dirstore(state: IOState) -> ReaderOutput:
``n_samples × 1 × n_ipoints × n_channels × 4 bytes``.
"""
_logger.info(
f"RUN {state.run_id} - {state.stream}: Loading {len(state.samples)} samples × "
f"RUN {state.run_id} [rank {state.rank}] - {state.stream}: "
f"Loading {len(state.samples)} samples × "
f"{len(state.fsteps)} fsteps via zarr I/O "
f"(workers={state.n_workers}, backend={state.backend})..."
)
Expand All @@ -472,7 +476,7 @@ def get_data_dirstore(state: IOState) -> ReaderOutput:

for fi, fs in enumerate(state.fsteps):
_logger.info(
f"RUN {state.run_id} - {state.stream}: "
f"RUN {state.run_id} [rank {state.rank}] - {state.stream}: "
f"Reading fstep {fs} ({fi + 1}/{len(state.fsteps)})..."
)

Expand All @@ -487,7 +491,7 @@ def get_data_dirstore(state: IOState) -> ReaderOutput:
is_gridded=state.is_gridded,
n_workers=n_workers,
backend=state.backend,
label=f"RUN {state.run_id} - {state.stream} fstep {fs}",
label=f"RUN {state.run_id} [rank {state.rank}] - {state.stream} fstep {fs}",
)
# If _parallel_read fell back to sequential, honour that for the rest
if fell_back:
Expand Down Expand Up @@ -525,7 +529,7 @@ def get_data_dirstore(state: IOState) -> ReaderOutput:
get_reusable_executor().shutdown(wait=True)

_logger.info(
f"RUN {state.run_id} - {state.stream}: I/O complete. "
f"RUN {state.run_id} [rank {state.rank}] - {state.stream}: I/O complete. "
f"{len(da_tars_dict)} forecast entries loaded."
)
return ReaderOutput(target=da_tars_dict, prediction=da_preds_dict)
Expand All @@ -546,7 +550,8 @@ def get_data_zipstore(state: IOState) -> ReaderOutput:
"""
n_total = len(state.samples) * len(state.fsteps)
_logger.info(
f"RUN {state.run_id} - {state.stream}: Loading {len(state.samples)} samples × "
f"RUN {state.run_id} [rank {state.rank}] - {state.stream}: "
f"Loading {len(state.samples)} samples × "
f"{len(state.fsteps)} fsteps = {n_total} items via ZipStore-parallel zarr I/O "
f"(workers={state.n_workers}, backend={state.backend})..."
)
Expand All @@ -569,7 +574,7 @@ def get_data_zipstore(state: IOState) -> ReaderOutput:
calls,
n_workers=state.n_workers,
backend=state.backend,
desc=f"RUN {state.run_id} - {state.stream} (ZipStore)",
desc=f"RUN {state.run_id} [rank {state.rank}] - {state.stream} (ZipStore)",
verbose=5,
)

Expand Down Expand Up @@ -633,7 +638,7 @@ def get_data_zipstore(state: IOState) -> ReaderOutput:
get_reusable_executor().shutdown(wait=True)

_logger.info(
f"RUN {state.run_id} - {state.stream}: ZipStore-parallel I/O complete. "
f"RUN {state.run_id} [rank {state.rank}] - {state.stream}: ZipStore-parallel I/O complete. "
f"{len(da_tars_dict)} forecast entries loaded."
)
return ReaderOutput(target=da_tars_dict, prediction=da_preds_dict)
Loading
Loading