1
1
import re
2
+ import warnings
2
3
from contextlib import contextmanager
3
4
from multiprocessing import Pool
4
- from typing import List
5
+ from typing import List , Optional
5
6
6
7
import numpy as np
7
8
import torch
10
11
from torch import nn
11
12
from torch .utils .data import Dataset
12
13
13
- from .torch_layers import IndexTensor , IndexTuple , Reverse , SamePadding1d , Transpose
14
+ from fcd .torch_layers import IndexTensor , IndexTuple , Reverse , SamePadding1d , Transpose
14
15
15
16
# fmt: off
16
17
__vocab = ["C" ,"N" ,"O" ,"H" ,"F" ,"Cl" ,"P" ,"B" ,"Br" ,"S" ,"I" ,"Si" ,"#" ,"(" ,")" ,"+" ,"-" ,"1" ,"2" ,"3" ,"4" ,"5" ,"6" ,"7" ,"8" ,"=" ,"[" ,"]" ,"@" ,"c" ,"n" ,"o" ,"s" ,"X" ,"." ]
@@ -42,7 +43,7 @@ def tokenize(smiles: str) -> List[str]:
42
43
return tok_smile
43
44
44
45
45
- def get_one_hot (smiles : str , pad_len : int = - 1 ) -> np .ndarray :
46
+ def get_one_hot (smiles : str , pad_len : Optional [ int ] = None ) -> np .ndarray :
46
47
"""Generate one-hot representation of a Smiles string.
47
48
48
49
Args:
@@ -52,10 +53,13 @@ def get_one_hot(smiles: str, pad_len: int = -1) -> np.ndarray:
52
53
Returns:
53
54
np.ndarray: Array containing the one-hot encoded Smiles
54
55
"""
56
+ # add end token
55
57
smiles = smiles + "."
56
58
57
59
# initialize array
58
- array_length = len (smiles ) if pad_len < 0 else pad_len
60
+ array_length = len (smiles ) if pad_len is None else pad_len
61
+ assert array_length >= len (smiles ), "Pad length must be greater than the length of the input SMILES string + 1."
62
+
59
63
vocab_size = len (__vocab )
60
64
one_hot = np .zeros ((array_length , vocab_size ))
61
65
@@ -106,22 +110,57 @@ def load_imported_model(keras_config):
106
110
107
111
108
112
class SmilesDataset (Dataset ):
109
- __PAD_LEN = 350
113
+ """
114
+ A dataset class for handling SMILES data.
115
+
116
+ Args:
117
+ smiles_list (list): A list of SMILES strings.
118
+ pad_len (int, optional): The length to pad the SMILES strings to. If not provided, the default pad length of 350 will be used.
119
+ warn (bool, optional): Whether to display a warning message if the specified pad length is different from the default. Defaults to True.
120
+
121
+ Attributes:
122
+ smiles_list (list): A list of SMILES strings.
123
+ pad_len (int): The length to pad the SMILES strings to.
124
+
125
+ """
110
126
111
- def __init__ (self , smiles_list ):
127
+ def __init__ (self , smiles_list , pad_len = None , warn = True ):
112
128
super ().__init__ ()
129
+ DEFAULT_PAD_LEN = 350
130
+
113
131
self .smiles_list = smiles_list
132
+ max_len = max (len (smiles ) for smiles in smiles_list ) + 1 # plus one for the end token
133
+
134
+ if pad_len is None :
135
+ pad_len = max (DEFAULT_PAD_LEN , max_len )
136
+ else :
137
+ if pad_len < max_len :
138
+ raise ValueError (f"Specified pad_len { pad_len } is less than max_len { max_len } " )
139
+
140
+ if pad_len != DEFAULT_PAD_LEN :
141
+ warnings .warn (
142
+ """Padding lengths differing from the default of 350 may affect FCD scores. See https://github.com/hogru/GuacaMolEval.
143
+ Use warn=False to suppress this warning."""
144
+ )
145
+
146
+ self .pad_len = pad_len
114
147
115
148
def __getitem__ (self , idx ):
116
149
smiles = self .smiles_list [idx ]
117
- features = get_one_hot (smiles , 350 )
150
+ features = get_one_hot (smiles , pad_len = self . pad_len )
118
151
return features / features .shape [1 ]
119
152
120
153
def __len__ (self ):
121
154
return len (self .smiles_list )
122
155
123
156
124
- def calculate_frechet_distance (mu1 , sigma1 , mu2 , sigma2 , eps = 1e-6 ):
157
+ def calculate_frechet_distance (
158
+ mu1 : np .ndarray ,
159
+ sigma1 : np .ndarray ,
160
+ mu2 : np .ndarray ,
161
+ sigma2 : np .ndarray ,
162
+ eps : float = 1e-6 ,
163
+ ) -> float :
125
164
"""Numpy implementation of the Frechet Distance.
126
165
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
127
166
and X_2 ~ N(mu_2, C_2) is
@@ -151,21 +190,20 @@ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
151
190
sigma1 = np .atleast_2d (sigma1 )
152
191
sigma2 = np .atleast_2d (sigma2 )
153
192
154
- assert (
155
- mu1 .shape == mu2 .shape
156
- ), "Training and test mean vectors have different lengths"
157
- assert (
158
- sigma1 .shape == sigma2 .shape
159
- ), "Training and test covariances have different dimensions"
193
+ assert mu1 .shape == mu2 .shape , "Training and test mean vectors have different lengths"
194
+ assert sigma1 .shape == sigma2 .shape , "Training and test covariances have different dimensions"
160
195
161
196
diff = mu1 - mu2
162
197
163
198
# product might be almost singular
164
199
covmean , _ = linalg .sqrtm (sigma1 .dot (sigma2 ), disp = False )
165
- if not np .isfinite (covmean ).all ():
200
+ is_real = np .allclose (np .diagonal (covmean ).imag , 0 , atol = 1e-3 )
201
+
202
+ if not np .isfinite (covmean ).all () or not is_real :
166
203
offset = np .eye (sigma1 .shape [0 ]) * eps
167
204
covmean = linalg .sqrtm ((sigma1 + offset ).dot (sigma2 + offset ))
168
205
206
+ assert isinstance (covmean , np .ndarray )
169
207
# numerical error might give slight imaginary component
170
208
if np .iscomplexobj (covmean ):
171
209
if not np .allclose (np .diagonal (covmean ).imag , 0 , atol = 1e-3 ):
@@ -175,7 +213,7 @@ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
175
213
176
214
tr_covmean = np .trace (covmean )
177
215
178
- return diff .dot (diff ) + np .trace (sigma1 ) + np .trace (sigma2 ) - 2 * tr_covmean
216
+ return float ( diff .dot (diff ) + np .trace (sigma1 ) + np .trace (sigma2 ) - 2 * tr_covmean )
179
217
180
218
181
219
@contextmanager
@@ -188,11 +226,11 @@ def todevice(model, device):
188
226
189
227
def canonical (smi ):
190
228
try :
191
- return Chem .MolToSmiles (Chem .MolFromSmiles (smi ))
192
- except :
229
+ return Chem .MolToSmiles (Chem .MolFromSmiles (smi )) # type: ignore
230
+ except Exception :
193
231
return None
194
232
195
233
196
- def canonical_smiles (smiles , njobs = 32 ):
234
+ def canonical_smiles (smiles , njobs = - 1 ):
197
235
with Pool (njobs ) as pool :
198
236
return pool .map (canonical , smiles )
0 commit comments