Skip to content

Commit

Permalink
updating various unit tests and fixing bug in filtering notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
michal-g committed Feb 10, 2025
1 parent a0fd143 commit 6ac2736
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 80 deletions.
14 changes: 7 additions & 7 deletions cryodrgn/templates/analysis_template.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,11 @@
"outputs": [],
"source": [
"# Load kmeans\n",
"K = 20\n",
"KMEANS = None\n",
"kmeans_labels = utils.load_pkl(os.path.join(WORKDIR, f\"analyze.{EPOCH}\",\n",
" f\"kmeans{K}\", \"labels.pkl\"))\n",
" f\"kmeans{KMEANS}\", \"labels.pkl\"))\n",
"kmeans_centers = np.loadtxt(os.path.join(WORKDIR, f\"analyze.{EPOCH}\",\n",
" f\"kmeans{K}\", \"centers.txt\"))\n",
" f\"kmeans{KMEANS}\", \"centers.txt\"))\n",
"# Or re-run kmeans with the desired number of classes\n",
"#kmeans_labels, kmeans_centers = analysis.cluster_kmeans(z, 20)\n",
"\n",
Expand Down Expand Up @@ -434,7 +434,7 @@
"source": [
"K = len(set(kmeans_labels))\n",
"c = pca.transform(kmeans_centers) # transform to view with PCs\n",
"analysis.plot_by_cluster(pc[:,0], pc[:,1], K, \n",
"analysis.plot_by_cluster(pc[:,0], pc[:,1], KMEANS,\n",
" kmeans_labels, \n",
" centers=c,\n",
" annotate=True)\n",
Expand All @@ -448,7 +448,7 @@
"metadata": {},
"outputs": [],
"source": [
"fig, ax = analysis.plot_by_cluster_subplot(pc[:,0], pc[:,1], K,\n",
"fig, ax = analysis.plot_by_cluster_subplot(pc[:,0], pc[:,1], KMEANS,\n",
" kmeans_labels)"
]
},
Expand All @@ -458,7 +458,7 @@
"metadata": {},
"outputs": [],
"source": [
"analysis.plot_by_cluster(umap[:,0], umap[:,1], K, \n",
"analysis.plot_by_cluster(umap[:,0], umap[:,1], KMEANS,\n",
" kmeans_labels, \n",
" centers_ind=centers_ind,\n",
" annotate=True)\n",
Expand All @@ -472,7 +472,7 @@
"metadata": {},
"outputs": [],
"source": [
"fig, ax = analysis.plot_by_cluster_subplot(umap[:,0], umap[:,1], K, \n",
"fig, ax = analysis.plot_by_cluster_subplot(umap[:,0], umap[:,1], KMEANS,\n",
" kmeans_labels)"
]
},
Expand Down
47 changes: 23 additions & 24 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import pytest
import os
import numpy as np
from typing import Sequence
from torch.utils.data.sampler import BatchSampler, RandomSampler
from torch.utils.data import DataLoader
from cryodrgn.dataset import DataShuffler, ImageDataset, TiltSeriesData, make_dataloader
Expand Down Expand Up @@ -153,41 +152,41 @@ def test_loading_slow(self, particles, indices, ntilts, batch_size):
dataset = TiltSeriesData(tiltstar=particles.path, ntilts=ntilts, ind=ind)
data_loader = make_dataloader(dataset, batch_size=batch_size, shuffle=True)

# minibatch is a list of (particles, tilt, indices)
for i, minibatch in enumerate(data_loader):
assert isinstance(minibatch, Sequence)
assert len(minibatch) == 3
assert isinstance(minibatch, dict)
assert len(minibatch) == 4
assert sorted(minibatch.keys()) == [
"indices",
"tilt_indices",
"y",
"y_real",
]

# We have 100 particles. For all but the last iteration *
# for all but the last iteration (100//7 = 14), we'll have 7 images each
D = dataset.D
if i < (dataset.Np // batch_size):
assert minibatch[0][0].shape == (
batch_size * ntilts,
dataset.D - 1,
dataset.D - 1,
)
assert minibatch[0][1].shape == (
batch_size * ntilts,
dataset.D,
dataset.D,
)
assert minibatch[1].shape == (batch_size * ntilts,)
assert minibatch[2].shape == (batch_size,)
assert minibatch["y"].shape == (batch_size * ntilts, D, D)
assert minibatch["y_real"].shape == (batch_size * ntilts, D - 1, D - 1)
assert minibatch["indices"].shape == (batch_size,)
assert minibatch["tilt_indices"].shape == (batch_size * ntilts,)

# and 100 % 7 = 2 for the very last one
else:
assert minibatch[0][0].shape == (
assert minibatch["y"].shape == (
(dataset.Np % batch_size) * ntilts,
dataset.D - 1,
dataset.D - 1,
D,
D,
)
assert minibatch[0][1].shape == (
assert minibatch["y_real"].shape == (
(dataset.Np % batch_size) * ntilts,
D - 1,
D - 1,
)
assert minibatch["indices"].shape == (dataset.Np % batch_size,)
assert minibatch["tilt_indices"].shape == (
(dataset.Np % batch_size) * ntilts,
dataset.D,
dataset.D,
)
assert minibatch[1].shape == ((dataset.Np % batch_size) * ntilts,)
assert minibatch[2].shape == (dataset.Np % batch_size,)


@pytest.mark.parametrize(
Expand Down
86 changes: 40 additions & 46 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,38 +69,29 @@ def test_train_model(self, tmpdir_factory, particles, poses, ctf, indices):
if indices.path is not None:
args += ["--ind", indices.path]

train_vae.main(train_vae.add_args(argparse.ArgumentParser()).parse_args(args))
parser = argparse.ArgumentParser()
train_vae.add_args(parser)
train_vae.main(parser.parse_args(args))

def test_analyze(self, tmpdir_factory, particles, poses, ctf, indices):
"""Produce standard analyses for the final epoch."""
outdir = self.get_outdir(tmpdir_factory, particles, indices, poses, ctf)
assert os.path.exists(
os.path.join(outdir, "weights.9.pkl")
os.path.join(outdir, "weights.10.pkl")
), "Upstream tests have failed!"

args = analyze.add_args(argparse.ArgumentParser()).parse_args(
[
outdir,
"9", # Epoch number to analyze - 0-indexed
"--pc",
"2", # Number of principal component traversals to generate
"--ksample",
"10", # Number of kmeans samples to generate
"--vol-start-index",
"1",
]
)
analyze.main(args)
assert os.path.exists(os.path.join(outdir, "analyze.9"))
parser = argparse.ArgumentParser()
analyze.add_args(parser)
analyze.main(parser.parse_args([outdir, "--pc", "2", "--ksample", "10"]))

@pytest.mark.parametrize(
"nb_lbl", ["cryoDRGN_figures", "cryoDRGN_filtering", "cryoDRGN_viz"]
)
assert os.path.exists(os.path.join(outdir, "analyze.10"))

@pytest.mark.parametrize("nb_lbl", ["cryoDRGN_figures", "analysis", "cryoDRGN_viz"])
def test_notebooks(self, tmpdir_factory, particles, poses, ctf, indices, nb_lbl):
"""Execute the demonstration Jupyter notebooks produced by analysis."""
outdir = self.get_outdir(tmpdir_factory, particles, indices, poses, ctf)
orig_cwd = os.path.abspath(os.getcwd())
os.chdir(os.path.join(outdir, "analyze.9"))
os.chdir(os.path.join(outdir, "analyze.10"))
assert os.path.exists(f"{nb_lbl}.ipynb"), "Upstream tests have failed!"

with open(f"{nb_lbl}.ipynb") as ff:
Expand All @@ -122,33 +113,36 @@ def test_refiltering(self, tmpdir_factory, particles, poses, ctf, indices):
ind_keep_fl = ind_keep_fl[0]

ind_keep_fl = os.path.join(outdir, ind_keep_fl)
args = train_vae.add_args(argparse.ArgumentParser()).parse_args(
[
particles.path,
"-o",
outdir,
"--ctf",
ctf.path,
"--ind",
ind_keep_fl,
"--num-epochs",
"5",
"--poses",
poses.path,
"--zdim",
"4",
"--tdim",
"64",
"--enc-dim",
"64",
"--dec-dim",
"64",
"--pe-type",
"gaussian",
"--no-amp",
]
parser = argparse.ArgumentParser()
train_vae.add_args(parser)
train_vae.main(
parser.parse_args(
[
particles.path,
"-o",
outdir,
"--ctf",
ctf.path,
"--ind",
ind_keep_fl,
"--num-epochs",
"5",
"--poses",
poses.path,
"--zdim",
"4",
"--tdim",
"64",
"--enc-dim",
"64",
"--dec-dim",
"64",
"--pe-type",
"gaussian",
"--no-amp",
]
)
)
train_vae.main(args)

shutil.rmtree(outdir)

Expand Down
7 changes: 4 additions & 3 deletions tests/test_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,11 @@ def test_write_star_from_mrcs(self, tmpdir_factory, particles_starfile, resoluti
def test_parse_ctf_cs(tmpdir, particles):
pkl_out = os.path.join(tmpdir, "ctf.pkl")
png_out = os.path.join(tmpdir, "ctf.png")
args = parse_ctf_csparc.add_args(argparse.ArgumentParser()).parse_args(
[particles.path, "-o", pkl_out, "--png", png_out]
parser = argparse.ArgumentParser()
parse_ctf_csparc.add_args(parser)
parse_ctf_csparc.main(
parser.parse_args([particles.path, "-o", pkl_out, "--png", png_out])
)
parse_ctf_csparc.main(args)

assert_pkl_close(pkl_out, os.path.join(pytest.DATADIR, "ctf2.pkl"))

Expand Down

0 comments on commit 6ac2736

Please sign in to comment.