-
Notifications
You must be signed in to change notification settings - Fork 38
Expand file tree
/
Copy patheval.py
More file actions
90 lines (82 loc) · 3.55 KB
/
Copy patheval.py
File metadata and controls
90 lines (82 loc) · 3.55 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch
import warnings
from utils.loaddata import load_batch_level_dataset, load_entity_level_dataset, load_metadata
from model.autoencoder import build_model
from utils.poolers import Pooling
from utils.utils import set_random_seed
import numpy as np
from model.eval import batch_level_evaluation, evaluate_entity_level_using_knn
from utils.config import build_args
warnings.filterwarnings('ignore')
def main(main_args):
device = main_args.device if main_args.device >= 0 else "cpu"
device = torch.device(device)
dataset_name = main_args.dataset
if dataset_name in ['streamspot', 'wget']:
main_args.num_hidden = 256
main_args.num_layers = 4
else:
main_args.num_hidden = 64
main_args.num_layers = 3
set_random_seed(0)
if dataset_name == 'streamspot' or dataset_name == 'wget':
dataset = load_batch_level_dataset(dataset_name)
n_node_feat = dataset['n_feat']
n_edge_feat = dataset['e_feat']
main_args.n_dim = n_node_feat
main_args.e_dim = n_edge_feat
model = build_model(main_args)
model.load_state_dict(torch.load("./checkpoints/checkpoint-{}.pt".format(dataset_name), map_location=device))
model = model.to(device)
pooler = Pooling(main_args.pooling)
test_auc, test_std = batch_level_evaluation(model, pooler, device, ['knn'], args.dataset, main_args.n_dim,
main_args.e_dim)
else:
metadata = load_metadata(dataset_name)
main_args.n_dim = metadata['node_feature_dim']
main_args.e_dim = metadata['edge_feature_dim']
model = build_model(main_args)
model.load_state_dict(torch.load("./checkpoints/checkpoint-{}.pt".format(dataset_name), map_location=device))
model = model.to(device)
model.eval()
malicious, _ = metadata['malicious']
n_train = metadata['n_train']
n_test = metadata['n_test']
with torch.no_grad():
x_train = []
for i in range(n_train):
g = load_entity_level_dataset(dataset_name, 'train', i).to(device)
x_train.append(model.embed(g).cpu().numpy())
del g
x_train = np.concatenate(x_train, axis=0)
skip_benign = 0
x_test = []
for i in range(n_test):
g = load_entity_level_dataset(dataset_name, 'test', i).to(device)
# Exclude training samples from the test set
if i != n_test - 1:
skip_benign += g.number_of_nodes()
x_test.append(model.embed(g).cpu().numpy())
del g
x_test = np.concatenate(x_test, axis=0)
n = x_test.shape[0]
y_test = np.zeros(n)
y_test[malicious] = 1.0
malicious_dict = {}
for i, m in enumerate(malicious):
malicious_dict[m] = i
# Exclude training samples from the test set
test_idx = []
for i in range(x_test.shape[0]):
if i >= skip_benign or y_test[i] == 1.0:
test_idx.append(i)
result_x_test = x_test[test_idx]
result_y_test = y_test[test_idx]
del x_test, y_test
test_auc, test_std, _, _ = evaluate_entity_level_using_knn(dataset_name, x_train, result_x_test,
result_y_test)
print(f"#Test_AUC: {test_auc:.4f}±{test_std:.4f}")
return
if __name__ == '__main__':
args = build_args()
main(args)