-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathqa_translator.py
162 lines (137 loc) · 6.17 KB
/
qa_translator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import os
import json
import time
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor, as_completed
from validator import HotpotEntryValidator
class QaTranslator(ABC):
@abstractmethod
def translate_text(self, question_object):
pass
def translate_json(self, json_dict):
input_string = json.dumps(json_dict, ensure_ascii=False)
translated_string = self.translate_text(input_string)
translated_string = translated_string.removeprefix("```json").removesuffix("```")
output_json_dict = {}
error_obj={}
try:
output_json_dict = json.loads(translated_string)
except json.JSONDecodeError as e:
error_obj = {
"has_error": True,
"completion": translated_string,
"error_messages": [e.msg],
}
if error_obj:
return error_obj
return output_json_dict
def qa_translate(self, row):
to_translate = {
'question': row['question'],
'answer': row['answer'],
'supporting_facts': row['supporting_facts'],
'context': row['context']
}
translation = self.translate_json(to_translate)
if "has_error" in translation and translation["has_error"]:
return translation
error_messages = HotpotEntryValidator.validate(translation)
if error_messages:
return {
"has_error": True,
"completion": translation,
"error_messages": error_messages,
}
translated_row = {
'id': row['id'],
'question': translation['question'],
'answer': translation['answer'],
'type': row['type'],
'level': row['level'],
'supporting_facts': {
'title': translation['supporting_facts']['title'],
'sent_id': row['supporting_facts']['sent_id']
},
'context': translation['context']
}
return translated_row
def translate_dataset(self, dataset, temp_path, error_path) -> str:
start_time = time.time() # Record the start time
max_rows = 8000
num_threads = 100 # Number of concurrent threads
translation_list = []
id_set = set()
if os.path.exists(temp_path):
print("A temporary file exists, the process will continue where it stopped.")
with open(temp_path, 'r', encoding='utf-8') as file:
temp_file_str = file.read()
temp_file_str = "[" + temp_file_str.rstrip(",\n") + "]"
translation_list = json.loads(temp_file_str)
for temp_element in translation_list:
id_set.add(temp_element["id"])
print(f"Entries recovered from the temporary file: Total={len(translation_list)}, Unique={len(id_set)}")
n_completed = 0
with \
ThreadPoolExecutor(max_workers=num_threads) as executor, \
open(temp_path, "a", encoding="utf-8") as temp_file, \
open(error_path, "a", encoding="utf-8") as error_file:
futures = {}
started_threads = 0
for i, row in enumerate(dataset):
if started_threads >= num_threads:
print("Waiting API to return the results.")
n_success = 0
n_error = 0
for future in as_completed(futures):
translated_row = future.result()
if "has_error" in translated_row and translated_row["has_error"]:
json.dump(translated_row, error_file, ensure_ascii=False, indent=4)
error_file.write(",\n")
error_file.flush()
n_error += 1
else:
translation_list.append(translated_row)
json.dump(translated_row, temp_file, ensure_ascii=False, indent=4)
temp_file.write(",\n")
temp_file.flush()
id_set.add(translated_row["id"])
n_success += 1
n_completed += 1
started_threads = 0
futures = {}
print(f"Partial results saved in temporary file. n_success={n_success} / n_error={n_error}")
if i >= max_rows:
print(f"Max number of rows achieved (max={max_rows}).")
break
if row["id"] in id_set:
print(f"Already translated {i}: {row['id']}")
continue
print(f"Submitting row {i}: {row['id']} for translation")
future = executor.submit(self.qa_translate, row)
futures[future] = row["id"]
started_threads += 1
print("Waiting API to return the last batch of results.")
n_success = 0
n_error = 0
for future in as_completed(futures):
translated_row = future.result()
if "has_error" in translated_row and translated_row["has_error"]:
json.dump(translated_row, error_file, ensure_ascii=False, indent=4)
error_file.write(",\n")
error_file.flush()
n_error += 1
else:
translation_list.append(translated_row)
json.dump(translated_row, temp_file, ensure_ascii=False, indent=4)
temp_file.write(",\n")
temp_file.flush()
id_set.add(translated_row["id"])
n_success += 1
n_completed += 1
started_threads = 0
futures = {}
print(f"Partial results saved in temporary file. n_success={n_success} / n_error={n_error}")
end_time = time.time()
elapsed_time = end_time - start_time
print(f"Time spent in execution: {elapsed_time:.3f} seconds.")
return translation_list