diff --git a/config/default_config.yml b/config/default_config.yml index 9e0629ae0..67bcc3e76 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -130,6 +130,17 @@ train_logging: checkpoint: 250 log_grad_norms: False + # Detect anomalous training losses and optionally skip those batches. + loss_spike_detection: + enabled: False + window_size: 50 + min_history: 20 + ratio_threshold: 5.0 + loss_threshold: 0.0 + skip_batch: True + max_unique_times_per_step: 8 + file_name: "loss_spikes.jsonl" + # parameters for data loading data_loading : diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 21dd9390e..d7a0c808e 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -9,8 +9,10 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. import copy +import json import logging import time +from collections import deque from math import sqrt import numpy as np @@ -54,6 +56,17 @@ logger = logging.getLogger(__name__) +LOSS_SPIKE_DETECTION_DEFAULTS = { + "enabled": False, + "window_size": 50, + "min_history": 20, + "ratio_threshold": 5.0, + "loss_threshold": 0.0, + "skip_batch": True, + "max_unique_times_per_step": 8, + "file_name": "loss_spikes.jsonl", +} + # cfg_keys_to_filter = ["losses", "model_input", "target_input"] @@ -87,6 +100,9 @@ def __init__(self, train_logging: Config): self.batch_size_test_per_gpu = -1 self.collapse_monitor: CollapseMonitor | None = None self.perf_tracker: ThroughputTracker | NullThroughputTracker = NullThroughputTracker() + self.loss_spike_cfg = None + self.loss_spike_file = None + self.loss_spike_history = deque() def get_batch_size_total(self, batch_size_per_gpu) -> int: """ @@ -157,6 +173,7 @@ def init(self, cf: Config, devices): config.get_path_model(cf).mkdir(exist_ok=True, parents=True) self.train_logger = TrainLogger(cf, config.get_path_run(self.cf)) + self._init_loss_spike_detection() # Initialize collapse monitor for SSL training collapse_config = cf.train_logging.get("collapse_monitoring", {}) @@ -481,6 +498,18 @@ def train(self, mini_epoch): metadata=extract_batch_metadata(batch), istep=self.cf.general.istep, ) + loss_value = self._get_tensor_item(loss.detach()) + if self._maybe_log_loss_spike(loss_value, batch, mini_epoch, bidx): + self._drop_latest_loss_record() + self.optimizer.zero_grad() + if is_root(): + logger.warning( + "Skipping batch %s in mini_epoch %s due to loss spike: %.8E", + bidx, + mini_epoch, + loss_value, + ) + continue # TODO re-enable this, need to think on how to make it compatible with # student-teacher training @@ -910,6 +939,147 @@ def _get_tensor_item(self, tensor): """ return tensor.full_tensor().item() if isinstance(tensor, DTensor) else tensor.item() + def _init_loss_spike_detection(self) -> None: + configured_loss_spike_cfg = self.cf.train_logging.get("loss_spike_detection", {}) or {} + self.loss_spike_cfg = OmegaConf.merge( + OmegaConf.create(LOSS_SPIKE_DETECTION_DEFAULTS), + configured_loss_spike_cfg, + ) + window_size = int(self.loss_spike_cfg.window_size) + self.loss_spike_history = deque(maxlen=window_size) + self.loss_spike_file = None + + if not self.loss_spike_cfg.enabled: + return + + self.loss_spike_file = config.get_path_run(self.cf) / self.loss_spike_cfg.file_name + + def _serialize_datetimes(self, datetimes) -> list[str]: + if datetimes is None: + return [] + + datetimes_arr = np.asarray(datetimes).reshape(-1) + if datetimes_arr.size == 0: + return [] + + max_unique = int(self.loss_spike_cfg.max_unique_times_per_step) + return [str(dt) for dt in np.unique(datetimes_arr)[:max_unique]] + + @staticmethod + def _to_python_indices(indices): + if hasattr(indices, "astype") and hasattr(indices, "tolist"): + return indices.astype(int).tolist() + if isinstance(indices, list): + return [int(idx) for idx in indices] + if indices is None: + return None + return int(indices) + + @staticmethod + def _to_bool_list(value) -> list[bool]: + if isinstance(value, list): + return [bool(item) for item in value] + return [bool(value)] + + def _collect_sample_debug_info(self, sample, matching_indices) -> dict: + streams = {} + for stream_name, stream_data in sample.streams_data.items(): + if stream_data is None: + continue + + source_raw = getattr(stream_data, "source_raw", []) + target_times_raw = getattr(stream_data, "target_times_raw", []) + source_start_idx = int(stream_data.sample_idx) - len(source_raw) + 1 + + streams[stream_name] = { + "sample_idx": int(stream_data.sample_idx), + "source_is_spoof": self._to_bool_list(stream_data.source_is_spoof), + "target_is_spoof": self._to_bool_list(stream_data.target_is_spoof), + "source_step_indices": [source_start_idx + step for step in range(len(source_raw))], + "target_step_indices": list(range(len(target_times_raw))), + "source_step_datetimes": [ + self._serialize_datetimes(getattr(raw_data, "datetimes", None)) + for raw_data in source_raw + ], + "target_step_datetimes": [ + self._serialize_datetimes(datetimes) for datetimes in target_times_raw + ], + } + + return { + "matching_indices": self._to_python_indices(matching_indices), + "streams": streams, + } + + def _write_loss_spike_record( + self, loss_value, baseline, ratio, batch, mini_epoch, bidx + ) -> None: + if self.loss_spike_file is None: + return + + record = { + "run_id": str(self.cf.general.run_id), + "mini_epoch": int(mini_epoch), + "batch_index": int(bidx), + "global_step": int(self.cf.general.istep), + "loss": float(loss_value), + "loss_repr": f"{loss_value:.8E}", + "baseline_median": float(baseline), + "ratio_to_baseline": float(ratio), + "skip_batch": bool(self.loss_spike_cfg.skip_batch), + "source_samples": [ + self._collect_sample_debug_info(sample, batch.source2target_matching_idxs[sidx]) + for sidx, sample in enumerate(batch.source_samples.get_samples()) + ], + "target_samples": [ + self._collect_sample_debug_info(sample, batch.target2source_matching_idxs[tidx]) + for tidx, sample in enumerate(batch.target_samples.get_samples()) + ], + } + + with self.loss_spike_file.open("a", encoding="utf-8") as file_out: + file_out.write(json.dumps(record) + "\n") + + def _sync_loss_spike_skip(self, should_skip: bool) -> bool: + if torch.distributed.is_available() and torch.distributed.is_initialized(): + skip_flag = torch.tensor( + [int(should_skip)], dtype=torch.int32, device=self.device or torch.device("cpu") + ) + torch.distributed.broadcast(skip_flag, src=0) + should_skip = bool(skip_flag.item()) + + return should_skip + + def _drop_latest_loss_record(self) -> None: + for hist_name in ("loss_hist", "losses_unweighted_hist", "stddev_unweighted_hist"): + hist = getattr(self.loss_calculator, hist_name) + if hist: + hist.pop() + + def _maybe_log_loss_spike(self, loss_value: float, batch, mini_epoch: int, bidx: int) -> bool: + if not self.loss_spike_cfg.enabled: + return False + + should_skip = False + if not is_root(): + return self._sync_loss_spike_skip(should_skip) + + is_finite = np.isfinite(loss_value) + min_history = int(self.loss_spike_cfg.min_history) + if len(self.loss_spike_history) >= min_history: + baseline = float(np.median(self.loss_spike_history)) + ratio = loss_value / baseline if baseline > 0 else np.inf + is_large_enough = loss_value >= float(self.loss_spike_cfg.loss_threshold) + is_spike = ratio >= float(self.loss_spike_cfg.ratio_threshold) + if (is_finite and is_large_enough and is_spike) or not is_finite: + self._write_loss_spike_record(loss_value, baseline, ratio, batch, mini_epoch, bidx) + should_skip = bool(self.loss_spike_cfg.skip_batch) + + if is_finite and not should_skip: + self.loss_spike_history.append(float(loss_value)) + + return self._sync_loss_spike_skip(should_skip) + def _log_instant_grad_norms(self, stage: Stage): """ Log instantaneous grad norms, we do not average because of the cost and because we want to