From 1f3266914b162b62952a1b89a42907d5db0b72e8 Mon Sep 17 00:00:00 2001 From: Simon Grasse Date: Sat, 11 Apr 2026 08:42:34 +0200 Subject: [PATCH 1/9] return streams_datasets instead of modifiying attribute. --- src/weathergen/datasets/multi_stream_data_sampler.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index e76edd28c..089bfdf94 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -196,13 +196,13 @@ def calc_baseperms(self): perms_len -= (self.fsm + self.output_offset) * (self.time_step // self.step_timedelta) self.base_perms = np.arange(perms_len) - def _init_stream_datasets(self, cf): + def _init_stream_datasets(self, cf) -> dict[StreamName, list[AnyDataReader]]: """Load dataset readers for all streams from config.""" - self.streams_datasets: dict[StreamName, list[AnyDataReader]] = {} + streams_datasets: dict[StreamName, list[AnyDataReader]] = {} for _, stream_info in enumerate(cf.streams): # list of sources for current stream - self.streams_datasets[stream_info["name"]] = [] + streams_datasets[stream_info["name"]] = [] kwargs = { "tw_handler": self.time_window_handler, @@ -259,7 +259,9 @@ def _init_stream_datasets(self, cf): else [1.0 for _ in ds.target_channels] ) - self.streams_datasets[stream_info["name"]] += [ds] + streams_datasets[stream_info["name"]] += [ds] + + return streams_datasets def reset(self): """Reset RNG, shuffle perms, compute forecast steps.""" From 3db99f59564d18a166661bc7a112dc22e858ac50 Mon Sep 17 00:00:00 2001 From: Simon Grasse Date: Sat, 11 Apr 2026 09:08:42 +0200 Subject: [PATCH 2/9] return perms, perms_num_forecast_steps instead of using attributes. --- .../datasets/multi_stream_data_sampler.py | 34 ++++++++++--------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 089bfdf94..92442f6f4 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -9,6 +9,7 @@ import logging import pathlib +from collections.abc import Sequence import numpy as np import torch @@ -130,8 +131,6 @@ def __init__(self, cf: Config, mode_cfg: dict, stage: Stage): self.data_loader_rng_seed = rs if rs > nw else rs * 97 self.rng = None - self.perms = None - self.perms_num_forecast_steps = None def _init_forecast_cfg(self): if len(self.forecast_cfg) == 0: @@ -189,12 +188,13 @@ def check_samples(self): if not self.repeat_data: assert n_duplicates <= 0 - def calc_baseperms(self): + def _calc_baseperms(self) -> np.typing.NDArray: """This calculates the base permutation array and depends on fsm so must be repeated for __init__ and reset""" perms_len = int(self.index_range.end - self.index_range.start) perms_len -= (self.fsm + self.output_offset) * (self.time_step // self.step_timedelta) - self.base_perms = np.arange(perms_len) + + return np.arange(perms_len) def _init_stream_datasets(self, cf) -> dict[StreamName, list[AnyDataReader]]: """Load dataset readers for all streams from config.""" @@ -263,19 +263,24 @@ def _init_stream_datasets(self, cf) -> dict[StreamName, list[AnyDataReader]]: return streams_datasets - def reset(self): - """Reset RNG, shuffle perms, compute forecast steps.""" + def reset(self) -> tuple[Sequence[int], Sequence[int]]: + """ + Reset RNG, return shuffled perms adn forecast steps for this mini epoch. + + The permutation index size is proportional to self.samples_per_mini_epoch, + wheras the forecast steps index length is proportional to len(self). + + Returns: permutation index, forecast steps index + """ self.rng = np.random.default_rng(self.data_loader_rng_seed) # reset fsm for each mini epoch fsm = self.reset_fsm() if fsm != self.fsm: logger.info(f"Number of forecast steps updated from {self.fsm} to {fsm}.") self.fsm = fsm - self.check_samples() + perms = self.check_samples() self.calc_baseperms() - perms = self.base_perms.copy() - # rng changed, repeat if needed if self.repeat_data and len(perms) < self.samples_per_mini_epoch: perms = np.tile(perms, self.samples_per_mini_epoch // len(perms)) @@ -290,8 +295,6 @@ def reset(self): if self.shuffle: perms = self.rng.permutation(perms) - self.perms = perms - len_dt = len(self) // self.batch_size if self.forecast_policy is None: @@ -310,10 +313,9 @@ def reset(self): else: raise ValueError(f"Unknown forecast policy {self.forecast_policy}") - self.perms_num_forecast_steps = fs - # reset tokenizer RNG self.tokenizer.reset_rng(self.rng) + return (perms, fs) def reset_fsm(self): # fixed number of forecast steps for this run @@ -746,7 +748,7 @@ def __iter__(self) -> ModelBatch: logger.info(f"iter_start={iter_start}, iter_end={iter_end}, len={self.len}") # create new shuffeling - self.reset() + perms, perms_num_forecast_steps = self.reset() # bidx is used to count the #batches that have been emitted # idx_raw is used to index into the dataset; the decoupling is needed @@ -755,12 +757,12 @@ def __iter__(self) -> ModelBatch: for i, _bidx in enumerate(range(iter_start, iter_end, self.batch_size)): # num_forecast_steps needs to be constant per batch # (amortized through data parallel training) - num_forecast_steps = self.perms_num_forecast_steps[i] + num_forecast_steps = perms_num_forecast_steps[i] # use while loop due to the scattered nature of the data in time and to # ensure batches are not empty while True: - idx: TIndex = self.perms[idx_raw % self.perms.shape[0]] + idx: TIndex = perms[idx_raw % perms.shape[0]] idx_raw += 1 batch = self._get_batch(idx, num_forecast_steps) From e641ba9e4e632ab365196968d1201fb88b726939 Mon Sep 17 00:00:00 2001 From: Simon Grasse Date: Sat, 11 Apr 2026 09:15:15 +0200 Subject: [PATCH 3/9] remove attribute self.fsm (replace by return values) --- .../datasets/multi_stream_data_sampler.py | 25 ++++++++----------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 92442f6f4..1af959dc5 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -101,7 +101,6 @@ def __init__(self, cf: Config, mode_cfg: dict, stage: Stage): self._init_forecast_cfg() # initialise fsm, but can change for future mini_epochs - self.fsm = self.list_num_forecast_steps[0] self.batch_size = get_batch_size_from_config(mode_cfg) self.shuffle = mode_cfg.shuffle @@ -121,8 +120,7 @@ def __init__(self, cf: Config, mode_cfg: dict, stage: Stage): # check samples per mini epoch self.samples_per_mini_epoch = mode_cfg.samples_per_mini_epoch - self.check_samples() - self.calc_baseperms() + self.check_samples(self._get_fsm(self.mini_epoch)) self._init_stream_datasets(cf) # RNG seed setup @@ -188,11 +186,11 @@ def check_samples(self): if not self.repeat_data: assert n_duplicates <= 0 - def _calc_baseperms(self) -> np.typing.NDArray: + def _calc_baseperms(self, fsm: int) -> np.typing.NDArray: """This calculates the base permutation array and depends on fsm so must be repeated for __init__ and reset""" perms_len = int(self.index_range.end - self.index_range.start) - perms_len -= (self.fsm + self.output_offset) * (self.time_step // self.step_timedelta) + perms_len -= (fsm + self.output_offset) * (self.time_step // self.step_timedelta) return np.arange(perms_len) @@ -273,13 +271,9 @@ def reset(self) -> tuple[Sequence[int], Sequence[int]]: Returns: permutation index, forecast steps index """ self.rng = np.random.default_rng(self.data_loader_rng_seed) - # reset fsm for each mini epoch - fsm = self.reset_fsm() - if fsm != self.fsm: - logger.info(f"Number of forecast steps updated from {self.fsm} to {fsm}.") - self.fsm = fsm - perms = self.check_samples() - self.calc_baseperms() + fsm = self._get_fsm() + self.check_samples(fsm) + perms = self._calc_baseperms(fsm) # rng changed, repeat if needed if self.repeat_data and len(perms) < self.samples_per_mini_epoch: @@ -301,12 +295,12 @@ def reset(self) -> tuple[Sequence[int], Sequence[int]]: fs = np.zeros(len_dt, dtype=np.int64) elif self.forecast_policy in ("fixed", "sequential"): - fs = self.fsm * np.ones(len_dt, dtype=np.int64) + fs = fsm * np.ones(len_dt, dtype=np.int64) elif self.forecast_policy in ("random", "sequential_random"): fs = self.rng.integers( low=self.list_num_forecast_steps.min(), - high=self.fsm + 1, + high=fsm + 1, size=len_dt, dtype=np.int64, ) @@ -317,7 +311,8 @@ def reset(self) -> tuple[Sequence[int], Sequence[int]]: self.tokenizer.reset_rng(self.rng) return (perms, fs) - def reset_fsm(self): + def _get_fsm(self) -> int: + """Obtain maximum number of forecast steps for current mini epoch.""" # fixed number of forecast steps for this run if self.forecast_policy != "random": idx = min(self.mini_epoch, len(self.list_num_forecast_steps) - 1) From d59a428a2be66304f110a5f34397fff4323633e4 Mon Sep 17 00:00:00 2001 From: Simon Grasse Date: Sat, 11 Apr 2026 22:40:13 +0200 Subject: [PATCH 4/9] Inline forecast config initialization. --- .../datasets/multi_stream_data_sampler.py | 34 +++++++------------ 1 file changed, 13 insertions(+), 21 deletions(-) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 1af959dc5..e3542292e 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -40,6 +40,13 @@ logger = logging.getLogger(__name__) +FORECAST_DEFAULTS = { + "offset": 0, + "time_step": np.timedelta64(0, "ms"), + "policy": None, + "num_steps": np.array([0], dtype=np.int32), +} + def collect_datasources(stream_datasets: list, idx: int, type: str, rng) -> IOReaderData: """ @@ -97,8 +104,12 @@ def __init__(self, cf: Config, mode_cfg: dict, stage: Stage): self.masker = Masker(cf.healpix_level, stage, self.streams, self.mode_cfg) self.tokenizer = TokenizerMasking(cf.healpix_level, self.masker) - self.forecast_cfg = mode_cfg.get("forecast", {}) - self._init_forecast_cfg() + forecast_cfg = FORECAST_DEFAULTS | mode_cfg.get("forecast", {}) + self.output_offset = forecast_cfg["offset"] + self.time_step = forecast_cfg["time_step"] + self.forecast_policy = forecast_cfg["policy"] + steps = np.array(forecast_cfg.num_steps, dtype=np.int32).reshape(-1) + self.list_num_forecast_steps = np.array(steps, dtype=np.int32) # initialise fsm, but can change for future mini_epochs self.batch_size = get_batch_size_from_config(mode_cfg) @@ -130,25 +141,6 @@ def __init__(self, cf: Config, mode_cfg: dict, stage: Stage): self.rng = None - def _init_forecast_cfg(self): - if len(self.forecast_cfg) == 0: - self.list_num_forecast_steps = np.array([0], dtype=np.int32) - self.output_offset = 0 - self.forecast_policy = None - self.time_step = np.timedelta64(0, "ms") - return - - self.output_offset = self.forecast_cfg.get("offset", 0) - self.time_step = self.forecast_cfg.get("time_step", np.timedelta64(0, "ms")) - self.forecast_policy = self.forecast_cfg.get("policy", None) - - if isinstance(self.forecast_cfg.num_steps, int): - steps = [self.forecast_cfg.num_steps] - else: - steps = self.forecast_cfg.num_steps - - self.list_num_forecast_steps = np.array(steps, dtype=np.int32) - def check_samples(self): """Check if samples_per_mini_epoch is suitable Repeated both to initialise the MultiStreamDataSampler and for each mini epoch""" From 246519fb2aadd5d3def93613dba2006ec8ed1b29 Mon Sep 17 00:00:00 2001 From: Simon Grasse Date: Sat, 11 Apr 2026 22:40:39 +0200 Subject: [PATCH 5/9] fixup init datastreams --- src/weathergen/datasets/multi_stream_data_sampler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index e3542292e..ef2b7a8a6 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -132,7 +132,7 @@ def __init__(self, cf: Config, mode_cfg: dict, stage: Stage): # check samples per mini epoch self.samples_per_mini_epoch = mode_cfg.samples_per_mini_epoch self.check_samples(self._get_fsm(self.mini_epoch)) - self._init_stream_datasets(cf) + self.streams_datasets = self._init_stream_datasets(cf) # RNG seed setup rs = cf.data_loading.rng_seed From 72d2d7d69047e6ad4563ef096bd1c96a159f80e6 Mon Sep 17 00:00:00 2001 From: Simon Grasse Date: Sat, 11 Apr 2026 22:41:13 +0200 Subject: [PATCH 6/9] fixup: fsm no attribute --- src/weathergen/datasets/multi_stream_data_sampler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index ef2b7a8a6..7c5a52347 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -141,7 +141,7 @@ def __init__(self, cf: Config, mode_cfg: dict, stage: Stage): self.rng = None - def check_samples(self): + def check_samples(self, fsm: int): """Check if samples_per_mini_epoch is suitable Repeated both to initialise the MultiStreamDataSampler and for each mini epoch""" From 7a0069b2bade06ab834f1f0b8a5ddd8113a1f118 Mon Sep 17 00:00:00 2001 From: Simon Grasse Date: Sat, 11 Apr 2026 22:44:31 +0200 Subject: [PATCH 7/9] correct calculation of #requested indexes #available samples --- .../datasets/multi_stream_data_sampler.py | 28 +++++++++++-------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 7c5a52347..aa1c7f222 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -145,10 +145,15 @@ def check_samples(self, fsm: int): """Check if samples_per_mini_epoch is suitable Repeated both to initialise the MultiStreamDataSampler and for each mini epoch""" - forecast_win = self.time_step * (self.fsm + self.output_offset) # in time units - available_samples = ( - self.mode_cfg.end_date - self.mode_cfg.start_date - forecast_win - ) // self.step_timedelta + max_index = self.index_range.end - ( + ( # max time units needed to make a forecast + self.time_step * (fsm + self.output_offset) # translation due to forecasting + + self.len_timedelta # length of forecasting window + ) + // self.step_timedelta # as number of indexs + ) + + available_samples = max_index * self.batch_size # as number of samples assert available_samples > 0, ( "There is an insufficient date range to \ @@ -170,11 +175,11 @@ def check_samples(self, fsm: int): logger.info("Samples will be repeated within the time range") # streamlined calculation of length - self.len = self.samples_per_mini_epoch + epoch_len = self.samples_per_mini_epoch # adjust len to split loading across all workers and ensure it is multiple of batch_size - len_chunk = ((self.len // self.world_size) // self.batch_size) * self.batch_size - self.len = min(self.len, len_chunk) - n_duplicates = self.len - available_samples + self.len = ((epoch_len // self.world_size) // self.batch_size) * self.batch_size + + n_duplicates = self.len * self.world_size - available_samples if not self.repeat_data: assert n_duplicates <= 0 @@ -268,11 +273,12 @@ def reset(self) -> tuple[Sequence[int], Sequence[int]]: perms = self._calc_baseperms(fsm) # rng changed, repeat if needed - if self.repeat_data and len(perms) < self.samples_per_mini_epoch: - perms = np.tile(perms, self.samples_per_mini_epoch // len(perms)) + n_requested_idxs = self.samples_per_mini_epoch // self.batch_size + if self.repeat_data and len(perms) < n_requested_idxs: + perms = np.tile(perms, n_requested_idxs // len(perms)) filler = self.rng.choice( perms, - size=self.samples_per_mini_epoch - len(perms), + size=n_requested_idxs - len(perms), replace=False, ) perms = np.concatenate([perms, filler]) From c0f693cb5dc9485792c0fae90958b12256d75ad8 Mon Sep 17 00:00:00 2001 From: Simon Grasse Date: Mon, 13 Apr 2026 15:41:42 +0200 Subject: [PATCH 8/9] fix forcast config initialization --- src/weathergen/datasets/multi_stream_data_sampler.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index aa1c7f222..66acc1f09 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -13,6 +13,7 @@ import numpy as np import torch +from omegaconf import OmegaConf from weathergen.common.config import Config from weathergen.common.io import IOReaderData @@ -104,7 +105,7 @@ def __init__(self, cf: Config, mode_cfg: dict, stage: Stage): self.masker = Masker(cf.healpix_level, stage, self.streams, self.mode_cfg) self.tokenizer = TokenizerMasking(cf.healpix_level, self.masker) - forecast_cfg = FORECAST_DEFAULTS | mode_cfg.get("forecast", {}) + forecast_cfg = OmegaConf.merge(FORECAST_DEFAULTS, mode_cfg.get("forecast", {})) self.output_offset = forecast_cfg["offset"] self.time_step = forecast_cfg["time_step"] self.forecast_policy = forecast_cfg["policy"] From c9cae158fa1e676f2b377e632c3f1c0b10660a31 Mon Sep 17 00:00:00 2001 From: Simon Grasse Date: Tue, 14 Apr 2026 08:30:14 +0200 Subject: [PATCH 9/9] fix: unable to convert config with interpolated keys to dict. --- packages/common/src/weathergen/common/config.py | 6 +++--- src/weathergen/datasets/multi_stream_data_sampler.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/packages/common/src/weathergen/common/config.py b/packages/common/src/weathergen/common/config.py index a4852f0fc..0777460aa 100644 --- a/packages/common/src/weathergen/common/config.py +++ b/packages/common/src/weathergen/common/config.py @@ -88,7 +88,7 @@ def _sanitize_start_end_time_keys(sub_conf): for key in time_keys: if key in sub_conf: raw_key = f"_{key}" - sub_conf[raw_key] = f"${{{key}}}" + sub_conf[raw_key] = sub_conf[key] sub_conf[key] = f"${{{_DATETIME_TYPE_NAME}:{sub_conf[key]}}}" @@ -98,14 +98,14 @@ def _sanitize_delta_time_keys(sub_conf): for key in delta_keys: if key in sub_conf: raw_key = f"_{key}" - sub_conf[raw_key] = f"${{{key}}}" + sub_conf[raw_key] = sub_conf[key] sub_conf[key] = f"${{{_TIMEDELTA_TYPE_NAME}:{sub_conf[key]}}}" if sub_conf.get("forecast") is not None: key = "time_step" if key in sub_conf.forecast: raw_key = f"_{key}" - sub_conf.forecast[raw_key] = f"${{{key}}}" + sub_conf.forecast[raw_key] = sub_conf.forecast[key] sub_conf.forecast[key] = f"${{{_TIMEDELTA_TYPE_NAME}:{sub_conf.forecast[key]}}}" diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 66acc1f09..5ef6e49fe 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -105,11 +105,11 @@ def __init__(self, cf: Config, mode_cfg: dict, stage: Stage): self.masker = Masker(cf.healpix_level, stage, self.streams, self.mode_cfg) self.tokenizer = TokenizerMasking(cf.healpix_level, self.masker) - forecast_cfg = OmegaConf.merge(FORECAST_DEFAULTS, mode_cfg.get("forecast", {})) + forecast_cfg = FORECAST_DEFAULTS | OmegaConf.to_object(mode_cfg.get("forecast", {})) self.output_offset = forecast_cfg["offset"] self.time_step = forecast_cfg["time_step"] self.forecast_policy = forecast_cfg["policy"] - steps = np.array(forecast_cfg.num_steps, dtype=np.int32).reshape(-1) + steps = np.array(forecast_cfg["num_steps"], dtype=np.int32).reshape(-1) self.list_num_forecast_steps = np.array(steps, dtype=np.int32) # initialise fsm, but can change for future mini_epochs @@ -132,7 +132,7 @@ def __init__(self, cf: Config, mode_cfg: dict, stage: Stage): # check samples per mini epoch self.samples_per_mini_epoch = mode_cfg.samples_per_mini_epoch - self.check_samples(self._get_fsm(self.mini_epoch)) + self.check_samples(self._get_fsm()) self.streams_datasets = self._init_stream_datasets(cf) # RNG seed setup