diff --git a/config/evaluate/eval_config.yml b/config/evaluate/eval_config.yml index a7415052b..feb3ae7e9 100644 --- a/config/evaluate/eval_config.yml +++ b/config/evaluate/eval_config.yml @@ -99,7 +99,7 @@ run_ids : color: "magenta" #optional: if not specified, the color is automatically assigned by the plotting module results_base_dir : "./results/" epoch: 1 #optional: if not specified epoch 0 (in inference it is always 0) is used - rank: 2 #optional: if not specified rank 0 is used + rank: "all" #optional: int, "all", or list of ints. Default: "all". Use "all" for multi-rank inference. streams: ERA5: channels: ["2t", "10u", "10v"] diff --git a/packages/evaluate/src/weathergen/evaluate/io/data/io_orchestration.py b/packages/evaluate/src/weathergen/evaluate/io/data/io_orchestration.py index 40a317024..22a0320e9 100644 --- a/packages/evaluate/src/weathergen/evaluate/io/data/io_orchestration.py +++ b/packages/evaluate/src/weathergen/evaluate/io/data/io_orchestration.py @@ -77,6 +77,7 @@ class IOState: lon: NDArray n_workers: int backend: str = "loky" + rank: str = "0000" # --------------------------------------------------------------------------- @@ -247,6 +248,7 @@ def _build_io_state( ensemble: list[str], n_io_workers: int, ens_select: EnsembleSelect, + rank: str = "", ) -> IOState: """Resolve all I/O parameters that are shared between the two impl paths.""" zarr_path = str(fname_zarr) @@ -283,6 +285,7 @@ def _build_io_state( lat=lat, lon=lon, n_workers=n_io_workers, + rank=rank, ) @@ -459,7 +462,8 @@ def get_data_dirstore(state: IOState) -> ReaderOutput: ``n_samples × 1 × n_ipoints × n_channels × 4 bytes``. """ _logger.info( - f"RUN {state.run_id} - {state.stream}: Loading {len(state.samples)} samples × " + f"RUN {state.run_id} [rank {state.rank}] - {state.stream}: " + f"Loading {len(state.samples)} samples × " f"{len(state.fsteps)} fsteps via zarr I/O " f"(workers={state.n_workers}, backend={state.backend})..." ) @@ -472,7 +476,7 @@ def get_data_dirstore(state: IOState) -> ReaderOutput: for fi, fs in enumerate(state.fsteps): _logger.info( - f"RUN {state.run_id} - {state.stream}: " + f"RUN {state.run_id} [rank {state.rank}] - {state.stream}: " f"Reading fstep {fs} ({fi + 1}/{len(state.fsteps)})..." ) @@ -487,7 +491,7 @@ def get_data_dirstore(state: IOState) -> ReaderOutput: is_gridded=state.is_gridded, n_workers=n_workers, backend=state.backend, - label=f"RUN {state.run_id} - {state.stream} fstep {fs}", + label=f"RUN {state.run_id} [rank {state.rank}] - {state.stream} fstep {fs}", ) # If _parallel_read fell back to sequential, honour that for the rest if fell_back: @@ -525,7 +529,7 @@ def get_data_dirstore(state: IOState) -> ReaderOutput: get_reusable_executor().shutdown(wait=True) _logger.info( - f"RUN {state.run_id} - {state.stream}: I/O complete. " + f"RUN {state.run_id} [rank {state.rank}] - {state.stream}: I/O complete. " f"{len(da_tars_dict)} forecast entries loaded." ) return ReaderOutput(target=da_tars_dict, prediction=da_preds_dict) @@ -546,7 +550,8 @@ def get_data_zipstore(state: IOState) -> ReaderOutput: """ n_total = len(state.samples) * len(state.fsteps) _logger.info( - f"RUN {state.run_id} - {state.stream}: Loading {len(state.samples)} samples × " + f"RUN {state.run_id} [rank {state.rank}] - {state.stream}: " + f"Loading {len(state.samples)} samples × " f"{len(state.fsteps)} fsteps = {n_total} items via ZipStore-parallel zarr I/O " f"(workers={state.n_workers}, backend={state.backend})..." ) @@ -569,7 +574,7 @@ def get_data_zipstore(state: IOState) -> ReaderOutput: calls, n_workers=state.n_workers, backend=state.backend, - desc=f"RUN {state.run_id} - {state.stream} (ZipStore)", + desc=f"RUN {state.run_id} [rank {state.rank}] - {state.stream} (ZipStore)", verbose=5, ) @@ -633,7 +638,7 @@ def get_data_zipstore(state: IOState) -> ReaderOutput: get_reusable_executor().shutdown(wait=True) _logger.info( - f"RUN {state.run_id} - {state.stream}: ZipStore-parallel I/O complete. " + f"RUN {state.run_id} [rank {state.rank}] - {state.stream}: ZipStore-parallel I/O complete. " f"{len(da_tars_dict)} forecast entries loaded." ) return ReaderOutput(target=da_tars_dict, prediction=da_preds_dict) diff --git a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py index d9ff6ca5d..d1f0cd29d 100644 --- a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py @@ -382,37 +382,163 @@ def get_recomputable_metrics(self, metrics): class WeatherGenZarrReader(WeatherGenReader): def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = None): - """Data reader class for WeatherGenerator model outputs stored in Zarr format.""" + """Data reader class for WeatherGenerator model outputs stored in Zarr format. + + Supports multi-rank inference outputs where each rank file contains a disjoint + subset of forecast initializations with overlapping local sample indices. + """ super().__init__(eval_cfg, run_id, private_paths) zarr_ext = self.inference_cfg.get("zarr_store", "zarr") - # For backwards compatibility, assume zarr store is local (.zarr format). + self.zarr_ext = zarr_ext - fname_zarr = self.results_dir.joinpath( - f"validation_chkpt{self.mini_epoch:05d}_rank{self.rank:04d}.{zarr_ext}" - ) + # Discover rank files: support rank="all", rank=[0,1,2], or rank=0 (int) + self.rank_files: list[Path] = self._discover_rank_files() - assert fname_zarr.exists(), f"Zarr file {fname_zarr} does not exist." - - assert (zarr_ext == "zarr" and fname_zarr.is_dir()) or ( - zarr_ext == "zip" and fname_zarr.is_file() - ), ( - f"Zarr file {fname_zarr} has unexpected format. ({zarr_ext}). " - f"Expected directory for 'zarr' or file for 'zip'." - ) - self.fname_zarr = fname_zarr + # Validate metadata consistency across all ranks (fail-fast) + self._validated_metadata: dict = self._validate_rank_metadata() # Metadata caches — populated lazily on first access self._cached_samples: set[int] | None = None - self._cached_fsteps: set[int] | None = None - self._cached_streams: set[str] | None = None self._cached_ensemble: dict[str, list[str]] = {} self._cached_is_gridded: dict[str, bool] = {} + self._rank_sample_map: dict[Path, tuple[list[int], int]] | None = None # Raw I/O worker config (direct zarr access) self._max_workers: int | None = eval_cfg.get("max_workers") self._num_io_workers: int = get_num_workers(max_workers=self._max_workers) + def _discover_rank_files(self) -> list[Path]: + """Discover zarr rank files based on the ``rank`` config parameter. + + Supports: + - ``rank: 0`` (int) — single specific rank (backward compatible) + - ``rank: "all"`` — glob all matching rank files + - ``rank: [0, 1, 2]`` — specific list of ranks + """ + rank_cfg = self.eval_cfg.get("rank", self.rank) + + if isinstance(rank_cfg, int): + # Single rank (backward compatible) + fname = self.results_dir / ( + f"validation_chkpt{self.mini_epoch:05d}_rank{rank_cfg:04d}.{self.zarr_ext}" + ) + if not fname.exists(): + raise FileNotFoundError(f"Zarr file {fname} does not exist.") + return self._validate_rank_files([fname]) + + elif rank_cfg == "all": + pattern = f"validation_chkpt{self.mini_epoch:05d}_rank*.{self.zarr_ext}" + files = sorted(self.results_dir.glob(pattern)) + if not files: + raise FileNotFoundError(f"No zarr files matching {pattern} in {self.results_dir}") + _logger.info(f"Discovered {len(files)} rank file(s) for run {self.run_id}.") + return self._validate_rank_files(files) + + elif isinstance(rank_cfg, list | tuple): + files = [] + for r in rank_cfg: + fname = self.results_dir / ( + f"validation_chkpt{self.mini_epoch:05d}_rank{int(r):04d}.{self.zarr_ext}" + ) + if not fname.exists(): + raise FileNotFoundError(f"Zarr file {fname} does not exist.") + files.append(fname) + return self._validate_rank_files(sorted(files)) + + else: + raise ValueError( + f"Invalid rank config: {rank_cfg!r}. Use an int, 'all', or a list of ints." + ) + + def _validate_rank_files(self, files: list[Path]) -> list[Path]: + """Validate that rank files have the expected format.""" + for f in files: + is_valid = (self.zarr_ext == "zarr" and f.is_dir()) or ( + self.zarr_ext == "zip" and f.is_file() + ) + if not is_valid: + raise FileNotFoundError( + f"Zarr file {f} has unexpected format ({self.zarr_ext}). " + f"Expected directory for 'zarr' or file for 'zip'." + ) + return files + + def _validate_rank_metadata(self) -> dict: + """Validate that all rank files share identical metadata (streams, fsteps). + + Returns a dict with the validated common metadata. + Raises ValueError if any inconsistency is detected. + """ + reference_streams: set[str] | None = None + reference_fsteps: set[int] | None = None + + for rank_file in self.rank_files: + with zarrio_reader(rank_file) as zio: + streams = set(zio.streams) + fsteps = set(int(f) for f in zio.forecast_steps) + + if reference_streams is None: + reference_streams = streams + reference_fsteps = fsteps + else: + if streams != reference_streams: + raise ValueError( + f"Stream mismatch: {rank_file.name} has {streams}, " + f"expected {reference_streams}" + ) + if fsteps != reference_fsteps: + raise ValueError( + f"Forecast step mismatch: {rank_file.name} has {fsteps}, " + f"expected {reference_fsteps}" + ) + + return { + "streams": reference_streams or set(), + "forecast_steps": reference_fsteps or set(), + } + + def _open_any_rank_for_metadata(self): + """Open a rank file for metadata queries. Tries each rank until one succeeds. + + Returns a context-manager (zarrio_reader) that the caller must use in a + ``with`` statement or close manually. + """ + for rank_file in self.rank_files: + try: + return zarrio_reader(rank_file) + except Exception: + _logger.warning(f"Failed to open {rank_file.name} for metadata, trying next...") + raise RuntimeError("No rank files could be opened for metadata queries.") + + def _get_rank_sample_map(self) -> dict[Path, tuple[list[int], int]]: + """Build and cache mapping of rank_file → (local_samples, global_offset). + + Since all ranks use local indices (0, 1, ...), we assign global offsets: + rank0: offset=0, rank1: offset=len(rank0_samples), etc. + """ + if self._rank_sample_map is None: + self._rank_sample_map = {} + offset = 0 + for zarr_file in self.rank_files: + with zarrio_reader(zarr_file) as zio: + local = sorted(int(s) for s in zio.samples) + self._rank_sample_map[zarr_file] = (local, offset) + offset += len(local) + return self._rank_sample_map + + def _merge_fsteps(self, all_das: dict, global_sample_coords) -> dict: + """Merge lists of DataArrays for each forecast step across ranks. + Concatenates along the sample dimension and re-indexes to global samples. + """ + merged = {} + for fstep, das in all_das.items(): + combined = xr.concat(das, dim="sample") if len(das) > 1 else das[0] + merged[fstep] = combined.assign_coords( + sample=global_sample_coords[: len(combined.sample)] + ) + return merged + def get_data( self, stream: str, @@ -423,6 +549,9 @@ def get_data( ) -> ReaderOutput: """Load prediction and target data via direct zarr array access. + When multiple rank files are present, loads from each rank sequentially + and concatenates along the sample dimension with re-indexed global samples. + Parameters ---------- stream : str @@ -437,76 +566,126 @@ def get_data( """ resolved_ensemble = to_list(ensemble or self.get_ensemble(stream)) ens_select = EnsembleSelect.from_names(resolved_ensemble, self.get_ensemble(stream)) - state = _build_io_state( - self.run_id, - self.fname_zarr, - stream, - self.get_stream(stream), - self.get_channels(stream), - self.is_gridded_data(stream), - sorted(int(f) for f in (fsteps or self.get_forecast_steps())), - sorted(int(s) for s in (samples or self.get_samples())), - to_list(channels or self.get_stream(stream).get("channels", self.get_channels(stream))), - resolved_ensemble, - self._num_io_workers, - ens_select, + resolved_fsteps = sorted(int(f) for f in (fsteps or self.get_forecast_steps())) + resolved_channels = to_list( + channels or self.get_stream(stream).get("channels", self.get_channels(stream)) ) - get_data = get_data_zipstore if state.is_zip else get_data_dirstore - return get_data(state) - def get_stream(self, stream: str): - """ - returns the dictionary associated to a particular stream. - Returns an empty dictionary if the stream does not exist in the Zarr file. + rank_sample_map = self._get_rank_sample_map() - Parameters - ---------- - stream: - the stream name + # Determine which ranks to load based on requested global samples + requested_globals = set(int(s) for s in (samples or self.get_samples())) - Returns - ------- - The config dictionary associated to that stream - """ - if self._cached_streams is None: - with zarrio_reader(self.fname_zarr) as zio: - self._cached_streams = set(zio.streams) + all_targets: dict[int, list[xr.DataArray]] = {} + all_predictions: dict[int, list[xr.DataArray]] = {} + ranks_loaded = 0 + + for rank_file in self.rank_files: + local_samples, global_offset = rank_sample_map[rank_file] + + # Check if any of this rank's global samples are requested + rank_globals = set(range(global_offset, global_offset + len(local_samples))) + if not rank_globals & requested_globals: + continue + + # Map requested global indices back to local indices for this rank + rank_local_to_load = [ + local_samples[g - global_offset] for g in sorted(rank_globals & requested_globals) + ] + + _logger.info( + f"RUN {self.run_id} [rank {rank_file.stem.split('rank')[-1]}]: " + f"Loading {len(rank_local_to_load)} samples" + ) + _logger.debug( + f"RUN {self.run_id} [rank {rank_file.stem.split('rank')[-1]}]: " + f"local indices {rank_local_to_load}, " + f"global samples {sorted(rank_globals & requested_globals)}" + ) + + state = _build_io_state( + self.run_id, + rank_file, + stream, + self.get_stream(stream), + self.get_channels(stream), + self.is_gridded_data(stream), + resolved_fsteps, + rank_local_to_load, + resolved_channels, + resolved_ensemble, + self._num_io_workers, + ens_select, + rank=rank_file.stem.split("rank")[-1], + ) + get_data_fn = get_data_zipstore if state.is_zip else get_data_dirstore + result = get_data_fn(state) + + for fstep, da in result.target.items(): + all_targets.setdefault(fstep, []).append(da) + for fstep, da in result.prediction.items(): + all_predictions.setdefault(fstep, []).append(da) + ranks_loaded += 1 + + # Concatenate across ranks along sample dimension and re-index + global_sample_coords = np.array(sorted(requested_globals)) - if stream in self._cached_streams: + merged_targets = self._merge_fsteps(all_targets, global_sample_coords) + merged_predictions = self._merge_fsteps(all_predictions, global_sample_coords) + + ranks_skipped = len(self.rank_files) - ranks_loaded + _logger.info( + f"RUN {self.run_id}: Multi-rank load complete. " + f"{len(global_sample_coords)} samples × {len(merged_targets)} fsteps " + f"from {ranks_loaded}/{len(self.rank_files)} ranks " + f"({ranks_skipped} skipped)." + ) + return ReaderOutput(target=merged_targets, prediction=merged_predictions) + + def get_stream(self, stream: str): + """Return the config dictionary for a particular stream. + + Returns an empty dictionary if the stream does not exist in the Zarr files. + """ + if stream in self._validated_metadata["streams"]: return self.eval_cfg.streams.get(stream, {}) return {} def get_samples(self) -> set[int]: - """Get the set of sample indices from the Zarr file.""" + """Get global sample indices across all rank files. + + Assigns contiguous global indices: rank0 gets 0..N0-1, rank1 gets N0..N0+N1-1, etc. + """ if self._cached_samples is None: - with zarrio_reader(self.fname_zarr) as zio: - self._cached_samples = set(int(s) for s in zio.samples) + rank_sample_map = self._get_rank_sample_map() + all_samples: set[int] = set() + for local_samples, offset in rank_sample_map.values(): + all_samples.update(range(offset, offset + len(local_samples))) + self._cached_samples = all_samples return self._cached_samples def get_forecast_steps(self) -> set[int]: - """Get the set of forecast steps from the Zarr file.""" - if self._cached_fsteps is None: - with zarrio_reader(self.fname_zarr) as zio: - self._cached_fsteps = set(int(f) for f in zio.forecast_steps) - return self._cached_fsteps + """Get the set of forecast steps (validated across all ranks at init).""" + return self._validated_metadata["forecast_steps"] def get_forecast_substep_valid_times(self, stream: str) -> set[str]: - """Get the set of forecast times from the Zarr file.""" + """Get the set of forecast times from a rank file.""" if not self.is_gridded_data(stream): _logger.warning(f"Stream {stream} is not gridded. Forecast times cannot be retrieved.") return set() - with zarrio_reader(self.fname_zarr) as zio: - dummy = zio.get_data(0, stream, zio.forecast_steps[0]) + with self._open_any_rank_for_metadata() as zio: + dummy = zio.get_data(zio.samples[0], stream, zio.forecast_steps[0]) unique_lead = np.unique(dummy.valid_time.data) return set(str(lt) for lt in unique_lead) def get_ensemble(self, stream: str | None = None) -> list[str]: - """Get the list of ensemble member names for a given stream from the config. + """Get the list of ensemble member names for a given stream. + Parameters ---------- stream : - The name of the stream to get channels for. + The name of the stream to get ensemble members for. Returns ------- @@ -515,18 +694,18 @@ def get_ensemble(self, stream: str | None = None) -> list[str]: _logger.debug(f"Getting ensembles for stream {stream}...") if stream not in self._cached_ensemble: - # TODO: improve this to get ensemble from io class - with zarrio_reader(self.fname_zarr) as zio: - dummy = zio.get_data(0, stream, zio.forecast_steps[0]) + with self._open_any_rank_for_metadata() as zio: + dummy = zio.get_data(zio.samples[0], stream, zio.forecast_steps[0]) self._cached_ensemble[stream] = list(dummy.prediction.as_xarray().coords["ens"].values) return self._cached_ensemble[stream] def is_gridded_data(self, stream: str) -> bool: - """Check if the latitude and longitude coordinates are regularly spaced for a given stream. + """Check if lat/lon coordinates are regularly spaced for a given stream. + Parameters ---------- stream : - The name of the stream to get channels for. + The name of the stream to check. Returns ------- @@ -540,8 +719,8 @@ def _compute_is_gridded(self, stream: str) -> bool: """is_gridded_data logic, called once per stream and cached.""" _logger.debug(f"Checking regular spacing for stream {stream}...") - with zarrio_reader(self.fname_zarr) as zio: - dummy = zio.get_data(0, stream, zio.forecast_steps[0]) + with self._open_any_rank_for_metadata() as zio: + dummy = zio.get_data(zio.samples[0], stream, zio.forecast_steps[0]) sample_idx = zio.samples[1] if len(zio.samples) > 1 else zio.samples[0] fstep_idx = (