Skip to content

Commit 003d799

Browse files
authored
Merge pull request #19 from bioinf-jku/dev
Fix numerical issues and efficiency issues
2 parents b4bcc22 + 2c45937 commit 003d799

File tree

9 files changed

+340
-80
lines changed

9 files changed

+340
-80
lines changed

.github/workflows/test_dev.yml

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
2+
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
3+
4+
name: Tests (dev)
5+
6+
on:
7+
push:
8+
branches: [ "dev" ]
9+
pull_request:
10+
branches: [ "dev" ]
11+
12+
jobs:
13+
build:
14+
15+
runs-on: ubuntu-latest
16+
strategy:
17+
fail-fast: false
18+
matrix:
19+
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
20+
21+
steps:
22+
- uses: actions/checkout@v4
23+
- name: Set up Python ${{ matrix.python-version }}
24+
uses: actions/setup-python@v5
25+
with:
26+
python-version: ${{ matrix.python-version }}
27+
- name: Install dependencies
28+
run: |
29+
python -m pip install --upgrade pip
30+
python -m pip install flake8 pytest
31+
python -m pip install -e .
32+
- name: Lint with flake8
33+
run: |
34+
# stop the build if there are Python syntax errors or undefined names
35+
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
36+
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
37+
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
38+
- name: Test with pytest
39+
run: |
40+
pytest

.github/workflows/test_master.yml

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
2+
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
3+
4+
name: Tests (master)
5+
6+
on:
7+
push:
8+
branches: [ "master"]
9+
pull_request:
10+
branches: [ "master"]
11+
12+
jobs:
13+
build:
14+
15+
runs-on: ubuntu-latest
16+
strategy:
17+
fail-fast: false
18+
matrix:
19+
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
20+
21+
steps:
22+
- uses: actions/checkout@v4
23+
- name: Set up Python ${{ matrix.python-version }}
24+
uses: actions/setup-python@v5
25+
with:
26+
python-version: ${{ matrix.python-version }}
27+
- name: Install dependencies
28+
run: |
29+
python -m pip install --upgrade pip
30+
python -m pip install flake8 pytest
31+
python -m pip install -e .
32+
- name: Lint with flake8
33+
run: |
34+
# stop the build if there are Python syntax errors or undefined names
35+
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
36+
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
37+
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
38+
- name: Test with pytest
39+
run: |
40+
pytest

README.md

+9
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,13 @@
11
# Fréchet ChemNet Distance
2+
![PyPI](https://img.shields.io/pypi/v/fcd)
3+
![PyPI - Python Version](https://img.shields.io/pypi/pyversions/fcd)
4+
![Tests (master)](https://github.com/bioinf-jku/fcd/actions/workflows/test_master.yml/badge.svg?branch=dev)
5+
![Tests (dev)](https://github.com/bioinf-jku/fcd/actions/workflows/test_dev.yml/badge.svg?branch=dev)
6+
![PyPI - Downloads](https://img.shields.io/pypi/dm/fcd)
7+
![GitHub release (latest by date)](https://img.shields.io/github/v/release/bioinf-jku/fcd)
8+
![GitHub release date](https://img.shields.io/github/release-date/bioinf-jku/fcd)
9+
![GitHub](https://img.shields.io/github/license/bioinf-jku/fcd)
10+
211

312
Code for the paper "Fréchet ChemNet Distance: A Metric for Generative Models for Molecules in Drug Discovery"
413
[JCIM](https://pubs.acs.org/doi/10.1021/acs.jcim.8b00234) /

fcd/__init__.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,14 @@
1-
from .fcd import get_fcd, get_predictions, load_ref_model
2-
from .utils import calculate_frechet_distance, canonical_smiles
1+
# ruff: noqa: F401
32

4-
__version__ = "1.2"
3+
from fcd.fcd import get_fcd, get_predictions, load_ref_model
4+
from fcd.utils import calculate_frechet_distance, canonical_smiles
5+
6+
__all__ = [
7+
"get_fcd",
8+
"get_predictions",
9+
"load_ref_model",
10+
"calculate_frechet_distance",
11+
"canonical_smiles",
12+
]
13+
14+
__version__ = "1.2.1"

fcd/fcd.py

+29-18
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from torch import nn
1010
from torch.utils.data import DataLoader
1111

12-
from .utils import (
12+
from fcd.utils import (
1313
SmilesDataset,
1414
calculate_frechet_distance,
1515
load_imported_model,
@@ -31,6 +31,8 @@ def load_ref_model(model_path: Optional[str] = None):
3131
if model_path is None:
3232
chemnet_model_filename = "ChemNet_v0.13_pretrained.pt"
3333
model_bytes = pkgutil.get_data("fcd", chemnet_model_filename)
34+
if model_bytes is None:
35+
raise FileNotFoundError(f"Could not find model file {chemnet_model_filename}")
3436

3537
tmpdir = tempfile.TemporaryDirectory()
3638
model_path = os.path.join(tmpdir.name, chemnet_model_filename)
@@ -48,7 +50,7 @@ def get_predictions(
4850
smiles_list: List[str],
4951
batch_size: int = 128,
5052
n_jobs: int = 1,
51-
device: str = "cpu",
53+
device: Optional[str] = None,
5254
) -> np.ndarray:
5355
"""Calculate Chemnet activations
5456
@@ -65,46 +67,55 @@ def get_predictions(
6567
if len(smiles_list) == 0:
6668
return np.zeros((0, 512))
6769

68-
dataloader = DataLoader(
69-
SmilesDataset(smiles_list), batch_size=batch_size, num_workers=n_jobs
70-
)
70+
dataloader = DataLoader(SmilesDataset(smiles_list), batch_size=batch_size, num_workers=n_jobs)
71+
72+
if device is None:
73+
device = "cuda" if torch.cuda.is_available() else "cpu"
74+
7175
with todevice(model, device), torch.no_grad():
7276
chemnet_activations = []
7377
for batch in dataloader:
7478
chemnet_activations.append(
75-
model(batch.transpose(1, 2).float().to(device))
76-
.to("cpu")
77-
.detach()
78-
.numpy()
79+
model(batch.transpose(1, 2).float().to(device)).to("cpu").detach().numpy().astype(np.float32)
7980
)
8081
return np.row_stack(chemnet_activations)
8182

8283

83-
def get_fcd(smiles1: List[str], smiles2: List[str], model: nn.Module = None) -> float:
84+
def get_fcd(smiles1: List[str], smiles2: List[str], model: Optional[nn.Module] = None, device=None) -> float:
8485
"""Calculate FCD between two sets of Smiles
8586
8687
Args:
87-
smiles1 (List[str]): First set of smiles
88-
smiles2 (List[str]): Second set of smiles
88+
smiles1 (List[str]): First set of SMILES.
89+
smiles2 (List[str]): Second set of SMILES.
8990
model (nn.Module, optional): The model to use. Loads default model if None.
91+
device: The device to use for computation.
9092
9193
Returns:
92-
float: The FCD score
94+
float: The FCD score.
95+
96+
Raises:
97+
ValueError: If the input SMILES lists are empty.
98+
99+
Example:
100+
>>> smiles1 = ['CCO', 'CCN']
101+
>>> smiles2 = ['CCO', 'CCC']
102+
>>> fcd_score = get_fcd(smiles1, smiles2)
93103
"""
104+
if not smiles1 or not smiles2:
105+
raise ValueError("Input SMILES lists cannot be empty.")
106+
94107
if model is None:
95108
model = load_ref_model()
96109

97-
act1 = get_predictions(model, smiles1)
98-
act2 = get_predictions(model, smiles2)
110+
act1 = get_predictions(model, smiles1, device=device)
111+
act2 = get_predictions(model, smiles2, device=device)
99112

100113
mu1 = np.mean(act1, axis=0)
101114
sigma1 = np.cov(act1.T)
102115

103116
mu2 = np.mean(act2, axis=0)
104117
sigma2 = np.cov(act2.T)
105118

106-
fcd_score = calculate_frechet_distance(
107-
mu1=mu1, mu2=mu2, sigma1=sigma1, sigma2=sigma2
108-
)
119+
fcd_score = calculate_frechet_distance(mu1=mu1, mu2=mu2, sigma1=sigma1, sigma2=sigma2)
109120

110121
return fcd_score

fcd/utils.py

+57-19
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import re
2+
import warnings
23
from contextlib import contextmanager
34
from multiprocessing import Pool
4-
from typing import List
5+
from typing import List, Optional
56

67
import numpy as np
78
import torch
@@ -10,7 +11,7 @@
1011
from torch import nn
1112
from torch.utils.data import Dataset
1213

13-
from .torch_layers import IndexTensor, IndexTuple, Reverse, SamePadding1d, Transpose
14+
from fcd.torch_layers import IndexTensor, IndexTuple, Reverse, SamePadding1d, Transpose
1415

1516
# fmt: off
1617
__vocab = ["C","N","O","H","F","Cl","P","B","Br","S","I","Si","#","(",")","+","-","1","2","3","4","5","6","7","8","=","[","]","@","c","n","o","s","X","."]
@@ -42,7 +43,7 @@ def tokenize(smiles: str) -> List[str]:
4243
return tok_smile
4344

4445

45-
def get_one_hot(smiles: str, pad_len: int = -1) -> np.ndarray:
46+
def get_one_hot(smiles: str, pad_len: Optional[int] = None) -> np.ndarray:
4647
"""Generate one-hot representation of a Smiles string.
4748
4849
Args:
@@ -52,10 +53,13 @@ def get_one_hot(smiles: str, pad_len: int = -1) -> np.ndarray:
5253
Returns:
5354
np.ndarray: Array containing the one-hot encoded Smiles
5455
"""
56+
# add end token
5557
smiles = smiles + "."
5658

5759
# initialize array
58-
array_length = len(smiles) if pad_len < 0 else pad_len
60+
array_length = len(smiles) if pad_len is None else pad_len
61+
assert array_length >= len(smiles), "Pad length must be greater than the length of the input SMILES string + 1."
62+
5963
vocab_size = len(__vocab)
6064
one_hot = np.zeros((array_length, vocab_size))
6165

@@ -106,22 +110,57 @@ def load_imported_model(keras_config):
106110

107111

108112
class SmilesDataset(Dataset):
109-
__PAD_LEN = 350
113+
"""
114+
A dataset class for handling SMILES data.
115+
116+
Args:
117+
smiles_list (list): A list of SMILES strings.
118+
pad_len (int, optional): The length to pad the SMILES strings to. If not provided, the default pad length of 350 will be used.
119+
warn (bool, optional): Whether to display a warning message if the specified pad length is different from the default. Defaults to True.
120+
121+
Attributes:
122+
smiles_list (list): A list of SMILES strings.
123+
pad_len (int): The length to pad the SMILES strings to.
124+
125+
"""
110126

111-
def __init__(self, smiles_list):
127+
def __init__(self, smiles_list, pad_len=None, warn=True):
112128
super().__init__()
129+
DEFAULT_PAD_LEN = 350
130+
113131
self.smiles_list = smiles_list
132+
max_len = max(len(smiles) for smiles in smiles_list) + 1 # plus one for the end token
133+
134+
if pad_len is None:
135+
pad_len = max(DEFAULT_PAD_LEN, max_len)
136+
else:
137+
if pad_len < max_len:
138+
raise ValueError(f"Specified pad_len {pad_len} is less than max_len {max_len}")
139+
140+
if pad_len != DEFAULT_PAD_LEN:
141+
warnings.warn(
142+
"""Padding lengths differing from the default of 350 may affect FCD scores. See https://github.com/hogru/GuacaMolEval.
143+
Use warn=False to suppress this warning."""
144+
)
145+
146+
self.pad_len = pad_len
114147

115148
def __getitem__(self, idx):
116149
smiles = self.smiles_list[idx]
117-
features = get_one_hot(smiles, 350)
150+
features = get_one_hot(smiles, pad_len=self.pad_len)
118151
return features / features.shape[1]
119152

120153
def __len__(self):
121154
return len(self.smiles_list)
122155

123156

124-
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
157+
def calculate_frechet_distance(
158+
mu1: np.ndarray,
159+
sigma1: np.ndarray,
160+
mu2: np.ndarray,
161+
sigma2: np.ndarray,
162+
eps: float = 1e-6,
163+
) -> float:
125164
"""Numpy implementation of the Frechet Distance.
126165
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
127166
and X_2 ~ N(mu_2, C_2) is
@@ -151,21 +190,20 @@ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
151190
sigma1 = np.atleast_2d(sigma1)
152191
sigma2 = np.atleast_2d(sigma2)
153192

154-
assert (
155-
mu1.shape == mu2.shape
156-
), "Training and test mean vectors have different lengths"
157-
assert (
158-
sigma1.shape == sigma2.shape
159-
), "Training and test covariances have different dimensions"
193+
assert mu1.shape == mu2.shape, "Training and test mean vectors have different lengths"
194+
assert sigma1.shape == sigma2.shape, "Training and test covariances have different dimensions"
160195

161196
diff = mu1 - mu2
162197

163198
# product might be almost singular
164199
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
165-
if not np.isfinite(covmean).all():
200+
is_real = np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3)
201+
202+
if not np.isfinite(covmean).all() or not is_real:
166203
offset = np.eye(sigma1.shape[0]) * eps
167204
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
168205

206+
assert isinstance(covmean, np.ndarray)
169207
# numerical error might give slight imaginary component
170208
if np.iscomplexobj(covmean):
171209
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
@@ -175,7 +213,7 @@ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
175213

176214
tr_covmean = np.trace(covmean)
177215

178-
return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
216+
return float(diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean)
179217

180218

181219
@contextmanager
@@ -188,11 +226,11 @@ def todevice(model, device):
188226

189227
def canonical(smi):
190228
try:
191-
return Chem.MolToSmiles(Chem.MolFromSmiles(smi))
192-
except:
229+
return Chem.MolToSmiles(Chem.MolFromSmiles(smi)) # type: ignore
230+
except Exception:
193231
return None
194232

195233

196-
def canonical_smiles(smiles, njobs=32):
234+
def canonical_smiles(smiles, njobs=-1):
197235
with Pool(njobs) as pool:
198236
return pool.map(canonical, smiles)

0 commit comments

Comments
 (0)