Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,62 @@
from __future__ import annotations

from enum import Enum
from typing import Any

from pydantic import BaseModel, ConfigDict, Field, create_model
from pydantic import BaseModel, ConfigDict, Field, create_model, model_validator

from data_designer.config.column_configs import Score

SCORING_FORMAT = "* {score}: {description}"
SCORE_FIELD_DESCRIPTION_FORMAT = "Score Descriptions for {enum_name}:\n{scoring}"


def _normalize_score_value(value: Any) -> str:
if isinstance(value, str):
return value.strip().casefold()
if isinstance(value, float) and value.is_integer():
return str(int(value))
return str(value).strip().casefold()


def _coerce_score_value(value: Any, enum_type: type[Enum]) -> Any:
for member in enum_type:
if isinstance(value, bool) != isinstance(member.value, bool):
continue
if value == member.value:
return value

normalized_value = _normalize_score_value(value)
matches = [member.value for member in enum_type if _normalize_score_value(member.value) == normalized_value]
if len(matches) == 1:
return matches[0]
return value


class BaseJudgeResponse(BaseModel):
"""Base model for all rubrics."""

model_config = ConfigDict(use_enum_values=True)
reasoning: str = Field(..., description="Reasoning for the assigned score.")

@model_validator(mode="before")
@classmethod
def coerce_score(cls, data: Any) -> Any:
if not isinstance(data, dict) or "score" not in data:
return data

score_field = cls.model_fields.get("score")
if score_field is None:
return data

score_type = score_field.annotation
if not isinstance(score_type, type) or not issubclass(score_type, Enum):
return data

coerced_data = data.copy()
coerced_data["score"] = _coerce_score_value(data["score"], score_type)
return coerced_data


def _stringify_scoring(options: dict, enum_type: type[Enum]) -> str:
"""Convert score descriptions into a single text block."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from enum import Enum

import pytest
from pydantic import BaseModel
from pydantic import BaseModel, ValidationError

from data_designer.config.column_configs import Score
from data_designer.engine.column_generators.utils.judge_score_factory import (
Expand Down Expand Up @@ -61,6 +61,27 @@ def test_judge_score_factory_create_judge_response_model():
assert instance.reasoning == "Test reasoning"


@pytest.mark.parametrize(
("options", "input_score", "expected_score"),
[
({"1": "Low quality", "2": "High quality"}, 1, "1"),
({1: "Low quality", 2: "High quality"}, "1", 1),
({"Poor": "Low quality", "Good": "High quality"}, " good ", "Good"),
],
)
def test_judge_score_factory_coerces_score_drift(options, input_score, expected_score):
score = Score(
name="quality_score",
description="Quality assessment score",
options=options,
)

model_class = create_judge_response_model(score)
instance = model_class(score=input_score, reasoning="Test reasoning")

assert instance.score == expected_score


def test_judge_score_factory_create_judge_structured_output_model():
score = Score(
name="quality_score",
Expand All @@ -75,6 +96,34 @@ def test_judge_score_factory_create_judge_structured_output_model():
assert "quality_score" in model_class.model_fields


def test_judge_score_factory_structured_output_coerces_nested_score_drift():
score = Score(
name="quality_score",
description="Quality assessment score",
options={"1": "Low quality", "2": "High quality"},
)

response_model = create_judge_response_model(score)
model_class = create_judge_structured_output_model([response_model])

instance = model_class(quality_score={"score": 1, "reasoning": "Test reasoning"})

assert instance.quality_score.score == "1"


def test_judge_score_factory_invalid_unhashable_score_uses_pydantic_validation():
score = Score(
name="quality_score",
description="Quality assessment score",
options={"1": "Low quality", "2": "High quality"},
)

model_class = create_judge_response_model(score)

with pytest.raises(ValidationError):
model_class(score=["1"], reasoning="Test reasoning")


def test_judge_score_factory_preserves_score_name_casing():
"""Test that Score name casing is preserved in the JSON output keys."""
score = Score(
Expand Down
Loading