diff --git a/src/specify_cli/workflows/base.py b/src/specify_cli/workflows/base.py index b61fdb1a08..3a42679309 100644 --- a/src/specify_cli/workflows/base.py +++ b/src/specify_cli/workflows/base.py @@ -97,6 +97,13 @@ class StepBase(ABC): Every step type — built-in or extension-provided — implements this interface and registers in ``STEP_REGISTRY``. + + Thread-safety: ``STEP_REGISTRY`` holds a single shared instance per type, so + a concurrent ``fan-out`` (``max_concurrency > 1``) can invoke ``execute`` on + the same instance from several threads at once. Implementations must be + stateless / thread-safe — derive all per-run state from the ``config`` and + ``context`` arguments and never mutate ``self`` in ``execute``. The built-in + steps follow this rule. """ #: Matches the ``type:`` value in workflow YAML. diff --git a/src/specify_cli/workflows/engine.py b/src/specify_cli/workflows/engine.py index 23b5b0c5c0..68f2ca6f3d 100644 --- a/src/specify_cli/workflows/engine.py +++ b/src/specify_cli/workflows/engine.py @@ -10,10 +10,14 @@ from __future__ import annotations +import dataclasses import json import os import re +import tempfile +import threading import uuid +from concurrent.futures import Future, ThreadPoolExecutor from datetime import datetime, timezone from pathlib import Path from typing import Any @@ -412,6 +416,15 @@ def __init__( self.current_step_index = 0 self.current_step_id: str | None = None self.step_results: dict[str, dict[str, Any]] = {} + # Guards step_results mutation and save() so a concurrent fan-out cannot + # mutate the dict while save() is serializing it (which would raise + # "dictionary changed size during iteration"). + self._lock = threading.Lock() + # Serializes append_log's list append + log.jsonl write so concurrent + # fan-out workers cannot interleave or corrupt log lines. Kept separate + # from _lock so frequent logging never contends with state saves; since + # append_log is never called while _lock is held, the two never nest. + self._log_lock = threading.Lock() self.inputs: dict[str, Any] = {} self.created_at = datetime.now(timezone.utc).isoformat() self.updated_at = self.created_at @@ -421,28 +434,72 @@ def __init__( def runs_dir(self) -> Path: return self.project_root / ".specify" / "workflows" / "runs" / self.run_id + def record_step_result(self, step_id: str, data: dict[str, Any]) -> None: + """Record one step's result under the run lock. + + Routing the mutation through the lock keeps it from racing a concurrent + ``save()`` that is iterating ``step_results`` (e.g. during a concurrent + fan-out). For a sequential run this is an uncontended lock. + """ + with self._lock: + self.step_results[step_id] = data + + def set_step_output(self, step_id: str, output: Any) -> None: + """Replace an already-recorded step's ``output`` under the run lock. + + Fan-out updates its parent step's output after the items have run; + routing that nested mutation through the lock keeps it from racing a + ``save()`` serializing ``step_results`` — the same invariant + ``record_step_result`` provides for the top-level assignment. + """ + with self._lock: + if step_id in self.step_results: + self.step_results[step_id]["output"] = output + def save(self) -> None: - """Persist current state to disk.""" - self.updated_at = datetime.now(timezone.utc).isoformat() + """Persist current state to disk. + + Held under the run lock and written atomically (temp file + ``os.replace``) + so a concurrent fan-out can neither mutate ``step_results`` mid-serialization + nor leave a reader observing a half-written file. Racing writers only + contend to be last; they never corrupt. + """ runs_dir = self.runs_dir runs_dir.mkdir(parents=True, exist_ok=True) - state_data = { - "run_id": self.run_id, - "workflow_id": self.workflow_id, - "status": self.status.value, - "current_step_index": self.current_step_index, - "current_step_id": self.current_step_id, - "step_results": self.step_results, - "created_at": self.created_at, - "updated_at": self.updated_at, - } - with open(runs_dir / "state.json", "w", encoding="utf-8") as f: - json.dump(state_data, f, indent=2) - - inputs_data = {"inputs": self.inputs} - with open(runs_dir / "inputs.json", "w", encoding="utf-8") as f: - json.dump(inputs_data, f, indent=2) + with self._lock: + # Stamp updated_at inside the lock so the timestamp matches the + # snapshot this thread serializes (concurrent savers don't race it). + self.updated_at = datetime.now(timezone.utc).isoformat() + state_data = { + "run_id": self.run_id, + "workflow_id": self.workflow_id, + "status": self.status.value, + "current_step_index": self.current_step_index, + "current_step_id": self.current_step_id, + "step_results": self.step_results, + "created_at": self.created_at, + "updated_at": self.updated_at, + } + self._atomic_write_json(runs_dir / "state.json", state_data) + self._atomic_write_json(runs_dir / "inputs.json", {"inputs": self.inputs}) + + @staticmethod + def _atomic_write_json(path: Path, data: dict[str, Any]) -> None: + """Write *data* as indented JSON to *path* atomically (temp + ``os.replace``).""" + fd, tmp = tempfile.mkstemp( + dir=str(path.parent), prefix=f".{path.name}.", suffix=".tmp" + ) + try: + with os.fdopen(fd, "w", encoding="utf-8") as f: + json.dump(data, f, indent=2) + os.replace(tmp, path) + except BaseException: + try: + os.unlink(tmp) + except OSError: + pass + raise @classmethod def load(cls, run_id: str, project_root: Path) -> RunState: @@ -490,14 +547,18 @@ def load(cls, run_id: str, project_root: Path) -> RunState: return state def append_log(self, entry: dict[str, Any]) -> None: - """Append a log entry to the run log.""" - entry["timestamp"] = datetime.now(timezone.utc).isoformat() - self.log_entries.append(entry) + """Append a log entry to the run log. + Held under ``_log_lock`` so concurrent fan-out workers serialize their + list append and ``log.jsonl`` write rather than interleaving lines. + """ + entry["timestamp"] = datetime.now(timezone.utc).isoformat() runs_dir = self.runs_dir runs_dir.mkdir(parents=True, exist_ok=True) - with open(runs_dir / "log.jsonl", "a", encoding="utf-8") as f: - f.write(json.dumps(entry) + "\n") + with self._log_lock: + self.log_entries.append(entry) + with open(runs_dir / "log.jsonl", "a", encoding="utf-8") as f: + f.write(json.dumps(entry) + "\n") # -- Workflow Engine ------------------------------------------------------ @@ -509,6 +570,10 @@ class WorkflowEngine: def __init__(self, project_root: Path | None = None) -> None: self.project_root = project_root or Path(".") self.on_step_start: Any = None # Callable[[str, str], None] | None + # Serializes on_step_start so a concurrent fan-out can't interleave the + # callback's output (the CLI sets it to a console.print lambda). Uncontended + # for sequential runs. + self._callback_lock = threading.Lock() def load_workflow(self, source: str | Path) -> WorkflowDefinition: """Load a workflow from an installed ID or a local YAML path. @@ -712,6 +777,22 @@ def resume( state.save() return state + @staticmethod + def _record_result( + context: StepContext, state: RunState, step_id: str, data: dict[str, Any] + ) -> None: + """Record a step result into both the live context and persistent state. + + ``record_step_result`` writes ``state.step_results`` under the run lock. + On a resume run ``context.steps`` *is* that same dict, so that locked + write is the only one needed; mirror into ``context.steps`` separately + only when it is a distinct object (a fresh run), to avoid an unlocked + mutation of the shared dict that could race a concurrent ``save()``. + """ + if context.steps is not state.step_results: + context.steps[step_id] = data + state.record_step_result(step_id, data) + def _execute_steps( self, steps: list[dict[str, Any]], @@ -739,7 +820,8 @@ def _execute_steps( # otherwise stay silent (library-safe default). label = step_config.get("command", "") or step_type if self.on_step_start is not None: - self.on_step_start(step_id, label) + with self._callback_lock: + self.on_step_start(step_id, label) step_impl = registry.get(step_type) if not step_impl: @@ -772,8 +854,7 @@ def _execute_steps( "output": result.output, "status": result.status.value, } - context.steps[step_id] = step_data - state.step_results[step_id] = step_data + self._record_result(context, state, step_id, step_data) state.append_log( { @@ -900,40 +981,32 @@ def _execute_steps( ): return if orig and ns_copy["id"] in context.steps: - context.steps[orig] = context.steps[ns_copy["id"]] - state.step_results[orig] = context.steps[ns_copy["id"]] - - # Fan-out: execute nested step template per item with unique IDs + self._record_result( + context, state, orig, + context.steps[ns_copy["id"]], + ) + + # Fan-out: execute the nested step template once per item. Honors + # max_concurrency — <=1 runs sequentially (default, historical + # behavior); >1 runs up to that many items concurrently. Either way + # results are assembled in item order under the + # parentId:templateId:index id grammar. if step_type == "fan-out": items = result.output.get("items", []) template = result.output.get("step_template", {}) if template and items: - fan_out_results = [] - for item_idx, item_val in enumerate(result.output["items"]): - context.item = item_val - # Per-item ID: parentId:templateId:index - item_step = dict(template) - base_id = item_step.get("id", "item") - item_step["id"] = f"{step_id}:{base_id}:{item_idx}" - self._execute_steps( - [item_step], context, state, registry, - step_offset=-1, - ) - # Collect per-item result for fan-in - item_result = context.steps.get(item_step["id"], {}) - fan_out_results.append(item_result.get("output", {})) - if state.status in ( - RunStatus.PAUSED, - RunStatus.FAILED, - RunStatus.ABORTED, - ): - break + fan_out_results = self._run_fan_out( + items, template, step_id, context, state, registry, + result.output.get("max_concurrency", 1), + ) context.item = None # Preserve original output and add collected results fan_out_output = dict(result.output) fan_out_output["results"] = fan_out_results - context.steps[step_id]["output"] = fan_out_output - state.step_results[step_id]["output"] = fan_out_output + # set_step_output updates the recorded dict under the run lock; + # context.steps[step_id] is that same object, so it reflects the + # change too — no separate (unlocked) context mutation needed. + state.set_step_output(step_id, fan_out_output) if state.status in ( RunStatus.PAUSED, RunStatus.FAILED, @@ -943,8 +1016,170 @@ def _execute_steps( else: # Empty items or no template — normalize output result.output["results"] = [] - context.steps[step_id]["output"] = result.output - state.step_results[step_id]["output"] = result.output + state.set_step_output(step_id, result.output) + + def _run_fan_out( + self, + items: list[Any], + template: dict[str, Any], + step_id: str, + context: StepContext, + state: RunState, + registry: dict[str, Any], + max_concurrency: Any, + ) -> list[Any]: + """Run a fan-out template once per item; return per-item outputs in item order. + + ``max_concurrency`` <= 1 (the default) runs items sequentially, identical + to the historical fan-out behavior. ``max_concurrency`` > 1 runs items on a + bounded thread pool using a sliding submission window of that size: at most + that many items are ever in flight, and no new item is launched once the run + has reached a halting status, so a halt cannot keep starting queued work. + + Results are always returned in item order (never completion order). On a + halt (PAUSED/FAILED/ABORTED) the returned prefix is the items up to and + including the first item *in item order* whose own execution halted the run + — identical to the sequential path. Later items that have not yet started + are cancelled; any already running are allowed to finish but their outputs + are ignored. Halt is attributed per item from that item's recorded result + (not the shared run status, which a concurrently-running later item may have + already flipped), so the prefix never drops the actual halting item. + + ``max_concurrency`` is coerced with ``int()``; a value that cannot be + coerced (``None``, a non-numeric string, …) or that coerces to <= 1 runs + sequentially, while a numeric string like ``"4"`` or a float like ``4.0`` + is honored. + """ + if not items: + return [] + + halting = (RunStatus.PAUSED, RunStatus.FAILED, RunStatus.ABORTED) + try: + workers = max(1, int(max_concurrency)) + except (TypeError, ValueError): + workers = 1 + # Never spin up more workers than there is work — bounds a user-controlled + # max_concurrency from over-allocating threads. + workers = min(workers, len(items)) + + base_id = template.get("id", "item") + + def item_id(idx: int) -> str: + # Per-item ID grammar: parentId:templateId:index. + return f"{step_id}:{base_id}:{idx}" + + def run_item(idx: int, item_ctx: StepContext) -> Any: + item_step = dict(template) + item_step["id"] = item_id(idx) + self._execute_steps( + [item_step], item_ctx, state, registry, step_offset=-1, + ) + # Read back through the context that was actually executed against, + # not the outer closure — clearer and robust if StepContext copying + # ever stops sharing the steps dict by reference. + return item_ctx.steps.get(item_step["id"], {}).get("output", {}) + + # Sequential path — identical to the historical behavior. + if workers <= 1: + results: list[Any] = [] + for item_idx, item_val in enumerate(items): + context.item = item_val + results.append(run_item(item_idx, context)) + if state.status in halting: + break + return results + + # Concurrent path — bounded sliding window; results assembled in item order. + n = len(items) + slots: list[Any] = [None] * n + + def run_isolated(idx: int) -> Any: + # Each item runs against its own context copy so context.item is not + # clobbered across threads; the shared steps dict is written only on the + # disjoint parentId:templateId:index key (GIL-safe on distinct keys). + return run_item(idx, dataclasses.replace(context, item=items[idx])) + + def item_halt_status(idx: int) -> RunStatus | None: + # If THIS item's own execution halted the run, return the resulting run + # status; else None. Decided from the item's own recorded result, not + # the shared run status, so a later item's concurrent halt is never + # misattributed here. Mirrors the sequential mapping: PAUSED -> PAUSED; + # FAILED -> ABORTED when aborted, else FAILED, unless continue_on_error + # routes around it. + rec = context.steps.get(item_id(idx)) + if rec is None: + # Ran but recorded nothing — only when the item failed before + # record_step_result (e.g. an unknown step type returns early). + # Every item runs the same template, so the shared run status is + # this item's own outcome; attribute the halt to it. + return state.status if state.status in halting else None + status = rec.get("status") + if status == StepStatus.PAUSED.value: + return RunStatus.PAUSED + if status == StepStatus.FAILED.value: + out = rec.get("output") or {} + if out.get("aborted"): + return RunStatus.ABORTED + if template.get("continue_on_error") is not True: + return RunStatus.FAILED + return None + + # (halting item index, its run status) once a halt is attributed. + halt: tuple[int, RunStatus] | None = None + collected = 0 + with ThreadPoolExecutor(max_workers=workers) as pool: + futures: dict[int, Future] = {} + next_submit = 0 + for idx in range(n): + # Refill the window: keep <= workers in flight, and stop launching + # new items once the run is halting so a halt cannot keep starting + # queued work. Already-submitted futures are still collected in + # item order below. + while ( + next_submit < n + and len(futures) < workers + and state.status not in halting + ): + futures[next_submit] = pool.submit(run_isolated, next_submit) + next_submit += 1 + + fut = futures.pop(idx, None) + if fut is None: + # Safety net: the window submits indices in order and the loop + # breaks at the first halting item, so every collected index has + # an in-flight future. Stop cleanly rather than raise if a future + # change ever breaks that invariant. + break + try: + slots[idx] = fut.result() + except Exception: + # A genuine exception escaping a step (not a normal step + # FAILED, which sets state.status) must not be masked: cancel + # outstanding work and re-raise — with a bare ``raise`` so the + # original traceback is preserved — so the engine marks the run + # failed instead of reporting a vacuous completion. The pool's + # __exit__ still joins any already-running workers. + for other in futures.values(): + other.cancel() + raise + collected = idx + 1 + halt_status = item_halt_status(idx) + if halt_status is not None: + # First halting item in item order: include it (slots[idx] is + # already set), record its status, and cancel everything pending. + halt = (idx, halt_status) + for other in futures.values(): + other.cancel() + break + + if halt is not None: + halted_at, halted_status = halt + # A later in-flight item may have overwritten state.status before the + # pool joined; restore the halting item's own outcome so the final run + # status matches the sequential semantics. + state.status = halted_status + return slots[: halted_at + 1] + return slots[:collected] def _resolve_inputs( self, diff --git a/tests/test_workflows.py b/tests/test_workflows.py index eebc89fadd..ee63b36ba5 100644 --- a/tests/test_workflows.py +++ b/tests/test_workflows.py @@ -2045,6 +2045,210 @@ def test_validate_wait_for_not_list(self): assert any("non-empty list" in e for e in errors) +class TestFanOutConcurrency: + """Fan-out honors max_concurrency (WorkflowEngine._run_fan_out).""" + + @staticmethod + def _build(tmp_path, on_item=None): + """Wire an engine + run state to a probe step that echoes context.item. + + Per-item output is ``{"seen": }`` so order and per-thread item + isolation are checkable. ``on_item(item)`` may run a side effect and + optionally return a StepStatus to override COMPLETED (or raise). + """ + from specify_cli.workflows.base import ( + RunStatus, + StepBase, + StepContext, + StepResult, + StepStatus, + ) + from specify_cli.workflows.engine import RunState, WorkflowEngine + + class _ProbeStep(StepBase): + type_key = "probe" + + def execute(self, config, context): + status = StepStatus.COMPLETED + if on_item is not None: + override = on_item(context.item) + if override is not None: + status = override + return StepResult(status=status, output={"seen": context.item}) + + engine = WorkflowEngine(project_root=tmp_path) + context = StepContext() + state = RunState(run_id="r", workflow_id="w", project_root=tmp_path) + state.status = RunStatus.RUNNING + template = {"id": "impl", "type": "probe"} + return engine, context, state, {"probe": _ProbeStep()}, template + + def _run(self, tmp_path, items, max_concurrency, on_item=None): + engine, context, state, registry, template = self._build(tmp_path, on_item) + results = engine._run_fan_out( + items, template, "fan", context, state, registry, max_concurrency + ) + return results, state + + def test_sequential_default_preserves_order(self, tmp_path): + results, _ = self._run(tmp_path, list(range(5)), 1) + assert results == [{"seen": i} for i in range(5)] + + def test_concurrent_runs_all_items_in_item_order(self, tmp_path): + results, _ = self._run(tmp_path, list(range(10)), 4) + assert results == [{"seen": i} for i in range(10)] + + def test_sequential_and_concurrent_agree(self, tmp_path): + items = [{"n": i} for i in range(8)] + seq, _ = self._run(tmp_path, items, 1) + con, _ = self._run(tmp_path, items, 4) + assert seq == con == [{"seen": {"n": i}} for i in range(8)] + + def test_shuffled_completion_preserves_item_order(self, tmp_path): + # Determinism keystone: completion order is forced to the exact REVERSE of + # item order by an event chain (no sleeps) — item i blocks until item i+1 + # has finished, so item 0 completes LAST — yet results must still be in + # item order. K == len(items) so all workers are in flight together. + import threading + + n = 4 + done = [threading.Event() for _ in range(n)] + completion: list[int] = [] + clock = threading.Lock() + + def on_item(item): + if item + 1 < n: + assert done[item + 1].wait(2.0), f"item {item + 1} never finished" + with clock: + completion.append(item) + done[item].set() + return None + + results, _ = self._run(tmp_path, list(range(n)), n, on_item) + assert results == [{"seen": i} for i in range(n)] + assert completion == list(reversed(range(n))) + + def test_concurrency_is_real(self, tmp_path): + import threading + + # Deterministic proof of real parallelism (no wall-clock threshold to + # tune or flake): every item must reach the barrier before any may pass. + # Sequential execution would block the first item forever — the barrier + # times out, raises BrokenBarrierError, and fails the test. + n = 4 + barrier = threading.Barrier(n, timeout=5) + + def on_item(item): + barrier.wait() + return None + + results, _ = self._run(tmp_path, list(range(n)), n, on_item) + assert results == [{"seen": i} for i in range(n)] + + @pytest.mark.parametrize("bad", [0, -1, None, "abc", 1.0]) + def test_invalid_max_concurrency_coerces_to_sequential(self, tmp_path, bad): + results, _ = self._run(tmp_path, list(range(4)), bad) + assert results == [{"seen": i} for i in range(4)] + + def test_string_max_concurrency_is_honored(self, tmp_path): + results, _ = self._run(tmp_path, list(range(4)), "2") + assert results == [{"seen": i} for i in range(4)] + + def test_context_item_isolation_across_threads(self, tmp_path): + items = [{"id": f"x{i}"} for i in range(6)] + results, _ = self._run(tmp_path, items, 6) + assert [r["seen"]["id"] for r in results] == [f"x{i}" for i in range(6)] + + def test_empty_items(self, tmp_path): + results, _ = self._run(tmp_path, [], 4) + assert results == [] + + def test_concurrent_halt_status_not_clobbered_by_later_item(self, tmp_path): + # Item 1 PAUSES (first halting item in order); item 3 FAILS while in + # flight. The final run status must be the halting item's (PAUSED), never + # a later item's (FAILED) that raced after it — matching sequential. + from specify_cli.workflows.base import RunStatus, StepStatus + + def on_item(item): + if item == 1: + return StepStatus.PAUSED + if item == 3: + return StepStatus.FAILED + return None + + results, state = self._run(tmp_path, list(range(4)), 4, on_item) + assert results == [{"seen": 0}, {"seen": 1}] + assert state.status == RunStatus.PAUSED + + def test_halt_on_failure_sequential_returns_prefix(self, tmp_path): + from specify_cli.workflows.base import RunStatus, StepStatus + + def on_item(item): + return StepStatus.FAILED if item == 2 else None + + results, state = self._run(tmp_path, list(range(5)), 1, on_item) + assert len(results) == 3 # items 0,1,2 ran; 3,4 never dispatched + assert results[2] == {"seen": 2} + assert state.status == RunStatus.FAILED + + def test_halt_on_failure_concurrent_includes_halting_item(self, tmp_path): + # The concurrent prefix must match the sequential one: items up to and + # INCLUDING the failing item (2), never a short prefix that drops it just + # because a later in-flight item flipped the shared run status first. + from specify_cli.workflows.base import RunStatus, StepStatus + + def on_item(item): + return StepStatus.FAILED if item == 2 else None + + results, state = self._run(tmp_path, list(range(6)), 4, on_item) + assert results == [{"seen": 0}, {"seen": 1}, {"seen": 2}] + assert state.status == RunStatus.FAILED + + def test_continue_on_error_item_does_not_halt_concurrent(self, tmp_path): + # A failing item whose template sets continue_on_error must NOT truncate + # the fan-out: every item still runs and is returned in order. + from specify_cli.workflows.base import StepStatus + + def on_item(item): + return StepStatus.FAILED if item == 2 else None + + engine, context, state, registry, template = self._build(tmp_path, on_item) + template["continue_on_error"] = True + results = engine._run_fan_out( + list(range(5)), template, "fan", context, state, registry, 4 + ) + assert results == [{"seen": i} for i in range(5)] + + def test_unknown_template_type_halts_concurrent_like_sequential(self, tmp_path): + # A template whose type isn't registered fails fast and records no result; + # the concurrent path must still attribute the halt to the first item and + # return the same prefix as sequential — never run on as if completed. + from specify_cli.workflows.base import RunStatus, StepContext + from specify_cli.workflows.engine import RunState, WorkflowEngine + + def fresh(): + state = RunState(run_id="r", workflow_id="w", project_root=tmp_path) + state.status = RunStatus.RUNNING + return WorkflowEngine(project_root=tmp_path), StepContext(), state + + template = {"id": "impl", "type": "does-not-exist"} + e1, c1, s1 = fresh() + seq = e1._run_fan_out(list(range(5)), template, "fan", c1, s1, {}, 1) + e2, c2, s2 = fresh() + con = e2._run_fan_out(list(range(5)), template, "fan", c2, s2, {}, 4) + assert seq == con == [{}] # halted at the first item; rest never returned + assert s1.status == s2.status == RunStatus.FAILED + + def test_first_exception_cancels_and_reraises(self, tmp_path): + def on_item(item): + if item == 0: + raise ValueError("boom") + return None + + with pytest.raises(ValueError, match="boom"): + self._run(tmp_path, list(range(4)), 2, on_item) + + class TestFanInWaitForValidation: """fan-in wait_for must reference a declared step (no silent empty join)."""