Skip to content

Commit

Permalink
Merge branch 'trintamaki/more-evals' into 'main'
Browse files Browse the repository at this point in the history
More multimodal evals

See merge request ADLR/megatron-lm!2174
  • Loading branch information
jon-barker committed Oct 31, 2024
2 parents d546182 + 9ed8473 commit 2e2bdf6
Show file tree
Hide file tree
Showing 10 changed files with 1,206 additions and 549 deletions.
46 changes: 46 additions & 0 deletions examples/multimodal/evaluate_ai2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import argparse
import json

from evaluate_mmmu import get_input_output_paths
from evaluate_vqav2 import compute_vqa_accuracy


def merge_input_files(input_path):
"""Merge input files to a format compatible with the evaluator."""
input_file_paths, output_file_path = get_input_output_paths(input_path, task="AI2D")

results = []

for input_file_path in input_file_paths:
with open(input_file_path, "r") as input_file:
for line in input_file:
res = json.loads(line)
results.append(
{
"question_id": res["sample_id"],
"answer": res["answer"],
"gt_answer": res["gt_answer"],
}
)

with open(output_file_path, "w") as output_file:
json.dump(results, output_file)

return output_file_path


def ai2d_eval(input_path):
"""Run AI2D evaluation."""
result_file_path = merge_input_files(input_path)
avg_acc = compute_vqa_accuracy(result_file_path, task="AI2D")
return avg_acc


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--input-path', type=str, help="Path to input file(s)")
args = parser.parse_args()

avg_acc = ai2d_eval(args.input_path)

print(f"===== AI2D Accuracy {avg_acc:.2f}% =====")
2 changes: 1 addition & 1 deletion examples/multimodal/evaluate_chartqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def merge_input_files(input_path):
def chartqa_eval(input_path):
"""Run ChartQA evaluation."""
result_file_path = merge_input_files(input_path)
return compute_vqa_accuracy(result_file_path, use_chartqa_metric=True)
return compute_vqa_accuracy(result_file_path, task="ChartQA")


if __name__ == "__main__":
Expand Down
114 changes: 114 additions & 0 deletions examples/multimodal/evaluate_mathvista.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import argparse
import json
import re

from evaluate_mmmu import get_input_output_paths
from MMMU.mmmu.utils.eval_utils import parse_multi_choice_response
from open_flamingo.eval.vqa_metric import VQAEval


def merge_input_files(input_path):
"""Merge input files to a format compatible with the evaluator."""
input_file_paths, output_file_path = get_input_output_paths(input_path, task="MathVista")

results = []

for input_file_path in input_file_paths:
with open(input_file_path, "r") as input_file:
for line in input_file:
res = json.loads(line)
results.append(res)

with open(output_file_path, "w") as output_file:
json.dump(results, output_file)

return output_file_path


def extra_processing(text):
"""Extra processing."""
# Max decimal point capped to 2 decimal point
regex = re.compile(r'^\d+\.\d+$')
decimal = regex.findall(text)

if len(decimal) > 0:
non_decimal = len(decimal[0].split(".")[0])

# if decimal values are all 0, trim them
decimal_digits = [int(d) for d in decimal[0].split(".")[1]]
if sum(decimal_digits) == 0:
text = decimal[0][:non_decimal]
else:
text = decimal[0][: non_decimal + 3]

# remove % and trailing .
text = text.replace("%", "")
if text[-1] == ".":
text = text[:-1]

return text


def extract_answer(text):
"""Extract answer."""
alphabet = re.findall(r'[a-zA-Z]+', text)
if len(alphabet) > 0 and "e+" not in text:
template = re.findall(r'answer is -*\d+\.*\d*', text)
if len(template) > 0:
text = template[0]

numbers = re.findall(r'-*\d+\.*\d*', text)
text = numbers[0] if len(numbers) > 0 else text

return text


def compute_mathvista_accuracy(result_file):
"""Compute MathVista accuracy."""
merged_results = json.load(open(result_file))

vqa = VQAEval(vqa=None, vqaRes=None)
acc = 0
for res in merged_results:
pred_ans = res["answer"]
if res["question_type"] == "multi_choice":
pred_ans = parse_multi_choice_response(pred_ans, res["all_choices"], res["index2ans"])
else:
pred_ans = vqa.processPunctuation(pred_ans)
pred_ans = vqa.processDigitArticle(pred_ans)
# Extra processing and extraction.
pred_ans = extra_processing(pred_ans)
pred_ans = extract_answer(pred_ans)

gt_ans = res["gt_answer"]
if isinstance(gt_ans, list):
assert len(gt_ans) == 1, f"Expected 1 groundtruth, got {gt_ans}"
gt_ans = gt_ans[0]

if res["question_type"] != "multi_choice":
gt_ans = vqa.processPunctuation(gt_ans)
gt_ans = vqa.processDigitArticle(gt_ans)

gt_ans = extra_processing(gt_ans)

if pred_ans == gt_ans:
acc += 1
acc = acc / len(merged_results) * 100
return acc


def mathvista_eval(input_path):
"""Run MathVista evaluation."""
result_file_path = merge_input_files(input_path)
acc = compute_mathvista_accuracy(result_file_path)
return acc


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--input-path', type=str, help="Path to input file(s)")
args = parser.parse_args()

acc = mathvista_eval(args.input_path)

print(f"===== MathVista accuracy: {acc} =====")
10 changes: 9 additions & 1 deletion examples/multimodal/evaluate_mmmu.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,14 @@ def convert_to_mmmu_format(input_path):
sample_id = res["sample_id"]
prediction = res["prediction"]

if res["question_type"] == "multiple-choice":
from MMMU.mmmu.utils.eval_utils import parse_multi_choice_response

prediction = parse_multi_choice_response(
prediction, res["all_choices"], res["index2ans"]
)

# MMMU eval script expects just a sample_id to prediction mapping.
output[sample_id] = prediction

with open(output_file_path, "w") as output_file:
Expand Down Expand Up @@ -69,7 +77,7 @@ def mmmu_eval(input_path, groundtruth_path):
print(output.stderr)
print(output.stdout)

m = re.search("'Overall': {'num': \d, 'acc': (\d.\d+)}", output.stdout)
m = re.search("'Overall': {'num': \d+, 'acc': (\d.\d+)}", output.stdout)

return float(m.group(1)) * 100.0

Expand Down
129 changes: 129 additions & 0 deletions examples/multimodal/evaluate_ocrbench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import argparse
import json

from evaluate_mmmu import get_input_output_paths


def merge_input_files(input_path):
"""Merge input files to a format compatible with the evaluator."""
input_file_paths, output_file_path = get_input_output_paths(input_path, task="OCRBench")

results = []

for input_file_path in input_file_paths:
with open(input_file_path, "r") as input_file:
for line in input_file:
res = json.loads(line)
results.append(res)

with open(output_file_path, "w") as output_file:
json.dump(results, output_file)

return output_file_path


def compute_ocrbench_score(result_file):
"""Compute OCRBench score."""
merged_results = json.load(open(result_file))

# OCRBench score calculation is adopted from https://github.com/Yuliang-Liu/MultimodalOCR/blob/1b7713f44c91f30f64efb6d3e494c416861ef15f/example.py#L1
# MIT License. Copyright (c) 2023 Yuliang Liu
score = {
"Regular Text Recognition": 0,
"Irregular Text Recognition": 0,
"Artistic Text Recognition": 0,
"Handwriting Recognition": 0,
"Digit String Recognition": 0,
"Non-Semantic Text Recognition": 0,
"Scene Text-centric VQA": 0,
"Doc-oriented VQA": 0,
"Doc-oriented VQA": 0,
"Key Information Extraction": 0,
"Handwritten Mathematical Expression Recognition": 0,
}

for res in merged_results:
predict = res["answer"]
answers = res["gt_answer"]

dataset_name = res["dataset_name"]
ocr_type = res["data_type"]

if dataset_name == "HME100k":
if isinstance(answers, list):
for j in range(len(answers)):
answer = answers[j].strip().replace("\n", " ").replace(" ", "")
predict = predict.strip().replace("\n", " ").replace(" ", "")
if answer in predict:
score[ocr_type] += 1
else:
answers = answers.strip().replace("\n", " ").replace(" ", "")
predict = predict.strip().replace("\n", " ").replace(" ", "")
if answers in predict:
score[ocr_type] += 1
else:
if isinstance(answers, list):
for j in range(len(answers)):
answer = answers[j].lower().strip().replace("\n", " ")
predict = predict.lower().strip().replace("\n", " ")
if answer in predict:
score[ocr_type] += 1
else:
answers = answers.lower().strip().replace("\n", " ")
predict = predict.lower().strip().replace("\n", " ")
if answers in predict:
score[ocr_type] += 1

recognition_score = (
score['Regular Text Recognition']
+ score['Irregular Text Recognition']
+ score['Artistic Text Recognition']
+ score['Handwriting Recognition']
+ score['Digit String Recognition']
+ score['Non-Semantic Text Recognition']
)
final_score = (
recognition_score
+ score['Scene Text-centric VQA']
+ score['Doc-oriented VQA']
+ score['Key Information Extraction']
+ score['Handwritten Mathematical Expression Recognition']
)
result_log = f"""###########################OCRBench##############################
Text Recognition(Total 300): {recognition_score}
------------------Details of Recognition Score-------------------
Regular Text Recognition(Total 50): {score['Regular Text Recognition']}
Irregular Text Recognition(Total 50): {score['Irregular Text Recognition']}
Artistic Text Recognition(Total 50): {score['Artistic Text Recognition']}
Handwriting Recognition(Total 50): {score['Handwriting Recognition']}
Digit String Recognition(Total 50): {score['Digit String Recognition']}
Non-Semantic Text Recognition(Total 50): {score['Non-Semantic Text Recognition']}
----------------------------------------------------------------
Scene Text-centric VQA(Total 200): {score['Scene Text-centric VQA']}
----------------------------------------------------------------
Doc-oriented VQA(Total 200): {score['Doc-oriented VQA']}
----------------------------------------------------------------
Key Information Extraction(Total 200): {score['Key Information Extraction']}
----------------------------------------------------------------
Handwritten Mathematical Expression Recognition(Total 100): {score['Handwritten Mathematical Expression Recognition']}
----------------------Final Score-------------------------------
Final Score(Total 1000): {final_score}"""

return result_log, final_score


def ocrbench_eval(input_path):
"""Run OCRBench evaluation."""
result_file_path = merge_input_files(input_path)
result_log, score = compute_ocrbench_score(result_file_path)
return result_log, score


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--input-path', type=str, help="Path to input file(s)")
args = parser.parse_args()

result_log, _ = ocrbench_eval(args.input_path)

print(result_log)
2 changes: 1 addition & 1 deletion examples/multimodal/evaluate_textvqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def merge_input_files(input_path):
def textvqa_eval(input_path):
"""Run TextVQA evaluation."""
result_file_path = merge_input_files(input_path)
avg_acc = compute_vqa_accuracy(result_file_path)
avg_acc = compute_vqa_accuracy(result_file_path, task="TextVQA")
return avg_acc


Expand Down
16 changes: 11 additions & 5 deletions examples/multimodal/evaluate_vqav2.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def is_number(n: str):
return False


def compute_vqa_accuracy(result_file, use_chartqa_metric=False):
def compute_vqa_accuracy(result_file, task):
"""Compute VQA accuracy."""
merged_results = json.load(open(result_file))

Expand All @@ -51,8 +51,8 @@ def compute_vqa_accuracy(result_file, use_chartqa_metric=False):

# ChartQA uses relaxed accuracy:
# "We consider an answer to be correct if it is within 5% of the gold answer.
# For non-numeric answers, we still need an exact match to consider an answer to be correct."
if use_chartqa_metric:
# For non-numeric answers, we still need an exact match to consider an answer to be correct."
if task == "ChartQA":
acc = 0.0
assert len(gt) == 1, "expected exactly one groundtruth answer."
gt = gt[0]
Expand All @@ -66,10 +66,16 @@ def compute_vqa_accuracy(result_file, use_chartqa_metric=False):
acc = 1.0

all_acc.append(acc)
else:
elif task in ("VQAv2", "TextVQA"):
num_match = sum([pred == ans for ans in gt])
acc = min(1.0, num_match / 3.0)
all_acc.append(acc)
elif task == "AI2D":
assert len(gt) == 1, f"Expected exactly 1 GT, got {gt}"
acc = pred == gt[0]
all_acc.append(acc)
else:
raise NotImplementedError(f"unknown task {task}")

acc_avg = sum(all_acc) / len(all_acc) * 100

Expand All @@ -79,7 +85,7 @@ def compute_vqa_accuracy(result_file, use_chartqa_metric=False):
def vqav2_eval(input_path):
"""Run VQAv2 evaluation."""
result_file = merge_input_files(input_path)
avg_acc = compute_vqa_accuracy(result_file)
avg_acc = compute_vqa_accuracy(result_file, task="VQAv2")
return avg_acc


Expand Down
Loading

0 comments on commit 2e2bdf6

Please sign in to comment.