diff --git a/config/default_config.yml b/config/default_config.yml index 30dd87afd..a19542f37 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -220,6 +220,8 @@ validation_config: normalized_samples: False, # output streams to write; default all streams: null, + # number of forecast steps to accumulate before writing to disk, default no streaming + fstep_chunk_size: null, } # run validation before training starts (mainly for model development) diff --git a/src/weathergen/model/embeddings.py b/src/weathergen/model/embeddings.py index 09861d5db..aaf43ef0c 100644 --- a/src/weathergen/model/embeddings.py +++ b/src/weathergen/model/embeddings.py @@ -142,7 +142,7 @@ def forward_channels(self, x_in): x = peh(self.embed(x_in.transpose(-2, -1))) for layer in self.layers: - x = checkpoint(layer,x, use_reentrant=False) + x = checkpoint(layer, x, use_reentrant=False) # read out if self.unembed_mode == "full": diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 273c28838..ae6809eca 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -620,7 +620,7 @@ def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput: Tokens are processed through the model components, which were defined in the create method. Args: model_params : Query and embedding parameters - batch + batch : Batch of data Returns: A list containing all prediction results """ diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index bfa289c7b..961b30947 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -563,48 +563,22 @@ def validate(self, mini_epoch, mode_cfg, batch_size): batch.to_device(self.device) - # evaluate model - with torch.autocast( - device_type=f"cuda:{cf.local_rank}", - dtype=self.mixed_precision_dtype, - enabled=cf.with_mixed_precision, - ): - if self.ema_model is None: - preds = self.model( - self.model_params, - batch.get_source_samples(), - ) - else: - preds = self.ema_model.forward_eval( - self.model_params, - batch.get_source_samples(), - ) - - targets_and_auxs = {} - for loss_name, target_aux in self.target_and_aux_calculators_val.items(): - target_idxs = get_target_idxs_from_cfg(mode_cfg, loss_name) - targets_and_auxs[loss_name] = target_aux.compute( - self.cf.general.istep, - batch.get_target_samples(target_idxs), - self.model_params, - self.model, - ) - - _ = self.loss_calculator_val.compute_loss( - preds=preds, - targets_and_aux=targets_and_auxs, - metadata=extract_batch_metadata(batch), + # Forward pass + preds = self.model( + self.model_params, + batch.get_source_samples(), + ) + + # denormalization function for data + denormalize_data_fct = ( + (lambda x0, x1: x1) + if mode_cfg.get("output", {}).get("normalized_samples", False) + else self.dataset_val.denormalize_target_channels ) - # log output + # Write output after forward completes + # Single unified path: write_output handles both streaming and non-streaming if bidx < num_samples_write: - # denormalization function for data - denormalize_data_fct = ( - (lambda x0, x1: x1) - if mode_cfg.get("output", {}).get("normalized_samples", False) - else self.dataset_val.denormalize_target_channels - ) - # write output write_output( self.cf, mode_cfg, @@ -614,7 +588,7 @@ def validate(self, mini_epoch, mode_cfg, batch_size): denormalize_data_fct, batch, preds, - targets_and_auxs, + self.target_and_aux_calculators_val, ) pbar.update(batch_size) diff --git a/src/weathergen/utils/validation_io.py b/src/weathergen/utils/validation_io.py index 0e09fd38d..4ba8617cc 100644 --- a/src/weathergen/utils/validation_io.py +++ b/src/weathergen/utils/validation_io.py @@ -23,95 +23,37 @@ def write_output( cf, val_cfg, batch_size, mini_epoch, batch_idx, dn_data, batch, model_output, target_aux_out ): - """ - Interface for writing model output + """Write model output with configurable step chunking. + + Unifies streaming and non-streaming modes via output.fstep_chunk_size: + - fstep_chunk_size: 1 → Stream after each step + - fstep_chunk_size: N → Accumulate N steps, then write + - fstep_chunk_size: total_steps → Non-streaming (write all at once) + - fstep_chunk_size: None (default) → No streaming (equivalent to total_steps) """ # TODO: how to handle multiple physical loss terms outputs_physical = [ loss_name - for i, (loss_name, loss_term) in enumerate(val_cfg.losses.items()) + for loss_name, loss_term in val_cfg.losses.items() if loss_term.type == "LossPhysical" ] assert len(outputs_physical) == 1 - target_aux_out = target_aux_out[outputs_physical[0]] - - # collect all target / prediction-related information - fp32 = torch.float32 - preds_all, targets_all, targets_coords_all, targets_times_all = [], [], [], [] + target_aux_out_physical = target_aux_out[outputs_physical[0]] timestep_idxs = [0] if len(batch.get_output_idxs()) == 0 else batch.get_output_idxs() - forecast_offset = timestep_idxs[0] - targets_lens = [] - - # TODO Maybe stopping at forecast_steps explained #1657 - for t_idx in timestep_idxs: - preds_all += [[]] - targets_all += [[]] - targets_coords_all += [[]] - targets_times_all += [[]] - targets_lens += [[]] - for stream_info in cf.streams: - sname = stream_info["name"] - - # handle spoof data: do not write since it might corrupt validation (spoofing invisible - # there) - if target_aux_out.physical[t_idx][sname]["is_spoof"][0]: - preds = model_output.get_physical_prediction(t_idx, sname) - preds_shape = preds[0].shape - # for-loop to make sure we have a consistent number of samples - preds_s = [np.zeros((preds_shape[0], 0, preds_shape[2])) for _ in preds] - targets_s = [np.zeros((0, preds_shape[2])) for _ in preds] - t_coords_s = [np.zeros((0, 2)) for _ in preds] - t_times_s = [np.array([]).astype("datetime64[ns]") for _ in preds] - - else: - preds = model_output.get_physical_prediction(t_idx, sname) - targets = target_aux_out.physical[t_idx][sname]["target"] - - preds_s, targets_s, t_coords_s, t_times_s = [], [], [], [] - - # handle forcing streams or if sample is empty - if preds is None: - # preds are empty so create copy of target and add ensemble dimension - assert targets[0].shape[0] == 0, "Empty preds but non-empty targets." - preds = [target.clone().unsqueeze(0) for target in targets] - - for i_batch, (pred, target) in enumerate(zip(preds, targets, strict=True)): - # denormalize data if requested and map to storage format - preds_s += [dn_data(sname, pred).detach().to(fp32).cpu().numpy()] - targets_s += [dn_data(sname, target).detach().to(fp32).cpu().numpy()] - - # extract original target coords and times from target data - target_data = target_aux_out.physical[t_idx][sname] - t_coords_s += [target_data["target_coords"][i_batch].cpu().numpy()] - t_times_s += [target_data["target_times"][i_batch].astype("datetime64[ns]")] - - targets_lens[-1] += [[]] - targets_lens[-1][-1] += [t.shape[0] for t in targets_s] - - preds_all[-1] += [np.concatenate(preds_s, axis=1)] - targets_all[-1] += [np.concatenate(targets_s)] - targets_coords_all[-1] += [np.concatenate(t_coords_s)] - targets_times_all[-1] += [np.concatenate(t_times_s)] - - # # TODO: re-enable - # if len(idxs_inv) > 0: - # pred = pred[:, idxs_inv] - # target = target[idxs_inv] - # targets_coords_raw[t_idx][i_strm] = targets_coords_raw[t_idx][i_strm][idxs_inv] - # targets_times_raw[t_idx][i_strm] = targets_times_raw[t_idx][i_strm][idxs_inv] - - if len(preds_all) == 0 or np.array([p.shape[1] for pp in preds_all for p in pp]).sum() == 0: - _logger.warning("Writing no data since predictions are empty.") - return - - # collect source information + total_steps = len(timestep_idxs) + + # Get chunking configuration (default: no streaming) + fstep_chunk_size = val_cfg.get("output", {}).get("fstep_chunk_size", None) + if fstep_chunk_size is None: + fstep_chunk_size = total_steps + + # Collect source information (once, outside chunking loop) sources = [] for sample in batch.get_source_samples().get_samples(): sources += [[]] for _, stream_data in sample.streams_data.items(): - # TODO: support multiple input steps sources[-1] += [stream_data.source_raw[0]] sample_idxs = [ @@ -119,9 +61,7 @@ def write_output( for sample in batch.get_source_samples().get_samples() ] - # more prep work - - # output stream names to be written, use specified ones or all if nothing specified + # Output stream configuration stream_names = [stream.name for stream in cf.streams] if val_cfg.get("output").get("streams") is not None: output_stream_names = val_cfg.output.streams @@ -131,43 +71,110 @@ def write_output( output_streams = {name: stream_names.index(name) for name in output_stream_names} _logger.debug(f"Using output streams: {output_streams} from streams: {stream_names}") - target_channels: list[list[str]] = [list(stream.val_target_channels) for stream in cf.streams] - source_channels: list[list[str]] = [list(stream.val_source_channels) for stream in cf.streams] - - geoinfo_channels = [[] for _ in cf.streams] # TODO obtain channels + target_channels = [list(stream.val_target_channels) for stream in cf.streams] + source_channels = [list(stream.val_source_channels) for stream in cf.streams] + geoinfo_channels = [[] for _ in cf.streams] - # calculate global sample indices for this batch by offsetting by sample_start sample_start = batch_idx * batch_size - # write output - + # Calculate source intervals start_date = val_cfg.start_date end_date = val_cfg.end_date - - twh = TimeWindowHandler( - start_date, - end_date, - val_cfg.time_window_len, - val_cfg.time_window_step, - ) + twh = TimeWindowHandler(start_date, end_date, val_cfg.time_window_len, val_cfg.time_window_step) source_windows = (twh.window(idx) for idx in sample_idxs) source_intervals = [TimeRange(window.start, window.end) for window in source_windows] - data = io.OutputBatchData( - sources, - source_intervals, - targets_all, - preds_all, - targets_coords_all, - targets_times_all, - targets_lens, - output_streams, - target_channels, - source_channels, - geoinfo_channels, - sample_start, - forecast_offset, - ) + # Write in chunks based on fstep_chunk_size + fp32 = torch.float32 with zarrio_writer(config.get_path_results(cf, mini_epoch)) as zio: - for subset in data.items(): - zio.write_zarr(subset) + for chunk_start in range(0, total_steps, fstep_chunk_size): + chunk_end = min(chunk_start + fstep_chunk_size, total_steps) + chunk_indices = timestep_idxs[chunk_start:chunk_end] + + # Process and write chunk + preds_chunk, targets_chunk, targets_coords_chunk, targets_times_chunk = [], [], [], [] + targets_lens_chunk = [] + + for t_idx in chunk_indices: + preds_step = [] + targets_step = [] + targets_coords_step = [] + targets_times_step = [] + targets_lens_step = [] + + for stream_info in cf.streams: + sname = stream_info["name"] + + if target_aux_out_physical.physical[t_idx][sname]["is_spoof"][0]: + preds = model_output.get_physical_prediction(t_idx, sname) + preds_shape = preds[0].shape if preds else (1, 1, 1) + preds_s = [ + np.zeros((preds_shape[0], 0, preds_shape[2])) + for _ in range(len(preds) if preds else 1) + ] + targets_s = [ + np.zeros((0, preds_shape[2])) for _ in range(len(preds) if preds else 1) + ] + t_coords_s = [np.zeros((0, 2)) for _ in range(len(preds) if preds else 1)] + t_times_s = [ + np.array([]).astype("datetime64[ns]") + for _ in range(len(preds) if preds else 1) + ] + else: + preds = model_output.get_physical_prediction(t_idx, sname) + targets = target_aux_out_physical.physical[t_idx][sname]["target"] + + preds_s, targets_s, t_coords_s, t_times_s = [], [], [], [] + + if preds is None: + assert targets[0].shape[0] == 0, "Empty preds but non-empty targets." + preds = [target.clone().unsqueeze(0) for target in targets] + + for i_batch, (pred, target) in enumerate(zip(preds, targets, strict=True)): + preds_s += [dn_data(sname, pred).detach().to(fp32).cpu().numpy()] + targets_s += [dn_data(sname, target).detach().to(fp32).cpu().numpy()] + target_data = target_aux_out_physical.physical[t_idx][sname] + t_coords_s += [target_data["target_coords"][i_batch].cpu().numpy()] + t_times_s += [ + target_data["target_times"][i_batch].astype("datetime64[ns]") + ] + + targets_lens_step += [t.shape[0] for t in targets_s] + preds_step += [np.concatenate(preds_s, axis=1)] + targets_step += [np.concatenate(targets_s)] + targets_coords_step += [np.concatenate(t_coords_s)] + targets_times_step += [np.concatenate(t_times_s)] + + if len(preds_step) == 0 or np.array([p.shape[1] for p in preds_step]).sum() == 0: + _logger.warning(f"Empty predictions for step {t_idx}") + continue + + preds_chunk.append(preds_step) + targets_chunk.append(targets_step) + targets_coords_chunk.append(targets_coords_step) + targets_times_chunk.append(targets_times_step) + targets_lens_chunk.append([targets_lens_step]) + + if len(preds_chunk) == 0: + continue + + # Determine forecast offset for this chunk + forecast_offset = chunk_indices[0] + + data = io.OutputBatchData( + sources, + source_intervals, + preds_chunk, + targets_chunk, + targets_coords_chunk, + targets_times_chunk, + targets_lens_chunk, + output_streams, + target_channels, + source_channels, + geoinfo_channels, + sample_start, + forecast_offset, + ) + for subset in data.items(): + zio.write_zarr(subset)