From 40844fe556f47d24251de2e31124797f99d85e00 Mon Sep 17 00:00:00 2001 From: gleb Date: Wed, 27 May 2026 19:02:21 +0200 Subject: [PATCH] SimpleQuestion: discrete questions --- .../test_simple_question_discrete.py | 199 ++++++++++++++++++ .../question_generators/simple_question.py | 158 ++++++++++---- 2 files changed, 319 insertions(+), 38 deletions(-) create mode 100644 code_tests/unit_tests/test_agents_and_tools/test_simple_question_discrete.py diff --git a/code_tests/unit_tests/test_agents_and_tools/test_simple_question_discrete.py b/code_tests/unit_tests/test_agents_and_tools/test_simple_question_discrete.py new file mode 100644 index 00000000..0c22cc44 --- /dev/null +++ b/code_tests/unit_tests/test_agents_and_tools/test_simple_question_discrete.py @@ -0,0 +1,199 @@ +from datetime import datetime, timezone + +import pytest +from pydantic import ValidationError + +from forecasting_tools.agents_and_tools.question_generators.simple_question import ( + SimpleQuestion, +) +from forecasting_tools.data_models.questions import DiscreteQuestion + + +_BASE = dict( + question_text="How many X will happen?", + resolution_criteria="Counts of X reported by source.", + fine_print="", + background_information="", + expected_resolution_date=datetime(2026, 12, 31, tzinfo=timezone.utc), +) + + +def _make_discrete(**overrides): + return SimpleQuestion( + **{ + **_BASE, + "question_type": "discrete", + "open_lower_bound": False, + "open_upper_bound": False, + **overrides, + } + ) + + +def test_coworker_canonical_example_0_to_10(): + sq = _make_discrete(min_value=0, max_value=10, step=1) + mq = SimpleQuestion.simple_questions_to_metaculus_questions([sq])[0] + assert isinstance(mq, DiscreteQuestion) + assert mq.lower_bound == -0.5 + assert mq.upper_bound == 10.5 + assert mq.nominal_lower_bound == 0 + assert mq.nominal_upper_bound == 10 + assert mq.cdf_size == 12 # inbound_outcome_count (11) + 1 + + +def test_round_trip_preserves_step(): + sq = _make_discrete(min_value=0, max_value=10, step=1) + mq = SimpleQuestion.simple_questions_to_metaculus_questions([sq])[0] + back = SimpleQuestion.full_questions_to_simple_questions([mq])[0] + assert back.question_type == "discrete" + assert back.step == 1.0 + assert back.min_value == 0 + assert back.max_value == 10 + + +def test_round_trip_preserves_all_fields_non_integer_step(): + """Full SimpleQuestion -> DiscreteQuestion -> SimpleQuestion round-trip + with a non-1 step that actually exercises the reverse float division.""" + sq = _make_discrete( + min_value=0, + max_value=100, + step=5, + open_lower_bound=False, + open_upper_bound=True, + ) + mq = SimpleQuestion.simple_questions_to_metaculus_questions([sq])[0] + back = SimpleQuestion.full_questions_to_simple_questions([mq])[0] + assert back.question_type == "discrete" + assert back.step == pytest.approx(5) + assert back.min_value == pytest.approx(0) + assert back.max_value == pytest.approx(100) + assert back.open_lower_bound is False + assert back.open_upper_bound is True + assert back.question_text == sq.question_text + assert back.resolution_criteria == sq.resolution_criteria + assert back.expected_resolution_date == sq.expected_resolution_date + assert back.options == [] + + +def test_round_trip_with_float_step(): + """Float step (0.1) — exercises float-arithmetic in both directions.""" + sq = _make_discrete(min_value=0, max_value=1, step=0.1) + mq = SimpleQuestion.simple_questions_to_metaculus_questions([sq])[0] + back = SimpleQuestion.full_questions_to_simple_questions([mq])[0] + assert back.step == pytest.approx(0.1) + assert back.min_value == pytest.approx(0) + assert back.max_value == pytest.approx(1) + + +def test_rejects_max_equal_to_min(): + with pytest.raises(ValidationError, match="must be greater than"): + _make_discrete(min_value=5, max_value=5, step=1) + + +def test_back_compute_when_nominal_bounds_missing(): + """API may return a DiscreteQuestion without nominal_*_bound populated; + full_questions_to_simple_questions should recover step + nominal bounds + from actual bounds and cdf_size.""" + sq = _make_discrete(min_value=0, max_value=10, step=1) + mq = SimpleQuestion.simple_questions_to_metaculus_questions([sq])[0] + mq.nominal_lower_bound = None + mq.nominal_upper_bound = None + back = SimpleQuestion.full_questions_to_simple_questions([mq])[0] + assert back.question_type == "discrete" + assert back.step == 1.0 + assert back.min_value == 0 + assert back.max_value == 10 + + +def test_binned_percentage_step_5(): + sq = _make_discrete(min_value=0, max_value=100, step=5) + mq = SimpleQuestion.simple_questions_to_metaculus_questions([sq])[0] + assert mq.cdf_size == 22 # 21 outcomes + 1 + assert mq.lower_bound == -2.5 + assert mq.upper_bound == 102.5 + + +def test_tenths_step_with_float_arithmetic(): + sq = _make_discrete(min_value=0, max_value=1, step=0.1) + mq = SimpleQuestion.simple_questions_to_metaculus_questions([sq])[0] + assert mq.cdf_size == 12 # 11 outcomes + 1 + + +def test_min_outcomes_edge_3_outcomes(): + sq = _make_discrete(min_value=0, max_value=10, step=5) + mq = SimpleQuestion.simple_questions_to_metaculus_questions([sq])[0] + assert mq.cdf_size == 4 + + +def test_max_outcomes_edge_200_outcomes(): + sq = _make_discrete(min_value=0, max_value=199, step=1) + mq = SimpleQuestion.simple_questions_to_metaculus_questions([sq])[0] + assert mq.cdf_size == 201 + + +def test_rejects_step_too_large_only_two_outcomes(): + with pytest.raises(ValidationError, match="step too large"): + _make_discrete(min_value=0, max_value=10, step=10) + + +def test_rejects_step_too_small_over_200_outcomes(): + with pytest.raises(ValidationError, match="step too small"): + _make_discrete(min_value=0, max_value=100, step=0.1) + + +def test_rejects_non_integer_divisor(): + with pytest.raises(ValidationError, match="must be an integer"): + _make_discrete(min_value=0, max_value=10, step=3) + + +def test_rejects_negative_step(): + with pytest.raises(ValidationError, match="step must be positive"): + _make_discrete(min_value=0, max_value=10, step=-1) + + +def test_rejects_missing_step(): + with pytest.raises(ValidationError, match="step must be provided"): + _make_discrete(min_value=0, max_value=10) + + +def test_rejects_missing_bounds(): + with pytest.raises(ValidationError, match="Upper bound must be provided"): + _make_discrete(min_value=0, step=1) + + +def test_rejects_step_on_binary(): + with pytest.raises(ValidationError, match="step must not be provided"): + SimpleQuestion( + **{ + **_BASE, + "question_type": "binary", + "step": 1, + } + ) + + +def test_rejects_step_on_numeric(): + with pytest.raises(ValidationError, match="step must not be provided"): + SimpleQuestion( + **{ + **_BASE, + "question_type": "numeric", + "min_value": 0, + "max_value": 100, + "open_lower_bound": False, + "open_upper_bound": False, + "step": 1, + } + ) + + +def test_rejects_step_on_multiple_choice(): + with pytest.raises(ValidationError, match="step must not be provided"): + SimpleQuestion( + **{ + **_BASE, + "question_type": "multiple_choice", + "options": ["a", "b"], + "step": 1, + } + ) diff --git a/forecasting_tools/agents_and_tools/question_generators/simple_question.py b/forecasting_tools/agents_and_tools/question_generators/simple_question.py index 23c18524..1fe15684 100644 --- a/forecasting_tools/agents_and_tools/question_generators/simple_question.py +++ b/forecasting_tools/agents_and_tools/question_generators/simple_question.py @@ -8,6 +8,7 @@ from forecasting_tools.data_models.questions import ( BinaryQuestion, DateQuestion, + DiscreteQuestion, MetaculusQuestion, MultipleChoiceQuestion, NumericQuestion, @@ -22,26 +23,32 @@ class SimpleQuestion(BaseModel, Jsonable): fine_print: str | None = None background_information: str | None = None expected_resolution_date: datetime - question_type: Literal["binary", "numeric", "multiple_choice"] = "binary" + question_type: Literal["binary", "numeric", "multiple_choice", "discrete"] = ( + "binary" + ) options: list[str] = Field( default_factory=list, - description="Options are for multiple choice question. Empty if numeric or binary. Must be defined for multiple choice questions.", + description="Options are for multiple choice question. Empty if numeric, discrete, or binary. Must be defined for multiple choice questions.", ) open_upper_bound: bool | None = Field( default=None, - description="Open upper bound defines whether there can be a value higher than upper bound. Must be defined for numeric questions and None for other question types.", + description="Open upper bound defines whether there can be a value higher than upper bound. Must be defined for numeric and discrete questions and None for other question types.", ) open_lower_bound: bool | None = Field( default=None, - description="Open lower bound defines whether there can be a value lower than lower bound. Must be defined for numeric questions and None for other question types.", + description="Open lower bound defines whether there can be a value lower than lower bound. Must be defined for numeric and discrete questions and None for other question types.", ) max_value: float | None = Field( default=None, - description="Max value defines the largest reasonable value that the answer to the question can be. Must be defined for numeric questions and None for other question types.", + description="Max value defines the largest reasonable value that the answer to the question can be. Must be defined for numeric and discrete questions and None for other question types.", ) min_value: float | None = Field( default=None, - description="Min value defines the smallest reasonable value that the answer to the question can be. Must be defined for numeric questions and None for other question types.", + description="Min value defines the smallest reasonable value that the answer to the question can be. Must be defined for numeric and discrete questions and None for other question types.", + ) + step: float | None = Field( + default=None, + description="Spacing between consecutive outcomes for discrete questions. Required for discrete questions; must be None for all other question types.", ) @classmethod @@ -53,12 +60,13 @@ def get_field_descriptions(cls) -> str: - fine_print: Additional information covering *every* edge case that could happen. There should be no chance of an ambiguous resolution. Resolution criteria + fine print should pass the clairvoyance test such that after the event happens there is no debate about whether it happened or not no matter how it resolves. - background_information: Relevant context and historical information to help understand the question - expected_resolution_date: The date when the question is expected to resolve - - question_type: The type of question, either binary, numeric, or multiple_choice based on how the forecaster should answer (with yes/no, a number, or a choice from a list) + - question_type: The type of question — binary, numeric, multiple_choice, or discrete — based on how the forecaster should answer (yes/no, a continuous number, a choice from a list, or a value from a small fixed set of evenly-spaced outcomes). - options: The options for the question, only used for multiple_choice questions. Empty list for other question types. - - open_upper_bound: Whether there can be a value higher than upper bound (e.g. if the value is a percentag, 100 is the max the bound is closed, but number of certifications in a population has an open upper bound), only used for numeric questions. - - open_lower_bound: Whether there can be a value lower than lower bound (e.g. distances can't be negative the bound is closed at 0, but profit margins can be negative so the bound is open), only used for numeric questions. - - max_value: The max value that the answer to the question can be. If bound is closed then choose the max number. If bound is open then pick a really really big number. Only used for numeric questions. (e.g. 100 for a percentage, 1000 for a number of certifications from an small org, 100000 for a number of new houses built in a large city in a year) - - min_value: The min value that the answer to the question can be. If bound is closed then choose the min number. If bound is open then pick a really really negative number. Only used for numeric questions. (e.g. 0 for a percentage, 0 for a number of certifications from a small org, -10000000 for a medium company net profit) + - open_upper_bound: Whether there can be a value higher than upper bound (e.g. if the value is a percentage, 100 is the max the bound is closed, but number of certifications in a population has an open upper bound), only used for numeric and discrete questions. + - open_lower_bound: Whether there can be a value lower than lower bound (e.g. distances can't be negative the bound is closed at 0, but profit margins can be negative so the bound is open), only used for numeric and discrete questions. + - max_value: The max value that the answer to the question can be. If bound is closed then choose the max number. If bound is open then pick a really really big number. For discrete questions, this is the largest nominal outcome. Only used for numeric and discrete questions. (e.g. 100 for a percentage, 1000 for a number of certifications from an small org, 100000 for a number of new houses built in a large city in a year) + - min_value: The min value that the answer to the question can be. If bound is closed then choose the min number. If bound is open then pick a really really negative number. For discrete questions, this is the smallest nominal outcome. Only used for numeric and discrete questions. (e.g. 0 for a percentage, 0 for a number of certifications from a small org, -10000000 for a medium company net profit) + - step: Spacing between consecutive outcomes for discrete questions only (must be None for other types). `(max_value - min_value) / step` must be an integer in [2, 199]. Example: integer counts 0..10 use step=1. """ ) @@ -73,32 +81,32 @@ def ensure_utc_timezone(cls, value: datetime) -> datetime: mode="after", ) def validate_question_type_fields(self: SimpleQuestion) -> SimpleQuestion: - if self.question_type == "numeric": + if self.question_type in ("numeric", "discrete"): assert ( self.max_value is not None - ), "Upper bound must be provided for numeric questions" + ), "Upper bound must be provided for continuous questions" assert ( self.min_value is not None - ), "Lower bound must be provided for numeric questions" + ), "Lower bound must be provided for continuous questions" assert ( self.open_upper_bound is not None - ), "Open upper bound must be provided for numeric questions" + ), "Open upper bound must be provided for continuous questions" assert ( self.open_lower_bound is not None - ), "Open lower bound must be provided for numeric questions" + ), "Open lower bound must be provided for continuous questions" else: assert ( self.max_value is None - ), "Upper bound must not be provided for non-numeric questions" + ), "Upper bound must not be provided for non-numeric/discrete questions" assert ( self.min_value is None - ), "Lower bound must not be provided for non-numeric questions" + ), "Lower bound must not be provided for non-numeric/discrete questions" assert ( self.open_upper_bound is None - ), "Open upper bound must not be provided for non-numeric questions" + ), "Open upper bound must not be provided for non-numeric/discrete questions" assert ( self.open_lower_bound is None - ), "Open lower bound must not be provided for non-numeric questions" + ), "Open lower bound must not be provided for non-numeric/discrete questions" if self.question_type == "multiple_choice": assert ( @@ -108,6 +116,35 @@ def validate_question_type_fields(self: SimpleQuestion) -> SimpleQuestion: assert ( len(self.options) == 0 ), "Options must not be provided for non-multiple choice questions" + + if self.question_type == "discrete": + assert self.step is not None, "step must be provided for discrete questions" + assert self.step > 0, f"step must be positive, got {self.step}" + assert self.max_value is not None and self.min_value is not None + range_size = self.max_value - self.min_value + assert range_size > 0, "max_value must be greater than min_value" + assert ( + self.step <= range_size / 2 + ), "step too large: must be <= (max_value - min_value) / 2" + assert ( + self.step >= range_size / 199 + ), f"step too small: must be >= (max_value - min_value) / 199" + quotient = range_size / self.step + assert abs(round(quotient) - quotient) < 1e-6, ( + "range / step must be an integer; " + f"(max_value - min_value) = {range_size} is not a multiple of " + f"step = {self.step}" + ) + inbound_outcome_count = round(quotient) + 1 + assert 3 <= inbound_outcome_count <= 200, ( + f"derived inbound_outcome_count={inbound_outcome_count} outside " + "the platform's [3, 200] range" + ) + else: + assert ( + self.step is None + ), f"step must not be provided for {self.question_type} questions" + return self @classmethod @@ -126,7 +163,37 @@ def full_questions_to_simple_questions( assert question.scheduled_resolution_time is not None assert question.fine_print is not None - if isinstance(question, NumericQuestion): + step = None + if isinstance(question, DiscreteQuestion): + question_type = "discrete" + options = [] + open_upper_bound = question.open_upper_bound + open_lower_bound = question.open_lower_bound + inbound_outcome_count = question.cdf_size - 1 + if ( + question.nominal_lower_bound is not None + and question.nominal_upper_bound is not None + ): + upper_bound = question.nominal_upper_bound + lower_bound = question.nominal_lower_bound + step = ( + (upper_bound - lower_bound) / (inbound_outcome_count - 1) + if inbound_outcome_count > 1 + else None + ) + else: + # Back-compute nominal bounds from actual bounds + cdf_size. + # Actual range spans inbound_outcome_count * step (since the + # actual bounds are ±0.5*step beyond the nominal bounds), so + # step = actual_range / inbound_outcome_count, and the + # nominal bounds sit half a step inside the actual bounds. + step = ( + (question.upper_bound - question.lower_bound) + / inbound_outcome_count + ) + lower_bound = question.lower_bound + step / 2 + upper_bound = question.upper_bound - step / 2 + elif isinstance(question, NumericQuestion): # TODO: Give more direct support for date questions question_type = "numeric" options = [] @@ -163,6 +230,7 @@ def full_questions_to_simple_questions( min_value=lower_bound, open_upper_bound=open_upper_bound, open_lower_bound=open_lower_bound, + step=step, ) simple_questions.append(simple_question) return simple_questions @@ -173,38 +241,52 @@ def simple_questions_to_metaculus_questions( ) -> list[MetaculusQuestion]: full_questions = [] for question in simple_questions: + base_attrs = { + "question_text": question.question_text, + "background_info": question.background_information, + "resolution_criteria": question.resolution_criteria, + "fine_print": question.fine_print, + "scheduled_resolution_time": question.expected_resolution_date, + } + if question.question_type == "binary": - full_question = BinaryQuestion( - question_text=question.question_text, - background_info=question.background_information, - resolution_criteria=question.resolution_criteria, - fine_print=question.fine_print, - scheduled_resolution_time=question.expected_resolution_date, - ) + full_question = BinaryQuestion(**base_attrs) elif question.question_type == "numeric": assert question.max_value is not None assert question.min_value is not None assert question.open_upper_bound is not None assert question.open_lower_bound is not None full_question = NumericQuestion( - question_text=question.question_text, - background_info=question.background_information, - resolution_criteria=question.resolution_criteria, - fine_print=question.fine_print, + **base_attrs, upper_bound=question.max_value, lower_bound=question.min_value, open_upper_bound=question.open_upper_bound, open_lower_bound=question.open_lower_bound, - scheduled_resolution_time=question.expected_resolution_date, + ) + elif question.question_type == "discrete": + assert question.max_value is not None + assert question.min_value is not None + assert question.step is not None and question.step > 0 + assert question.open_upper_bound is not None + assert question.open_lower_bound is not None + half_step = question.step / 2 + inbound_outcome_count = ( + round((question.max_value - question.min_value) / question.step) + 1 + ) + full_question = DiscreteQuestion( + **base_attrs, + nominal_lower_bound=question.min_value, + nominal_upper_bound=question.max_value, + lower_bound=question.min_value - half_step, + upper_bound=question.max_value + half_step, + open_upper_bound=question.open_upper_bound, + open_lower_bound=question.open_lower_bound, + cdf_size=inbound_outcome_count + 1, ) elif question.question_type == "multiple_choice": full_question = MultipleChoiceQuestion( - question_text=question.question_text, - background_info=question.background_information, - resolution_criteria=question.resolution_criteria, - fine_print=question.fine_print, + **base_attrs, options=question.options, - scheduled_resolution_time=question.expected_resolution_date, ) else: raise ValueError(f"Unknown question type: {question.question_type}")