diff --git a/packages/common/src/weathergen/common/config.py b/packages/common/src/weathergen/common/config.py index a4852f0fc..0777460aa 100644 --- a/packages/common/src/weathergen/common/config.py +++ b/packages/common/src/weathergen/common/config.py @@ -88,7 +88,7 @@ def _sanitize_start_end_time_keys(sub_conf): for key in time_keys: if key in sub_conf: raw_key = f"_{key}" - sub_conf[raw_key] = f"${{{key}}}" + sub_conf[raw_key] = sub_conf[key] sub_conf[key] = f"${{{_DATETIME_TYPE_NAME}:{sub_conf[key]}}}" @@ -98,14 +98,14 @@ def _sanitize_delta_time_keys(sub_conf): for key in delta_keys: if key in sub_conf: raw_key = f"_{key}" - sub_conf[raw_key] = f"${{{key}}}" + sub_conf[raw_key] = sub_conf[key] sub_conf[key] = f"${{{_TIMEDELTA_TYPE_NAME}:{sub_conf[key]}}}" if sub_conf.get("forecast") is not None: key = "time_step" if key in sub_conf.forecast: raw_key = f"_{key}" - sub_conf.forecast[raw_key] = f"${{{key}}}" + sub_conf.forecast[raw_key] = sub_conf.forecast[key] sub_conf.forecast[key] = f"${{{_TIMEDELTA_TYPE_NAME}:{sub_conf.forecast[key]}}}" diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index e76edd28c..5ef6e49fe 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -9,9 +9,11 @@ import logging import pathlib +from collections.abc import Sequence import numpy as np import torch +from omegaconf import OmegaConf from weathergen.common.config import Config from weathergen.common.io import IOReaderData @@ -39,6 +41,13 @@ logger = logging.getLogger(__name__) +FORECAST_DEFAULTS = { + "offset": 0, + "time_step": np.timedelta64(0, "ms"), + "policy": None, + "num_steps": np.array([0], dtype=np.int32), +} + def collect_datasources(stream_datasets: list, idx: int, type: str, rng) -> IOReaderData: """ @@ -96,11 +105,14 @@ def __init__(self, cf: Config, mode_cfg: dict, stage: Stage): self.masker = Masker(cf.healpix_level, stage, self.streams, self.mode_cfg) self.tokenizer = TokenizerMasking(cf.healpix_level, self.masker) - self.forecast_cfg = mode_cfg.get("forecast", {}) - self._init_forecast_cfg() + forecast_cfg = FORECAST_DEFAULTS | OmegaConf.to_object(mode_cfg.get("forecast", {})) + self.output_offset = forecast_cfg["offset"] + self.time_step = forecast_cfg["time_step"] + self.forecast_policy = forecast_cfg["policy"] + steps = np.array(forecast_cfg["num_steps"], dtype=np.int32).reshape(-1) + self.list_num_forecast_steps = np.array(steps, dtype=np.int32) # initialise fsm, but can change for future mini_epochs - self.fsm = self.list_num_forecast_steps[0] self.batch_size = get_batch_size_from_config(mode_cfg) self.shuffle = mode_cfg.shuffle @@ -120,9 +132,8 @@ def __init__(self, cf: Config, mode_cfg: dict, stage: Stage): # check samples per mini epoch self.samples_per_mini_epoch = mode_cfg.samples_per_mini_epoch - self.check_samples() - self.calc_baseperms() - self._init_stream_datasets(cf) + self.check_samples(self._get_fsm()) + self.streams_datasets = self._init_stream_datasets(cf) # RNG seed setup rs = cf.data_loading.rng_seed @@ -130,36 +141,20 @@ def __init__(self, cf: Config, mode_cfg: dict, stage: Stage): self.data_loader_rng_seed = rs if rs > nw else rs * 97 self.rng = None - self.perms = None - self.perms_num_forecast_steps = None - - def _init_forecast_cfg(self): - if len(self.forecast_cfg) == 0: - self.list_num_forecast_steps = np.array([0], dtype=np.int32) - self.output_offset = 0 - self.forecast_policy = None - self.time_step = np.timedelta64(0, "ms") - return - - self.output_offset = self.forecast_cfg.get("offset", 0) - self.time_step = self.forecast_cfg.get("time_step", np.timedelta64(0, "ms")) - self.forecast_policy = self.forecast_cfg.get("policy", None) - - if isinstance(self.forecast_cfg.num_steps, int): - steps = [self.forecast_cfg.num_steps] - else: - steps = self.forecast_cfg.num_steps - - self.list_num_forecast_steps = np.array(steps, dtype=np.int32) - def check_samples(self): + def check_samples(self, fsm: int): """Check if samples_per_mini_epoch is suitable Repeated both to initialise the MultiStreamDataSampler and for each mini epoch""" - forecast_win = self.time_step * (self.fsm + self.output_offset) # in time units - available_samples = ( - self.mode_cfg.end_date - self.mode_cfg.start_date - forecast_win - ) // self.step_timedelta + max_index = self.index_range.end - ( + ( # max time units needed to make a forecast + self.time_step * (fsm + self.output_offset) # translation due to forecasting + + self.len_timedelta # length of forecasting window + ) + // self.step_timedelta # as number of indexs + ) + + available_samples = max_index * self.batch_size # as number of samples assert available_samples > 0, ( "There is an insufficient date range to \ @@ -181,28 +176,29 @@ def check_samples(self): logger.info("Samples will be repeated within the time range") # streamlined calculation of length - self.len = self.samples_per_mini_epoch + epoch_len = self.samples_per_mini_epoch # adjust len to split loading across all workers and ensure it is multiple of batch_size - len_chunk = ((self.len // self.world_size) // self.batch_size) * self.batch_size - self.len = min(self.len, len_chunk) - n_duplicates = self.len - available_samples + self.len = ((epoch_len // self.world_size) // self.batch_size) * self.batch_size + + n_duplicates = self.len * self.world_size - available_samples if not self.repeat_data: assert n_duplicates <= 0 - def calc_baseperms(self): + def _calc_baseperms(self, fsm: int) -> np.typing.NDArray: """This calculates the base permutation array and depends on fsm so must be repeated for __init__ and reset""" perms_len = int(self.index_range.end - self.index_range.start) - perms_len -= (self.fsm + self.output_offset) * (self.time_step // self.step_timedelta) - self.base_perms = np.arange(perms_len) + perms_len -= (fsm + self.output_offset) * (self.time_step // self.step_timedelta) - def _init_stream_datasets(self, cf): + return np.arange(perms_len) + + def _init_stream_datasets(self, cf) -> dict[StreamName, list[AnyDataReader]]: """Load dataset readers for all streams from config.""" - self.streams_datasets: dict[StreamName, list[AnyDataReader]] = {} + streams_datasets: dict[StreamName, list[AnyDataReader]] = {} for _, stream_info in enumerate(cf.streams): # list of sources for current stream - self.streams_datasets[stream_info["name"]] = [] + streams_datasets[stream_info["name"]] = [] kwargs = { "tw_handler": self.time_window_handler, @@ -259,27 +255,31 @@ def _init_stream_datasets(self, cf): else [1.0 for _ in ds.target_channels] ) - self.streams_datasets[stream_info["name"]] += [ds] + streams_datasets[stream_info["name"]] += [ds] - def reset(self): - """Reset RNG, shuffle perms, compute forecast steps.""" - self.rng = np.random.default_rng(self.data_loader_rng_seed) - # reset fsm for each mini epoch - fsm = self.reset_fsm() - if fsm != self.fsm: - logger.info(f"Number of forecast steps updated from {self.fsm} to {fsm}.") - self.fsm = fsm - self.check_samples() - self.calc_baseperms() + return streams_datasets - perms = self.base_perms.copy() + def reset(self) -> tuple[Sequence[int], Sequence[int]]: + """ + Reset RNG, return shuffled perms adn forecast steps for this mini epoch. + + The permutation index size is proportional to self.samples_per_mini_epoch, + wheras the forecast steps index length is proportional to len(self). + + Returns: permutation index, forecast steps index + """ + self.rng = np.random.default_rng(self.data_loader_rng_seed) + fsm = self._get_fsm() + self.check_samples(fsm) + perms = self._calc_baseperms(fsm) # rng changed, repeat if needed - if self.repeat_data and len(perms) < self.samples_per_mini_epoch: - perms = np.tile(perms, self.samples_per_mini_epoch // len(perms)) + n_requested_idxs = self.samples_per_mini_epoch // self.batch_size + if self.repeat_data and len(perms) < n_requested_idxs: + perms = np.tile(perms, n_requested_idxs // len(perms)) filler = self.rng.choice( perms, - size=self.samples_per_mini_epoch - len(perms), + size=n_requested_idxs - len(perms), replace=False, ) perms = np.concatenate([perms, filler]) @@ -288,32 +288,30 @@ def reset(self): if self.shuffle: perms = self.rng.permutation(perms) - self.perms = perms - len_dt = len(self) // self.batch_size if self.forecast_policy is None: fs = np.zeros(len_dt, dtype=np.int64) elif self.forecast_policy in ("fixed", "sequential"): - fs = self.fsm * np.ones(len_dt, dtype=np.int64) + fs = fsm * np.ones(len_dt, dtype=np.int64) elif self.forecast_policy in ("random", "sequential_random"): fs = self.rng.integers( low=self.list_num_forecast_steps.min(), - high=self.fsm + 1, + high=fsm + 1, size=len_dt, dtype=np.int64, ) else: raise ValueError(f"Unknown forecast policy {self.forecast_policy}") - self.perms_num_forecast_steps = fs - # reset tokenizer RNG self.tokenizer.reset_rng(self.rng) + return (perms, fs) - def reset_fsm(self): + def _get_fsm(self) -> int: + """Obtain maximum number of forecast steps for current mini epoch.""" # fixed number of forecast steps for this run if self.forecast_policy != "random": idx = min(self.mini_epoch, len(self.list_num_forecast_steps) - 1) @@ -744,7 +742,7 @@ def __iter__(self) -> ModelBatch: logger.info(f"iter_start={iter_start}, iter_end={iter_end}, len={self.len}") # create new shuffeling - self.reset() + perms, perms_num_forecast_steps = self.reset() # bidx is used to count the #batches that have been emitted # idx_raw is used to index into the dataset; the decoupling is needed @@ -753,12 +751,12 @@ def __iter__(self) -> ModelBatch: for i, _bidx in enumerate(range(iter_start, iter_end, self.batch_size)): # num_forecast_steps needs to be constant per batch # (amortized through data parallel training) - num_forecast_steps = self.perms_num_forecast_steps[i] + num_forecast_steps = perms_num_forecast_steps[i] # use while loop due to the scattered nature of the data in time and to # ensure batches are not empty while True: - idx: TIndex = self.perms[idx_raw % self.perms.shape[0]] + idx: TIndex = perms[idx_raw % perms.shape[0]] idx_raw += 1 batch = self._get_batch(idx, num_forecast_steps)