From 2b3c629e64a31aacace97d3377ffd42add14cb65 Mon Sep 17 00:00:00 2001 From: schultzjack Date: Tue, 16 Jun 2026 20:30:14 -0400 Subject: [PATCH] fix: coerce judge score drift Signed-off-by: schultzjack --- .../utils/judge_score_factory.py | 43 +++++++++++++++- .../utils/test_judge_score_factory.py | 51 ++++++++++++++++++- 2 files changed, 92 insertions(+), 2 deletions(-) diff --git a/packages/data-designer-engine/src/data_designer/engine/column_generators/utils/judge_score_factory.py b/packages/data-designer-engine/src/data_designer/engine/column_generators/utils/judge_score_factory.py index 0ee030d0d..bc2f4bd83 100644 --- a/packages/data-designer-engine/src/data_designer/engine/column_generators/utils/judge_score_factory.py +++ b/packages/data-designer-engine/src/data_designer/engine/column_generators/utils/judge_score_factory.py @@ -4,8 +4,9 @@ from __future__ import annotations from enum import Enum +from typing import Any -from pydantic import BaseModel, ConfigDict, Field, create_model +from pydantic import BaseModel, ConfigDict, Field, create_model, model_validator from data_designer.config.column_configs import Score @@ -13,12 +14,52 @@ SCORE_FIELD_DESCRIPTION_FORMAT = "Score Descriptions for {enum_name}:\n{scoring}" +def _normalize_score_value(value: Any) -> str: + if isinstance(value, str): + return value.strip().casefold() + if isinstance(value, float) and value.is_integer(): + return str(int(value)) + return str(value).strip().casefold() + + +def _coerce_score_value(value: Any, enum_type: type[Enum]) -> Any: + for member in enum_type: + if isinstance(value, bool) != isinstance(member.value, bool): + continue + if value == member.value: + return value + + normalized_value = _normalize_score_value(value) + matches = [member.value for member in enum_type if _normalize_score_value(member.value) == normalized_value] + if len(matches) == 1: + return matches[0] + return value + + class BaseJudgeResponse(BaseModel): """Base model for all rubrics.""" model_config = ConfigDict(use_enum_values=True) reasoning: str = Field(..., description="Reasoning for the assigned score.") + @model_validator(mode="before") + @classmethod + def coerce_score(cls, data: Any) -> Any: + if not isinstance(data, dict) or "score" not in data: + return data + + score_field = cls.model_fields.get("score") + if score_field is None: + return data + + score_type = score_field.annotation + if not isinstance(score_type, type) or not issubclass(score_type, Enum): + return data + + coerced_data = data.copy() + coerced_data["score"] = _coerce_score_value(data["score"], score_type) + return coerced_data + def _stringify_scoring(options: dict, enum_type: type[Enum]) -> str: """Convert score descriptions into a single text block.""" diff --git a/packages/data-designer-engine/tests/engine/column_generators/utils/test_judge_score_factory.py b/packages/data-designer-engine/tests/engine/column_generators/utils/test_judge_score_factory.py index 8f66e5c25..125ee2a7b 100644 --- a/packages/data-designer-engine/tests/engine/column_generators/utils/test_judge_score_factory.py +++ b/packages/data-designer-engine/tests/engine/column_generators/utils/test_judge_score_factory.py @@ -4,7 +4,7 @@ from enum import Enum import pytest -from pydantic import BaseModel +from pydantic import BaseModel, ValidationError from data_designer.config.column_configs import Score from data_designer.engine.column_generators.utils.judge_score_factory import ( @@ -61,6 +61,27 @@ def test_judge_score_factory_create_judge_response_model(): assert instance.reasoning == "Test reasoning" +@pytest.mark.parametrize( + ("options", "input_score", "expected_score"), + [ + ({"1": "Low quality", "2": "High quality"}, 1, "1"), + ({1: "Low quality", 2: "High quality"}, "1", 1), + ({"Poor": "Low quality", "Good": "High quality"}, " good ", "Good"), + ], +) +def test_judge_score_factory_coerces_score_drift(options, input_score, expected_score): + score = Score( + name="quality_score", + description="Quality assessment score", + options=options, + ) + + model_class = create_judge_response_model(score) + instance = model_class(score=input_score, reasoning="Test reasoning") + + assert instance.score == expected_score + + def test_judge_score_factory_create_judge_structured_output_model(): score = Score( name="quality_score", @@ -75,6 +96,34 @@ def test_judge_score_factory_create_judge_structured_output_model(): assert "quality_score" in model_class.model_fields +def test_judge_score_factory_structured_output_coerces_nested_score_drift(): + score = Score( + name="quality_score", + description="Quality assessment score", + options={"1": "Low quality", "2": "High quality"}, + ) + + response_model = create_judge_response_model(score) + model_class = create_judge_structured_output_model([response_model]) + + instance = model_class(quality_score={"score": 1, "reasoning": "Test reasoning"}) + + assert instance.quality_score.score == "1" + + +def test_judge_score_factory_invalid_unhashable_score_uses_pydantic_validation(): + score = Score( + name="quality_score", + description="Quality assessment score", + options={"1": "Low quality", "2": "High quality"}, + ) + + model_class = create_judge_response_model(score) + + with pytest.raises(ValidationError): + model_class(score=["1"], reasoning="Test reasoning") + + def test_judge_score_factory_preserves_score_name_casing(): """Test that Score name casing is preserved in the JSON output keys.""" score = Score(