Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1654,10 +1654,16 @@ def _format_worker_failure_warning(cls, exc: Exception, *, context: dict | None
context_label = f" in column {column_name!r}" if column_name else ""
failure_kind = cls._classify_worker_failure(exc)
failure_detail = cls._extract_failure_detail(exc)
return (
warning = (
f"⚠️ Generation for record at index {record_index} failed{context_label} ({failure_kind}). "
f"Will omit this record from the dataset. Detail: {failure_detail}"
)
if getattr(exc, "truncated_by_max_tokens", False):
warning += (
" The model response appears to have been cut off by max_tokens, which caused the parse/recipe "
"failure. Increase inference_parameters.max_tokens in the model config."
)
return warning

def _worker_error_callback(self, exc: Exception, *, context: dict | None = None) -> None:
"""If a worker fails, we can handle the exception here."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from data_designer.engine.models.clients.parsing import extract_usage, fill_reasoning_token_count_from_content
from data_designer.engine.models.clients.types import (
AssistantMessage,
ChatCompletionChoice,
ChatCompletionRequest,
ChatCompletionResponse,
ToolCall,
Expand Down Expand Up @@ -124,7 +125,14 @@ def parse_anthropic_response(response_json: dict[str, Any]) -> ChatCompletionRes
usage = extract_usage(raw_usage)
usage = fill_reasoning_token_count_from_content(usage, message.reasoning_content)

return ChatCompletionResponse(message=message, usage=usage, raw=response_json)
stop_reason = response_json.get("stop_reason")
finish_reason = stop_reason if isinstance(stop_reason, str) else None
return ChatCompletionResponse(
message=message,
usage=usage,
raw=response_json,
choices=[ChatCompletionChoice(message=message, finish_reason=finish_reason)],
)


def translate_request_messages(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,20 @@ class GenerationValidationFailureError(Exception):
summary: str
detail: str | None
failure_kind: str
truncated_by_max_tokens: bool

def __init__(
self,
summary: str,
*,
detail: str | None = None,
failure_kind: str = "validation_error",
truncated_by_max_tokens: bool = False,
) -> None:
self.summary = summary.strip()
self.detail = _normalize_error_detail(detail)
self.failure_kind = failure_kind
self.truncated_by_max_tokens = truncated_by_max_tokens

message = self.summary
if self.detail is not None:
Expand Down Expand Up @@ -115,20 +118,23 @@ class ModelStructuredOutputError(DataDesignerError): ...
class ModelGenerationValidationFailureError(DataDesignerError):
detail: str | None
failure_kind: str | None
truncated_by_max_tokens: bool

def __init__(
self,
message: object | None = None,
*,
detail: str | None = None,
failure_kind: str | None = None,
truncated_by_max_tokens: bool = False,
) -> None:
if message is None:
super().__init__()
else:
super().__init__(message)
self.detail = _normalize_error_detail(detail)
self.failure_kind = failure_kind
self.truncated_by_max_tokens = truncated_by_max_tokens


class ImageGenerationError(DataDesignerError): ...
Expand Down Expand Up @@ -216,16 +222,35 @@ def handle_llm_exceptions(
case GenerationValidationFailureError():
detail_text = exception.detail.rstrip(".") if exception.detail is not None else None
validation_detail = f" Validation detail: {detail_text}." if detail_text is not None else ""
if exception.truncated_by_max_tokens:
cause = (
f"The model output from {model_name!r} could not be parsed into the requested format "
f"while {purpose} because the response appears to have been cut off by max_tokens."
f"{validation_detail}"
)
solution = (
"Increase inference_parameters.max_tokens in the model config and try again. "
"If the failure persists, simplify or modify the output schema."
)
else:
cause = (
f"The model output from {model_name!r} could not be parsed into the requested format "
f"while {purpose}.{validation_detail}"
)
solution = (
"This is most likely temporary as we make additional attempts. If you continue to see more of "
"this, simplify or modify the output schema for structured output and try again. If you are "
"attempting token-intensive tasks like generations with high-reasoning effort, ensure that "
"max_tokens in the model config is high enough to reach completion."
)
raise ModelGenerationValidationFailureError(
FormattedLLMErrorMessage(
cause=(
f"The model output from {model_name!r} could not be parsed into the requested format "
f"while {purpose}.{validation_detail}"
),
solution="This is most likely temporary as we make additional attempts. If you continue to see more of this, simplify or modify the output schema for structured output and try again. If you are attempting token-intensive tasks like generations with high-reasoning effort, ensure that max_tokens in the model config is high enough to reach completion.",
cause=cause,
solution=solution,
),
detail=exception.detail,
failure_kind=exception.failure_kind,
truncated_by_max_tokens=exception.truncated_by_max_tokens,
) from None

case DataDesignerError():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,27 @@ def _identity(x: Any) -> Any:
logger = logging.getLogger(__name__)


_MAX_TOKENS_FINISH_REASONS = frozenset({"length", "max_tokens"})


def _is_max_tokens_finish_reason(value: Any) -> bool:
return isinstance(value, str) and value.strip().lower() in _MAX_TOKENS_FINISH_REASONS


def _get_response_value(source: Any, key: str) -> Any:
if source is None:
return None
if isinstance(source, dict):
return source.get(key)
return getattr(source, key, None)


def _get_first_response_item(values: Any) -> Any | None:
if isinstance(values, list) and values:
return values[0]
return None


def _classify_generation_failure_kind(exc: ParserException) -> str:
detail = " ".join(str(get_exception_primary_cause(exc)).split()).lower()
if "response_schema" in detail or "model_validate" in detail:
Expand All @@ -67,11 +88,34 @@ def _classify_generation_failure_kind(exc: ParserException) -> str:
return "parse_error"


def _build_generation_validation_error(summary: str, exc: ParserException) -> GenerationValidationFailureError:
def _response_was_truncated_by_max_tokens(completion_response: ChatCompletionResponse) -> bool:
if any(_is_max_tokens_finish_reason(choice.finish_reason) for choice in completion_response.choices):
return True

raw_response = completion_response.raw
if raw_response is None:
return False

first_choice = _get_first_response_item(_get_response_value(raw_response, "choices"))
finish_reason = _get_response_value(first_choice, "finish_reason")
if _is_max_tokens_finish_reason(finish_reason):
return True

stop_reason = _get_response_value(raw_response, "stop_reason")
return _is_max_tokens_finish_reason(stop_reason)


def _build_generation_validation_error(
summary: str,
exc: ParserException,
*,
truncated_by_max_tokens: bool = False,
) -> GenerationValidationFailureError:
return GenerationValidationFailureError(
summary,
detail=str(get_exception_primary_cause(exc)),
failure_kind=_classify_generation_failure_kind(exc),
truncated_by_max_tokens=truncated_by_max_tokens,
)


Expand Down Expand Up @@ -400,10 +444,12 @@ def generate(
output_obj = parser(response) # type: ignore - if not a string will cause a ParserException below
break
except ParserException as exc:
truncated_by_max_tokens = _response_was_truncated_by_max_tokens(completion_response)
if max_correction_steps == 0 and max_conversation_restarts == 0:
raise _build_generation_validation_error(
"Unsuccessful generation attempt. No retries were attempted.",
exc,
truncated_by_max_tokens=truncated_by_max_tokens,
) from exc

if parse_attempts <= max_correction_steps:
Expand All @@ -423,6 +469,7 @@ def generate(
f"and {max_conversation_restarts} conversation restarts."
),
exc,
truncated_by_max_tokens=truncated_by_max_tokens,
) from exc

if not skip_usage_tracking and mcp_facade is not None:
Expand Down Expand Up @@ -505,10 +552,12 @@ async def agenerate(
output_obj = parser(response)
break
except ParserException as exc:
truncated_by_max_tokens = _response_was_truncated_by_max_tokens(completion_response)
if max_correction_steps == 0 and max_conversation_restarts == 0:
raise _build_generation_validation_error(
"Unsuccessful generation attempt. No retries were attempted.",
exc,
truncated_by_max_tokens=truncated_by_max_tokens,
) from exc

if parse_attempts <= max_correction_steps:
Expand All @@ -527,6 +576,7 @@ async def agenerate(
f"and {max_conversation_restarts} conversation restarts."
),
exc,
truncated_by_max_tokens=truncated_by_max_tokens,
) from exc

if not skip_usage_tracking and mcp_facade is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
from data_designer.engine.dataset_builders.utils.row_group_buffer import RowGroupBufferManager
from data_designer.engine.models.errors import (
RETRYABLE_MODEL_ERRORS,
FormattedLLMErrorMessage,
ModelGenerationValidationFailureError,
ModelInternalServerError,
ModelRateLimitError,
ModelRequestAdmissionTimeoutError,
Expand Down Expand Up @@ -199,6 +201,29 @@ def generate(self, data: dict) -> dict:
return data


class MockTruncatedParseFailureGenerator(ColumnGenerator[ExpressionColumnConfig]):
"""Generator that simulates a parser failure caused by max_tokens truncation."""

@staticmethod
def get_generation_strategy() -> GenerationStrategy:
return GenerationStrategy.CELL_BY_CELL

def generate(self, _data: dict) -> dict:
raise ModelGenerationValidationFailureError(
FormattedLLMErrorMessage(
cause=(
"The model output from 'test-model' could not be parsed into the requested format while "
"running generation for column 'fail_col' because the response appears to have been cut off "
"by max_tokens. Validation detail: Unterminated JSON object."
),
solution="Increase inference_parameters.max_tokens in the model config and try again.",
),
detail="Unterminated JSON object.",
failure_kind="parse_error",
truncated_by_max_tokens=True,
)


class MockBuggyGenerator(ColumnGenerator[ExpressionColumnConfig]):
"""Generator that raises a bare built-in exception from generator code."""

Expand Down Expand Up @@ -736,6 +761,41 @@ async def test_scheduler_non_retryable_failure_drops_row() -> None:
assert tracker.is_row_group_complete(0, 2, ["seed", "fail_col"])


@pytest.mark.asyncio(loop_scope="session")
async def test_scheduler_logs_max_tokens_truncation_guidance(caplog: pytest.LogCaptureFixture) -> None:
provider = _mock_provider()
configs = [
SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}),
LLMTextColumnConfig(name="fail_col", prompt="{{ seed }}", model_alias=MODEL_ALIAS),
]
strategies = {
"seed": GenerationStrategy.FULL_COLUMN,
"fail_col": GenerationStrategy.CELL_BY_CELL,
}
generators = {
"seed": MockSeedGenerator(config=_expr_config("seed"), resource_provider=provider),
"fail_col": MockTruncatedParseFailureGenerator(config=_expr_config("fail_col"), resource_provider=provider),
}

graph = ExecutionGraph.create(configs, strategies)
row_groups = [(0, 1)]
tracker = CompletionTracker.with_graph(graph, row_groups)

scheduler = AsyncTaskScheduler(
generators=generators,
graph=graph,
tracker=tracker,
row_groups=row_groups,
)
with caplog.at_level(logging.WARNING):
await scheduler.run()

assert tracker.is_dropped(0, 0)
assert "Non-retryable failure on fail_col[rg=0, row=0]" in caplog.text
assert "cut off by max_tokens" in caplog.text
assert "Increase inference_parameters.max_tokens in the model config" in caplog.text


def test_scheduler_internal_bug_classifier_preserves_scheduler_builtin_failures() -> None:
scheduler, tracker = _build_simple_pipeline(num_records=1)
assert scheduler._is_internal_bug(KeyError("missing internal key"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,36 @@ def test_worker_error_callback_logs_timeout_detail(
assert 17 in stub_dataset_builder._records_to_drop


def test_worker_error_callback_logs_max_tokens_truncation_guidance(
stub_dataset_builder: DatasetBuilder,
caplog: pytest.LogCaptureFixture,
) -> None:
exc = ModelGenerationValidationFailureError(
FormattedLLMErrorMessage(
cause=(
"The model output from 'test-model' could not be parsed into the requested format while "
"running generation for column 'test_column' because the response appears to have been cut off "
"by max_tokens. Validation detail: Unterminated JSON object."
),
solution="Increase inference_parameters.max_tokens in the model config and try again.",
),
detail="Unterminated JSON object.",
failure_kind="parse_error",
truncated_by_max_tokens=True,
)

with caplog.at_level(logging.WARNING):
stub_dataset_builder._worker_error_callback(exc, context={"index": 33, "column_name": "test_column"})

assert "record at index 33" in caplog.text
assert "column 'test_column'" in caplog.text
assert "(parse error)" in caplog.text
assert "Unterminated JSON object." in caplog.text
assert "cut off by max_tokens" in caplog.text
assert "Increase inference_parameters.max_tokens in the model config." in caplog.text
assert 33 in stub_dataset_builder._records_to_drop


def test_worker_error_callback_requires_context_index(
stub_dataset_builder: DatasetBuilder,
caplog: pytest.LogCaptureFixture,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@ def _make_client(
# --- Response helpers ---


def _text_response(text: str = "Hello!") -> dict[str, Any]:
def _text_response(text: str = "Hello!", stop_reason: str = "end_turn") -> dict[str, Any]:
return {
"content": [{"type": "text", "text": text}],
"usage": {"input_tokens": 10, "output_tokens": 5},
"stop_reason": "end_turn",
"stop_reason": stop_reason,
}


Expand Down Expand Up @@ -92,18 +92,29 @@ def test_completion_maps_text_content() -> None:
result = client.completion(request)

assert result.message.content == "Hello from Claude!"
assert result.choices[0].finish_reason == "end_turn"
assert result.usage is not None
assert result.usage.input_tokens == 10
assert result.usage.output_tokens == 5


def test_completion_maps_max_tokens_stop_reason() -> None:
client = _make_client(sync_client=make_mock_sync_client(_text_response(stop_reason="max_tokens")))

request = ChatCompletionRequest(model=MODEL, messages=[{"role": "user", "content": "Hi"}])
result = client.completion(request)

assert result.choices[0].finish_reason == "max_tokens"


def test_completion_maps_tool_use_blocks() -> None:
client = _make_client(sync_client=make_mock_sync_client(_tool_use_response()))

request = ChatCompletionRequest(model=MODEL, messages=[{"role": "user", "content": "Weather?"}])
result = client.completion(request)

assert result.message.content == "Let me search for that."
assert result.choices[0].finish_reason == "tool_use"
assert len(result.message.tool_calls) == 1
assert result.message.tool_calls[0].id == "toolu_01"
assert result.message.tool_calls[0].name == "search"
Expand Down
Loading
Loading