Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 47 additions & 6 deletions src/maxtext/trainers/post_train/rl/utils_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,12 +484,53 @@ def check_numbers(


def extract_answer(response: str, tmvp_config: Any) -> str:
"""Function to extract the answer from the text based on the tmvp_config format."""
answer_fallback = get_answer_fallback_regex(tmvp_config)
# Find the *last* occurrence of the answer tag (most likely the final answer).
fallback_matches = answer_fallback.findall(response)
extracted_response = fallback_matches[-1].strip() if fallback_matches else FALLBACK_ANSWER
return extracted_response
"""Extract the final numeric answer from the model's response.

Strategy (priority order):
1. Narrow the search scope to the LAST
`{solution_start_token}...{solution_end_token}` block (default
`<answer>...</answer>`) if present; otherwise use the full response.
2. Inside the search scope, find the last `\\boxed{N}` via a brace-
balanced scan (handles nested braces in LaTeX). Fall back to a
permissive `\\boxed{N}` regex if no balanced match is found.
3. If no boxed expression is found, fall back to the same configured
solution-tag regex over the full response, for recipes that emit the
answer as plain text rather than `\\boxed{N}`.

Step 1 + 2 are required for modern reasoning models (Qwen3, DeepSeek-R1,
etc.) that emit `<think>...</think>\\boxed{N}` or `<answer>\\boxed{N}</answer>`
framing. Without `\\boxed` extraction the legacy regex returns the raw
`\\boxed{N}` string and math_verify cannot match it against a bare numeric
gold. Step 3 keeps the function backward-compatible with recipes that
emit plain-text answers inside the configured solution tags.

The solution tags have a single source of truth: both the scoping (step 1)
and the plain-text fallback (step 3) reuse `get_answer_fallback_regex`,
built from `tmvp_config.solution_start_token` / `solution_end_token`.
"""
answer_tag_regex = get_answer_fallback_regex(tmvp_config)
answer_matches = answer_tag_regex.findall(response)
content = answer_matches[-1] if answer_matches else response
boxed_matches: list[str] = []
stack: list[int] = []
for i, ch in enumerate(content):
if ch == "{":
stack.append(i)
elif ch == "}":
if not stack:
continue
op = stack.pop()
if content[:op].endswith(r"\boxed"):
boxed_matches.append(content[op + 1 : i].strip())
if boxed_matches:
return boxed_matches[-1]
m = re.search(r"\\boxed\s*\{?\s*([a-zA-Z0-9\.,\-]+)\s*\}?", content)
if m:
return m.group(1).strip()
fallback_matches = answer_tag_regex.findall(response)
if fallback_matches:
return fallback_matches[-1].strip()
return FALLBACK_ANSWER


def extract_hash_answer(text: str) -> str | None:
Expand Down
124 changes: 124 additions & 0 deletions tests/post_training/unit/extract_answer_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# Copyright 2023-2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Unit tests for utils_rl.extract_answer (CPU-only).

Covers the two-part contract of the boxed-extraction change:
1. `\\boxed{N}` is extracted (with/without <answer> tags, nested LaTeX,
multiple boxed, whitespace, negatives, and answer-tag scoping).
2. Legacy plain-text answers inside the solution tags still work, so
existing recipes that do not emit `\\boxed` are unaffected.
"""

import unittest
from types import SimpleNamespace

import pytest

from maxtext.trainers.post_train.rl import utils_rl

pytestmark = [pytest.mark.post_training]


def _make_config():
"""Minimal config carrying the solution/reasoning tokens extract_answer reads."""
return SimpleNamespace(
reasoning_start_token="<reasoning>",
reasoning_end_token="</reasoning>",
solution_start_token="<answer>",
solution_end_token="</answer>",
)


class ExtractAnswerTest(unittest.TestCase):
"""Verify boxed extraction and legacy-fallback behavior of extract_answer."""

def setUp(self):
super().setUp()
self.config = _make_config()

# ---- boxed extraction ----

@pytest.mark.cpu_only
def test_boxed_inside_answer_tags(self):
got = utils_rl.extract_answer("<reasoning>2+2</reasoning><answer>\\boxed{4}</answer>", self.config)
self.assertEqual(got, "4")

@pytest.mark.cpu_only
def test_boxed_without_answer_tags(self):
got = utils_rl.extract_answer("the result is \\boxed{42}", self.config)
self.assertEqual(got, "42")

@pytest.mark.cpu_only
def test_boxed_nested_latex(self):
"""Brace-balanced scan keeps the full nested LaTeX content."""
got = utils_rl.extract_answer("<answer>\\boxed{\\frac{1}{2}}</answer>", self.config)
self.assertEqual(got, "\\frac{1}{2}")

@pytest.mark.cpu_only
def test_multiple_boxed_returns_last(self):
got = utils_rl.extract_answer("first \\boxed{1} then \\boxed{99}", self.config)
self.assertEqual(got, "99")

@pytest.mark.cpu_only
def test_boxed_strips_whitespace(self):
got = utils_rl.extract_answer("<answer>\\boxed{ 7 }</answer>", self.config)
self.assertEqual(got, "7")

@pytest.mark.cpu_only
def test_boxed_negative(self):
got = utils_rl.extract_answer("answer: \\boxed{-3}", self.config)
self.assertEqual(got, "-3")

@pytest.mark.cpu_only
def test_answer_tag_scopes_over_reasoning_boxed(self):
"""A boxed value in <reasoning> must not win over the one in <answer>."""
resp = "<reasoning>maybe \\boxed{1}</reasoning><answer>\\boxed{8}</answer>"
self.assertEqual(utils_rl.extract_answer(resp, self.config), "8")

@pytest.mark.cpu_only
def test_scoping_follows_configured_solution_tokens(self):
"""Scoping uses solution_start/end_token, not a hardcoded <answer> tag."""
config = SimpleNamespace(
reasoning_start_token="<reasoning>",
reasoning_end_token="</reasoning>",
solution_start_token="<sol>",
solution_end_token="</sol>",
)
resp = "<reasoning>maybe \\boxed{1}</reasoning><sol>\\boxed{8}</sol>"
self.assertEqual(utils_rl.extract_answer(resp, config), "8")

# ---- legacy fallback (no boxed) ----

@pytest.mark.cpu_only
def test_legacy_plain_answer_in_tags(self):
"""A plain-text answer inside <answer> tags (no boxed) still extracts."""
got = utils_rl.extract_answer("<reasoning>work</reasoning><answer>42</answer>", self.config)
self.assertEqual(got, "42")

@pytest.mark.cpu_only
def test_legacy_last_answer_wins(self):
got = utils_rl.extract_answer("<answer>1</answer> ... <answer>5</answer>", self.config)
self.assertEqual(got, "5")

# ---- no answer ----

@pytest.mark.cpu_only
def test_no_answer_returns_fallback_constant(self):
got = utils_rl.extract_answer("I have no idea", self.config)
self.assertEqual(got, utils_rl.FALLBACK_ANSWER)


if __name__ == "__main__":
unittest.main()
Loading