Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions config/default_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,8 @@ validation_config:
normalized_samples: False,
# output streams to write; default all
streams: null,
# number of forecast steps to accumulate before writing to disk, default no streaming
fstep_chunk_size: null,
}

# run validation before training starts (mainly for model development)
Expand Down
2 changes: 1 addition & 1 deletion src/weathergen/model/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def forward_channels(self, x_in):
x = peh(self.embed(x_in.transpose(-2, -1)))

for layer in self.layers:
x = checkpoint(layer,x, use_reentrant=False)
x = checkpoint(layer, x, use_reentrant=False)

# read out
if self.unembed_mode == "full":
Expand Down
2 changes: 1 addition & 1 deletion src/weathergen/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,7 @@ def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput:
Tokens are processed through the model components, which were defined in the create method.
Args:
model_params : Query and embedding parameters
batch
batch : Batch of data
Returns:
A list containing all prediction results
"""
Expand Down
54 changes: 14 additions & 40 deletions src/weathergen/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,48 +563,22 @@ def validate(self, mini_epoch, mode_cfg, batch_size):

batch.to_device(self.device)

# evaluate model
with torch.autocast(
device_type=f"cuda:{cf.local_rank}",
dtype=self.mixed_precision_dtype,
enabled=cf.with_mixed_precision,
):
if self.ema_model is None:
preds = self.model(
self.model_params,
batch.get_source_samples(),
)
else:
preds = self.ema_model.forward_eval(
self.model_params,
batch.get_source_samples(),
)

targets_and_auxs = {}
for loss_name, target_aux in self.target_and_aux_calculators_val.items():
target_idxs = get_target_idxs_from_cfg(mode_cfg, loss_name)
targets_and_auxs[loss_name] = target_aux.compute(
self.cf.general.istep,
batch.get_target_samples(target_idxs),
self.model_params,
self.model,
)

_ = self.loss_calculator_val.compute_loss(
preds=preds,
targets_and_aux=targets_and_auxs,
metadata=extract_batch_metadata(batch),
# Forward pass
preds = self.model(
self.model_params,
batch.get_source_samples(),
)

# denormalization function for data
denormalize_data_fct = (
(lambda x0, x1: x1)
if mode_cfg.get("output", {}).get("normalized_samples", False)
else self.dataset_val.denormalize_target_channels
)

# log output
# Write output after forward completes
# Single unified path: write_output handles both streaming and non-streaming
if bidx < num_samples_write:
# denormalization function for data
denormalize_data_fct = (
(lambda x0, x1: x1)
if mode_cfg.get("output", {}).get("normalized_samples", False)
else self.dataset_val.denormalize_target_channels
)
# write output
write_output(
self.cf,
mode_cfg,
Expand All @@ -614,7 +588,7 @@ def validate(self, mini_epoch, mode_cfg, batch_size):
denormalize_data_fct,
batch,
preds,
targets_and_auxs,
self.target_and_aux_calculators_val,
)

pbar.update(batch_size)
Expand Down
225 changes: 116 additions & 109 deletions src/weathergen/utils/validation_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,105 +23,45 @@
def write_output(
cf, val_cfg, batch_size, mini_epoch, batch_idx, dn_data, batch, model_output, target_aux_out
):
"""
Interface for writing model output
"""Write model output with configurable step chunking.

Unifies streaming and non-streaming modes via output.fstep_chunk_size:
- fstep_chunk_size: 1 → Stream after each step
- fstep_chunk_size: N → Accumulate N steps, then write
- fstep_chunk_size: total_steps → Non-streaming (write all at once)
- fstep_chunk_size: None (default) → No streaming (equivalent to total_steps)
"""

# TODO: how to handle multiple physical loss terms
outputs_physical = [
loss_name
for i, (loss_name, loss_term) in enumerate(val_cfg.losses.items())
for loss_name, loss_term in val_cfg.losses.items()
if loss_term.type == "LossPhysical"
]
assert len(outputs_physical) == 1
target_aux_out = target_aux_out[outputs_physical[0]]

# collect all target / prediction-related information
fp32 = torch.float32
preds_all, targets_all, targets_coords_all, targets_times_all = [], [], [], []
target_aux_out_physical = target_aux_out[outputs_physical[0]]

timestep_idxs = [0] if len(batch.get_output_idxs()) == 0 else batch.get_output_idxs()
forecast_offset = timestep_idxs[0]
targets_lens = []

# TODO Maybe stopping at forecast_steps explained #1657
for t_idx in timestep_idxs:
preds_all += [[]]
targets_all += [[]]
targets_coords_all += [[]]
targets_times_all += [[]]
targets_lens += [[]]
for stream_info in cf.streams:
sname = stream_info["name"]

# handle spoof data: do not write since it might corrupt validation (spoofing invisible
# there)
if target_aux_out.physical[t_idx][sname]["is_spoof"][0]:
preds = model_output.get_physical_prediction(t_idx, sname)
preds_shape = preds[0].shape
# for-loop to make sure we have a consistent number of samples
preds_s = [np.zeros((preds_shape[0], 0, preds_shape[2])) for _ in preds]
targets_s = [np.zeros((0, preds_shape[2])) for _ in preds]
t_coords_s = [np.zeros((0, 2)) for _ in preds]
t_times_s = [np.array([]).astype("datetime64[ns]") for _ in preds]

else:
preds = model_output.get_physical_prediction(t_idx, sname)
targets = target_aux_out.physical[t_idx][sname]["target"]

preds_s, targets_s, t_coords_s, t_times_s = [], [], [], []

# handle forcing streams or if sample is empty
if preds is None:
# preds are empty so create copy of target and add ensemble dimension
assert targets[0].shape[0] == 0, "Empty preds but non-empty targets."
preds = [target.clone().unsqueeze(0) for target in targets]

for i_batch, (pred, target) in enumerate(zip(preds, targets, strict=True)):
# denormalize data if requested and map to storage format
preds_s += [dn_data(sname, pred).detach().to(fp32).cpu().numpy()]
targets_s += [dn_data(sname, target).detach().to(fp32).cpu().numpy()]

# extract original target coords and times from target data
target_data = target_aux_out.physical[t_idx][sname]
t_coords_s += [target_data["target_coords"][i_batch].cpu().numpy()]
t_times_s += [target_data["target_times"][i_batch].astype("datetime64[ns]")]

targets_lens[-1] += [[]]
targets_lens[-1][-1] += [t.shape[0] for t in targets_s]

preds_all[-1] += [np.concatenate(preds_s, axis=1)]
targets_all[-1] += [np.concatenate(targets_s)]
targets_coords_all[-1] += [np.concatenate(t_coords_s)]
targets_times_all[-1] += [np.concatenate(t_times_s)]

# # TODO: re-enable
# if len(idxs_inv) > 0:
# pred = pred[:, idxs_inv]
# target = target[idxs_inv]
# targets_coords_raw[t_idx][i_strm] = targets_coords_raw[t_idx][i_strm][idxs_inv]
# targets_times_raw[t_idx][i_strm] = targets_times_raw[t_idx][i_strm][idxs_inv]

if len(preds_all) == 0 or np.array([p.shape[1] for pp in preds_all for p in pp]).sum() == 0:
_logger.warning("Writing no data since predictions are empty.")
return

# collect source information
total_steps = len(timestep_idxs)

# Get chunking configuration (default: no streaming)
fstep_chunk_size = val_cfg.get("output", {}).get("fstep_chunk_size", None)
if fstep_chunk_size is None:
fstep_chunk_size = total_steps

# Collect source information (once, outside chunking loop)
sources = []
for sample in batch.get_source_samples().get_samples():
sources += [[]]
for _, stream_data in sample.streams_data.items():
# TODO: support multiple input steps
sources[-1] += [stream_data.source_raw[0]]

sample_idxs = [
list(sample.streams_data.values())[0].sample_idx
for sample in batch.get_source_samples().get_samples()
]

# more prep work

# output stream names to be written, use specified ones or all if nothing specified
# Output stream configuration
stream_names = [stream.name for stream in cf.streams]
if val_cfg.get("output").get("streams") is not None:
output_stream_names = val_cfg.output.streams
Expand All @@ -131,43 +71,110 @@ def write_output(
output_streams = {name: stream_names.index(name) for name in output_stream_names}
_logger.debug(f"Using output streams: {output_streams} from streams: {stream_names}")

target_channels: list[list[str]] = [list(stream.val_target_channels) for stream in cf.streams]
source_channels: list[list[str]] = [list(stream.val_source_channels) for stream in cf.streams]

geoinfo_channels = [[] for _ in cf.streams] # TODO obtain channels
target_channels = [list(stream.val_target_channels) for stream in cf.streams]
source_channels = [list(stream.val_source_channels) for stream in cf.streams]
geoinfo_channels = [[] for _ in cf.streams]

# calculate global sample indices for this batch by offsetting by sample_start
sample_start = batch_idx * batch_size

# write output

# Calculate source intervals
start_date = val_cfg.start_date
end_date = val_cfg.end_date

twh = TimeWindowHandler(
start_date,
end_date,
val_cfg.time_window_len,
val_cfg.time_window_step,
)
twh = TimeWindowHandler(start_date, end_date, val_cfg.time_window_len, val_cfg.time_window_step)
source_windows = (twh.window(idx) for idx in sample_idxs)
source_intervals = [TimeRange(window.start, window.end) for window in source_windows]

data = io.OutputBatchData(
sources,
source_intervals,
targets_all,
preds_all,
targets_coords_all,
targets_times_all,
targets_lens,
output_streams,
target_channels,
source_channels,
geoinfo_channels,
sample_start,
forecast_offset,
)
# Write in chunks based on fstep_chunk_size
fp32 = torch.float32
with zarrio_writer(config.get_path_results(cf, mini_epoch)) as zio:
for subset in data.items():
zio.write_zarr(subset)
for chunk_start in range(0, total_steps, fstep_chunk_size):
chunk_end = min(chunk_start + fstep_chunk_size, total_steps)
chunk_indices = timestep_idxs[chunk_start:chunk_end]

# Process and write chunk
preds_chunk, targets_chunk, targets_coords_chunk, targets_times_chunk = [], [], [], []
targets_lens_chunk = []

for t_idx in chunk_indices:
preds_step = []
targets_step = []
targets_coords_step = []
targets_times_step = []
targets_lens_step = []

for stream_info in cf.streams:
sname = stream_info["name"]

if target_aux_out_physical.physical[t_idx][sname]["is_spoof"][0]:
preds = model_output.get_physical_prediction(t_idx, sname)
preds_shape = preds[0].shape if preds else (1, 1, 1)
preds_s = [
np.zeros((preds_shape[0], 0, preds_shape[2]))
for _ in range(len(preds) if preds else 1)
]
targets_s = [
np.zeros((0, preds_shape[2])) for _ in range(len(preds) if preds else 1)
]
t_coords_s = [np.zeros((0, 2)) for _ in range(len(preds) if preds else 1)]
t_times_s = [
np.array([]).astype("datetime64[ns]")
for _ in range(len(preds) if preds else 1)
]
else:
preds = model_output.get_physical_prediction(t_idx, sname)
targets = target_aux_out_physical.physical[t_idx][sname]["target"]

preds_s, targets_s, t_coords_s, t_times_s = [], [], [], []

if preds is None:
assert targets[0].shape[0] == 0, "Empty preds but non-empty targets."
preds = [target.clone().unsqueeze(0) for target in targets]

for i_batch, (pred, target) in enumerate(zip(preds, targets, strict=True)):
preds_s += [dn_data(sname, pred).detach().to(fp32).cpu().numpy()]
targets_s += [dn_data(sname, target).detach().to(fp32).cpu().numpy()]
target_data = target_aux_out_physical.physical[t_idx][sname]
t_coords_s += [target_data["target_coords"][i_batch].cpu().numpy()]
t_times_s += [
target_data["target_times"][i_batch].astype("datetime64[ns]")
]

targets_lens_step += [t.shape[0] for t in targets_s]
preds_step += [np.concatenate(preds_s, axis=1)]
targets_step += [np.concatenate(targets_s)]
targets_coords_step += [np.concatenate(t_coords_s)]
targets_times_step += [np.concatenate(t_times_s)]

if len(preds_step) == 0 or np.array([p.shape[1] for p in preds_step]).sum() == 0:
_logger.warning(f"Empty predictions for step {t_idx}")
continue

preds_chunk.append(preds_step)
targets_chunk.append(targets_step)
targets_coords_chunk.append(targets_coords_step)
targets_times_chunk.append(targets_times_step)
targets_lens_chunk.append([targets_lens_step])

if len(preds_chunk) == 0:
continue

# Determine forecast offset for this chunk
forecast_offset = chunk_indices[0]

data = io.OutputBatchData(
sources,
source_intervals,
preds_chunk,
targets_chunk,
targets_coords_chunk,
targets_times_chunk,
targets_lens_chunk,
output_streams,
target_channels,
source_channels,
geoinfo_channels,
sample_start,
forecast_offset,
)
for subset in data.items():
zio.write_zarr(subset)
Loading