-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathadjacency_plot.py
110 lines (92 loc) · 4.05 KB
/
adjacency_plot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import torch
from utils.datasets import load_dataset, MOLECULAR_DATASETS
from utils.graphs import unflatt_tril
from pylatex import Document, TikZ, NoEscape
from math import isclose
def nextgrouplot(pic, matrix, title, colorbar=False):
s = []
for i in range(len(matrix)):
for j in range(len(matrix)):
# s.append(f'({i},{j},{matrix[i, j]})')
s.append(f'({i+1},{j+1}) [{matrix[i, j]}]')
s.append('\n\n')
s = ' '.join(s)
ngp = f'\\nextgroupplot[xlabel={title}'
if colorbar == True:
ngp += r',colorbar]'
else:
ngp += r']'
pic.append(NoEscape(ngp))
pic.append(NoEscape(f'\\addplot[matrix plot, point meta=explicit] coordinates {{\n' + s + '\n};'))
def markzeros(pic, matrix):
s = []
for i in range(len(matrix)):
for j in range(len(matrix)):
if isclose(matrix[i, j], 0.0, abs_tol=1e-8):
s.append(f'\\node[text=green] at (axis cs: {i+1},{j+1}) {{$*$}};')
s.append('\n')
s = ''.join(s)
pic.append(NoEscape(s))
if __name__ == "__main__":
dataset = 'qm9'
max_atoms = MOLECULAR_DATASETS[dataset]['max_atoms']
loader_uno = load_dataset(dataset, 100, [0.98, 0.01, 0.01], order='unordered')
loader_can = load_dataset(dataset, 100, [0.98, 0.01, 0.01], order='canonical')
loader_bft = load_dataset(dataset, 100, [0.98, 0.01, 0.01], order='bft')
loader_dft = load_dataset(dataset, 100, [0.98, 0.01, 0.01], order='dft')
loader_rcm = load_dataset(dataset, 100, [0.98, 0.01, 0.01], order='rcm')
a_uno = torch.stack([b['a'] for b in loader_uno['loader_trn'].dataset])
a_can = torch.stack([b['a'] for b in loader_can['loader_trn'].dataset])
a_bft = torch.stack([b['a'] for b in loader_bft['loader_trn'].dataset])
a_dft = torch.stack([b['a'] for b in loader_dft['loader_trn'].dataset])
a_rcm = torch.stack([b['a'] for b in loader_rcm['loader_trn'].dataset])
a_uno = unflatt_tril(a_uno, max_atoms)
a_can = unflatt_tril(a_can, max_atoms)
a_bft = unflatt_tril(a_bft, max_atoms)
a_dft = unflatt_tril(a_dft, max_atoms)
a_rcm = unflatt_tril(a_rcm, max_atoms)
a_uno = (a_uno > 0).to(torch.float).mean(dim=0)
a_can = (a_can > 0).to(torch.float).mean(dim=0)
a_bft = (a_bft > 0).to(torch.float).mean(dim=0)
a_dft = (a_dft > 0).to(torch.float).mean(dim=0)
a_rcm = (a_rcm > 0).to(torch.float).mean(dim=0)
# a_uno = (a_uno > 0.).to(torch.float)
# a_can = (a_can > 0.).to(torch.float)
# a_bft = (a_bft > 0.).to(torch.float)
# a_dft = (a_dft > 0.).to(torch.float)
# a_rcm = (a_rcm > 0.).to(torch.float)
doc = Document(documentclass='standalone', document_options=('preview'), geometry_options={'margin': '1cm'})
doc.packages.append(NoEscape(r'\usepackage{pgfplots}'))
doc.packages.append(NoEscape(r'\pgfplotsset{compat=1.18}'))
doc.packages.append(NoEscape(r'\usepgfplotslibrary{groupplots}'))
with doc.create(TikZ(options=NoEscape(r'font=\footnotesize'))) as pic:
pic.append(NoEscape(
r'\begin{groupplot}[' +
r'group style={group size=5 by 1},' +
r'height=3.7cm,' +
r'width=3.7cm,' +
r'xticklabel pos=right,' +
r'xtick={1,3,...,9},' +
r'ytick={1,3,...,9},' +
r'xmin=0.5,' +
r'ymin=0.5,' +
f'xmax={max_atoms}.5,' +
f'ymax={max_atoms}.5,' +
r'colormap name=hot,' +
r'point meta min=0.0,' +
r'point meta max=0.6,' +
r'colorbar style={width=5pt}' +
r']'
))
nextgrouplot(pic, a_uno, 'Random')
markzeros(pic, a_uno)
nextgrouplot(pic, a_bft, 'BFT')
markzeros(pic, a_bft)
nextgrouplot(pic, a_dft, 'DFT')
markzeros(pic, a_dft)
nextgrouplot(pic, a_rcm, 'RCM')
markzeros(pic, a_rcm)
nextgrouplot(pic, a_can, 'MCA', True)
markzeros(pic, a_can)
pic.append(NoEscape(r'\end{groupplot}'))
doc.generate_pdf('results/adjacency_plot', clean_tex=False)