Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion fern/versions/latest/pages/concepts/workflow-chaining.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,21 @@ workflow.add_stage("cleanup", cleanup)

This is useful for final cleanup, schema transforms, and format-specific export preparation.

## Resume

Workflow names are durable artifact identities. Reusing the same name with `resume=ResumeMode.IF_POSSIBLE` reuses compatible completed stages, resumes a matching partial stage through `DataDesigner.create(..., resume=ResumeMode.ALWAYS)`, and reruns the first changed or missing stage plus its descendants.

```python
from data_designer.interface import ResumeMode

results = workflow.run(resume=ResumeMode.IF_POSSIBLE)
```

Use `ResumeMode.ALWAYS` for strict resume before the first recovered checkpoint. A changed stage or missing selected output raises instead of starting fresh. If a matching partial stage resumes successfully, descendants are recreated from that stage's current output.

## Current limits

- Stages are linear. DAGs, parallel branches, and joins are planned separately.
- Stage-level resume is not implemented yet.
- `push_to_hub()` does not support selected processor or callback outputs yet. Use `export()` for the selected workflow output.
- `on_success` callbacks are trusted user code. If a callback returns a path, Data Designer reads that path as the next stage input.
- The artifact layout is intended for inspection, but it is not yet a stable public contract.
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,31 @@
import hashlib
import json
import logging
import os
import shutil
import time
import uuid
from collections.abc import Callable, ItemsView, Iterator, KeysView
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any

from pydantic import ValidationError

import data_designer.lazy_heavy_imports as lazy
from data_designer.config.analysis.dataset_profiler import DatasetProfilerResults
from data_designer.config.base import ProcessorConfig
from data_designer.config.config_builder import BuilderConfig, DataDesignerConfigBuilder
from data_designer.config.data_designer_config import DataDesignerConfig
from data_designer.config.dataset_metadata import DatasetMetadata
from data_designer.config.errors import InvalidFileFormatError
from data_designer.config.seed import IndexRange, PartitionBlock, SamplingStrategy
from data_designer.config.seed_source import LocalFileSeedSource
from data_designer.config.utils.constants import DEFAULT_NUM_RECORDS
from data_designer.config.utils.type_helpers import StrEnum
from data_designer.config.version import get_library_version
from data_designer.engine.dataset_builders.errors import ArtifactStorageError
from data_designer.engine.storage.artifact_storage import ArtifactStorage, ResumeMode
from data_designer.interface.errors import DataDesignerWorkflowError
from data_designer.interface.results import (
SUPPORTED_EXPORT_FORMATS,
Expand All @@ -37,13 +44,21 @@
if TYPE_CHECKING:
import pandas as pd

from data_designer.config.analysis.dataset_profiler import DatasetProfilerResults
from data_designer.interface.data_designer import DataDesigner


logger = logging.getLogger(__name__)

OnSuccessCallback = Callable[[Path], Path | str]
WORKFLOW_METADATA_FILENAME = "workflow-metadata.json"
COMPLETED_STAGE_STATUSES = {"completed", "completed_empty"}
RESUMABLE_STAGE_STATUSES = {"running", "failed"}
WORKFLOW_PATH_METADATA_KEYS = (
"seed_path",
"output_seed_path",
"callback_output_path",
"output_processor_output_path",
)


@dataclass(frozen=True)
Expand Down Expand Up @@ -221,8 +236,8 @@ def add_stage(
)
return self

def run(self) -> CompositeWorkflowResults:
"""Run all stages from scratch.
def run(self, *, resume: ResumeMode = ResumeMode.NEVER) -> CompositeWorkflowResults:
"""Run all stages, optionally reusing compatible completed stage outputs.

Each stage writes a deterministic artifact directory under the parent
Data Designer artifact path. Downstream stages are seeded from the
Expand All @@ -233,6 +248,7 @@ def run(self) -> CompositeWorkflowResults:

workflow_path = self._data_designer.artifact_path / self.name
workflow_path.mkdir(parents=True, exist_ok=True)
prior_metadata = _read_prior_workflow_metadata(workflow_path, self.name, resume)
metadata: dict[str, Any] = {
"name": self.name,
"library_version": get_library_version(),
Expand All @@ -245,6 +261,8 @@ def run(self) -> CompositeWorkflowResults:
previous_stage_name: str | None = None
previous_stage_fingerprint: str | None = None
skipped_upstream_stage: str | None = None
# A stage that runs or resumes may produce new data, so descendants rebuild from its current output.
force_rerun_downstream = False

for index, stage in enumerate(self._stages):
stage_dir_name = _stage_dir_name(index, stage.name)
Expand Down Expand Up @@ -288,7 +306,44 @@ def run(self) -> CompositeWorkflowResults:
upstream_fingerprint=previous_stage_fingerprint,
)
stage_path = workflow_path / stage_dir_name
if stage_path.exists():
prior_stage_metadata = _get_prior_stage_metadata(prior_metadata, index, stage, stage_dir_name)
stage_resume = ResumeMode.NEVER
prior_matches = (
not force_rerun_downstream
and prior_stage_metadata is not None
and prior_stage_metadata.get("fingerprint") == stage_fingerprint
)

if prior_matches and _can_skip_prior_stage(stage, prior_stage_metadata, workflow_path):
stage_metadata.update(prior_stage_metadata)
output_seed_path = _resolve_metadata_path(workflow_path, stage_metadata["output_seed_path"])
_normalize_stage_path_metadata(workflow_path, stage_metadata)
output_records = _count_parquet_records(output_seed_path)
output_result = _stage_result_from_metadata(
workflow_path=workflow_path,
stage=stage,
stage_dir_name=stage_dir_name,
stage_builder=stage_builder,
)
stage_results[stage.name] = output_result
stage_output_paths[stage.name] = output_seed_path
previous_seed_path = output_seed_path
previous_output_records = None if stage_metadata["status"] == "completed_empty" else output_records
previous_stage_name = stage.name
previous_stage_fingerprint = stage_fingerprint
if stage_metadata["status"] == "completed_empty":
skipped_upstream_stage = stage.name
_write_workflow_metadata(workflow_path, metadata)
continue

if prior_matches and prior_stage_metadata.get("status") in RESUMABLE_STAGE_STATUSES and stage_path.exists():
stage_resume = ResumeMode.ALWAYS
elif resume == ResumeMode.ALWAYS and not force_rerun_downstream:
raise DataDesignerWorkflowError(
f"Cannot resume workflow {self.name!r}: stage {stage.name!r} is not reusable."
)

if stage_resume == ResumeMode.NEVER and stage_path.exists():
shutil.rmtree(stage_path)

stage_metadata.update(
Expand All @@ -297,7 +352,11 @@ def run(self) -> CompositeWorkflowResults:
"fingerprint": stage_fingerprint,
"num_records_requested": num_records,
"seeded_from_stage": previous_stage_name,
"seed_path": str(previous_seed_path) if previous_seed_path is not None else None,
"seed_path": (
_metadata_path_value(workflow_path, previous_seed_path)
if previous_seed_path is not None
else None
),
"config": stage_config.model_dump(mode="json"),
}
)
Expand All @@ -310,11 +369,15 @@ def run(self) -> CompositeWorkflowResults:
num_records=num_records,
dataset_name=stage_dir_name,
artifact_path=workflow_path,
resume=stage_resume,
)
actual_records = result.count_records()
output_result = result
output_source_result = result
if stage.output_processors:
output_processor_path = stage_path / "output-processors"
if output_processor_path.exists():
shutil.rmtree(output_processor_path)
output_processor_builder = _output_processor_config_builder(
stage_builder=stage_builder,
seed_path=result.artifact_storage.final_dataset_path,
Expand Down Expand Up @@ -349,10 +412,14 @@ def run(self) -> CompositeWorkflowResults:
"status": status,
"num_records_actual": actual_records,
"output_records": output_records,
"output_seed_path": str(output_seed_path),
"callback_output_path": str(callback_output_path) if callback_output_path else None,
"output_seed_path": _metadata_path_value(workflow_path, output_seed_path),
"callback_output_path": (
_metadata_path_value(workflow_path, callback_output_path) if callback_output_path else None
),
"output_processor_output_path": (
str(output_result.artifact_storage.base_dataset_path) if stage.output_processors else None
_metadata_path_value(workflow_path, output_result.artifact_storage.base_dataset_path)
if stage.output_processors
else None
),
"duration_sec": time.monotonic() - start_time,
}
Expand All @@ -365,9 +432,10 @@ def run(self) -> CompositeWorkflowResults:
stage_results[stage.name] = output_result
stage_output_paths[stage.name] = output_seed_path
previous_seed_path = output_seed_path
previous_output_records = output_records
previous_output_records = None if status == "completed_empty" else output_records
previous_stage_name = stage.name
previous_stage_fingerprint = stage_fingerprint
force_rerun_downstream = True
_write_workflow_metadata(workflow_path, metadata)

return CompositeWorkflowResults(
Expand All @@ -378,6 +446,155 @@ def run(self) -> CompositeWorkflowResults:
)


def _read_prior_workflow_metadata(
workflow_path: Path,
workflow_name: str,
resume: ResumeMode,
) -> dict[str, Any] | None:
if resume == ResumeMode.NEVER:
return None
metadata_path = workflow_path / WORKFLOW_METADATA_FILENAME
if not metadata_path.exists():
if resume == ResumeMode.ALWAYS:
raise DataDesignerWorkflowError(f"Cannot resume workflow {workflow_name!r}: no workflow metadata found.")
return None
try:
metadata = json.loads(metadata_path.read_text(encoding="utf-8"))
except json.JSONDecodeError as exc:
if resume != ResumeMode.ALWAYS:
logger.warning("Workflow metadata for %r is corrupt; starting fresh.", workflow_name)
return None
raise DataDesignerWorkflowError(
f"Cannot resume workflow {workflow_name!r}: workflow metadata is corrupt."
) from exc
except OSError as exc:
if resume != ResumeMode.ALWAYS:
logger.warning("Workflow metadata for %r could not be read; starting fresh.", workflow_name)
return None
raise DataDesignerWorkflowError(
f"Cannot resume workflow {workflow_name!r}: workflow metadata could not be read."
) from exc
if metadata.get("name") != workflow_name:
if resume != ResumeMode.ALWAYS:
logger.warning("Workflow metadata for %r has a different name; starting fresh.", workflow_name)
return None
raise DataDesignerWorkflowError(
f"Cannot resume workflow {workflow_name!r}: workflow metadata name does not match."
)
return metadata


def _get_prior_stage_metadata(
prior_metadata: dict[str, Any] | None,
index: int,
stage: _WorkflowStage,
stage_dir_name: str,
) -> dict[str, Any] | None:
if prior_metadata is None:
return None
stages = prior_metadata.get("stages")
if not isinstance(stages, list) or index >= len(stages):
return None
prior_stage = stages[index]
if not isinstance(prior_stage, dict):
return None
if prior_stage.get("name") != stage.name or prior_stage.get("stage_dir") != stage_dir_name:
return None
return prior_stage


def _can_skip_prior_stage(stage: _WorkflowStage, prior_stage_metadata: dict[str, Any], workflow_path: Path) -> bool:
if prior_stage_metadata.get("status") not in COMPLETED_STAGE_STATUSES:
return False
if stage.on_success is not None and stage.on_success_version is None:
return False
output_seed_path = prior_stage_metadata.get("output_seed_path")
if not isinstance(output_seed_path, str) or not output_seed_path:
return False
try:
_count_parquet_records(_resolve_metadata_path(workflow_path, output_seed_path))
except DataDesignerWorkflowError:
return False
return True


def _metadata_path_value(workflow_path: Path, path: Path) -> str:
if path.is_absolute():
try:
return str(path.relative_to(workflow_path))
except ValueError:
return str(path)
return str(path)


def _resolve_metadata_path(workflow_path: Path, path: str) -> Path:
metadata_path = Path(path)
if metadata_path.is_absolute():
return metadata_path
return workflow_path / metadata_path


def _normalize_stage_path_metadata(workflow_path: Path, stage_metadata: dict[str, Any]) -> None:
for key in WORKFLOW_PATH_METADATA_KEYS:
value = stage_metadata.get(key)
if isinstance(value, str) and value:
stage_metadata[key] = _metadata_path_value(workflow_path, _resolve_metadata_path(workflow_path, value))


def _stage_result_from_metadata(
*,
workflow_path: Path,
stage: _WorkflowStage,
stage_dir_name: str,
stage_builder: DataDesignerConfigBuilder,
) -> DatasetCreationResults:
main_storage = ArtifactStorage(artifact_path=workflow_path, dataset_name=stage_dir_name, resume=ResumeMode.ALWAYS)
result_storage = main_storage
result_builder = stage_builder
if stage.output_processors:
result_storage = ArtifactStorage(
artifact_path=workflow_path / stage_dir_name,
dataset_name="output-processors",
resume=ResumeMode.ALWAYS,
)
result_builder = _output_processor_config_builder(
stage_builder=stage_builder,
seed_path=main_storage.final_dataset_path,
output_processors=stage.output_processors,
)
return DatasetCreationResults(
artifact_storage=result_storage,
analysis=_load_stage_analysis(result_storage),
config_builder=result_builder,
dataset_metadata=DatasetMetadata(),
)


def _load_stage_analysis(artifact_storage: ArtifactStorage) -> Any:
try:
metadata = artifact_storage.read_metadata()
except (FileNotFoundError, json.JSONDecodeError, OSError):
return None
column_statistics = metadata.get("column_statistics")
if not column_statistics:
return None
num_records = metadata.get("actual_num_records")
if num_records is None:
num_records = _count_parquet_records(artifact_storage.final_dataset_path)
try:
return DatasetProfilerResults.model_validate(
{
"num_records": num_records,
"target_num_records": metadata.get("target_num_records", num_records),
"column_statistics": column_statistics,
"side_effect_column_names": metadata.get("side_effect_column_names"),
"column_profiles": metadata.get("column_profiles"),
}
)
except ValidationError:
return None


def _clone_config_builder(config_builder: DataDesignerConfigBuilder) -> DataDesignerConfigBuilder:
return DataDesignerConfigBuilder.from_config(BuilderConfig(data_designer=config_builder.build()))

Expand Down Expand Up @@ -527,8 +744,16 @@ def _parquet_files(path: Path) -> list[Path]:


def _write_workflow_metadata(workflow_path: Path, metadata: dict[str, Any]) -> None:
path = workflow_path / "workflow-metadata.json"
path.write_text(json.dumps(metadata, indent=2, sort_keys=True), encoding="utf-8")
path = workflow_path / WORKFLOW_METADATA_FILENAME
tmp_path = path.with_name(f"{path.name}.tmp.{os.getpid()}.{uuid.uuid4().hex}")
try:
with tmp_path.open("w", encoding="utf-8") as f:
json.dump(metadata, f, indent=2, sort_keys=True)
f.flush()
os.fsync(f.fileno())
os.replace(tmp_path, path)
finally:
tmp_path.unlink(missing_ok=True)


def _validate_stage_output(output: str) -> None:
Expand Down
Loading
Loading