From 4357f921cabafa28da5057c5d3b71e46c6909048 Mon Sep 17 00:00:00 2001 From: Claude Date: Thu, 26 Feb 2026 22:26:25 +0000 Subject: [PATCH 01/18] Add on-demand full-state checkpointing for OpenShift AI / KubeFlow preemption MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements signal-driven checkpoint-and-exit for distributed training jobs running in OpenShift AI as KubeFlow training jobs or multi-node bare metal. When `on_demand_checkpointing=True` is set in TrainingArgs: - Parent process (run_training) installs handlers for SIGTERM, SIGINT, SIGUSR1, SIGUSR2, SIGXCPU, and SIGHUP — covering all signals Kubernetes/OpenShift sends before the hard SIGKILL. - On signal receipt, a trigger file is atomically written to /dev/shm (tmpfs, shared within the pod, zero disk I/O). - Worker processes check for the trigger file after each optimizer step via an all_reduce(MAX) collective, ensuring global consensus across all ranks on all nodes. - When any rank detects the trigger, all ranks collectively save a full-state distributed checkpoint (model + optimizer + LR scheduler) then exit gracefully. - Parent waits up to 300s for workers to complete the checkpoint before proceeding with normal shutdown. https://claude.ai/code/session_01HSxsk7SnMULJxy7uafe7t3 --- src/instructlab/training/config.py | 13 + src/instructlab/training/main_ds.py | 102 ++++++- .../training/on_demand_checkpoint.py | 277 ++++++++++++++++++ 3 files changed, 385 insertions(+), 7 deletions(-) create mode 100644 src/instructlab/training/on_demand_checkpoint.py diff --git a/src/instructlab/training/config.py b/src/instructlab/training/config.py index 34dfda98..87e29082 100644 --- a/src/instructlab/training/config.py +++ b/src/instructlab/training/config.py @@ -361,6 +361,19 @@ class TrainingArgs(BaseModel): description="How often to evaluate validation loss (in training steps). Required when validation_split > 0.", ) + on_demand_checkpointing: bool = Field( + default=False, + description=( + "Enable on-demand full-state checkpointing triggered by Unix signals. " + "When enabled, the parent process intercepts termination signals " + "(SIGTERM, SIGINT, SIGUSR1, SIGUSR2, SIGXCPU, SIGHUP) and writes a " + "trigger file to /dev/shm. Worker processes check for this trigger " + "after each training step and collectively save a distributed " + "checkpoint before exiting gracefully. Designed for OpenShift AI / " + "KubeFlow training jobs where preemption signals must be handled." + ), + ) + @model_validator(mode="after") def validate_validation_config(self): if not 0.0 <= self.validation_split < 1.0: diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index a9887800..6033e434 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -173,6 +173,7 @@ def train( accelerator: Accelerator, val_data_loader=None, validation_frequency=None, + on_demand_checkpointing: bool = False, ): model.train() @@ -183,6 +184,15 @@ def train( metric_logger = logging.getLogger("instructlab.training.metrics") base_logger = logging.getLogger("instructlab.training") + # Import on-demand checkpointing utilities once if the feature is enabled + if on_demand_checkpointing: + from instructlab.training.on_demand_checkpoint import ( + check_checkpoint_requested, + save_on_demand_checkpoint, + ) + + base_logger.info("On-demand checkpointing is enabled in worker process.") + # Mini_trainer approach: batch_size will be determined dynamically by data loader # For save logic, use effective_batch_size since that's the target samples_seen = 0 @@ -308,6 +318,22 @@ def train( base_logger.debug("RANK (%d) waiting at post-save barrier.", local_rank) dist.barrier() + # --- On-demand checkpointing: check if a signal triggered a save --- + if on_demand_checkpointing and check_checkpoint_requested(): + save_on_demand_checkpoint( + args=args, + accelerator=accelerator, + model=model, + tokenizer=model.tokenizer, + samples_seen=samples_seen, + epoch=epoch, + is_lora=bool(args.lora_r), + ) + base_logger.info( + "On-demand checkpoint saved. Exiting training gracefully." + ) + return + global_step += 1 if local_rank == 0: inner_pb.update(1) @@ -572,6 +598,7 @@ def main(args): accelerator=accelerator, val_data_loader=val_loader, validation_frequency=validation_frequency, + on_demand_checkpointing=getattr(args, "on_demand_checkpointing", False), ) dist.barrier() @@ -809,7 +836,24 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: if train_args.keep_last_checkpoint_only: command.append("--keep_last_checkpoint_only") + if train_args.on_demand_checkpointing: + command.append("--on_demand_checkpointing") + logger.info("Running training command as subprocess: %s", " ".join(command)) + + # --- On-demand checkpointing: install signal handlers in the parent --- + signal_handler = None + if train_args.on_demand_checkpointing: + # First Party + from instructlab.training.on_demand_checkpoint import ParentSignalHandler + + signal_handler = ParentSignalHandler() + signal_handler.install() + logger.info( + "On-demand checkpointing is ENABLED. " + "Termination signals will trigger a full-state checkpoint before exit." + ) + process = None interrupt: KeyboardInterrupt | Exception | None = None failure = False @@ -829,19 +873,49 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: interrupt = e finally: if "process" not in locals() or process is None: + if signal_handler is not None: + signal_handler.uninstall() return + # If a signal was caught by the on-demand checkpoint handler, give + # the workers time to detect the trigger file and save a checkpoint + # before we start sending our own signals to the subprocess. + if signal_handler is not None and signal_handler.signal_received is not None: + logger.info( + "On-demand checkpoint: signal %s received. Waiting for workers to " + "save checkpoint before proceeding with shutdown...", + signal_handler.signal_received.name, + ) + # Give workers generous time to complete the checkpoint save. + # The workers will exit on their own after saving. + try: + process.wait(timeout=300) + except subprocess.TimeoutExpired: + logger.warning( + "On-demand checkpoint: workers did not finish within 300s. " + "Proceeding with shutdown." + ) + # wait for the process to exit so we can properly read the exit code - process.wait(timeout=60) + try: + process.wait(timeout=60) + except subprocess.TimeoutExpired: + pass process_code = process.poll() - failure = process_code != 0 + failure = process_code is not None and process_code != 0 - if not failure: - logger.info("Operation completed successfully! 🎉") + if process_code is not None and not failure: + logger.info("Operation completed successfully!") else: - logger.error( - f"Training subprocess has not exited yet. Sending SIGTERM. Process code: {process_code}" - ) + if process_code is None: + logger.error( + "Training subprocess has not exited yet. Sending SIGTERM." + ) + else: + logger.error( + "Training subprocess exited with code %d. Sending SIGTERM.", + process_code, + ) process.terminate() try: @@ -853,6 +927,9 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: ) process.kill() + if signal_handler is not None: + signal_handler.uninstall() + if interrupt: raise interrupt if failure: @@ -1072,6 +1149,17 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: ), ) + parser.add_argument( + "--on_demand_checkpointing", + action="store_true", + default=False, + help=( + "Enable on-demand full-state checkpointing triggered by Unix signals. " + "When enabled, workers check for a trigger file in /dev/shm after each " + "training step and collectively save a distributed checkpoint before " + "exiting. Designed for OpenShift AI / KubeFlow preemption handling." + ), + ) parser.add_argument( "--use_liger", action="store_true", diff --git a/src/instructlab/training/on_demand_checkpoint.py b/src/instructlab/training/on_demand_checkpoint.py new file mode 100644 index 00000000..8d4e7462 --- /dev/null +++ b/src/instructlab/training/on_demand_checkpoint.py @@ -0,0 +1,277 @@ +# SPDX-License-Identifier: Apache-2.0 + +""" +On-demand checkpointing for distributed training. + +This module enables graceful checkpoint-and-exit when termination signals are +received. It is designed for environments like OpenShift AI / KubeFlow where +training jobs can be preempted at any time and the platform sends Unix signals +before killing the pod. + +Architecture +------------ +There are two sides to this feature: + +**Parent process** (``run_training`` in ``main_ds.py``): + Installs signal handlers that catch every signal OpenShift / Kubernetes can + send before a SIGKILL. When a signal arrives the handler writes a small + *trigger file* to ``/dev/shm`` (a tmpfs shared between containers in the + same pod). Because ``/dev/shm`` is node-local, every worker on the **same + node** can see the file instantly with zero network I/O. + +**Worker processes** (torchrun children): + After every optimizer step the training loop calls + ``check_checkpoint_requested()``. Each rank checks its local ``/dev/shm`` + for the trigger file, converts the boolean to a tensor, and does an + ``all_reduce(MAX)`` so that if *any* rank on *any* node detected the + trigger, *every* rank agrees to save a checkpoint. This works correctly in + multi-node training because all_reduce is a global collective. + +Signals handled +--------------- +We intercept every signal that Kubernetes / OpenShift can deliver before the +hard SIGKILL (which cannot be caught): + +* **SIGTERM** – the standard graceful-shutdown signal. Kubernetes sends this + first (configurable via ``terminationGracePeriodSeconds``). +* **SIGINT** – sent on Ctrl-C or by some job controllers. +* **SIGUSR1 / SIGUSR2** – commonly used by batch schedulers and custom + preemption controllers to signal upcoming eviction. +* **SIGXCPU** – sent when CPU time limits are exceeded (relevant for jobs + with resource quotas). +* **SIGHUP** – sent when the controlling terminal disconnects; some + container runtimes forward this on pod eviction. +""" + +# Standard +import logging +import os +import signal +import tempfile +from pathlib import Path +from typing import Optional + +# Third Party +import torch +import torch.distributed as dist + +logger = logging.getLogger("instructlab.training") + +# --------------------------------------------------------------------------- +# Trigger file helpers +# --------------------------------------------------------------------------- + +# The trigger file lives in /dev/shm which is a tmpfs (RAM-backed filesystem). +# It is: +# 1. Extremely fast (no disk I/O). +# 2. Shared between all containers in the same Kubernetes pod. +# 3. Automatically cleaned up when the pod is destroyed. +_TRIGGER_DIR = Path("/dev/shm") +_TRIGGER_FILENAME = "instructlab_checkpoint_requested" + + +def _get_trigger_path(job_id: Optional[str] = None) -> Path: + """Return the path to the checkpoint trigger file. + + An optional *job_id* can be supplied to avoid collisions if multiple + training jobs share the same ``/dev/shm`` (unlikely but possible). + """ + name = f"{_TRIGGER_FILENAME}_{job_id}" if job_id else _TRIGGER_FILENAME + return _TRIGGER_DIR / name + + +def write_trigger_file(job_id: Optional[str] = None) -> Path: + """Create the trigger file that tells workers to checkpoint. + + This is called from the *parent* process signal handler. + Returns the path that was written. + """ + path = _get_trigger_path(job_id) + # Use a atomic write via tempfile + rename to avoid partial reads. + fd, tmp = tempfile.mkstemp(dir=_TRIGGER_DIR, prefix=".ckpt_trigger_") + try: + os.write(fd, b"1") + finally: + os.close(fd) + os.rename(tmp, path) + logger.info( + "On-demand checkpoint trigger file written: %s", + path, + ) + return path + + +def trigger_file_exists(job_id: Optional[str] = None) -> bool: + """Check whether the trigger file exists (worker-side).""" + return _get_trigger_path(job_id).exists() + + +def remove_trigger_file(job_id: Optional[str] = None) -> None: + """Remove the trigger file after the checkpoint has been saved.""" + path = _get_trigger_path(job_id) + try: + path.unlink(missing_ok=True) + except OSError: + pass + + +# --------------------------------------------------------------------------- +# Parent-side signal handling +# --------------------------------------------------------------------------- + +# Signals that OpenShift / Kubernetes / batch schedulers may send before +# the hard SIGKILL. SIGKILL (9) and SIGSTOP (19) cannot be caught. +_CATCHABLE_SIGNALS = ( + signal.SIGTERM, # Kubernetes default graceful shutdown signal + signal.SIGINT, # Ctrl-C / some job controllers + signal.SIGUSR1, # Custom preemption controllers + signal.SIGUSR2, # Custom preemption controllers + signal.SIGXCPU, # CPU time limit exceeded (resource quotas) + signal.SIGHUP, # Terminal disconnect / some eviction paths +) + + +class ParentSignalHandler: + """Installs signal handlers in the parent (launcher) process. + + When any of the catchable signals fire, the handler: + 1. Writes the trigger file to ``/dev/shm``. + 2. Records that a signal was received (so the caller can decide to + wait for the child process to finish checkpointing). + + The handler is idempotent – multiple signals will not create multiple + trigger files. + + Parameters + ---------- + job_id : str, optional + Unique identifier for this training job. Used to namespace the + trigger file. + """ + + def __init__(self, job_id: Optional[str] = None): + self.job_id = job_id + self.signal_received: Optional[signal.Signals] = None + self._original_handlers: dict[signal.Signals, object] = {} + self._trigger_written = False + + def install(self) -> None: + """Register signal handlers for all catchable signals.""" + for sig in _CATCHABLE_SIGNALS: + try: + self._original_handlers[sig] = signal.getsignal(sig) + signal.signal(sig, self._handle) + except (OSError, ValueError): + # Some signals may not be available on all platforms + logger.debug("Could not install handler for %s", sig.name) + + logger.info( + "On-demand checkpoint signal handlers installed for: %s", + ", ".join(s.name for s in self._original_handlers), + ) + + def uninstall(self) -> None: + """Restore original signal handlers.""" + for sig, handler in self._original_handlers.items(): + try: + signal.signal(sig, handler) + except (OSError, ValueError): + pass + self._original_handlers.clear() + + def _handle(self, signum: int, _frame) -> None: + """Signal handler callback.""" + sig = signal.Signals(signum) + logger.info( + "On-demand checkpoint: received signal %s (%d). " + "Writing trigger file for workers to checkpoint before exit.", + sig.name, + signum, + ) + self.signal_received = sig + + if not self._trigger_written: + write_trigger_file(self.job_id) + self._trigger_written = True + + +# --------------------------------------------------------------------------- +# Worker-side synchronization +# --------------------------------------------------------------------------- + + +def check_checkpoint_requested(job_id: Optional[str] = None) -> bool: + """Check across all ranks whether an on-demand checkpoint was requested. + + This function must be called by **all ranks** at the same point in the + training loop (it contains a collective all_reduce). + + Returns ``True`` if any rank detected the trigger file, meaning all + ranks should save a checkpoint. + """ + local_trigger = trigger_file_exists(job_id) + + # Convert to a tensor and all-reduce (MAX) so that if ANY rank on ANY + # node saw the trigger, every rank gets True. + trigger_tensor = torch.tensor( + [1 if local_trigger else 0], + dtype=torch.int32, + device=torch.cuda.current_device(), + ) + dist.all_reduce(trigger_tensor, op=dist.ReduceOp.MAX) + + requested = trigger_tensor.item() > 0 + + if requested: + logger.info( + "On-demand checkpoint: global consensus reached – " + "all ranks will save a checkpoint." + ) + # Clean up the trigger file so that if the process somehow + # continues, we don't save again immediately. + remove_trigger_file(job_id) + + return requested + + +def save_on_demand_checkpoint( + args, + accelerator, + model, + tokenizer, + samples_seen: int, + epoch: int, + is_lora: bool, +) -> None: + """Save a full-state distributed checkpoint for on-demand resume. + + This is a thin wrapper that calls the existing ``save_checkpoint`` + utility with ``full_state=True`` so that optimizer + LR scheduler + state are also persisted, enabling exact training resumption. + """ + # First Party – imported here to avoid circular imports + from instructlab.training.utils import save_checkpoint + + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if local_rank == 0: + logger.info( + "On-demand checkpoint: saving full-state checkpoint at " + "epoch=%d, samples_seen=%d", + epoch, + samples_seen, + ) + + save_checkpoint( + args=args, + accelerator=accelerator, + model=model, + tokenizer=tokenizer, + samples_seen=samples_seen, + is_lora=is_lora, + full_state=True, + hf_format=True, + epoch=epoch, + ) + + if local_rank == 0: + logger.info("On-demand checkpoint: checkpoint saved successfully.") From 848f51b1608d3132f309489e7fa5803b5c28eaca Mon Sep 17 00:00:00 2001 From: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> Date: Mon, 2 Mar 2026 19:28:43 +0000 Subject: [PATCH 02/18] Address review feedback for on-demand checkpointing - Fix mypy error: properly type _original_handlers dict with _SignalHandler type alias instead of bare object - Fix ruff/isort: remove duplicate comment, fix import ordering - Namespace trigger file with rdzv_id as job_id so concurrent jobs sharing /dev/shm don't interfere with each other - Recompute subprocess failure status after forced termination to avoid stale exit code - Gate consensus log message to rank 0 to reduce log noise on large jobs --- src/instructlab/training/main_ds.py | 57 ++++++++++++------- .../training/on_demand_checkpoint.py | 37 +++++++----- 2 files changed, 57 insertions(+), 37 deletions(-) diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index 6033e434..dcf46761 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -185,12 +185,15 @@ def train( base_logger = logging.getLogger("instructlab.training") # Import on-demand checkpointing utilities once if the feature is enabled + checkpoint_job_id = None if on_demand_checkpointing: + # First Party from instructlab.training.on_demand_checkpoint import ( check_checkpoint_requested, save_on_demand_checkpoint, ) + checkpoint_job_id = os.environ.get("INSTRUCTLAB_ON_DEMAND_JOB_ID") base_logger.info("On-demand checkpointing is enabled in worker process.") # Mini_trainer approach: batch_size will be determined dynamically by data loader @@ -319,7 +322,9 @@ def train( dist.barrier() # --- On-demand checkpointing: check if a signal triggered a save --- - if on_demand_checkpointing and check_checkpoint_requested(): + if on_demand_checkpointing and check_checkpoint_requested( + checkpoint_job_id + ): save_on_demand_checkpoint( args=args, accelerator=accelerator, @@ -847,11 +852,16 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: # First Party from instructlab.training.on_demand_checkpoint import ParentSignalHandler - signal_handler = ParentSignalHandler() + # Use rdzv_id to namespace the trigger file so concurrent jobs + # sharing /dev/shm don't interfere with each other. + checkpoint_job_id = str(torch_args.rdzv_id) + os.environ["INSTRUCTLAB_ON_DEMAND_JOB_ID"] = checkpoint_job_id + signal_handler = ParentSignalHandler(job_id=checkpoint_job_id) signal_handler.install() logger.info( - "On-demand checkpointing is ENABLED. " - "Termination signals will trigger a full-state checkpoint before exit." + "On-demand checkpointing is ENABLED (job_id=%s). " + "Termination signals will trigger a full-state checkpoint before exit.", + checkpoint_job_id, ) process = None @@ -902,30 +912,33 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: except subprocess.TimeoutExpired: pass process_code = process.poll() - failure = process_code is not None and process_code != 0 - if process_code is not None and not failure: + if process_code is not None and process_code == 0: logger.info("Operation completed successfully!") - else: - if process_code is None: - logger.error( - "Training subprocess has not exited yet. Sending SIGTERM." - ) - else: + elif process_code is None: + logger.error("Training subprocess has not exited yet. Sending SIGTERM.") + process.terminate() + try: + logger.info("Waiting for process to exit, 60s...") + process.wait(timeout=60) + except subprocess.TimeoutExpired: logger.error( - "Training subprocess exited with code %d. Sending SIGTERM.", - process_code, + "Training subprocess did not terminate before timeout, sending SIGKILL." ) - - process.terminate() - try: - logger.info("Waiting for process to exit, 60s...") - process.wait(timeout=60) - except subprocess.TimeoutExpired: + process.kill() + try: + process.wait(timeout=10) + except subprocess.TimeoutExpired: + pass + else: logger.error( - "Training subprocess did not terminate before timeout, sending SIGKILL." + "Training subprocess exited with code %d.", + process_code, ) - process.kill() + + # Recompute final exit status after any forced shutdown + process_code = process.poll() + failure = process_code is None or process_code != 0 if signal_handler is not None: signal_handler.uninstall() diff --git a/src/instructlab/training/on_demand_checkpoint.py b/src/instructlab/training/on_demand_checkpoint.py index 8d4e7462..b9643531 100644 --- a/src/instructlab/training/on_demand_checkpoint.py +++ b/src/instructlab/training/on_demand_checkpoint.py @@ -44,17 +44,23 @@ """ # Standard +from pathlib import Path +from typing import Callable, Optional, Union import logging import os import signal import tempfile -from pathlib import Path -from typing import Optional +import types # Third Party import torch import torch.distributed as dist +# Type alias matching the return type of signal.getsignal(). +_SignalHandler = Union[ + Callable[[int, Optional[types.FrameType]], None], int, signal.Handlers, None +] + logger = logging.getLogger("instructlab.training") # --------------------------------------------------------------------------- @@ -122,12 +128,12 @@ def remove_trigger_file(job_id: Optional[str] = None) -> None: # Signals that OpenShift / Kubernetes / batch schedulers may send before # the hard SIGKILL. SIGKILL (9) and SIGSTOP (19) cannot be caught. _CATCHABLE_SIGNALS = ( - signal.SIGTERM, # Kubernetes default graceful shutdown signal - signal.SIGINT, # Ctrl-C / some job controllers - signal.SIGUSR1, # Custom preemption controllers - signal.SIGUSR2, # Custom preemption controllers - signal.SIGXCPU, # CPU time limit exceeded (resource quotas) - signal.SIGHUP, # Terminal disconnect / some eviction paths + signal.SIGTERM, # Kubernetes default graceful shutdown signal + signal.SIGINT, # Ctrl-C / some job controllers + signal.SIGUSR1, # Custom preemption controllers + signal.SIGUSR2, # Custom preemption controllers + signal.SIGXCPU, # CPU time limit exceeded (resource quotas) + signal.SIGHUP, # Terminal disconnect / some eviction paths ) @@ -152,7 +158,7 @@ class ParentSignalHandler: def __init__(self, job_id: Optional[str] = None): self.job_id = job_id self.signal_received: Optional[signal.Signals] = None - self._original_handlers: dict[signal.Signals, object] = {} + self._original_handlers: dict[signal.Signals, _SignalHandler] = {} self._trigger_written = False def install(self) -> None: @@ -174,7 +180,7 @@ def uninstall(self) -> None: """Restore original signal handlers.""" for sig, handler in self._original_handlers.items(): try: - signal.signal(sig, handler) + signal.signal(sig, handler) # type: ignore[arg-type] except (OSError, ValueError): pass self._original_handlers.clear() @@ -223,10 +229,11 @@ def check_checkpoint_requested(job_id: Optional[str] = None) -> bool: requested = trigger_tensor.item() > 0 if requested: - logger.info( - "On-demand checkpoint: global consensus reached – " - "all ranks will save a checkpoint." - ) + if dist.is_initialized() and dist.get_rank() == 0: + logger.info( + "On-demand checkpoint: global consensus reached – " + "all ranks will save a checkpoint." + ) # Clean up the trigger file so that if the process somehow # continues, we don't save again immediately. remove_trigger_file(job_id) @@ -249,7 +256,7 @@ def save_on_demand_checkpoint( utility with ``full_state=True`` so that optimizer + LR scheduler state are also persisted, enabling exact training resumption. """ - # First Party – imported here to avoid circular imports + # First Party from instructlab.training.utils import save_checkpoint local_rank = int(os.environ.get("LOCAL_RANK", "0")) From 4959fe3afa8aaa21cebaad8ee2726112f1feffc8 Mon Sep 17 00:00:00 2001 From: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> Date: Mon, 2 Mar 2026 20:27:05 +0000 Subject: [PATCH 03/18] Check for on-demand checkpoint after each minibatch backward Move the checkpoint request check from after the full optimizer step to after each minibatch's backward pass inside BatchLossManager.process_batch. This ensures the system responds within one fwd+bwd cycle (~1-2s) even when gradient accumulation spans many minibatches, giving more time to save before Kubernetes sends SIGKILL after the grace period. The check is passed as an optional interrupt_check callback to keep checkpoint-specific logic out of BatchLossManager. When triggered, the batch loop breaks early and the training loop saves the checkpoint immediately, skipping the optimizer step to preserve the pre-step model state for exact resumption. --- .../training/batch_loss_manager.py | 22 ++++++++- src/instructlab/training/main_ds.py | 49 +++++++++++-------- 2 files changed, 49 insertions(+), 22 deletions(-) diff --git a/src/instructlab/training/batch_loss_manager.py b/src/instructlab/training/batch_loss_manager.py index cc6da021..b199ac17 100644 --- a/src/instructlab/training/batch_loss_manager.py +++ b/src/instructlab/training/batch_loss_manager.py @@ -7,7 +7,8 @@ """ # Standard -from dataclasses import dataclass +from collections.abc import Callable +from dataclasses import dataclass, field import logging # Third Party @@ -33,6 +34,7 @@ class BatchMetrics: accumulated_aux_loss: torch.Tensor | None grad_accum_steps: int num_minibatches: int + interrupted: bool = field(default=False) class BatchLossManager: @@ -62,12 +64,21 @@ def __init__(self, model, accelerator, world_size: int, local_rank: int): self.local_rank: int = local_rank self.torch_device = torch.device("cuda", local_rank) - def process_batch(self, batch: list[CollatedItem]) -> tuple[BatchMetrics, float]: + def process_batch( + self, + batch: list[CollatedItem], + interrupt_check: Callable[[], bool] | None = None, + ) -> tuple[BatchMetrics, float]: """ Process a batch of minibatches, computing losses and accumulating gradients. Args: batch: List of minibatches to process + interrupt_check: Optional callback invoked after each minibatch's + backward pass. If it returns ``True``, gradient accumulation + stops early and ``BatchMetrics.interrupted`` is set. Used by + on-demand checkpointing to react within one fwd+bwd cycle + instead of waiting for the full optimizer step. Returns: tuple: (BatchMetrics, average_loss_across_ranks) @@ -82,6 +93,7 @@ def process_batch(self, batch: list[CollatedItem]) -> tuple[BatchMetrics, float] accumulated_loss = 0.0 accumulated_aux_loss = 0.0 grad_accum_steps = 0 + interrupted = False # process each minibatch for mb in batch: @@ -108,6 +120,11 @@ def process_batch(self, batch: list[CollatedItem]) -> tuple[BatchMetrics, float] if raw_losses.aux_loss is not None: accumulated_aux_loss += raw_losses.aux_loss + # check for early exit (e.g. on-demand checkpoint requested) + if interrupt_check is not None and interrupt_check(): + interrupted = True + break + # reduce metrics across ranks batch_total_samples, batch_total_length = self._reduce_metrics( batch_total_samples, batch_total_length @@ -127,6 +144,7 @@ def process_batch(self, batch: list[CollatedItem]) -> tuple[BatchMetrics, float] accumulated_aux_loss=accumulated_aux_loss, grad_accum_steps=grad_accum_steps, num_minibatches=num_minibatches, + interrupted=interrupted, ) return metrics, avg_loss_across_ranks diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index dcf46761..b3dc5713 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -233,11 +233,38 @@ def train( continue start = time.time() - # Process the batch using the BatchLossManager + # Process the batch using the BatchLossManager. + # When on-demand checkpointing is enabled, pass a callback so + # the check runs after every minibatch backward rather than + # waiting for the full optimizer step. + _interrupt_check = ( + (lambda: check_checkpoint_requested(checkpoint_job_id)) + if on_demand_checkpointing + else None + ) batch_metrics, avg_loss_across_ranks = batch_loss_manager.process_batch( - batch + batch, interrupt_check=_interrupt_check ) + # If the batch was interrupted by an on-demand checkpoint + # request, save immediately and exit — skip the optimizer step + # since we want to preserve the pre-step model state for + # exact resumption. + if batch_metrics.interrupted: + save_on_demand_checkpoint( + args=args, + accelerator=accelerator, + model=model, + tokenizer=model.tokenizer, + samples_seen=samples_seen, + epoch=epoch, + is_lora=bool(args.lora_r), + ) + base_logger.info( + "On-demand checkpoint saved. Exiting training gracefully." + ) + return + # Update samples seen samples_seen += batch_metrics.total_samples @@ -321,24 +348,6 @@ def train( base_logger.debug("RANK (%d) waiting at post-save barrier.", local_rank) dist.barrier() - # --- On-demand checkpointing: check if a signal triggered a save --- - if on_demand_checkpointing and check_checkpoint_requested( - checkpoint_job_id - ): - save_on_demand_checkpoint( - args=args, - accelerator=accelerator, - model=model, - tokenizer=model.tokenizer, - samples_seen=samples_seen, - epoch=epoch, - is_lora=bool(args.lora_r), - ) - base_logger.info( - "On-demand checkpoint saved. Exiting training gracefully." - ) - return - global_step += 1 if local_rank == 0: inner_pb.update(1) From 5b8f1d63eb29a24da347154d5cfeb53d04e4af14 Mon Sep 17 00:00:00 2001 From: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> Date: Mon, 2 Mar 2026 20:29:01 +0000 Subject: [PATCH 04/18] Add diagnostic note when on-demand checkpoint fails to save in time When the training subprocess fails after an on-demand checkpoint signal was received, the error message now includes guidance to increase terminationGracePeriodSeconds or reduce fwd/bwd pass time so the checkpoint check fires before SIGKILL arrives. --- src/instructlab/training/main_ds.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index b3dc5713..02adc240 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -955,9 +955,22 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: if interrupt: raise interrupt if failure: - raise RuntimeError( - "Suffered a failure during distributed training. Please see the training logs for more context." - ) + msg = "Suffered a failure during distributed training. Please see the training logs for more context." + if ( + signal_handler is not None + and signal_handler.signal_received is not None + ): + msg += ( + f"\n\nNote: signal {signal_handler.signal_received.name} was" + " received and on-demand checkpointing was enabled, but the" + " training subprocess did not exit cleanly. This usually" + " means the process was killed (SIGKILL) before the" + " checkpoint could be saved. To fix this, increase" + " terminationGracePeriodSeconds in your pod spec to give" + " workers more time, or reduce the model's forward/backward" + " pass time so the checkpoint check fires sooner." + ) + raise RuntimeError(msg) if __name__ == "__main__": From 6fb1bb5e54b73e0530620bc9e33de73d1c731b08 Mon Sep 17 00:00:00 2001 From: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> Date: Mon, 2 Mar 2026 20:38:56 +0000 Subject: [PATCH 05/18] Fix help text: checkpoint check happens after each minibatch backward Update --on_demand_checkpointing help text and TrainingArgs description to accurately state that workers check for the trigger file after each minibatch backward pass, not after each training step. --- src/instructlab/training/config.py | 2 +- src/instructlab/training/main_ds.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/instructlab/training/config.py b/src/instructlab/training/config.py index 87e29082..8811ead0 100644 --- a/src/instructlab/training/config.py +++ b/src/instructlab/training/config.py @@ -368,7 +368,7 @@ class TrainingArgs(BaseModel): "When enabled, the parent process intercepts termination signals " "(SIGTERM, SIGINT, SIGUSR1, SIGUSR2, SIGXCPU, SIGHUP) and writes a " "trigger file to /dev/shm. Worker processes check for this trigger " - "after each training step and collectively save a distributed " + "after each minibatch backward pass and collectively save a distributed " "checkpoint before exiting gracefully. Designed for OpenShift AI / " "KubeFlow training jobs where preemption signals must be handled." ), diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index 02adc240..f5edd015 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -1191,7 +1191,7 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: help=( "Enable on-demand full-state checkpointing triggered by Unix signals. " "When enabled, workers check for a trigger file in /dev/shm after each " - "training step and collectively save a distributed checkpoint before " + "minibatch backward pass and collectively save a distributed checkpoint before " "exiting. Designed for OpenShift AI / KubeFlow preemption handling." ), ) From d7b965b1a7215641385246ee11debb2949891972 Mon Sep 17 00:00:00 2001 From: Oleg S <97077423+RobotSail@users.noreply.github.com> Date: Fri, 20 Mar 2026 10:27:58 -0400 Subject: [PATCH 06/18] Add checkpoint checks at 5 synchronization points per training step Expand on-demand checkpointing to check for a trigger at five points: 1. Before each minibatch forward pass 2. Before each minibatch backward pass 3. After each minibatch backward pass (existing) 4. Before the optimizer step 5. After the optimizer step This minimizes the latency between a termination signal arriving and the checkpoint being saved, which is critical when the SIGKILL grace period is short (e.g. 30s on OpenShift/Kubernetes). Also cleans up the save-and-exit logic in train() by extracting a _save_and_exit() helper to eliminate three nearly identical blocks, and fixes _compute_average_loss to handle the case where the minibatch loop is interrupted before any forward pass completes. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../training/batch_loss_manager.py | 37 +++++++++---- src/instructlab/training/main_ds.py | 53 +++++++++++++------ .../training/on_demand_checkpoint.py | 28 +++++++--- 3 files changed, 85 insertions(+), 33 deletions(-) diff --git a/src/instructlab/training/batch_loss_manager.py b/src/instructlab/training/batch_loss_manager.py index b199ac17..46e2af30 100644 --- a/src/instructlab/training/batch_loss_manager.py +++ b/src/instructlab/training/batch_loss_manager.py @@ -74,10 +74,11 @@ def process_batch( Args: batch: List of minibatches to process - interrupt_check: Optional callback invoked after each minibatch's - backward pass. If it returns ``True``, gradient accumulation + interrupt_check: Optional callback invoked at three points per + minibatch: before forward, before backward, and after + backward. If it returns ``True`` at any point, processing stops early and ``BatchMetrics.interrupted`` is set. Used by - on-demand checkpointing to react within one fwd+bwd cycle + on-demand checkpointing to react as quickly as possible instead of waiting for the full optimizer step. Returns: @@ -97,6 +98,11 @@ def process_batch( # process each minibatch for mb in batch: + # Check for on-demand checkpoint before forward + if interrupt_check is not None and interrupt_check(): + interrupted = True + break + # extract minibatch-specific info micro_batch_size = mb["num_samples"] total_length = mb["total_length"] @@ -108,10 +114,16 @@ def process_batch( # prepare model inputs model_inputs = self._prepare_model_inputs(mb) - # compute loss and backward pass + # compute loss (forward pass) scaled_loss, raw_losses = self.model.compute_loss( model_inputs, self.world_size, batch_num_loss_counted_tokens ) + + # Check for on-demand checkpoint before backward + if interrupt_check is not None and interrupt_check(): + interrupted = True + break + self.accelerator.backward(scaled_loss) # accumulate losses @@ -120,7 +132,7 @@ def process_batch( if raw_losses.aux_loss is not None: accumulated_aux_loss += raw_losses.aux_loss - # check for early exit (e.g. on-demand checkpoint requested) + # Check for on-demand checkpoint after backward if interrupt_check is not None and interrupt_check(): interrupted = True break @@ -183,8 +195,8 @@ def _reduce_metrics( def _compute_average_loss( self, - accumulated_loss: torch.Tensor, - accumulated_aux_loss: torch.Tensor | None, + accumulated_loss: torch.Tensor | float, + accumulated_aux_loss: torch.Tensor | float | None, batch_num_loss_counted_tokens: int, ) -> float: """Compute average loss across all ranks for metrics logging.""" @@ -195,11 +207,16 @@ def _compute_average_loss( if accumulated_aux_loss is not None: total_batch_loss += accumulated_aux_loss + # Extract scalar value — accumulated_loss may be a plain float if the + # minibatch loop was interrupted before any forward pass completed. + if isinstance(total_batch_loss, torch.Tensor): + loss_value = total_batch_loss.detach().item() + else: + loss_value = float(total_batch_loss) + # reduce across ranks avg_loss_across_ranks = self.accelerator.reduce( - torch.tensor( - total_batch_loss.detach().item(), device=self.accelerator.device - ), + torch.tensor(loss_value, device=self.accelerator.device), reduction="mean", ).item() diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index f5edd015..2870bc51 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -196,6 +196,22 @@ def train( checkpoint_job_id = os.environ.get("INSTRUCTLAB_ON_DEMAND_JOB_ID") base_logger.info("On-demand checkpointing is enabled in worker process.") + def _save_and_exit(checkpoint_location: str) -> None: + """Save an on-demand checkpoint and exit the training loop.""" + save_on_demand_checkpoint( + args=args, + accelerator=accelerator, + model=model, + tokenizer=model.tokenizer, + samples_seen=samples_seen, + epoch=epoch, + is_lora=bool(args.lora_r), + ) + base_logger.info( + "On-demand checkpoint saved (%s). Exiting training.", + checkpoint_location, + ) + # Mini_trainer approach: batch_size will be determined dynamically by data loader # For save logic, use effective_batch_size since that's the target samples_seen = 0 @@ -251,22 +267,14 @@ def train( # since we want to preserve the pre-step model state for # exact resumption. if batch_metrics.interrupted: - save_on_demand_checkpoint( - args=args, - accelerator=accelerator, - model=model, - tokenizer=model.tokenizer, - samples_seen=samples_seen, - epoch=epoch, - is_lora=bool(args.lora_r), - ) - base_logger.info( - "On-demand checkpoint saved. Exiting training gracefully." - ) + _save_and_exit("during minibatch processing") return - # Update samples seen - samples_seen += batch_metrics.total_samples + if on_demand_checkpointing and check_checkpoint_requested( + checkpoint_job_id + ): + _save_and_exit("before optimizer step") + return base_logger.info( f"Epoch: {epoch}, Step: {global_step}, Rank: {dist.get_rank()}, loss = {avg_loss_across_ranks:.6f}, grad_accum_steps = {batch_metrics.grad_accum_steps}" @@ -275,6 +283,15 @@ def train( # Take optimizer step after all minibatches accelerator.take_optimizer_step() + # Update samples seen after the optimizer step has been applied + samples_seen += batch_metrics.total_samples + + if on_demand_checkpointing and check_checkpoint_requested( + checkpoint_job_id + ): + _save_and_exit("after optimizer step") + return + if local_rank == 0: elapsed_time = time.time() - start overall_throughput = batch_metrics.total_samples / elapsed_time @@ -1190,9 +1207,11 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: default=False, help=( "Enable on-demand full-state checkpointing triggered by Unix signals. " - "When enabled, workers check for a trigger file in /dev/shm after each " - "minibatch backward pass and collectively save a distributed checkpoint before " - "exiting. Designed for OpenShift AI / KubeFlow preemption handling." + "When enabled, workers check for a trigger file in /dev/shm at five " + "synchronization points per step (before/after each minibatch forward " + "and backward pass, and before/after the optimizer step) and collectively " + "save a distributed checkpoint before exiting. Designed for OpenShift AI / " + "KubeFlow preemption handling." ), ) parser.add_argument( diff --git a/src/instructlab/training/on_demand_checkpoint.py b/src/instructlab/training/on_demand_checkpoint.py index b9643531..20238ee5 100644 --- a/src/instructlab/training/on_demand_checkpoint.py +++ b/src/instructlab/training/on_demand_checkpoint.py @@ -20,12 +20,28 @@ node** can see the file instantly with zero network I/O. **Worker processes** (torchrun children): - After every optimizer step the training loop calls - ``check_checkpoint_requested()``. Each rank checks its local ``/dev/shm`` - for the trigger file, converts the boolean to a tensor, and does an - ``all_reduce(MAX)`` so that if *any* rank on *any* node detected the - trigger, *every* rank agrees to save a checkpoint. This works correctly in - multi-node training because all_reduce is a global collective. + The training loop calls ``check_checkpoint_requested()`` at five + synchronization points per training step, allowing the system to + react as quickly as possible to termination signals: + + 1. **Before each minibatch forward pass** — no partial computation; + the current state is saved as-is. + 2. **Before each minibatch backward pass** — the forward result is + discarded; the pre-step state is saved. + 3. **After each minibatch backward pass** — gradients are computed but + not yet applied; the pre-step state is saved (gradients will be + recomputed on resume). + 4. **Before the optimizer step** — all minibatches are done and + gradients are ready, but the step is skipped; the pre-step state + is saved. + 5. **After the optimizer step** — the step has been applied; + ``samples_seen`` is updated and the post-step state is saved. + + Each rank checks its local ``/dev/shm`` for the trigger file, converts + the boolean to a tensor, and does an ``all_reduce(MAX)`` so that if + *any* rank on *any* node detected the trigger, *every* rank agrees to + save a checkpoint. This works correctly in multi-node training because + all_reduce is a global collective. Signals handled --------------- From 25752c1059e91ef582b2b5cb1acd3e5a58e0433d Mon Sep 17 00:00:00 2001 From: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> Date: Mon, 13 Apr 2026 14:46:11 +0000 Subject: [PATCH 07/18] test: add unit tests for on-demand checkpointing 24 tests covering: - Trigger file helpers (path generation, write, exists, remove) - ParentSignalHandler (install, handle, idempotency, uninstall, real signal) - check_checkpoint_requested (trigger detection, cleanup, all_reduce consensus) - BatchLossManager interrupt handling (all 3 check points, early exit, float loss) --- tests/unit/test_on_demand_checkpoint.py | 349 ++++++++++++++++++++++++ 1 file changed, 349 insertions(+) create mode 100644 tests/unit/test_on_demand_checkpoint.py diff --git a/tests/unit/test_on_demand_checkpoint.py b/tests/unit/test_on_demand_checkpoint.py new file mode 100644 index 00000000..a805f49f --- /dev/null +++ b/tests/unit/test_on_demand_checkpoint.py @@ -0,0 +1,349 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for on-demand checkpointing.""" + +# Standard +from unittest.mock import MagicMock, call, patch +import os +import signal + +# Third Party +import pytest +import torch + +# First Party +from instructlab.training.on_demand_checkpoint import ( + _CATCHABLE_SIGNALS, + ParentSignalHandler, + _get_trigger_path, + check_checkpoint_requested, + remove_trigger_file, + trigger_file_exists, + write_trigger_file, +) + +# --------------------------------------------------------------------------- +# Trigger file helpers +# --------------------------------------------------------------------------- + + +class TestGetTriggerPath: + def test_without_job_id(self): + path = _get_trigger_path() + assert path.name == "instructlab_checkpoint_requested" + assert str(path.parent) == "/dev/shm" + + def test_with_job_id(self): + path = _get_trigger_path("my-job-123") + assert path.name == "instructlab_checkpoint_requested_my-job-123" + + def test_different_job_ids_produce_different_paths(self): + p1 = _get_trigger_path("job-a") + p2 = _get_trigger_path("job-b") + assert p1 != p2 + + +class TestWriteTriggerFile: + def test_creates_file(self, tmp_path): + with patch("instructlab.training.on_demand_checkpoint._TRIGGER_DIR", tmp_path): + path = write_trigger_file("test-write") + assert path.exists() + assert path.read_text() == "1" + + def test_returns_correct_path(self, tmp_path): + with patch("instructlab.training.on_demand_checkpoint._TRIGGER_DIR", tmp_path): + path = write_trigger_file("test-path") + assert path == tmp_path / "instructlab_checkpoint_requested_test-path" + + +class TestTriggerFileExists: + def test_returns_false_when_absent(self, tmp_path): + with patch("instructlab.training.on_demand_checkpoint._TRIGGER_DIR", tmp_path): + assert trigger_file_exists("nonexistent") is False + + def test_returns_true_when_present(self, tmp_path): + with patch("instructlab.training.on_demand_checkpoint._TRIGGER_DIR", tmp_path): + write_trigger_file("exists") + assert trigger_file_exists("exists") is True + + +class TestRemoveTriggerFile: + def test_removes_existing_file(self, tmp_path): + with patch("instructlab.training.on_demand_checkpoint._TRIGGER_DIR", tmp_path): + write_trigger_file("to-remove") + assert trigger_file_exists("to-remove") is True + remove_trigger_file("to-remove") + assert trigger_file_exists("to-remove") is False + + def test_noop_on_missing_file(self, tmp_path): + with patch("instructlab.training.on_demand_checkpoint._TRIGGER_DIR", tmp_path): + # Should not raise + remove_trigger_file("never-existed") + + +# --------------------------------------------------------------------------- +# ParentSignalHandler +# --------------------------------------------------------------------------- + + +class TestParentSignalHandler: + def test_install_registers_handlers(self): + handler = ParentSignalHandler(job_id="test-install") + original_handlers = {sig: signal.getsignal(sig) for sig in _CATCHABLE_SIGNALS} + try: + handler.install() + for sig in _CATCHABLE_SIGNALS: + current = signal.getsignal(sig) + assert current == handler._handle, ( + f"Expected handler._handle for {sig.name}, got {current}" + ) + finally: + # Restore originals regardless + for sig, orig in original_handlers.items(): + signal.signal(sig, orig) + + def test_handle_writes_trigger_and_records_signal(self, tmp_path): + with patch("instructlab.training.on_demand_checkpoint._TRIGGER_DIR", tmp_path): + handler = ParentSignalHandler(job_id="test-handle") + assert handler.signal_received is None + assert handler._trigger_written is False + + handler._handle(signal.SIGUSR1, None) + + assert handler.signal_received == signal.SIGUSR1 + assert handler._trigger_written is True + assert trigger_file_exists("test-handle") is True + + def test_handle_is_idempotent(self, tmp_path): + """Multiple signals should only write the trigger file once.""" + with patch("instructlab.training.on_demand_checkpoint._TRIGGER_DIR", tmp_path): + handler = ParentSignalHandler(job_id="test-idempotent") + + with patch( + "instructlab.training.on_demand_checkpoint.write_trigger_file" + ) as mock_write: + mock_write.return_value = tmp_path / "dummy" + handler._handle(signal.SIGUSR1, None) + handler._handle(signal.SIGTERM, None) + handler._handle(signal.SIGINT, None) + + # write_trigger_file called only once + mock_write.assert_called_once_with("test-idempotent") + + # signal_received should be the LAST signal + assert handler.signal_received == signal.SIGINT + + def test_uninstall_restores_original_handlers(self): + handler = ParentSignalHandler(job_id="test-uninstall") + originals = {sig: signal.getsignal(sig) for sig in _CATCHABLE_SIGNALS} + + handler.install() + # Verify handlers changed + for sig in _CATCHABLE_SIGNALS: + assert signal.getsignal(sig) == handler._handle + + handler.uninstall() + # Verify handlers restored + for sig in _CATCHABLE_SIGNALS: + assert signal.getsignal(sig) == originals[sig], f"{sig.name} not restored" + + def test_install_via_real_signal(self, tmp_path): + """End-to-end: install handler, send SIGUSR1, verify trigger written.""" + with patch("instructlab.training.on_demand_checkpoint._TRIGGER_DIR", tmp_path): + handler = ParentSignalHandler(job_id="test-real-signal") + handler.install() + try: + os.kill(os.getpid(), signal.SIGUSR1) + assert handler.signal_received == signal.SIGUSR1 + assert trigger_file_exists("test-real-signal") is True + finally: + handler.uninstall() + remove_trigger_file("test-real-signal") + + +# --------------------------------------------------------------------------- +# check_checkpoint_requested (worker-side, mocked dist) +# --------------------------------------------------------------------------- + + +class TestCheckCheckpointRequested: + def _mock_all_reduce_propagate(self, tensor, op=None): + """Mock all_reduce that just keeps the local value.""" + pass # tensor already has the local value + + def test_returns_false_when_no_trigger(self, tmp_path): + with ( + patch("instructlab.training.on_demand_checkpoint._TRIGGER_DIR", tmp_path), + patch("instructlab.training.on_demand_checkpoint.dist") as mock_dist, + patch("torch.cuda.current_device", return_value=0), + ): + mock_dist.all_reduce = self._mock_all_reduce_propagate + mock_dist.is_initialized.return_value = True + mock_dist.get_rank.return_value = 0 + + result = check_checkpoint_requested("test-no-trigger") + assert result is False + + def test_returns_true_when_trigger_exists(self, tmp_path): + with ( + patch("instructlab.training.on_demand_checkpoint._TRIGGER_DIR", tmp_path), + patch("instructlab.training.on_demand_checkpoint.dist") as mock_dist, + patch("torch.cuda.current_device", return_value=0), + ): + mock_dist.all_reduce = self._mock_all_reduce_propagate + mock_dist.is_initialized.return_value = True + mock_dist.get_rank.return_value = 0 + + write_trigger_file("test-trigger") + result = check_checkpoint_requested("test-trigger") + assert result is True + + def test_cleans_up_trigger_after_detection(self, tmp_path): + with ( + patch("instructlab.training.on_demand_checkpoint._TRIGGER_DIR", tmp_path), + patch("instructlab.training.on_demand_checkpoint.dist") as mock_dist, + patch("torch.cuda.current_device", return_value=0), + ): + mock_dist.all_reduce = self._mock_all_reduce_propagate + mock_dist.is_initialized.return_value = True + mock_dist.get_rank.return_value = 0 + + write_trigger_file("test-cleanup") + check_checkpoint_requested("test-cleanup") + assert trigger_file_exists("test-cleanup") is False + + def test_all_reduce_is_called(self, tmp_path): + with ( + patch("instructlab.training.on_demand_checkpoint._TRIGGER_DIR", tmp_path), + patch("instructlab.training.on_demand_checkpoint.dist") as mock_dist, + patch("torch.cuda.current_device", return_value=0), + ): + mock_dist.all_reduce = MagicMock() + mock_dist.is_initialized.return_value = True + mock_dist.get_rank.return_value = 0 + mock_dist.ReduceOp.MAX = torch.distributed.ReduceOp.MAX + + check_checkpoint_requested("test-allreduce") + mock_dist.all_reduce.assert_called_once() + # Verify MAX reduction op + _, kwargs = mock_dist.all_reduce.call_args + assert kwargs.get("op") == torch.distributed.ReduceOp.MAX + + +# --------------------------------------------------------------------------- +# BatchLossManager.process_batch interrupt handling +# --------------------------------------------------------------------------- + + +class TestBatchLossManagerInterrupt: + """Test that interrupt_check callbacks stop processing correctly.""" + + @pytest.fixture + def manager(self): + model = MagicMock() + model.compute_loss.return_value = ( + torch.tensor(1.0, requires_grad=True), + MagicMock(main_loss=torch.tensor(0.5), aux_loss=None), + ) + accelerator = MagicMock() + accelerator.device = torch.device("cpu") + # reduce is called with a 2-element tensor (metrics) and a scalar (loss). + # Return the input unchanged to simulate single-rank "reduction". + accelerator.reduce.side_effect = lambda t, **kw: t + accelerator.backward = MagicMock() + + # First Party + from instructlab.training.batch_loss_manager import BatchLossManager + + mgr = BatchLossManager( + model=model, + accelerator=accelerator, + world_size=1, + local_rank=0, + ) + return mgr + + def _make_batch(self, n_minibatches=3): + """Create a fake batch with n minibatches.""" + return [ + { + "input_ids": torch.randint(0, 100, (2, 32)), + "labels": torch.randint(0, 100, (2, 32)), + "num_samples": 2, + "total_length": 32, + "batch_num_loss_counted_tokens": 64, + } + for _ in range(n_minibatches) + ] + + def test_no_interrupt_processes_all_minibatches(self, manager): + batch = self._make_batch(3) + metrics, _ = manager.process_batch(batch, interrupt_check=None) + assert metrics.interrupted is False + assert metrics.grad_accum_steps == 3 + + def test_interrupt_before_first_forward(self, manager): + """Interrupt fires immediately — no forward/backward should run.""" + batch = self._make_batch(3) + metrics, _ = manager.process_batch(batch, interrupt_check=lambda: True) + assert metrics.interrupted is True + assert metrics.grad_accum_steps == 0 + manager.model.compute_loss.assert_not_called() + manager.accelerator.backward.assert_not_called() + + def test_interrupt_before_backward(self, manager): + """Interrupt fires after forward but before backward.""" + call_count = 0 + + def interrupt_on_second_call(): + nonlocal call_count + call_count += 1 + # First call: before forward — let it pass + # Second call: before backward — interrupt + return call_count == 2 + + batch = self._make_batch(3) + metrics, _ = manager.process_batch( + batch, interrupt_check=interrupt_on_second_call + ) + assert metrics.interrupted is True + # Forward ran once, backward never ran + assert manager.model.compute_loss.call_count == 1 + manager.accelerator.backward.assert_not_called() + assert metrics.grad_accum_steps == 0 + + def test_interrupt_after_backward(self, manager): + """Interrupt fires after first backward — one grad accum step done.""" + call_count = 0 + + def interrupt_on_third_call(): + nonlocal call_count + call_count += 1 + # Calls: 1=before_fwd, 2=before_bwd, 3=after_bwd (interrupt) + return call_count == 3 + + batch = self._make_batch(3) + metrics, _ = manager.process_batch( + batch, interrupt_check=interrupt_on_third_call + ) + assert metrics.interrupted is True + assert metrics.grad_accum_steps == 1 + manager.model.compute_loss.assert_called_once() + manager.accelerator.backward.assert_called_once() + + def test_interrupt_never_fires(self, manager): + """interrupt_check always returns False — full batch processed.""" + batch = self._make_batch(3) + metrics, _ = manager.process_batch(batch, interrupt_check=lambda: False) + assert metrics.interrupted is False + assert metrics.grad_accum_steps == 3 + + def test_compute_average_loss_handles_float_when_interrupted(self, manager): + """When interrupted before any forward, accumulated_loss is 0.0 (float).""" + # _compute_average_loss must handle float, not just Tensor + result = manager._compute_average_loss( + accumulated_loss=0.0, + accumulated_aux_loss=None, + batch_num_loss_counted_tokens=64, + ) + # Should not raise and should return a float + assert isinstance(result, float) From a4e83d5e911782e0288925df2b5b1f3a1b1384f2 Mon Sep 17 00:00:00 2001 From: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> Date: Mon, 13 Apr 2026 15:18:52 +0000 Subject: [PATCH 08/18] fix: correct help text for on_demand_checkpointing sync point count --- src/instructlab/training/main_ds.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index 2870bc51..179f0a74 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -1207,11 +1207,11 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: default=False, help=( "Enable on-demand full-state checkpointing triggered by Unix signals. " - "When enabled, workers check for a trigger file in /dev/shm at five " - "synchronization points per step (before/after each minibatch forward " - "and backward pass, and before/after the optimizer step) and collectively " - "save a distributed checkpoint before exiting. Designed for OpenShift AI / " - "KubeFlow preemption handling." + "When enabled, workers check for a trigger file in /dev/shm at " + "multiple synchronization points per step (before/after each " + "minibatch forward and backward pass, and before/after the optimizer " + "step) and collectively save a distributed checkpoint before exiting. " + "Designed for OpenShift AI / KubeFlow preemption handling." ), ) parser.add_argument( From afff1e66d89058d8a4fa4258b27f125194f0d55b Mon Sep 17 00:00:00 2001 From: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> Date: Tue, 24 Mar 2026 17:01:31 +0000 Subject: [PATCH 09/18] Add stale trigger cleanup and exact mid-epoch resume Two fixes to the on-demand checkpointing feature: 1. Stale trigger file cleanup: ParentSignalHandler.install() now checks for and removes any existing trigger file before installing signal handlers. If the file exists before handlers are installed, it's from a previous run that was killed before workers could clean it up. Prevents a new training job from immediately checkpointing and exiting. 2. Exact mid-epoch resume: save_on_demand_checkpoint() now persists global_step in the checkpoint metadata alongside current_epoch and samples_seen. On resume, load_latest_full_state() detects the global_step field and sets last_step accordingly, so the training loop fast-forwards to the exact step within the epoch. Without this, mid-epoch checkpoints would skip to the next epoch on resume, losing remaining steps. Tested with Qwen2-1.5B-Instruct on 2 GPUs: interrupted at step 19/25, checkpoint saved with global_step=19, resumed and completed steps 20-25. --- src/instructlab/training/main_ds.py | 1 + .../training/on_demand_checkpoint.py | 28 ++++++++++++++++++- src/instructlab/training/utils.py | 26 +++++++++++++---- 3 files changed, 49 insertions(+), 6 deletions(-) diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index 179f0a74..a2551df7 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -205,6 +205,7 @@ def _save_and_exit(checkpoint_location: str) -> None: tokenizer=model.tokenizer, samples_seen=samples_seen, epoch=epoch, + global_step=global_step, is_lora=bool(args.lora_r), ) base_logger.info( diff --git a/src/instructlab/training/on_demand_checkpoint.py b/src/instructlab/training/on_demand_checkpoint.py index 20238ee5..2d21a6d9 100644 --- a/src/instructlab/training/on_demand_checkpoint.py +++ b/src/instructlab/training/on_demand_checkpoint.py @@ -179,6 +179,25 @@ def __init__(self, job_id: Optional[str] = None): def install(self) -> None: """Register signal handlers for all catchable signals.""" + # Clear any stale trigger file from a previous run. If the file + # exists before we've even installed signal handlers, it cannot + # be from this job — it's left over from a prior run that was + # killed before the workers could clean it up. + if trigger_file_exists(self.job_id): + logger.info( + "On-demand checkpoint: clearing stale trigger file from " + "a previous run (job_id=%s).", + self.job_id, + ) + try: + remove_trigger_file(self.job_id) + except Exception: + logger.warning( + "On-demand checkpoint: failed to remove stale trigger file, " + "but continuing anyway.", + exc_info=True, + ) + for sig in _CATCHABLE_SIGNALS: try: self._original_handlers[sig] = signal.getsignal(sig) @@ -264,6 +283,7 @@ def save_on_demand_checkpoint( tokenizer, samples_seen: int, epoch: int, + global_step: int, is_lora: bool, ) -> None: """Save a full-state distributed checkpoint for on-demand resume. @@ -271,6 +291,10 @@ def save_on_demand_checkpoint( This is a thin wrapper that calls the existing ``save_checkpoint`` utility with ``full_state=True`` so that optimizer + LR scheduler state are also persisted, enabling exact training resumption. + + The ``global_step`` is saved to the checkpoint metadata so that + on resume the training loop can fast-forward to the exact step + within the epoch where training was interrupted. """ # First Party from instructlab.training.utils import save_checkpoint @@ -279,8 +303,9 @@ def save_on_demand_checkpoint( if local_rank == 0: logger.info( "On-demand checkpoint: saving full-state checkpoint at " - "epoch=%d, samples_seen=%d", + "epoch=%d, global_step=%d, samples_seen=%d", epoch, + global_step, samples_seen, ) @@ -294,6 +319,7 @@ def save_on_demand_checkpoint( full_state=True, hf_format=True, epoch=epoch, + global_step=global_step, ) if local_rank == 0: diff --git a/src/instructlab/training/utils.py b/src/instructlab/training/utils.py index a7adb32b..410ad132 100644 --- a/src/instructlab/training/utils.py +++ b/src/instructlab/training/utils.py @@ -835,6 +835,7 @@ def save_checkpoint( epoch: int = None, hf_format: bool = True, full_state: bool = False, + global_step: int | None = None, ) -> None: if hf_format: save_hf_format_accelerate( @@ -853,10 +854,11 @@ def save_checkpoint( is_lora=is_lora, epoch=epoch, samples_seen=samples_seen, + global_step=global_step, ) -def save_full_state(args, accelerator, is_lora: bool, epoch: int, samples_seen: int): +def save_full_state(args, accelerator, is_lora: bool, epoch: int, samples_seen: int, global_step: int | None = None): """ Saves model, optimizer, and lr_scheduler state. TODO: save model config - decided not to do this. @@ -889,9 +891,11 @@ def _get_state_dict_patched(model, unwrap=False): # save metadata file for current training status if accelerator.is_main_process: - # TODO: should we set the global_step here rather than calculating global_step - # based on samples_seen? metadata = {"current_epoch": epoch, "samples_seen": samples_seen} + # Save global_step when provided (on-demand mid-epoch checkpoints) + # so that resume can fast-forward to the exact training step. + if global_step is not None: + metadata["global_step"] = global_step torch.save(metadata, output_dir / "training_metadata.json") log_rank_0(f"\033[93mSaving training state: {metadata}\033[0m", to_print=True) @@ -936,10 +940,22 @@ def load_latest_full_state(args, accelerator) -> None: f"\033[93mTraining metadata loaded: {training_metadata}\033[0m", to_print=True ) - # previous epoch is basis for current epoch. - args.__dict__["current_epoch"] = training_metadata["current_epoch"] + 1 args.__dict__["samples_seen"] = training_metadata["samples_seen"] + if "global_step" in training_metadata: + # On-demand mid-epoch checkpoint: resume at the same epoch and + # fast-forward to the exact step via last_step. + args.__dict__["current_epoch"] = training_metadata["current_epoch"] + args.__dict__["last_step"] = training_metadata["global_step"] + log_rank_0( + f"\033[93mResuming mid-epoch: epoch={args.current_epoch}, " + f"last_step={args.last_step}\033[0m", + to_print=True, + ) + else: + # Epoch-boundary checkpoint: start at the next epoch. + args.__dict__["current_epoch"] = training_metadata["current_epoch"] + 1 + def freeze_router_params(model: Model): """ From e2fe8bf2b32cc3858642559fe63d6ec492b27921 Mon Sep 17 00:00:00 2001 From: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> Date: Tue, 24 Mar 2026 17:06:41 +0000 Subject: [PATCH 10/18] Fix ruff formatting in utils.py --- src/instructlab/training/utils.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/instructlab/training/utils.py b/src/instructlab/training/utils.py index 410ad132..bc4d8ed6 100644 --- a/src/instructlab/training/utils.py +++ b/src/instructlab/training/utils.py @@ -858,7 +858,14 @@ def save_checkpoint( ) -def save_full_state(args, accelerator, is_lora: bool, epoch: int, samples_seen: int, global_step: int | None = None): +def save_full_state( + args, + accelerator, + is_lora: bool, + epoch: int, + samples_seen: int, + global_step: int | None = None, +): """ Saves model, optimizer, and lr_scheduler state. TODO: save model config - decided not to do this. From 254b662f3d452ec73c1122d1d9229ff915de0dd8 Mon Sep 17 00:00:00 2001 From: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> Date: Tue, 24 Mar 2026 17:07:24 +0000 Subject: [PATCH 11/18] Add documentation for on-demand checkpointing feature --- docs/on_demand_checkpointing.md | 139 ++++++++++++++++++++++++++++++++ 1 file changed, 139 insertions(+) create mode 100644 docs/on_demand_checkpointing.md diff --git a/docs/on_demand_checkpointing.md b/docs/on_demand_checkpointing.md new file mode 100644 index 00000000..b3cb9f52 --- /dev/null +++ b/docs/on_demand_checkpointing.md @@ -0,0 +1,139 @@ +# On-Demand Checkpointing + +On-demand checkpointing enables graceful checkpoint-and-exit when termination +signals are received during training. It is designed for environments like +OpenShift AI and KubeFlow where training jobs can be preempted at any time. + +## How It Works + +When enabled, the system installs signal handlers in the parent (launcher) +process that catch termination signals before the hard SIGKILL. When a signal +arrives: + +1. The parent writes a trigger file to `/dev/shm` (a fast, node-local tmpfs). +2. Worker processes check for the trigger file at multiple synchronization + points during each training step. +3. Workers coordinate via `all_reduce` so that if any rank on any node + detects the trigger, all ranks agree to save. +4. A full-state checkpoint (model + optimizer + LR scheduler) is saved. +5. Workers exit cleanly. + +On resume, the training loop detects the mid-epoch checkpoint, restores the +full training state, and fast-forwards to the exact step where training was +interrupted. + +## Signals Handled + +The following signals are intercepted (SIGKILL cannot be caught): + +| Signal | Source | +|--------|--------| +| SIGTERM | Kubernetes graceful shutdown (default) | +| SIGINT | Ctrl-C / some job controllers | +| SIGUSR1 | Custom preemption controllers | +| SIGUSR2 | Custom preemption controllers | +| SIGXCPU | CPU time limit exceeded (resource quotas) | +| SIGHUP | Terminal disconnect / some eviction paths | + +## Usage + +### Python API + +```python +from instructlab.training.config import TorchrunArgs, TrainingArgs +from instructlab.training import run_training + +torch_args = TorchrunArgs( + nproc_per_node=8, + nnodes=1, + node_rank=0, + rdzv_id=12345, + rdzv_endpoint="127.0.0.1:29500", +) + +train_args = TrainingArgs( + model_path="Qwen/Qwen2-1.5B-Instruct", + data_path="./data.jsonl", + data_output_dir="./processed", + ckpt_output_dir="./checkpoints", + num_epochs=3, + on_demand_checkpointing=True, # Enable the feature + # ... other training args +) + +run_training(torch_args, train_args) +``` + +### CLI + +```bash +torchrun --nproc-per-node=8 \ + instructlab/training/main_ds.py \ + --model_name_or_path Qwen/Qwen2-1.5B-Instruct \ + --data_path ./data.jsonl \ + --output_dir ./checkpoints \ + --on_demand_checkpointing \ + ... +``` + +## Resume Behavior + +When a checkpoint saved by on-demand checkpointing is found in the output +directory, the training loop automatically: + +1. Loads the full optimizer and LR scheduler state from the checkpoint. +2. Reads `global_step` from the checkpoint metadata to determine where + training was interrupted. +3. Resumes at the **same epoch** and fast-forwards to the exact step, + skipping already-completed batches. + +This differs from epoch-boundary checkpoints, which resume at the start of +the next epoch. + +### Checkpoint Metadata + +On-demand checkpoints store additional metadata compared to epoch-boundary +checkpoints: + +```json +{ + "current_epoch": 0, + "samples_seen": 144, + "global_step": 19 +} +``` + +The `global_step` field is what distinguishes an on-demand checkpoint from an +epoch-boundary one. When present, the resume logic keeps `current_epoch` +unchanged and sets `last_step = global_step` to enable fast-forwarding. + +## Multi-Node Training + +The trigger file mechanism works correctly across multiple nodes: + +- The trigger file lives on `/dev/shm`, which is node-local. Each node's + parent process writes its own trigger file when it receives a signal. +- Workers use `all_reduce(MAX)` to synchronize: if any rank on any node + detects a trigger, all ranks on all nodes agree to save. +- The checkpoint itself is saved to the shared filesystem (the configured + `ckpt_output_dir`), accessible by all nodes on resume. + +## Stale Trigger Files + +If a previous training run was killed before workers could clean up the +trigger file, the new run's `ParentSignalHandler` detects and removes it +during initialization. This prevents a new job from immediately +checkpointing and exiting due to a leftover trigger from a prior run. + +## Kubernetes / OpenShift Configuration + +To give workers enough time to save a checkpoint before the hard SIGKILL, +increase `terminationGracePeriodSeconds` in your pod spec: + +```yaml +spec: + terminationGracePeriodSeconds: 300 # 5 minutes +``` + +The default of 30 seconds may not be enough for large models. The checkpoint +save time depends on model size, number of GPUs, and filesystem speed. From 6c4ea8a105daea989d90abcd7c502cdff768d772 Mon Sep 17 00:00:00 2001 From: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> Date: Tue, 24 Mar 2026 17:13:40 +0000 Subject: [PATCH 12/18] Document manual trigger file creation for on-demand checkpointing --- docs/on_demand_checkpointing.md | 46 +++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/docs/on_demand_checkpointing.md b/docs/on_demand_checkpointing.md index b3cb9f52..58b4555b 100644 --- a/docs/on_demand_checkpointing.md +++ b/docs/on_demand_checkpointing.md @@ -125,6 +125,52 @@ trigger file, the new run's `ParentSignalHandler` detects and removes it during initialization. This prevents a new job from immediately checkpointing and exiting due to a leftover trigger from a prior run. +## Manually Triggering a Checkpoint + +You can trigger a checkpoint-and-exit without sending a signal by writing +the trigger file directly. This is useful for debugging, testing, or +integration with custom orchestration that doesn't use Unix signals. + +The trigger file path is: + +``` +/dev/shm/instructlab_checkpoint_requested_ +``` + +Where `` is the `rdzv_id` passed to `TorchrunArgs`. If no job ID +was set, the path is `/dev/shm/instructlab_checkpoint_requested` (no suffix). + +To trigger a checkpoint from a shell on any node in the training cluster: + +```bash +# Find the job ID (it's the rdzv_id, also stored in the environment) +JOB_ID=$(printenv INSTRUCTLAB_ON_DEMAND_JOB_ID) + +# Write the trigger file +echo 1 > /dev/shm/instructlab_checkpoint_requested_${JOB_ID} +``` + +Or without the job ID: + +```bash +echo 1 > /dev/shm/instructlab_checkpoint_requested +``` + +Workers check for the trigger file at each synchronization point in the +training loop (multiple times per step). Once any rank on any node detects +it, all ranks coordinate via `all_reduce` to save a checkpoint and exit. + +You only need to write the file on **one node** — the `all_reduce` ensures +all nodes participate even if they don't see the file locally. + +From Python: + +```python +from instructlab.training.on_demand_checkpoint import write_trigger_file + +write_trigger_file(job_id="12345") # or job_id=None for default path +``` + ## Kubernetes / OpenShift Configuration To give workers enough time to save a checkpoint before the hard SIGKILL, From c74b452b600c0e88085cf71a7323cf0ce375a4c9 Mon Sep 17 00:00:00 2001 From: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> Date: Tue, 24 Mar 2026 17:15:23 +0000 Subject: [PATCH 13/18] Improve manual trigger docs: clarify job ID requirement --- docs/on_demand_checkpointing.md | 28 ++++++++++++---------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/docs/on_demand_checkpointing.md b/docs/on_demand_checkpointing.md index 58b4555b..70da36a5 100644 --- a/docs/on_demand_checkpointing.md +++ b/docs/on_demand_checkpointing.md @@ -131,29 +131,25 @@ You can trigger a checkpoint-and-exit without sending a signal by writing the trigger file directly. This is useful for debugging, testing, or integration with custom orchestration that doesn't use Unix signals. -The trigger file path is: - -``` -/dev/shm/instructlab_checkpoint_requested_ -``` - -Where `` is the `rdzv_id` passed to `TorchrunArgs`. If no job ID -was set, the path is `/dev/shm/instructlab_checkpoint_requested` (no suffix). - -To trigger a checkpoint from a shell on any node in the training cluster: +The trigger file lives in `/dev/shm` and is named using the job ID that +the training process was started with. To find the correct filename and +create the trigger: ```bash -# Find the job ID (it's the rdzv_id, also stored in the environment) -JOB_ID=$(printenv INSTRUCTLAB_ON_DEMAND_JOB_ID) +# Find the trigger filename for the running job — look for the job ID +# that was set when training started +ls /dev/shm/instructlab_checkpoint_requested* -# Write the trigger file -echo 1 > /dev/shm/instructlab_checkpoint_requested_${JOB_ID} +# Create the trigger file (use the exact name shown by ls above) +touch /dev/shm/instructlab_checkpoint_requested_ ``` -Or without the job ID: +If you don't know the job ID, you can read it from the training process +environment: ```bash -echo 1 > /dev/shm/instructlab_checkpoint_requested +# From inside the same pod / container where training is running +cat /proc/$(pgrep -f main_ds.py | head -1)/environ | tr '\0' '\n' | grep INSTRUCTLAB_ON_DEMAND_JOB_ID ``` Workers check for the trigger file at each synchronization point in the From dee3d5fada61ffbbdb38f9d235cab5d451588846 Mon Sep 17 00:00:00 2001 From: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> Date: Tue, 24 Mar 2026 17:19:41 +0000 Subject: [PATCH 14/18] Simplify trigger file: remove job_id namespacing Drop the job_id suffix from the trigger file path. The file is now always /dev/shm/instructlab_checkpoint_requested with no suffix. The namespacing was defensive against concurrent jobs sharing /dev/shm, but in practice Kubernetes pods each get their own /dev/shm. This makes manual triggering trivial: touch /dev/shm/instructlab_checkpoint_requested --- docs/on_demand_checkpointing.md | 22 ++------- src/instructlab/training/main_ds.py | 19 ++------ .../training/on_demand_checkpoint.py | 46 +++++++------------ 3 files changed, 26 insertions(+), 61 deletions(-) diff --git a/docs/on_demand_checkpointing.md b/docs/on_demand_checkpointing.md index 70da36a5..b21d71d5 100644 --- a/docs/on_demand_checkpointing.md +++ b/docs/on_demand_checkpointing.md @@ -131,25 +131,11 @@ You can trigger a checkpoint-and-exit without sending a signal by writing the trigger file directly. This is useful for debugging, testing, or integration with custom orchestration that doesn't use Unix signals. -The trigger file lives in `/dev/shm` and is named using the job ID that -the training process was started with. To find the correct filename and -create the trigger: +The trigger file is always at a fixed path. To trigger a checkpoint +(e.g. via `kubectl exec` into the training pod): ```bash -# Find the trigger filename for the running job — look for the job ID -# that was set when training started -ls /dev/shm/instructlab_checkpoint_requested* - -# Create the trigger file (use the exact name shown by ls above) -touch /dev/shm/instructlab_checkpoint_requested_ -``` - -If you don't know the job ID, you can read it from the training process -environment: - -```bash -# From inside the same pod / container where training is running -cat /proc/$(pgrep -f main_ds.py | head -1)/environ | tr '\0' '\n' | grep INSTRUCTLAB_ON_DEMAND_JOB_ID +touch /dev/shm/instructlab_checkpoint_requested ``` Workers check for the trigger file at each synchronization point in the @@ -164,7 +150,7 @@ From Python: ```python from instructlab.training.on_demand_checkpoint import write_trigger_file -write_trigger_file(job_id="12345") # or job_id=None for default path +write_trigger_file() ``` ## Kubernetes / OpenShift Configuration diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index a2551df7..1e6bd43b 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -185,7 +185,6 @@ def train( base_logger = logging.getLogger("instructlab.training") # Import on-demand checkpointing utilities once if the feature is enabled - checkpoint_job_id = None if on_demand_checkpointing: # First Party from instructlab.training.on_demand_checkpoint import ( @@ -193,7 +192,6 @@ def train( save_on_demand_checkpoint, ) - checkpoint_job_id = os.environ.get("INSTRUCTLAB_ON_DEMAND_JOB_ID") base_logger.info("On-demand checkpointing is enabled in worker process.") def _save_and_exit(checkpoint_location: str) -> None: @@ -255,7 +253,7 @@ def _save_and_exit(checkpoint_location: str) -> None: # the check runs after every minibatch backward rather than # waiting for the full optimizer step. _interrupt_check = ( - (lambda: check_checkpoint_requested(checkpoint_job_id)) + (lambda: check_checkpoint_requested()) if on_demand_checkpointing else None ) @@ -271,9 +269,7 @@ def _save_and_exit(checkpoint_location: str) -> None: _save_and_exit("during minibatch processing") return - if on_demand_checkpointing and check_checkpoint_requested( - checkpoint_job_id - ): + if on_demand_checkpointing and check_checkpoint_requested(): _save_and_exit("before optimizer step") return @@ -287,9 +283,7 @@ def _save_and_exit(checkpoint_location: str) -> None: # Update samples seen after the optimizer step has been applied samples_seen += batch_metrics.total_samples - if on_demand_checkpointing and check_checkpoint_requested( - checkpoint_job_id - ): + if on_demand_checkpointing and check_checkpoint_requested(): _save_and_exit("after optimizer step") return @@ -881,14 +875,11 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: # Use rdzv_id to namespace the trigger file so concurrent jobs # sharing /dev/shm don't interfere with each other. - checkpoint_job_id = str(torch_args.rdzv_id) - os.environ["INSTRUCTLAB_ON_DEMAND_JOB_ID"] = checkpoint_job_id - signal_handler = ParentSignalHandler(job_id=checkpoint_job_id) + signal_handler = ParentSignalHandler() signal_handler.install() logger.info( - "On-demand checkpointing is ENABLED (job_id=%s). " + "On-demand checkpointing is ENABLED. " "Termination signals will trigger a full-state checkpoint before exit.", - checkpoint_job_id, ) process = None diff --git a/src/instructlab/training/on_demand_checkpoint.py b/src/instructlab/training/on_demand_checkpoint.py index 2d21a6d9..2bf0361a 100644 --- a/src/instructlab/training/on_demand_checkpoint.py +++ b/src/instructlab/training/on_demand_checkpoint.py @@ -92,23 +92,18 @@ _TRIGGER_FILENAME = "instructlab_checkpoint_requested" -def _get_trigger_path(job_id: Optional[str] = None) -> Path: - """Return the path to the checkpoint trigger file. - - An optional *job_id* can be supplied to avoid collisions if multiple - training jobs share the same ``/dev/shm`` (unlikely but possible). - """ - name = f"{_TRIGGER_FILENAME}_{job_id}" if job_id else _TRIGGER_FILENAME - return _TRIGGER_DIR / name +def _get_trigger_path() -> Path: + """Return the path to the checkpoint trigger file.""" + return _TRIGGER_DIR / _TRIGGER_FILENAME -def write_trigger_file(job_id: Optional[str] = None) -> Path: +def write_trigger_file() -> Path: """Create the trigger file that tells workers to checkpoint. This is called from the *parent* process signal handler. Returns the path that was written. """ - path = _get_trigger_path(job_id) + path = _get_trigger_path() # Use a atomic write via tempfile + rename to avoid partial reads. fd, tmp = tempfile.mkstemp(dir=_TRIGGER_DIR, prefix=".ckpt_trigger_") try: @@ -123,14 +118,14 @@ def write_trigger_file(job_id: Optional[str] = None) -> Path: return path -def trigger_file_exists(job_id: Optional[str] = None) -> bool: +def trigger_file_exists() -> bool: """Check whether the trigger file exists (worker-side).""" - return _get_trigger_path(job_id).exists() + return _get_trigger_path().exists() -def remove_trigger_file(job_id: Optional[str] = None) -> None: +def remove_trigger_file() -> None: """Remove the trigger file after the checkpoint has been saved.""" - path = _get_trigger_path(job_id) + path = _get_trigger_path() try: path.unlink(missing_ok=True) except OSError: @@ -164,15 +159,9 @@ class ParentSignalHandler: The handler is idempotent – multiple signals will not create multiple trigger files. - Parameters - ---------- - job_id : str, optional - Unique identifier for this training job. Used to namespace the - trigger file. """ - def __init__(self, job_id: Optional[str] = None): - self.job_id = job_id + def __init__(self): self.signal_received: Optional[signal.Signals] = None self._original_handlers: dict[signal.Signals, _SignalHandler] = {} self._trigger_written = False @@ -183,14 +172,13 @@ def install(self) -> None: # exists before we've even installed signal handlers, it cannot # be from this job — it's left over from a prior run that was # killed before the workers could clean it up. - if trigger_file_exists(self.job_id): + if trigger_file_exists(): logger.info( "On-demand checkpoint: clearing stale trigger file from " - "a previous run (job_id=%s).", - self.job_id, + "a previous run.", ) try: - remove_trigger_file(self.job_id) + remove_trigger_file() except Exception: logger.warning( "On-demand checkpoint: failed to remove stale trigger file, " @@ -232,7 +220,7 @@ def _handle(self, signum: int, _frame) -> None: self.signal_received = sig if not self._trigger_written: - write_trigger_file(self.job_id) + write_trigger_file() self._trigger_written = True @@ -241,7 +229,7 @@ def _handle(self, signum: int, _frame) -> None: # --------------------------------------------------------------------------- -def check_checkpoint_requested(job_id: Optional[str] = None) -> bool: +def check_checkpoint_requested() -> bool: """Check across all ranks whether an on-demand checkpoint was requested. This function must be called by **all ranks** at the same point in the @@ -250,7 +238,7 @@ def check_checkpoint_requested(job_id: Optional[str] = None) -> bool: Returns ``True`` if any rank detected the trigger file, meaning all ranks should save a checkpoint. """ - local_trigger = trigger_file_exists(job_id) + local_trigger = trigger_file_exists() # Convert to a tensor and all-reduce (MAX) so that if ANY rank on ANY # node saw the trigger, every rank gets True. @@ -271,7 +259,7 @@ def check_checkpoint_requested(job_id: Optional[str] = None) -> bool: ) # Clean up the trigger file so that if the process somehow # continues, we don't save again immediately. - remove_trigger_file(job_id) + remove_trigger_file() return requested From 07576b44616e9298d3e7c7b534691ef6177db020 Mon Sep 17 00:00:00 2001 From: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> Date: Tue, 24 Mar 2026 17:49:22 +0000 Subject: [PATCH 15/18] Fix pylint: remove unnecessary lambda wrapper --- src/instructlab/training/main_ds.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index 1e6bd43b..08ee3f02 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -253,9 +253,7 @@ def _save_and_exit(checkpoint_location: str) -> None: # the check runs after every minibatch backward rather than # waiting for the full optimizer step. _interrupt_check = ( - (lambda: check_checkpoint_requested()) - if on_demand_checkpointing - else None + check_checkpoint_requested if on_demand_checkpointing else None ) batch_metrics, avg_loss_across_ranks = batch_loss_manager.process_batch( batch, interrupt_check=_interrupt_check From d745e9ff0a8fff646f10a62213cba2de835fee28 Mon Sep 17 00:00:00 2001 From: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> Date: Thu, 16 Apr 2026 19:58:46 +0000 Subject: [PATCH 16/18] fix: update unit tests for simplified trigger file API (no job_id) --- tests/unit/test_on_demand_checkpoint.py | 84 +++++++++---------------- 1 file changed, 30 insertions(+), 54 deletions(-) diff --git a/tests/unit/test_on_demand_checkpoint.py b/tests/unit/test_on_demand_checkpoint.py index a805f49f..2f243f91 100644 --- a/tests/unit/test_on_demand_checkpoint.py +++ b/tests/unit/test_on_demand_checkpoint.py @@ -2,7 +2,7 @@ """Tests for on-demand checkpointing.""" # Standard -from unittest.mock import MagicMock, call, patch +from unittest.mock import MagicMock, patch import os import signal @@ -27,57 +27,48 @@ class TestGetTriggerPath: - def test_without_job_id(self): + def test_returns_correct_name(self): path = _get_trigger_path() assert path.name == "instructlab_checkpoint_requested" assert str(path.parent) == "/dev/shm" - def test_with_job_id(self): - path = _get_trigger_path("my-job-123") - assert path.name == "instructlab_checkpoint_requested_my-job-123" - - def test_different_job_ids_produce_different_paths(self): - p1 = _get_trigger_path("job-a") - p2 = _get_trigger_path("job-b") - assert p1 != p2 - class TestWriteTriggerFile: def test_creates_file(self, tmp_path): with patch("instructlab.training.on_demand_checkpoint._TRIGGER_DIR", tmp_path): - path = write_trigger_file("test-write") + path = write_trigger_file() assert path.exists() assert path.read_text() == "1" def test_returns_correct_path(self, tmp_path): with patch("instructlab.training.on_demand_checkpoint._TRIGGER_DIR", tmp_path): - path = write_trigger_file("test-path") - assert path == tmp_path / "instructlab_checkpoint_requested_test-path" + path = write_trigger_file() + assert path == tmp_path / "instructlab_checkpoint_requested" class TestTriggerFileExists: def test_returns_false_when_absent(self, tmp_path): with patch("instructlab.training.on_demand_checkpoint._TRIGGER_DIR", tmp_path): - assert trigger_file_exists("nonexistent") is False + assert trigger_file_exists() is False def test_returns_true_when_present(self, tmp_path): with patch("instructlab.training.on_demand_checkpoint._TRIGGER_DIR", tmp_path): - write_trigger_file("exists") - assert trigger_file_exists("exists") is True + write_trigger_file() + assert trigger_file_exists() is True class TestRemoveTriggerFile: def test_removes_existing_file(self, tmp_path): with patch("instructlab.training.on_demand_checkpoint._TRIGGER_DIR", tmp_path): - write_trigger_file("to-remove") - assert trigger_file_exists("to-remove") is True - remove_trigger_file("to-remove") - assert trigger_file_exists("to-remove") is False + write_trigger_file() + assert trigger_file_exists() is True + remove_trigger_file() + assert trigger_file_exists() is False def test_noop_on_missing_file(self, tmp_path): with patch("instructlab.training.on_demand_checkpoint._TRIGGER_DIR", tmp_path): # Should not raise - remove_trigger_file("never-existed") + remove_trigger_file() # --------------------------------------------------------------------------- @@ -87,7 +78,7 @@ def test_noop_on_missing_file(self, tmp_path): class TestParentSignalHandler: def test_install_registers_handlers(self): - handler = ParentSignalHandler(job_id="test-install") + handler = ParentSignalHandler() original_handlers = {sig: signal.getsignal(sig) for sig in _CATCHABLE_SIGNALS} try: handler.install() @@ -97,13 +88,12 @@ def test_install_registers_handlers(self): f"Expected handler._handle for {sig.name}, got {current}" ) finally: - # Restore originals regardless for sig, orig in original_handlers.items(): signal.signal(sig, orig) def test_handle_writes_trigger_and_records_signal(self, tmp_path): with patch("instructlab.training.on_demand_checkpoint._TRIGGER_DIR", tmp_path): - handler = ParentSignalHandler(job_id="test-handle") + handler = ParentSignalHandler() assert handler.signal_received is None assert handler._trigger_written is False @@ -111,12 +101,12 @@ def test_handle_writes_trigger_and_records_signal(self, tmp_path): assert handler.signal_received == signal.SIGUSR1 assert handler._trigger_written is True - assert trigger_file_exists("test-handle") is True + assert trigger_file_exists() is True def test_handle_is_idempotent(self, tmp_path): """Multiple signals should only write the trigger file once.""" with patch("instructlab.training.on_demand_checkpoint._TRIGGER_DIR", tmp_path): - handler = ParentSignalHandler(job_id="test-idempotent") + handler = ParentSignalHandler() with patch( "instructlab.training.on_demand_checkpoint.write_trigger_file" @@ -126,38 +116,35 @@ def test_handle_is_idempotent(self, tmp_path): handler._handle(signal.SIGTERM, None) handler._handle(signal.SIGINT, None) - # write_trigger_file called only once - mock_write.assert_called_once_with("test-idempotent") + mock_write.assert_called_once() # signal_received should be the LAST signal assert handler.signal_received == signal.SIGINT def test_uninstall_restores_original_handlers(self): - handler = ParentSignalHandler(job_id="test-uninstall") + handler = ParentSignalHandler() originals = {sig: signal.getsignal(sig) for sig in _CATCHABLE_SIGNALS} handler.install() - # Verify handlers changed for sig in _CATCHABLE_SIGNALS: assert signal.getsignal(sig) == handler._handle handler.uninstall() - # Verify handlers restored for sig in _CATCHABLE_SIGNALS: assert signal.getsignal(sig) == originals[sig], f"{sig.name} not restored" def test_install_via_real_signal(self, tmp_path): """End-to-end: install handler, send SIGUSR1, verify trigger written.""" with patch("instructlab.training.on_demand_checkpoint._TRIGGER_DIR", tmp_path): - handler = ParentSignalHandler(job_id="test-real-signal") + handler = ParentSignalHandler() handler.install() try: os.kill(os.getpid(), signal.SIGUSR1) assert handler.signal_received == signal.SIGUSR1 - assert trigger_file_exists("test-real-signal") is True + assert trigger_file_exists() is True finally: handler.uninstall() - remove_trigger_file("test-real-signal") + remove_trigger_file() # --------------------------------------------------------------------------- @@ -168,7 +155,7 @@ def test_install_via_real_signal(self, tmp_path): class TestCheckCheckpointRequested: def _mock_all_reduce_propagate(self, tensor, op=None): """Mock all_reduce that just keeps the local value.""" - pass # tensor already has the local value + pass def test_returns_false_when_no_trigger(self, tmp_path): with ( @@ -180,8 +167,7 @@ def test_returns_false_when_no_trigger(self, tmp_path): mock_dist.is_initialized.return_value = True mock_dist.get_rank.return_value = 0 - result = check_checkpoint_requested("test-no-trigger") - assert result is False + assert check_checkpoint_requested() is False def test_returns_true_when_trigger_exists(self, tmp_path): with ( @@ -193,9 +179,8 @@ def test_returns_true_when_trigger_exists(self, tmp_path): mock_dist.is_initialized.return_value = True mock_dist.get_rank.return_value = 0 - write_trigger_file("test-trigger") - result = check_checkpoint_requested("test-trigger") - assert result is True + write_trigger_file() + assert check_checkpoint_requested() is True def test_cleans_up_trigger_after_detection(self, tmp_path): with ( @@ -207,9 +192,9 @@ def test_cleans_up_trigger_after_detection(self, tmp_path): mock_dist.is_initialized.return_value = True mock_dist.get_rank.return_value = 0 - write_trigger_file("test-cleanup") - check_checkpoint_requested("test-cleanup") - assert trigger_file_exists("test-cleanup") is False + write_trigger_file() + check_checkpoint_requested() + assert trigger_file_exists() is False def test_all_reduce_is_called(self, tmp_path): with ( @@ -222,9 +207,8 @@ def test_all_reduce_is_called(self, tmp_path): mock_dist.get_rank.return_value = 0 mock_dist.ReduceOp.MAX = torch.distributed.ReduceOp.MAX - check_checkpoint_requested("test-allreduce") + check_checkpoint_requested() mock_dist.all_reduce.assert_called_once() - # Verify MAX reduction op _, kwargs = mock_dist.all_reduce.call_args assert kwargs.get("op") == torch.distributed.ReduceOp.MAX @@ -246,8 +230,6 @@ def manager(self): ) accelerator = MagicMock() accelerator.device = torch.device("cpu") - # reduce is called with a 2-element tensor (metrics) and a scalar (loss). - # Return the input unchanged to simulate single-rank "reduction". accelerator.reduce.side_effect = lambda t, **kw: t accelerator.backward = MagicMock() @@ -297,8 +279,6 @@ def test_interrupt_before_backward(self, manager): def interrupt_on_second_call(): nonlocal call_count call_count += 1 - # First call: before forward — let it pass - # Second call: before backward — interrupt return call_count == 2 batch = self._make_batch(3) @@ -306,7 +286,6 @@ def interrupt_on_second_call(): batch, interrupt_check=interrupt_on_second_call ) assert metrics.interrupted is True - # Forward ran once, backward never ran assert manager.model.compute_loss.call_count == 1 manager.accelerator.backward.assert_not_called() assert metrics.grad_accum_steps == 0 @@ -318,7 +297,6 @@ def test_interrupt_after_backward(self, manager): def interrupt_on_third_call(): nonlocal call_count call_count += 1 - # Calls: 1=before_fwd, 2=before_bwd, 3=after_bwd (interrupt) return call_count == 3 batch = self._make_batch(3) @@ -339,11 +317,9 @@ def test_interrupt_never_fires(self, manager): def test_compute_average_loss_handles_float_when_interrupted(self, manager): """When interrupted before any forward, accumulated_loss is 0.0 (float).""" - # _compute_average_loss must handle float, not just Tensor result = manager._compute_average_loss( accumulated_loss=0.0, accumulated_aux_loss=None, batch_num_loss_counted_tokens=64, ) - # Should not raise and should return a float assert isinstance(result, float) From b2bd5364dc6505a328ddea034f86605512a81682 Mon Sep 17 00:00:00 2001 From: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> Date: Thu, 16 Apr 2026 20:09:57 +0000 Subject: [PATCH 17/18] fix: add compat shim for Nemotron's is_flash_attn_greater_or_equal_2_10 The function was renamed to is_flash_attn_greater_or_equal in transformers 5.x, but Nemotron's HF Hub remote code still imports the old name. Inject a compatibility wrapper before model loading. --- src/instructlab/training/model.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/instructlab/training/model.py b/src/instructlab/training/model.py index 09455c2d..53192d4c 100644 --- a/src/instructlab/training/model.py +++ b/src/instructlab/training/model.py @@ -84,6 +84,17 @@ def __init__( if self.is_granitemoehybrid or self.is_nemotronh: self._use_local_mamba_kernels() + # Compatibility shim for Nemotron's HF Hub remote code which imports + # is_flash_attn_greater_or_equal_2_10, renamed in transformers 5.x. + if self.is_nemotronh: + # Third Party + from transformers.utils import import_utils as _iu + + if not hasattr(_iu, "is_flash_attn_greater_or_equal_2_10"): + _iu.is_flash_attn_greater_or_equal_2_10 = lambda: ( + _iu.is_flash_attn_greater_or_equal("2.10") + ) + if self.is_gpt_oss: # Third Party quant_config = Mxfp4Config(dequantize=True) From 48da1647189bfa2f25ab9ab7c58d46c017cce28a Mon Sep 17 00:00:00 2001 From: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> Date: Thu, 16 Apr 2026 20:34:10 +0000 Subject: [PATCH 18/18] fix: make unit tests work without CUDA - Patch torch.tensor in check_checkpoint_requested tests to avoid CUDA device creation (CI runs without GPUs) - Override BatchLossManager.torch_device to CPU in fixture so _prepare_model_inputs doesn't trigger CUDA init --- tests/unit/test_on_demand_checkpoint.py | 30 +++++++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_on_demand_checkpoint.py b/tests/unit/test_on_demand_checkpoint.py index 2f243f91..f9db657e 100644 --- a/tests/unit/test_on_demand_checkpoint.py +++ b/tests/unit/test_on_demand_checkpoint.py @@ -67,7 +67,6 @@ def test_removes_existing_file(self, tmp_path): def test_noop_on_missing_file(self, tmp_path): with patch("instructlab.training.on_demand_checkpoint._TRIGGER_DIR", tmp_path): - # Should not raise remove_trigger_file() @@ -118,7 +117,6 @@ def test_handle_is_idempotent(self, tmp_path): mock_write.assert_called_once() - # signal_received should be the LAST signal assert handler.signal_received == signal.SIGINT def test_uninstall_restores_original_handlers(self): @@ -157,11 +155,25 @@ def _mock_all_reduce_propagate(self, tensor, op=None): """Mock all_reduce that just keeps the local value.""" pass + def _make_cpu_tensor_patch(self): + """Return a patched torch.tensor that forces CPU device.""" + _original_tensor = torch.tensor + + def _cpu_tensor(*args, **kwargs): + kwargs.pop("device", None) + return _original_tensor(*args, **kwargs) + + return _cpu_tensor + def test_returns_false_when_no_trigger(self, tmp_path): with ( patch("instructlab.training.on_demand_checkpoint._TRIGGER_DIR", tmp_path), patch("instructlab.training.on_demand_checkpoint.dist") as mock_dist, patch("torch.cuda.current_device", return_value=0), + patch( + "instructlab.training.on_demand_checkpoint.torch.tensor", + side_effect=self._make_cpu_tensor_patch(), + ), ): mock_dist.all_reduce = self._mock_all_reduce_propagate mock_dist.is_initialized.return_value = True @@ -174,6 +186,10 @@ def test_returns_true_when_trigger_exists(self, tmp_path): patch("instructlab.training.on_demand_checkpoint._TRIGGER_DIR", tmp_path), patch("instructlab.training.on_demand_checkpoint.dist") as mock_dist, patch("torch.cuda.current_device", return_value=0), + patch( + "instructlab.training.on_demand_checkpoint.torch.tensor", + side_effect=self._make_cpu_tensor_patch(), + ), ): mock_dist.all_reduce = self._mock_all_reduce_propagate mock_dist.is_initialized.return_value = True @@ -187,6 +203,10 @@ def test_cleans_up_trigger_after_detection(self, tmp_path): patch("instructlab.training.on_demand_checkpoint._TRIGGER_DIR", tmp_path), patch("instructlab.training.on_demand_checkpoint.dist") as mock_dist, patch("torch.cuda.current_device", return_value=0), + patch( + "instructlab.training.on_demand_checkpoint.torch.tensor", + side_effect=self._make_cpu_tensor_patch(), + ), ): mock_dist.all_reduce = self._mock_all_reduce_propagate mock_dist.is_initialized.return_value = True @@ -201,6 +221,10 @@ def test_all_reduce_is_called(self, tmp_path): patch("instructlab.training.on_demand_checkpoint._TRIGGER_DIR", tmp_path), patch("instructlab.training.on_demand_checkpoint.dist") as mock_dist, patch("torch.cuda.current_device", return_value=0), + patch( + "instructlab.training.on_demand_checkpoint.torch.tensor", + side_effect=self._make_cpu_tensor_patch(), + ), ): mock_dist.all_reduce = MagicMock() mock_dist.is_initialized.return_value = True @@ -242,6 +266,8 @@ def manager(self): world_size=1, local_rank=0, ) + # Override torch_device to CPU so tests work without CUDA + mgr.torch_device = torch.device("cpu") return mgr def _make_batch(self, n_minibatches=3):