Skip to content

Commit b65c7d1

Browse files
committed
Add script for evaluating MRR, recall@k
1 parent f1d849b commit b65c7d1

File tree

1 file changed

+225
-0
lines changed

1 file changed

+225
-0
lines changed

utils/evaluate_retrieval.py

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
#
2+
# For licensing see accompanying LICENSE file.
3+
# Copyright (C) 2021 Apple Inc. All Rights Reserved.
4+
#
5+
6+
from argparse import ArgumentParser
7+
import json
8+
import logging
9+
from multiprocessing import Pool
10+
from pathlib import Path
11+
from typing import Dict, List, Tuple
12+
13+
from evaluate_qa import compute_exact, compute_f1
14+
from span_heuristic import find_closest_span_match
15+
16+
"""
17+
Functions for evaluating passage retrieval.
18+
19+
This is used to compute MRR (mean reciprocal rank), Recall@10, and Recall@100 in Table 5 of the paper.
20+
"""
21+
22+
23+
RELEVANCE_THRESHOLD = 0.8
24+
25+
26+
def compute_f1_for_retrieved_passage(line: str) -> dict:
27+
"""
28+
Given a serialized JSON line, with fields 'content' and 'answer', find the closest span matching answer,
29+
update the deserialized dict with the span and F1 score, and return the dict.
30+
"""
31+
data = json.loads(line)
32+
content, answer = data['content'], data['answer']
33+
34+
# If there is no answer, although the closest extractive answer is '', in the MRR and recall@k functions below
35+
# we do not count any passage for these questions as relevant.
36+
if len(answer) < 1:
37+
data['heuristic_answer'] = ''
38+
data['f1'] = compute_f1(answer, '')
39+
return data
40+
41+
best_span, best_f1 = find_closest_span_match(content, answer)
42+
43+
data['heuristic_answer'] = best_span
44+
data['f1'] = best_f1
45+
46+
return data
47+
48+
49+
def compute_mean_reciprocal_rank(
50+
question_id_to_docs: Dict[str, List[dict]], relevance_threshold: float
51+
) -> float:
52+
"""Given a dictionary mapping a question id to a list of docs, find the mean reciprocal rank."""
53+
recip_rank_sum = 0
54+
for qid, docs in question_id_to_docs.items():
55+
top_rank = float('inf')
56+
for doc in docs:
57+
if len(doc['answer']) > 0 and doc['f1'] >= relevance_threshold:
58+
top_rank = min(top_rank, doc['rank'])
59+
60+
recip_rank = 1 / top_rank if top_rank != float('inf') else 0
61+
recip_rank_sum += recip_rank
62+
63+
return recip_rank_sum / len(question_id_to_docs)
64+
65+
66+
def compute_recall_at_k(
67+
question_id_to_docs: Dict[str, List[dict]], k: int, relevance_threshold: float
68+
) -> float:
69+
"""
70+
Given a dictionary mapping a question id to a list of docs, find the recall@k.
71+
72+
We define recall@k = 1.0 if any document in the top-k is relevant, and 0 otherwise.
73+
"""
74+
relevant_doc_found_total = 0
75+
for qid, docs in question_id_to_docs.items():
76+
relevant_doc_found = 0
77+
for doc in docs:
78+
if len(doc['answer']) > 0 and doc['f1'] >= relevance_threshold and doc['rank'] <= k:
79+
relevant_doc_found = 1
80+
break
81+
82+
relevant_doc_found_total += relevant_doc_found
83+
84+
return relevant_doc_found_total / len(question_id_to_docs)
85+
86+
87+
def compute_extractive_upper_bounds(
88+
question_id_to_docs: Dict[str, List[dict]], temp_files_directory: Path
89+
) -> Tuple[float, float]:
90+
"""Given a dictionary mapping a question id to a list of docs, find the extractive upper bounds of (EM, F1)."""
91+
total_em, total_f1 = 0, 0.0
92+
with open(temp_files_directory / 'retrieved-passages-relevant-f1.jsonl', 'w') as outfile:
93+
for qid, docs in question_id_to_docs.items():
94+
best_em, best_f1 = 0, 0.0
95+
best_doc = docs[0]
96+
for doc in docs:
97+
em = compute_exact(doc['answer'], doc['heuristic_answer'])
98+
f1 = compute_f1(doc['answer'], doc['heuristic_answer'])
99+
if f1 > best_f1:
100+
best_doc = doc
101+
best_em = max(best_em, em)
102+
best_f1 = max(best_f1, f1)
103+
if best_em == 1 and best_f1 == 1.0:
104+
break
105+
106+
total_em += best_em
107+
total_f1 += best_f1
108+
109+
outfile.write(json.dumps(best_doc) + '\n')
110+
111+
return (
112+
total_em / len(question_id_to_docs),
113+
total_f1 / len(question_id_to_docs),
114+
)
115+
116+
117+
def get_unique_relevant_docs_count(
118+
question_id_to_docs: Dict[str, List[dict]], relevance_threshold: float
119+
) -> float:
120+
"""Given a dictionary mapping a question id to a list of docs, find the number of unique relevant docs."""
121+
unique_relevant_docs = set()
122+
for qid, docs in question_id_to_docs.items():
123+
for doc in docs:
124+
if len(doc['answer']) > 0 and doc['f1'] >= relevance_threshold:
125+
unique_relevant_docs.add(doc['docid'])
126+
127+
return len(unique_relevant_docs)
128+
129+
130+
def get_average_relevant_docs_per_question(
131+
question_id_to_docs: Dict[str, List[dict]], relevance_threshold: float
132+
) -> float:
133+
"""Given a dictionary mapping a question id to a list of docs, find the average number of relevant docs per question."""
134+
relevant_docs = 0
135+
for qid, docs in question_id_to_docs.items():
136+
for doc in docs:
137+
if len(doc['answer']) > 0 and doc['f1'] >= relevance_threshold:
138+
relevant_docs += 1
139+
140+
return relevant_docs / len(question_id_to_docs)
141+
142+
143+
def main(retrieved_passages_pattern: str, temp_files_directory: str, workers: int):
144+
retrieved_passages_files = Path().glob(retrieved_passages_pattern)
145+
temp_files_directory = Path(temp_files_directory)
146+
temp_files_directory.mkdir(exist_ok=True, parents=True)
147+
148+
question_id_to_docs = {}
149+
150+
for retrieved_passages_file in retrieved_passages_files:
151+
with open(retrieved_passages_file) as infile:
152+
with Pool(workers) as p:
153+
for i, passage_results in enumerate(
154+
p.imap(compute_f1_for_retrieved_passage, infile)
155+
):
156+
if (i + 1) % 5000 == 0:
157+
logging.info(
158+
f'Processing {retrieved_passages_file.name}, {i + 1} lines done...'
159+
)
160+
161+
qid = f"{passage_results['Conversation-ID']}_{passage_results['Turn-ID']}"
162+
if qid not in question_id_to_docs:
163+
question_id_to_docs[qid] = []
164+
165+
question_id_to_docs[qid].append(
166+
{
167+
'Conversation-ID': passage_results['Conversation-ID'],
168+
'Turn-ID': passage_results['Turn-ID'],
169+
'docid': passage_results['docid'],
170+
'content': passage_results['content'],
171+
'rank': passage_results['rank'],
172+
'answer': passage_results['answer'],
173+
'heuristic_answer': passage_results['heuristic_answer'],
174+
'f1': passage_results['f1'],
175+
}
176+
)
177+
178+
print('Final metrics:')
179+
unique_relevant_docs = get_unique_relevant_docs_count(question_id_to_docs, RELEVANCE_THRESHOLD)
180+
unique_docs_perfect_f1 = get_unique_relevant_docs_count(question_id_to_docs, 1.0)
181+
avg_relevant_docs_per_question = get_average_relevant_docs_per_question(
182+
question_id_to_docs, 1.0
183+
)
184+
185+
print(f'Total number of unique queries: {len(question_id_to_docs)}')
186+
print(f'Total number of unique relevant docs: {unique_relevant_docs}')
187+
print(f'Total number of unique docs with F1=1.0: {unique_docs_perfect_f1}')
188+
print(f'Average number of relevant docs per query: {avg_relevant_docs_per_question}')
189+
190+
mrr = compute_mean_reciprocal_rank(question_id_to_docs, RELEVANCE_THRESHOLD)
191+
recall_at_10 = compute_recall_at_k(question_id_to_docs, 10, RELEVANCE_THRESHOLD)
192+
recall_at_100 = compute_recall_at_k(question_id_to_docs, 100, RELEVANCE_THRESHOLD)
193+
print(f'Mean Reciprocal Rank (MRR): {mrr:.4f}')
194+
print(f'Recall@10: {recall_at_10 * 100:.2f}%')
195+
print(f'Recall@100: {recall_at_100 * 100:.2f}%')
196+
197+
em_upper_bound, f1_upper_bound = compute_extractive_upper_bounds(
198+
question_id_to_docs, temp_files_directory
199+
)
200+
print(f'Extractive Upper Bound for EM (100 point scale): {em_upper_bound * 100:.2f}')
201+
print(f'Extractive Upper Bound for F1 (100 point scale): {f1_upper_bound * 100:.2f}')
202+
203+
204+
if __name__ == '__main__':
205+
parser = ArgumentParser(description='Passage retrieval evaluation')
206+
parser.add_argument(
207+
'--retrieved-passages-pattern',
208+
required=True,
209+
help="""A globbing pattern to select .jsonl files containing retrieved passages.
210+
Each json line should contain the fields 'Conversation-ID', 'Turn-ID', 'docid', 'content', 'answer', 'rank'.
211+
'answer' is the gold answer given in the QReCC dataset and rank is the rank of the document starting from 1.""",
212+
)
213+
parser.add_argument(
214+
'--temp-files-directory',
215+
default='/tmp/qrecc-retrieval-eval',
216+
help='Directory to store temporary files containing F1 scores, which can be used for debugging and analysis',
217+
)
218+
parser.add_argument(
219+
'--workers', default=8, type=int, help='Number of workers for parallel processing',
220+
)
221+
args = parser.parse_args()
222+
223+
logging.basicConfig(level=logging.INFO)
224+
225+
main(args.retrieved_passages_pattern, args.temp_files_directory, args.workers)

0 commit comments

Comments
 (0)