-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtest_all.py
143 lines (107 loc) · 4.31 KB
/
test_all.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
"""
Tests the entire pipeline using metrics
"""
from classification.predict_kind import LazyLoadedClassifier
from extraction.preprocess import resolve_coref
from extraction.parse import LazyLoadedExtractor
from extraction.assemble import assemble, remove_duplicates
from extraction.utils import uml, metrics, inquire
import os
import pandas
def setup():
global SOURCE_DIR, ZOO_DIR
SOURCE_DIR = os.path.join(os.path.dirname(__file__), "..")
ZOO_DIR = os.path.join(SOURCE_DIR, "zoo")
# load the CSV of grouped data
global GROUPED
try:
GROUPED = pandas.read_csv(
os.path.join(SOURCE_DIR, "three-step", "data", "grouped.csv")
)
except FileNotFoundError:
print("group.csv does not exist. Please generate it.")
exit(1)
global LOG_DIR
LOG_DIR = os.path.join(SOURCE_DIR, "three-step", "data", "logs")
os.makedirs(LOG_DIR, exist_ok=True)
def make_predictions():
"""
Run the pipeline. Returns the predictions.
"""
classifier = LazyLoadedClassifier()
class_extractor = LazyLoadedExtractor("", "class")
rel_extractor = LazyLoadedExtractor("", "rel")
predictions: dict[str, uml.UML] = {}
# Read the data
for _, row in GROUPED.iterrows():
model_name = row["model"]
grouped_text = row["text"]
# preprocess each data point
preprocessed_text = resolve_coref(grouped_text)
# classify each sentence
classification_results = {}
extraction_results = []
for index, sentence in preprocessed_text.items():
predicted_kind = classifier.predict(sentence)
classification_results[index] = predicted_kind
if predicted_kind == "class":
class_extractor.extractor.set_sentence(sentence)
result = class_extractor.handle_class()
elif predicted_kind == "rel":
rel_extractor.extractor.set_sentence(sentence)
result = rel_extractor.handle_rel()
else:
raise Exception("Unexpected kind!")
extraction_results.append(result)
# assemble the fragments
assembly_result = assemble(extraction_results)
predictions[model_name] = assembly_result
return predictions
def evaluate(predictions: dict[str, uml.UML]):
"""
Fetch ground truth and compare them
"""
ground_truth = []
predicted_models = []
for model_name, prediction in predictions.items():
ground = inquire.get_json_uml(os.path.join(ZOO_DIR, f"{model_name}.json"))
ground_truth.append(ground)
predicted_models.append(prediction)
return metrics.compute_metrics(predicted_models, ground_truth)
def selective_test():
original = inquire.get_json_uml_fragment(
r"C:\Users\songy\Documents\My Documents\UDEM\master thesis\uml data\database\analysis\three-step\extraction\temp\assembly\CFG_original_failed.json"
)
prediction = inquire.get_json_uml_fragment(
r"C:\Users\songy\Documents\My Documents\UDEM\master thesis\uml data\database\analysis\three-step\extraction\temp\assembly\CFG_prediction_failed.json"
)
original = remove_duplicates(original)
prediction = remove_duplicates(prediction)
print(metrics.compute_metrics([prediction], [original]))
def run_tests():
setup()
log_message = ""
log_message += "\nRunning metric based test suite. Logs are reported in {}".format(
LOG_DIR
)
predictions = make_predictions()
# log the predictions
os.makedirs(os.path.join(LOG_DIR, "predictions"), exist_ok=True)
for model_name, prediction in predictions.items():
prediction.save(os.path.join(LOG_DIR, "predictions", f"{model_name}.plantuml"))
class_scores, rel_scores = evaluate(predictions)
log_message += "\nClass scores\n" + str(class_scores)
log_message += "\nRel scores\n" + str(rel_scores)
log_message += "\nUnweighted Mean for Classes\t" + str(
[sum(y) / len(y) for y in zip(*class_scores)]
)
log_message += "\nUnweighted Mean for Relations\t" + str(
[sum(y) / len(y) for y in zip(*rel_scores)]
)
log_message += "\n(Precision, Recall, f1)"
print(log_message)
with open(os.path.join(LOG_DIR, "last_run.log"), "w") as out:
out.write(log_message)
if __name__ == "__main__":
run_tests()
# selective_test()