Skip to content

Commit 447767a

Browse files
committed
Add scripts
1 parent 71d1789 commit 447767a

40 files changed

+9021
-7
lines changed

Geom3D/datasets/__init__.py

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from Geom3D.datasets.dataset_utils import graph_data_obj_to_nx_simple, nx_to_graph_data_obj_simple, atom_type_count
2+
3+
from Geom3D.datasets.dataset_PCQM4Mv2 import PCQM4Mv2
4+
5+
from Geom3D.datasets.dataset_QM9 import MoleculeDatasetQM9
6+
7+
from Geom3D.datasets.dataset_MD17 import DatasetMD17
8+
9+
from Geom3D.datasets.dataset_3D import Molecule3DDataset
10+
from Geom3D.datasets.dataset_3D_Radius import MoleculeDataset3DRadius
11+
12+
from Geom3D.datasets.dataset_MoleculeNet_2D import MoleculeNetDataset2D

Geom3D/datasets/dataset_3D.py

+179
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
2+
import os
3+
import pandas as pd
4+
import numpy as np
5+
from itertools import repeat
6+
import torch
7+
from torch_geometric.data import Data, InMemoryDataset
8+
from torch_geometric.utils import subgraph, to_networkx, remove_self_loops, to_dense_adj, dense_to_sparse
9+
from torch_sparse import coalesce, spspmm
10+
11+
12+
def extend_graph(data):
13+
edge_index = data.edge_index
14+
N = data.num_nodes
15+
16+
value = edge_index.new_ones((edge_index.size(1),), dtype=torch.float)
17+
18+
index, value = spspmm(edge_index, value, edge_index, value, N, N, N)
19+
value.fill_(0)
20+
index, value = remove_self_loops(index, value)
21+
22+
edge_index = torch.cat([edge_index, index], dim=1)
23+
24+
edge_index, _ = coalesce(edge_index, None, N, N)
25+
26+
value = edge_index.new_ones((edge_index.size(1),), dtype=torch.float)
27+
28+
index, value = spspmm(edge_index, value, edge_index, value, N, N, N)
29+
value.fill_(0)
30+
index, value = remove_self_loops(index, value)
31+
32+
edge_index = torch.cat([edge_index, index], dim=1)
33+
34+
data.extended_edge_index, _ = coalesce(edge_index, None, N, N)
35+
return data
36+
37+
38+
class Molecule3DDataset(InMemoryDataset):
39+
def __init__(self, root, dataset, mask_ratio=0, remove_center=False, transform=None, pre_transform=None, pre_filter=None, empty=False, use_extend_graph=False):
40+
self.root = root
41+
self.dataset = dataset
42+
self.mask_ratio = mask_ratio
43+
self.remove_center = remove_center
44+
self.use_extend_graph = use_extend_graph
45+
46+
self.transform, self.pre_transform, self.pre_filter = transform, pre_transform, pre_filter
47+
super(Molecule3DDataset, self).__init__(root, transform, pre_transform, pre_filter)
48+
49+
if not empty:
50+
self.data, self.slices = torch.load(self.processed_paths[0])
51+
print('Dataset: {}\nData: {}'.format(self.dataset, self.data))
52+
53+
def subgraph(self, data):
54+
G = to_networkx(data)
55+
node_num = data.x.size()[0]
56+
sub_num = int(node_num * (1 - self.mask_ratio))
57+
58+
idx_sub = [np.random.randint(node_num, size=1)[0]]
59+
idx_neigh = set([n for n in G.neighbors(idx_sub[-1])])
60+
61+
# BFS
62+
while len(idx_sub) <= sub_num:
63+
if len(idx_neigh) == 0:
64+
idx_unsub = list(set([n for n in range(node_num)]).difference(set(idx_sub)))
65+
idx_neigh = set([np.random.choice(idx_unsub)])
66+
sample_node = np.random.choice(list(idx_neigh))
67+
68+
idx_sub.append(sample_node)
69+
idx_neigh = idx_neigh.union(
70+
set([n for n in G.neighbors(idx_sub[-1])])).difference(set(idx_sub))
71+
72+
idx_nondrop = idx_sub
73+
idx_nondrop.sort()
74+
75+
edge_idx, edge_attr = subgraph(
76+
subset=idx_nondrop,
77+
edge_index=data.edge_index,
78+
edge_attr=data.edge_attr,
79+
relabel_nodes=True,
80+
num_nodes=node_num
81+
)
82+
data.edge_index = edge_idx
83+
data.edge_attr = edge_attr
84+
data.x = data.x[idx_nondrop]
85+
data.positions = data.positions[idx_nondrop]
86+
data.__num_nodes__ = data.x.size()[0]
87+
88+
if "radius_edge_index" in data:
89+
radius_edge_index, _ = subgraph(
90+
subset=idx_nondrop,
91+
edge_index=data.radius_edge_index,
92+
relabel_nodes=True,
93+
num_nodes=node_num)
94+
data.radius_edge_index = radius_edge_index
95+
if "extended_edge_index" in data:
96+
# TODO: may consider extended_edge_attr
97+
extended_edge_index, _ = subgraph(
98+
subset=idx_nondrop,
99+
edge_index=data.extended_edge_index,
100+
relabel_nodes=True,
101+
num_nodes=node_num)
102+
data.extended_edge_index = extended_edge_index
103+
# TODO: will also need to do this for other edge_index
104+
return data
105+
106+
def get(self, idx):
107+
data = Data()
108+
for key in self.data.keys:
109+
item, slices = self.data[key], self.slices[key]
110+
s = list(repeat(slice(None), item.dim()))
111+
s[data.__cat_dim__(key, item)] = slice(slices[idx], slices[idx + 1])
112+
data[key] = item[s]
113+
114+
if self.use_extend_graph:
115+
data = extend_graph(data)
116+
117+
if self.mask_ratio > 0:
118+
data = self.subgraph(data)
119+
120+
if self.remove_center:
121+
center = data.positions.mean(dim=0)
122+
data.positions -= center
123+
124+
return data
125+
126+
def _download(self):
127+
return
128+
129+
@property
130+
def processed_file_names(self):
131+
return 'geometric_data_processed.pt'
132+
133+
def process(self):
134+
return
135+
136+
137+
if __name__ == "__main__":
138+
139+
def extend_graph(data):
140+
edge_index = data.edge_index
141+
N = data.num_nodes
142+
143+
value = edge_index.new_ones((edge_index.size(1), ), dtype=torch.float)
144+
edge_index_2_hop, value_2_hop = spspmm(edge_index, value, edge_index, value, N, N, N)
145+
print("edge_index_2_hop", edge_index_2_hop)
146+
print("value_2_hop", value_2_hop)
147+
value_2_hop.fill_(1)
148+
edge_index_3_hop, value_3_hop = spspmm(edge_index, value, edge_index_2_hop, value_2_hop, N, N, N)
149+
print("edge_index_3_hop", edge_index_3_hop)
150+
print("value_3_hop", value_3_hop)
151+
value_3_hop.fill_(1)
152+
153+
index_list = [edge_index, edge_index_2_hop, edge_index_3_hop]
154+
value_list = [value, value_2_hop, value_3_hop]
155+
index = torch.cat(index_list, dim=-1)
156+
value = torch.cat(value_list, dim=-1)
157+
index, value = remove_self_loops(index, value)
158+
159+
edge_index = torch.cat([edge_index, index], dim=1)
160+
161+
data.extended_edge_index, _ = coalesce(edge_index, None, N, N)
162+
return data
163+
164+
from torch import Tensor
165+
x = Tensor([0, 1, 2, 3, 4])
166+
row = Tensor([0, 1, 1, 2, 2, 3, 3, 4])
167+
col = Tensor([1, 0, 2, 1, 3, 2, 4, 3])
168+
edge_index = [row, col]
169+
edge_index = torch.stack(edge_index).long()
170+
data = Data(
171+
x=x,
172+
edge_index=edge_index,
173+
)
174+
print(data)
175+
176+
data = extend_graph(data)
177+
print()
178+
print(data.extended_edge_index)
179+
print(data)

Geom3D/datasets/dataset_3D_Radius.py

+168
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
import os
2+
import shutil
3+
from itertools import repeat
4+
import numpy as np
5+
6+
import torch
7+
from tqdm import tqdm
8+
from torch_geometric.data import Data, InMemoryDataset
9+
from torch_geometric.nn import radius_graph
10+
from torch_geometric.utils import subgraph, to_networkx
11+
from .dataset_3D import extend_graph
12+
13+
14+
class MoleculeDataset3DRadius(InMemoryDataset):
15+
def __init__(self, root, preprcessed_dataset, radius, mask_ratio=0, remove_center=False, use_extend_graph=False):
16+
self.root = root
17+
self.dataset = preprcessed_dataset.dataset
18+
self.preprcessed_dataset = preprcessed_dataset
19+
self.radius = radius
20+
self.mask_ratio = mask_ratio
21+
self.remove_center = remove_center
22+
self.use_extend_graph = use_extend_graph
23+
24+
# TODO: rotation_transform is left for the future
25+
# self.rotation_transform = preprcessed_dataset.rotation_transform
26+
self.transform = preprcessed_dataset.transform
27+
self.pre_transform = preprcessed_dataset.pre_transform
28+
self.pre_filter = preprcessed_dataset.pre_filter
29+
30+
super(MoleculeDataset3DRadius, self).__init__(root, self.transform, self.pre_transform, self.pre_filter)
31+
self.data, self.slices = torch.load(self.processed_paths[0])
32+
print("Dataset: {}\nData: {}".format(self.dataset, self.data))
33+
34+
return
35+
36+
def mean(self):
37+
y = torch.stack([self.get(i).y for i in range(len(self))], dim=0)
38+
y = y.mean(dim=0)
39+
return y
40+
41+
def std(self):
42+
y = torch.stack([self.get(i).y for i in range(len(self))], dim=0)
43+
y = y.std(dim=0)
44+
return y
45+
46+
def subgraph(self, data):
47+
G = to_networkx(data)
48+
node_num = data.x.size()[0]
49+
sub_num = int(node_num * (1 - self.mask_ratio))
50+
51+
idx_sub = [np.random.randint(node_num, size=1)[0]]
52+
idx_neigh = set([n for n in G.neighbors(idx_sub[-1])])
53+
54+
# BFS
55+
while len(idx_sub) <= sub_num:
56+
if len(idx_neigh) == 0:
57+
idx_unsub = list(set([n for n in range(node_num)]).difference(set(idx_sub)))
58+
idx_neigh = set([np.random.choice(idx_unsub)])
59+
sample_node = np.random.choice(list(idx_neigh))
60+
61+
idx_sub.append(sample_node)
62+
idx_neigh = idx_neigh.union(
63+
set([n for n in G.neighbors(idx_sub[-1])])).difference(set(idx_sub))
64+
65+
idx_nondrop = idx_sub
66+
idx_nondrop.sort()
67+
68+
edge_idx, edge_attr = subgraph(
69+
subset=idx_nondrop,
70+
edge_index=data.edge_index,
71+
edge_attr=data.edge_attr,
72+
relabel_nodes=True,
73+
num_nodes=node_num
74+
)
75+
data.edge_index = edge_idx
76+
data.edge_attr = edge_attr
77+
data.x = data.x[idx_nondrop]
78+
data.positions = data.positions[idx_nondrop]
79+
data.__num_nodes__ = data.x.size()[0]
80+
81+
radius_edge_index, _ = subgraph(
82+
subset=idx_nondrop,
83+
edge_index=data.radius_edge_index,
84+
relabel_nodes=True,
85+
num_nodes=node_num)
86+
data.radius_edge_index = radius_edge_index
87+
88+
if "extended_edge_index" in data:
89+
extended_edge_index, _ = subgraph(
90+
subset=idx_nondrop,
91+
edge_index=data.extended_edge_index,
92+
relabel_nodes=True,
93+
num_nodes=node_num)
94+
data.extended_edge_index = extended_edge_index
95+
# TODO: will also need to do this for other edge_index
96+
97+
return data
98+
99+
def get(self, idx):
100+
data = Data()
101+
for key in self.data.keys:
102+
item, slices = self.data[key], self.slices[key]
103+
s = list(repeat(slice(None), item.dim()))
104+
s[data.__cat_dim__(key, item)] = slice(slices[idx], slices[idx + 1])
105+
data[key] = item[s]
106+
107+
if self.use_extend_graph:
108+
data = extend_graph(data)
109+
110+
if self.mask_ratio > 0:
111+
data = self.subgraph(data)
112+
113+
if self.remove_center:
114+
center = data.positions.mean(dim=0)
115+
data.positions -= center
116+
117+
return data
118+
119+
@property
120+
def processed_file_names(self):
121+
return "geometric_data_processed.pt"
122+
123+
def process(self):
124+
print("Preprocessing on {} with Radius Edges ...".format(self.dataset))
125+
126+
if self.dataset == "qm9":
127+
print("Preprocessing on QM9 Radius ...")
128+
preprocessed_smiles_path = os.path.join(self.preprcessed_dataset.processed_dir, "smiles.csv")
129+
smiles_path = os.path.join(self.processed_dir, "smiles.csv")
130+
shutil.copyfile(preprocessed_smiles_path, smiles_path)
131+
132+
preprocessed_data_name_file = os.path.join(self.preprcessed_dataset.processed_dir, "name.csv")
133+
data_name_file = os.path.join(self.processed_dir, "name.csv")
134+
shutil.copyfile(preprocessed_data_name_file, data_name_file)
135+
136+
elif self.dataset == "md17":
137+
print("Preprocessing on MD17 Radius ...")
138+
pass
139+
140+
elif self.dataset == "Molecule3D":
141+
print("Preprocessing on Molecule3D Radius ...")
142+
preprocessed_smiles_path = os.path.join(self.preprcessed_dataset.processed_dir, "smiles.csv")
143+
smiles_path = os.path.join(self.processed_dir, "smiles.csv")
144+
shutil.copyfile(preprocessed_smiles_path, smiles_path)
145+
146+
elif "GEOM" in self.dataset:
147+
print("Preprocessing on GEOM Radius ...")
148+
preprocessed_smiles_path = os.path.join(self.preprcessed_dataset.processed_dir, "smiles.csv")
149+
smiles_path = os.path.join(self.processed_dir, "smiles.csv")
150+
shutil.copyfile(preprocessed_smiles_path, smiles_path)
151+
152+
data_list = []
153+
for i in tqdm(range(len(self.preprcessed_dataset))):
154+
data = self.preprcessed_dataset.get(i)
155+
radius_edge_index = radius_graph(data.positions, r=self.radius, loop=False)
156+
data.radius_edge_index = radius_edge_index
157+
data_list.append(data)
158+
159+
if self.pre_filter is not None:
160+
data_list = [data for data in data_list if self.pre_filter(data)]
161+
162+
if self.pre_transform is not None:
163+
data_list = [self.pre_transform(data) for data in data_list]
164+
165+
data, slices = self.collate(data_list)
166+
torch.save((data, slices), self.processed_paths[0])
167+
168+
return

0 commit comments

Comments
 (0)