From 8e405953c51ce61ee316be7a279a8357d182a5d3 Mon Sep 17 00:00:00 2001 From: stepwise-ai-dev Date: Tue, 31 Mar 2026 13:27:49 -0400 Subject: [PATCH 1/2] fix: improve truncation-aware parse failure logging --- .../dataset_builders/dataset_builder.py | 8 +- .../src/data_designer/engine/models/errors.py | 35 ++++++- .../src/data_designer/engine/models/facade.py | 42 +++++++- .../dataset_builders/test_dataset_builder.py | 30 ++++++ .../tests/engine/models/test_facade.py | 97 ++++++++++++++++++- .../tests/engine/models/test_model_errors.py | 22 +++++ 6 files changed, 225 insertions(+), 9 deletions(-) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py index 1bbd51df7..adde49760 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py @@ -1654,10 +1654,16 @@ def _format_worker_failure_warning(cls, exc: Exception, *, context: dict | None context_label = f" in column {column_name!r}" if column_name else "" failure_kind = cls._classify_worker_failure(exc) failure_detail = cls._extract_failure_detail(exc) - return ( + warning = ( f"⚠️ Generation for record at index {record_index} failed{context_label} ({failure_kind}). " f"Will omit this record from the dataset. Detail: {failure_detail}" ) + if getattr(exc, "truncated_by_max_tokens", False): + warning += ( + " The model response appears to have been cut off by max_tokens, which caused the parse/recipe " + "failure. Increase inference_parameters.max_tokens in the model config." + ) + return warning def _worker_error_callback(self, exc: Exception, *, context: dict | None = None) -> None: """If a worker fails, we can handle the exception here.""" diff --git a/packages/data-designer-engine/src/data_designer/engine/models/errors.py b/packages/data-designer-engine/src/data_designer/engine/models/errors.py index c289d5c62..3c6570656 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/errors.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/errors.py @@ -45,6 +45,7 @@ class GenerationValidationFailureError(Exception): summary: str detail: str | None failure_kind: str + truncated_by_max_tokens: bool def __init__( self, @@ -52,10 +53,12 @@ def __init__( *, detail: str | None = None, failure_kind: str = "validation_error", + truncated_by_max_tokens: bool = False, ) -> None: self.summary = summary.strip() self.detail = _normalize_error_detail(detail) self.failure_kind = failure_kind + self.truncated_by_max_tokens = truncated_by_max_tokens message = self.summary if self.detail is not None: @@ -115,6 +118,7 @@ class ModelStructuredOutputError(DataDesignerError): ... class ModelGenerationValidationFailureError(DataDesignerError): detail: str | None failure_kind: str | None + truncated_by_max_tokens: bool def __init__( self, @@ -122,6 +126,7 @@ def __init__( *, detail: str | None = None, failure_kind: str | None = None, + truncated_by_max_tokens: bool = False, ) -> None: if message is None: super().__init__() @@ -129,6 +134,7 @@ def __init__( super().__init__(message) self.detail = _normalize_error_detail(detail) self.failure_kind = failure_kind + self.truncated_by_max_tokens = truncated_by_max_tokens class ImageGenerationError(DataDesignerError): ... @@ -216,16 +222,35 @@ def handle_llm_exceptions( case GenerationValidationFailureError(): detail_text = exception.detail.rstrip(".") if exception.detail is not None else None validation_detail = f" Validation detail: {detail_text}." if detail_text is not None else "" + if exception.truncated_by_max_tokens: + cause = ( + f"The model output from {model_name!r} could not be parsed into the requested format " + f"while {purpose} because the response appears to have been cut off by max_tokens." + f"{validation_detail}" + ) + solution = ( + "Increase inference_parameters.max_tokens in the model config and try again. " + "If the failure persists, simplify or modify the output schema." + ) + else: + cause = ( + f"The model output from {model_name!r} could not be parsed into the requested format " + f"while {purpose}.{validation_detail}" + ) + solution = ( + "This is most likely temporary as we make additional attempts. If you continue to see more of " + "this, simplify or modify the output schema for structured output and try again. If you are " + "attempting token-intensive tasks like generations with high-reasoning effort, ensure that " + "max_tokens in the model config is high enough to reach completion." + ) raise ModelGenerationValidationFailureError( FormattedLLMErrorMessage( - cause=( - f"The model output from {model_name!r} could not be parsed into the requested format " - f"while {purpose}.{validation_detail}" - ), - solution="This is most likely temporary as we make additional attempts. If you continue to see more of this, simplify or modify the output schema for structured output and try again. If you are attempting token-intensive tasks like generations with high-reasoning effort, ensure that max_tokens in the model config is high enough to reach completion.", + cause=cause, + solution=solution, ), detail=exception.detail, failure_kind=exception.failure_kind, + truncated_by_max_tokens=exception.truncated_by_max_tokens, ) from None case DataDesignerError(): diff --git a/packages/data-designer-engine/src/data_designer/engine/models/facade.py b/packages/data-designer-engine/src/data_designer/engine/models/facade.py index 2c0a7a9ab..e766b6496 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/facade.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/facade.py @@ -58,6 +58,20 @@ def _identity(x: Any) -> Any: logger = logging.getLogger(__name__) +def _get_response_value(source: Any, key: str) -> Any: + if source is None: + return None + if isinstance(source, dict): + return source.get(key) + return getattr(source, key, None) + + +def _get_first_response_item(values: Any) -> Any | None: + if isinstance(values, list) and values: + return values[0] + return None + + def _classify_generation_failure_kind(exc: ParserException) -> str: detail = " ".join(str(get_exception_primary_cause(exc)).split()).lower() if "response_schema" in detail or "model_validate" in detail: @@ -67,11 +81,31 @@ def _classify_generation_failure_kind(exc: ParserException) -> str: return "parse_error" -def _build_generation_validation_error(summary: str, exc: ParserException) -> GenerationValidationFailureError: +def _response_was_truncated_by_max_tokens(completion_response: ChatCompletionResponse) -> bool: + raw_response = completion_response.raw + if raw_response is None: + return False + + first_choice = _get_first_response_item(_get_response_value(raw_response, "choices")) + finish_reason = _get_response_value(first_choice, "finish_reason") + if isinstance(finish_reason, str) and finish_reason.strip().lower() == "length": + return True + + stop_reason = _get_response_value(raw_response, "stop_reason") + return isinstance(stop_reason, str) and stop_reason.strip().lower() == "max_tokens" + + +def _build_generation_validation_error( + summary: str, + exc: ParserException, + *, + truncated_by_max_tokens: bool = False, +) -> GenerationValidationFailureError: return GenerationValidationFailureError( summary, detail=str(get_exception_primary_cause(exc)), failure_kind=_classify_generation_failure_kind(exc), + truncated_by_max_tokens=truncated_by_max_tokens, ) @@ -400,10 +434,12 @@ def generate( output_obj = parser(response) # type: ignore - if not a string will cause a ParserException below break except ParserException as exc: + truncated_by_max_tokens = _response_was_truncated_by_max_tokens(completion_response) if max_correction_steps == 0 and max_conversation_restarts == 0: raise _build_generation_validation_error( "Unsuccessful generation attempt. No retries were attempted.", exc, + truncated_by_max_tokens=truncated_by_max_tokens, ) from exc if parse_attempts <= max_correction_steps: @@ -423,6 +459,7 @@ def generate( f"and {max_conversation_restarts} conversation restarts." ), exc, + truncated_by_max_tokens=truncated_by_max_tokens, ) from exc if not skip_usage_tracking and mcp_facade is not None: @@ -505,10 +542,12 @@ async def agenerate( output_obj = parser(response) break except ParserException as exc: + truncated_by_max_tokens = _response_was_truncated_by_max_tokens(completion_response) if max_correction_steps == 0 and max_conversation_restarts == 0: raise _build_generation_validation_error( "Unsuccessful generation attempt. No retries were attempted.", exc, + truncated_by_max_tokens=truncated_by_max_tokens, ) from exc if parse_attempts <= max_correction_steps: @@ -527,6 +566,7 @@ async def agenerate( f"and {max_conversation_restarts} conversation restarts." ), exc, + truncated_by_max_tokens=truncated_by_max_tokens, ) from exc if not skip_usage_tracking and mcp_facade is not None: diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/test_dataset_builder.py b/packages/data-designer-engine/tests/engine/dataset_builders/test_dataset_builder.py index 0a0f192b4..8d4fb17b1 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/test_dataset_builder.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/test_dataset_builder.py @@ -236,6 +236,36 @@ def test_worker_error_callback_logs_timeout_detail( assert 17 in stub_dataset_builder._records_to_drop +def test_worker_error_callback_logs_max_tokens_truncation_guidance( + stub_dataset_builder: DatasetBuilder, + caplog: pytest.LogCaptureFixture, +) -> None: + exc = ModelGenerationValidationFailureError( + FormattedLLMErrorMessage( + cause=( + "The model output from 'test-model' could not be parsed into the requested format while " + "running generation for column 'test_column' because the response appears to have been cut off " + "by max_tokens. Validation detail: Unterminated JSON object." + ), + solution="Increase inference_parameters.max_tokens in the model config and try again.", + ), + detail="Unterminated JSON object.", + failure_kind="parse_error", + truncated_by_max_tokens=True, + ) + + with caplog.at_level(logging.WARNING): + stub_dataset_builder._worker_error_callback(exc, context={"index": 33, "column_name": "test_column"}) + + assert "record at index 33" in caplog.text + assert "column 'test_column'" in caplog.text + assert "(parse error)" in caplog.text + assert "Unterminated JSON object." in caplog.text + assert "cut off by max_tokens" in caplog.text + assert "Increase inference_parameters.max_tokens in the model config." in caplog.text + assert 33 in stub_dataset_builder._records_to_drop + + def test_worker_error_callback_requires_context_index( stub_dataset_builder: DatasetBuilder, caplog: pytest.LogCaptureFixture, diff --git a/packages/data-designer-engine/tests/engine/models/test_facade.py b/packages/data-designer-engine/tests/engine/models/test_facade.py index 0be33bd02..35c89a13e 100644 --- a/packages/data-designer-engine/tests/engine/models/test_facade.py +++ b/packages/data-designer-engine/tests/engine/models/test_facade.py @@ -32,9 +32,11 @@ from data_designer.engine.testing import StubMCPFacade, StubMCPRegistry, make_stub_completion_response -def _make_response(content: str | None = None, **kwargs: Any) -> ChatCompletionResponse: +def _make_response(content: str | None = None, raw: Any | None = None, **kwargs: Any) -> ChatCompletionResponse: """Shorthand for creating a ChatCompletionResponse in tests.""" - return make_stub_completion_response(content=content, **kwargs) + response = make_stub_completion_response(content=content, **kwargs) + response.raw = raw + return response def _assert_no_multi_choice_request( @@ -260,6 +262,97 @@ def _failing_parser(response: str) -> str: assert exc_info.value.failure_kind == "schema_validation" +@pytest.mark.parametrize( + ("raw_response", "expected_truncated"), + [ + ({"choices": [{"finish_reason": "length"}]}, True), + ({"stop_reason": "max_tokens"}, True), + ({"choices": [{"finish_reason": "stop"}]}, False), + ({"stop_reason": "end_turn"}, False), + (None, False), + ], + ids=[ + "openai_length", + "anthropic_max_tokens", + "openai_stop", + "anthropic_end_turn", + "missing_raw", + ], +) +@patch.object(ModelFacade, "completion", autospec=True) +def test_generate_sets_truncation_metadata_on_parser_failure( + mock_completion: Any, + stub_model_facade: ModelFacade, + raw_response: dict[str, Any] | None, + expected_truncated: bool, +) -> None: + mock_completion.return_value = _make_response("bad response", raw=raw_response) + + def _failing_parser(response: str) -> str: + raise ParserException("Response doesn't match requested \n'name' is a required property") + + with pytest.raises(ModelGenerationValidationFailureError) as exc_info: + stub_model_facade.generate( + prompt="foo", + parser=_failing_parser, + max_correction_steps=0, + max_conversation_restarts=0, + ) + + assert exc_info.value.truncated_by_max_tokens is expected_truncated + if expected_truncated: + assert "cut off by max_tokens" in str(exc_info.value) + assert "Increase inference_parameters.max_tokens in the model config" in str(exc_info.value) + else: + assert "cut off by max_tokens" not in str(exc_info.value) + + +@pytest.mark.parametrize( + ("raw_response", "expected_truncated"), + [ + ({"choices": [{"finish_reason": "length"}]}, True), + ({"stop_reason": "max_tokens"}, True), + ({"choices": [{"finish_reason": "stop"}]}, False), + ({"stop_reason": "end_turn"}, False), + (None, False), + ], + ids=[ + "openai_length", + "anthropic_max_tokens", + "openai_stop", + "anthropic_end_turn", + "missing_raw", + ], +) +@patch.object(ModelFacade, "acompletion", new_callable=AsyncMock) +@pytest.mark.asyncio +async def test_agenerate_sets_truncation_metadata_on_parser_failure( + mock_acompletion: AsyncMock, + stub_model_facade: ModelFacade, + raw_response: dict[str, Any] | None, + expected_truncated: bool, +) -> None: + mock_acompletion.return_value = _make_response("bad response", raw=raw_response) + + def _failing_parser(response: str) -> str: + raise ParserException("Response doesn't match requested \n'name' is a required property") + + with pytest.raises(ModelGenerationValidationFailureError) as exc_info: + await stub_model_facade.agenerate( + prompt="foo", + parser=_failing_parser, + max_correction_steps=0, + max_conversation_restarts=0, + ) + + assert exc_info.value.truncated_by_max_tokens is expected_truncated + if expected_truncated: + assert "cut off by max_tokens" in str(exc_info.value) + assert "Increase inference_parameters.max_tokens in the model config" in str(exc_info.value) + else: + assert "cut off by max_tokens" not in str(exc_info.value) + + @pytest.mark.parametrize( "raw_content,expected", [ diff --git a/packages/data-designer-engine/tests/engine/models/test_model_errors.py b/packages/data-designer-engine/tests/engine/models/test_model_errors.py index 8873aceba..e0f3c9e35 100644 --- a/packages/data-designer-engine/tests/engine/models/test_model_errors.py +++ b/packages/data-designer-engine/tests/engine/models/test_model_errors.py @@ -246,6 +246,28 @@ def test_handle_llm_exceptions_preserves_generation_failure_kind() -> None: assert exc_info.value.failure_kind == "schema_validation" assert exc_info.value.detail == "Response doesn't match requested : 'name' is a required property" + assert exc_info.value.truncated_by_max_tokens is False + + +def test_handle_llm_exceptions_emits_truncation_specific_generation_failure_message() -> None: + with pytest.raises(ModelGenerationValidationFailureError) as exc_info: + handle_llm_exceptions( + GenerationValidationFailureError( + "Generation validation failure", + detail="Response doesn't match requested : 'name' is a required property", + failure_kind="schema_validation", + truncated_by_max_tokens=True, + ), + stub_model_name, + stub_model_provider_name, + stub_purpose, + ) + + assert exc_info.value.failure_kind == "schema_validation" + assert exc_info.value.detail == "Response doesn't match requested : 'name' is a required property" + assert exc_info.value.truncated_by_max_tokens is True + assert "cut off by max_tokens" in str(exc_info.value) + assert "Increase inference_parameters.max_tokens in the model config" in str(exc_info.value) def test_handle_llm_exceptions_strips_duplicate_period_from_validation_detail() -> None: From 84119850d1ab7985f42ec910ef1c95bc94c60cb8 Mon Sep 17 00:00:00 2001 From: stepwise-ai-dev Date: Tue, 16 Jun 2026 17:27:03 -0400 Subject: [PATCH 2/2] fix: use finish reasons for truncation guidance Normalize Anthropic stop reasons into completion choices and prefer canonical finish_reason metadata when detecting max_tokens truncation. Add async scheduler coverage so dropped rows retain the actionable max_tokens guidance. --- .../clients/adapters/anthropic_translation.py | 10 +++- .../src/data_designer/engine/models/facade.py | 14 ++++- .../dataset_builders/test_async_scheduler.py | 60 +++++++++++++++++++ .../engine/models/clients/test_anthropic.py | 15 ++++- .../tests/engine/models/test_facade.py | 58 +++++++++++------- 5 files changed, 129 insertions(+), 28 deletions(-) diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/anthropic_translation.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/anthropic_translation.py index 5ba186add..20e142b4e 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/anthropic_translation.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/anthropic_translation.py @@ -15,6 +15,7 @@ from data_designer.engine.models.clients.parsing import extract_usage, fill_reasoning_token_count_from_content from data_designer.engine.models.clients.types import ( AssistantMessage, + ChatCompletionChoice, ChatCompletionRequest, ChatCompletionResponse, ToolCall, @@ -124,7 +125,14 @@ def parse_anthropic_response(response_json: dict[str, Any]) -> ChatCompletionRes usage = extract_usage(raw_usage) usage = fill_reasoning_token_count_from_content(usage, message.reasoning_content) - return ChatCompletionResponse(message=message, usage=usage, raw=response_json) + stop_reason = response_json.get("stop_reason") + finish_reason = stop_reason if isinstance(stop_reason, str) else None + return ChatCompletionResponse( + message=message, + usage=usage, + raw=response_json, + choices=[ChatCompletionChoice(message=message, finish_reason=finish_reason)], + ) def translate_request_messages( diff --git a/packages/data-designer-engine/src/data_designer/engine/models/facade.py b/packages/data-designer-engine/src/data_designer/engine/models/facade.py index e766b6496..e6f7ca69e 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/facade.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/facade.py @@ -58,6 +58,13 @@ def _identity(x: Any) -> Any: logger = logging.getLogger(__name__) +_MAX_TOKENS_FINISH_REASONS = frozenset({"length", "max_tokens"}) + + +def _is_max_tokens_finish_reason(value: Any) -> bool: + return isinstance(value, str) and value.strip().lower() in _MAX_TOKENS_FINISH_REASONS + + def _get_response_value(source: Any, key: str) -> Any: if source is None: return None @@ -82,17 +89,20 @@ def _classify_generation_failure_kind(exc: ParserException) -> str: def _response_was_truncated_by_max_tokens(completion_response: ChatCompletionResponse) -> bool: + if any(_is_max_tokens_finish_reason(choice.finish_reason) for choice in completion_response.choices): + return True + raw_response = completion_response.raw if raw_response is None: return False first_choice = _get_first_response_item(_get_response_value(raw_response, "choices")) finish_reason = _get_response_value(first_choice, "finish_reason") - if isinstance(finish_reason, str) and finish_reason.strip().lower() == "length": + if _is_max_tokens_finish_reason(finish_reason): return True stop_reason = _get_response_value(raw_response, "stop_reason") - return isinstance(stop_reason, str) and stop_reason.strip().lower() == "max_tokens" + return _is_max_tokens_finish_reason(stop_reason) def _build_generation_validation_error( diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py index bb5cf5685..a3a523920 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py @@ -49,6 +49,8 @@ from data_designer.engine.dataset_builders.utils.row_group_buffer import RowGroupBufferManager from data_designer.engine.models.errors import ( RETRYABLE_MODEL_ERRORS, + FormattedLLMErrorMessage, + ModelGenerationValidationFailureError, ModelInternalServerError, ModelRateLimitError, ModelRequestAdmissionTimeoutError, @@ -199,6 +201,29 @@ def generate(self, data: dict) -> dict: return data +class MockTruncatedParseFailureGenerator(ColumnGenerator[ExpressionColumnConfig]): + """Generator that simulates a parser failure caused by max_tokens truncation.""" + + @staticmethod + def get_generation_strategy() -> GenerationStrategy: + return GenerationStrategy.CELL_BY_CELL + + def generate(self, _data: dict) -> dict: + raise ModelGenerationValidationFailureError( + FormattedLLMErrorMessage( + cause=( + "The model output from 'test-model' could not be parsed into the requested format while " + "running generation for column 'fail_col' because the response appears to have been cut off " + "by max_tokens. Validation detail: Unterminated JSON object." + ), + solution="Increase inference_parameters.max_tokens in the model config and try again.", + ), + detail="Unterminated JSON object.", + failure_kind="parse_error", + truncated_by_max_tokens=True, + ) + + class MockBuggyGenerator(ColumnGenerator[ExpressionColumnConfig]): """Generator that raises a bare built-in exception from generator code.""" @@ -736,6 +761,41 @@ async def test_scheduler_non_retryable_failure_drops_row() -> None: assert tracker.is_row_group_complete(0, 2, ["seed", "fail_col"]) +@pytest.mark.asyncio(loop_scope="session") +async def test_scheduler_logs_max_tokens_truncation_guidance(caplog: pytest.LogCaptureFixture) -> None: + provider = _mock_provider() + configs = [ + SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + LLMTextColumnConfig(name="fail_col", prompt="{{ seed }}", model_alias=MODEL_ALIAS), + ] + strategies = { + "seed": GenerationStrategy.FULL_COLUMN, + "fail_col": GenerationStrategy.CELL_BY_CELL, + } + generators = { + "seed": MockSeedGenerator(config=_expr_config("seed"), resource_provider=provider), + "fail_col": MockTruncatedParseFailureGenerator(config=_expr_config("fail_col"), resource_provider=provider), + } + + graph = ExecutionGraph.create(configs, strategies) + row_groups = [(0, 1)] + tracker = CompletionTracker.with_graph(graph, row_groups) + + scheduler = AsyncTaskScheduler( + generators=generators, + graph=graph, + tracker=tracker, + row_groups=row_groups, + ) + with caplog.at_level(logging.WARNING): + await scheduler.run() + + assert tracker.is_dropped(0, 0) + assert "Non-retryable failure on fail_col[rg=0, row=0]" in caplog.text + assert "cut off by max_tokens" in caplog.text + assert "Increase inference_parameters.max_tokens in the model config" in caplog.text + + def test_scheduler_internal_bug_classifier_preserves_scheduler_builtin_failures() -> None: scheduler, tracker = _build_simple_pipeline(num_records=1) assert scheduler._is_internal_bug(KeyError("missing internal key")) diff --git a/packages/data-designer-engine/tests/engine/models/clients/test_anthropic.py b/packages/data-designer-engine/tests/engine/models/clients/test_anthropic.py index b0c71cc5a..36847de99 100644 --- a/packages/data-designer-engine/tests/engine/models/clients/test_anthropic.py +++ b/packages/data-designer-engine/tests/engine/models/clients/test_anthropic.py @@ -47,11 +47,11 @@ def _make_client( # --- Response helpers --- -def _text_response(text: str = "Hello!") -> dict[str, Any]: +def _text_response(text: str = "Hello!", stop_reason: str = "end_turn") -> dict[str, Any]: return { "content": [{"type": "text", "text": text}], "usage": {"input_tokens": 10, "output_tokens": 5}, - "stop_reason": "end_turn", + "stop_reason": stop_reason, } @@ -92,11 +92,21 @@ def test_completion_maps_text_content() -> None: result = client.completion(request) assert result.message.content == "Hello from Claude!" + assert result.choices[0].finish_reason == "end_turn" assert result.usage is not None assert result.usage.input_tokens == 10 assert result.usage.output_tokens == 5 +def test_completion_maps_max_tokens_stop_reason() -> None: + client = _make_client(sync_client=make_mock_sync_client(_text_response(stop_reason="max_tokens"))) + + request = ChatCompletionRequest(model=MODEL, messages=[{"role": "user", "content": "Hi"}]) + result = client.completion(request) + + assert result.choices[0].finish_reason == "max_tokens" + + def test_completion_maps_tool_use_blocks() -> None: client = _make_client(sync_client=make_mock_sync_client(_tool_use_response())) @@ -104,6 +114,7 @@ def test_completion_maps_tool_use_blocks() -> None: result = client.completion(request) assert result.message.content == "Let me search for that." + assert result.choices[0].finish_reason == "tool_use" assert len(result.message.tool_calls) == 1 assert result.message.tool_calls[0].id == "toolu_01" assert result.message.tool_calls[0].name == "search" diff --git a/packages/data-designer-engine/tests/engine/models/test_facade.py b/packages/data-designer-engine/tests/engine/models/test_facade.py index 35c89a13e..31fdd8460 100644 --- a/packages/data-designer-engine/tests/engine/models/test_facade.py +++ b/packages/data-designer-engine/tests/engine/models/test_facade.py @@ -32,10 +32,16 @@ from data_designer.engine.testing import StubMCPFacade, StubMCPRegistry, make_stub_completion_response -def _make_response(content: str | None = None, raw: Any | None = None, **kwargs: Any) -> ChatCompletionResponse: +def _make_response( + content: str | None = None, + raw: Any | None = None, + finish_reason: str | None = None, + **kwargs: Any, +) -> ChatCompletionResponse: """Shorthand for creating a ChatCompletionResponse in tests.""" response = make_stub_completion_response(content=content, **kwargs) response.raw = raw + response.choices[0].finish_reason = finish_reason return response @@ -263,19 +269,21 @@ def _failing_parser(response: str) -> str: @pytest.mark.parametrize( - ("raw_response", "expected_truncated"), + ("finish_reason", "raw_response", "expected_truncated"), [ - ({"choices": [{"finish_reason": "length"}]}, True), - ({"stop_reason": "max_tokens"}, True), - ({"choices": [{"finish_reason": "stop"}]}, False), - ({"stop_reason": "end_turn"}, False), - (None, False), + ("length", None, True), + ("max_tokens", None, True), + ("stop", None, False), + (None, {"choices": [{"finish_reason": "length"}]}, True), + (None, {"stop_reason": "max_tokens"}, True), + (None, None, False), ], ids=[ - "openai_length", - "anthropic_max_tokens", - "openai_stop", - "anthropic_end_turn", + "canonical_openai_length", + "canonical_anthropic_max_tokens", + "canonical_stop", + "raw_openai_length_fallback", + "raw_anthropic_max_tokens_fallback", "missing_raw", ], ) @@ -283,10 +291,11 @@ def _failing_parser(response: str) -> str: def test_generate_sets_truncation_metadata_on_parser_failure( mock_completion: Any, stub_model_facade: ModelFacade, + finish_reason: str | None, raw_response: dict[str, Any] | None, expected_truncated: bool, ) -> None: - mock_completion.return_value = _make_response("bad response", raw=raw_response) + mock_completion.return_value = _make_response("bad response", raw=raw_response, finish_reason=finish_reason) def _failing_parser(response: str) -> str: raise ParserException("Response doesn't match requested \n'name' is a required property") @@ -308,19 +317,21 @@ def _failing_parser(response: str) -> str: @pytest.mark.parametrize( - ("raw_response", "expected_truncated"), + ("finish_reason", "raw_response", "expected_truncated"), [ - ({"choices": [{"finish_reason": "length"}]}, True), - ({"stop_reason": "max_tokens"}, True), - ({"choices": [{"finish_reason": "stop"}]}, False), - ({"stop_reason": "end_turn"}, False), - (None, False), + ("length", None, True), + ("max_tokens", None, True), + ("stop", None, False), + (None, {"choices": [{"finish_reason": "length"}]}, True), + (None, {"stop_reason": "max_tokens"}, True), + (None, None, False), ], ids=[ - "openai_length", - "anthropic_max_tokens", - "openai_stop", - "anthropic_end_turn", + "canonical_openai_length", + "canonical_anthropic_max_tokens", + "canonical_stop", + "raw_openai_length_fallback", + "raw_anthropic_max_tokens_fallback", "missing_raw", ], ) @@ -329,10 +340,11 @@ def _failing_parser(response: str) -> str: async def test_agenerate_sets_truncation_metadata_on_parser_failure( mock_acompletion: AsyncMock, stub_model_facade: ModelFacade, + finish_reason: str | None, raw_response: dict[str, Any] | None, expected_truncated: bool, ) -> None: - mock_acompletion.return_value = _make_response("bad response", raw=raw_response) + mock_acompletion.return_value = _make_response("bad response", raw=raw_response, finish_reason=finish_reason) def _failing_parser(response: str) -> str: raise ParserException("Response doesn't match requested \n'name' is a required property")