-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'trintamaki/more-evals' into 'main'
More multimodal evals See merge request ADLR/megatron-lm!2174
- Loading branch information
Showing
10 changed files
with
1,206 additions
and
549 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import argparse | ||
import json | ||
|
||
from evaluate_mmmu import get_input_output_paths | ||
from evaluate_vqav2 import compute_vqa_accuracy | ||
|
||
|
||
def merge_input_files(input_path): | ||
"""Merge input files to a format compatible with the evaluator.""" | ||
input_file_paths, output_file_path = get_input_output_paths(input_path, task="AI2D") | ||
|
||
results = [] | ||
|
||
for input_file_path in input_file_paths: | ||
with open(input_file_path, "r") as input_file: | ||
for line in input_file: | ||
res = json.loads(line) | ||
results.append( | ||
{ | ||
"question_id": res["sample_id"], | ||
"answer": res["answer"], | ||
"gt_answer": res["gt_answer"], | ||
} | ||
) | ||
|
||
with open(output_file_path, "w") as output_file: | ||
json.dump(results, output_file) | ||
|
||
return output_file_path | ||
|
||
|
||
def ai2d_eval(input_path): | ||
"""Run AI2D evaluation.""" | ||
result_file_path = merge_input_files(input_path) | ||
avg_acc = compute_vqa_accuracy(result_file_path, task="AI2D") | ||
return avg_acc | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--input-path', type=str, help="Path to input file(s)") | ||
args = parser.parse_args() | ||
|
||
avg_acc = ai2d_eval(args.input_path) | ||
|
||
print(f"===== AI2D Accuracy {avg_acc:.2f}% =====") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
import argparse | ||
import json | ||
import re | ||
|
||
from evaluate_mmmu import get_input_output_paths | ||
from MMMU.mmmu.utils.eval_utils import parse_multi_choice_response | ||
from open_flamingo.eval.vqa_metric import VQAEval | ||
|
||
|
||
def merge_input_files(input_path): | ||
"""Merge input files to a format compatible with the evaluator.""" | ||
input_file_paths, output_file_path = get_input_output_paths(input_path, task="MathVista") | ||
|
||
results = [] | ||
|
||
for input_file_path in input_file_paths: | ||
with open(input_file_path, "r") as input_file: | ||
for line in input_file: | ||
res = json.loads(line) | ||
results.append(res) | ||
|
||
with open(output_file_path, "w") as output_file: | ||
json.dump(results, output_file) | ||
|
||
return output_file_path | ||
|
||
|
||
def extra_processing(text): | ||
"""Extra processing.""" | ||
# Max decimal point capped to 2 decimal point | ||
regex = re.compile(r'^\d+\.\d+$') | ||
decimal = regex.findall(text) | ||
|
||
if len(decimal) > 0: | ||
non_decimal = len(decimal[0].split(".")[0]) | ||
|
||
# if decimal values are all 0, trim them | ||
decimal_digits = [int(d) for d in decimal[0].split(".")[1]] | ||
if sum(decimal_digits) == 0: | ||
text = decimal[0][:non_decimal] | ||
else: | ||
text = decimal[0][: non_decimal + 3] | ||
|
||
# remove % and trailing . | ||
text = text.replace("%", "") | ||
if text[-1] == ".": | ||
text = text[:-1] | ||
|
||
return text | ||
|
||
|
||
def extract_answer(text): | ||
"""Extract answer.""" | ||
alphabet = re.findall(r'[a-zA-Z]+', text) | ||
if len(alphabet) > 0 and "e+" not in text: | ||
template = re.findall(r'answer is -*\d+\.*\d*', text) | ||
if len(template) > 0: | ||
text = template[0] | ||
|
||
numbers = re.findall(r'-*\d+\.*\d*', text) | ||
text = numbers[0] if len(numbers) > 0 else text | ||
|
||
return text | ||
|
||
|
||
def compute_mathvista_accuracy(result_file): | ||
"""Compute MathVista accuracy.""" | ||
merged_results = json.load(open(result_file)) | ||
|
||
vqa = VQAEval(vqa=None, vqaRes=None) | ||
acc = 0 | ||
for res in merged_results: | ||
pred_ans = res["answer"] | ||
if res["question_type"] == "multi_choice": | ||
pred_ans = parse_multi_choice_response(pred_ans, res["all_choices"], res["index2ans"]) | ||
else: | ||
pred_ans = vqa.processPunctuation(pred_ans) | ||
pred_ans = vqa.processDigitArticle(pred_ans) | ||
# Extra processing and extraction. | ||
pred_ans = extra_processing(pred_ans) | ||
pred_ans = extract_answer(pred_ans) | ||
|
||
gt_ans = res["gt_answer"] | ||
if isinstance(gt_ans, list): | ||
assert len(gt_ans) == 1, f"Expected 1 groundtruth, got {gt_ans}" | ||
gt_ans = gt_ans[0] | ||
|
||
if res["question_type"] != "multi_choice": | ||
gt_ans = vqa.processPunctuation(gt_ans) | ||
gt_ans = vqa.processDigitArticle(gt_ans) | ||
|
||
gt_ans = extra_processing(gt_ans) | ||
|
||
if pred_ans == gt_ans: | ||
acc += 1 | ||
acc = acc / len(merged_results) * 100 | ||
return acc | ||
|
||
|
||
def mathvista_eval(input_path): | ||
"""Run MathVista evaluation.""" | ||
result_file_path = merge_input_files(input_path) | ||
acc = compute_mathvista_accuracy(result_file_path) | ||
return acc | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--input-path', type=str, help="Path to input file(s)") | ||
args = parser.parse_args() | ||
|
||
acc = mathvista_eval(args.input_path) | ||
|
||
print(f"===== MathVista accuracy: {acc} =====") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
import argparse | ||
import json | ||
|
||
from evaluate_mmmu import get_input_output_paths | ||
|
||
|
||
def merge_input_files(input_path): | ||
"""Merge input files to a format compatible with the evaluator.""" | ||
input_file_paths, output_file_path = get_input_output_paths(input_path, task="OCRBench") | ||
|
||
results = [] | ||
|
||
for input_file_path in input_file_paths: | ||
with open(input_file_path, "r") as input_file: | ||
for line in input_file: | ||
res = json.loads(line) | ||
results.append(res) | ||
|
||
with open(output_file_path, "w") as output_file: | ||
json.dump(results, output_file) | ||
|
||
return output_file_path | ||
|
||
|
||
def compute_ocrbench_score(result_file): | ||
"""Compute OCRBench score.""" | ||
merged_results = json.load(open(result_file)) | ||
|
||
# OCRBench score calculation is adopted from https://github.com/Yuliang-Liu/MultimodalOCR/blob/1b7713f44c91f30f64efb6d3e494c416861ef15f/example.py#L1 | ||
# MIT License. Copyright (c) 2023 Yuliang Liu | ||
score = { | ||
"Regular Text Recognition": 0, | ||
"Irregular Text Recognition": 0, | ||
"Artistic Text Recognition": 0, | ||
"Handwriting Recognition": 0, | ||
"Digit String Recognition": 0, | ||
"Non-Semantic Text Recognition": 0, | ||
"Scene Text-centric VQA": 0, | ||
"Doc-oriented VQA": 0, | ||
"Doc-oriented VQA": 0, | ||
"Key Information Extraction": 0, | ||
"Handwritten Mathematical Expression Recognition": 0, | ||
} | ||
|
||
for res in merged_results: | ||
predict = res["answer"] | ||
answers = res["gt_answer"] | ||
|
||
dataset_name = res["dataset_name"] | ||
ocr_type = res["data_type"] | ||
|
||
if dataset_name == "HME100k": | ||
if isinstance(answers, list): | ||
for j in range(len(answers)): | ||
answer = answers[j].strip().replace("\n", " ").replace(" ", "") | ||
predict = predict.strip().replace("\n", " ").replace(" ", "") | ||
if answer in predict: | ||
score[ocr_type] += 1 | ||
else: | ||
answers = answers.strip().replace("\n", " ").replace(" ", "") | ||
predict = predict.strip().replace("\n", " ").replace(" ", "") | ||
if answers in predict: | ||
score[ocr_type] += 1 | ||
else: | ||
if isinstance(answers, list): | ||
for j in range(len(answers)): | ||
answer = answers[j].lower().strip().replace("\n", " ") | ||
predict = predict.lower().strip().replace("\n", " ") | ||
if answer in predict: | ||
score[ocr_type] += 1 | ||
else: | ||
answers = answers.lower().strip().replace("\n", " ") | ||
predict = predict.lower().strip().replace("\n", " ") | ||
if answers in predict: | ||
score[ocr_type] += 1 | ||
|
||
recognition_score = ( | ||
score['Regular Text Recognition'] | ||
+ score['Irregular Text Recognition'] | ||
+ score['Artistic Text Recognition'] | ||
+ score['Handwriting Recognition'] | ||
+ score['Digit String Recognition'] | ||
+ score['Non-Semantic Text Recognition'] | ||
) | ||
final_score = ( | ||
recognition_score | ||
+ score['Scene Text-centric VQA'] | ||
+ score['Doc-oriented VQA'] | ||
+ score['Key Information Extraction'] | ||
+ score['Handwritten Mathematical Expression Recognition'] | ||
) | ||
result_log = f"""###########################OCRBench############################## | ||
Text Recognition(Total 300): {recognition_score} | ||
------------------Details of Recognition Score------------------- | ||
Regular Text Recognition(Total 50): {score['Regular Text Recognition']} | ||
Irregular Text Recognition(Total 50): {score['Irregular Text Recognition']} | ||
Artistic Text Recognition(Total 50): {score['Artistic Text Recognition']} | ||
Handwriting Recognition(Total 50): {score['Handwriting Recognition']} | ||
Digit String Recognition(Total 50): {score['Digit String Recognition']} | ||
Non-Semantic Text Recognition(Total 50): {score['Non-Semantic Text Recognition']} | ||
---------------------------------------------------------------- | ||
Scene Text-centric VQA(Total 200): {score['Scene Text-centric VQA']} | ||
---------------------------------------------------------------- | ||
Doc-oriented VQA(Total 200): {score['Doc-oriented VQA']} | ||
---------------------------------------------------------------- | ||
Key Information Extraction(Total 200): {score['Key Information Extraction']} | ||
---------------------------------------------------------------- | ||
Handwritten Mathematical Expression Recognition(Total 100): {score['Handwritten Mathematical Expression Recognition']} | ||
----------------------Final Score------------------------------- | ||
Final Score(Total 1000): {final_score}""" | ||
|
||
return result_log, final_score | ||
|
||
|
||
def ocrbench_eval(input_path): | ||
"""Run OCRBench evaluation.""" | ||
result_file_path = merge_input_files(input_path) | ||
result_log, score = compute_ocrbench_score(result_file_path) | ||
return result_log, score | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--input-path', type=str, help="Path to input file(s)") | ||
args = parser.parse_args() | ||
|
||
result_log, _ = ocrbench_eval(args.input_path) | ||
|
||
print(result_log) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.