-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path_evaluateROC.py
More file actions
85 lines (71 loc) · 3.66 KB
/
_evaluateROC.py
File metadata and controls
85 lines (71 loc) · 3.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import os
import matplotlib.pyplot as plt
from sklearn.metrics import auc
from sklearn.linear_model import LogisticRegression
from utils.parser import get_parser
from utils.save_and_load import load_all_generations
from probes.CCS import ROC_CCS as CCS
from pathlib import Path
ROOT_DIR = Path(__file__).resolve().parents[0]
def main(args, generation_args):
# load hidden states and labels
generations = load_all_generations(generation_args)
neg_hs, pos_hs, y = tuple(generations.values())
# Make sure the shape is correct
assert neg_hs.shape == pos_hs.shape
neg_hs, pos_hs = neg_hs[..., -1], pos_hs[..., -1] # take the last layer
if neg_hs.shape[1] == 1: # T5 may have an extra dimension; if so, get rid of it
neg_hs = neg_hs.squeeze(1)
pos_hs = pos_hs.squeeze(1)
# Very simple train/test split (using the fact that the data is already shuffled)
neg_hs_train, neg_hs_test = neg_hs[:len(neg_hs) // 2], neg_hs[len(neg_hs) // 2:]
pos_hs_train, pos_hs_test = pos_hs[:len(pos_hs) // 2], pos_hs[len(pos_hs) // 2:]
y_train, y_test = y[:len(y) // 2], y[len(y) // 2:]
# Make sure logistic regression accuracy is reasonable; otherwise our method won't have much of a chance of working
# you can also concatenate, but this works fine and is more comparable to CCS inputs
x_train = neg_hs_train - pos_hs_train
x_test = neg_hs_test - pos_hs_test
lr = LogisticRegression(class_weight="balanced")
lr.fit(x_train, y_train)
print("Logistic regression accuracy: {}".format(lr.score(x_test, y_test)))
# Set up CCS. Note that you can usually just use the default args by simply doing ccs = CCS(neg_hs, pos_hs, y)
ccs = CCS(neg_hs_train, pos_hs_train, nepochs=args.nepochs, ntries=args.ntries, lr=args.lr, batch_size=args.ccs_batch_size,
verbose=args.verbose, device=args.ccs_device, linear=args.linear, weight_decay=args.weight_decay,
var_normalize=args.var_normalize)
# train and evaluate CCS
ccs.repeated_train()
ccs_acc = ccs.get_acc(neg_hs_test, pos_hs_test, y_test)
print("CCS accuracy: {}".format(ccs_acc))
scores = ccs.get_scores(neg_hs_test, pos_hs_test)
fpr, tpr, roc_auc = ccs.compute_roc(scores, y_test)
plot_dir = str(ROOT_DIR / "plots/ROC")
if not os.path.exists(plot_dir):
os.makedirs(plot_dir)
save_path = f"{plot_dir}/{args.model_name}_{args.dataset_name}.png"
plt.figure()
lw = 2
plt.plot(fpr, tpr, color='darkorange', lw=lw, label='ROC curve (area = %0.2f)' % roc_auc)
plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic Curve')
plt.legend(loc="lower right")
plt.savefig(save_path)
plt.show()
if __name__ == "__main__":
parser = get_parser()
generation_args = parser.parse_args() # we'll use this to load the correct hidden states + labels
# We'll also add some additional args for evaluation
parser.add_argument("--nepochs", type=int, default=1000)
parser.add_argument("--ntries", type=int, default=10)
parser.add_argument("--lr", type=float, default=1e-3)
parser.add_argument("--ccs_batch_size", type=int, default=-1)
parser.add_argument("--verbose", action="store_true")
parser.add_argument("--ccs_device", type=str, default="cuda")
parser.add_argument("--linear", action="store_true")
parser.add_argument("--weight_decay", type=float, default=0.01)
parser.add_argument("--var_normalize", action="store_true")
args = parser.parse_args()
main(args, generation_args)