From abe044ffe96640fd10626f26ca90fcb850859239 Mon Sep 17 00:00:00 2001 From: "Eric W. Tramel" Date: Fri, 5 Jun 2026 23:49:14 +0000 Subject: [PATCH] Expose row-group admission controls Fixes #741 Signed-off-by: Eric W. Tramel <1223539+eric-tramel@users.noreply.github.com> --- architecture/dataset-builders.md | 2 +- docs/concepts/architecture-and-performance.md | 67 ++- .../concepts/architecture-and-performance.mdx | 131 +++-- .../src/data_designer/config/__init__.py | 8 + .../src/data_designer/config/run_config.py | 101 +++- .../tests/config/test_run_config.py | 141 ++++++ .../src/data_designer/engine/capacity.py | 1 + .../dataset_builders/async_scheduler.py | 269 ++++++---- .../dataset_builders/dataset_builder.py | 13 + .../dataset_builders/scheduling/completion.py | 467 ++++++++++++++++-- .../scheduling/test_completion.py | 329 +++++++++++- .../test_async_builder_integration.py | 158 +++++- .../dataset_builders/test_async_scheduler.py | 459 ++++++++++++++++- .../data_designer/interface/data_designer.py | 3 +- plans/645/async-scheduling-epic.puml | 6 + plans/645/contracts.md | 6 + plans/741/row-group-admission.md | 65 +++ 17 files changed, 2013 insertions(+), 213 deletions(-) create mode 100644 plans/741/row-group-admission.md diff --git a/architecture/dataset-builders.md b/architecture/dataset-builders.md index 70b6afed4..e05d3b312 100644 --- a/architecture/dataset-builders.md +++ b/architecture/dataset-builders.md @@ -130,7 +130,7 @@ DatasetBuilder.build() → collect TaskTraces, emit telemetry ``` -Row-group admission is fixed by default in the dataset-builder path: the configured row-group concurrency is the hard in-flight cap. The scheduler also has an internal adaptive row-group mode for direct use that only raises a soft target up to that cap; it is additive ramp-up, not AIMD shrink/recovery behavior. +Row-group admission is fixed by default in the dataset-builder path: the default row-group concurrency is the hard in-flight cap. Public async runs can override that horizon with `RunConfig.row_group_admission`. Fixed mode uses `max_concurrent_row_groups` as the hard in-flight cap. The historical default fixed horizon remains row-group-count-only, while widened fixed horizons derive an active-row guard when `max_admitted_rows` is omitted. Adaptive mode starts from `adaptive_initial_target` and only raises a soft target up to that cap; it is additive ramp-up, not AIMD shrink/recovery behavior. Adaptive mode also derives an active-row guard when `max_admitted_rows` is omitted, and rejects row groups above that guard, so wide row groups cannot silently admit unbounded active state. When request admission is available, async scheduling may use request-pressure snapshots as a read-only advisory during fair-queue selection. A request-pressured task can be skipped for an eligible peer without mutating request-admission state; provider/model/domain request limits remain owned by request admission. diff --git a/docs/concepts/architecture-and-performance.md b/docs/concepts/architecture-and-performance.md index f03b549d8..89145b556 100644 --- a/docs/concepts/architecture-and-performance.md +++ b/docs/concepts/architecture-and-performance.md @@ -99,7 +99,8 @@ Within each column, cells are processed **in parallel** up to the configured lim ### Concurrency Formula -At any moment, the number of concurrent LLM requests is: +On the sync engine, each batch is processed one column at a time. At any moment, +the number of concurrent LLM requests is: ```python concurrent_requests = min( @@ -109,6 +110,21 @@ concurrent_requests = min( ) ``` +On the async engine, ready cells can come from multiple active row groups: + +```python +concurrent_requests = min( + active_ready_model_cells, # Ready cells across admitted row groups + current_request_limit, # AIMD-managed limit (≤ max_parallel_requests) + max_in_flight_tasks # Scheduler task-lease ceiling +) +``` + +`active_ready_model_cells` is bounded by row-group admission: +`max_concurrent_row_groups`, the effective `max_admitted_rows` guard, the DAG +dependencies that have become ready, and any rows already dropped by processors +or failures. + `max_parallel_requests` sets the **ceiling**. The actual limit (`current_request_limit`) is managed at runtime by adaptive request admission that reacts to rate-limit signals from the inference server: - **During optional startup ramp**: when `startup_ramp_seconds` is greater than 0, a new request resource starts at one concurrent request and increases linearly toward `max_parallel_requests` over that duration. @@ -152,6 +168,53 @@ designer.set_run_config(run_config) --- +### Row-Group Admission (RunConfig) + +Controls how many async row groups can be active at once. A row group contains +`buffer_size` records, so this setting is the scheduler horizon above the batch +size: a wider horizon can expose more ready model work to fast endpoints, while +a smaller horizon tends to checkpoint completed records earlier and hold less +active state. + +```python +import data_designer.config as dd +from data_designer.interface import DataDesigner + +run_config = dd.RunConfig( + buffer_size=1000, + max_in_flight_tasks=4096, + row_group_admission=dd.RowGroupAdmissionConfig( + mode="adaptive", + max_concurrent_row_groups=8, + adaptive_initial_target=2, + max_admitted_rows=16_000, + ), +) + +designer = DataDesigner() +designer.set_run_config(run_config) +``` + +| Parameter | Default | Effect | +|-----------|---------|--------| +| `mode` | `fixed` | `fixed` admits up to the hard cap immediately; `adaptive` starts lower and raises the target when scheduler pressure shows that more ready work can be useful. | +| `max_concurrent_row_groups` | 3 | Hard cap on active row groups. Maximum is 64. | +| `adaptive_initial_target` | 1 in adaptive mode | Initial soft target before adaptive additive ramp-up. | +| `max_admitted_rows` | Engine-derived for adaptive mode and widened fixed horizons; unset for the default fixed horizon | Optional guardrail on total records held across active row groups. When omitted for adaptive mode or fixed mode with `max_concurrent_row_groups > 3`, the engine derives `max(max_concurrent_row_groups * buffer_size, 8192)`, bounded by the requested target record count when available, falling back to scheduled rows for direct scheduler plans, and a 1,000,000-row ceiling. Derived guards require `buffer_size` at or below that ceiling. Explicit values must be at least `buffer_size` and at most 1,000,000. | + +**When to use fixed mode**: You want predictable checkpoint cadence, lower +active memory, or easier debugging. + +**When to use adaptive mode**: Large async DAGs, fan-out/fan-in flows, mixed +latency columns, or high-capacity endpoints where the default horizon leaves +capacity idle. + +Async scheduler telemetry includes the effective mode, active row-group target, +observed maximum target, active row-group count, max admitted rows, and blocked +reasons when scheduler event instrumentation is enabled. + +--- + ### `max_parallel_requests` (InferenceParams) Sets the **maximum** concurrent LLM API calls **per model**. This is the ceiling that adaptive request admission can ramp up to — the actual concurrency at runtime may be lower if the server signals rate limits. @@ -319,7 +382,7 @@ DATA_DESIGNER_ASYNC_ENGINE=0 python my_pipeline.py | Problem | Symptom | Solution | |---------|---------|----------| | **Low throughput** | Low GPU utilization | Increase `max_parallel_requests` and/or `buffer_size`. If request admission has self-reduced due to earlier 429s (check logs for "concurrency reduced" messages), the server may need more capacity or you can wait for AIMD recovery. | -| **Frequent 429 → recovery cycles** | Logs show repeated concurrency drops and ramp-ups | The `max_parallel_requests` ceiling is above the server's sustained capacity. This is handled automatically, but you can lower the ceiling to reduce the sawtooth. | +| **Frequent 429 → recovery cycles** | Logs show repeated concurrency drops and ramp-ups | The `max_parallel_requests` ceiling is above the server's sustained capacity. This is handled automatically, but you can lower the ceiling to reduce the sawtooth or tune `request_admission` with `RequestAdmissionTuningConfig`. | | **Long tail of slow generations** | Most records fast, few very slow | Reduce `max_conversation_restarts`, simplify schemas, improve prompts | | **Multi-model idle periods** | One model busy, others idle | Reduce `buffer_size` for faster cycling, or consolidate models | | **Memory errors** | OOM crashes | Reduce `buffer_size` and `max_parallel_requests` | diff --git a/fern/versions/latest/pages/concepts/architecture-and-performance.mdx b/fern/versions/latest/pages/concepts/architecture-and-performance.mdx index f5d31eed4..8473bccf4 100644 --- a/fern/versions/latest/pages/concepts/architecture-and-performance.mdx +++ b/fern/versions/latest/pages/concepts/architecture-and-performance.mdx @@ -51,7 +51,7 @@ This guide explains the architecture, execution model, and how to tune performan ## Execution Model - The default execution path is the **async engine**, which dispatches work at the cell level and overlaps independent columns — see [Async Engine](#async-engine) below for its semantics. The legacy **sync engine** is still available for one transitional release via `DATA_DESIGNER_ASYNC_ENGINE=0` and is what this section describes. The configuration knobs documented below (`buffer_size`, `max_parallel_requests`, AIMD throttle config, error handling) apply to both engines; the differences are flagged inline. + The default execution path is the **async engine**, which dispatches work at the cell level and overlaps independent columns — see [Async Engine](#async-engine) below for its semantics. The legacy **sync engine** is still available for one transitional release via `DATA_DESIGNER_ASYNC_ENGINE=0` and is what this section describes. The public configuration knobs documented below (`buffer_size`, `max_parallel_requests`, error handling) apply to both engines; the differences are flagged inline. The sync engine processes datasets in **batches**, with **parallel** operations within each batch. @@ -103,29 +103,45 @@ Within each column, cells are processed **in parallel** up to the configured lim ### Concurrency Formula -At any moment, the number of concurrent LLM requests is: +On the sync engine, each batch is processed one column at a time. At any moment, +the number of concurrent LLM requests is: ```python concurrent_requests = min( buffer_size, # Records in current batch - current_throttle_limit, # AIMD-managed limit (≤ max_parallel_requests) + current_request_limit, # AIMD-managed limit (≤ max_parallel_requests) remaining_cells_in_column # Cells left to generate ) ``` -`max_parallel_requests` sets the **ceiling**. The actual limit (`current_throttle_limit`) is managed at runtime by an AIMD (Additive Increase / Multiplicative Decrease) controller that reacts to rate-limit signals from the inference server: +On the async engine, ready cells can come from multiple active row groups: -- **During optional startup ramp**: when `rampup_seconds` is greater than 0, a new throttle domain starts at one concurrent request and increases linearly toward `max_parallel_requests` over that duration. +```python +concurrent_requests = min( + active_ready_model_cells, # Ready cells across admitted row groups + current_request_limit, # AIMD-managed limit (≤ max_parallel_requests) + max_in_flight_tasks # Scheduler task-lease ceiling +) +``` + +`active_ready_model_cells` is bounded by row-group admission: +`max_concurrent_row_groups`, the effective `max_admitted_rows` guard, the DAG +dependencies that have become ready, and any rows already dropped by processors +or failures. + +`max_parallel_requests` sets the **ceiling**. The actual limit (`current_request_limit`) is managed at runtime by adaptive request admission that reacts to rate-limit signals from the inference server: + +- **During optional startup ramp**: when `startup_ramp_seconds` is greater than 0, a new request resource starts at one concurrent request and increases linearly toward `max_parallel_requests` over that duration. - **On the first 429 in a burst**: the limit is reduced by a configurable factor (default: 25% reduction) and a cooldown is applied. Further 429s from already in-flight requests in the same burst do not reduce the limit again — they release their permits and hold the limit steady. - **After consecutive successes**: the limit increases by 1 (by default) until it reaches the ceiling or a stabilized rate-limit threshold. This means Data Designer automatically finds the right concurrency level for your server without manual tuning. - AIMD adaptive concurrency is fully active on the default **async engine**. The legacy **sync engine** is available for one transitional release via `DATA_DESIGNER_ASYNC_ENGINE=0`; on that path 429s are first retried at the HTTP transport layer and AIMD only engages as a fallback. See [Async engine](#async-engine) below. + Request admission wraps model requests on both sync and async paths. When request admission is active, provider 429 responses propagate to the AIMD controller instead of being hidden by HTTP transport retries. See [Async engine](#async-engine) below. -**Example**: With `buffer_size=100` and `max_parallel_requests=32`, Data Designer can send up to 32 requests in parallel. If `rampup_seconds=30`, it starts at one request and climbs linearly toward 32 over 30 seconds. If the server returns 429s, startup ramp stops, concurrency drops automatically (e.g., to 24, then 18), and normal AIMD recovery takes over once the server catches up. +**Example**: With `buffer_size=100` and `max_parallel_requests=32`, Data Designer can send up to 32 requests in parallel. If `startup_ramp_seconds=30`, it starts at one request and climbs linearly toward 32 over 30 seconds. If the server returns 429s, startup ramp stops, concurrency drops automatically (e.g., to 24, then 18), and normal AIMD recovery takes over once the server catches up. --- @@ -157,6 +173,53 @@ designer.set_run_config(run_config) --- +### Row-Group Admission (RunConfig) + +Controls how many async row groups can be active at once. A row group contains +`buffer_size` records, so this setting is the scheduler horizon above the batch +size: a wider horizon can expose more ready model work to fast endpoints, while +a smaller horizon tends to checkpoint completed records earlier and hold less +active state. + +```python +import data_designer.config as dd +from data_designer.interface import DataDesigner + +run_config = dd.RunConfig( + buffer_size=1000, + max_in_flight_tasks=4096, + row_group_admission=dd.RowGroupAdmissionConfig( + mode="adaptive", + max_concurrent_row_groups=8, + adaptive_initial_target=2, + max_admitted_rows=16_000, + ), +) + +designer = DataDesigner() +designer.set_run_config(run_config) +``` + +| Parameter | Default | Effect | +|-----------|---------|--------| +| `mode` | `fixed` | `fixed` admits up to the hard cap immediately; `adaptive` starts lower and raises the target when scheduler pressure shows that more ready work can be useful. | +| `max_concurrent_row_groups` | 3 | Hard cap on active row groups. Maximum is 64. | +| `adaptive_initial_target` | 1 in adaptive mode | Initial soft target before adaptive additive ramp-up. | +| `max_admitted_rows` | Engine-derived for adaptive mode and widened fixed horizons; unset for the default fixed horizon | Optional guardrail on total records held across active row groups. When omitted for adaptive mode or fixed mode with `max_concurrent_row_groups > 3`, the engine derives `max(max_concurrent_row_groups * buffer_size, 8192)`, bounded by the requested target record count when available, falling back to scheduled rows for direct scheduler plans, and a 1,000,000-row ceiling. Derived guards require `buffer_size` at or below that ceiling. Explicit values must be at least `buffer_size` and at most 1,000,000. | + +**When to use fixed mode**: You want predictable checkpoint cadence, lower +active memory, or easier debugging. + +**When to use adaptive mode**: Large async DAGs, fan-out/fan-in flows, mixed +latency columns, or high-capacity endpoints where the default horizon leaves +capacity idle. + +Async scheduler telemetry includes the effective mode, active row-group target, +observed maximum target, active row-group count, max admitted rows, and blocked +reasons when scheduler event instrumentation is enabled. + +--- + ## Resuming Interrupted Runs Long generation jobs can be resumed from checkpoints by passing `resume` to `DataDesigner.create()` or `data-designer create --resume`. @@ -199,7 +262,7 @@ Only resume datasets from trusted artifact directories. Resume reads local `meta ### `max_parallel_requests` (InferenceParams) -Sets the **maximum** concurrent LLM API calls **per model**. This is the ceiling that the AIMD throttle controller can ramp up to — the actual concurrency at runtime may be lower if the server signals rate limits. +Sets the **maximum** concurrent LLM API calls **per model**. This is the ceiling that adaptive request admission can ramp up to — the actual concurrency at runtime may be lower if the server signals rate limits. ```python import data_designer.config as dd @@ -216,7 +279,7 @@ model = dd.ModelConfig( **Default**: 4 -**When to increase**: Your inference backend has high throughput capacity, you're using a cloud API with generous rate limits, or you're running vLLM/TensorRT-LLM with multiple GPUs. With AIMD, setting an aggressively high value is safer than before — the system will self-correct downward if the server can't keep up. The salvage queue on the async engine (default) reclaims failed rows; on the sync engine the initial burst of 429s before AIMD stabilizes can drop rows, so start with a more conservative ceiling if you've opted into sync. +**When to increase**: Your inference backend has high throughput capacity, you're using a cloud API with generous rate limits, or you're running vLLM/TensorRT-LLM with multiple GPUs. With adaptive request admission, setting an aggressively high value is safer than before — the system will self-correct downward if the server can't keep up. The salvage queue on the async engine (default) reclaims failed rows; on the sync engine the initial burst of 429s before AIMD stabilizes can drop rows, so start with a more conservative ceiling if you've opted into sync. **When to decrease**: You want to cap resource usage to a known safe level, or you want more predictable/debuggable execution. @@ -224,7 +287,7 @@ model = dd.ModelConfig( Finding the optimal value The right value depends on your inference stack and model. Self-hosted vLLM servers can often handle values as high as 256, 512, or even 1024 depending on your hardware. -With AIMD, a practical approach is to set `max_parallel_requests` to the **upper bound** you're comfortable with and let the throttle controller find the sustainable level automatically. If you see frequent 429 → recovery cycles in the logs, your ceiling is above the server's true capacity but the system is handling it. If you never see any throttle activity, you may have room to increase the ceiling further. +With adaptive request admission, a practical approach is to set `max_parallel_requests` to the **upper bound** you're comfortable with and let the request controller find the sustainable level automatically. If you see frequent 429 → recovery cycles in the logs, your ceiling is above the server's true capacity but the system is handling it. If you never see any request-admission activity, you may have room to increase the ceiling further. **Benchmark approach**: Run a small dataset (e.g., 100 records) with increasing `max_parallel_requests` values (4 → 8 → 16 → 32 → ...) and measure generation time. Stop increasing when the runtime stops decreasing—that's when your inference server is saturated. @@ -246,12 +309,12 @@ designer.set_run_config(run_config) --- -### Adaptive Throttling (RunConfig) +### Adaptive Request Admission -Data Designer uses an AIMD (Additive Increase / Multiplicative Decrease) controller to automatically adjust concurrency per model based on rate-limit feedback from the inference server. The defaults work well for most workloads. Override them via `ThrottleConfig` only when you understand the trade-offs. +Data Designer uses AIMD (Additive Increase / Multiplicative Decrease) request admission to automatically adjust concurrency per provider/model/domain based on rate-limit feedback from the inference server. For most workloads, set `max_parallel_requests` as the user-facing ceiling and inspect `AsyncCapacityPlan`/logs to understand the effective runtime limits. Advanced AIMD tuning is available through `RequestAdmissionTuningConfig`. - Adaptive throttling is fully active on the default **async engine**, where 429 responses propagate directly to the AIMD controller. On the legacy **sync engine** (`DATA_DESIGNER_ASYNC_ENGINE=0`), 429s are first retried at the HTTP transport layer; `ThrottleConfig` settings only take effect as a fallback if transport retries are exhausted. + Request admission wraps model requests on both sync and async paths. When request admission is active, provider 429 responses propagate to the AIMD controller instead of being hidden by HTTP transport retries. ```python @@ -259,13 +322,12 @@ import data_designer.config as dd from data_designer.interface import DataDesigner run_config = dd.RunConfig( - throttle=dd.ThrottleConfig( - reduce_factor=0.75, # Multiply limit by this on a 429 (default: 0.75) - additive_increase=1, # Add this many slots after success_window successes (default: 1) - success_window=25, # Consecutive successes before increasing (default: 25) - cooldown_seconds=2.0, # Pause after a 429 when no Retry-After header (default: 2.0) - ceiling_overshoot=0.10, # Probe 10% above observed server limit (default: 0.10) - rampup_seconds=0.0, # Optional startup ramp duration; 0 disables it (default: 0.0) + request_admission=dd.RequestAdmissionTuningConfig( + multiplicative_decrease_factor=0.75, # Multiply limit by this on a 429 + additive_increase_step=1, # Slots added after each success window + successes_until_increase=25, # Successful releases before increasing + cooldown_seconds=2.0, # Fallback pause when no Retry-After header is present + startup_ramp_seconds=0.0, # Optional startup ramp duration; 0 disables it ), ) @@ -275,16 +337,17 @@ designer.set_run_config(run_config) | Parameter | Default | Effect | |-----------|---------|--------| -| `reduce_factor` | 0.75 | How aggressively to cut concurrency on a 429. Lower = more aggressive. | -| `additive_increase` | 1 | Slots added per recovery step. Higher = faster ramp-up, but riskier. | -| `success_window` | 25 | Consecutive successes required before each increase step. | -| `cooldown_seconds` | 2.0 | Pause duration after a 429 (used when the server doesn't send `Retry-After`). | -| `ceiling_overshoot` | 0.10 | Fraction above the observed rate-limit ceiling the controller is allowed to probe. | -| `rampup_seconds` | 0.0 | Optional startup ramp duration. When greater than 0, domains start at one concurrent request and linearly climb to the configured ceiling unless a 429 aborts the ramp. | +| `multiplicative_decrease_factor` | 0.75 | How aggressively to cut concurrency on a 429. Lower = more aggressive. | +| `additive_increase_step` | 1 | Slots added per recovery step. Higher = faster recovery, but riskier. | +| `successes_until_increase` | 25 | Successful releases required before each increase step. | +| `cooldown_seconds` | 2.0 | Pause duration after a 429 when the server does not send `Retry-After`. | +| `startup_ramp_seconds` | 0.0 | Optional startup ramp duration. When greater than 0, resources start at one concurrent request and linearly climb to the configured ceiling unless a 429 aborts the ramp. | + +`RunConfig.throttle` and `ThrottleConfig` remain as deprecated compatibility shims. Existing `reduce_factor`, `additive_increase`, `success_window`, `cooldown_seconds`, and `rampup_seconds` values are translated to `RequestAdmissionTuningConfig`; `ceiling_overshoot` is accepted for compatibility but is no longer forwarded because request admission does not expose that knob. How it works in practice -When a model endpoint returns HTTP 429, the controller reduces the concurrency limit for that model and pauses briefly. After enough consecutive successes, it begins ramping back up. If the server rate-limits again, the controller records that level as a ceiling and stabilizes just below it, with a small overshoot band to detect when the server can handle more load. +When a model endpoint returns HTTP 429, the controller reduces the concurrency limit for that request resource and pauses briefly. After enough consecutive successes, it begins ramping back up. If the server rate-limits again, the controller records that level as a ceiling and stabilizes at a lower sustainable limit. You can observe this in the logs — look for messages like `concurrency reduced from X → Y` and `concurrency increased from X → Y`. @@ -316,11 +379,11 @@ designer.set_run_config(run_config) ## Async Engine -The async engine is the default execution path. It dispatches work at the cell level rather than the column level, so independent columns overlap in time and per-(provider, model) AIMD pools tune themselves independently. See the [Async All the Way Down](/dev-notes/async-all-the-way-down) dev note for the full architecture. +The async engine is the default execution path. It dispatches work at the cell level rather than the column level, so independent columns overlap in time and provider/model/domain request resources tune themselves independently. See the [Async All the Way Down](/dev-notes/async-all-the-way-down) dev note for the full architecture. ### Per-model timeouts drive every deadline -The `inference_parameters.timeout` field on a `ModelConfig` sets the per-request HTTP timeout. The same value also drives the sync→async bridge that custom columns use when they call `model.generate()`. There is no separate queue-wait deadline — waits scale with provider speed and AIMD's adaptive concurrency. Slow self-hosted endpoints (e.g. large models on a single GPU) only need this one knob raised: +The `inference_parameters.timeout` field on a `ModelConfig` sets the per-request HTTP timeout. The same value also drives the sync→async bridge that custom columns use when they call `model.generate()`. There is no separate queue-wait deadline — waits scale with provider speed and adaptive request admission. Slow self-hosted endpoints (e.g. large models on a single GPU) only need this one knob raised: ```python import data_designer.config as dd @@ -369,8 +432,8 @@ DATA_DESIGNER_ASYNC_ENGINE=0 python my_pipeline.py | Problem | Symptom | Solution | |---------|---------|----------| -| **Low throughput** | Low GPU utilization | Increase `max_parallel_requests` and/or `buffer_size`. If the throttle has self-reduced due to earlier 429s (check logs for "concurrency reduced" messages), the server may need more capacity or you can wait for AIMD recovery. | -| **Frequent 429 → recovery cycles** | Logs show repeated concurrency drops and ramp-ups | The `max_parallel_requests` ceiling is above the server's sustained capacity. This is handled automatically, but you can lower the ceiling to reduce the sawtooth or tune `reduce_factor` / `success_window`. | +| **Low throughput** | Low GPU utilization | Increase `max_parallel_requests` and/or `buffer_size`. If request admission has self-reduced due to earlier 429s (check logs for "concurrency reduced" messages), the server may need more capacity or you can wait for AIMD recovery. | +| **Frequent 429 → recovery cycles** | Logs show repeated concurrency drops and ramp-ups | The `max_parallel_requests` ceiling is above the server's sustained capacity. This is handled automatically, but you can lower the ceiling to reduce the sawtooth or tune `request_admission` with `RequestAdmissionTuningConfig`. | | **Long tail of slow generations** | Most records fast, few very slow | Reduce `max_conversation_restarts`, simplify schemas, improve prompts | | **Multi-model idle periods** | One model busy, others idle | Reduce `buffer_size` for faster cycling, or consolidate models | | **Memory errors** | OOM crashes | Reduce `buffer_size` and `max_parallel_requests` | @@ -380,10 +443,10 @@ DATA_DESIGNER_ASYNC_ENGINE=0 python my_pipeline.py ## Tuning Workflow -1. **Start with defaults** for initial development — AIMD handles rate-limit adaptation automatically +1. **Start with defaults** for initial development — adaptive request admission handles rate-limit adaptation automatically 2. **Profile your workload**: How many LLM columns? How many records? What models? -3. **Identify bottleneck**: Low GPU util → increase `max_parallel_requests` (AIMD will self-correct if you overshoot). Memory issues → decrease `buffer_size`. Long tails → tune retry settings. -4. **Check throttle logs**: Look for "concurrency reduced" / "concurrency increased" messages to understand whether rate limits are the bottleneck +3. **Identify bottleneck**: Low GPU util → increase `max_parallel_requests` (request admission will self-correct if you overshoot). Memory issues → decrease `buffer_size`. Long tails → tune retry settings. +4. **Check request-admission logs**: Look for "concurrency reduced" / "concurrency increased" messages to understand whether rate limits are the bottleneck 5. **Iterate**: Make one change at a time, measure impact before next change --- diff --git a/packages/data-designer-config/src/data_designer/config/__init__.py b/packages/data-designer-config/src/data_designer/config/__init__.py index e608476b2..0b3d68397 100644 --- a/packages/data-designer-config/src/data_designer/config/__init__.py +++ b/packages/data-designer-config/src/data_designer/config/__init__.py @@ -61,8 +61,12 @@ SchemaTransformProcessorConfig, ) from data_designer.config.run_config import ( # noqa: F401 + MAX_ROW_GROUP_ADMISSION_HORIZON, + MAX_ROW_GROUP_ADMITTED_ROWS, JinjaRenderingEngine, RequestAdmissionTuningConfig, + RowGroupAdmissionConfig, + RowGroupAdmissionMode, RunConfig, ThrottleConfig, ) @@ -188,7 +192,11 @@ "SchemaTransformProcessorConfig": (_MOD_PROCESSORS, "SchemaTransformProcessorConfig"), # run_config "JinjaRenderingEngine": (f"{_MOD_BASE}.run_config", "JinjaRenderingEngine"), + "MAX_ROW_GROUP_ADMISSION_HORIZON": (f"{_MOD_BASE}.run_config", "MAX_ROW_GROUP_ADMISSION_HORIZON"), + "MAX_ROW_GROUP_ADMITTED_ROWS": (f"{_MOD_BASE}.run_config", "MAX_ROW_GROUP_ADMITTED_ROWS"), "RequestAdmissionTuningConfig": (f"{_MOD_BASE}.run_config", "RequestAdmissionTuningConfig"), + "RowGroupAdmissionConfig": (f"{_MOD_BASE}.run_config", "RowGroupAdmissionConfig"), + "RowGroupAdmissionMode": (f"{_MOD_BASE}.run_config", "RowGroupAdmissionMode"), "RunConfig": (f"{_MOD_BASE}.run_config", "RunConfig"), "ThrottleConfig": (f"{_MOD_BASE}.run_config", "ThrottleConfig"), # scheduling metadata diff --git a/packages/data-designer-config/src/data_designer/config/run_config.py b/packages/data-designer-config/src/data_designer/config/run_config.py index 654edde84..13b3f94f4 100644 --- a/packages/data-designer-config/src/data_designer/config/run_config.py +++ b/packages/data-designer-config/src/data_designer/config/run_config.py @@ -25,6 +25,10 @@ class JinjaRenderingEngine(StrEnum): "RequestAdmissionTuningConfig for supported advanced request-admission tuning." ) +DEFAULT_ROW_GROUP_ADMISSION_HORIZON = 3 +MAX_ROW_GROUP_ADMISSION_HORIZON = 64 +MAX_ROW_GROUP_ADMITTED_ROWS = 1_000_000 + class RequestAdmissionTuningConfig(ConfigBase): """Advanced request-admission AIMD tuning for model API calls. @@ -65,6 +69,70 @@ class RequestAdmissionTuningConfig(ConfigBase): ) +class RowGroupAdmissionMode(StrEnum): + """Row-group admission policy used by the async scheduler.""" + + FIXED = "fixed" + ADAPTIVE = "adaptive" + + +class RowGroupAdmissionConfig(ConfigBase): + """Async row-group admission horizon and optional adaptive ramp-up policy. + + ``buffer_size`` defines how many records belong to one row group. This + policy controls how many row groups and records the async scheduler may keep + active at once. Fixed mode uses ``max_concurrent_row_groups`` as a hard + horizon. Adaptive mode starts at ``adaptive_initial_target`` and raises the + active target up to ``max_concurrent_row_groups`` when scheduler pressure + indicates more ready work can be admitted. Adaptive mode and widened fixed + horizons derive an active-record guard when ``max_admitted_rows`` is omitted, + while the default fixed horizon preserves historical row-group-count-only + behavior. + """ + + mode: RowGroupAdmissionMode = Field( + default=RowGroupAdmissionMode.FIXED, + description="Use a fixed row-group horizon or adaptive additive ramp-up beneath that hard cap.", + ) + max_concurrent_row_groups: int = Field( + default=DEFAULT_ROW_GROUP_ADMISSION_HORIZON, + ge=1, + le=MAX_ROW_GROUP_ADMISSION_HORIZON, + description="Hard cap on row groups that may be active in the async scheduler at once. Maximum is 64.", + ) + adaptive_initial_target: int | None = Field( + default=None, + ge=1, + description=( + "Initial active row-group target for adaptive mode. Defaults to 1 when omitted. " + "Must not exceed max_concurrent_row_groups." + ), + ) + max_admitted_rows: int | None = Field( + default=None, + ge=1, + le=MAX_ROW_GROUP_ADMITTED_ROWS, + description=( + "Optional guardrail on the total records held across active row groups. " + "When set on RunConfig, it must be at least buffer_size and at most 1,000,000. " + "When omitted in adaptive mode or widened fixed mode, the engine derives a conservative " + "guardrail from buffer_size and target record count." + ), + ) + + @model_validator(mode="after") + def validate_adaptive_settings(self) -> Self: + mode = RowGroupAdmissionMode(self.mode) + if mode == RowGroupAdmissionMode.FIXED: + if self.adaptive_initial_target is not None: + raise ValueError("adaptive_initial_target applies only when row-group admission mode is 'adaptive'.") + elif self.adaptive_initial_target is None: + self.adaptive_initial_target = 1 + elif self.adaptive_initial_target > self.max_concurrent_row_groups: + raise ValueError("adaptive_initial_target must not exceed max_concurrent_row_groups.") + return self + + class ThrottleConfig(ConfigBase): """Deprecated compatibility DTO for request-admission tuning. @@ -130,9 +198,10 @@ class RunConfig(ConfigBase): early shutdown is enabled. Default is 0.5. shutdown_error_window: Minimum number of completed tasks before error rate monitoring begins. Must be >= 1. Default is 10. - buffer_size: Number of records to process in each batch during dataset generation. - A batch is processed end-to-end (column generation, post-batch processors, and writing the batch - to artifact storage) before moving on to the next batch. Must be > 0. Default is 1000. + buffer_size: Number of records in each sync batch or async row group during dataset generation. + The sync engine processes one batch end-to-end before moving to the next batch. The async engine + may admit multiple row groups concurrently according to row_group_admission. Must be > 0. + Default is 1000. max_in_flight_tasks: Maximum number of async scheduler tasks that may hold task leases at once. Tasks may be executing, awaiting I/O, or waiting on model request admission. Model API request concurrency is controlled separately by @@ -161,6 +230,10 @@ class RunConfig(ConfigBase): Default is ``secure``. request_admission: Advanced AIMD request-admission tuning for provider/model calls. Most users should leave this unset and tune ``max_parallel_requests`` instead. + row_group_admission: Async scheduler row-group horizon/adaptive admission policy. + Defaults to a fixed horizon of three active row groups. Tune this + for large async runs that need earlier checkpoints or wider endpoint + occupancy. Notes: Request-admission controller internals remain engine-owned. This field @@ -200,6 +273,7 @@ class RunConfig(ConfigBase): ), ) request_admission: RequestAdmissionTuningConfig | None = None + row_group_admission: RowGroupAdmissionConfig = Field(default_factory=RowGroupAdmissionConfig) @model_validator(mode="before") @classmethod @@ -225,6 +299,27 @@ def translate_deprecated_throttle_config(cls, data: Any) -> Any: return normalized return data + @model_validator(mode="after") + def validate_row_group_admission_budget(self) -> Self: + mode = RowGroupAdmissionMode(self.row_group_admission.mode) + requires_derived_row_guard = ( + mode == RowGroupAdmissionMode.ADAPTIVE + or self.row_group_admission.max_concurrent_row_groups > DEFAULT_ROW_GROUP_ADMISSION_HORIZON + ) + if ( + self.row_group_admission.max_admitted_rows is None + and requires_derived_row_guard + and self.buffer_size > MAX_ROW_GROUP_ADMITTED_ROWS + ): + raise ValueError( + f"row-group admission with a derived active-row guard requires buffer_size to be at most " + f"{MAX_ROW_GROUP_ADMITTED_ROWS}." + ) + max_admitted_rows = self.row_group_admission.max_admitted_rows + if max_admitted_rows is not None and max_admitted_rows < self.buffer_size: + raise ValueError("row_group_admission.max_admitted_rows must be at least buffer_size.") + return self + @model_validator(mode="after") def normalize_shutdown_settings(self) -> Self: """Normalize shutdown settings for compatibility.""" diff --git a/packages/data-designer-config/tests/config/test_run_config.py b/packages/data-designer-config/tests/config/test_run_config.py index 3b6718100..aa6befaeb 100644 --- a/packages/data-designer-config/tests/config/test_run_config.py +++ b/packages/data-designer-config/tests/config/test_run_config.py @@ -8,8 +8,12 @@ import data_designer.config as dd from data_designer.config.run_config import ( + MAX_ROW_GROUP_ADMISSION_HORIZON, + MAX_ROW_GROUP_ADMITTED_ROWS, JinjaRenderingEngine, RequestAdmissionTuningConfig, + RowGroupAdmissionConfig, + RowGroupAdmissionMode, RunConfig, ThrottleConfig, ) @@ -142,10 +146,147 @@ def test_run_config_accepts_request_admission_tuning_dict() -> None: assert run_config.request_admission.startup_ramp_seconds == 10.0 +def test_run_config_exposes_default_row_group_admission_policy() -> None: + run_config = RunConfig() + + assert run_config.row_group_admission is not None + assert RowGroupAdmissionMode(run_config.row_group_admission.mode) == RowGroupAdmissionMode.FIXED + assert run_config.row_group_admission.max_concurrent_row_groups == 3 + assert run_config.row_group_admission.adaptive_initial_target is None + assert run_config.row_group_admission.max_admitted_rows is None + + +def test_run_config_exports_row_group_admission_public_caps() -> None: + assert dd.MAX_ROW_GROUP_ADMISSION_HORIZON == MAX_ROW_GROUP_ADMISSION_HORIZON + assert dd.MAX_ROW_GROUP_ADMITTED_ROWS == MAX_ROW_GROUP_ADMITTED_ROWS + + +def test_run_config_accepts_row_group_admission_tuning() -> None: + run_config = RunConfig( + row_group_admission=RowGroupAdmissionConfig( + mode=RowGroupAdmissionMode.ADAPTIVE, + max_concurrent_row_groups=8, + adaptive_initial_target=2, + max_admitted_rows=4096, + ) + ) + + assert run_config.row_group_admission is not None + assert RowGroupAdmissionMode(run_config.row_group_admission.mode) == RowGroupAdmissionMode.ADAPTIVE + assert run_config.row_group_admission.max_concurrent_row_groups == 8 + assert run_config.row_group_admission.adaptive_initial_target == 2 + assert run_config.row_group_admission.max_admitted_rows == 4096 + + +def test_row_group_admission_config_normalizes_adaptive_initial_target() -> None: + row_group_admission = RowGroupAdmissionConfig(mode=RowGroupAdmissionMode.ADAPTIVE) + + assert row_group_admission.adaptive_initial_target == 1 + + +def test_run_config_accepts_fixed_row_group_admission_row_budget() -> None: + run_config = RunConfig(row_group_admission=RowGroupAdmissionConfig(max_admitted_rows=2048)) + + assert RowGroupAdmissionMode(run_config.row_group_admission.mode) == RowGroupAdmissionMode.FIXED + assert run_config.row_group_admission.max_admitted_rows == 2048 + + +def test_run_config_accepts_row_group_admission_dict() -> None: + run_config = RunConfig( + row_group_admission={ + "mode": "adaptive", + "max_concurrent_row_groups": 5, + "adaptive_initial_target": 3, + } + ) + + assert run_config.row_group_admission is not None + assert RowGroupAdmissionMode(run_config.row_group_admission.mode) == RowGroupAdmissionMode.ADAPTIVE + assert run_config.row_group_admission.max_concurrent_row_groups == 5 + assert run_config.row_group_admission.adaptive_initial_target == 3 + + +def test_row_group_admission_config_rejects_invalid_adaptive_target() -> None: + with pytest.raises(ValidationError, match="adaptive_initial_target must not exceed max_concurrent_row_groups"): + RowGroupAdmissionConfig( + mode=RowGroupAdmissionMode.ADAPTIVE, + max_concurrent_row_groups=2, + adaptive_initial_target=3, + ) + + +def test_row_group_admission_config_rejects_adaptive_only_fields_in_fixed_mode() -> None: + with pytest.raises(ValidationError, match="adaptive_initial_target applies only"): + RowGroupAdmissionConfig(adaptive_initial_target=1) + + +def test_row_group_admission_config_rejects_unsafe_public_caps() -> None: + with pytest.raises(ValidationError, match="max_concurrent_row_groups"): + RowGroupAdmissionConfig(max_concurrent_row_groups=65) + + with pytest.raises(ValidationError, match="max_admitted_rows"): + RowGroupAdmissionConfig(max_admitted_rows=1_000_001) + + +@pytest.mark.parametrize( + "kwargs", + [ + {"max_concurrent_row_groups": 0}, + {"mode": RowGroupAdmissionMode.ADAPTIVE, "adaptive_initial_target": 0}, + {"max_admitted_rows": 0}, + ], +) +def test_row_group_admission_config_rejects_non_positive_values(kwargs: dict[str, object]) -> None: + with pytest.raises(ValidationError): + RowGroupAdmissionConfig(**kwargs) + + +def test_run_config_rejects_row_group_admission_row_budget_below_buffer_size() -> None: + with pytest.raises(ValidationError, match="row_group_admission.max_admitted_rows must be at least buffer_size"): + RunConfig( + buffer_size=1000, + row_group_admission=RowGroupAdmissionConfig(max_admitted_rows=999), + ) + + +def test_run_config_rejects_adaptive_row_group_size_above_public_row_guard() -> None: + with pytest.raises(ValidationError, match="derived active-row guard requires buffer_size"): + RunConfig( + buffer_size=MAX_ROW_GROUP_ADMITTED_ROWS + 1, + row_group_admission=RowGroupAdmissionConfig(mode=RowGroupAdmissionMode.ADAPTIVE), + ) + + +def test_run_config_rejects_widened_fixed_row_group_size_above_public_row_guard() -> None: + with pytest.raises(ValidationError, match="derived active-row guard requires buffer_size"): + RunConfig( + buffer_size=MAX_ROW_GROUP_ADMITTED_ROWS + 1, + row_group_admission=RowGroupAdmissionConfig(max_concurrent_row_groups=4), + ) + + +def test_run_config_allows_fixed_row_group_size_above_public_row_guard_by_default() -> None: + run_config = RunConfig(buffer_size=MAX_ROW_GROUP_ADMITTED_ROWS + 1) + + assert RowGroupAdmissionMode(run_config.row_group_admission.mode) == RowGroupAdmissionMode.FIXED + assert run_config.row_group_admission.max_admitted_rows is None + + +def test_row_group_admission_config_rejects_unknown_fields() -> None: + with pytest.raises(ValidationError, match="row_group_horizon"): + RowGroupAdmissionConfig(row_group_horizon=3) + + def test_request_admission_tuning_config_is_exported_from_config_package() -> None: assert dd.RequestAdmissionTuningConfig is RequestAdmissionTuningConfig +def test_row_group_admission_config_is_exported_from_config_package() -> None: + assert dd.MAX_ROW_GROUP_ADMITTED_ROWS == MAX_ROW_GROUP_ADMITTED_ROWS + assert dd.RowGroupAdmissionConfig is RowGroupAdmissionConfig + assert dd.RowGroupAdmissionMode is RowGroupAdmissionMode + + def test_deprecated_throttle_config_is_exported_from_config_package() -> None: assert dd.ThrottleConfig is ThrottleConfig namespace: dict[str, object] = {} diff --git a/packages/data-designer-engine/src/data_designer/engine/capacity.py b/packages/data-designer-engine/src/data_designer/engine/capacity.py index e10a729e7..48ac84a14 100644 --- a/packages/data-designer-engine/src/data_designer/engine/capacity.py +++ b/packages/data-designer-engine/src/data_designer/engine/capacity.py @@ -43,6 +43,7 @@ class RowGroupAdmission: target_in_flight: int | None = None observed_max_target: int | None = None max_admitted_rows: int | None = None + max_admitted_rows_source: CapacityValueSource | None = None blocked_reasons: Mapping[str, int] = field(default_factory=dict) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py index 6d415c6b4..4389e6860 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py @@ -16,12 +16,18 @@ import data_designer.lazy_heavy_imports as lazy from data_designer.config.column_configs import GenerationStrategy +from data_designer.config.run_config import ( + DEFAULT_ROW_GROUP_ADMISSION_HORIZON, + MAX_ROW_GROUP_ADMISSION_HORIZON, + MAX_ROW_GROUP_ADMITTED_ROWS, +) from data_designer.engine.capacity import ( AsyncCapacityConfigured, AsyncCapacityObservedMaxima, AsyncCapacityPlan, AsyncCapacityRuntimeSnapshot, CapacityValue, + CapacityValueSource, RequestAdmissionConfigSnapshot, RowGroupAdmission, ) @@ -160,7 +166,7 @@ def __init__( row_groups: RowGroupInput, buffer_manager: RowGroupBufferManager | None = None, *, - max_concurrent_row_groups: int = 3, + max_concurrent_row_groups: int = DEFAULT_ROW_GROUP_ADMISSION_HORIZON, max_in_flight_tasks: int = DEFAULT_IN_FLIGHT_TASK_CAPACITY, max_model_task_admission: int = DEFAULT_IN_FLIGHT_TASK_CAPACITY, task_admission_config: TaskAdmissionConfig | None = None, @@ -184,9 +190,24 @@ def __init__( initial_completed_records: int = 0, adaptive_row_group_admission: bool = False, adaptive_row_group_initial_target: int = 1, + max_admitted_rows: int | None = None, + row_group_admission_source: CapacityValueSource = "default", request_pressure_provider: RequestPressureSnapshotProvider | None = None, request_pressure_advisory: bool = False, ) -> None: + if max_concurrent_row_groups < 1: + raise ValueError("max_concurrent_row_groups must be at least 1.") + if max_concurrent_row_groups > MAX_ROW_GROUP_ADMISSION_HORIZON: + raise ValueError(f"max_concurrent_row_groups must be at most {MAX_ROW_GROUP_ADMISSION_HORIZON}.") + if adaptive_row_group_initial_target < 1: + raise ValueError("adaptive_row_group_initial_target must be at least 1.") + if adaptive_row_group_initial_target > max_concurrent_row_groups: + raise ValueError("adaptive_row_group_initial_target must not exceed max_concurrent_row_groups.") + if max_admitted_rows is not None and max_admitted_rows < 1: + raise ValueError("max_admitted_rows must be at least 1.") + if max_admitted_rows is not None and max_admitted_rows > MAX_ROW_GROUP_ADMITTED_ROWS: + raise ValueError(f"max_admitted_rows must be at most {MAX_ROW_GROUP_ADMITTED_ROWS}.") + self._generators = generators self._graph = graph self._tracker = tracker @@ -302,7 +323,6 @@ def __init__( self._first_non_retryable_error: Exception | None = None self._fatal_worker_error: BaseException | None = None - self._max_concurrent_row_groups = max_concurrent_row_groups self._max_in_flight_tasks = max_in_flight_tasks self._max_model_task_admission = max_model_task_admission self._num_records = num_records @@ -317,9 +337,9 @@ def __init__( self._observed_max_provider_model_aggregate_in_flight: dict[ProviderModelKey, int] = {} self._observed_max_request_domain_current_limits: dict[RequestResourceKey, int] = {} self._adaptive_row_group_admission = adaptive_row_group_admission - self._row_group_admission_hard_cap = max(1, max_concurrent_row_groups) + self._row_group_admission_hard_cap = max_concurrent_row_groups self._row_group_admission_target = ( - max(1, min(self._row_group_admission_hard_cap, adaptive_row_group_initial_target)) + min(self._row_group_admission_hard_cap, adaptive_row_group_initial_target) if adaptive_row_group_admission else self._row_group_admission_hard_cap ) @@ -328,7 +348,28 @@ def __init__( self._row_group_admission_event.set() self._row_group_admission_pressure_ticks = 0 self._row_group_admission_blocked_reasons: Counter[str] = Counter() - self._adaptive_max_admitted_rows = self._max_admitted_rows_guardrail() + derive_row_guard = ( + adaptive_row_group_admission or max_concurrent_row_groups > DEFAULT_ROW_GROUP_ADMISSION_HORIZON + ) + self._max_admitted_rows = ( + max_admitted_rows + if max_admitted_rows is not None + else self._max_admitted_rows_guardrail() + if derive_row_guard + else None + ) + self._max_admitted_rows_source: CapacityValueSource | None = ( + row_group_admission_source + if max_admitted_rows is not None + else "engine_internal_config" + if self._max_admitted_rows is not None + else None + ) + self._validate_row_group_row_budget() + # Diagnostic-only candidate size for the row-group admission attempt + # currently waiting on capacity; actual admission always rechecks guards. + self._pending_row_group_admission_size: int | None = None + self._row_group_admission_source = row_group_admission_source self._request_pressure_provider = request_pressure_provider self._request_pressure_advisory = request_pressure_advisory and request_pressure_provider is not None self._request_pressure_advisory_skips = 0 @@ -513,7 +554,7 @@ def _scheduler_health_diagnostics(self, *, reason: str) -> dict[str, object]: "target_row_groups": self._row_group_admission_target, "hard_cap_row_groups": self._row_group_admission_hard_cap, "active_admitted_rows": self._active_admitted_row_count(), - "max_admitted_rows": self._adaptive_max_admitted_rows, + "max_admitted_rows": self._effective_max_admitted_rows(), "all_row_groups_admitted": self._all_rgs_admitted, "queued_total": queue_view.queued_total, "queued_by_group": _string_keyed_counts(queue_view.queued_by_group), @@ -553,7 +594,7 @@ def _scheduler_job_diagnostics(self) -> dict[str, object]: "adaptive_row_group_admission": self._adaptive_row_group_admission, "row_group_initial_target": self._row_group_admission_target, "row_group_hard_cap": self._row_group_admission_hard_cap, - "max_admitted_rows": self._adaptive_max_admitted_rows, + "max_admitted_rows": self._effective_max_admitted_rows(), "request_pressure_advisory_enabled": self._request_pressure_advisory, } @@ -872,8 +913,12 @@ def _task_flow_identity(self, task: Task) -> tuple[str, ...]: def _max_admitted_rows_guardrail(self) -> int: if self._num_records > 0 and self._buffer_size > 0: - return min(self._num_records, max(3 * self._buffer_size, 8192)) - return max(1, self._row_groups.scheduled_total_rows) + return min( + self._num_records, + max(self._row_group_admission_hard_cap * self._buffer_size, 8192), + MAX_ROW_GROUP_ADMITTED_ROWS, + ) + return min(max(1, self._row_groups.scheduled_total_rows), MAX_ROW_GROUP_ADMITTED_ROWS) async def _wait_for_row_group_admission_capacity(self, row_group_size: int) -> None: while True: @@ -888,40 +933,63 @@ async def _wait_for_row_group_admission_capacity(self, row_group_size: int) -> N return if row_guard_blocked: self._row_group_admission_blocked_reasons["max_admitted_rows"] += 1 - self._emit_scheduler_event( + self._emit_row_group_admission_event( "row_group_admission_blocked", - diagnostics=self._row_group_admission_diagnostics(reason="max_admitted_rows"), + reason="max_admitted_rows", ) - self._emit_scheduler_health_snapshot("row_group_admission_blocked") await self._row_group_admission_event.wait() self._raise_if_fatal_worker_error() def _row_group_row_guard_allows(self, row_group_size: int) -> bool: - if not self._adaptive_row_group_admission: + max_admitted_rows = self._max_admitted_rows + if max_admitted_rows is None: return True + if row_group_size > max_admitted_rows: + return False admitted_rows = self._active_admitted_row_count() - return admitted_rows == 0 or admitted_rows + row_group_size <= self._adaptive_max_admitted_rows + return admitted_rows + row_group_size <= max_admitted_rows def _active_admitted_row_count(self) -> int: return sum(state.size for state in self._rg_states.values()) + def _effective_max_admitted_rows(self) -> int | None: + return self._max_admitted_rows + + def _validate_row_group_row_budget(self) -> None: + max_admitted_rows = self._max_admitted_rows + if max_admitted_rows is None: + return + max_row_group_size = self._row_groups.row_group_max_size + if max_row_group_size <= max_admitted_rows: + return + if max_admitted_rows >= MAX_ROW_GROUP_ADMITTED_ROWS: + hint = ( + f"Reduce buffer_size or row-group size; max_admitted_rows is capped at {MAX_ROW_GROUP_ADMITTED_ROWS}." + ) + elif self._max_admitted_rows_source == "engine_internal_config": + hint = "Reduce buffer_size or row-group size, or set row_group_admission.max_admitted_rows explicitly." + else: + hint = "Reduce buffer_size or increase row_group_admission.max_admitted_rows." + raise ValueError( + f"row-group size {max_row_group_size} exceeds row_group_admission.max_admitted_rows " + f"({max_admitted_rows}). {hint}" + ) + def _maybe_update_adaptive_row_group_target(self) -> None: if not self._adaptive_row_group_admission: return if self._all_rgs_admitted or self._early_shutdown or self._fatal_worker_error is not None: return + if self._pending_row_group_admission_size is None: + return if len(self._rg_states) >= self._row_group_admission_hard_cap: self._row_group_admission_pressure_ticks = 0 return - reason = self._adaptive_row_group_block_reason() + reason = self._adaptive_row_group_block_reason(self._pending_row_group_admission_size) if reason is not None: self._row_group_admission_blocked_reasons[reason] += 1 self._row_group_admission_pressure_ticks = 0 - self._emit_scheduler_event( - "row_group_admission_blocked", - diagnostics=self._row_group_admission_diagnostics(reason=reason), - ) - self._emit_scheduler_health_snapshot("row_group_admission_blocked") + self._emit_row_group_admission_event("row_group_admission_blocked", reason=reason) return self._row_group_admission_pressure_ticks += 1 @@ -935,20 +1003,18 @@ def _maybe_update_adaptive_row_group_target(self) -> None: ) self._row_group_admission_pressure_ticks = 0 if self._row_group_admission_target != old_target: - self._emit_scheduler_event( + self._emit_row_group_admission_event( "row_group_admission_target_changed", - diagnostics=self._row_group_admission_diagnostics(reason="horizon_limited") - | {"old_target": old_target, "new_target": self._row_group_admission_target}, + reason="horizon_limited", + extra={"old_target": old_target, "new_target": self._row_group_admission_target}, ) - self._emit_scheduler_health_snapshot("row_group_admission_target_changed") self._row_group_admission_event.set() - def _adaptive_row_group_block_reason(self) -> str | None: + def _adaptive_row_group_block_reason(self, next_size: int | None) -> str | None: if self._deferred: return "deferred_tasks" - next_size = self._next_unadmitted_row_group_size() if next_size is None: - return "no_pending_row_groups" + return None if not self._row_group_row_guard_allows(next_size): return "max_admitted_rows" queue_view = self._fair_queue.view() @@ -967,14 +1033,6 @@ def _adaptive_row_group_block_reason(self) -> str | None: return "queued_llm_demand" return None - def _next_unadmitted_row_group_size(self) -> int | None: - for rg_id, rg_size in self._row_groups: - if rg_id not in self._rg_states and not self._tracker.is_row_group_complete( - rg_id, rg_size, self._graph.columns - ): - return rg_size - return None - def _row_group_admission_diagnostics(self, *, reason: str) -> dict[str, object]: queue_view = self._fair_queue.view() task_view = self._task_admission.view() @@ -986,7 +1044,7 @@ def _row_group_admission_diagnostics(self, *, reason: str) -> dict[str, object]: "target_row_groups": self._row_group_admission_target, "hard_cap": self._row_group_admission_hard_cap, "admitted_rows": admitted_rows, - "max_admitted_rows": self._adaptive_max_admitted_rows, + "max_admitted_rows": self._effective_max_admitted_rows(), "queued_total": queue_view.queued_total, "queued_llm_wait_demand": queue_view.queued_peer_demand_by_resource.get("llm_wait", 0), "llm_wait_limit": task_view.resource_limits.get("llm_wait", 0), @@ -995,40 +1053,53 @@ def _row_group_admission_diagnostics(self, *, reason: str) -> dict[str, object]: "blocked_reasons": dict(self._row_group_admission_blocked_reasons), } + def _emit_row_group_admission_event( + self, + event_kind: str, + *, + reason: str, + extra: dict[str, object] | None = None, + ) -> None: + if self._scheduler_event_sink is None: + return + diagnostics = self._row_group_admission_diagnostics(reason=reason) + if extra: + diagnostics |= extra + self._emit_scheduler_event(event_kind, diagnostics=diagnostics) + self._emit_scheduler_health_snapshot(event_kind) + async def _admit_row_groups(self) -> None: """Admit row groups as semaphore slots become available.""" all_admitted = True - for rg_id, rg_size in self._row_groups: - await self._wait_for_row_group_admission_capacity(rg_size) - if self._early_shutdown or self._fatal_worker_error is not None: - all_admitted = False - break - await self._rg_semaphore.acquire() - if self._early_shutdown or self._fatal_worker_error is not None: - self._rg_semaphore.release() - all_admitted = False - break - if not self._row_group_row_guard_allows(rg_size): - self._rg_semaphore.release() - await self._wait_for_row_group_admission_capacity(rg_size) + try: + for rg_id, rg_size in self._row_groups: + self._pending_row_group_admission_size = rg_size + try: + await self._wait_for_row_group_admission_capacity(rg_size) + finally: + self._pending_row_group_admission_size = None + if self._early_shutdown or self._fatal_worker_error is not None: + all_admitted = False + break await self._rg_semaphore.acquire() if self._early_shutdown or self._fatal_worker_error is not None: self._rg_semaphore.release() all_admitted = False break - self._rg_states[rg_id] = _RowGroupState(size=rg_size) + self._rg_states[rg_id] = _RowGroupState(size=rg_size) - if self._buffer_manager is not None: - self._buffer_manager.init_row_group(rg_id, rg_size) + if self._buffer_manager is not None: + self._buffer_manager.init_row_group(rg_id, rg_size) - await self._dispatch_seeds(rg_id, rg_size) - self._emit_scheduler_event( - "row_group_admitted", - diagnostics=self._row_group_admission_diagnostics(reason="admitted") - | {"row_group": rg_id, "row_group_size": rg_size}, - ) - self._emit_scheduler_health_snapshot("row_group_admitted") - self._wake_event.set() + await self._dispatch_seeds(rg_id, rg_size) + self._emit_row_group_admission_event( + "row_group_admitted", + reason="admitted", + extra={"row_group": rg_id, "row_group_size": rg_size}, + ) + self._wake_event.set() + finally: + self._pending_row_group_admission_size = None self._all_rgs_admitted = all_admitted self._wake_event.set() @@ -1337,11 +1408,12 @@ def _checkpoint_completed_row_groups(self, all_columns: list[str]) -> None: completed = [ (rg_id, state.size) for rg_id, state in self._rg_states.items() + if state.in_flight_count == 0 if self._tracker.is_row_group_complete(rg_id, state.size, all_columns) ] + checkpointed_row_groups: set[int] = set() for rg_id, rg_size in completed: - dropped_rows = sum(1 for ri in range(rg_size) if self._tracker.is_dropped(rg_id, ri)) - checkpointed = False + dropped_rows = self._tracker.dropped_row_count(rg_id, rg_size) checkpoint_result = "unknown" try: if self._on_before_checkpoint: @@ -1353,8 +1425,6 @@ def _checkpoint_completed_row_groups(self, all_columns: list[str]) -> None: raise DatasetGenerationError( f"Post-batch processor failed for row group {rg_id}: {exc}" ) from exc - # Remove from tracking only after the callback succeeds. - del self._rg_states[rg_id] # If all rows were dropped (e.g. seed failure), free instead of finalizing if dropped_rows == rg_size: if self._buffer_manager: @@ -1365,48 +1435,51 @@ def _checkpoint_completed_row_groups(self, all_columns: list[str]) -> None: checkpoint_result = "finalized" else: checkpoint_result = "completed" - checkpointed = True except DatasetGenerationError: raise - except Exception: - logger.error(f"Failed to checkpoint row group {rg_id}.", exc_info=True) - finally: - self._rg_semaphore.release() - self._row_group_admission_event.set() - if checkpointed: - self._emit_scheduler_event( - "row_group_checkpointed", - diagnostics={ - "row_group": rg_id, - "row_group_size": rg_size, - "dropped_rows": dropped_rows, - "surviving_rows": rg_size - dropped_rows, - "result": checkpoint_result, - "active_row_groups": len(self._rg_states), - }, - ) - self._emit_scheduler_health_snapshot("row_group_checkpointed") + except Exception as exc: + raise DatasetGenerationError(f"Failed to checkpoint row group {rg_id}: {exc}") from exc + + del self._rg_states[rg_id] + self._rg_semaphore.release() + self._row_group_admission_event.set() + self._emit_scheduler_event( + "row_group_checkpointed", + diagnostics={ + "row_group": rg_id, + "row_group_size": rg_size, + "dropped_rows": dropped_rows, + "surviving_rows": rg_size - dropped_rows, + "result": checkpoint_result, + "active_row_groups": len(self._rg_states), + }, + ) + self._tracker.release_row_group(rg_id, rg_size, all_columns) + checkpointed_row_groups.add(rg_id) + self._emit_scheduler_health_snapshot("row_group_checkpointed") # Clean up deferred tasks for checkpointed row groups - if completed: - checkpointed = {rg_id for rg_id, _ in completed} - self._deferred = [t for t in self._deferred if t.row_group not in checkpointed] + if checkpointed_row_groups: + self._deferred = [t for t in self._deferred if t.row_group not in checkpointed_row_groups] self._deferred_errors = { - task: exc for task, exc in self._deferred_errors.items() if task.row_group not in checkpointed + task: exc + for task, exc in self._deferred_errors.items() + if task.row_group not in checkpointed_row_groups } self._preserved_retryable_counts = Counter( { task: count for task, count in self._preserved_retryable_counts.items() - if task.row_group not in checkpointed + if task.row_group not in checkpointed_row_groups } ) self._preserved_retryable_log_state = { row_group: state for row_group, state in self._preserved_retryable_log_state.items() - if row_group not in checkpointed + if row_group not in checkpointed_row_groups } - for rg_id in checkpointed: + self._dispatched = {task for task in self._dispatched if task.row_group not in checkpointed_row_groups} + for rg_id in checkpointed_row_groups: self._drop_pending_ready_for_row_group(rg_id) def _finalize_after_shutdown(self, all_columns: list[str]) -> None: @@ -1612,9 +1685,6 @@ async def _execute_task_inner_impl(self, task: Task, lease: TaskAdmissionLease, output_cols = self._gen_instance_to_columns.get(id(generator), [task.column]) retryable = False cancelled = False - # When True, skip removing from _dispatched so the task isn't re-dispatched - # from the frontier (it was never completed, so it stays in the frontier). - skipped = False uses_model_stage_resource = "llm_wait" in lease.resources stateful_lock_acquired = False @@ -1623,7 +1693,6 @@ async def _execute_task_inner_impl(self, task: Task, lease: TaskAdmissionLease, # when a vacuously-ready downstream is dispatched via create_task # in the same loop iteration that checkpoints the row group). if task.row_group not in self._rg_states: - skipped = True return if task.task_type == "from_scratch" and id(generator) in self._stateful_locks: @@ -1643,6 +1712,9 @@ async def _execute_task_inner_impl(self, task: Task, lease: TaskAdmissionLease, else: raise ValueError(f"Unknown task type: {task.task_type}") + if task.row_group not in self._rg_states: + return + # Mark all output columns complete for col in output_cols: if task.row_index is None: @@ -1746,7 +1818,7 @@ async def _execute_task_inner_impl(self, task: Task, lease: TaskAdmissionLease, self._in_flight.discard(task) if (s := self._rg_states.get(task.row_group)) is not None: s.in_flight_count = max(0, s.in_flight_count - 1) - if not retryable and not skipped: + if not retryable: self._dispatched.discard(task) if not retryable: self._deferred_errors.pop(task, None) @@ -2049,14 +2121,15 @@ def capacity_plan(self) -> AsyncCapacityPlan: buffer_size=CapacityValue(value=self._buffer_size, source="run_config"), row_group_admission=RowGroupAdmission( row_group_concurrency=CapacityValue( - value=self._max_concurrent_row_groups, - source="dataset_builder", + value=self._row_group_admission_hard_cap, + source=self._row_group_admission_source, ), observed_in_flight=len(self._rg_states), mode="adaptive" if self._adaptive_row_group_admission else "fixed", target_in_flight=self._row_group_admission_target, observed_max_target=self._observed_max_row_group_admission_target, - max_admitted_rows=self._adaptive_max_admitted_rows, + max_admitted_rows=self._effective_max_admitted_rows(), + max_admitted_rows_source=self._max_admitted_rows_source, blocked_reasons=dict(self._row_group_admission_blocked_reasons), ), submission_capacity=CapacityValue(value=self._max_in_flight_tasks, source="run_config"), diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py index efa68edf5..bab1abdf5 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py @@ -25,6 +25,7 @@ ProcessorConfig, ProcessorType, ) +from data_designer.config.run_config import RowGroupAdmissionMode from data_designer.config.utils.type_helpers import StrEnum from data_designer.config.utils.warning_helpers import warn_at_caller from data_designer.config.version import get_library_version @@ -1106,6 +1107,13 @@ def on_before_checkpoint(rg_id: int, rg_size: int) -> None: max_in_flight_tasks = self._resource_provider.run_config.max_in_flight_tasks max_model_task_admission = max_in_flight_tasks + row_group_admission = self._resource_provider.run_config.row_group_admission + row_group_admission_mode = RowGroupAdmissionMode(row_group_admission.mode) + adaptive_row_group_initial_target = 1 + if row_group_admission_mode == RowGroupAdmissionMode.ADAPTIVE: + adaptive_row_group_initial_target = row_group_admission.adaptive_initial_target + if adaptive_row_group_initial_target is None: + raise ValueError("adaptive_initial_target must be normalized for adaptive row-group admission.") scheduler = AsyncTaskScheduler( generators=gen_map, @@ -1113,6 +1121,7 @@ def on_before_checkpoint(rg_id: int, rg_size: int) -> None: tracker=tracker, row_groups=row_groups, buffer_manager=buffer_manager, + max_concurrent_row_groups=row_group_admission.max_concurrent_row_groups, max_in_flight_tasks=max_in_flight_tasks, max_model_task_admission=max_model_task_admission, on_finalize_row_group=on_finalize_row_group, @@ -1135,6 +1144,10 @@ def on_before_checkpoint(rg_id: int, rg_size: int) -> None: progress_bar=self._resource_provider.run_config.progress_bar, request_pressure_provider=self._resource_provider.model_registry.request_admission, request_pressure_advisory=True, + adaptive_row_group_admission=row_group_admission_mode == RowGroupAdmissionMode.ADAPTIVE, + adaptive_row_group_initial_target=adaptive_row_group_initial_target, + max_admitted_rows=row_group_admission.max_admitted_rows, + row_group_admission_source="run_config", ) return scheduler, buffer_manager diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/completion.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/completion.py index 855c91642..518a5f479 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/completion.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/completion.py @@ -3,6 +3,8 @@ from __future__ import annotations +import sys +from bisect import bisect_right from collections import defaultdict from dataclasses import dataclass from typing import TYPE_CHECKING @@ -20,6 +22,104 @@ from data_designer.engine.dataset_builders.utils.execution_graph import ExecutionGraph +MAX_EXACT_RELEASED_ROW_INTERVALS = 4_096 +MAX_RELEASED_ROW_GROUP_SUMMARY_RANGES = 4_096 + + +def _range_index( + ranges: list[tuple[int, int]] | tuple[tuple[int, int], ...], + value: int, +) -> int | None: + if not ranges or value < ranges[0][0] or value > ranges[-1][1]: + return None + index = bisect_right(ranges, (value, sys.maxsize)) - 1 + if index < 0: + return None + start, end = ranges[index] + if start <= value <= end: + return index + return None + + +def _ranges_from_sorted_values( + values: list[int], + *, + max_ranges: int | None = None, +) -> tuple[tuple[int, int], ...] | None: + if not values: + return () + ranges: list[tuple[int, int]] = [] + start = end = values[0] + for value in values[1:]: + if value == end + 1: + end = value + continue + ranges.append((start, end)) + if max_ranges is not None and len(ranges) > max_ranges: + return None + start = end = value + ranges.append((start, end)) + if max_ranges is not None and len(ranges) > max_ranges: + return None + return tuple(ranges) + + +def _survivor_ranges( + row_group_size: int, + dropped_rows: set[int], + *, + max_ranges: int | None = None, +) -> tuple[tuple[int, int], ...] | None: + ranges: list[tuple[int, int]] = [] + start: int | None = None + end: int | None = None + for row_index in range(row_group_size): + if row_index in dropped_rows: + if start is not None and end is not None: + ranges.append((start, end)) + if max_ranges is not None and len(ranges) > max_ranges: + return None + start = end = None + continue + if start is None: + start = end = row_index + else: + end = row_index + if start is not None and end is not None: + ranges.append((start, end)) + if max_ranges is not None and len(ranges) > max_ranges: + return None + return tuple(ranges) + + +def _dropped_ranges( + row_group_size: int, + dropped_rows: set[int], + *, + max_ranges: int | None = None, +) -> tuple[tuple[int, int], ...] | None: + ranges: list[tuple[int, int]] = [] + start: int | None = None + end: int | None = None + for row_index in range(row_group_size): + if row_index not in dropped_rows: + if start is not None and end is not None: + ranges.append((start, end)) + if max_ranges is not None and len(ranges) > max_ranges: + return None + start = end = None + continue + if start is None: + start = end = row_index + else: + end = row_index + if start is not None and end is not None: + ranges.append((start, end)) + if max_ranges is not None and len(ranges) > max_ranges: + return None + return tuple(ranges) + + @dataclass(frozen=True) class FrontierDelta: """Tasks added to or removed from the ready frontier by a tracker mutation.""" @@ -32,6 +132,48 @@ def empty(self) -> bool: return not self.added and not self.removed +@dataclass(frozen=True) +class _ReleasedColumns: + complete_columns: frozenset[str] + batch_columns: frozenset[str] + + +@dataclass(frozen=True) +class _ReleasedRows: + row_group_size: int + intervals: tuple[tuple[int, int], ...] + stores_survivors: bool + exact: bool = True + dropped_count_value: int | None = None + + def is_dropped(self, row_index: int) -> bool: + if not self.exact or not 0 <= row_index < self.row_group_size: + return False + in_intervals = _range_index(self.intervals, row_index) is not None + return not in_intervals if self.stores_survivors else in_intervals + + @property + def dropped_count(self) -> int: + if not self.exact: + if self.dropped_count_value is None: + return 0 + return self.dropped_count_value + interval_count = sum(end - start + 1 for start, end in self.intervals) + if self.stores_survivors: + return self.row_group_size - interval_count + return interval_count + + +@dataclass(frozen=True) +class _ReleasedRowGroup: + row_group_size: int + columns: _ReleasedColumns + rows: _ReleasedRows + + def is_dropped(self, row_index: int) -> bool: + return self.rows.is_dropped(row_index) + + class CompletionTracker: """Tracks which cells (column, row_group, row_index) are done. @@ -50,7 +192,13 @@ def __init__(self) -> None: self._graph: ExecutionGraph | None = None self._row_group_plan: RowGroupPlanLike | None = None - self._batch_complete: dict[int, set[str]] = defaultdict(set) + self._batch_complete: dict[int, dict[str, int]] = defaultdict(dict) + self._released_row_groups: dict[int, _ReleasedRowGroup] = {} + self._released_range_summaries: list[tuple[int, int, _ReleasedRowGroup]] = [] + self._remaining_cell_rows: dict[int, dict[str, int]] = defaultdict(dict) + # Exact post-checkpoint dropped-row identity is diagnostic-only. Keep it + # globally bounded so released summaries cannot grow with total rows. + self._released_exact_row_interval_count = 0 self._frontier: set[Task] = set() @classmethod @@ -62,9 +210,16 @@ def with_graph(cls, graph: ExecutionGraph, row_groups: RowGroupInput) -> Complet return tracker def mark_cell_complete(self, column: str, row_group: int, row_index: int) -> FrontierDelta: - self._validate_row_group(row_group) + row_group_size = self._validate_row_group(row_group) self._validate_strategy(column, GenerationStrategy.CELL_BY_CELL, "mark_cell_complete") - self._completed[row_group][column].add(row_index) + if row_group_size is not None and not 0 <= row_index < row_group_size: + raise ValueError(f"row_index out of range for rg={row_group}: got {row_index}, size {row_group_size}") + self._forget_released_row_group(row_group) + completed = self._completed[row_group][column] + was_complete = row_index in completed + completed.add(row_index) + if not was_complete and row_index not in self._dropped.get(row_group, set()): + self._decrement_remaining_cell_rows(row_group, column) removed: list[Task] = [] added: list[Task] = [] if self._graph is not None: @@ -79,8 +234,11 @@ def mark_row_range_complete(self, column: str, row_group: int, row_group_size: i self._validate_strategy(column, GenerationStrategy.FULL_COLUMN, "mark_row_range_complete") if expected is not None and row_group_size != expected: raise ValueError(f"Row-group size mismatch for rg={row_group}: got {row_group_size}, expected {expected}") - self._completed[row_group][column] = set(range(row_group_size)) - self._batch_complete[row_group].add(column) + self._forget_released_row_group(row_group) + if row_group in self._completed: + self._completed[row_group].pop(column, None) + self._remaining_cell_rows.get(row_group, {}).pop(column, None) + self._batch_complete[row_group][column] = row_group_size removed: list[Task] = [] added: list[Task] = [] if self._graph is not None: @@ -91,6 +249,25 @@ def mark_row_range_complete(self, column: str, row_group: int, row_group_size: i return self._record_delta(added=added, removed=removed) def is_complete(self, ref: SliceRef) -> bool: + if released := self._released_row_group(ref.row_group): + if ref.column not in released.columns.complete_columns: + return False + if ref.row_index is None: + return ref.column in released.columns.batch_columns + if not 0 <= ref.row_index < released.row_group_size: + return False + if ref.column in released.columns.batch_columns: + return True + if not released.rows.exact: + return True + return not released.is_dropped(ref.row_index) + batch_complete = self._batch_complete.get(ref.row_group, {}) + if ref.column in batch_complete: + if ref.row_index is None: + return True + return 0 <= ref.row_index < batch_complete[ref.column] + if ref.row_index is None: + return False return ref.row_index in self._completed.get(ref.row_group, {}).get(ref.column, set()) def is_all_complete(self, cells: list[SliceRef]) -> bool: @@ -99,28 +276,29 @@ def is_all_complete(self, cells: list[SliceRef]) -> bool: A ``row_index`` of ``None`` means the entire batch for that column must have been completed via ``mark_row_range_complete``. """ - for ref in cells: - if ref.row_index is None: - if ref.column not in self._batch_complete.get(ref.row_group, set()): - return False - elif not self.is_complete(ref): - return False - return True + return all(self.is_complete(ref) for ref in cells) def is_column_complete_for_rg(self, column: str, row_group_index: int) -> bool: """Check if *column* has been fully completed for *row_group_index*.""" - if column in self._batch_complete.get(row_group_index, set()): + if released := self._released_row_group(row_group_index): + return column in released.columns.complete_columns + if column in self._batch_complete.get(row_group_index, {}): return True rg_size = self._row_group_size_or_default(row_group_index, default=0) if rg_size == 0: return False - completed = self._completed.get(row_group_index, {}).get(column, set()) - dropped = self._dropped.get(row_group_index, set()) - return all(ri in completed or ri in dropped for ri in range(rg_size)) + return self._is_cell_column_complete(row_group_index, column, rg_size) def drop_row(self, row_group: int, row_index: int) -> FrontierDelta: - self._validate_row_group(row_group) - self._dropped[row_group].add(row_index) + row_group_size = self._validate_row_group(row_group) + if row_group_size is not None and not 0 <= row_index < row_group_size: + raise ValueError(f"row_index out of range for rg={row_group}: got {row_index}, size {row_group_size}") + self._forget_released_row_group(row_group) + dropped = self._dropped[row_group] + was_dropped = row_index in dropped + dropped.add(row_index) + if not was_dropped: + self._decrement_remaining_for_dropped_row(row_group, row_index) removed: list[Task] = [] added: list[Task] = [] if self._graph is not None: @@ -134,8 +312,15 @@ def drop_row(self, row_group: int, row_index: int) -> FrontierDelta: return self._record_delta(added=added, removed=removed) def is_dropped(self, row_group: int, row_index: int) -> bool: + if released := self._released_row_group(row_group): + return released.is_dropped(row_index) return row_index in self._dropped.get(row_group, set()) + def dropped_row_count(self, row_group: int, row_group_size: int) -> int: + if released := self._released_row_group(row_group): + return released.rows.dropped_count + return sum(1 for row_index in self._dropped.get(row_group, set()) if 0 <= row_index < row_group_size) + def is_row_group_complete( self, row_group: int, @@ -143,14 +328,16 @@ def is_row_group_complete( all_columns: list[str], ) -> bool: """All non-dropped rows have all columns done.""" - dropped = self._dropped.get(row_group, set()) - completed = self._completed.get(row_group, {}) - for ri in range(row_group_size): - if ri in dropped: + if released := self._released_row_group(row_group): + return row_group_size == released.row_group_size and set(all_columns).issubset( + released.columns.complete_columns + ) + batch_complete = self._batch_complete.get(row_group, {}) + for col in all_columns: + if col in batch_complete and batch_complete[col] == row_group_size: continue - for col in all_columns: - if ri not in completed.get(col, set()): - return False + if not self._is_cell_column_complete(row_group, col, row_group_size): + return False return True def ready_frontier(self) -> tuple[Task, ...]: @@ -165,6 +352,30 @@ def mark_enqueued(self, task_ids: set[str] | list[str] | tuple[str, ...]) -> Non def mark_complete(self, task: Task) -> None: """Compatibility hook for scheduler terminal accounting.""" + def release_row_group(self, row_group: int, row_group_size: int, all_columns: list[str]) -> None: + """Release completion state for a row group after durable checkpointing.""" + self._forget_released_row_group(row_group) + columns = _ReleasedColumns( + complete_columns=frozenset(all_columns), + batch_columns=frozenset(self._batch_complete.get(row_group, {})), + ) + dropped_rows = self._dropped.get(row_group, set()) + released = _ReleasedRowGroup( + row_group_size=row_group_size, + columns=columns, + rows=self._released_rows(row_group_size, dropped_rows), + ) + if self._row_group_plan is not None: + self._released_row_groups.pop(row_group, None) + self._add_released_summary_range(row_group, released) + else: + self._released_row_groups[row_group] = released + self._completed.pop(row_group, None) + self._dropped.pop(row_group, None) + self._batch_complete.pop(row_group, None) + self._remaining_cell_rows.pop(row_group, None) + self._frontier = {task for task in self._frontier if task.row_group != row_group} + def add_ready_tasks(self, tasks: list[Task] | tuple[Task, ...]) -> FrontierDelta: """Add ready tasks to the frontier idempotently.""" added: list[Task] = [] @@ -213,6 +424,7 @@ def add_root_tasks( expected = self._validate_row_group(row_group) if expected is not None and expected != row_group_size: raise ValueError(f"Row-group size mismatch for rg={row_group}: got {row_group_size}, expected {expected}") + self._forget_released_row_group(row_group) root_columns = columns or tuple(self._graph.get_root_columns()) added: list[Task] = [] for col in root_columns: @@ -228,6 +440,131 @@ def add_root_tasks( added.append(task) return self._record_delta(added=added, removed=[]) + def _released_row_group(self, row_group: int) -> _ReleasedRowGroup | None: + if released := self._released_row_groups.get(row_group): + return released + index = bisect_right(self._released_range_summaries, (row_group, sys.maxsize)) - 1 + if index >= 0: + start, end, released = self._released_range_summaries[index] + if start <= row_group <= end: + return released + return None + + def _released_rows(self, row_group_size: int, dropped_rows: set[int]) -> _ReleasedRows: + valid_dropped_count = sum(1 for row_index in dropped_rows if 0 <= row_index < row_group_size) + survivor_count = row_group_size - valid_dropped_count + exact_budget = MAX_EXACT_RELEASED_ROW_INTERVALS - self._released_exact_row_interval_count + if exact_budget <= 0: + return self._released_rows_aggregate(row_group_size, valid_dropped_count) + + if survivor_count < valid_dropped_count: + intervals = _survivor_ranges(row_group_size, dropped_rows, max_ranges=exact_budget) + if intervals is None: + return self._released_rows_aggregate(row_group_size, valid_dropped_count) + self._released_exact_row_interval_count += len(intervals) + return _ReleasedRows( + row_group_size=row_group_size, + intervals=intervals, + stores_survivors=True, + ) + if valid_dropped_count <= exact_budget: + valid_dropped = sorted(row_index for row_index in dropped_rows if 0 <= row_index < row_group_size) + intervals = _ranges_from_sorted_values(valid_dropped, max_ranges=exact_budget) + else: + intervals = _dropped_ranges(row_group_size, dropped_rows, max_ranges=exact_budget) + if intervals is None: + return self._released_rows_aggregate(row_group_size, valid_dropped_count) + self._released_exact_row_interval_count += len(intervals) + return _ReleasedRows( + row_group_size=row_group_size, + intervals=intervals, + stores_survivors=False, + ) + + def _released_rows_aggregate(self, row_group_size: int, dropped_count: int) -> _ReleasedRows: + return _ReleasedRows( + row_group_size=row_group_size, + intervals=(), + stores_survivors=False, + exact=False, + dropped_count_value=dropped_count, + ) + + def _add_released_summary_range(self, row_group: int, released: _ReleasedRowGroup) -> None: + index = ( + bisect_right( + self._released_range_summaries, + (row_group, sys.maxsize), + ) + - 1 + ) + if index >= 0: + start, end, existing = self._released_range_summaries[index] + if start <= row_group <= end: + return + if existing == released and end + 1 == row_group: + right_index = index + 1 + if ( + right_index < len(self._released_range_summaries) + and self._released_range_summaries[right_index][0] == row_group + 1 + and self._released_range_summaries[right_index][2] == released + ): + self._released_range_summaries[index] = ( + start, + self._released_range_summaries[right_index][1], + released, + ) + del self._released_range_summaries[right_index] + else: + self._released_range_summaries[index] = (start, row_group, released) + return + right_index = index + 1 + if ( + right_index < len(self._released_range_summaries) + and self._released_range_summaries[right_index][0] == row_group + 1 + and self._released_range_summaries[right_index][2] == released + ): + _right_start, right_end, _right_released = self._released_range_summaries[right_index] + self._released_range_summaries[right_index] = (row_group, right_end, released) + else: + self._released_range_summaries.insert(right_index, (row_group, row_group, released)) + self._trim_released_summary_ranges() + + def _trim_released_summary_ranges(self) -> None: + excess = len(self._released_range_summaries) - MAX_RELEASED_ROW_GROUP_SUMMARY_RANGES + if excess > 0: + del self._released_range_summaries[:excess] + + def _forget_released_row_group(self, row_group: int) -> None: + self._released_row_groups.pop(row_group, None) + self._remove_released_summary_range(row_group) + + def _remove_released_summary_range(self, row_group: int) -> None: + index = ( + bisect_right( + self._released_range_summaries, + (row_group, sys.maxsize), + ) + - 1 + ) + if index < 0: + return + start, end, released = self._released_range_summaries[index] + if not start <= row_group <= end: + return + if start == end: + del self._released_range_summaries[index] + elif row_group == start: + self._released_range_summaries[index] = (start + 1, end, released) + elif row_group == end: + self._released_range_summaries[index] = (start, end - 1, released) + else: + self._released_range_summaries[index : index + 1] = [ + (start, row_group - 1, released), + (row_group + 1, end, released), + ] + self._trim_released_summary_ranges() + def _record_delta(self, *, added: list[Task], removed: list[Task]) -> FrontierDelta: return FrontierDelta(added=tuple(added), removed=tuple(removed)) @@ -250,7 +587,7 @@ def _enqueue_downstream(self, column: str, row_group: int, row_index: int | None added: list[Task] = [] rg_completed = self._completed.get(row_group, {}) rg_dropped = self._dropped.get(row_group, set()) - rg_batch_complete = self._batch_complete.get(row_group, set()) + rg_batch_complete = self._batch_complete.get(row_group, {}) rg_size = self._row_group_size(row_group) for down in sorted(self._graph.get_downstream_columns(column)): @@ -286,9 +623,7 @@ def _enqueue_downstream(self, column: str, row_group: int, row_index: int | None added.append(task) else: # FULL_COLUMN downstream: ready when all cell upstreams are fully complete - if down not in rg_batch_complete and self._are_cell_ups_complete( - cell_ups, rg_completed, rg_size, rg_dropped - ): + if down not in rg_batch_complete and self._are_cell_ups_complete(row_group, cell_ups, rg_size): task = Task(column=down, row_group=row_group, row_index=None, task_type="batch") if self._add_frontier_task(task): added.append(task) @@ -299,9 +634,7 @@ def _reevaluate_batch_tasks(self, row_group: int) -> list[Task]: if self._graph is None: raise RuntimeError("This method requires a graph to be set.") added: list[Task] = [] - rg_completed = self._completed.get(row_group, {}) - rg_dropped = self._dropped.get(row_group, set()) - rg_batch_complete = self._batch_complete.get(row_group, set()) + rg_batch_complete = self._batch_complete.get(row_group, {}) rg_size = self._row_group_size(row_group) for col in self._graph.get_topological_order(): @@ -312,7 +645,7 @@ def _reevaluate_batch_tasks(self, row_group: int) -> list[Task]: batch_ups, cell_ups = self._graph.split_upstream_by_strategy(col) if any(up not in rg_batch_complete for up in batch_ups): continue - if self._are_cell_ups_complete(cell_ups, rg_completed, rg_size, rg_dropped): + if self._are_cell_ups_complete(row_group, cell_ups, rg_size): task = Task(column=col, row_group=row_group, row_index=None, task_type="batch") if self._add_frontier_task(task): added.append(task) @@ -320,19 +653,75 @@ def _reevaluate_batch_tasks(self, row_group: int) -> list[Task]: def _are_cell_ups_complete( self, + row_group: int, cell_ups: list[str], - rg_completed: dict[str, set[int]], rg_size: int, - rg_dropped: set[int], ) -> bool: """Check all non-dropped rows are complete for each cell-by-cell upstream column.""" for up in cell_ups: - up_completed = rg_completed.get(up, set()) - for ri in range(rg_size): - if ri not in rg_dropped and ri not in up_completed: - return False + if not self._is_cell_column_complete(row_group, up, rg_size): + return False return True + def _is_cell_column_complete(self, row_group: int, column: str, row_group_size: int) -> bool: + completed = self._completed.get(row_group, {}).get(column, set()) + dropped = self._dropped.get(row_group, set()) + if self._row_group_plan is not None and self._row_group_plan.has_row_group(row_group): + return self._remaining_cell_row_count(row_group, column, row_group_size, completed, dropped) == 0 + return self._row_indices_complete_with_dropped(row_group_size, completed, dropped) + + def _remaining_cell_row_count( + self, + row_group: int, + column: str, + row_group_size: int, + completed: set[int], + dropped: set[int], + ) -> int: + remaining_by_column = self._remaining_cell_rows[row_group] + if column in remaining_by_column: + return remaining_by_column[column] + if dropped: + completed_non_dropped = len(completed - dropped) + remaining = row_group_size - len(dropped) - completed_non_dropped + else: + remaining = row_group_size - len(completed) + remaining_by_column[column] = max(0, remaining) + return remaining_by_column[column] + + def _decrement_remaining_cell_rows(self, row_group: int, column: str) -> None: + remaining_by_column = self._remaining_cell_rows.get(row_group) + if remaining_by_column is None or column not in remaining_by_column: + return + remaining_by_column[column] = max(0, remaining_by_column[column] - 1) + + def _decrement_remaining_for_dropped_row(self, row_group: int, row_index: int) -> None: + remaining_by_column = self._remaining_cell_rows.get(row_group) + if not remaining_by_column: + return + completed_by_column = self._completed.get(row_group, {}) + for column in list(remaining_by_column): + if row_index not in completed_by_column.get(column, set()): + remaining_by_column[column] = max(0, remaining_by_column[column] - 1) + + def _row_indices_complete_with_dropped( + self, + row_group_size: int, + completed: set[int], + dropped: set[int], + ) -> bool: + if not dropped: + if len(completed) < row_group_size: + return False + if row_group_size == 0: + return True + if len(completed) == row_group_size: + return min(completed) == 0 and max(completed) == row_group_size - 1 + return all(ri in completed for ri in range(row_group_size)) + if len(completed) + len(dropped) < row_group_size: + return False + return all(ri in completed or ri in dropped for ri in range(row_group_size)) + def _validate_strategy(self, column: str, expected: GenerationStrategy, method: str) -> None: """Validate that *column* matches the expected strategy in graph-enabled mode.""" if self._graph is None: diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_completion.py b/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_completion.py index e647d4ac6..477f6ebc1 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_completion.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_completion.py @@ -3,6 +3,7 @@ from __future__ import annotations +import builtins from dataclasses import dataclass import pytest @@ -14,7 +15,10 @@ SamplerColumnConfig, ) from data_designer.config.sampler_params import SamplerType -from data_designer.engine.dataset_builders.scheduling.completion import CompletionTracker +from data_designer.engine.dataset_builders.scheduling.completion import ( + MAX_RELEASED_ROW_GROUP_SUMMARY_RANGES, + CompletionTracker, +) from data_designer.engine.dataset_builders.scheduling.resources import stable_task_id from data_designer.engine.dataset_builders.scheduling.task_model import SliceRef, Task from data_designer.engine.dataset_builders.utils.execution_graph import ExecutionGraph @@ -86,6 +90,16 @@ def test_mark_cell_complete_raises_on_unknown_row_group(ready_ctx: ReadyTasksFix ready_ctx.tracker.mark_cell_complete("question", row_group=999, row_index=0) +def test_mark_cell_complete_raises_on_out_of_range_row_index(ready_ctx: ReadyTasksFixture) -> None: + with pytest.raises(ValueError, match="row_index out of range"): + ready_ctx.tracker.mark_cell_complete("question", row_group=0, row_index=3) + + +def test_drop_row_raises_on_out_of_range_row_index(ready_ctx: ReadyTasksFixture) -> None: + with pytest.raises(ValueError, match="row_index out of range"): + ready_ctx.tracker.drop_row(row_group=0, row_index=3) + + # -- is_all_complete ----------------------------------------------------------- @@ -133,6 +147,8 @@ def test_drop_row() -> None: assert tracker.is_dropped(0, 2) assert not tracker.is_dropped(0, 0) assert not tracker.is_dropped(1, 2) + assert tracker.dropped_row_count(0, 3) == 1 + assert tracker.dropped_row_count(0, 2) == 0 # -- is_row_group_complete -------------------------------------------------- @@ -153,6 +169,24 @@ def test_row_group_incomplete() -> None: assert not tracker.is_row_group_complete(0, 3, ["col_a", "col_b"]) +def test_row_group_incomplete_with_out_of_range_cell_completion() -> None: + tracker = CompletionTracker() + tracker.mark_cell_complete("col_a", 0, 0) + tracker.mark_cell_complete("col_a", 0, 1) + tracker.mark_cell_complete("col_a", 0, 99) + + assert not tracker.is_column_complete_for_rg("col_a", 0) + assert not tracker.is_row_group_complete(0, 3, ["col_a"]) + + +def test_row_group_incomplete_when_batch_marker_size_mismatches() -> None: + tracker = CompletionTracker() + tracker.mark_row_range_complete("col_a", 0, 2) + + assert tracker.is_row_group_complete(0, 2, ["col_a"]) + assert not tracker.is_row_group_complete(0, 3, ["col_a"]) + + def test_row_group_complete_with_dropped_rows() -> None: tracker = CompletionTracker() tracker.mark_cell_complete("col_a", 0, 0) @@ -174,6 +208,272 @@ def test_row_group_not_complete_missing_non_dropped() -> None: assert not tracker.is_row_group_complete(0, 3, ["col_a", "col_b"]) +def test_release_row_group_clears_heavy_state_and_preserves_summary(ready_ctx: ReadyTasksFixture) -> None: + tracker = ready_ctx.tracker + tracker.mark_row_range_complete("topic", 0, 3) + tracker.mark_cell_complete("question", 0, 0) + tracker.mark_cell_complete("question", 0, 2) + tracker.drop_row(0, 1) + tracker.mark_row_range_complete("score", 0, 3) + + tracker.release_row_group(0, 3, ["topic", "question", "score"]) + + assert tracker._completed == {} + assert tracker._dropped == {} + assert tracker._batch_complete == {} + assert tracker.ready_frontier() == () + assert tracker.is_row_group_complete(0, 3, ["topic", "question", "score"]) + assert tracker.is_dropped(0, 1) + assert tracker.dropped_row_count(0, 3) == 1 + assert tracker.is_complete(SliceRef("question", 0, 0)) + assert not tracker.is_complete(SliceRef("question", 0, 1)) + assert tracker.is_complete(SliceRef("score", 0, 1)) + assert tracker.is_complete(SliceRef("score", 0, None)) + assert not tracker.is_complete(SliceRef("question", 0, None)) + + +def test_release_row_group_merges_clean_summaries_into_one_range() -> None: + graph = _build_simple_graph() + tracker = CompletionTracker.with_graph(graph, [(rg_id, 3) for rg_id in range(4)]) + + for row_group in range(4): + tracker.mark_row_range_complete("topic", row_group, 3) + for row_index in range(3): + tracker.mark_cell_complete("question", row_group, row_index) + tracker.mark_row_range_complete("score", row_group, 3) + tracker.release_row_group(row_group, 3, ["topic", "question", "score"]) + + assert tracker._completed == {} + assert tracker._dropped == {} + assert tracker._batch_complete == {} + assert tracker._released_row_groups == {} + assert len(tracker._released_range_summaries) == 1 + assert tracker.dropped_row_count(2, 3) == 0 + assert tracker.is_row_group_complete(2, 3, ["topic", "question", "score"]) + assert tracker.is_complete(SliceRef("question", 2, 1)) + assert tracker.is_complete(SliceRef("score", 2, None)) + + delta = tracker.add_root_tasks(1, 3) + + assert [task.row_group for task in delta.added] == [1] + assert [(start, end) for start, end, _released in tracker._released_range_summaries] == [(0, 0), (2, 3)] + assert not tracker.is_row_group_complete(1, 3, ["topic", "question", "score"]) + assert tracker.is_row_group_complete(2, 3, ["topic", "question", "score"]) + + +def test_release_row_group_preserves_fragmented_clean_ranges_compactly() -> None: + graph = _build_simple_graph() + tracker = CompletionTracker.with_graph(graph, [(rg_id, 3) for rg_id in range(8)]) + + for row_group in (0, 2, 4, 6): + tracker.mark_row_range_complete("topic", row_group, 3) + tracker.mark_row_range_complete("score", row_group, 3) + tracker.release_row_group(row_group, 3, ["topic", "score"]) + + assert [(start, end) for start, end, _released in tracker._released_range_summaries] == [ + (0, 0), + (2, 2), + (4, 4), + (6, 6), + ] + + tracker.mark_row_range_complete("topic", 7, 3) + + assert [(start, end) for start, end, _released in tracker._released_range_summaries] == [ + (0, 0), + (2, 2), + (4, 4), + (6, 6), + ] + + tracker.mark_row_range_complete("topic", 1, 3) + tracker.mark_row_range_complete("score", 1, 3) + tracker.release_row_group(1, 3, ["topic", "score"]) + + assert [(start, end) for start, end, _released in tracker._released_range_summaries] == [ + (0, 2), + (4, 4), + (6, 6), + ] + + +def test_release_row_group_stores_dropped_rows_as_survivor_ranges_when_smaller() -> None: + graph = _build_simple_graph() + tracker = CompletionTracker.with_graph(graph, [(0, 10)]) + tracker.mark_row_range_complete("topic", 0, 10) + tracker.mark_row_range_complete("score", 0, 10) + for row_index in range(9): + tracker.drop_row(0, row_index) + + tracker.release_row_group(0, 10, ["topic", "score"]) + + released = tracker._released_row_group(0) + assert released is not None + assert released.rows.stores_survivors + assert released.rows.intervals == ((9, 9),) + assert tracker._released_row_groups == {} + assert len(tracker._released_range_summaries) == 1 + assert tracker.dropped_row_count(0, 10) == 9 + assert tracker.is_dropped(0, 0) + assert not tracker.is_dropped(0, 9) + + +def test_release_row_group_merges_dropped_summaries_into_one_range() -> None: + graph = _build_simple_graph() + tracker = CompletionTracker.with_graph(graph, [(row_group, 3) for row_group in range(4)]) + + for row_group in range(4): + tracker.mark_row_range_complete("topic", row_group, 3) + tracker.mark_row_range_complete("score", row_group, 3) + tracker.drop_row(row_group, 1) + tracker.release_row_group(row_group, 3, ["topic", "score"]) + + assert tracker._released_row_groups == {} + assert [(start, end) for start, end, _released in tracker._released_range_summaries] == [(0, 3)] + assert tracker.dropped_row_count(2, 3) == 1 + assert tracker.is_dropped(2, 1) + assert not tracker.is_dropped(2, 0) + assert tracker.is_row_group_complete(2, 3, ["topic", "score"]) + + delta = tracker.add_root_tasks(1, 3) + + assert [task.row_group for task in delta.added] == [1] + assert [(start, end) for start, end, _released in tracker._released_range_summaries] == [(0, 0), (2, 3)] + assert not tracker.is_row_group_complete(1, 3, ["topic", "score"]) + + +def test_release_row_group_bounds_fragmented_dropped_row_summary() -> None: + graph = _build_simple_graph() + row_group_size = 10_000 + tracker = CompletionTracker.with_graph(graph, [(0, row_group_size)]) + + tracker.mark_row_range_complete("topic", 0, row_group_size) + tracker.mark_row_range_complete("score", 0, row_group_size) + for row_index in range(0, row_group_size, 2): + tracker.drop_row(0, row_index) + + tracker.release_row_group(0, row_group_size, ["topic", "score"]) + + released = tracker._released_row_group(0) + assert released is not None + assert not released.rows.exact + assert released.rows.intervals == () + assert tracker._released_row_groups == {} + assert len(tracker._released_range_summaries) == 1 + assert tracker.dropped_row_count(0, row_group_size) == row_group_size // 2 + assert tracker.is_row_group_complete(0, row_group_size, ["topic", "score"]) + assert tracker.is_complete(SliceRef("topic", 0, 0)) + + +def test_release_row_group_aggregates_first_fragmented_drop_without_sort(monkeypatch: pytest.MonkeyPatch) -> None: + graph = _build_simple_graph() + row_group_size = 1_000_000 + tracker = CompletionTracker.with_graph(graph, [(0, row_group_size)]) + + tracker.mark_row_range_complete("topic", 0, row_group_size) + tracker.mark_row_range_complete("score", 0, row_group_size) + tracker._dropped[0] = set(range(0, row_group_size, 2)) + + def fail_sorted(*_args: object, **_kwargs: object) -> list[object]: + raise AssertionError("fragmented released rows should aggregate without sorting every dropped row") + + monkeypatch.setattr(builtins, "sorted", fail_sorted) + + tracker.release_row_group(0, row_group_size, ["topic", "score"]) + + released = tracker._released_row_group(0) + assert released is not None + assert not released.rows.exact + assert released.rows.intervals == () + assert tracker._released_exact_row_interval_count == 0 + assert tracker.dropped_row_count(0, row_group_size) == row_group_size // 2 + + +def test_release_row_group_keeps_large_contiguous_dropped_split_exact() -> None: + graph = _build_simple_graph() + row_group_size = 10_000 + tracker = CompletionTracker.with_graph(graph, [(0, row_group_size)]) + + tracker.mark_row_range_complete("topic", 0, row_group_size) + tracker.mark_row_range_complete("score", 0, row_group_size) + for row_index in range(row_group_size // 2): + tracker.drop_row(0, row_index) + + tracker.release_row_group(0, row_group_size, ["topic", "score"]) + + released = tracker._released_row_group(0) + assert released is not None + assert released.rows.exact + assert released.rows.intervals == ((0, 4_999),) + assert tracker.dropped_row_count(0, row_group_size) == row_group_size // 2 + assert tracker.is_dropped(0, 0) + assert tracker.is_dropped(0, 4_999) + assert not tracker.is_dropped(0, 5_000) + + +def test_release_row_group_bounds_exact_dropped_row_summaries_across_run() -> None: + graph = _build_simple_graph() + row_group_size = 8_192 + tracker = CompletionTracker.with_graph(graph, [(0, row_group_size), (1, row_group_size)]) + + for row_group in (0, 1): + tracker.mark_row_range_complete("topic", row_group, row_group_size) + tracker.mark_row_range_complete("score", row_group, row_group_size) + for row_index in range(0, row_group_size, 2): + tracker.drop_row(row_group, row_index) + for row_index in range(1, row_group_size, 2): + tracker.mark_cell_complete("question", row_group, row_index) + tracker.release_row_group(row_group, row_group_size, ["topic", "question", "score"]) + + first = tracker._released_row_group(0) + second = tracker._released_row_group(1) + assert first is not None + assert second is not None + assert first.rows.exact + assert len(first.rows.intervals) == 4_096 + assert not second.rows.exact + assert second.rows.intervals == () + assert tracker.dropped_row_count(0, row_group_size) == row_group_size // 2 + assert tracker.dropped_row_count(1, row_group_size) == row_group_size // 2 + assert tracker.is_row_group_complete(1, row_group_size, ["topic", "question", "score"]) + assert tracker.is_complete(SliceRef("question", 1, 1)) + + +def test_release_row_group_bounds_alternating_summary_ranges() -> None: + graph = _build_simple_graph() + row_group_count = MAX_RELEASED_ROW_GROUP_SUMMARY_RANGES + 16 + tracker = CompletionTracker.with_graph(graph, [(row_group, 8) for row_group in range(row_group_count)]) + + for row_group in range(row_group_count): + tracker.mark_row_range_complete("topic", row_group, 8) + tracker.mark_row_range_complete("score", row_group, 8) + tracker.drop_row(row_group, row_group % 8) + tracker.release_row_group(row_group, 8, ["topic", "score"]) + + assert len(tracker._released_range_summaries) == MAX_RELEASED_ROW_GROUP_SUMMARY_RANGES + assert tracker._released_row_group(0) is None + assert tracker.is_row_group_complete(row_group_count - 1, 8, ["topic", "score"]) + assert tracker.dropped_row_count(row_group_count - 1, 8) == 1 + + +def test_reopened_released_row_groups_keep_summary_ranges_bounded() -> None: + graph = _build_simple_graph() + row_group_count = MAX_RELEASED_ROW_GROUP_SUMMARY_RANGES * 2 + 1 + tracker = CompletionTracker.with_graph(graph, [(row_group, 3) for row_group in range(row_group_count)]) + + for row_group in range(row_group_count): + tracker.mark_row_range_complete("topic", row_group, 3) + tracker.mark_row_range_complete("score", row_group, 3) + tracker.release_row_group(row_group, 3, ["topic", "score"]) + + assert len(tracker._released_range_summaries) == 1 + + for row_group in range(1, row_group_count, 2): + tracker.add_root_tasks(row_group, 3) + + assert len(tracker._released_range_summaries) <= MAX_RELEASED_ROW_GROUP_SUMMARY_RANGES + + # -- get_ready_tasks -------------------------------------------------------- @@ -301,6 +601,19 @@ def test_drop_row_unblocks_full_column_downstream(ready_ctx: ReadyTasksFixture) assert score_tasks[0] in delta.added +def test_cached_remaining_cell_count_updates_after_drop(ready_ctx: ReadyTasksFixture) -> None: + ready_ctx.tracker.mark_row_range_complete("topic", 0, 3) + assert not ready_ctx.tracker.is_row_group_complete(0, 3, ["topic", "question", "score"]) + assert ready_ctx.tracker._remaining_cell_rows[0]["question"] == 3 + + ready_ctx.tracker.mark_cell_complete("question", 0, 0) + ready_ctx.tracker.mark_cell_complete("question", 0, 1) + delta = ready_ctx.tracker.drop_row(0, 2) + + assert ready_ctx.tracker._remaining_cell_rows[0]["question"] == 0 + assert Task(column="score", row_group=0, row_index=None, task_type="batch") in delta.added + + def test_get_ready_tasks_full_column_waits_for_all_cells(ready_ctx: ReadyTasksFixture) -> None: ready_ctx.tracker.mark_row_range_complete("topic", 0, 3) ready_ctx.tracker.mark_cell_complete("question", 0, 0) @@ -328,6 +641,20 @@ def test_get_ready_tasks_full_column_ready_when_all_cells_done(ready_ctx: ReadyT assert delta.added == (score_tasks[0],) +def test_cached_remaining_cell_count_updates_after_completion(ready_ctx: ReadyTasksFixture) -> None: + ready_ctx.tracker.mark_row_range_complete("topic", 0, 3) + assert not ready_ctx.tracker.is_column_complete_for_rg("question", 0) + assert ready_ctx.tracker._remaining_cell_rows[0]["question"] == 3 + + delta = None + for ri in range(3): + delta = ready_ctx.tracker.mark_cell_complete("question", 0, ri) + + assert ready_ctx.tracker._remaining_cell_rows[0]["question"] == 0 + assert delta is not None + assert delta.added == (Task(column="score", row_group=0, row_index=None, task_type="batch"),) + + def test_get_ready_tasks_multiple_row_groups() -> None: graph = _build_simple_graph() tracker = CompletionTracker.with_graph(graph, [(0, 3), (1, 2)]) diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_builder_integration.py b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_builder_integration.py index 66ef898e6..d8834de40 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_builder_integration.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_builder_integration.py @@ -19,6 +19,7 @@ LLMTextColumnConfig, SamplerColumnConfig, ) +from data_designer.config.run_config import RowGroupAdmissionConfig, RowGroupAdmissionMode, RunConfig from data_designer.config.sampler_params import SamplerType from data_designer.engine.column_generators.generators.base import ( ColumnGenerator, @@ -193,7 +194,14 @@ def finalize_row_group(rg_id: int) -> None: assert tracker.is_row_group_complete(1, 2, all_cols) -def test_prepare_async_run_enables_request_pressure_advisory(monkeypatch: pytest.MonkeyPatch) -> None: +def _capture_prepare_async_run_kwargs( + monkeypatch: pytest.MonkeyPatch, + run_config: RunConfig, + *, + num_records: int = 1, + buffer_size: int = 1, + request_admission: object | None = None, +) -> dict[str, object]: captured_kwargs: dict[str, object] = {} class _SpyScheduler: @@ -201,7 +209,6 @@ def __init__(self, **kwargs: object) -> None: captured_kwargs.update(kwargs) monkeypatch.setattr(builder_mod, "AsyncTaskScheduler", _SpyScheduler) - request_admission = object() model_registry = MagicMock() model_registry.get_aggregate_max_parallel_requests.side_effect = AssertionError( "model task admission should follow max_in_flight_tasks directly" @@ -209,7 +216,7 @@ def __init__(self, **kwargs: object) -> None: model_registry.request_admission = request_admission provider = SimpleNamespace( model_registry=model_registry, - run_config=SimpleNamespace(max_in_flight_tasks=64, progress_interval=5.0, progress_bar=False), + run_config=run_config, ) processor_runner = MagicMock() processor_runner.has_processors_for.return_value = False @@ -222,7 +229,18 @@ def __init__(self, **kwargs: object) -> None: ) generator = MockSeed(config=_expr_config("seed"), resource_provider=provider) - DatasetBuilder._prepare_async_run(builder, [generator], num_records=1, buffer_size=1) + DatasetBuilder._prepare_async_run(builder, [generator], num_records=num_records, buffer_size=buffer_size) + + return captured_kwargs + + +def test_prepare_async_run_enables_request_pressure_advisory(monkeypatch: pytest.MonkeyPatch) -> None: + request_admission = object() + captured_kwargs = _capture_prepare_async_run_kwargs( + monkeypatch, + RunConfig(max_in_flight_tasks=64), + request_admission=request_admission, + ) assert captured_kwargs["request_pressure_provider"] is request_admission assert captured_kwargs["request_pressure_advisory"] is True @@ -230,34 +248,120 @@ def __init__(self, **kwargs: object) -> None: assert captured_kwargs["max_model_task_admission"] == 64 -def test_prepare_async_run_uses_compact_plan_for_large_fresh_runs(monkeypatch: pytest.MonkeyPatch) -> None: - captured_kwargs: dict[str, object] = {} +@pytest.mark.parametrize( + "run_config,expected_cap,expected_adaptive,expected_initial_target,expected_row_budget", + [ + pytest.param( + RunConfig(max_in_flight_tasks=64), + 3, + False, + 1, + None, + id="default_fixed", + ), + pytest.param( + RunConfig( + max_in_flight_tasks=64, + row_group_admission=RowGroupAdmissionConfig(max_concurrent_row_groups=7), + ), + 7, + False, + 1, + None, + id="custom_fixed", + ), + pytest.param( + RunConfig( + max_in_flight_tasks=64, + row_group_admission=RowGroupAdmissionConfig( + max_concurrent_row_groups=7, + max_admitted_rows=4096, + ), + ), + 7, + False, + 1, + 4096, + id="custom_fixed_row_budget", + ), + pytest.param( + RunConfig( + max_in_flight_tasks=64, + row_group_admission=RowGroupAdmissionConfig( + mode=RowGroupAdmissionMode.ADAPTIVE, + max_concurrent_row_groups=9, + adaptive_initial_target=3, + max_admitted_rows=2048, + ), + ), + 9, + True, + 3, + 2048, + id="adaptive", + ), + pytest.param( + RunConfig( + max_in_flight_tasks=64, + row_group_admission=RowGroupAdmissionConfig( + mode=RowGroupAdmissionMode.ADAPTIVE, + max_concurrent_row_groups=5, + ), + ), + 5, + True, + 1, + None, + id="adaptive_default_initial_target", + ), + ], +) +def test_prepare_async_run_threads_row_group_admission_config( + monkeypatch: pytest.MonkeyPatch, + run_config: RunConfig, + expected_cap: int, + expected_adaptive: bool, + expected_initial_target: int, + expected_row_budget: int | None, +) -> None: + captured_kwargs = _capture_prepare_async_run_kwargs(monkeypatch, run_config) + + assert captured_kwargs["max_concurrent_row_groups"] == expected_cap + assert captured_kwargs["adaptive_row_group_admission"] is expected_adaptive + assert captured_kwargs["adaptive_row_group_initial_target"] == expected_initial_target + assert captured_kwargs["max_admitted_rows"] == expected_row_budget + assert captured_kwargs["row_group_admission_source"] == "run_config" + + +def test_prepare_async_run_threads_adaptive_row_budget_inputs(monkeypatch: pytest.MonkeyPatch) -> None: + captured_kwargs = _capture_prepare_async_run_kwargs( + monkeypatch, + RunConfig( + max_in_flight_tasks=64, + row_group_admission=RowGroupAdmissionConfig( + mode=RowGroupAdmissionMode.ADAPTIVE, + max_concurrent_row_groups=5, + ), + ), + num_records=12345, + buffer_size=123, + ) - class _SpyScheduler: - def __init__(self, **kwargs: object) -> None: - captured_kwargs.update(kwargs) + assert captured_kwargs["adaptive_row_group_admission"] is True + assert captured_kwargs["max_admitted_rows"] is None + assert captured_kwargs["num_records"] == 12345 + assert captured_kwargs["buffer_size"] == 123 - monkeypatch.setattr(builder_mod, "AsyncTaskScheduler", _SpyScheduler) - model_registry = MagicMock() - model_registry.request_admission = None - provider = SimpleNamespace( - model_registry=model_registry, - run_config=SimpleNamespace(max_in_flight_tasks=64, progress_interval=5.0, progress_bar=False), - ) - processor_runner = MagicMock() - processor_runner.has_processors_for.return_value = False - config = SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}) - builder = SimpleNamespace( - _column_configs=[config], - _processor_runner=processor_runner, - artifact_storage=MagicMock(), - _resource_provider=provider, - ) - generator = MockSeed(config=_expr_config("seed"), resource_provider=provider) +def test_prepare_async_run_uses_compact_plan_for_large_fresh_runs(monkeypatch: pytest.MonkeyPatch) -> None: tracemalloc.start() try: - DatasetBuilder._prepare_async_run(builder, [generator], num_records=2_000_000, buffer_size=2) + captured_kwargs = _capture_prepare_async_run_kwargs( + monkeypatch, + RunConfig(max_in_flight_tasks=64), + num_records=2_000_000, + buffer_size=2, + ) _current, peak_bytes = tracemalloc.get_traced_memory() finally: tracemalloc.stop() diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py index d36d12366..92776284e 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py @@ -41,7 +41,7 @@ from data_designer.engine.dataset_builders.row_group_plan import CompactRowGroupPlan from data_designer.engine.dataset_builders.scheduling.completion import CompletionTracker, FrontierDelta from data_designer.engine.dataset_builders.scheduling.task_admission import TaskAdmissionConfig, TaskAdmissionLease -from data_designer.engine.dataset_builders.scheduling.task_model import Task +from data_designer.engine.dataset_builders.scheduling.task_model import SliceRef, Task from data_designer.engine.dataset_builders.scheduling.task_policies import BoundedBorrowTaskAdmissionPolicyConfig from data_designer.engine.dataset_builders.utils.async_progress_reporter import AsyncProgressReporter from data_designer.engine.dataset_builders.utils.execution_graph import ExecutionGraph @@ -466,6 +466,42 @@ def test_scheduler_preparation_memory_stays_bounded_for_million_row_groups() -> assert peak_bytes < 5 * 1024 * 1024 +def test_completion_tracker_full_column_batch_uses_compact_completion_marker() -> None: + configs = [SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]})] + graph = ExecutionGraph.create(configs, {"seed": GenerationStrategy.FULL_COLUMN}) + tracker = CompletionTracker.with_graph(graph, [(0, 1_000_000)]) + + delta = tracker.mark_row_range_complete("seed", 0, 1_000_000) + + assert delta.empty + assert tracker._batch_complete[0] == {"seed": 1_000_000} + assert "seed" not in tracker._completed.get(0, {}) + assert tracker.is_complete(SliceRef(column="seed", row_group=0, row_index=999_999)) + assert tracker.is_all_complete( + [ + SliceRef(column="seed", row_group=0, row_index=None), + SliceRef(column="seed", row_group=0, row_index=123), + ] + ) + assert tracker.is_column_complete_for_rg("seed", 0) + assert tracker.is_row_group_complete(0, 1_000_000, ["seed"]) + + +@pytest.mark.asyncio(loop_scope="session") +async def test_scheduler_releases_completion_state_after_checkpoint() -> None: + scheduler, tracker = _build_simple_pipeline(num_records=4, buffer_size=2) + + await asyncio.wait_for(scheduler.run(), timeout=10.0) + + assert tracker._completed == {} + assert tracker._dropped == {} + assert tracker._batch_complete == {} + assert tracker.ready_frontier() == () + assert tracker._released_row_groups == {} + assert tracker.is_row_group_complete(0, 2, ["seed", "cell_out"]) + assert tracker.is_row_group_complete(1, 2, ["seed", "cell_out"]) + + def _seed_plus_cell_setup( cell_generator: ColumnGenerator, num_records: int, @@ -730,6 +766,57 @@ async def test_scheduler_non_retryable_failure_drops_row() -> None: assert tracker.is_row_group_complete(0, 2, ["seed", "fail_col"]) +@pytest.mark.asyncio(loop_scope="session") +async def test_scheduler_waits_for_in_flight_siblings_before_checkpoint_release() -> None: + provider = _mock_provider() + configs = [ + SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + LLMTextColumnConfig(name="fail_col", prompt="{{ seed }}", model_alias=MODEL_ALIAS), + LLMTextColumnConfig(name="slow_col", prompt="{{ seed }}", model_alias=MODEL_ALIAS), + ] + strategies = { + "seed": GenerationStrategy.FULL_COLUMN, + "fail_col": GenerationStrategy.CELL_BY_CELL, + "slow_col": GenerationStrategy.CELL_BY_CELL, + } + row_groups = [(0, 1)] + graph = ExecutionGraph.create(configs, strategies) + tracker = CompletionTracker.with_graph(graph, row_groups) + scheduler = AsyncTaskScheduler( + generators={ + "seed": MockSeedGenerator(config=_expr_config("seed"), resource_provider=provider), + "fail_col": MockFailingGenerator(config=_expr_config("fail_col"), resource_provider=provider), + "slow_col": SlowCellGenerator( + config=_expr_config("slow_col"), + resource_provider=provider, + delay=0.05, + ), + }, + graph=graph, + tracker=tracker, + row_groups=row_groups, + max_in_flight_tasks=2, + max_model_task_admission=2, + num_records=1, + buffer_size=1, + ) + release_in_flight_counts: list[int] = [] + release_row_group = tracker.release_row_group + + def spy_release_row_group(row_group: int, row_group_size: int, all_columns: list[str]) -> None: + release_in_flight_counts.append(sum(task.row_group == row_group for task in scheduler._in_flight)) + release_row_group(row_group, row_group_size, all_columns) + + tracker.release_row_group = spy_release_row_group # type: ignore[method-assign] + + await scheduler.run() + + assert release_in_flight_counts == [0] + assert tracker.is_dropped(0, 0) + assert tracker.is_row_group_complete(0, 1, ["seed", "fail_col", "slow_col"]) + assert scheduler.capacity_plan().observed_maxima.row_groups_in_flight == 1 + + def test_scheduler_internal_bug_classifier_preserves_scheduler_builtin_failures() -> None: scheduler, tracker = _build_simple_pipeline(num_records=1) assert scheduler._is_internal_bug(KeyError("missing internal key")) @@ -2718,6 +2805,7 @@ async def test_scheduler_429_beyond_salvage_cap_is_delayed_not_dropped() -> None assert scheduler._deferred_errors == {} assert scheduler._preserved_retryable_counts == {} assert scheduler._preserved_retryable_log_state == {} + assert scheduler._dispatched == set() @pytest.mark.asyncio(loop_scope="session") @@ -2874,6 +2962,7 @@ async def test_scheduler_drops_non_preserved_retryable_errors_when_salvage_exhau assert scheduler._deferred_errors == {} assert scheduler._preserved_retryable_counts == {} assert scheduler._preserved_retryable_log_state == {} + assert scheduler._dispatched == set() def test_scheduler_rejects_zero_salvage_rounds() -> None: @@ -3135,6 +3224,22 @@ async def agenerate(self, data: dict) -> dict: return self.generate(data) +class SlowFullColumnGenerator(ColumnGeneratorFullColumn[ExpressionColumnConfig]): + """Full-column generator with configurable async delay.""" + + def __init__(self, *args: Any, delay: float = 0.05, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._delay = delay + + def generate(self, data: lazy.pd.DataFrame) -> lazy.pd.DataFrame: + data[self.config.name] = "gen" + return data + + async def agenerate(self, data: lazy.pd.DataFrame) -> lazy.pd.DataFrame: + await asyncio.sleep(self._delay) + return self.generate(data) + + class SlowLLMBoundCellGenerator(SlowCellGenerator): """Slow cell generator that participates in model-stage scheduling.""" @@ -3632,6 +3737,45 @@ def test_scheduler_capacity_plan_reports_default_request_initial_limit_after_aim assert plan.runtime_snapshot.request_domain_current_limits[resource] == 3 +def test_scheduler_capacity_plan_reports_public_row_group_admission_config() -> None: + provider = _mock_provider() + configs = [ + SamplerColumnConfig(name="topic", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + LLMTextColumnConfig(name="model_col", prompt="{{ topic }}", model_alias=MODEL_ALIAS), + ] + strategies = { + "topic": GenerationStrategy.FULL_COLUMN, + "model_col": GenerationStrategy.CELL_BY_CELL, + } + row_groups = [(0, 2), (1, 2), (2, 2)] + graph = ExecutionGraph.create(configs, strategies) + scheduler = AsyncTaskScheduler( + generators={ + "topic": MockSeedGenerator(config=_expr_config("topic"), resource_provider=provider), + "model_col": MockCellGenerator(config=_expr_config("model_col"), resource_provider=provider), + }, + graph=graph, + tracker=CompletionTracker.with_graph(graph, row_groups), + row_groups=row_groups, + max_concurrent_row_groups=5, + adaptive_row_group_admission=True, + adaptive_row_group_initial_target=2, + max_admitted_rows=2048, + row_group_admission_source="run_config", + num_records=6, + buffer_size=2, + ) + + plan = scheduler.capacity_plan() + + assert plan.configured.row_group_admission.row_group_concurrency.value == 5 + assert plan.configured.row_group_admission.row_group_concurrency.source == "run_config" + assert plan.configured.row_group_admission.mode == "adaptive" + assert plan.configured.row_group_admission.target_in_flight == 2 + assert plan.configured.row_group_admission.max_admitted_rows == 2048 + assert plan.configured.row_group_admission.max_admitted_rows_source == "run_config" + + @pytest.mark.asyncio(loop_scope="session") async def test_scheduler_emits_job_health_and_row_group_telemetry() -> None: provider = _mock_provider() @@ -3744,10 +3888,12 @@ async def test_scheduler_adaptive_row_group_admission_expands_target_for_horizon assert plan.configured.row_group_admission.observed_max_target is not None assert plan.configured.row_group_admission.observed_max_target > 1 assert plan.observed_maxima.row_groups_in_flight > 1 + assert "no_pending_row_groups" not in plan.configured.row_group_admission.blocked_reasons assert any(event.event_kind == "row_group_admission_target_changed" for event in sink.scheduler_events) -def test_scheduler_adaptive_row_group_row_guard_blocks_extra_large_groups() -> None: +@pytest.mark.parametrize("adaptive", [False, True], ids=["fixed", "adaptive"]) +def test_scheduler_derived_row_group_row_guard_blocks_extra_large_groups(adaptive: bool) -> None: provider = _mock_provider() configs = [ SamplerColumnConfig(name="topic", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), @@ -3772,7 +3918,7 @@ def test_scheduler_adaptive_row_group_row_guard_blocks_extra_large_groups() -> N tracker=CompletionTracker.with_graph(graph, row_groups), row_groups=row_groups, max_concurrent_row_groups=4, - adaptive_row_group_admission=True, + adaptive_row_group_admission=adaptive, adaptive_row_group_initial_target=4, num_records=10_000, buffer_size=1, @@ -3780,11 +3926,251 @@ def test_scheduler_adaptive_row_group_row_guard_blocks_extra_large_groups() -> N scheduler._rg_states[0] = SimpleNamespace(size=5_000) - assert scheduler._adaptive_max_admitted_rows == 8_192 + assert scheduler._max_admitted_rows == 8_192 assert not scheduler._row_group_row_guard_allows(5_000) assert scheduler._row_group_row_guard_allows(1_000) scheduler._rg_states.clear() - assert scheduler._row_group_row_guard_allows(9_000) + assert not scheduler._row_group_row_guard_allows(9_000) + + +def test_scheduler_derived_row_group_row_guard_is_capped_to_public_max() -> None: + provider = _mock_provider() + configs = [ + SamplerColumnConfig(name="topic", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + LLMTextColumnConfig(name="model_col", prompt="{{ topic }}", model_alias=MODEL_ALIAS), + ] + strategies = { + "topic": GenerationStrategy.FULL_COLUMN, + "model_col": GenerationStrategy.CELL_BY_CELL, + } + row_groups = [(rg_id, 20_000) for rg_id in range(64)] + graph = ExecutionGraph.create(configs, strategies) + scheduler = AsyncTaskScheduler( + generators={ + "topic": MockSeedGenerator(config=_expr_config("topic"), resource_provider=provider), + "model_col": MockCellGenerator(config=_expr_config("model_col"), resource_provider=provider), + }, + graph=graph, + tracker=CompletionTracker.with_graph(graph, row_groups), + row_groups=row_groups, + max_concurrent_row_groups=64, + adaptive_row_group_admission=True, + num_records=1_280_000, + buffer_size=20_000, + ) + + assert scheduler._max_admitted_rows == 1_000_000 + plan = scheduler.capacity_plan() + assert plan.configured.row_group_admission.max_admitted_rows == 1_000_000 + assert plan.configured.row_group_admission.max_admitted_rows_source == "engine_internal_config" + + +def test_scheduler_rejects_row_group_larger_than_row_budget() -> None: + provider = _mock_provider() + configs = [ + SamplerColumnConfig(name="topic", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + LLMTextColumnConfig(name="model_col", prompt="{{ topic }}", model_alias=MODEL_ALIAS), + ] + strategies = { + "topic": GenerationStrategy.FULL_COLUMN, + "model_col": GenerationStrategy.CELL_BY_CELL, + } + row_groups = [(0, 1_000_001)] + graph = ExecutionGraph.create(configs, strategies) + + with pytest.raises(ValueError, match="row-group size 1000001 exceeds row_group_admission.max_admitted_rows"): + AsyncTaskScheduler( + generators={ + "topic": MockSeedGenerator(config=_expr_config("topic"), resource_provider=provider), + "model_col": MockCellGenerator(config=_expr_config("model_col"), resource_provider=provider), + }, + graph=graph, + tracker=CompletionTracker.with_graph(graph, row_groups), + row_groups=row_groups, + max_concurrent_row_groups=1, + adaptive_row_group_admission=True, + num_records=1_000_001, + buffer_size=1_000_001, + ) + + +def test_scheduler_adaptive_target_does_not_grow_before_pending_row_group() -> None: + scheduler, _tracker = _build_simple_pipeline(num_records=2, buffer_size=1) + scheduler._adaptive_row_group_admission = True + scheduler._row_group_admission_hard_cap = 2 + scheduler._row_group_admission_target = 1 + scheduler._pending_row_group_admission_size = None + + scheduler._maybe_update_adaptive_row_group_target() + + assert scheduler._row_group_admission_target == 1 + + +@pytest.mark.asyncio(loop_scope="session") +async def test_scheduler_adaptive_derived_row_guard_blocks_admission() -> None: + provider = _mock_provider() + configs = [ + SamplerColumnConfig(name="topic", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + ExpressionColumnConfig(name="slow_full", expr="{{ topic }}", dtype="str"), + ] + strategies = { + "topic": GenerationStrategy.FULL_COLUMN, + "slow_full": GenerationStrategy.FULL_COLUMN, + } + row_groups = [(0, 4_097), (1, 4_097)] + graph = ExecutionGraph.create(configs, strategies) + tracker = CompletionTracker.with_graph(graph, row_groups) + scheduler = AsyncTaskScheduler( + generators={ + "topic": MockSeedGenerator(config=_expr_config("topic"), resource_provider=provider), + "slow_full": SlowFullColumnGenerator( + config=_expr_config("slow_full"), + resource_provider=provider, + delay=0.01, + ), + }, + graph=graph, + tracker=tracker, + row_groups=row_groups, + max_concurrent_row_groups=2, + max_in_flight_tasks=16, + max_model_task_admission=16, + adaptive_row_group_admission=True, + adaptive_row_group_initial_target=2, + num_records=8_194, + buffer_size=4_096, + ) + + await asyncio.wait_for(scheduler.run(), timeout=10.0) + + plan = scheduler.capacity_plan() + assert plan.configured.row_group_admission.max_admitted_rows == 8_192 + assert plan.observed_maxima.row_groups_in_flight == 1 + assert plan.configured.row_group_admission.blocked_reasons["max_admitted_rows"] > 0 + assert tracker.is_row_group_complete(0, 4_097, ["topic", "slow_full"]) + assert tracker.is_row_group_complete(1, 4_097, ["topic", "slow_full"]) + + +@pytest.mark.asyncio(loop_scope="session") +async def test_scheduler_widened_fixed_horizon_derives_row_group_row_guard() -> None: + provider = _mock_provider() + configs = [ + SamplerColumnConfig(name="topic", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + ExpressionColumnConfig(name="slow_full", expr="{{ topic }}", dtype="str"), + ] + strategies = { + "topic": GenerationStrategy.FULL_COLUMN, + "slow_full": GenerationStrategy.FULL_COLUMN, + } + row_groups = [(0, 2_000), (1, 2_000), (2, 2_000), (3, 2_000)] + graph = ExecutionGraph.create(configs, strategies) + tracker = CompletionTracker.with_graph(graph, row_groups) + scheduler = AsyncTaskScheduler( + generators={ + "topic": MockSeedGenerator(config=_expr_config("topic"), resource_provider=provider), + "slow_full": SlowFullColumnGenerator( + config=_expr_config("slow_full"), + resource_provider=provider, + delay=0.01, + ), + }, + graph=graph, + tracker=tracker, + row_groups=row_groups, + max_concurrent_row_groups=4, + max_in_flight_tasks=16, + max_model_task_admission=16, + num_records=8_000, + buffer_size=2_000, + ) + + await asyncio.wait_for(scheduler.run(), timeout=10.0) + + plan = scheduler.capacity_plan() + assert plan.configured.row_group_admission.max_admitted_rows == 8_000 + assert plan.configured.row_group_admission.max_admitted_rows_source == "engine_internal_config" + assert plan.observed_maxima.row_groups_in_flight == 4 + + +def test_scheduler_default_fixed_horizon_preserves_count_only_row_group_default() -> None: + provider = _mock_provider() + configs = [ + SamplerColumnConfig(name="topic", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + LLMTextColumnConfig(name="model_col", prompt="{{ topic }}", model_alias=MODEL_ALIAS), + ] + strategies = { + "topic": GenerationStrategy.FULL_COLUMN, + "model_col": GenerationStrategy.CELL_BY_CELL, + } + row_groups = [(0, 1_000_001)] + graph = ExecutionGraph.create(configs, strategies) + scheduler = AsyncTaskScheduler( + generators={ + "topic": MockSeedGenerator(config=_expr_config("topic"), resource_provider=provider), + "model_col": MockCellGenerator(config=_expr_config("model_col"), resource_provider=provider), + }, + graph=graph, + tracker=CompletionTracker.with_graph(graph, row_groups), + row_groups=row_groups, + num_records=1_000_001, + buffer_size=1_000_001, + ) + + assert scheduler.capacity_plan().configured.row_group_admission.max_admitted_rows is None + + +@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.parametrize("adaptive", [False, True], ids=["fixed", "adaptive"]) +async def test_scheduler_explicit_row_group_row_guard_blocks_admission(adaptive: bool) -> None: + provider = _mock_provider() + configs = [ + SamplerColumnConfig(name="topic", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + LLMTextColumnConfig(name="model_col", prompt="{{ topic }}", model_alias=MODEL_ALIAS), + ] + strategies = { + "topic": GenerationStrategy.FULL_COLUMN, + "model_col": GenerationStrategy.CELL_BY_CELL, + } + row_groups = [(0, 500), (1, 500)] + graph = ExecutionGraph.create(configs, strategies) + tracker = CompletionTracker.with_graph(graph, row_groups) + buffer_manager = RowGroupBufferManager(_make_storage()) + + def finalize_row_group(row_group: int) -> None: + buffer_manager.checkpoint_row_group(row_group) + + scheduler = AsyncTaskScheduler( + generators={ + "topic": MockSeedGenerator(config=_expr_config("topic"), resource_provider=provider), + "model_col": SlowLLMBoundCellGenerator( + config=_expr_config("model_col"), + resource_provider=provider, + delay=0.001, + ), + }, + graph=graph, + tracker=tracker, + row_groups=row_groups, + buffer_manager=buffer_manager, + on_finalize_row_group=finalize_row_group, + max_concurrent_row_groups=2, + max_in_flight_tasks=128, + max_model_task_admission=128, + adaptive_row_group_admission=adaptive, + adaptive_row_group_initial_target=2, + max_admitted_rows=600, + num_records=1_000, + buffer_size=500, + ) + + await asyncio.wait_for(scheduler.run(), timeout=10.0) + + plan = scheduler.capacity_plan() + assert tracker.is_row_group_complete(0, 500, ["topic", "model_col"]) + assert tracker.is_row_group_complete(1, 500, ["topic", "model_col"]) + assert plan.configured.row_group_admission.mode == ("adaptive" if adaptive else "fixed") + assert plan.observed_maxima.row_groups_in_flight == 1 + assert plan.configured.row_group_admission.blocked_reasons["max_admitted_rows"] > 0 def test_scheduler_adaptive_row_group_block_reason_prefers_llm_saturation() -> None: @@ -3823,7 +4209,7 @@ def test_scheduler_adaptive_row_group_block_reason_prefers_llm_saturation() -> N view=lambda: SimpleNamespace(resource_limits={"llm_wait": 1}, resources_available={"llm_wait": 0}) ) - assert scheduler._adaptive_row_group_block_reason() == "llm_wait_saturated" + assert scheduler._adaptive_row_group_block_reason(1) == "llm_wait_saturated" def test_scheduler_adaptive_row_group_queue_guard_uses_in_flight_task_cap() -> None: @@ -3834,7 +4220,7 @@ def test_scheduler_adaptive_row_group_queue_guard_uses_in_flight_task_cap() -> N view=lambda: SimpleNamespace(queued_total=8, queued_peer_demand_by_resource={}) ) - assert scheduler._adaptive_row_group_block_reason() == "queued_task_guardrail" + assert scheduler._adaptive_row_group_block_reason(1) == "queued_task_guardrail" @pytest.mark.asyncio(loop_scope="session") @@ -4277,6 +4663,12 @@ async def test_scheduler_post_batch_failure_raises() -> None: graph = ExecutionGraph.create(configs, strategies) row_groups = [(0, 3)] tracker = CompletionTracker.with_graph(graph, row_groups) + release_calls: list[int] = [] + release_row_group = tracker.release_row_group + + def spy_release_row_group(row_group: int, row_group_size: int, all_columns: list[str]) -> None: + release_calls.append(row_group) + release_row_group(row_group, row_group_size, all_columns) storage = MagicMock() storage.dataset_name = "test" @@ -4286,6 +4678,7 @@ async def test_scheduler_post_batch_failure_raises() -> None: def fail_post_batch(rg_id: int, rg_size: int) -> None: raise RuntimeError("post-batch processor exploded") + tracker.release_row_group = spy_release_row_group # type: ignore[method-assign] scheduler = AsyncTaskScheduler( generators=generators, graph=graph, @@ -4297,6 +4690,58 @@ def fail_post_batch(rg_id: int, rg_size: int) -> None: with pytest.raises(DatasetGenerationError, match="Post-batch processor failed"): await scheduler.run() + assert release_calls == [] + assert 0 in scheduler._rg_states + assert scheduler._active_admitted_row_count() == 3 + + +@pytest.mark.asyncio(loop_scope="session") +async def test_scheduler_checkpoint_failure_preserves_admission_state() -> None: + provider = _mock_provider() + configs = [ + SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + LLMTextColumnConfig(name="cell_out", prompt="{{ seed }}", model_alias=MODEL_ALIAS), + ] + strategies = { + "seed": GenerationStrategy.FULL_COLUMN, + "cell_out": GenerationStrategy.CELL_BY_CELL, + } + graph = ExecutionGraph.create(configs, strategies) + row_groups = [(0, 1), (1, 1)] + tracker = CompletionTracker.with_graph(graph, row_groups) + release_calls: list[int] = [] + release_row_group = tracker.release_row_group + + def spy_release_row_group(row_group: int, row_group_size: int, all_columns: list[str]) -> None: + release_calls.append(row_group) + release_row_group(row_group, row_group_size, all_columns) + + def fail_finalize(rg_id: int) -> None: + raise RuntimeError(f"storage checkpoint failed for row group {rg_id}") + + tracker.release_row_group = spy_release_row_group # type: ignore[method-assign] + scheduler = AsyncTaskScheduler( + generators={ + "seed": MockSeedGenerator(config=_expr_config("seed"), resource_provider=provider), + "cell_out": MockCellGenerator(config=_expr_config("cell_out"), resource_provider=provider), + }, + graph=graph, + tracker=tracker, + row_groups=row_groups, + max_concurrent_row_groups=1, + num_records=2, + buffer_size=1, + on_finalize_row_group=fail_finalize, + ) + + with pytest.raises(DatasetGenerationError, match="Failed to checkpoint row group 0"): + await scheduler.run() + + assert release_calls == [] + assert 0 in scheduler._rg_states + assert scheduler._active_admitted_row_count() == 1 + assert tracker.is_row_group_complete(0, 1, ["seed", "cell_out"]) + # -- Early shutdown drains workers ------------------------------------------- diff --git a/packages/data-designer/src/data_designer/interface/data_designer.py b/packages/data-designer/src/data_designer/interface/data_designer.py index 535e2391c..ce2beedae 100644 --- a/packages/data-designer/src/data_designer/interface/data_designer.py +++ b/packages/data-designer/src/data_designer/interface/data_designer.py @@ -606,7 +606,8 @@ def set_run_config(self, run_config: RunConfig) -> None: Args: run_config: A RunConfig instance containing runtime settings such as early shutdown behavior, batch sizing via `buffer_size`, async task lease - capacity via `max_in_flight_tasks`, and non-inference worker concurrency via + capacity via `max_in_flight_tasks`, async row-group horizon/admission via + `row_group_admission`, and non-inference worker concurrency via `non_inference_max_parallel_workers`. Notes: diff --git a/plans/645/async-scheduling-epic.puml b/plans/645/async-scheduling-epic.puml index e581eab8f..84b8f094c 100644 --- a/plans/645/async-scheduling-epic.puml +++ b/plans/645/async-scheduling-epic.puml @@ -558,6 +558,12 @@ package "Capacity planning (issue 654)" { class RowGroupAdmission { +row_group_concurrency +observed_in_flight + +mode + +target_in_flight + +observed_max_target + +max_admitted_rows + +max_admitted_rows_source + +blocked_reasons } class TransportPoolConfig { diff --git a/plans/645/contracts.md b/plans/645/contracts.md index 0c7a51a4f..ec961e081 100644 --- a/plans/645/contracts.md +++ b/plans/645/contracts.md @@ -477,6 +477,12 @@ CapacityValue[T]: RowGroupAdmission: row_group_concurrency: CapacityValue[int] observed_in_flight: int | None + mode: fixed | adaptive + target_in_flight: int + observed_max_target: int + max_admitted_rows: int | None + max_admitted_rows_source: default | run_config | dataset_builder | model_metadata | engine_internal_config | adapter_config | environment | runtime_snapshot | benchmark_override | None + blocked_reasons: Mapping[str, int] ProviderModelStaticCap: cap: int diff --git a/plans/741/row-group-admission.md b/plans/741/row-group-admission.md new file mode 100644 index 000000000..3c0c7be83 --- /dev/null +++ b/plans/741/row-group-admission.md @@ -0,0 +1,65 @@ +# Plan: Public Row-Group Admission Controls + +Fixes #741. + +## Problem + +Large async runs expose public controls for batch size, scheduler task leases, and +model request caps, but the row-group admission horizon is still hidden in the +async scheduler. Users cannot choose whether the scheduler admits a fixed number +of row groups or ramps the active row-group target adaptively, even though that +choice affects checkpoint cadence, active state size, and endpoint occupancy. + +## Goals + +1. Expose row-group admission as a supported `RunConfig` policy. +2. Preserve the existing fixed default behavior unless users opt into a wider or + adaptive policy. +3. Thread the public policy through the dataset-builder boundary into + `AsyncTaskScheduler`. +4. Keep scheduler diagnostics/capacity plans aligned with the effective public + settings. +5. Validate fixed and adaptive policies with local mock-provider experiments. + +## Non-Goals + +- Do not redesign task admission or request admission. +- Do not make adaptive row-group admission AIMD; it remains additive ramp-up + beneath a hard cap. +- Do not add new model/provider request-concurrency knobs. + +## Design + +Add `RowGroupAdmissionConfig` and `RowGroupAdmissionMode` to the config package. +`RunConfig.row_group_admission` defaults to a fixed horizon of three active row +groups, matching the current scheduler default while making the policy visible. + +`DatasetBuilder._prepare_async_run()` translates the public config into the +existing scheduler constructor arguments: + +- `max_concurrent_row_groups` +- `adaptive_row_group_admission` +- `adaptive_row_group_initial_target` +- `max_admitted_rows` + +The scheduler keeps ownership of the actual admission loop and capacity-plan +diagnostics. It records whether the row-group horizon and active-row budget came +from `run_config` or internal derivation so capacity reports can distinguish +public configuration from internal defaults. +The public config bounds `max_concurrent_row_groups`. The historical default +fixed horizon preserves row-group-count-only behavior, while widened fixed +horizons and adaptive mode derive an active-row guard when `max_admitted_rows` +is omitted so public row-group tuning cannot multiply active buffers silently. +Adaptive mode rejects row groups larger than the effective active-row guard +instead of admitting an oversized first group. + +## Validation + +- Config tests cover default exposure, dict/object construction, exports, and + invalid adaptive-only fields. +- Builder tests verify public row-group admission settings are passed to the + scheduler. +- Scheduler tests verify capacity diagnostics report the public source, fixed + cap, adaptive target, and explicit max-admitted-row guard. +- Local mock-provider experiments compare fixed and adaptive horizons across + fan-out, dependency-chain, and wide-row-group workloads.