forked from CrawlScript/tf_geometric
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdemo_tagcn.py
73 lines (52 loc) · 2.21 KB
/
demo_tagcn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# coding=utf-8
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import tf_geometric as tfg
import tensorflow as tf
import numpy as np
from tensorflow import keras
from tf_geometric.datasets import CoraDataset
graph, (train_index, valid_index, test_index) = CoraDataset().load_data()
num_classes = graph.y.max() + 1
tagcn0 = tfg.layers.TAGCN(16)
tagcn1 = tfg.layers.TAGCN(num_classes)
dropout = keras.layers.Dropout(0.3)
def forward(graph, training=False):
h = tagcn0([graph.x, graph.edge_index, graph.edge_weight], cache=graph.cache)
h = dropout(h, training=training)
h = tagcn1([h, graph.edge_index, graph.edge_weight], cache=graph.cache)
return h
def compute_loss(logits, mask_index, vars):
masked_logits = tf.gather(logits, mask_index)
masked_labels = tf.gather(graph.y, mask_index)
losses = tf.nn.softmax_cross_entropy_with_logits(
logits=masked_logits,
labels=tf.one_hot(masked_labels, depth=num_classes)
)
kernel_vals = [var for var in vars if "kernel" in var.name]
l2_losses = [tf.nn.l2_loss(kernel_var) for kernel_var in kernel_vals]
return tf.reduce_mean(losses) + tf.add_n(l2_losses) * 5e-4
def evaluate(mask):
logits = forward(graph)
logits = tf.nn.log_softmax(logits, axis=-1)
masked_logits = tf.gather(logits, mask)
masked_labels = tf.gather(graph.y, mask)
y_pred = tf.argmax(masked_logits, axis=-1, output_type=tf.int32)
accuracy_m = keras.metrics.Accuracy()
accuracy_m.update_state(masked_labels, y_pred)
return accuracy_m.result().numpy()
optimizer = tf.keras.optimizers.Adam(learning_rate=0.01)
best_test_acc = tmp_valid_acc = 0
for step in range(1, 101):
with tf.GradientTape() as tape:
logits = forward(graph, training=True)
loss = compute_loss(logits, train_index, tape.watched_variables())
vars = tape.watched_variables()
grads = tape.gradient(loss, vars)
optimizer.apply_gradients(zip(grads, vars))
valid_acc = evaluate(valid_index)
test_acc = evaluate(test_index)
if test_acc > best_test_acc:
best_test_acc = test_acc
tmp_valid_acc = valid_acc
print("step = {}\tloss = {}\tvalid_acc = {}\tbest_test_acc = {}".format(step, loss, tmp_valid_acc, best_test_acc))