Skip to content

Commit

Permalink
more fixes to FSC correction and code style
Browse files Browse the repository at this point in the history
  • Loading branch information
michal-g committed Sep 10, 2024
1 parent a36e078 commit d8d3cc9
Showing 1 changed file with 36 additions and 22 deletions.
58 changes: 36 additions & 22 deletions cryodrgn/commands_utils/fsc.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,19 +99,22 @@ def add_args(parser: argparse.ArgumentParser) -> None:
)


def get_fftn_dists(resolution: int) -> np.array:
x = np.arange(-resolution // 2, resolution // 2)
def get_fftn_center_dists(box_size: int) -> np.array:
"""Get distances from the center (and hence the resolution) for FFT co-ordinates."""

x = np.arange(-box_size // 2, box_size // 2)
x2, x1, x0 = np.meshgrid(x, x, x, indexing="ij")
coords = np.stack((x0, x1, x2), -1)
dists = (coords**2).sum(-1) ** 0.5
assert dists[resolution // 2, resolution // 2, resolution // 2] == 0.0
assert dists[box_size // 2, box_size // 2, box_size // 2] == 0.0

return dists


def calc_fsc(
def calculate_fsc(
v1: Union[np.ndarray, torch.Tensor], v2: Union[np.ndarray, torch.Tensor]
) -> float:
"""Calculate the Fourier Shell Correlation between two complex vectors."""
var = (np.vdot(v1, v1) * np.vdot(v2, v2)) ** 0.5

if var:
Expand All @@ -122,30 +125,32 @@ def calc_fsc(
return fsc


def calculate_fsc(
def get_fsc_curve(
vol1: torch.Tensor,
vol2: torch.Tensor,
initial_mask: Optional[torch.Tensor] = None,
out_file: Optional[str] = None,
) -> pd.DataFrame:
"""Calculate the FSCs between two volumes across all available resolutions."""

# Apply the given mask before applying the Fourier transform
maskvol1 = vol1 * initial_mask if initial_mask is not None else vol1.clone()
maskvol2 = vol2 * initial_mask if initial_mask is not None else vol2.clone()
res = vol1.shape[0]
dists = get_fftn_dists(res)
box_size = vol1.shape[0]
dists = get_fftn_center_dists(box_size)
maskvol1 = fft.fftn_center(maskvol1)
maskvol2 = fft.fftn_center(maskvol2)

# logger.info(r[D//2, D//2, D//2:])
prev_mask = np.zeros((res, res, res), dtype=bool)
prev_mask = np.zeros((box_size, box_size, box_size), dtype=bool)
fsc = [1.0]
for i in range(1, res // 2):
for i in range(1, box_size // 2):
mask = dists < i
shell = np.where(mask & np.logical_not(prev_mask))
fsc.append(calc_fsc(maskvol1[shell], maskvol2[shell]))
fsc.append(calculate_fsc(maskvol1[shell], maskvol2[shell]))
prev_mask = mask

fsc_vals = pd.DataFrame(
dict(pixres=np.arange(res // 2) / res, fsc=fsc), dtype=float
dict(pixres=np.arange(box_size // 2) / box_size, fsc=fsc), dtype=float
)
if out_file is not None:
logger.info(f"Saving FSC values to {out_file}")
Expand All @@ -157,6 +162,8 @@ def calculate_fsc(
def get_fsc_thresholds(
fsc_vals: pd.DataFrame, apix: float, verbose: bool = True
) -> tuple[float, float]:
"""Retrieve the max resolutions at which an FSC curve is above 0.5 and 0.143."""

if ((fsc_vals.pixres > 0) & (fsc_vals.fsc >= 0.5)).any():
res_05 = fsc_vals.pixres[fsc_vals.fsc >= 0.5].max()
if verbose:
Expand Down Expand Up @@ -207,12 +214,22 @@ def correct_fsc(
f"instead have {fsc_vals.shape[0]}!"
)

maskvol1 = vol1 * initial_mask if initial_mask is not None else vol1.clone()
maskvol2 = vol2 * initial_mask if initial_mask is not None else vol2.clone()
dists = get_fftn_dists(box_size)
# Randomize phases in the raw half-maps beyond the given threshold
dists = get_fftn_center_dists(box_size)
fftvol1 = fft.fftn_center(vol1)
fftvol2 = fft.fftn_center(vol2)
phase_res = int(randomization_threshold * box_size)
rand_shell = np.where(dists >= phase_res)
fftvol1[rand_shell] = fftvol1[rand_shell].apply_(randomize_phase)
fftvol2[rand_shell] = fftvol2[rand_shell].apply_(randomize_phase)
fftvol1 = fft.ifftn_center(fftvol1)
fftvol2 = fft.ifftn_center(fftvol2)

# Apply the given masks then go back into Fourier space
maskvol1 = fftvol1 * initial_mask if initial_mask is not None else fftvol1.clone()
maskvol2 = fftvol2 * initial_mask if initial_mask is not None else fftvol2.clone()
maskvol1 = fft.fftn_center(maskvol1)
maskvol2 = fft.fftn_center(maskvol2)
phase_res = int(randomization_threshold * box_size)

# re-calculate the FSCs past the resolution using the phase-randomized volumes
prev_mask = np.zeros((box_size, box_size, box_size), dtype=bool)
Expand All @@ -222,10 +239,7 @@ def correct_fsc(
shell = np.where(mask & np.logical_not(prev_mask))

if i > phase_res:
p = calc_fsc(
maskvol1[shell].apply_(randomize_phase),
maskvol2[shell].apply_(randomize_phase),
)
p = calculate_fsc(maskvol1[shell], maskvol2[shell])

# normalize the original FSC value using the phase-randomized value
if p == 1.0:
Expand Down Expand Up @@ -275,7 +289,7 @@ def calculate_cryosparc_fscs(
)

fsc_vals = {
mask_lbl: calculate_fsc(half_vol1, half_vol2, initial_mask=mask)
mask_lbl: get_fsc_curve(half_vol1, half_vol2, initial_mask=mask)
for mask_lbl, mask in masks.items()
}
fsc_thresh = {
Expand Down Expand Up @@ -364,7 +378,7 @@ def main(args: argparse.Namespace) -> None:
f"Calculating FSCs between `{args.volumes[0]}` and `{args.volumes[1]}`..."
)
if args.ref_volume is None:
fsc_vals = calculate_fsc(
fsc_vals = get_fsc_curve(
volumes[0].images(),
volumes[1].images(),
mask,
Expand Down

0 comments on commit d8d3cc9

Please sign in to comment.