Skip to content
6 changes: 3 additions & 3 deletions packages/common/src/weathergen/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def _sanitize_start_end_time_keys(sub_conf):
for key in time_keys:
if key in sub_conf:
raw_key = f"_{key}"
sub_conf[raw_key] = f"${{{key}}}"
sub_conf[raw_key] = sub_conf[key]
sub_conf[key] = f"${{{_DATETIME_TYPE_NAME}:{sub_conf[key]}}}"


Expand All @@ -98,14 +98,14 @@ def _sanitize_delta_time_keys(sub_conf):
for key in delta_keys:
if key in sub_conf:
raw_key = f"_{key}"
sub_conf[raw_key] = f"${{{key}}}"
sub_conf[raw_key] = sub_conf[key]
sub_conf[key] = f"${{{_TIMEDELTA_TYPE_NAME}:{sub_conf[key]}}}"

if sub_conf.get("forecast") is not None:
key = "time_step"
if key in sub_conf.forecast:
raw_key = f"_{key}"
sub_conf.forecast[raw_key] = f"${{{key}}}"
sub_conf.forecast[raw_key] = sub_conf.forecast[key]
sub_conf.forecast[key] = f"${{{_TIMEDELTA_TYPE_NAME}:{sub_conf.forecast[key]}}}"


Expand Down
132 changes: 65 additions & 67 deletions src/weathergen/datasets/multi_stream_data_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@

import logging
import pathlib
from collections.abc import Sequence

import numpy as np
import torch
from omegaconf import OmegaConf

from weathergen.common.config import Config
from weathergen.common.io import IOReaderData
Expand Down Expand Up @@ -39,6 +41,13 @@

logger = logging.getLogger(__name__)

FORECAST_DEFAULTS = {
"offset": 0,
"time_step": np.timedelta64(0, "ms"),
"policy": None,
"num_steps": np.array([0], dtype=np.int32),
}


def collect_datasources(stream_datasets: list, idx: int, type: str, rng) -> IOReaderData:
"""
Expand Down Expand Up @@ -96,11 +105,14 @@ def __init__(self, cf: Config, mode_cfg: dict, stage: Stage):
self.masker = Masker(cf.healpix_level, stage, self.streams, self.mode_cfg)
self.tokenizer = TokenizerMasking(cf.healpix_level, self.masker)

self.forecast_cfg = mode_cfg.get("forecast", {})
self._init_forecast_cfg()
forecast_cfg = FORECAST_DEFAULTS | OmegaConf.to_object(mode_cfg.get("forecast", {}))
self.output_offset = forecast_cfg["offset"]
self.time_step = forecast_cfg["time_step"]
self.forecast_policy = forecast_cfg["policy"]
steps = np.array(forecast_cfg["num_steps"], dtype=np.int32).reshape(-1)
self.list_num_forecast_steps = np.array(steps, dtype=np.int32)

# initialise fsm, but can change for future mini_epochs
self.fsm = self.list_num_forecast_steps[0]
self.batch_size = get_batch_size_from_config(mode_cfg)
self.shuffle = mode_cfg.shuffle

Expand All @@ -120,46 +132,29 @@ def __init__(self, cf: Config, mode_cfg: dict, stage: Stage):

# check samples per mini epoch
self.samples_per_mini_epoch = mode_cfg.samples_per_mini_epoch
self.check_samples()
self.calc_baseperms()
self._init_stream_datasets(cf)
self.check_samples(self._get_fsm())
self.streams_datasets = self._init_stream_datasets(cf)

# RNG seed setup
rs = cf.data_loading.rng_seed
nw = cf.data_loading.num_workers
self.data_loader_rng_seed = rs if rs > nw else rs * 97

self.rng = None
self.perms = None
self.perms_num_forecast_steps = None

def _init_forecast_cfg(self):
if len(self.forecast_cfg) == 0:
self.list_num_forecast_steps = np.array([0], dtype=np.int32)
self.output_offset = 0
self.forecast_policy = None
self.time_step = np.timedelta64(0, "ms")
return

self.output_offset = self.forecast_cfg.get("offset", 0)
self.time_step = self.forecast_cfg.get("time_step", np.timedelta64(0, "ms"))
self.forecast_policy = self.forecast_cfg.get("policy", None)

if isinstance(self.forecast_cfg.num_steps, int):
steps = [self.forecast_cfg.num_steps]
else:
steps = self.forecast_cfg.num_steps

self.list_num_forecast_steps = np.array(steps, dtype=np.int32)

def check_samples(self):
def check_samples(self, fsm: int):
"""Check if samples_per_mini_epoch is suitable
Repeated both to initialise the MultiStreamDataSampler and for each mini epoch"""

forecast_win = self.time_step * (self.fsm + self.output_offset) # in time units
available_samples = (
self.mode_cfg.end_date - self.mode_cfg.start_date - forecast_win
) // self.step_timedelta
max_index = self.index_range.end - (
( # max time units needed to make a forecast
self.time_step * (fsm + self.output_offset) # translation due to forecasting
+ self.len_timedelta # length of forecasting window
)
// self.step_timedelta # as number of indexs
)

available_samples = max_index * self.batch_size # as number of samples

assert available_samples > 0, (
"There is an insufficient date range to \
Expand All @@ -181,28 +176,29 @@ def check_samples(self):
logger.info("Samples will be repeated within the time range")

# streamlined calculation of length
self.len = self.samples_per_mini_epoch
epoch_len = self.samples_per_mini_epoch
# adjust len to split loading across all workers and ensure it is multiple of batch_size
len_chunk = ((self.len // self.world_size) // self.batch_size) * self.batch_size
self.len = min(self.len, len_chunk)
n_duplicates = self.len - available_samples
self.len = ((epoch_len // self.world_size) // self.batch_size) * self.batch_size

n_duplicates = self.len * self.world_size - available_samples
if not self.repeat_data:
assert n_duplicates <= 0

def calc_baseperms(self):
def _calc_baseperms(self, fsm: int) -> np.typing.NDArray:
"""This calculates the base permutation array and
depends on fsm so must be repeated for __init__ and reset"""
perms_len = int(self.index_range.end - self.index_range.start)
perms_len -= (self.fsm + self.output_offset) * (self.time_step // self.step_timedelta)
self.base_perms = np.arange(perms_len)
perms_len -= (fsm + self.output_offset) * (self.time_step // self.step_timedelta)

def _init_stream_datasets(self, cf):
return np.arange(perms_len)

def _init_stream_datasets(self, cf) -> dict[StreamName, list[AnyDataReader]]:
"""Load dataset readers for all streams from config."""
self.streams_datasets: dict[StreamName, list[AnyDataReader]] = {}
streams_datasets: dict[StreamName, list[AnyDataReader]] = {}

for _, stream_info in enumerate(cf.streams):
# list of sources for current stream
self.streams_datasets[stream_info["name"]] = []
streams_datasets[stream_info["name"]] = []

kwargs = {
"tw_handler": self.time_window_handler,
Expand Down Expand Up @@ -259,27 +255,31 @@ def _init_stream_datasets(self, cf):
else [1.0 for _ in ds.target_channels]
)

self.streams_datasets[stream_info["name"]] += [ds]
streams_datasets[stream_info["name"]] += [ds]

def reset(self):
"""Reset RNG, shuffle perms, compute forecast steps."""
self.rng = np.random.default_rng(self.data_loader_rng_seed)
# reset fsm for each mini epoch
fsm = self.reset_fsm()
if fsm != self.fsm:
logger.info(f"Number of forecast steps updated from {self.fsm} to {fsm}.")
self.fsm = fsm
self.check_samples()
self.calc_baseperms()
return streams_datasets

perms = self.base_perms.copy()
def reset(self) -> tuple[Sequence[int], Sequence[int]]:
"""
Reset RNG, return shuffled perms adn forecast steps for this mini epoch.

The permutation index size is proportional to self.samples_per_mini_epoch,
wheras the forecast steps index length is proportional to len(self).

Returns: permutation index, forecast steps index
"""
self.rng = np.random.default_rng(self.data_loader_rng_seed)
fsm = self._get_fsm()
self.check_samples(fsm)
perms = self._calc_baseperms(fsm)

# rng changed, repeat if needed
if self.repeat_data and len(perms) < self.samples_per_mini_epoch:
perms = np.tile(perms, self.samples_per_mini_epoch // len(perms))
n_requested_idxs = self.samples_per_mini_epoch // self.batch_size
if self.repeat_data and len(perms) < n_requested_idxs:
perms = np.tile(perms, n_requested_idxs // len(perms))
filler = self.rng.choice(
perms,
size=self.samples_per_mini_epoch - len(perms),
size=n_requested_idxs - len(perms),
replace=False,
)
perms = np.concatenate([perms, filler])
Expand All @@ -288,32 +288,30 @@ def reset(self):
if self.shuffle:
perms = self.rng.permutation(perms)

self.perms = perms

len_dt = len(self) // self.batch_size

if self.forecast_policy is None:
fs = np.zeros(len_dt, dtype=np.int64)

elif self.forecast_policy in ("fixed", "sequential"):
fs = self.fsm * np.ones(len_dt, dtype=np.int64)
fs = fsm * np.ones(len_dt, dtype=np.int64)

elif self.forecast_policy in ("random", "sequential_random"):
fs = self.rng.integers(
low=self.list_num_forecast_steps.min(),
high=self.fsm + 1,
high=fsm + 1,
size=len_dt,
dtype=np.int64,
)
else:
raise ValueError(f"Unknown forecast policy {self.forecast_policy}")

self.perms_num_forecast_steps = fs

# reset tokenizer RNG
self.tokenizer.reset_rng(self.rng)
return (perms, fs)

def reset_fsm(self):
def _get_fsm(self) -> int:
"""Obtain maximum number of forecast steps for current mini epoch."""
# fixed number of forecast steps for this run
if self.forecast_policy != "random":
idx = min(self.mini_epoch, len(self.list_num_forecast_steps) - 1)
Expand Down Expand Up @@ -744,7 +742,7 @@ def __iter__(self) -> ModelBatch:
logger.info(f"iter_start={iter_start}, iter_end={iter_end}, len={self.len}")

# create new shuffeling
self.reset()
perms, perms_num_forecast_steps = self.reset()

# bidx is used to count the #batches that have been emitted
# idx_raw is used to index into the dataset; the decoupling is needed
Expand All @@ -753,12 +751,12 @@ def __iter__(self) -> ModelBatch:
for i, _bidx in enumerate(range(iter_start, iter_end, self.batch_size)):
# num_forecast_steps needs to be constant per batch
# (amortized through data parallel training)
num_forecast_steps = self.perms_num_forecast_steps[i]
num_forecast_steps = perms_num_forecast_steps[i]

# use while loop due to the scattered nature of the data in time and to
# ensure batches are not empty
while True:
idx: TIndex = self.perms[idx_raw % self.perms.shape[0]]
idx: TIndex = perms[idx_raw % perms.shape[0]]
idx_raw += 1

batch = self._get_batch(idx, num_forecast_steps)
Expand Down
Loading