Skip to content

Commit

Permalink
better tests for filtering .star files; fixing pytest fixture outdir …
Browse files Browse the repository at this point in the history
…issues
  • Loading branch information
michal-g committed May 6, 2024
1 parent 010a587 commit 993556c
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 23 deletions.
24 changes: 17 additions & 7 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,9 +283,8 @@ def train_load_epoch(self, load_epoch: int, train_epochs: int) -> None:
def train_dir(request, tmpdir_factory) -> TrainDir:
"""Run an experiment to generate output; remove this output when finished."""
args = TrainDir.parse_request(request.param)
args.update(
dict(out_lbl=tmpdir_factory.mktemp(f"output_{request.node.__class__.__name__}"))
)
out_lbl = f"train-outs_{request.node.__class__.__name__}"
args.update(dict(out_lbl=tmpdir_factory.mktemp(out_lbl)))

tdir = TrainDir(**args)
yield tdir
Expand Down Expand Up @@ -324,12 +323,19 @@ def __init__(
zdim: int,
dataset: str = "hand",
epochs: int = 2,
out_lbl: Optional[str] = None,
seed: Optional[int] = None,
) -> None:
self.zdim = zdim
self.dataset = dataset
self.particles, _ = get_testing_datasets(dataset)
self.outdir = os.path.abspath(f"test-output_{dataset}")

if out_lbl is None:
self.outdir = os.path.abspath(f"test-output_{dataset}")
else:
self.outdir = os.path.abspath(out_lbl)

shutil.rmtree(self.outdir, ignore_errors=True)
os.makedirs(self.outdir)
self.epochs = epochs
self.seed = seed
Expand Down Expand Up @@ -411,8 +417,12 @@ def view_config(self) -> None:
assert err == "", err


@pytest.fixture(scope="session")
def abinit_dir(request) -> AbInitioDir:
adir = AbInitioDir(**AbInitioDir.parse_request(request.param))
@pytest.fixture
def abinit_dir(request, tmpdir_factory) -> AbInitioDir:
args = AbInitioDir.parse_request(request.param)
out_lbl = f"abinit-outs_{request.function.__name__}"
args.update(dict(out_lbl=tmpdir_factory.mktemp(out_lbl)))

adir = AbInitioDir(**args)
yield adir
shutil.rmtree(adir.outdir)
14 changes: 10 additions & 4 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,15 @@ def test_analyze(self, outdir, particles, poses):
@pytest.mark.parametrize("nb_lbl", ["cryoDRGN_figures"])
def test_notebooks(self, outdir, particles, poses, nb_lbl):
"""Execute the demonstration Jupyter notebooks produced by analysis."""
orig_cwd = os.path.abspath(os.getcwd())
os.chdir(os.path.join(outdir, "analyze.2"))
assert os.path.exists(f"{nb_lbl}.ipynb")

with open(f"{nb_lbl}.ipynb") as ff:
nb_in = nbformat.read(ff, nbformat.NO_CONVERT)

ExecutePreprocessor(timeout=600, kernel_name="python3").preprocess(nb_in)
os.chdir(os.path.join("..", ".."))
os.chdir(orig_cwd)

def test_landscape(self, outdir, particles, poses):
args = analyze_landscape.add_args(argparse.ArgumentParser()).parse_args(
Expand Down Expand Up @@ -257,14 +258,15 @@ def test_analyze(self, outdir, particles, ctf, indices):
@pytest.mark.parametrize("nb_lbl", ["cryoDRGN_figures"])
def test_notebooks(self, outdir, particles, ctf, indices, nb_lbl):
"""Execute the demonstration Jupyter notebooks produced by analysis."""
orig_cwd = os.path.abspath(os.getcwd())
os.chdir(os.path.join(outdir, "analyze.1"))
assert os.path.exists(f"{nb_lbl}.ipynb"), "Upstream tests have failed!"

with open(f"{nb_lbl}.ipynb") as ff:
nb_in = nbformat.read(ff, nbformat.NO_CONVERT)

ExecutePreprocessor(timeout=600, kernel_name="python3").preprocess(nb_in)
os.chdir(os.path.join("..", ".."))
os.chdir(orig_cwd)


@pytest.mark.parametrize("particles", ["tilts.star"], indirect=True)
Expand Down Expand Up @@ -450,17 +452,19 @@ def test_backproject(
@pytest.mark.parametrize("nb_lbl", ["cryoDRGN_figures", "cryoDRGN_ET_viz"])
def test_notebooks(self, outdir, particles, indices, poses, ctf, datadir, nb_lbl):
"""Execute the demonstration Jupyter notebooks produced by analysis."""
orig_cwd = os.path.abspath(os.getcwd())
os.chdir(os.path.join(outdir, "analyze.4"))
assert os.path.exists(f"{nb_lbl}.ipynb"), "Upstream tests have failed!"

with open(f"{nb_lbl}.ipynb") as ff:
nb_in = nbformat.read(ff, nbformat.NO_CONVERT)

ExecutePreprocessor(timeout=600, kernel_name="python3").preprocess(nb_in)
os.chdir(os.path.join("..", ".."))
os.chdir(orig_cwd)

def test_refiltering(self, outdir, particles, indices, poses, ctf, datadir):
"""Use particle index creating during analysis."""
orig_cwd = os.path.abspath(os.getcwd())
os.chdir(os.path.join(outdir, "analyze.4"))
assert os.path.exists("tmp_ind_selected.pkl"), "Upstream tests have failed!"

Expand Down Expand Up @@ -499,6 +503,7 @@ def test_refiltering(self, outdir, particles, indices, poses, ctf, datadir):

args = train_vae.add_args(argparse.ArgumentParser()).parse_args(args)
train_vae.main(args)
os.chdir(orig_cwd)


@pytest.mark.parametrize("particles", ["tilts.star"], indirect=True)
Expand Down Expand Up @@ -635,11 +640,12 @@ def test_analyze(self, outdir, particles, indices, poses, ctf):
)
def test_notebooks(self, outdir, particles, indices, poses, ctf, nb_lbl):
"""Execute the demonstration Jupyter notebooks produced by analysis."""
orig_cwd = os.path.abspath(os.getcwd())
os.chdir(os.path.join(outdir, "analyze.2"))
assert os.path.exists(f"{nb_lbl}.ipynb"), "Upstream tests have failed!"

with open(f"{nb_lbl}.ipynb") as ff:
nb_in = nbformat.read(ff, nbformat.NO_CONVERT)

ExecutePreprocessor(timeout=600, kernel_name="python3").preprocess(nb_in)
os.chdir(os.path.join("..", ".."))
os.chdir(orig_cwd)
51 changes: 39 additions & 12 deletions tests/test_read_filter_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pickle
import numpy as np
import torch
from itertools import product
from cryodrgn.source import ImageSource
from cryodrgn.commands import parse_ctf_star
from cryodrgn.commands_utils import filter_star, write_cs, write_star
Expand Down Expand Up @@ -39,31 +40,57 @@ def test_read_starfile(particles, datadir):

@pytest.mark.parametrize("particles", ["toy.star", "toy.star-13"], indirect=True)
@pytest.mark.parametrize("datadir", ["default-datadir"], indirect=True)
def test_filter(outdir, particles, datadir):
indices_pkl = os.path.join(outdir, "indices.pkl")
with open(indices_pkl, "wb") as f:
# 0-based indices into the input star file
# Note that these indices are simply the 0-indexed row numbers in the starfile,
# and have nothing to do with the index of the individual particle in the MRCS
# file (e.g. 00042@mymrcs.mrcs)
pickle.dump([11, 3, 2, 4], f)
@pytest.mark.parametrize(
"index_pair",
[([11, 3, 2, 4], [1, 2, 3]), ([5, 8, 11], [0, 7, 10])],
ids=("inds1", "inds2"),
)
def test_filter(outdir, particles, datadir, index_pair):
indices_pkl1 = os.path.join(outdir, "indices1.pkl")
indices_pkl2 = os.path.join(outdir, "indices2.pkl")

with open(indices_pkl1, "wb") as f:
pickle.dump(index_pair[0], f)
with open(indices_pkl2, "wb") as f:
pickle.dump(index_pair[1], f)

out_fl = os.path.join(outdir, "issue150_filtered.star")
parser = argparse.ArgumentParser()
filter_star.add_args(parser)
filter_star.main(parser.parse_args([particles, "--ind", indices_pkl, "-o", out_fl]))
filter_star.main(
parser.parse_args([particles, "--ind", indices_pkl1, "-o", out_fl])
)

data = ImageSource.from_file(out_fl, lazy=False, datadir=datadir).images()
assert isinstance(data, torch.Tensor)
assert data.shape == (4, 30, 30)
data1 = ImageSource.from_file(out_fl, lazy=False, datadir=datadir).images()
assert isinstance(data1, torch.Tensor)
assert data1.shape == (len(index_pair[0]), 30, 30)
os.remove(out_fl)

filter_star.main(
parser.parse_args([particles, "--ind", indices_pkl2, "-o", out_fl])
)

data2 = ImageSource.from_file(out_fl, lazy=False, datadir=datadir).images()
assert isinstance(data2, torch.Tensor)
assert data2.shape == (len(index_pair[1]), 30, 30)
os.remove(out_fl)

for (i, ind1), (j, ind2) in product(
enumerate(index_pair[0]), enumerate(index_pair[1])
):
if ind1 == ind2:
assert np.allclose(data1[i, ...], data2[j, ...])


@pytest.mark.parametrize("datadir", ["default-datadir"], indirect=True)
@pytest.mark.parametrize("particles", ["tilts.star"], indirect=True)
class TestFilterStar:
def test_filter_with_indices(self, outdir, particles, datadir):
indices_pkl = os.path.join(outdir, "indices.pkl")
# 0-based indices into the input star file
# Note that these indices are simply the 0-indexed row numbers in the starfile,
# and have nothing to do with the index of the individual particle in the MRCS
# file (e.g. 00042@mymrcs.mrcs)
with open(indices_pkl, "wb") as f:
pickle.dump([1, 3, 4], f)

Expand Down

0 comments on commit 993556c

Please sign in to comment.