Skip to content

Commit b8b4f74

Browse files
committed
Print accuracy@k when using beam search
1 parent 022b600 commit b8b4f74

File tree

1 file changed

+27
-11
lines changed

1 file changed

+27
-11
lines changed

model.py

+27-11
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,10 @@ def train(self):
104104
self.epochs_trained += self.config.SAVE_EVERY_EPOCHS
105105
print('Finished %d epochs' % self.config.SAVE_EVERY_EPOCHS)
106106
results, precision, recall, f1 = self.evaluate()
107-
print('Accuracy after %d epochs: %.5f' % (self.epochs_trained, results))
107+
if self.config.BEAM_WIDTH == 0:
108+
print('Accuracy after %d epochs: %.5f' % (self.epochs_trained, results))
109+
else:
110+
print('Accuracy after {} epochs: {}'.format(self.epochs_trained, results))
108111
print('After %d epochs: Precision: %.5f, recall: %.5f, F1: %.5f' % (
109112
self.epochs_trained, precision, recall, f1))
110113
if f1 > best_f1:
@@ -167,7 +170,8 @@ def evaluate(self, release=False):
167170
with open(model_dirname + '/log.txt', 'w') as output_file, open(ref_file_name, 'w') as ref_file, open(
168171
predicted_file_name,
169172
'w') as pred_file:
170-
num_correct_predictions = 0
173+
num_correct_predictions = 0 if self.config.BEAM_WIDTH == 0 \
174+
else np.zeros([self.config.BEAM_WIDTH], dtype=np.int32)
171175
total_predictions = 0
172176
total_prediction_batches = 0
173177
true_positive, false_positive, false_negative = 0, 0, 0
@@ -223,17 +227,29 @@ def evaluate(self, release=False):
223227

224228
def update_correct_predictions(self, num_correct_predictions, output_file, results):
225229
for original_name, predicted in results:
230+
original_name_parts = original_name.split(Common.internal_delimiter) # list
231+
filtered_original = Common.filter_impossible_names(original_name_parts) # list
232+
predicted_first = predicted
226233
if self.config.BEAM_WIDTH > 0:
227-
predicted = predicted[0]
228-
original_name_parts = original_name.split(Common.internal_delimiter)
229-
filtered_original = Common.filter_impossible_names(original_name_parts)
230-
filtered_predicted_parts = Common.filter_impossible_names(predicted)
234+
predicted_first = predicted[0]
235+
filtered_predicted_first_parts = Common.filter_impossible_names(predicted_first) # list
231236
output_file.write('Original: ' + Common.internal_delimiter.join(original_name_parts) +
232-
' , predicted 1st: ' + Common.internal_delimiter.join(
233-
[target for target in filtered_predicted_parts]) + '\n')
234-
if filtered_original == filtered_predicted_parts or Common.unique(filtered_original) == Common.unique(
235-
filtered_predicted_parts) or ''.join(filtered_original) == ''.join(filtered_predicted_parts):
236-
num_correct_predictions += 1
237+
' , predicted 1st: ' + Common.internal_delimiter.join(filtered_predicted_first_parts) + '\n')
238+
239+
if self.config.BEAM_WIDTH == 0:
240+
if filtered_original == filtered_predicted_first_parts or Common.unique(filtered_original) == Common.unique(
241+
filtered_predicted_first_parts) or ''.join(filtered_original) == ''.join(filtered_predicted_first_parts):
242+
num_correct_predictions += 1
243+
else:
244+
filtered_predicted = [Common.internal_delimiter.join(Common.filter_impossible_names(p)) for p in predicted]
245+
246+
true_ref = original_name
247+
if true_ref in filtered_predicted:
248+
index_of_correct = filtered_predicted.index(true_ref)
249+
update = np.concatenate(
250+
[np.zeros(index_of_correct, dtype=np.int32),
251+
np.ones(self.config.BEAM_WIDTH - index_of_correct, dtype=np.int32)])
252+
num_correct_predictions += update
237253
return num_correct_predictions
238254

239255
def update_per_subtoken_statistics(self, results, true_positive, false_positive, false_negative):

0 commit comments

Comments
 (0)