-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgraph_visualiser.py
More file actions
53 lines (38 loc) · 1.63 KB
/
graph_visualiser.py
File metadata and controls
53 lines (38 loc) · 1.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import networkx as nx
import matplotlib.pyplot as plt
from torch_geometric.utils import to_networkx
from torch_geometric.data import Data
import torch
import XY_to_graph
from XY import XY_model
def plot_spin_graph():
xy = XY_model(5, 0.1)
x = XY_to_graph.get_xy_spin_node_features(xy.spin_grid, xy.spin_vel_grid)
edge_index = XY_to_graph.get_xy_edge_index(xy.spin_grid.shape)
edge_attr = XY_to_graph.get_xy_edge_attr(xy.spin_grid, edge_index)
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
G = to_networkx(data, to_undirected=False)
rows, columns = xy.spin_grid.shape
# Shift nodes to make arrows more visible
shift = 0.1
pos = {}
for i in range(rows):
for j in range(columns):
node_index = i * columns + j
row_shift = shift * (i % 2)
col_shift = shift * (j % 2)
pos[node_index] = (j + row_shift, -(i + col_shift)) # Apply shift to every other row and column
plt.figure(figsize=(6, 6))
nx.draw(G, pos=pos, with_labels=True, node_size=500, font_weight='bold')
plt.show()
def plot_vortex_graph():
x = torch.tensor([[0,1], [1,0], [1,1], [0,0]])
rows, cols = torch.combinations(torch.arange(4), 2).t()
edge_index = torch.cat([torch.stack([rows, cols]), torch.stack([cols, rows])], dim=1)
data = Data(x=x, edge_index=edge_index)
G = to_networkx(data, to_undirected=False)
plt.figure(figsize=(6, 6))
nx.draw(G, pos={i : x[i] for i in range(len(x))}, with_labels=True, node_size=500, font_weight='bold')
plt.show()
#plot_vortex_graph()
plot_spin_graph()