Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from __future__ import annotations

import logging
from collections import Counter
from typing import TYPE_CHECKING

import data_designer.lazy_heavy_imports as lazy
Expand All @@ -12,13 +13,19 @@
from data_designer.engine.column_generators.utils.errors import ExpressionTemplateRenderError
from data_designer.engine.context import format_row_group_tag
from data_designer.engine.processing.ginja.environment import WithJinja2UserTemplateRendering
from data_designer.engine.processing.ginja.exceptions import EmptyTemplateRenderError, UserTemplateError
from data_designer.engine.processing.utils import deserialize_json_values

if TYPE_CHECKING:
import pandas as pd

logger = logging.getLogger(__name__)

EMPTY_RENDERED_EXPRESSION = "EmptyRenderedExpression"
TEMPLATE_RENDER_ERROR = "TemplateRenderError"
TYPE_CAST_ERROR = "TypeCastError"
_VALID_DTYPES = {"str", "float", "int", "bool"}


class ExpressionColumnGenerator(WithJinja2UserTemplateRendering, ColumnGeneratorFullColumn[ExpressionColumnConfig]):
def generate(self, data: pd.DataFrame) -> pd.DataFrame:
Expand All @@ -32,13 +39,69 @@ def generate(self, data: pd.DataFrame) -> pd.DataFrame:
)
raise ExpressionTemplateRenderError(error_msg)

if self.config.dtype not in _VALID_DTYPES:
raise ValueError(f"Invalid dtype: {self.config.dtype}")

self.prepare_jinja2_template_renderer(self.config.expr, data.columns.to_list())
records = []
for record in data.to_dict(orient="records"):
record[self.config.name] = self._cast_type(self.render_template(deserialize_json_values(record)))
records: list[dict] = []
retained_indexes: list[object] = []
drop_counts: Counter[str] = Counter()

for row_index, record in zip(data.index.to_list(), data.to_dict(orient="records"), strict=True):
prepared_record = deserialize_json_values(record)
try:
rendered_value = self.render_template(prepared_record)
except EmptyTemplateRenderError:
drop_counts[EMPTY_RENDERED_EXPRESSION] += 1
continue
except Exception:
logger.debug(
"Expression column %r dropped row %r after template render failure.",
self.config.name,
row_index,
exc_info=True,
)
drop_counts[TEMPLATE_RENDER_ERROR] += 1
continue

if self._is_empty_rendered_expression(rendered_value):
drop_counts[EMPTY_RENDERED_EXPRESSION] += 1
continue

try:
record[self.config.name] = self._cast_type(rendered_value)
except (OverflowError, TypeError, ValueError):
drop_counts[TYPE_CAST_ERROR] += 1
continue

records.append(record)
retained_indexes.append(row_index)

total_dropped = sum(drop_counts.values())
if total_dropped > 0:
self._log_row_drops(drop_counts, input_count=len(data), retained_count=len(records))
if len(records) == 0:
raise UserTemplateError(f"Expression column {self.config.name!r} produced no valid rows.")

return lazy.pd.DataFrame(records)
return lazy.pd.DataFrame(records, index=retained_indexes)

@staticmethod
def _is_empty_rendered_expression(value: object) -> bool:
if value is None:
return True
return isinstance(value, str) and len(value.strip()) == 0

def _log_row_drops(self, drop_counts: Counter[str], *, input_count: int, retained_count: int) -> None:
breakdown = ", ".join(f"{name}={count}" for name, count in sorted(drop_counts.items()))
total_dropped = sum(drop_counts.values())
message = (
f"Expression column {self.config.name!r} dropped {total_dropped}/{input_count} rows after render: "
f"{breakdown}."
)
if retained_count == 0:
logger.error(message)
else:
logger.warning(f"{message} Continuing with {retained_count} rows.")

def _cast_type(self, value: str) -> str | float | int | bool:
if self.config.dtype == "str":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from typing import TYPE_CHECKING, Any, Callable

import data_designer.lazy_heavy_imports as lazy
from data_designer.config.column_configs import GenerationStrategy
from data_designer.config.column_configs import ExpressionColumnConfig, GenerationStrategy
from data_designer.engine.capacity import (
AsyncCapacityConfigured,
AsyncCapacityObservedMaxima,
Expand Down Expand Up @@ -80,6 +80,7 @@
SchedulerAdmissionEventSink,
runtime_correlation_provider,
)
from data_designer.engine.processing.ginja.exceptions import UserTemplateError

if TYPE_CHECKING:
from data_designer.engine.column_generators.generators.base import ColumnGenerator
Expand Down Expand Up @@ -402,6 +403,8 @@ def first_non_retryable_error(self) -> Exception | None:
def _raise_if_fatal_worker_error(self) -> None:
if self._fatal_worker_error is None:
return
if isinstance(self._fatal_worker_error, UserTemplateError):
raise DatasetGenerationError(str(self._fatal_worker_error)) from self._fatal_worker_error
raise DatasetGenerationError(
"Unexpected internal task failure in async scheduler."
) from self._fatal_worker_error
Expand Down Expand Up @@ -1696,6 +1699,19 @@ async def _execute_task_inner_impl(self, task: Task, lease: TaskAdmissionLease,
trace.status = "error"
trace.error = str(exc)

if self._is_fatal_expression_template_error(task, generator, exc):
logger.error(
"Fatal expression failure on %s[rg=%s, row=%s]: %s",
task.column,
task.row_group,
task.row_index,
exc,
exc_info=True,
)
self._fatal_worker_error = exc
self._wake_event.set()
raise

if retryable:
self._deferred.append(task)
self._deferred_errors[task] = exc
Expand Down Expand Up @@ -1783,6 +1799,9 @@ async def _run_generator_call(self, task: Task, operation: str, call: Coroutine[
try:
return await call
except Exception as exc:
generator = self._generators[task.column]
if self._is_fatal_expression_template_error(task, generator, exc):
raise
if self._is_retryable(exc) or self._is_expected_non_retryable(exc):
raise
raise DatasetGenerationError(
Expand All @@ -1808,6 +1827,17 @@ def _require_dataframe_result(
)
return result

def _task_supports_row_drops(self, task: Task, generator: ColumnGenerator) -> bool:
return task.task_type == "batch" and isinstance(generator.config, ExpressionColumnConfig)

def _is_fatal_expression_template_error(
self,
task: Task,
generator: ColumnGenerator,
exc: BaseException,
) -> bool:
return self._task_supports_row_drops(task, generator) and isinstance(exc, UserTemplateError)

async def _run_from_scratch(self, task: Task, generator: ColumnGenerator) -> Any:
"""Execute a from_scratch task."""
rg_size = self._get_rg_size(task.row_group)
Expand Down Expand Up @@ -1920,6 +1950,7 @@ async def _run_batch(self, task: Task, generator: ColumnGenerator) -> Any:
if self._buffer_manager is not None:
pre_dropped: set[int] = {ri for ri in range(rg_size) if self._buffer_manager.is_dropped(task.row_group, ri)}
active_rows_data: list[dict] = []
active_row_indices: list[int] = []

# Skip evaluation only applies to single-column configs.
# Multi-column configs (sampler/seed) are rejected by the SkipConfig
Expand All @@ -1937,16 +1968,18 @@ async def _run_batch(self, task: Task, generator: ColumnGenerator) -> Any:
continue

active_rows_data.append(record)
active_row_indices.append(ri)

batch_df = (
lazy.pd.DataFrame(strip_skip_metadata_from_records(active_rows_data))
lazy.pd.DataFrame(strip_skip_metadata_from_records(active_rows_data), index=active_row_indices)
if active_rows_data
else lazy.pd.DataFrame()
)
else:
batch_df = lazy.pd.DataFrame()
pre_dropped = set()
pre_skipped = set()
active_row_indices = []

if len(batch_df) == 0:
return batch_df
Expand All @@ -1957,28 +1990,81 @@ async def _run_batch(self, task: Task, generator: ColumnGenerator) -> Any:
"batch generation",
generator.agenerate(batch_df),
)
result_df = self._require_dataframe_result(
task,
"Batch generator",
result_df,
expected_rows=active_rows,
)
if self._task_supports_row_drops(task, generator):
result_df = self._require_expression_row_drop_result(
task,
result_df,
active_row_indices=active_row_indices,
)
else:
result_df = self._require_dataframe_result(
task,
"Batch generator",
result_df,
expected_rows=active_rows,
)

# Merge result columns back to buffer (include side-effect columns)
if self._buffer_manager is not None:
write_cols = self._gen_instance_to_columns_including_side_effects.get(id(generator), [task.column])
result_idx = 0
result_by_row_index = self._batch_result_by_row_index(
result_df,
active_row_indices=active_row_indices,
supports_row_drops=self._task_supports_row_drops(task, generator),
)
for ri in range(rg_size):
if ri in pre_dropped or ri in pre_skipped:
continue
result_row = result_by_row_index.get(ri)
if result_row is None:
self._drop_row(task.row_group, ri, exclude_columns={task.column})
continue
if not self._buffer_manager.is_dropped(task.row_group, ri):
for col in write_cols:
if col in result_df.columns:
self._buffer_manager.update_cell(task.row_group, ri, col, result_df.iloc[result_idx][col])
result_idx += 1
if col in result_row:
self._buffer_manager.update_cell(task.row_group, ri, col, result_row[col])

return result_df

def _require_expression_row_drop_result(
self,
task: Task,
result: Any,
*,
active_row_indices: list[int],
) -> Any:
result_df = self._require_dataframe_result(task, "Batch generator", result)
result_indexes = result_df.index.to_list()
if len(result_indexes) != len(set(result_indexes)):
raise DatasetGenerationError(
f"Batch generator for column '{task.column}' returned duplicate row indexes (rg={task.row_group})."
)
active_index_set = set(active_row_indices)
unexpected_indexes = set(result_indexes) - active_index_set
if unexpected_indexes:
raise DatasetGenerationError(
f"Batch generator for column '{task.column}' returned row indexes outside the active input rows "
f"(rg={task.row_group}): {sorted(unexpected_indexes)!r}."
)
if len(result_df) > len(active_row_indices):
raise DatasetGenerationError(
f"Batch generator for column '{task.column}' returned {len(result_df)} rows "
f"but at most {len(active_row_indices)} active rows were expected (rg={task.row_group})."
)
return result_df

@staticmethod
def _batch_result_by_row_index(
result_df: Any,
*,
active_row_indices: list[int],
supports_row_drops: bool,
) -> dict[int, dict[str, Any]]:
result_records = result_df.to_dict(orient="records")
if supports_row_drops:
return dict(zip(result_df.index.to_list(), result_records, strict=True))
return dict(zip(active_row_indices, result_records, strict=True))

def _get_rg_size(self, row_group: int) -> int:
try:
return self._row_groups.row_group_size(row_group)
Comment on lines +2063 to 2070

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Third guard in _require_expression_row_drop_result is unreachable

Given that the two preceding checks already pass β€” (1) no duplicate result indexes, and (2) every result index is a member of active_index_set β€” the result set is a subset of active_row_indices with no duplicates. Because active_row_indices is itself duplicate-free (built from range(rg_size) minus pre-dropped rows), len(active_index_set) == len(active_row_indices), so len(result_df) <= len(active_row_indices) is guaranteed and the third if can never fire. This is dead code, not a runtime bug, but it may create a false sense of coverage for this guard.

Prompt To Fix With AI
This is a comment left during a code review.
Path: packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py
Line: 2063-2070

Comment:
**Third guard in `_require_expression_row_drop_result` is unreachable**

Given that the two preceding checks already pass β€” (1) no duplicate result indexes, and (2) every result index is a member of `active_index_set` β€” the result set is a subset of `active_row_indices` with no duplicates. Because `active_row_indices` is itself duplicate-free (built from `range(rg_size)` minus pre-dropped rows), `len(active_index_set) == len(active_row_indices)`, so `len(result_df) <= len(active_row_indices)` is guaranteed and the third `if` can never fire. This is dead code, not a runtime bug, but it may create a false sense of coverage for this guard.

How can I resolve this? If you propose a fix, please make it concise.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Expand Down
Loading
Loading