1
+
2
+ import os
3
+ import pandas as pd
4
+ import numpy as np
5
+ from itertools import repeat
6
+ import torch
7
+ from torch_geometric .data import Data , InMemoryDataset
8
+ from torch_geometric .utils import subgraph , to_networkx , remove_self_loops , to_dense_adj , dense_to_sparse
9
+ from torch_sparse import coalesce , spspmm
10
+
11
+
12
+ def extend_graph (data ):
13
+ edge_index = data .edge_index
14
+ N = data .num_nodes
15
+
16
+ value = edge_index .new_ones ((edge_index .size (1 ),), dtype = torch .float )
17
+
18
+ index , value = spspmm (edge_index , value , edge_index , value , N , N , N )
19
+ value .fill_ (0 )
20
+ index , value = remove_self_loops (index , value )
21
+
22
+ edge_index = torch .cat ([edge_index , index ], dim = 1 )
23
+
24
+ edge_index , _ = coalesce (edge_index , None , N , N )
25
+
26
+ value = edge_index .new_ones ((edge_index .size (1 ),), dtype = torch .float )
27
+
28
+ index , value = spspmm (edge_index , value , edge_index , value , N , N , N )
29
+ value .fill_ (0 )
30
+ index , value = remove_self_loops (index , value )
31
+
32
+ edge_index = torch .cat ([edge_index , index ], dim = 1 )
33
+
34
+ data .extended_edge_index , _ = coalesce (edge_index , None , N , N )
35
+ return data
36
+
37
+
38
+ class Molecule3DDataset (InMemoryDataset ):
39
+ def __init__ (self , root , dataset , mask_ratio = 0 , remove_center = False , transform = None , pre_transform = None , pre_filter = None , empty = False , use_extend_graph = False ):
40
+ self .root = root
41
+ self .dataset = dataset
42
+ self .mask_ratio = mask_ratio
43
+ self .remove_center = remove_center
44
+ self .use_extend_graph = use_extend_graph
45
+
46
+ self .transform , self .pre_transform , self .pre_filter = transform , pre_transform , pre_filter
47
+ super (Molecule3DDataset , self ).__init__ (root , transform , pre_transform , pre_filter )
48
+
49
+ if not empty :
50
+ self .data , self .slices = torch .load (self .processed_paths [0 ])
51
+ print ('Dataset: {}\n Data: {}' .format (self .dataset , self .data ))
52
+
53
+ def subgraph (self , data ):
54
+ G = to_networkx (data )
55
+ node_num = data .x .size ()[0 ]
56
+ sub_num = int (node_num * (1 - self .mask_ratio ))
57
+
58
+ idx_sub = [np .random .randint (node_num , size = 1 )[0 ]]
59
+ idx_neigh = set ([n for n in G .neighbors (idx_sub [- 1 ])])
60
+
61
+ # BFS
62
+ while len (idx_sub ) <= sub_num :
63
+ if len (idx_neigh ) == 0 :
64
+ idx_unsub = list (set ([n for n in range (node_num )]).difference (set (idx_sub )))
65
+ idx_neigh = set ([np .random .choice (idx_unsub )])
66
+ sample_node = np .random .choice (list (idx_neigh ))
67
+
68
+ idx_sub .append (sample_node )
69
+ idx_neigh = idx_neigh .union (
70
+ set ([n for n in G .neighbors (idx_sub [- 1 ])])).difference (set (idx_sub ))
71
+
72
+ idx_nondrop = idx_sub
73
+ idx_nondrop .sort ()
74
+
75
+ edge_idx , edge_attr = subgraph (
76
+ subset = idx_nondrop ,
77
+ edge_index = data .edge_index ,
78
+ edge_attr = data .edge_attr ,
79
+ relabel_nodes = True ,
80
+ num_nodes = node_num
81
+ )
82
+ data .edge_index = edge_idx
83
+ data .edge_attr = edge_attr
84
+ data .x = data .x [idx_nondrop ]
85
+ data .positions = data .positions [idx_nondrop ]
86
+ data .__num_nodes__ = data .x .size ()[0 ]
87
+
88
+ if "radius_edge_index" in data :
89
+ radius_edge_index , _ = subgraph (
90
+ subset = idx_nondrop ,
91
+ edge_index = data .radius_edge_index ,
92
+ relabel_nodes = True ,
93
+ num_nodes = node_num )
94
+ data .radius_edge_index = radius_edge_index
95
+ if "extended_edge_index" in data :
96
+ # TODO: may consider extended_edge_attr
97
+ extended_edge_index , _ = subgraph (
98
+ subset = idx_nondrop ,
99
+ edge_index = data .extended_edge_index ,
100
+ relabel_nodes = True ,
101
+ num_nodes = node_num )
102
+ data .extended_edge_index = extended_edge_index
103
+ # TODO: will also need to do this for other edge_index
104
+ return data
105
+
106
+ def get (self , idx ):
107
+ data = Data ()
108
+ for key in self .data .keys :
109
+ item , slices = self .data [key ], self .slices [key ]
110
+ s = list (repeat (slice (None ), item .dim ()))
111
+ s [data .__cat_dim__ (key , item )] = slice (slices [idx ], slices [idx + 1 ])
112
+ data [key ] = item [s ]
113
+
114
+ if self .use_extend_graph :
115
+ data = extend_graph (data )
116
+
117
+ if self .mask_ratio > 0 :
118
+ data = self .subgraph (data )
119
+
120
+ if self .remove_center :
121
+ center = data .positions .mean (dim = 0 )
122
+ data .positions -= center
123
+
124
+ return data
125
+
126
+ def _download (self ):
127
+ return
128
+
129
+ @property
130
+ def processed_file_names (self ):
131
+ return 'geometric_data_processed.pt'
132
+
133
+ def process (self ):
134
+ return
135
+
136
+
137
+ if __name__ == "__main__" :
138
+
139
+ def extend_graph (data ):
140
+ edge_index = data .edge_index
141
+ N = data .num_nodes
142
+
143
+ value = edge_index .new_ones ((edge_index .size (1 ), ), dtype = torch .float )
144
+ edge_index_2_hop , value_2_hop = spspmm (edge_index , value , edge_index , value , N , N , N )
145
+ print ("edge_index_2_hop" , edge_index_2_hop )
146
+ print ("value_2_hop" , value_2_hop )
147
+ value_2_hop .fill_ (1 )
148
+ edge_index_3_hop , value_3_hop = spspmm (edge_index , value , edge_index_2_hop , value_2_hop , N , N , N )
149
+ print ("edge_index_3_hop" , edge_index_3_hop )
150
+ print ("value_3_hop" , value_3_hop )
151
+ value_3_hop .fill_ (1 )
152
+
153
+ index_list = [edge_index , edge_index_2_hop , edge_index_3_hop ]
154
+ value_list = [value , value_2_hop , value_3_hop ]
155
+ index = torch .cat (index_list , dim = - 1 )
156
+ value = torch .cat (value_list , dim = - 1 )
157
+ index , value = remove_self_loops (index , value )
158
+
159
+ edge_index = torch .cat ([edge_index , index ], dim = 1 )
160
+
161
+ data .extended_edge_index , _ = coalesce (edge_index , None , N , N )
162
+ return data
163
+
164
+ from torch import Tensor
165
+ x = Tensor ([0 , 1 , 2 , 3 , 4 ])
166
+ row = Tensor ([0 , 1 , 1 , 2 , 2 , 3 , 3 , 4 ])
167
+ col = Tensor ([1 , 0 , 2 , 1 , 3 , 2 , 4 , 3 ])
168
+ edge_index = [row , col ]
169
+ edge_index = torch .stack (edge_index ).long ()
170
+ data = Data (
171
+ x = x ,
172
+ edge_index = edge_index ,
173
+ )
174
+ print (data )
175
+
176
+ data = extend_graph (data )
177
+ print ()
178
+ print (data .extended_edge_index )
179
+ print (data )
0 commit comments