@@ -104,7 +104,10 @@ def train(self):
104
104
self .epochs_trained += self .config .SAVE_EVERY_EPOCHS
105
105
print ('Finished %d epochs' % self .config .SAVE_EVERY_EPOCHS )
106
106
results , precision , recall , f1 = self .evaluate ()
107
- print ('Accuracy after %d epochs: %.5f' % (self .epochs_trained , results ))
107
+ if self .config .BEAM_WIDTH == 0 :
108
+ print ('Accuracy after %d epochs: %.5f' % (self .epochs_trained , results ))
109
+ else :
110
+ print ('Accuracy after {} epochs: {}' .format (self .epochs_trained , results ))
108
111
print ('After %d epochs: Precision: %.5f, recall: %.5f, F1: %.5f' % (
109
112
self .epochs_trained , precision , recall , f1 ))
110
113
if f1 > best_f1 :
@@ -167,7 +170,8 @@ def evaluate(self, release=False):
167
170
with open (model_dirname + '/log.txt' , 'w' ) as output_file , open (ref_file_name , 'w' ) as ref_file , open (
168
171
predicted_file_name ,
169
172
'w' ) as pred_file :
170
- num_correct_predictions = 0
173
+ num_correct_predictions = 0 if self .config .BEAM_WIDTH == 0 \
174
+ else np .zeros ([self .config .BEAM_WIDTH ], dtype = np .int32 )
171
175
total_predictions = 0
172
176
total_prediction_batches = 0
173
177
true_positive , false_positive , false_negative = 0 , 0 , 0
@@ -223,17 +227,29 @@ def evaluate(self, release=False):
223
227
224
228
def update_correct_predictions (self , num_correct_predictions , output_file , results ):
225
229
for original_name , predicted in results :
230
+ original_name_parts = original_name .split (Common .internal_delimiter ) # list
231
+ filtered_original = Common .filter_impossible_names (original_name_parts ) # list
232
+ predicted_first = predicted
226
233
if self .config .BEAM_WIDTH > 0 :
227
- predicted = predicted [0 ]
228
- original_name_parts = original_name .split (Common .internal_delimiter )
229
- filtered_original = Common .filter_impossible_names (original_name_parts )
230
- filtered_predicted_parts = Common .filter_impossible_names (predicted )
234
+ predicted_first = predicted [0 ]
235
+ filtered_predicted_first_parts = Common .filter_impossible_names (predicted_first ) # list
231
236
output_file .write ('Original: ' + Common .internal_delimiter .join (original_name_parts ) +
232
- ' , predicted 1st: ' + Common .internal_delimiter .join (
233
- [target for target in filtered_predicted_parts ]) + '\n ' )
234
- if filtered_original == filtered_predicted_parts or Common .unique (filtered_original ) == Common .unique (
235
- filtered_predicted_parts ) or '' .join (filtered_original ) == '' .join (filtered_predicted_parts ):
236
- num_correct_predictions += 1
237
+ ' , predicted 1st: ' + Common .internal_delimiter .join (filtered_predicted_first_parts ) + '\n ' )
238
+
239
+ if self .config .BEAM_WIDTH == 0 :
240
+ if filtered_original == filtered_predicted_first_parts or Common .unique (filtered_original ) == Common .unique (
241
+ filtered_predicted_first_parts ) or '' .join (filtered_original ) == '' .join (filtered_predicted_first_parts ):
242
+ num_correct_predictions += 1
243
+ else :
244
+ filtered_predicted = [Common .internal_delimiter .join (Common .filter_impossible_names (p )) for p in predicted ]
245
+
246
+ true_ref = original_name
247
+ if true_ref in filtered_predicted :
248
+ index_of_correct = filtered_predicted .index (true_ref )
249
+ update = np .concatenate (
250
+ [np .zeros (index_of_correct , dtype = np .int32 ),
251
+ np .ones (self .config .BEAM_WIDTH - index_of_correct , dtype = np .int32 )])
252
+ num_correct_predictions += update
237
253
return num_correct_predictions
238
254
239
255
def update_per_subtoken_statistics (self , results , true_positive , false_positive , false_negative ):
0 commit comments