Skip to content

Commit

Permalink
allowing downsample to save .mrcs number formats other than float32
Browse files Browse the repository at this point in the history
  • Loading branch information
michal-g committed Sep 9, 2024
1 parent ae6458d commit dc182ff
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 22 deletions.
4 changes: 2 additions & 2 deletions cryodrgn/commands/downsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def transform_fn(chunk, indices):
logger.info(f"Saving {out_fl}")

header = MRCHeader.make_default_header(
nz=src.n, ny=new_D, nx=new_D, Apix=new_apix, data=None, is_vol=False
nz=src.n, ny=new_D, nx=new_D, Apix=new_apix, dtype=src.dtype, is_vol=False
)
src.write_mrc(
output_file=out_fl,
Expand All @@ -181,7 +181,7 @@ def transform_fn(chunk, indices):
ny=new_D,
nx=new_D,
Apix=new_apix,
data=None,
dtype=src.dtype,
is_vol=False,
)

Expand Down
71 changes: 51 additions & 20 deletions cryodrgn/mrcfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
import numpy as np
import torch

import logging

logger = logging.getLogger(__name__)


class MRCHeader:
"""A class for representing the headers of .mrc files which store metadata.
Expand Down Expand Up @@ -120,7 +124,9 @@ def __str__(self):
return f"Header: {self.fields}\nExtended header: {self.extended_header}"

@classmethod
def parse(cls, fname):
def parse(cls, fname: str) -> Self:
"""Create a `MRCHeader` object by reading in the header from a .mrc(s) file."""

with open(fname, "rb") as f:
f.seek(cls.MACHST_OFFSET)
cls.ENDIANNESS = cls.ENDIANNESS_FOR_MACHST.get(f.read(2), "=")
Expand All @@ -143,30 +149,36 @@ def make_default_header(
ny: Optional[int] = None,
nx: Optional[int] = None,
data: Optional[Union[np.ndarray, torch.Tensor]] = None,
dtype: Optional[Union[str, np.dtype]] = None,
is_vol: bool = True,
Apix: float = 1.0,
xorg: float = 0.0,
yorg: float = 0.0,
zorg: float = 0.0,
) -> Self:
if dtype is not None:
data_dtype = np.dtype(dtype)
else:
data_dtype = np.dtype("float32") # default to np.float 32 mode

if data is not None:
nz, ny, nx = data.shape
if isinstance(data, torch.Tensor):
try:
data_dtype = np.dtype(str(data.dtype).split(".")[1])
except TypeError:
data_dtype = np.dtype("float32")
else:
data_dtype = data.dtype

if data_dtype in cls.MODE_FOR_DTYPE:
use_mode = cls.MODE_FOR_DTYPE[data_dtype]
elif data_dtype.type in cls.MODE_FOR_DTYPE:
use_mode = cls.MODE_FOR_DTYPE[data_dtype.type]
else:
use_mode = 2

if dtype is None:
if isinstance(data, torch.Tensor):
try:
data_dtype = np.dtype(str(data.dtype).split(".")[1])
except TypeError:
data_dtype = np.dtype("float32")
else:
data_dtype = data.dtype

if data_dtype in cls.MODE_FOR_DTYPE:
use_mode = cls.MODE_FOR_DTYPE[data_dtype]
elif data_dtype.type in cls.MODE_FOR_DTYPE:
use_mode = cls.MODE_FOR_DTYPE[data_dtype.type]
else:
use_mode = 2 # default to np.float 32 mode as above
use_mode = 2

assert nz is not None
assert ny is not None
Expand Down Expand Up @@ -264,7 +276,7 @@ def origin(self, value: tuple[float, float, float]) -> None:


def parse_mrc(fname: str) -> Tuple[np.ndarray, MRCHeader]:
# parse the header
"""Read in the array of data values and the header data stored in a .mrc(s) file."""
header = MRCHeader.parse(fname)

# get the number of bytes in extended header
Expand Down Expand Up @@ -301,14 +313,15 @@ def get_mrc_header(
return header


def fix_mrc_header(header: Optional[MRCHeader]) -> MRCHeader:
# Older versions of MRCHeader had incorrect cmap and stamp fields.
# Fix these before writing to disk.
def fix_mrc_header(header: MRCHeader) -> MRCHeader:
"""Fix older versions of MRCHeader with incorrect `cmap` and `stamp` fields."""
header.fields["cmap"] = b"MAP "

if header.ENDIANNESS == "=":
endianness = {"little": "<", "big": ">"}[sys.byteorder]
else:
endianness = header.ENDIANNESS

header.fields["stamp"] = header.MACHST_FOR_ENDIANNESS[endianness]

return header
Expand All @@ -322,9 +335,27 @@ def write_mrc(
transform_fn: Optional[Callable] = None,
**header_args,
) -> None:
"""Save an image stack or volume to disk as an .mrc(s) file.
Arguments
---------
filename Where the .mrc(s) will be saved.
array The image stack or volume to save to file.
header Optionally supply an MRCHeader instead of using the default one.
is_vol Don't infer whether this is a volume from the array itself.
transform_fn Apply this function to the array values before saving.
header_args Additional keyword arguments passed to `MRCHeader` if not using
your own header.
"""
if header is None:
header = get_mrc_header(array, is_vol, **header_args)
else:
if header_args:
logger.warning(
f"Passed header arguments {header_args} to `write_mrc` but these will "
"not be used as header was also given!"
)
header = fix_mrc_header(header=header)

if transform_fn is None:
Expand Down

0 comments on commit dc182ff

Please sign in to comment.