From cddd7dcf64bb34c6045a151bde7d1d164cd352be Mon Sep 17 00:00:00 2001 From: Pooya Moradi Date: Tue, 26 May 2026 16:00:15 +0000 Subject: [PATCH] extract_answer: prefer \boxed{N} extraction, fall back to legacy tags Modern reasoning models (Qwen3, DeepSeek-R1, etc.) emit `\boxed{N}` inside `...` (or with no answer tags at all). The legacy regex returned the raw `` content (e.g. `\boxed{42}` as a string), which math_verify cannot match against a bare numeric gold like "42". Result: ~0% accuracy on Qwen3/GSM8K even when the model's numeric answer is correct. New strategy (priority order): 1. If a `{solution_start_token}...{solution_end_token}` block is present (default `...`), use the last block's content as the search scope; otherwise use the full response. 2. Inside the scope, extract the last `\boxed{N}` via brace-balanced scan + permissive regex fallback. 3. If no `\boxed` is found, fall back to the same configured solution-tag regex (backward-compat for recipes that emit plain-text answers). Both the scoping (step 1) and the plain-text fallback (step 3) reuse get_answer_fallback_regex, so the solution tags have a single source of truth in solution_start_token / solution_end_token. --- .../trainers/post_train/rl/utils_rl.py | 53 +++++++- .../post_training/unit/extract_answer_test.py | 124 ++++++++++++++++++ 2 files changed, 171 insertions(+), 6 deletions(-) create mode 100644 tests/post_training/unit/extract_answer_test.py diff --git a/src/maxtext/trainers/post_train/rl/utils_rl.py b/src/maxtext/trainers/post_train/rl/utils_rl.py index 5a29b5fbc2..8a339d4865 100644 --- a/src/maxtext/trainers/post_train/rl/utils_rl.py +++ b/src/maxtext/trainers/post_train/rl/utils_rl.py @@ -484,12 +484,53 @@ def check_numbers( def extract_answer(response: str, tmvp_config: Any) -> str: - """Function to extract the answer from the text based on the tmvp_config format.""" - answer_fallback = get_answer_fallback_regex(tmvp_config) - # Find the *last* occurrence of the answer tag (most likely the final answer). - fallback_matches = answer_fallback.findall(response) - extracted_response = fallback_matches[-1].strip() if fallback_matches else FALLBACK_ANSWER - return extracted_response + """Extract the final numeric answer from the model's response. + + Strategy (priority order): + 1. Narrow the search scope to the LAST + `{solution_start_token}...{solution_end_token}` block (default + `...`) if present; otherwise use the full response. + 2. Inside the search scope, find the last `\\boxed{N}` via a brace- + balanced scan (handles nested braces in LaTeX). Fall back to a + permissive `\\boxed{N}` regex if no balanced match is found. + 3. If no boxed expression is found, fall back to the same configured + solution-tag regex over the full response, for recipes that emit the + answer as plain text rather than `\\boxed{N}`. + + Step 1 + 2 are required for modern reasoning models (Qwen3, DeepSeek-R1, + etc.) that emit `...\\boxed{N}` or `\\boxed{N}` + framing. Without `\\boxed` extraction the legacy regex returns the raw + `\\boxed{N}` string and math_verify cannot match it against a bare numeric + gold. Step 3 keeps the function backward-compatible with recipes that + emit plain-text answers inside the configured solution tags. + + The solution tags have a single source of truth: both the scoping (step 1) + and the plain-text fallback (step 3) reuse `get_answer_fallback_regex`, + built from `tmvp_config.solution_start_token` / `solution_end_token`. + """ + answer_tag_regex = get_answer_fallback_regex(tmvp_config) + answer_matches = answer_tag_regex.findall(response) + content = answer_matches[-1] if answer_matches else response + boxed_matches: list[str] = [] + stack: list[int] = [] + for i, ch in enumerate(content): + if ch == "{": + stack.append(i) + elif ch == "}": + if not stack: + continue + op = stack.pop() + if content[:op].endswith(r"\boxed"): + boxed_matches.append(content[op + 1 : i].strip()) + if boxed_matches: + return boxed_matches[-1] + m = re.search(r"\\boxed\s*\{?\s*([a-zA-Z0-9\.,\-]+)\s*\}?", content) + if m: + return m.group(1).strip() + fallback_matches = answer_tag_regex.findall(response) + if fallback_matches: + return fallback_matches[-1].strip() + return FALLBACK_ANSWER def extract_hash_answer(text: str) -> str | None: diff --git a/tests/post_training/unit/extract_answer_test.py b/tests/post_training/unit/extract_answer_test.py new file mode 100644 index 0000000000..2129ba05b0 --- /dev/null +++ b/tests/post_training/unit/extract_answer_test.py @@ -0,0 +1,124 @@ +# Copyright 2023-2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for utils_rl.extract_answer (CPU-only). + +Covers the two-part contract of the boxed-extraction change: + 1. `\\boxed{N}` is extracted (with/without tags, nested LaTeX, + multiple boxed, whitespace, negatives, and answer-tag scoping). + 2. Legacy plain-text answers inside the solution tags still work, so + existing recipes that do not emit `\\boxed` are unaffected. +""" + +import unittest +from types import SimpleNamespace + +import pytest + +from maxtext.trainers.post_train.rl import utils_rl + +pytestmark = [pytest.mark.post_training] + + +def _make_config(): + """Minimal config carrying the solution/reasoning tokens extract_answer reads.""" + return SimpleNamespace( + reasoning_start_token="", + reasoning_end_token="", + solution_start_token="", + solution_end_token="", + ) + + +class ExtractAnswerTest(unittest.TestCase): + """Verify boxed extraction and legacy-fallback behavior of extract_answer.""" + + def setUp(self): + super().setUp() + self.config = _make_config() + + # ---- boxed extraction ---- + + @pytest.mark.cpu_only + def test_boxed_inside_answer_tags(self): + got = utils_rl.extract_answer("2+2\\boxed{4}", self.config) + self.assertEqual(got, "4") + + @pytest.mark.cpu_only + def test_boxed_without_answer_tags(self): + got = utils_rl.extract_answer("the result is \\boxed{42}", self.config) + self.assertEqual(got, "42") + + @pytest.mark.cpu_only + def test_boxed_nested_latex(self): + """Brace-balanced scan keeps the full nested LaTeX content.""" + got = utils_rl.extract_answer("\\boxed{\\frac{1}{2}}", self.config) + self.assertEqual(got, "\\frac{1}{2}") + + @pytest.mark.cpu_only + def test_multiple_boxed_returns_last(self): + got = utils_rl.extract_answer("first \\boxed{1} then \\boxed{99}", self.config) + self.assertEqual(got, "99") + + @pytest.mark.cpu_only + def test_boxed_strips_whitespace(self): + got = utils_rl.extract_answer("\\boxed{ 7 }", self.config) + self.assertEqual(got, "7") + + @pytest.mark.cpu_only + def test_boxed_negative(self): + got = utils_rl.extract_answer("answer: \\boxed{-3}", self.config) + self.assertEqual(got, "-3") + + @pytest.mark.cpu_only + def test_answer_tag_scopes_over_reasoning_boxed(self): + """A boxed value in must not win over the one in .""" + resp = "maybe \\boxed{1}\\boxed{8}" + self.assertEqual(utils_rl.extract_answer(resp, self.config), "8") + + @pytest.mark.cpu_only + def test_scoping_follows_configured_solution_tokens(self): + """Scoping uses solution_start/end_token, not a hardcoded tag.""" + config = SimpleNamespace( + reasoning_start_token="", + reasoning_end_token="", + solution_start_token="", + solution_end_token="", + ) + resp = "maybe \\boxed{1}\\boxed{8}" + self.assertEqual(utils_rl.extract_answer(resp, config), "8") + + # ---- legacy fallback (no boxed) ---- + + @pytest.mark.cpu_only + def test_legacy_plain_answer_in_tags(self): + """A plain-text answer inside tags (no boxed) still extracts.""" + got = utils_rl.extract_answer("work42", self.config) + self.assertEqual(got, "42") + + @pytest.mark.cpu_only + def test_legacy_last_answer_wins(self): + got = utils_rl.extract_answer("1 ... 5", self.config) + self.assertEqual(got, "5") + + # ---- no answer ---- + + @pytest.mark.cpu_only + def test_no_answer_returns_fallback_constant(self): + got = utils_rl.extract_answer("I have no idea", self.config) + self.assertEqual(got, utils_rl.FALLBACK_ANSWER) + + +if __name__ == "__main__": + unittest.main()