From a16421562de610d2966ad975a3ec97c17d92e3f9 Mon Sep 17 00:00:00 2001 From: evenmn Date: Tue, 24 Feb 2026 14:36:46 +0100 Subject: [PATCH 1/3] Implemented step_callback for incremental output of forecast outputs Signed-off-by: evenmn --- config/default_config.yml | 8 + src/weathergen/model/model.py | 13 +- src/weathergen/train/trainer.py | 39 ++- src/weathergen/utils/validation_io.py | 336 +++++++++++++++++--------- 4 files changed, 270 insertions(+), 126 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index 30dd87afd..4214ffeaf 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -220,6 +220,14 @@ validation_config: normalized_samples: False, # output streams to write; default all streams: null, + # streaming output configuration for inference + streaming: { + # enable streaming mode: write output after each N forecast steps instead of all at once + enabled: False, + # number of forecast steps to process before writing to disk + # if null, writes after each step (num_steps=1) + num_steps: null, + } } # run validation before training starts (mainly for model development) diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 273c28838..6c13fceef 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -614,13 +614,18 @@ def tokens_to_latent_state(self, tokens_post_norm, tokens) -> LatentState: z_pre_norm=tokens, ) - def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput: + def forward(self, model_params: ModelParams, batch: ModelBatch, step_callback=None) -> ModelOutput: """Forward pass of the model Tokens are processed through the model components, which were defined in the create method. Args: model_params : Query and embedding parameters - batch + batch : Batch of data + step_callback : Optional callback function to be called after each forecast step. + Called as step_callback(step, output) where step is the forecast step + index and output is the partial ModelOutput for that step. + Used for streaming output writing - allows writing after each step + instead of accumulating all steps in memory. Returns: A list containing all prediction results """ @@ -645,6 +650,10 @@ def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput: output = self.predict_decoders(model_params, step, tokens, batch, output) # latent predictions (raw and with SSL heads) output = self.predict_latent(model_params, step, tokens, batch, output) + + # invoke callback for streaming output if provided + if step_callback is not None: + step_callback(step, output) return output diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index bfa289c7b..8caa3d770 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -563,6 +563,29 @@ def validate(self, mini_epoch, mode_cfg, batch_size): batch.to_device(self.device) + # denormalization function for data + denormalize_data_fct = ( + (lambda x0, x1: x1) + if mode_cfg.get("output", {}).get("normalized_samples", False) + else self.dataset_val.denormalize_target_channels + ) + + # Prepare for streaming or standard write + step_callback = None + streaming_writer = None + if bidx < num_samples_write: + step_callback, streaming_writer = write_output( + self.cf, + mode_cfg, + batch_size, + mini_epoch, + bidx, + denormalize_data_fct, + batch, + None, # model_output (will be computed below) + self.target_and_aux_calculators_val, + ) + # evaluate model with torch.autocast( device_type=f"cuda:{cf.local_rank}", @@ -573,11 +596,13 @@ def validate(self, mini_epoch, mode_cfg, batch_size): preds = self.model( self.model_params, batch.get_source_samples(), + step_callback=step_callback, ) else: preds = self.ema_model.forward_eval( self.model_params, batch.get_source_samples(), + step_callback=step_callback, ) targets_and_auxs = {} @@ -596,15 +621,11 @@ def validate(self, mini_epoch, mode_cfg, batch_size): metadata=extract_batch_metadata(batch), ) - # log output - if bidx < num_samples_write: - # denormalization function for data - denormalize_data_fct = ( - (lambda x0, x1: x1) - if mode_cfg.get("output", {}).get("normalized_samples", False) - else self.dataset_val.denormalize_target_channels - ) - # write output + # For streaming mode, flush remaining steps + if streaming_writer is not None: + streaming_writer.flush() + # For non-streaming mode, now write the complete output + elif bidx < num_samples_write: write_output( self.cf, mode_cfg, diff --git a/src/weathergen/utils/validation_io.py b/src/weathergen/utils/validation_io.py index 0e09fd38d..9e85d7efb 100644 --- a/src/weathergen/utils/validation_io.py +++ b/src/weathergen/utils/validation_io.py @@ -20,98 +20,12 @@ _logger = logging.getLogger(__name__) -def write_output( - cf, val_cfg, batch_size, mini_epoch, batch_idx, dn_data, batch, model_output, target_aux_out -): - """ - Interface for writing model output - """ - - # TODO: how to handle multiple physical loss terms - outputs_physical = [ - loss_name - for i, (loss_name, loss_term) in enumerate(val_cfg.losses.items()) - if loss_term.type == "LossPhysical" - ] - assert len(outputs_physical) == 1 - target_aux_out = target_aux_out[outputs_physical[0]] - - # collect all target / prediction-related information - fp32 = torch.float32 - preds_all, targets_all, targets_coords_all, targets_times_all = [], [], [], [] - - timestep_idxs = [0] if len(batch.get_output_idxs()) == 0 else batch.get_output_idxs() - forecast_offset = timestep_idxs[0] - targets_lens = [] - - # TODO Maybe stopping at forecast_steps explained #1657 - for t_idx in timestep_idxs: - preds_all += [[]] - targets_all += [[]] - targets_coords_all += [[]] - targets_times_all += [[]] - targets_lens += [[]] - for stream_info in cf.streams: - sname = stream_info["name"] - - # handle spoof data: do not write since it might corrupt validation (spoofing invisible - # there) - if target_aux_out.physical[t_idx][sname]["is_spoof"][0]: - preds = model_output.get_physical_prediction(t_idx, sname) - preds_shape = preds[0].shape - # for-loop to make sure we have a consistent number of samples - preds_s = [np.zeros((preds_shape[0], 0, preds_shape[2])) for _ in preds] - targets_s = [np.zeros((0, preds_shape[2])) for _ in preds] - t_coords_s = [np.zeros((0, 2)) for _ in preds] - t_times_s = [np.array([]).astype("datetime64[ns]") for _ in preds] - - else: - preds = model_output.get_physical_prediction(t_idx, sname) - targets = target_aux_out.physical[t_idx][sname]["target"] - - preds_s, targets_s, t_coords_s, t_times_s = [], [], [], [] - - # handle forcing streams or if sample is empty - if preds is None: - # preds are empty so create copy of target and add ensemble dimension - assert targets[0].shape[0] == 0, "Empty preds but non-empty targets." - preds = [target.clone().unsqueeze(0) for target in targets] - - for i_batch, (pred, target) in enumerate(zip(preds, targets, strict=True)): - # denormalize data if requested and map to storage format - preds_s += [dn_data(sname, pred).detach().to(fp32).cpu().numpy()] - targets_s += [dn_data(sname, target).detach().to(fp32).cpu().numpy()] - - # extract original target coords and times from target data - target_data = target_aux_out.physical[t_idx][sname] - t_coords_s += [target_data["target_coords"][i_batch].cpu().numpy()] - t_times_s += [target_data["target_times"][i_batch].astype("datetime64[ns]")] - - targets_lens[-1] += [[]] - targets_lens[-1][-1] += [t.shape[0] for t in targets_s] - - preds_all[-1] += [np.concatenate(preds_s, axis=1)] - targets_all[-1] += [np.concatenate(targets_s)] - targets_coords_all[-1] += [np.concatenate(t_coords_s)] - targets_times_all[-1] += [np.concatenate(t_times_s)] - - # # TODO: re-enable - # if len(idxs_inv) > 0: - # pred = pred[:, idxs_inv] - # target = target[idxs_inv] - # targets_coords_raw[t_idx][i_strm] = targets_coords_raw[t_idx][i_strm][idxs_inv] - # targets_times_raw[t_idx][i_strm] = targets_times_raw[t_idx][i_strm][idxs_inv] - - if len(preds_all) == 0 or np.array([p.shape[1] for pp in preds_all for p in pp]).sum() == 0: - _logger.warning("Writing no data since predictions are empty.") - return - - # collect source information +def _prepare_batch_data(cf, val_cfg, batch_size, batch_idx, batch): + """Prepare common data independent of forecast steps.""" sources = [] for sample in batch.get_source_samples().get_samples(): sources += [[]] for _, stream_data in sample.streams_data.items(): - # TODO: support multiple input steps sources[-1] += [stream_data.source_raw[0]] sample_idxs = [ @@ -119,9 +33,6 @@ def write_output( for sample in batch.get_source_samples().get_samples() ] - # more prep work - - # output stream names to be written, use specified ones or all if nothing specified stream_names = [stream.name for stream in cf.streams] if val_cfg.get("output").get("streams") is not None: output_stream_names = val_cfg.output.streams @@ -131,43 +42,238 @@ def write_output( output_streams = {name: stream_names.index(name) for name in output_stream_names} _logger.debug(f"Using output streams: {output_streams} from streams: {stream_names}") - target_channels: list[list[str]] = [list(stream.val_target_channels) for stream in cf.streams] - source_channels: list[list[str]] = [list(stream.val_source_channels) for stream in cf.streams] + target_channels = [list(stream.val_target_channels) for stream in cf.streams] + source_channels = [list(stream.val_source_channels) for stream in cf.streams] + geoinfo_channels = [[] for _ in cf.streams] - geoinfo_channels = [[] for _ in cf.streams] # TODO obtain channels - - # calculate global sample indices for this batch by offsetting by sample_start sample_start = batch_idx * batch_size - # write output - start_date = val_cfg.start_date end_date = val_cfg.end_date - - twh = TimeWindowHandler( - start_date, - end_date, - val_cfg.time_window_len, - val_cfg.time_window_step, - ) + twh = TimeWindowHandler(start_date, end_date, val_cfg.time_window_len, val_cfg.time_window_step) source_windows = (twh.window(idx) for idx in sample_idxs) source_intervals = [TimeRange(window.start, window.end) for window in source_windows] - data = io.OutputBatchData( + return ( sources, source_intervals, - targets_all, - preds_all, - targets_coords_all, - targets_times_all, - targets_lens, output_streams, target_channels, source_channels, geoinfo_channels, sample_start, - forecast_offset, ) - with zarrio_writer(config.get_path_results(cf, mini_epoch)) as zio: - for subset in data.items(): - zio.write_zarr(subset) + + +def _process_timestep(t_idx, dn_data, model_output, target_aux_out, cf): + """Process a single forecast step and return data for writing.""" + fp32 = torch.float32 + preds_step = [] + targets_step = [] + targets_coords_step = [] + targets_times_step = [] + targets_lens_step = [] + + for stream_info in cf.streams: + sname = stream_info["name"] + + if target_aux_out.physical[t_idx][sname]["is_spoof"][0]: + preds = model_output.get_physical_prediction(t_idx, sname) + preds_shape = preds[0].shape if preds else (1, 1, 1) + preds_s = [np.zeros((preds_shape[0], 0, preds_shape[2])) for _ in range(len(preds) if preds else 1)] + targets_s = [np.zeros((0, preds_shape[2])) for _ in range(len(preds) if preds else 1)] + t_coords_s = [np.zeros((0, 2)) for _ in range(len(preds) if preds else 1)] + t_times_s = [np.array([]).astype("datetime64[ns]") for _ in range(len(preds) if preds else 1)] + else: + preds = model_output.get_physical_prediction(t_idx, sname) + targets = target_aux_out.physical[t_idx][sname]["target"] + + preds_s, targets_s, t_coords_s, t_times_s = [], [], [], [] + + if preds is None: + assert targets[0].shape[0] == 0, "Empty preds but non-empty targets." + preds = [target.clone().unsqueeze(0) for target in targets] + + for i_batch, (pred, target) in enumerate(zip(preds, targets, strict=True)): + preds_s += [dn_data(sname, pred).detach().to(fp32).cpu().numpy()] + targets_s += [dn_data(sname, target).detach().to(fp32).cpu().numpy()] + target_data = target_aux_out.physical[t_idx][sname] + t_coords_s += [target_data["target_coords"][i_batch].cpu().numpy()] + t_times_s += [target_data["target_times"][i_batch].astype("datetime64[ns]")] + + targets_lens_step += [t.shape[0] for t in targets_s] + preds_step += [np.concatenate(preds_s, axis=1)] + targets_step += [np.concatenate(targets_s)] + targets_coords_step += [np.concatenate(t_coords_s)] + targets_times_step += [np.concatenate(t_times_s)] + + return preds_step, targets_step, targets_coords_step, targets_times_step, targets_lens_step + + +class StreamingOutputWriter: + """Manages streaming output writing with callback mechanism.""" + + def __init__(self, cf, val_cfg, batch_size, mini_epoch, batch_idx, dn_data, batch, model_output, target_aux_out): + self.cf = cf + self.val_cfg = val_cfg + self.dn_data = dn_data + self.model_output = model_output + + # Extract physical target_aux + outputs_physical = [ + loss_name for loss_name, loss_term in val_cfg.losses.items() + if loss_term.type == "LossPhysical" + ] + assert len(outputs_physical) == 1 + self.target_aux_out = target_aux_out[outputs_physical[0]] + + # Prepare batch data once + ( + self.sources, + self.source_intervals, + self.output_streams, + self.target_channels, + self.source_channels, + self.geoinfo_channels, + self.sample_start, + ) = _prepare_batch_data(cf, val_cfg, batch_size, batch_idx, batch) + + # Streaming config + self.streaming_cfg = val_cfg.get("output", {}).get("streaming", {}) + self.write_freq = self.streaming_cfg.get("num_steps", 1) if self.streaming_cfg.get("num_steps") else 1 + + # Accumulate steps before writing + self.accumulated_steps = [] + self._zarrio_writer = None + + def create_callback(self): + """Create and return the step callback for model.forward().""" + def callback(step, output): + self.accumulated_steps.append(step) + # Write when we have accumulated enough steps + if len(self.accumulated_steps) >= self.write_freq: + self._write_accumulated_steps() + self.accumulated_steps = [] + return callback + + def _write_accumulated_steps(self): + """Write the accumulated steps to zarr.""" + if self._zarrio_writer is None: + self._zarrio_writer = zarrio_writer(config.get_path_results(self.cf, self.val_cfg.mini_epoch)).__enter__() + + for t_idx in self.accumulated_steps: + preds_step, targets_step, targets_coords_step, targets_times_step, targets_lens_step = ( + _process_timestep(t_idx, self.dn_data, self.model_output, self.target_aux_out, self.cf) + ) + + if len(preds_step) == 0 or np.array([p.shape[1] for p in preds_step]).sum() == 0: + _logger.warning(f"Empty predictions for step {t_idx}") + continue + + data = io.OutputBatchData( + self.sources, + self.source_intervals, + [preds_step], + [targets_step], + [targets_coords_step], + [targets_times_step], + [[targets_lens_step]], + self.output_streams, + self.target_channels, + self.source_channels, + self.geoinfo_channels, + self.sample_start, + t_idx, + ) + for subset in data.items(): + self._zarrio_writer.write_zarr(subset) + + def flush(self): + """Write any remaining accumulated steps and close.""" + if self.accumulated_steps: + self._write_accumulated_steps() + if self._zarrio_writer is not None: + self._zarrio_writer.__exit__(None, None, None) + + +def write_output( + cf, val_cfg, batch_size, mini_epoch, batch_idx, dn_data, batch, model_output, target_aux_out +): + """Write model output with optional streaming based on config.""" + + # TODO: how to handle multiple physical loss terms + outputs_physical = [ + loss_name for loss_name, loss_term in val_cfg.losses.items() + if loss_term.type == "LossPhysical" + ] + assert len(outputs_physical) == 1 + target_aux_out_physical = target_aux_out[outputs_physical[0]] + + # Check if streaming is enabled + streaming_cfg = val_cfg.get("output", {}).get("streaming", {}) + streaming_enabled = streaming_cfg.get("enabled", False) + + timestep_idxs = [0] if len(batch.get_output_idxs()) == 0 else batch.get_output_idxs() + + if streaming_enabled: + # Create streaming writer and return callback for use in model.forward() + writer = StreamingOutputWriter( + cf, val_cfg, batch_size, mini_epoch, batch_idx, dn_data, batch, + model_output, target_aux_out + ) + # Update mini_epoch in val_cfg (needed for zarr path) + val_cfg.mini_epoch = mini_epoch + + # Return callback for model.forward() to use + return writer.create_callback(), writer + else: + # Standard non-streaming: process all steps and write together + forecast_offset = timestep_idxs[0] + preds_all, targets_all, targets_coords_all, targets_times_all, targets_lens_all = [], [], [], [], [] + + for t_idx in timestep_idxs: + preds_step, targets_step, targets_coords_step, targets_times_step, targets_lens_step = ( + _process_timestep(t_idx, dn_data, model_output, target_aux_out_physical, cf) + ) + + if len(preds_step) == 0 or np.array([p.shape[1] for p in preds_step]).sum() == 0: + _logger.warning(f"Empty predictions for step {t_idx}") + continue + + preds_all.append(preds_step) + targets_all.append(targets_step) + targets_coords_all.append(targets_coords_step) + targets_times_all.append(targets_times_step) + targets_lens_all.append([targets_lens_step]) + + if len(preds_all) > 0: + ( + sources, + source_intervals, + output_streams, + target_channels, + source_channels, + geoinfo_channels, + sample_start, + ) = _prepare_batch_data(cf, val_cfg, batch_size, batch_idx, batch) + + data = io.OutputBatchData( + sources, + source_intervals, + preds_all, + targets_all, + targets_coords_all, + targets_times_all, + targets_lens_all, + output_streams, + target_channels, + source_channels, + geoinfo_channels, + sample_start, + forecast_offset, + ) + with zarrio_writer(config.get_path_results(cf, mini_epoch)) as zio: + for subset in data.items(): + zio.write_zarr(subset) + + return None, None From 498fcfb062e2544a3d854edfe3fc642dfd11ab40 Mon Sep 17 00:00:00 2001 From: evenmn Date: Tue, 24 Feb 2026 14:49:52 +0100 Subject: [PATCH 2/3] Linting Signed-off-by: evenmn --- src/weathergen/model/embeddings.py | 2 +- src/weathergen/model/model.py | 6 ++- src/weathergen/utils/validation_io.py | 77 ++++++++++++++++++++------- 3 files changed, 63 insertions(+), 22 deletions(-) diff --git a/src/weathergen/model/embeddings.py b/src/weathergen/model/embeddings.py index 09861d5db..aaf43ef0c 100644 --- a/src/weathergen/model/embeddings.py +++ b/src/weathergen/model/embeddings.py @@ -142,7 +142,7 @@ def forward_channels(self, x_in): x = peh(self.embed(x_in.transpose(-2, -1))) for layer in self.layers: - x = checkpoint(layer,x, use_reentrant=False) + x = checkpoint(layer, x, use_reentrant=False) # read out if self.unembed_mode == "full": diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 6c13fceef..2d98f22b2 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -614,7 +614,9 @@ def tokens_to_latent_state(self, tokens_post_norm, tokens) -> LatentState: z_pre_norm=tokens, ) - def forward(self, model_params: ModelParams, batch: ModelBatch, step_callback=None) -> ModelOutput: + def forward( + self, model_params: ModelParams, batch: ModelBatch, step_callback=None + ) -> ModelOutput: """Forward pass of the model Tokens are processed through the model components, which were defined in the create method. @@ -650,7 +652,7 @@ def forward(self, model_params: ModelParams, batch: ModelBatch, step_callback=No output = self.predict_decoders(model_params, step, tokens, batch, output) # latent predictions (raw and with SSL heads) output = self.predict_latent(model_params, step, tokens, batch, output) - + # invoke callback for streaming output if provided if step_callback is not None: step_callback(step, output) diff --git a/src/weathergen/utils/validation_io.py b/src/weathergen/utils/validation_io.py index 9e85d7efb..b0802e2f2 100644 --- a/src/weathergen/utils/validation_io.py +++ b/src/weathergen/utils/validation_io.py @@ -80,10 +80,15 @@ def _process_timestep(t_idx, dn_data, model_output, target_aux_out, cf): if target_aux_out.physical[t_idx][sname]["is_spoof"][0]: preds = model_output.get_physical_prediction(t_idx, sname) preds_shape = preds[0].shape if preds else (1, 1, 1) - preds_s = [np.zeros((preds_shape[0], 0, preds_shape[2])) for _ in range(len(preds) if preds else 1)] + preds_s = [ + np.zeros((preds_shape[0], 0, preds_shape[2])) + for _ in range(len(preds) if preds else 1) + ] targets_s = [np.zeros((0, preds_shape[2])) for _ in range(len(preds) if preds else 1)] t_coords_s = [np.zeros((0, 2)) for _ in range(len(preds) if preds else 1)] - t_times_s = [np.array([]).astype("datetime64[ns]") for _ in range(len(preds) if preds else 1)] + t_times_s = [ + np.array([]).astype("datetime64[ns]") for _ in range(len(preds) if preds else 1) + ] else: preds = model_output.get_physical_prediction(t_idx, sname) targets = target_aux_out.physical[t_idx][sname]["target"] @@ -113,20 +118,32 @@ def _process_timestep(t_idx, dn_data, model_output, target_aux_out, cf): class StreamingOutputWriter: """Manages streaming output writing with callback mechanism.""" - def __init__(self, cf, val_cfg, batch_size, mini_epoch, batch_idx, dn_data, batch, model_output, target_aux_out): + def __init__( + self, + cf, + val_cfg, + batch_size, + mini_epoch, + batch_idx, + dn_data, + batch, + model_output, + target_aux_out, + ): self.cf = cf self.val_cfg = val_cfg self.dn_data = dn_data self.model_output = model_output - + # Extract physical target_aux outputs_physical = [ - loss_name for loss_name, loss_term in val_cfg.losses.items() + loss_name + for loss_name, loss_term in val_cfg.losses.items() if loss_term.type == "LossPhysical" ] assert len(outputs_physical) == 1 self.target_aux_out = target_aux_out[outputs_physical[0]] - + # Prepare batch data once ( self.sources, @@ -137,33 +154,41 @@ def __init__(self, cf, val_cfg, batch_size, mini_epoch, batch_idx, dn_data, batc self.geoinfo_channels, self.sample_start, ) = _prepare_batch_data(cf, val_cfg, batch_size, batch_idx, batch) - + # Streaming config self.streaming_cfg = val_cfg.get("output", {}).get("streaming", {}) - self.write_freq = self.streaming_cfg.get("num_steps", 1) if self.streaming_cfg.get("num_steps") else 1 - + self.write_freq = ( + self.streaming_cfg.get("num_steps", 1) if self.streaming_cfg.get("num_steps") else 1 + ) + # Accumulate steps before writing self.accumulated_steps = [] self._zarrio_writer = None def create_callback(self): """Create and return the step callback for model.forward().""" + def callback(step, output): self.accumulated_steps.append(step) # Write when we have accumulated enough steps if len(self.accumulated_steps) >= self.write_freq: self._write_accumulated_steps() self.accumulated_steps = [] + return callback def _write_accumulated_steps(self): """Write the accumulated steps to zarr.""" if self._zarrio_writer is None: - self._zarrio_writer = zarrio_writer(config.get_path_results(self.cf, self.val_cfg.mini_epoch)).__enter__() + self._zarrio_writer = zarrio_writer( + config.get_path_results(self.cf, self.val_cfg.mini_epoch) + ).__enter__() for t_idx in self.accumulated_steps: preds_step, targets_step, targets_coords_step, targets_times_step, targets_lens_step = ( - _process_timestep(t_idx, self.dn_data, self.model_output, self.target_aux_out, self.cf) + _process_timestep( + t_idx, self.dn_data, self.model_output, self.target_aux_out, self.cf + ) ) if len(preds_step) == 0 or np.array([p.shape[1] for p in preds_step]).sum() == 0: @@ -200,10 +225,11 @@ def write_output( cf, val_cfg, batch_size, mini_epoch, batch_idx, dn_data, batch, model_output, target_aux_out ): """Write model output with optional streaming based on config.""" - + # TODO: how to handle multiple physical loss terms outputs_physical = [ - loss_name for loss_name, loss_term in val_cfg.losses.items() + loss_name + for loss_name, loss_term in val_cfg.losses.items() if loss_term.type == "LossPhysical" ] assert len(outputs_physical) == 1 @@ -214,22 +240,35 @@ def write_output( streaming_enabled = streaming_cfg.get("enabled", False) timestep_idxs = [0] if len(batch.get_output_idxs()) == 0 else batch.get_output_idxs() - + if streaming_enabled: # Create streaming writer and return callback for use in model.forward() writer = StreamingOutputWriter( - cf, val_cfg, batch_size, mini_epoch, batch_idx, dn_data, batch, - model_output, target_aux_out + cf, + val_cfg, + batch_size, + mini_epoch, + batch_idx, + dn_data, + batch, + model_output, + target_aux_out, ) # Update mini_epoch in val_cfg (needed for zarr path) val_cfg.mini_epoch = mini_epoch - + # Return callback for model.forward() to use return writer.create_callback(), writer else: # Standard non-streaming: process all steps and write together forecast_offset = timestep_idxs[0] - preds_all, targets_all, targets_coords_all, targets_times_all, targets_lens_all = [], [], [], [], [] + preds_all, targets_all, targets_coords_all, targets_times_all, targets_lens_all = ( + [], + [], + [], + [], + [], + ) for t_idx in timestep_idxs: preds_step, targets_step, targets_coords_step, targets_times_step, targets_lens_step = ( @@ -275,5 +314,5 @@ def write_output( with zarrio_writer(config.get_path_results(cf, mini_epoch)) as zio: for subset in data.items(): zio.write_zarr(subset) - + return None, None From e8d2cd0b8fbc13373282c496b0f04ba7c593c019 Mon Sep 17 00:00:00 2001 From: evenmn Date: Wed, 25 Feb 2026 10:57:38 +0100 Subject: [PATCH 3/3] Removed callbacks from Model.forward(), made non-streaming case a special case of streaming Signed-off-by: evenmn --- config/default_config.yml | 10 +- src/weathergen/model/model.py | 13 +- src/weathergen/train/trainer.py | 65 +---- src/weathergen/utils/validation_io.py | 362 ++++++++------------------ 4 files changed, 124 insertions(+), 326 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index 4214ffeaf..a19542f37 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -220,14 +220,8 @@ validation_config: normalized_samples: False, # output streams to write; default all streams: null, - # streaming output configuration for inference - streaming: { - # enable streaming mode: write output after each N forecast steps instead of all at once - enabled: False, - # number of forecast steps to process before writing to disk - # if null, writes after each step (num_steps=1) - num_steps: null, - } + # number of forecast steps to accumulate before writing to disk, default no streaming + fstep_chunk_size: null, } # run validation before training starts (mainly for model development) diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 2d98f22b2..ae6809eca 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -614,20 +614,13 @@ def tokens_to_latent_state(self, tokens_post_norm, tokens) -> LatentState: z_pre_norm=tokens, ) - def forward( - self, model_params: ModelParams, batch: ModelBatch, step_callback=None - ) -> ModelOutput: + def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput: """Forward pass of the model Tokens are processed through the model components, which were defined in the create method. Args: model_params : Query and embedding parameters batch : Batch of data - step_callback : Optional callback function to be called after each forecast step. - Called as step_callback(step, output) where step is the forecast step - index and output is the partial ModelOutput for that step. - Used for streaming output writing - allows writing after each step - instead of accumulating all steps in memory. Returns: A list containing all prediction results """ @@ -653,10 +646,6 @@ def forward( # latent predictions (raw and with SSL heads) output = self.predict_latent(model_params, step, tokens, batch, output) - # invoke callback for streaming output if provided - if step_callback is not None: - step_callback(step, output) - return output def predict_latent( diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 8caa3d770..961b30947 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -563,6 +563,12 @@ def validate(self, mini_epoch, mode_cfg, batch_size): batch.to_device(self.device) + # Forward pass + preds = self.model( + self.model_params, + batch.get_source_samples(), + ) + # denormalization function for data denormalize_data_fct = ( (lambda x0, x1: x1) @@ -570,62 +576,9 @@ def validate(self, mini_epoch, mode_cfg, batch_size): else self.dataset_val.denormalize_target_channels ) - # Prepare for streaming or standard write - step_callback = None - streaming_writer = None + # Write output after forward completes + # Single unified path: write_output handles both streaming and non-streaming if bidx < num_samples_write: - step_callback, streaming_writer = write_output( - self.cf, - mode_cfg, - batch_size, - mini_epoch, - bidx, - denormalize_data_fct, - batch, - None, # model_output (will be computed below) - self.target_and_aux_calculators_val, - ) - - # evaluate model - with torch.autocast( - device_type=f"cuda:{cf.local_rank}", - dtype=self.mixed_precision_dtype, - enabled=cf.with_mixed_precision, - ): - if self.ema_model is None: - preds = self.model( - self.model_params, - batch.get_source_samples(), - step_callback=step_callback, - ) - else: - preds = self.ema_model.forward_eval( - self.model_params, - batch.get_source_samples(), - step_callback=step_callback, - ) - - targets_and_auxs = {} - for loss_name, target_aux in self.target_and_aux_calculators_val.items(): - target_idxs = get_target_idxs_from_cfg(mode_cfg, loss_name) - targets_and_auxs[loss_name] = target_aux.compute( - self.cf.general.istep, - batch.get_target_samples(target_idxs), - self.model_params, - self.model, - ) - - _ = self.loss_calculator_val.compute_loss( - preds=preds, - targets_and_aux=targets_and_auxs, - metadata=extract_batch_metadata(batch), - ) - - # For streaming mode, flush remaining steps - if streaming_writer is not None: - streaming_writer.flush() - # For non-streaming mode, now write the complete output - elif bidx < num_samples_write: write_output( self.cf, mode_cfg, @@ -635,7 +588,7 @@ def validate(self, mini_epoch, mode_cfg, batch_size): denormalize_data_fct, batch, preds, - targets_and_auxs, + self.target_and_aux_calculators_val, ) pbar.update(batch_size) diff --git a/src/weathergen/utils/validation_io.py b/src/weathergen/utils/validation_io.py index b0802e2f2..4ba8617cc 100644 --- a/src/weathergen/utils/validation_io.py +++ b/src/weathergen/utils/validation_io.py @@ -20,8 +20,36 @@ _logger = logging.getLogger(__name__) -def _prepare_batch_data(cf, val_cfg, batch_size, batch_idx, batch): - """Prepare common data independent of forecast steps.""" +def write_output( + cf, val_cfg, batch_size, mini_epoch, batch_idx, dn_data, batch, model_output, target_aux_out +): + """Write model output with configurable step chunking. + + Unifies streaming and non-streaming modes via output.fstep_chunk_size: + - fstep_chunk_size: 1 → Stream after each step + - fstep_chunk_size: N → Accumulate N steps, then write + - fstep_chunk_size: total_steps → Non-streaming (write all at once) + - fstep_chunk_size: None (default) → No streaming (equivalent to total_steps) + """ + + # TODO: how to handle multiple physical loss terms + outputs_physical = [ + loss_name + for loss_name, loss_term in val_cfg.losses.items() + if loss_term.type == "LossPhysical" + ] + assert len(outputs_physical) == 1 + target_aux_out_physical = target_aux_out[outputs_physical[0]] + + timestep_idxs = [0] if len(batch.get_output_idxs()) == 0 else batch.get_output_idxs() + total_steps = len(timestep_idxs) + + # Get chunking configuration (default: no streaming) + fstep_chunk_size = val_cfg.get("output", {}).get("fstep_chunk_size", None) + if fstep_chunk_size is None: + fstep_chunk_size = total_steps + + # Collect source information (once, outside chunking loop) sources = [] for sample in batch.get_source_samples().get_samples(): sources += [[]] @@ -33,6 +61,7 @@ def _prepare_batch_data(cf, val_cfg, batch_size, batch_idx, batch): for sample in batch.get_source_samples().get_samples() ] + # Output stream configuration stream_names = [stream.name for stream in cf.streams] if val_cfg.get("output").get("streams") is not None: output_stream_names = val_cfg.output.streams @@ -48,262 +77,98 @@ def _prepare_batch_data(cf, val_cfg, batch_size, batch_idx, batch): sample_start = batch_idx * batch_size + # Calculate source intervals start_date = val_cfg.start_date end_date = val_cfg.end_date twh = TimeWindowHandler(start_date, end_date, val_cfg.time_window_len, val_cfg.time_window_step) source_windows = (twh.window(idx) for idx in sample_idxs) source_intervals = [TimeRange(window.start, window.end) for window in source_windows] - return ( - sources, - source_intervals, - output_streams, - target_channels, - source_channels, - geoinfo_channels, - sample_start, - ) - - -def _process_timestep(t_idx, dn_data, model_output, target_aux_out, cf): - """Process a single forecast step and return data for writing.""" + # Write in chunks based on fstep_chunk_size fp32 = torch.float32 - preds_step = [] - targets_step = [] - targets_coords_step = [] - targets_times_step = [] - targets_lens_step = [] - - for stream_info in cf.streams: - sname = stream_info["name"] - - if target_aux_out.physical[t_idx][sname]["is_spoof"][0]: - preds = model_output.get_physical_prediction(t_idx, sname) - preds_shape = preds[0].shape if preds else (1, 1, 1) - preds_s = [ - np.zeros((preds_shape[0], 0, preds_shape[2])) - for _ in range(len(preds) if preds else 1) - ] - targets_s = [np.zeros((0, preds_shape[2])) for _ in range(len(preds) if preds else 1)] - t_coords_s = [np.zeros((0, 2)) for _ in range(len(preds) if preds else 1)] - t_times_s = [ - np.array([]).astype("datetime64[ns]") for _ in range(len(preds) if preds else 1) - ] - else: - preds = model_output.get_physical_prediction(t_idx, sname) - targets = target_aux_out.physical[t_idx][sname]["target"] - - preds_s, targets_s, t_coords_s, t_times_s = [], [], [], [] - - if preds is None: - assert targets[0].shape[0] == 0, "Empty preds but non-empty targets." - preds = [target.clone().unsqueeze(0) for target in targets] - - for i_batch, (pred, target) in enumerate(zip(preds, targets, strict=True)): - preds_s += [dn_data(sname, pred).detach().to(fp32).cpu().numpy()] - targets_s += [dn_data(sname, target).detach().to(fp32).cpu().numpy()] - target_data = target_aux_out.physical[t_idx][sname] - t_coords_s += [target_data["target_coords"][i_batch].cpu().numpy()] - t_times_s += [target_data["target_times"][i_batch].astype("datetime64[ns]")] - - targets_lens_step += [t.shape[0] for t in targets_s] - preds_step += [np.concatenate(preds_s, axis=1)] - targets_step += [np.concatenate(targets_s)] - targets_coords_step += [np.concatenate(t_coords_s)] - targets_times_step += [np.concatenate(t_times_s)] - - return preds_step, targets_step, targets_coords_step, targets_times_step, targets_lens_step - - -class StreamingOutputWriter: - """Manages streaming output writing with callback mechanism.""" - - def __init__( - self, - cf, - val_cfg, - batch_size, - mini_epoch, - batch_idx, - dn_data, - batch, - model_output, - target_aux_out, - ): - self.cf = cf - self.val_cfg = val_cfg - self.dn_data = dn_data - self.model_output = model_output - - # Extract physical target_aux - outputs_physical = [ - loss_name - for loss_name, loss_term in val_cfg.losses.items() - if loss_term.type == "LossPhysical" - ] - assert len(outputs_physical) == 1 - self.target_aux_out = target_aux_out[outputs_physical[0]] - - # Prepare batch data once - ( - self.sources, - self.source_intervals, - self.output_streams, - self.target_channels, - self.source_channels, - self.geoinfo_channels, - self.sample_start, - ) = _prepare_batch_data(cf, val_cfg, batch_size, batch_idx, batch) - - # Streaming config - self.streaming_cfg = val_cfg.get("output", {}).get("streaming", {}) - self.write_freq = ( - self.streaming_cfg.get("num_steps", 1) if self.streaming_cfg.get("num_steps") else 1 - ) - - # Accumulate steps before writing - self.accumulated_steps = [] - self._zarrio_writer = None - - def create_callback(self): - """Create and return the step callback for model.forward().""" - - def callback(step, output): - self.accumulated_steps.append(step) - # Write when we have accumulated enough steps - if len(self.accumulated_steps) >= self.write_freq: - self._write_accumulated_steps() - self.accumulated_steps = [] - - return callback - - def _write_accumulated_steps(self): - """Write the accumulated steps to zarr.""" - if self._zarrio_writer is None: - self._zarrio_writer = zarrio_writer( - config.get_path_results(self.cf, self.val_cfg.mini_epoch) - ).__enter__() - - for t_idx in self.accumulated_steps: - preds_step, targets_step, targets_coords_step, targets_times_step, targets_lens_step = ( - _process_timestep( - t_idx, self.dn_data, self.model_output, self.target_aux_out, self.cf - ) - ) - - if len(preds_step) == 0 or np.array([p.shape[1] for p in preds_step]).sum() == 0: - _logger.warning(f"Empty predictions for step {t_idx}") + with zarrio_writer(config.get_path_results(cf, mini_epoch)) as zio: + for chunk_start in range(0, total_steps, fstep_chunk_size): + chunk_end = min(chunk_start + fstep_chunk_size, total_steps) + chunk_indices = timestep_idxs[chunk_start:chunk_end] + + # Process and write chunk + preds_chunk, targets_chunk, targets_coords_chunk, targets_times_chunk = [], [], [], [] + targets_lens_chunk = [] + + for t_idx in chunk_indices: + preds_step = [] + targets_step = [] + targets_coords_step = [] + targets_times_step = [] + targets_lens_step = [] + + for stream_info in cf.streams: + sname = stream_info["name"] + + if target_aux_out_physical.physical[t_idx][sname]["is_spoof"][0]: + preds = model_output.get_physical_prediction(t_idx, sname) + preds_shape = preds[0].shape if preds else (1, 1, 1) + preds_s = [ + np.zeros((preds_shape[0], 0, preds_shape[2])) + for _ in range(len(preds) if preds else 1) + ] + targets_s = [ + np.zeros((0, preds_shape[2])) for _ in range(len(preds) if preds else 1) + ] + t_coords_s = [np.zeros((0, 2)) for _ in range(len(preds) if preds else 1)] + t_times_s = [ + np.array([]).astype("datetime64[ns]") + for _ in range(len(preds) if preds else 1) + ] + else: + preds = model_output.get_physical_prediction(t_idx, sname) + targets = target_aux_out_physical.physical[t_idx][sname]["target"] + + preds_s, targets_s, t_coords_s, t_times_s = [], [], [], [] + + if preds is None: + assert targets[0].shape[0] == 0, "Empty preds but non-empty targets." + preds = [target.clone().unsqueeze(0) for target in targets] + + for i_batch, (pred, target) in enumerate(zip(preds, targets, strict=True)): + preds_s += [dn_data(sname, pred).detach().to(fp32).cpu().numpy()] + targets_s += [dn_data(sname, target).detach().to(fp32).cpu().numpy()] + target_data = target_aux_out_physical.physical[t_idx][sname] + t_coords_s += [target_data["target_coords"][i_batch].cpu().numpy()] + t_times_s += [ + target_data["target_times"][i_batch].astype("datetime64[ns]") + ] + + targets_lens_step += [t.shape[0] for t in targets_s] + preds_step += [np.concatenate(preds_s, axis=1)] + targets_step += [np.concatenate(targets_s)] + targets_coords_step += [np.concatenate(t_coords_s)] + targets_times_step += [np.concatenate(t_times_s)] + + if len(preds_step) == 0 or np.array([p.shape[1] for p in preds_step]).sum() == 0: + _logger.warning(f"Empty predictions for step {t_idx}") + continue + + preds_chunk.append(preds_step) + targets_chunk.append(targets_step) + targets_coords_chunk.append(targets_coords_step) + targets_times_chunk.append(targets_times_step) + targets_lens_chunk.append([targets_lens_step]) + + if len(preds_chunk) == 0: continue - data = io.OutputBatchData( - self.sources, - self.source_intervals, - [preds_step], - [targets_step], - [targets_coords_step], - [targets_times_step], - [[targets_lens_step]], - self.output_streams, - self.target_channels, - self.source_channels, - self.geoinfo_channels, - self.sample_start, - t_idx, - ) - for subset in data.items(): - self._zarrio_writer.write_zarr(subset) - - def flush(self): - """Write any remaining accumulated steps and close.""" - if self.accumulated_steps: - self._write_accumulated_steps() - if self._zarrio_writer is not None: - self._zarrio_writer.__exit__(None, None, None) - - -def write_output( - cf, val_cfg, batch_size, mini_epoch, batch_idx, dn_data, batch, model_output, target_aux_out -): - """Write model output with optional streaming based on config.""" - - # TODO: how to handle multiple physical loss terms - outputs_physical = [ - loss_name - for loss_name, loss_term in val_cfg.losses.items() - if loss_term.type == "LossPhysical" - ] - assert len(outputs_physical) == 1 - target_aux_out_physical = target_aux_out[outputs_physical[0]] - - # Check if streaming is enabled - streaming_cfg = val_cfg.get("output", {}).get("streaming", {}) - streaming_enabled = streaming_cfg.get("enabled", False) - - timestep_idxs = [0] if len(batch.get_output_idxs()) == 0 else batch.get_output_idxs() - - if streaming_enabled: - # Create streaming writer and return callback for use in model.forward() - writer = StreamingOutputWriter( - cf, - val_cfg, - batch_size, - mini_epoch, - batch_idx, - dn_data, - batch, - model_output, - target_aux_out, - ) - # Update mini_epoch in val_cfg (needed for zarr path) - val_cfg.mini_epoch = mini_epoch - - # Return callback for model.forward() to use - return writer.create_callback(), writer - else: - # Standard non-streaming: process all steps and write together - forecast_offset = timestep_idxs[0] - preds_all, targets_all, targets_coords_all, targets_times_all, targets_lens_all = ( - [], - [], - [], - [], - [], - ) - - for t_idx in timestep_idxs: - preds_step, targets_step, targets_coords_step, targets_times_step, targets_lens_step = ( - _process_timestep(t_idx, dn_data, model_output, target_aux_out_physical, cf) - ) - - if len(preds_step) == 0 or np.array([p.shape[1] for p in preds_step]).sum() == 0: - _logger.warning(f"Empty predictions for step {t_idx}") - continue - - preds_all.append(preds_step) - targets_all.append(targets_step) - targets_coords_all.append(targets_coords_step) - targets_times_all.append(targets_times_step) - targets_lens_all.append([targets_lens_step]) - - if len(preds_all) > 0: - ( - sources, - source_intervals, - output_streams, - target_channels, - source_channels, - geoinfo_channels, - sample_start, - ) = _prepare_batch_data(cf, val_cfg, batch_size, batch_idx, batch) + # Determine forecast offset for this chunk + forecast_offset = chunk_indices[0] data = io.OutputBatchData( sources, source_intervals, - preds_all, - targets_all, - targets_coords_all, - targets_times_all, - targets_lens_all, + preds_chunk, + targets_chunk, + targets_coords_chunk, + targets_times_chunk, + targets_lens_chunk, output_streams, target_channels, source_channels, @@ -311,8 +176,5 @@ def write_output( sample_start, forecast_offset, ) - with zarrio_writer(config.get_path_results(cf, mini_epoch)) as zio: - for subset in data.items(): - zio.write_zarr(subset) - - return None, None + for subset in data.items(): + zio.write_zarr(subset)