From 124d8b07da731dc3afbdc9fa049ac08a4fcc9509 Mon Sep 17 00:00:00 2001 From: kcz358 Date: Wed, 20 May 2026 19:20:09 -0700 Subject: [PATCH] fix(profiler): dump memory snapshot on trainer OOM --- src/lmms_engine/train/fsdp2/fsdp2_trainer.py | 10 +++++++++- src/lmms_engine/utils/profiler.py | 15 ++++++++++----- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/src/lmms_engine/train/fsdp2/fsdp2_trainer.py b/src/lmms_engine/train/fsdp2/fsdp2_trainer.py index 3a876f0e..be709fc6 100644 --- a/src/lmms_engine/train/fsdp2/fsdp2_trainer.py +++ b/src/lmms_engine/train/fsdp2/fsdp2_trainer.py @@ -367,7 +367,15 @@ def train(self, resume_from_checkpoint: bool = False): batch = send_to_device(batch, self.fsdp2_model.device) self.memory_snapshot_profiler.step(self.global_step) start_time = time.perf_counter() - train_metrics = self.training_step(batch) + try: + train_metrics = self.training_step(batch) + except torch.OutOfMemoryError: + self.memory_snapshot_profiler.dump_on_exception(f"oom_step{self.global_step}") + raise + except RuntimeError as e: + if "out of memory" in str(e).lower(): + self.memory_snapshot_profiler.dump_on_exception(f"oom_step{self.global_step}") + raise self.step_profiler.step() if self.step_profiler.should_save(self.global_step + 1): self.step_profiler.stop_and_save() diff --git a/src/lmms_engine/utils/profiler.py b/src/lmms_engine/utils/profiler.py index da5dd454..febf9060 100644 --- a/src/lmms_engine/utils/profiler.py +++ b/src/lmms_engine/utils/profiler.py @@ -1,4 +1,5 @@ import os +import time from contextlib import contextmanager, nullcontext from typing import Any, Dict, Optional @@ -119,16 +120,20 @@ def __init__( self.started = False self.stopped = False - def _dump(self, filename: str): - if not self.enable or self.stopped: + def _dump(self, filename: str, force: bool = False): + if not self.enable or (self.stopped and not force): return os.makedirs(self.directory, exist_ok=True) path = os.path.join(self.directory, filename) try: torch.cuda.memory._dump_snapshot(path) logger.info(f"[MemSnapshot] dumped snapshot to {path} (rank {self.rank})") - except Exception as e: - logger.error(f"[MemSnapshot] failed to dump snapshot: {e}") + except Exception: + logger.exception(f"[MemSnapshot] failed to dump snapshot to {path}") + + def dump_on_exception(self, reason: str): + timestamp = int(time.time()) + self._dump(f"snapshot_{reason}_rank{self.rank}_pid{os.getpid()}_{timestamp}.pickle", force=True) def _oom_observer(self, device, alloc, device_alloc, device_free): # Called by PyTorch BEFORE raising CUDA OOM. Dump current snapshot. @@ -137,7 +142,7 @@ def _oom_observer(self, device, alloc, device_alloc, device_free): f"attempted to alloc {alloc} bytes " f"(device_alloc={device_alloc}, device_free={device_free})" ) - self._dump(f"oom_rank{self.rank}.pickle") + self.dump_on_exception("oom_observer") # Mark stopped so we don't try to dump again on re-raise paths. self.stopped = True