diff --git a/rlm/utils/parsing.py b/rlm/utils/parsing.py index e4c2350..bde9b0e 100644 --- a/rlm/utils/parsing.py +++ b/rlm/utils/parsing.py @@ -33,6 +33,12 @@ def find_final_answer(text: str, environment: "BaseEnv | None" = None) -> str | If FINAL_VAR is found and an environment is provided, executes code to retrieve the variable value. Returns None if neither pattern is found. + Only accepts FINAL/FINAL_VAR in these positions to avoid false positives from CoT reasoning: + 1. At the start of a line (with optional leading whitespace) + 2. Immediately after special tokens like <|begin_of_box|> + + This rejects cases like "I will then do FINAL(...)" where the model is explaining its plan. + Args: text: The response text to parse environment: Optional environment to execute code for FINAL_VAR retrieval @@ -40,8 +46,14 @@ def find_final_answer(text: str, environment: "BaseEnv | None" = None) -> str | Returns: The final answer string, or None if no final answer pattern is found """ - # Check for FINAL_VAR pattern first - must be at start of line - final_var_pattern = r"^\s*FINAL_VAR\((.*?)\)" + # Pattern explanation: + # (?:^[ \t]*|(?<=\n)[ \t]*|<\|[^|]+\|>\s*) matches: + # - Start of string with optional spaces/tabs: ^[ \t]* + # - After newline with optional spaces/tabs: (?<=\n)[ \t]* + # - After special tokens like <|...|>: <\|[^|]+\|>\s* + + # Check for FINAL_VAR pattern first (more specific) + final_var_pattern = r"(?:^[ \t]*|(?<=\n)[ \t]*|<\|[^|]+\|>\s*)FINAL_VAR\((.*?)\)" match = re.search(final_var_pattern, text, re.MULTILINE | re.DOTALL) if match: variable_name = match.group(1).strip().strip('"').strip("'") @@ -53,8 +65,8 @@ def find_final_answer(text: str, environment: "BaseEnv | None" = None) -> str | return final_answer return None - # Check for FINAL pattern - must be at start of line - final_pattern = r"^\s*FINAL\((.*?)\)" + # Check for FINAL pattern + final_pattern = r"(?:^[ \t]*|(?<=\n)[ \t]*|<\|[^|]+\|>\s*)FINAL\((.*?)\)" match = re.search(final_pattern, text, re.MULTILINE | re.DOTALL) if match: return match.group(1).strip()