-
Notifications
You must be signed in to change notification settings - Fork 0
/
MNIST.py
67 lines (54 loc) · 1.58 KB
/
MNIST.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from __future__ import absolute_import
import numpy as np
import dgl
import os
import torch
from dgl.data.utils import download, extract_archive, get_download_dir, _get_dgl_url
from torch.utils.data import Dataset
import warnings
import torch.utils.data as data
from PIL import Image
import os
import os.path
import gzip
import numpy as np
import torch
import codecs
from torchvision import transforms
import networkx as nx
def adj_head(m):
'''
To create adjacency matrix as per Defferrard et al. 2016
'''
M = m**2
adj_matrix = np.zeros((M,M),dtype=int)
for i in range(m):
for j in range(m):
temp = np.zeros((m,m),dtype=int)
for yy in [-1, 0, 1]:
for xx in [-1, 0, 1]:
if 0 <= i + yy < m:
if 0<= j + xx < m:
temp[i+yy][j+xx]=1
adj_matrix[i*m+j]=temp.reshape(M)
return adj_matrix
def degree_(adj,m):
lenghth = adj.shape[0]
dia = 0
M = m**2
degree_matrix = np.zeros((M,M),dtype=int)
for i in range(lenghth):
degree = sum(adj[i])
degree_matrix[i][i] = degree
dia = dia + m + 1
return degree_matrix
class GraphTransform:
def __init__(self, device):
self.adj = adj_head(28)
self.degree = degree_(self.adj,28)
self.degree = np.where(self.degree>0, np.float_power(self.degree,-0.5),0)
self.adj = np.dot(self.degree,self.adj)
self.adj = np.dot(self.adj,self.degree)
def __call__(self, img):
return self.adj, \
np.array(img).reshape(-1, 1)