diff --git a/EL_entity_linking/eval/el_evaluator.py b/EL_entity_linking/eval/el_evaluator.py index b44ac91..010cedd 100644 --- a/EL_entity_linking/eval/el_evaluator.py +++ b/EL_entity_linking/eval/el_evaluator.py @@ -36,6 +36,7 @@ def calculate_f1(precision, recall): def evaluate_dbpedia(gold_answers, system_answers): count, total_p, total_r, total_f1 = 0, 0, 0, 0 for ques_id in gold_answers: + print(ques_id) count += 1 # if an answer is not provided to a question, we just move on if ques_id not in system_answers: @@ -48,9 +49,11 @@ def evaluate_dbpedia(gold_answers, system_answers): # collect all gold entities all_gold_ents = set() for gold_ent_set in gold_answer_list: - all_gold_ents.update(gold_ent_set) + all_gold_ents.add(gold_ent_set) # mark all correct answers from system answers + print("gold entities: ",all_gold_ents) + print("system entities: ",system_entities) for rel in system_entities: if rel in all_gold_ents: correct_count += 1 @@ -58,23 +61,25 @@ def evaluate_dbpedia(gold_answers, system_answers): # check how many entities sets are covered in system answers. In ground truth, we have multiple correct # entities for a given slot. # For example, {"dbo:locatedInArea", "dbo:city", "dbo:isPartOf", "dbo:location", "dbo:region"} - for gold_ent_set in gold_answer_list: + for ent in gold_answer_list: gold_ent_count += 1 - found_ent = False - for rel in gold_ent_set: - if rel in system_entities: - found_ent = True - system_entities.remove(rel) - break - if found_ent: + if ent in system_entities: found_count += 1 - + print("sys_ent_count: ", sys_ent_count) + print("gold_ent_count: ",gold_ent_count) + print("correct_count: ", correct_count) + print("found_count: ",found_count) # precision, recall and F1 calculation - precision = correct_count / sys_ent_count - recall = found_count / gold_ent_count + precision = correct_count / (sys_ent_count+0.0000001) + recall = found_count / (gold_ent_count+0.0000001) total_p += precision total_r += recall total_f1 += calculate_f1(precision, recall) + print("precision :",precision) + print("recall: ",recall) + print("f1: ",calculate_f1(precision, recall)) + print("average f1 so far:",(total_f1/count)) + print("====================================") return total_p/count, total_r/count, total_f1/count