Skip to content

Commit ad034b5

Browse files
authored
Add DSBench-DA evaluation (#1254)
Squash merge of changes during code-review. Signed-off-by: suriya <sgunasekar@nvidia.com>
1 parent 7593ab3 commit ad034b5

11 files changed

Lines changed: 410 additions & 2 deletions

File tree

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ nemo_skills/dataset/aalcr/lcr/
4747
CLAUDE.md
4848
AGENTS.md
4949
.codex
50-
50+
.claude
51+
.cursor
5152
.idea
5253

5354
#scripts at root level
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# settings that define how evaluation should be done by default (all can be changed from cmdline)
16+
EVAL_SPLIT = "test"
17+
METRICS_TYPE = "math"
18+
19+
# Use DSBench evaluator (extends MathEvaluator) with relaxed extraction and case-insensitive MCQ and handling of dict and list.
20+
GENERATION_ARGS = "++prompt_config=generic/dsbench-da ++eval_type=dsbench ++eval_config.relaxed_extraction=true"
21+
22+
# Recommend running LLM judge to verify dicts and lists correctly
23+
# JUDGE_PIPELINE_ARGS = {
24+
# "generation_type": "math_judge",
25+
# "model": "gpt-4.1",
26+
# "server_type": "openai",
27+
# "server_address": "https://api.openai.com/v1",
28+
# }
Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import argparse
16+
import json
17+
import zipfile
18+
from pathlib import Path
19+
20+
from huggingface_hub import hf_hub_download
21+
22+
23+
def read_excel_to_text(excel_path: Path) -> str:
24+
"""Read Excel file and convert to text representation."""
25+
import pandas as pd
26+
27+
try:
28+
# Explicitly handle .xlsb files with pyxlsb engine
29+
engine = "pyxlsb" if excel_path.suffix == ".xlsb" else None
30+
with pd.ExcelFile(excel_path, engine=engine) as xls:
31+
sheets = {sheet_name: xls.parse(sheet_name) for sheet_name in xls.sheet_names}
32+
except Exception as e:
33+
raise RuntimeError(f"Failed to read Excel file {excel_path}: {e}") from e
34+
35+
combined_text = ""
36+
for sheet_name, df in sheets.items():
37+
sheet_text = df.to_string(index=False)
38+
combined_text += f"Sheet name: {sheet_name}\n{sheet_text}\n\n"
39+
return combined_text
40+
41+
42+
def format_paths_for_prompt(paths: list[Path], actual_root: Path, display_root: Path) -> str:
43+
"""Format file paths for display in prompt.
44+
45+
Args:
46+
paths: List of absolute Path objects to format
47+
actual_root: Root directory where files actually exist
48+
display_root: Root directory to display in paths (absolute for abs paths, Path(".") for relative)
49+
"""
50+
if not paths:
51+
return ""
52+
53+
formatted = []
54+
for path in paths:
55+
try:
56+
rel = path.relative_to(actual_root)
57+
disp_path = display_root / rel
58+
except ValueError:
59+
disp_path = path
60+
formatted.append(str(disp_path))
61+
62+
return " ".join(formatted)
63+
64+
65+
def save_data(split: str, data_dir: str | Path, display_root: str | Path | None, incontext_data: bool) -> None:
66+
"""Download and prepare DSBench data."""
67+
print(f"Preparing DSBench data for {split} split and saving to {data_dir}...")
68+
69+
data_dir = Path(data_dir)
70+
data_dir.mkdir(parents=True, exist_ok=True)
71+
72+
extracted_data_dir = data_dir / "data"
73+
74+
# Extract if not already cached (hf_hub_download handles download caching)
75+
if not extracted_data_dir.exists():
76+
print(" Downloading dataset from HuggingFace...")
77+
zip_path = Path(
78+
hf_hub_download(repo_id="liqiang888/DSBench", filename="data_analysis/data.zip", repo_type="dataset")
79+
)
80+
print(" Extracting data...")
81+
with zipfile.ZipFile(zip_path, "r") as zip_ref:
82+
zip_ref.extractall(data_dir)
83+
if not extracted_data_dir.exists():
84+
raise FileNotFoundError(f"Could not find data directory after extraction in {extracted_data_dir}")
85+
print(f" Dataset cached to {data_dir}")
86+
else:
87+
print(f" Using cached dataset from {data_dir}")
88+
89+
# Load metadata
90+
print(" Loading metadata...")
91+
metadata_path = Path(
92+
hf_hub_download(repo_id="liqiang888/DSBench", filename="data_analysis/data.json", repo_type="dataset")
93+
)
94+
metadata = []
95+
with open(metadata_path, "r") as f:
96+
for line in f:
97+
if line.strip():
98+
metadata.append(json.loads(line.strip()))
99+
100+
# Process all tasks
101+
if not display_root:
102+
display_root = extracted_data_dir
103+
else:
104+
display_root = Path(display_root)
105+
106+
print(
107+
f" Processing {len(metadata)} tasks at {extracted_data_dir} - using display root {display_root} for paths shown in the prompt..."
108+
)
109+
all_entries = []
110+
111+
for task in metadata:
112+
task_id = task["id"]
113+
task_dir = extracted_data_dir / task_id
114+
115+
if not task_dir.exists():
116+
raise FileNotFoundError(
117+
f"Task directory not found: {task_dir}. "
118+
f"Expected task {task_id} from metadata but directory is missing. "
119+
"Data extraction may have failed."
120+
)
121+
if len(task["answers"]) != len(task["questions"]):
122+
raise ValueError(
123+
f"Task {task_id}: mismatched questions ({len(task['questions'])}) "
124+
f"and answers ({len(task['answers'])}) counts in metadata."
125+
)
126+
127+
# Read introduction
128+
intro_file = task_dir / "introduction.txt"
129+
introduction = ""
130+
if intro_file.exists():
131+
introduction = intro_file.read_text(encoding="utf-8", errors="ignore")
132+
133+
# Get data files - support all Excel formats
134+
excel_files = []
135+
for ext in ["*.xlsx", "*.xlsb", "*.xlsm"]:
136+
excel_files.extend(task_dir.glob(ext))
137+
excel_files = [f for f in excel_files if "answer" not in f.name.lower()]
138+
139+
# Read Excel content for in-context mode
140+
if incontext_data:
141+
excel_content = ""
142+
for excel_file in excel_files:
143+
sheets_text = read_excel_to_text(excel_file)
144+
excel_content += f"The excel file {excel_file.name} is: {sheets_text}\n\n"
145+
146+
# Format paths for tool mode (relative to data directory)
147+
excel_paths = format_paths_for_prompt(excel_files, actual_root=extracted_data_dir, display_root=display_root)
148+
149+
# Uncomment to get image files and csv files (for future multimodal and agentic support)
150+
# image_files = []
151+
# for ext in ["*.jpg", "*.png", "*.jpeg"]:
152+
# image_files.extend(task_dir.glob(ext))
153+
# csv_files = list(task_dir.glob("*.csv"))
154+
155+
# Process each question
156+
for idx, question_name in enumerate(task["questions"]):
157+
question_file = task_dir / f"{question_name}.txt"
158+
159+
if not question_file.exists():
160+
print(f" Warning: {task_id}/{question_name}.txt not found, skipping")
161+
continue
162+
163+
question_text = question_file.read_text(encoding="utf-8", errors="ignore").strip()
164+
165+
# Build problem text (introduction + question)
166+
problem_text = ""
167+
if introduction:
168+
problem_text += f"The introduction is detailed as follows.\n{introduction}\n\n"
169+
problem_text += f"The question for this task is detailed as follows.\n{question_text}"
170+
171+
# Create entry with all necessary fields
172+
entry = {
173+
# Skills standard fields
174+
"problem": problem_text,
175+
"expected_answer": task["answers"][idx],
176+
# For tool mode
177+
"excel_paths": excel_paths,
178+
# Metadata
179+
"task_id": task_id,
180+
"question_id": question_name,
181+
"task_name": task["name"],
182+
"task_url": task["url"],
183+
"task_year": task["year"],
184+
}
185+
186+
if incontext_data:
187+
entry["excel_content"] = excel_content.strip()
188+
189+
all_entries.append(entry)
190+
191+
# Validate we got some entries
192+
if not all_entries:
193+
raise ValueError(
194+
f"No valid entries created! Processed {len(metadata)} tasks but all failed. "
195+
"Check that data was downloaded correctly and Excel files are readable."
196+
)
197+
198+
# Save to output file
199+
output_file = data_dir / f"{split}.jsonl"
200+
with open(output_file, "w") as f:
201+
for entry in all_entries:
202+
f.write(json.dumps(entry) + "\n")
203+
204+
print(f" ✓ Saved {len(all_entries)} questions to {output_file}")
205+
206+
207+
if __name__ == "__main__":
208+
parser = argparse.ArgumentParser()
209+
parser.add_argument("--split", default="test", choices=("test",), help="DSBench only has test split")
210+
parser.add_argument(
211+
"--data_dir", type=str, default=None, help="Directory to save the data (defaults to dataset directory)"
212+
)
213+
parser.add_argument(
214+
"--display_root",
215+
type=str,
216+
default=None,
217+
help='Root directory to display in paths (absolute for abs paths, Path(".") for relative)',
218+
)
219+
parser.add_argument(
220+
"--incontext_data",
221+
action="store_true",
222+
help="Have the excel files read in-context under 'excel_content' field (Default: False)",
223+
)
224+
args = parser.parse_args()
225+
print(args)
226+
if args.data_dir is None:
227+
# Save to the same directory as this script
228+
data_dir = Path(__file__).absolute().parent
229+
else:
230+
data_dir = Path(args.data_dir)
231+
232+
save_data(args.split, data_dir, args.display_root, args.incontext_data)

nemo_skills/evaluation/evaluator/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
)
3232
from nemo_skills.evaluation.evaluator.compute_eval import ComputeEvalEvaluator
3333
from nemo_skills.evaluation.evaluator.critpt import CritPtEvaluator
34+
from nemo_skills.evaluation.evaluator.dsbench import DSBenchEvaluator
3435
from nemo_skills.evaluation.evaluator.icpc import ICPCEvaluator
3536
from nemo_skills.evaluation.evaluator.ifbench import eval_ifbench
3637
from nemo_skills.evaluation.evaluator.ifeval import eval_if
@@ -76,6 +77,7 @@
7677
"bird": BirdEvaluator,
7778
"compute-eval": ComputeEvalEvaluator,
7879
"critpt": CritPtEvaluator,
80+
"dsbench": DSBenchEvaluator,
7981
}
8082

8183
# Validation: Ensure no overlap between class and function maps
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import json
16+
import logging
17+
import re
18+
from typing import Any
19+
20+
from math_verify import StringExtractionConfig, parse, verify
21+
22+
from nemo_skills.evaluation.evaluator.math import MathEvaluator
23+
from nemo_skills.evaluation.math_grader import math_equal
24+
from nemo_skills.utils import get_logger_name
25+
26+
LOG = logging.getLogger(get_logger_name(__file__))
27+
28+
29+
def relaxed_equal(gt_answer: Any, predicted_answer: Any) -> bool:
30+
"""
31+
Relaxed equality check with:
32+
1. Case-insensitive MCQ matching
33+
2. Dict/list comparison using math_equal recursively
34+
"""
35+
if predicted_answer is None:
36+
return gt_answer is None
37+
38+
try:
39+
predicted_answer = json.loads(predicted_answer)
40+
except Exception:
41+
pass # keep original string form
42+
try:
43+
gt_answer = json.loads(gt_answer)
44+
except Exception:
45+
pass # keep original string form
46+
47+
if isinstance(predicted_answer, dict):
48+
if not isinstance(gt_answer, dict):
49+
# check if any of the values in predicted_answer are equal to gt_answer
50+
return any(relaxed_equal(gt_answer, p) for p in predicted_answer.values())
51+
52+
# check if all the keys in gt_answer are in predicted_answer and if the values are equal; ok for predicted_answer to have more keys
53+
return all(
54+
k in predicted_answer and relaxed_equal(gt_answer[k], predicted_answer[k]) for k in gt_answer.keys()
55+
)
56+
57+
if isinstance(predicted_answer, list):
58+
if not isinstance(gt_answer, list):
59+
# check if any of the values in predicted_answer are equal to gt_answer
60+
return any(relaxed_equal(gt_answer, p) for p in predicted_answer)
61+
# check if the lengths are equal and if all the values are equal
62+
return len(gt_answer) == len(predicted_answer) and all(
63+
relaxed_equal(e, p) for e, p in zip(gt_answer, predicted_answer)
64+
)
65+
66+
# Try case-insensitive MCQ matching
67+
# TODO: add support for numeric and roman numeral MCQs (i.e. "1", "I", "2", "II", etc.)
68+
mcq_options = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
69+
norm_gt_mcq = str(gt_answer).strip().upper()
70+
norm_pred_mcq = str(predicted_answer).strip().upper()
71+
is_mcq = re.fullmatch("|".join(mcq_options), norm_gt_mcq)
72+
if is_mcq:
73+
parsed_gt = parse(norm_gt_mcq, [StringExtractionConfig(strings=tuple(mcq_options))])
74+
parsed_pred = parse(norm_pred_mcq, [StringExtractionConfig(strings=tuple(mcq_options))])
75+
mcq_result = verify(parsed_gt, parsed_pred)
76+
if mcq_result:
77+
return mcq_result
78+
79+
return math_equal(str(gt_answer), str(predicted_answer))
80+
81+
82+
class DSBenchEvaluator(MathEvaluator):
83+
def __init__(self, config: dict, num_parallel_requests=10):
84+
super().__init__(config, num_parallel_requests)
85+
self.eval_config.extract_regex = r"(?:The final answer is |\\boxed=)(.+)$"
86+
87+
async def eval_single(self, data_point: dict[str, Any]) -> dict[str, Any]:
88+
"""Evaluate single DSBench problem with relaxed fallback."""
89+
# First try standard math evaluation
90+
data_point = await super().eval_single(data_point)
91+
92+
# If symbolic_correct is False, try relaxed_equal
93+
if not data_point["symbolic_correct"]:
94+
expected_answer = data_point["expected_answer"]
95+
predicted_answer = data_point["predicted_answer"]
96+
97+
if relaxed_equal(expected_answer, predicted_answer):
98+
data_point["symbolic_correct"] = True
99+
100+
return data_point

0 commit comments

Comments
 (0)