diff --git a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/expression.py b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/expression.py index 28b79adf3..65ad3f4b1 100644 --- a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/expression.py +++ b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/expression.py @@ -4,6 +4,7 @@ from __future__ import annotations import logging +from collections import Counter from typing import TYPE_CHECKING import data_designer.lazy_heavy_imports as lazy @@ -12,6 +13,7 @@ from data_designer.engine.column_generators.utils.errors import ExpressionTemplateRenderError from data_designer.engine.context import format_row_group_tag from data_designer.engine.processing.ginja.environment import WithJinja2UserTemplateRendering +from data_designer.engine.processing.ginja.exceptions import EmptyTemplateRenderError, UserTemplateError from data_designer.engine.processing.utils import deserialize_json_values if TYPE_CHECKING: @@ -19,6 +21,11 @@ logger = logging.getLogger(__name__) +EMPTY_RENDERED_EXPRESSION = "EmptyRenderedExpression" +TEMPLATE_RENDER_ERROR = "TemplateRenderError" +TYPE_CAST_ERROR = "TypeCastError" +_VALID_DTYPES = {"str", "float", "int", "bool"} + class ExpressionColumnGenerator(WithJinja2UserTemplateRendering, ColumnGeneratorFullColumn[ExpressionColumnConfig]): def generate(self, data: pd.DataFrame) -> pd.DataFrame: @@ -32,13 +39,69 @@ def generate(self, data: pd.DataFrame) -> pd.DataFrame: ) raise ExpressionTemplateRenderError(error_msg) + if self.config.dtype not in _VALID_DTYPES: + raise ValueError(f"Invalid dtype: {self.config.dtype}") + self.prepare_jinja2_template_renderer(self.config.expr, data.columns.to_list()) - records = [] - for record in data.to_dict(orient="records"): - record[self.config.name] = self._cast_type(self.render_template(deserialize_json_values(record))) + records: list[dict] = [] + retained_indexes: list[object] = [] + drop_counts: Counter[str] = Counter() + + for row_index, record in zip(data.index.to_list(), data.to_dict(orient="records"), strict=True): + prepared_record = deserialize_json_values(record) + try: + rendered_value = self.render_template(prepared_record) + except EmptyTemplateRenderError: + drop_counts[EMPTY_RENDERED_EXPRESSION] += 1 + continue + except Exception: + logger.debug( + "Expression column %r dropped row %r after template render failure.", + self.config.name, + row_index, + exc_info=True, + ) + drop_counts[TEMPLATE_RENDER_ERROR] += 1 + continue + + if self._is_empty_rendered_expression(rendered_value): + drop_counts[EMPTY_RENDERED_EXPRESSION] += 1 + continue + + try: + record[self.config.name] = self._cast_type(rendered_value) + except (OverflowError, TypeError, ValueError): + drop_counts[TYPE_CAST_ERROR] += 1 + continue + records.append(record) + retained_indexes.append(row_index) + + total_dropped = sum(drop_counts.values()) + if total_dropped > 0: + self._log_row_drops(drop_counts, input_count=len(data), retained_count=len(records)) + if len(records) == 0: + raise UserTemplateError(f"Expression column {self.config.name!r} produced no valid rows.") - return lazy.pd.DataFrame(records) + return lazy.pd.DataFrame(records, index=retained_indexes) + + @staticmethod + def _is_empty_rendered_expression(value: object) -> bool: + if value is None: + return True + return isinstance(value, str) and len(value.strip()) == 0 + + def _log_row_drops(self, drop_counts: Counter[str], *, input_count: int, retained_count: int) -> None: + breakdown = ", ".join(f"{name}={count}" for name, count in sorted(drop_counts.items())) + total_dropped = sum(drop_counts.values()) + message = ( + f"Expression column {self.config.name!r} dropped {total_dropped}/{input_count} rows after render: " + f"{breakdown}." + ) + if retained_count == 0: + logger.error(message) + else: + logger.warning(f"{message} Continuing with {retained_count} rows.") def _cast_type(self, value: str) -> str | float | int | bool: if self.config.dtype == "str": diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py index ace57190f..018c06d1d 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py @@ -15,7 +15,7 @@ from typing import TYPE_CHECKING, Any, Callable import data_designer.lazy_heavy_imports as lazy -from data_designer.config.column_configs import GenerationStrategy +from data_designer.config.column_configs import ExpressionColumnConfig, GenerationStrategy from data_designer.engine.capacity import ( AsyncCapacityConfigured, AsyncCapacityObservedMaxima, @@ -80,6 +80,7 @@ SchedulerAdmissionEventSink, runtime_correlation_provider, ) +from data_designer.engine.processing.ginja.exceptions import UserTemplateError if TYPE_CHECKING: from data_designer.engine.column_generators.generators.base import ColumnGenerator @@ -402,6 +403,8 @@ def first_non_retryable_error(self) -> Exception | None: def _raise_if_fatal_worker_error(self) -> None: if self._fatal_worker_error is None: return + if isinstance(self._fatal_worker_error, UserTemplateError): + raise DatasetGenerationError(str(self._fatal_worker_error)) from self._fatal_worker_error raise DatasetGenerationError( "Unexpected internal task failure in async scheduler." ) from self._fatal_worker_error @@ -1696,6 +1699,19 @@ async def _execute_task_inner_impl(self, task: Task, lease: TaskAdmissionLease, trace.status = "error" trace.error = str(exc) + if self._is_fatal_expression_template_error(task, generator, exc): + logger.error( + "Fatal expression failure on %s[rg=%s, row=%s]: %s", + task.column, + task.row_group, + task.row_index, + exc, + exc_info=True, + ) + self._fatal_worker_error = exc + self._wake_event.set() + raise + if retryable: self._deferred.append(task) self._deferred_errors[task] = exc @@ -1783,6 +1799,9 @@ async def _run_generator_call(self, task: Task, operation: str, call: Coroutine[ try: return await call except Exception as exc: + generator = self._generators[task.column] + if self._is_fatal_expression_template_error(task, generator, exc): + raise if self._is_retryable(exc) or self._is_expected_non_retryable(exc): raise raise DatasetGenerationError( @@ -1808,6 +1827,17 @@ def _require_dataframe_result( ) return result + def _task_supports_row_drops(self, task: Task, generator: ColumnGenerator) -> bool: + return task.task_type == "batch" and isinstance(generator.config, ExpressionColumnConfig) + + def _is_fatal_expression_template_error( + self, + task: Task, + generator: ColumnGenerator, + exc: BaseException, + ) -> bool: + return self._task_supports_row_drops(task, generator) and isinstance(exc, UserTemplateError) + async def _run_from_scratch(self, task: Task, generator: ColumnGenerator) -> Any: """Execute a from_scratch task.""" rg_size = self._get_rg_size(task.row_group) @@ -1920,6 +1950,7 @@ async def _run_batch(self, task: Task, generator: ColumnGenerator) -> Any: if self._buffer_manager is not None: pre_dropped: set[int] = {ri for ri in range(rg_size) if self._buffer_manager.is_dropped(task.row_group, ri)} active_rows_data: list[dict] = [] + active_row_indices: list[int] = [] # Skip evaluation only applies to single-column configs. # Multi-column configs (sampler/seed) are rejected by the SkipConfig @@ -1937,9 +1968,10 @@ async def _run_batch(self, task: Task, generator: ColumnGenerator) -> Any: continue active_rows_data.append(record) + active_row_indices.append(ri) batch_df = ( - lazy.pd.DataFrame(strip_skip_metadata_from_records(active_rows_data)) + lazy.pd.DataFrame(strip_skip_metadata_from_records(active_rows_data), index=active_row_indices) if active_rows_data else lazy.pd.DataFrame() ) @@ -1947,6 +1979,7 @@ async def _run_batch(self, task: Task, generator: ColumnGenerator) -> Any: batch_df = lazy.pd.DataFrame() pre_dropped = set() pre_skipped = set() + active_row_indices = [] if len(batch_df) == 0: return batch_df @@ -1957,28 +1990,81 @@ async def _run_batch(self, task: Task, generator: ColumnGenerator) -> Any: "batch generation", generator.agenerate(batch_df), ) - result_df = self._require_dataframe_result( - task, - "Batch generator", - result_df, - expected_rows=active_rows, - ) + if self._task_supports_row_drops(task, generator): + result_df = self._require_expression_row_drop_result( + task, + result_df, + active_row_indices=active_row_indices, + ) + else: + result_df = self._require_dataframe_result( + task, + "Batch generator", + result_df, + expected_rows=active_rows, + ) # Merge result columns back to buffer (include side-effect columns) if self._buffer_manager is not None: write_cols = self._gen_instance_to_columns_including_side_effects.get(id(generator), [task.column]) - result_idx = 0 + result_by_row_index = self._batch_result_by_row_index( + result_df, + active_row_indices=active_row_indices, + supports_row_drops=self._task_supports_row_drops(task, generator), + ) for ri in range(rg_size): if ri in pre_dropped or ri in pre_skipped: continue + result_row = result_by_row_index.get(ri) + if result_row is None: + self._drop_row(task.row_group, ri, exclude_columns={task.column}) + continue if not self._buffer_manager.is_dropped(task.row_group, ri): for col in write_cols: - if col in result_df.columns: - self._buffer_manager.update_cell(task.row_group, ri, col, result_df.iloc[result_idx][col]) - result_idx += 1 + if col in result_row: + self._buffer_manager.update_cell(task.row_group, ri, col, result_row[col]) return result_df + def _require_expression_row_drop_result( + self, + task: Task, + result: Any, + *, + active_row_indices: list[int], + ) -> Any: + result_df = self._require_dataframe_result(task, "Batch generator", result) + result_indexes = result_df.index.to_list() + if len(result_indexes) != len(set(result_indexes)): + raise DatasetGenerationError( + f"Batch generator for column '{task.column}' returned duplicate row indexes (rg={task.row_group})." + ) + active_index_set = set(active_row_indices) + unexpected_indexes = set(result_indexes) - active_index_set + if unexpected_indexes: + raise DatasetGenerationError( + f"Batch generator for column '{task.column}' returned row indexes outside the active input rows " + f"(rg={task.row_group}): {sorted(unexpected_indexes)!r}." + ) + if len(result_df) > len(active_row_indices): + raise DatasetGenerationError( + f"Batch generator for column '{task.column}' returned {len(result_df)} rows " + f"but at most {len(active_row_indices)} active rows were expected (rg={task.row_group})." + ) + return result_df + + @staticmethod + def _batch_result_by_row_index( + result_df: Any, + *, + active_row_indices: list[int], + supports_row_drops: bool, + ) -> dict[int, dict[str, Any]]: + result_records = result_df.to_dict(orient="records") + if supports_row_drops: + return dict(zip(result_df.index.to_list(), result_records, strict=True)) + return dict(zip(active_row_indices, result_records, strict=True)) + def _get_rg_size(self, row_group: int) -> int: try: return self._row_groups.row_group_size(row_group) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py index 1bbd51df7..84fd14c73 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py @@ -17,6 +17,7 @@ from pydantic import ValidationError import data_designer.lazy_heavy_imports as lazy +from data_designer.config.column_configs import ExpressionColumnConfig from data_designer.config.column_types import ColumnConfigT, DataDesignerColumnType from data_designer.config.config_builder import BuilderConfig from data_designer.config.data_designer_config import DataDesignerConfig @@ -1335,21 +1336,34 @@ def _run_full_column_generator_without_skip(self, generator: ColumnGenerator) -> """Run the generator on the full batch, preserving skip metadata across the replace.""" original_count = self.batch_manager.num_records_in_buffer allow_resize = generator.config.allow_resize if not isinstance(generator.config, MultiColumnConfig) else False + supports_row_drops = self._generator_supports_row_drops(generator) old_records = [record for _, record in self.batch_manager.iter_current_batch()] input_records, restore_context = prepare_records_for_skip_metadata_round_trip(old_records) df = generator.generate(lazy.pd.DataFrame(input_records)) + if supports_row_drops and len(df) > original_count: + raise DatasetGenerationError( + f"Generator for {self._column_display_name(generator.config)} returned {len(df)} rows " + f"but at most {original_count} were expected." + ) self._log_resize_if_changed(self._column_display_name(generator.config), original_count, len(df), allow_resize) new_records = df.to_dict(orient="records") if restore_context is not None: try: - restore_skip_metadata(new_records, context=restore_context, allow_resize=allow_resize) + restore_skip_metadata( + new_records, + context=restore_context, + allow_resize=allow_resize or supports_row_drops, + ) except ValueError as exc: raise DatasetGenerationError( f"Unable to restore skip provenance after FULL_COLUMN generation for " f"{self._column_display_name(generator.config)}: {exc}" ) from exc - self.batch_manager.replace_buffer(new_records, allow_resize=allow_resize) + self.batch_manager.replace_buffer(new_records, allow_resize=allow_resize or supports_row_drops) + + def _generator_supports_row_drops(self, generator: ColumnGenerator) -> bool: + return isinstance(generator.config, ExpressionColumnConfig) def _run_full_column_generator_with_skip(self, generator: ColumnGenerator, column_name: str) -> None: """Run a FULL_COLUMN generator with per-row skip evaluation and merge-back. @@ -1377,7 +1391,7 @@ def _run_full_column_generator_with_skip(self, generator: ColumnGenerator, colum return batch = self._merge_skipped_and_generated(generator, column_name, active_records, records_with_skip_status) - self.batch_manager.replace_buffer(batch, allow_resize=False) + self.batch_manager.replace_buffer(batch, allow_resize=self._generator_supports_row_drops(generator)) def _merge_skipped_and_generated( self, @@ -1391,7 +1405,16 @@ def _merge_skipped_and_generated( return [record for _, record in records_with_skip_status] active_df = lazy.pd.DataFrame(strip_skip_metadata_from_records(active_records)) - result_records = generator.generate(active_df).to_dict(orient="records") + result_df = generator.generate(active_df) + result_records = result_df.to_dict(orient="records") + if self._generator_supports_row_drops(generator): + return self._merge_row_dropped_generated_records( + result_df=result_df, + result_records=result_records, + active_record_count=len(active_records), + records_with_skip_status=records_with_skip_status, + ) + if len(result_records) != len(active_records): raise DatasetGenerationError( f"Generator for '{column_name}' returned {len(result_records)} rows " @@ -1411,6 +1434,45 @@ def _merge_skipped_and_generated( batch.append(gen_result) return batch + def _merge_row_dropped_generated_records( + self, + *, + result_df: pd.DataFrame, + result_records: list[dict], + active_record_count: int, + records_with_skip_status: list[tuple[bool, dict]], + ) -> list[dict]: + result_indexes = result_df.index.to_list() + if len(result_indexes) != len(set(result_indexes)): + raise DatasetGenerationError("Expression generator returned duplicate row indexes after row drops.") + + result_by_active_index = dict(zip(result_indexes, result_records, strict=True)) + unexpected_indexes = set(result_by_active_index) - set(range(active_record_count)) + if unexpected_indexes: + raise DatasetGenerationError( + "Expression generator returned row indexes outside the active input rows: " + f"{sorted(unexpected_indexes)!r}." + ) + + batch: list[dict] = [] + active_index = 0 + for skipped, record in records_with_skip_status: + if skipped: + batch.append(record) + continue + + gen_result = result_by_active_index.get(active_index) + active_index += 1 + if gen_result is None: + continue + + prior_skipped = record.get(SKIPPED_COLUMNS_RECORD_KEY) + if prior_skipped is not None: + gen_result[SKIPPED_COLUMNS_RECORD_KEY] = prior_skipped + batch.append(gen_result) + + return batch + def _setup_fan_out( self, generator: ColumnGeneratorWithModelRegistry, diff --git a/packages/data-designer-engine/tests/engine/column_generators/generators/test_expression.py b/packages/data-designer-engine/tests/engine/column_generators/generators/test_expression.py index 15080368f..4dcb8d1bf 100644 --- a/packages/data-designer-engine/tests/engine/column_generators/generators/test_expression.py +++ b/packages/data-designer-engine/tests/engine/column_generators/generators/test_expression.py @@ -3,6 +3,7 @@ from __future__ import annotations +import logging from unittest.mock import Mock, patch import pytest @@ -12,7 +13,7 @@ from data_designer.config.run_config import JinjaRenderingEngine, RunConfig from data_designer.engine.column_generators.generators.expression import ExpressionColumnGenerator from data_designer.engine.column_generators.utils.errors import ExpressionTemplateRenderError -from data_designer.engine.processing.ginja.exceptions import UserTemplateUnsupportedFiltersError +from data_designer.engine.processing.ginja.exceptions import UserTemplateError, UserTemplateUnsupportedFiltersError from data_designer.engine.resources.resource_provider import ResourceProvider @@ -27,6 +28,7 @@ def _create_test_generator(config=None, resource_provider=None): config = _create_test_config() if resource_provider is None: resource_provider = Mock(spec=ResourceProvider) + resource_provider.run_config = RunConfig() return ExpressionColumnGenerator(config=config, resource_provider=resource_provider) @@ -164,6 +166,63 @@ def test_generate_with_missing_columns(): generator.generate(df) +def test_generate_drops_empty_rendered_rows_and_warns(caplog: pytest.LogCaptureFixture) -> None: + config = _create_test_config("output", "{{ answer }}", "str") + generator = _create_test_generator(config) + df = lazy.pd.DataFrame({"answer": ["42", "", " ", "7"]}) + + with caplog.at_level(logging.WARNING): + result = generator.generate(df) + + assert result["output"].tolist() == ["42", "7"] + assert result.index.tolist() == [0, 3] + assert "Expression column 'output' dropped 2/4 rows after render: EmptyRenderedExpression=2." in caplog.text + assert "Continuing with 2 rows." in caplog.text + + +def test_generate_drops_row_specific_template_errors(caplog: pytest.LogCaptureFixture) -> None: + config = _create_test_config("ratio", "{{ 1 / denominator }}", "float") + generator = _create_test_generator(config) + df = lazy.pd.DataFrame({"denominator": [1, 0, 2]}) + + with caplog.at_level(logging.WARNING): + result = generator.generate(df) + + assert result["ratio"].tolist() == [1.0, 0.5] + assert result.index.tolist() == [0, 2] + assert "TemplateRenderError=1" in caplog.text + + +def test_generate_drops_type_cast_errors(caplog: pytest.LogCaptureFixture) -> None: + config = _create_test_config("number", "{{ value }}", "int") + generator = _create_test_generator(config) + df = lazy.pd.DataFrame({"value": ["1", "not-a-number", "3"]}) + + with caplog.at_level(logging.WARNING): + result = generator.generate(df) + + assert result["number"].tolist() == [1, 3] + assert result.index.tolist() == [0, 2] + assert "TypeCastError=1" in caplog.text + + +def test_generate_raises_when_all_rows_drop(caplog: pytest.LogCaptureFixture) -> None: + config = _create_test_config("output", "{{ answer }}", "str") + generator = _create_test_generator(config) + df = lazy.pd.DataFrame({"answer": ["", " "]}) + + with ( + caplog.at_level(logging.ERROR), + pytest.raises( + UserTemplateError, + match="Expression column 'output' produced no valid rows.", + ), + ): + generator.generate(df) + + assert "Expression column 'output' dropped 2/2 rows after render: EmptyRenderedExpression=2." in caplog.text + + def test_generate_respects_run_config_jinja_rendering_engine() -> None: df = lazy.pd.DataFrame({"col1": [["a", "b"]]}) diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_builder_integration.py b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_builder_integration.py index b2bc98c4a..031669477 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_builder_integration.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_builder_integration.py @@ -19,14 +19,18 @@ LLMTextColumnConfig, SamplerColumnConfig, ) +from data_designer.config.run_config import RunConfig from data_designer.config.sampler_params import SamplerType from data_designer.engine.column_generators.generators.base import ( ColumnGenerator, ColumnGeneratorFullColumn, FromScratchColumnGenerator, ) +from data_designer.engine.column_generators.generators.expression import ExpressionColumnGenerator +from data_designer.engine.context import current_row_group from data_designer.engine.dataset_builders.async_scheduler import AsyncTaskScheduler from data_designer.engine.dataset_builders.dataset_builder import DatasetBuilder +from data_designer.engine.dataset_builders.errors import DatasetGenerationError from data_designer.engine.dataset_builders.row_group_plan import CompactRowGroupPlan from data_designer.engine.dataset_builders.scheduling.completion import CompletionTracker, FrontierDelta from data_designer.engine.dataset_builders.utils.execution_graph import ExecutionGraph @@ -193,6 +197,118 @@ def finalize_row_group(rg_id: int) -> None: assert tracker.is_row_group_complete(1, 2, all_cols) +@pytest.mark.asyncio(loop_scope="session") +async def test_expression_row_drops_shrink_async_row_group(caplog: pytest.LogCaptureFixture) -> None: + """Expression row drops are applied to the exact row indexes in the async scheduler.""" + + class SeedWithEmpty(MockSeed): + def generate_from_scratch(self, num_records: int) -> lazy.pd.DataFrame: + values = ["keep", "", "also"] + return lazy.pd.DataFrame({"seed": values[:num_records]}) + + provider = _mock_provider() + provider.run_config = RunConfig() + seed_gen = SeedWithEmpty(config=_expr_config("seed"), resource_provider=provider) + expr_config = ExpressionColumnConfig(name="copy", expr="{{ seed }}") + expr_gen = ExpressionColumnGenerator(config=expr_config, resource_provider=provider) + + configs = [ + SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + expr_config, + ] + strategies = {"seed": GenerationStrategy.FULL_COLUMN, "copy": GenerationStrategy.FULL_COLUMN} + gen_map = {"seed": seed_gen, "copy": expr_gen} + + graph = ExecutionGraph.create(configs, strategies) + row_groups = [(0, 3)] + tracker = CompletionTracker.with_graph(graph, row_groups) + + written_batches: list[lazy.pd.DataFrame] = [] + storage = MagicMock() + storage.dataset_name = "test" + storage.get_file_paths.return_value = {} + storage.move_partial_result_to_final_file_path.return_value = "/fake_final.parquet" + + def capture_batch(**kwargs: object) -> None: + written_batches.append(kwargs["dataframe"].copy()) + + storage.write_batch_to_parquet_file.side_effect = capture_batch + buffer_manager = RowGroupBufferManager(storage) + + scheduler = AsyncTaskScheduler( + generators=gen_map, + graph=graph, + tracker=tracker, + row_groups=row_groups, + buffer_manager=buffer_manager, + on_finalize_row_group=lambda rg_id: buffer_manager.checkpoint_row_group(rg_id), + ) + with caplog.at_level("WARNING"): + await scheduler.run() + + assert buffer_manager.actual_num_records == 2 + assert tracker.is_dropped(0, 1) + assert tracker.is_row_group_complete(0, 3, ["seed", "copy"]) + assert len(written_batches) == 1 + assert written_batches[0]["seed"].tolist() == ["keep", "also"] + assert written_batches[0]["copy"].tolist() == ["keep", "also"] + assert "Expression column 'copy' dropped 1/3 rows after render: EmptyRenderedExpression=1." in caplog.text + + +@pytest.mark.asyncio(loop_scope="session") +async def test_expression_all_dropped_async_row_group_fails_loudly() -> None: + """All-dropped expression batches abort async generation instead of salvaging the row group.""" + + class SeedWithAllDroppedFirstGroup(MockSeed): + def generate_from_scratch(self, num_records: int) -> lazy.pd.DataFrame: + rg = current_row_group.get() + values = ["", ""] if rg is not None and rg[0] == 0 else ["keep", "also"] + return lazy.pd.DataFrame({"seed": values[:num_records]}) + + provider = _mock_provider() + provider.run_config = RunConfig() + seed_gen = SeedWithAllDroppedFirstGroup(config=_expr_config("seed"), resource_provider=provider) + expr_config = ExpressionColumnConfig(name="copy", expr="{{ seed }}") + expr_gen = ExpressionColumnGenerator(config=expr_config, resource_provider=provider) + + configs = [ + SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + expr_config, + ] + strategies = {"seed": GenerationStrategy.FULL_COLUMN, "copy": GenerationStrategy.FULL_COLUMN} + gen_map = {"seed": seed_gen, "copy": expr_gen} + + graph = ExecutionGraph.create(configs, strategies) + row_groups = [(0, 2), (1, 2)] + tracker = CompletionTracker.with_graph(graph, row_groups) + + written_batches: list[lazy.pd.DataFrame] = [] + storage = MagicMock() + storage.dataset_name = "test" + storage.get_file_paths.return_value = {} + storage.move_partial_result_to_final_file_path.return_value = "/fake_final.parquet" + storage.write_batch_to_parquet_file.side_effect = lambda **kwargs: written_batches.append( + kwargs["dataframe"].copy() + ) + buffer_manager = RowGroupBufferManager(storage) + + scheduler = AsyncTaskScheduler( + generators=gen_map, + graph=graph, + tracker=tracker, + row_groups=row_groups, + buffer_manager=buffer_manager, + max_concurrent_row_groups=1, + on_finalize_row_group=lambda rg_id: buffer_manager.checkpoint_row_group(rg_id), + ) + + with pytest.raises(DatasetGenerationError, match="Expression column 'copy' produced no valid rows."): + await scheduler.run() + + assert written_batches == [] + assert buffer_manager.actual_num_records == 0 + + def test_prepare_async_run_enables_request_pressure_advisory(monkeypatch: pytest.MonkeyPatch) -> None: captured_kwargs: dict[str, object] = {} diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/test_dataset_builder.py b/packages/data-designer-engine/tests/engine/dataset_builders/test_dataset_builder.py index 0a0f192b4..82c6379c9 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/test_dataset_builder.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/test_dataset_builder.py @@ -628,6 +628,61 @@ def bad_fn(df: pd.DataFrame) -> pd.DataFrame: builder.build_preview(num_records=3) +def test_expression_column_row_drops_shrink_sync_batch( + stub_resource_provider: Mock, + stub_model_configs: dict[str, object], + caplog: pytest.LogCaptureFixture, +) -> None: + seed_source = DataFrameSeedSource(df=lazy.pd.DataFrame({"seed_id": [1, 2, 3, 4], "text": ["a", "", "c", "d"]})) + seed_reader = DataFrameSeedReader() + seed_reader.attach(seed_source, Mock()) + stub_resource_provider.seed_reader = seed_reader + + config_builder = DataDesignerConfigBuilder(model_configs=stub_model_configs) + config_builder.with_seed_dataset(seed_source) + config_builder.add_column(ExpressionColumnConfig(name="copy", expr="{{ text }}")) + builder = DatasetBuilder( + data_designer_config=config_builder.build(), + resource_provider=stub_resource_provider, + ) + + with caplog.at_level(logging.WARNING): + result = builder.build_preview(num_records=4) + + assert result["seed_id"].tolist() == [1, 3, 4] + assert result["copy"].tolist() == ["a", "c", "d"] + assert "Expression column 'copy' dropped 1/4 rows after render: EmptyRenderedExpression=1." in caplog.text + + +def test_expression_column_row_drops_shrink_sync_skip_aware_batch( + stub_resource_provider: Mock, + stub_model_configs: dict[str, object], +) -> None: + seed_source = DataFrameSeedSource(df=lazy.pd.DataFrame({"seed_id": [1, 2, 3], "text": ["skip-me", "", "keep"]})) + seed_reader = DataFrameSeedReader() + seed_reader.attach(seed_source, Mock()) + stub_resource_provider.seed_reader = seed_reader + + config_builder = DataDesignerConfigBuilder(model_configs=stub_model_configs) + config_builder.with_seed_dataset(seed_source) + config_builder.add_column( + ExpressionColumnConfig( + name="copy", + expr="{{ text }}", + skip=SkipConfig(when="{{ seed_id == 1 }}", value="skipped"), + ) + ) + builder = DatasetBuilder( + data_designer_config=config_builder.build(), + resource_provider=stub_resource_provider, + ) + + result = builder.build_preview(num_records=3) + + assert result["seed_id"].tolist() == [1, 3] + assert result["copy"].tolist() == ["skipped", "keep"] + + def test_build_async_preview_returns_empty_dataframe_when_row_group_is_already_freed( stub_resource_provider, stub_test_config_builder,