diff --git a/brainglobe_template_builder/io.py b/brainglobe_template_builder/io.py index b8ea506..d7a3c79 100644 --- a/brainglobe_template_builder/io.py +++ b/brainglobe_template_builder/io.py @@ -1,4 +1,6 @@ +import os from pathlib import Path +from typing import Literal import nibabel as nib import numpy as np @@ -7,7 +9,11 @@ from brainglobe_utils.IO.image.save import to_tiff -def get_unique_folder_in_dir(search_dir: Path, search_str: str) -> Path: +def get_unique_folder_in_dir( + search_dir: Path, + search_str: str, + str_position: Literal["start", "end"] | None = None, +) -> Path: """ Find a folder in a directory that contains a unique string. @@ -17,6 +23,10 @@ def get_unique_folder_in_dir(search_dir: Path, search_str: str) -> Path: Directory to search in search_str : str String to search for in folder names + str_position: Literal["start", "end"] | None + If None (default), ``search_str`` can be anywhere in the folder name. + If "start", ``search_str`` must be at the start of the folder name. + If "end", ``search_str`` must be at the end of the folder name. Returns ------- @@ -25,6 +35,15 @@ def get_unique_folder_in_dir(search_dir: Path, search_str: str) -> Path: """ all_folders = [x for x in search_dir.iterdir() if x.is_dir()] folders_with_str = [x for x in all_folders if search_str in x.name] + if str_position == "start": + folders_with_str = [ + x for x in folders_with_str if x.name.startswith(search_str) + ] + elif str_position == "end": + folders_with_str = [ + x for x in folders_with_str if x.name.endswith(search_str) + ] + if len(folders_with_str) == 0: raise ValueError(f"No folders with {search_str} found") if len(folders_with_str) > 1: @@ -196,3 +215,27 @@ def nifti_to_tiff(nifti_path: Path, tiff_path: Path): """ stack = load_any(nifti_path.as_posix()) to_tiff(stack, tiff_path.as_posix()) + + +def get_path_from_env_variable(env_var: str, default_path: str) -> Path: + """ + Get a path from an environment variable, with a default fall-back path. + + This could be useful for debugging a script locally - i.e. by setting + and environment variable on your machine, you can override the default + path used in the script (which might point to a shared cluster directory). + + Parameters + ---------- + env_var : str + The name of the environment variable to read, e.g. "ATLAS_FORGE_DIR" + default_path : str + The default path to use if the environment variable is not set, + e.g. "/ceph/neuroinformatics/neuroinformatics/atlas-forge" + + Returns + ------- + atlas_dir : Path + The path to the directory + """ + return Path(os.getenv(env_var, default_path)) diff --git a/brainglobe_template_builder/plots.py b/brainglobe_template_builder/plots.py new file mode 100644 index 0000000..8cf515b --- /dev/null +++ b/brainglobe_template_builder/plots.py @@ -0,0 +1,331 @@ +from pathlib import Path +from typing import Literal + +import numpy as np +from brainglobe_space import AnatomicalSpace +from matplotlib import pyplot as plt + + +def plot_orthographic( + img: np.ndarray, + anat_space: str = "ASR", + voxel_sizes: tuple[float, float, float] = (1.0, 1.0, 1.0), + show_slices: tuple[int, int, int] | None = None, + mip_attenuation: float = 0.01, + save_path: Path | None = None, + **kwargs, +) -> tuple[plt.Figure, np.ndarray]: + """Plot image volume in three orthogonal views, plus a surface rendering. + + The surface rendering is a maximum intensity projection (MIP) along the + vertical (superior-inferior) axis and is shown from the top. + + Parameters + ---------- + img : np.ndarray + Image volume to plot. + anat_space : str, optional + Anatomical space of of the image volume according to the Brainglobe + definition (origin and order of axes), by default "ASR". + voxel_sizes : tuple, optional + Voxels sizes in micrometers per dimension, by default (1.0, 1.0, 1.0). + The relative sizes of the axes will be preserved in the plot. + show_slices : tuple, optional + Which slice to show per dimension. If None (default), show the middle + slice along each dimension. + mip_attenuation : float, optional + Attenuation factor for the MIP, by default 0.01. + A value of 0 means no attenuation. + save_path : Path, optional + Path to save the plot, by default None (no saving). + **kwargs + Additional keyword arguments to pass to ``matplotlib.pyplot.imshow``. + + Returns + ------- + tuple[plt.Figure, np.ndarray] + Matplotlib figure and axes objects + + """ + + space = AnatomicalSpace(anat_space) + vertical_axis = space.get_axis_idx("vertical") + + # Get middle slices if not specified + if show_slices is None: + slices_list = [s // 2 for s in img.shape] + else: + slices_list = list(show_slices) + + # Pad the image with zeros to make it cubic + img, pad_sizes = _pad_with_zeros(img, target=max(img.shape)) + slices_list = [s + pad_sizes[i] for i, s in enumerate(slices_list)] + + # Compute (attenuated) MIP along the vertical axis + mip, mip_label = _compute_attenuated_mip( + img, vertical_axis, mip_attenuation + ) + + # Create figure with 4 subplots (3 orthogonal views + MIP) + fig, axs = plt.subplots(1, 4, figsize=(14, 4)) + views = [img.take(slc, axis=i) for i, slc in enumerate(slices_list)] + views.append(mip) + axis_labels = [*space.axis_labels, space.axis_labels[vertical_axis]] + section_names = [s.capitalize() for s in space.sections] + [mip_label] + + kwargs = _set_imshow_defaults(img, kwargs) + + for j, (section, labels) in enumerate(zip(section_names, axis_labels)): + ax = axs[j] + ax.imshow(views[j], **kwargs) + ax.set_title(section) + ax.set_ylabel(labels[0]) + ax.set_xlabel(labels[1]) + ax = _clear_spines_and_ticks(ax) + plt.tight_layout() + + if save_path: + _save_and_close_figure( + fig, save_path.parent, save_path.name.split(".")[0] + ) + return fig, axs + + +def plot_grid( + img: np.ndarray, + anat_space="ASR", + section: Literal["frontal", "horizontal", "sagittal"] = "frontal", + n_slices: int = 12, + save_path: Path | None = None, + **kwargs, +) -> tuple[plt.Figure, np.ndarray]: + """Plot image volume as a grid of slices along a given anatomical section. + + Parameters + ---------- + img : np.ndarray + Image volume to plot. + anat_space : str, optional + Anatomical space of of the image volume according to the Brainglobe + definition (origin and order of axes), by default "ASR". + section : str, optional + Section to show, must be one of "frontal", "horizontal", or "sagittal", + by default "frontal". + n_slices : int, optional + Number of slices to show, by default 12. Slices will be evenly spaced, + starting from the first and ending with the last slice. + save_path : Path, optional + Path to save the plot, by default None (no saving). + **kwargs + Additional keyword arguments to pass to ``matplotlib.pyplot.imshow``. + + Returns + ------- + tuple[plt.Figure, np.ndarray] + Matplotlib figure and axes objects + + """ + space = AnatomicalSpace(anat_space) + section_to_axis = { # Mapping of section names to space axes + "frontal": "sagittal", + "horizontal": "vertical", + "sagittal": "frontal", + } + axis_idx = space.get_axis_idx(section_to_axis[section]) + + # Ensure n_slices is not greater than the number of slices in the image + n_slices = min(n_slices, img.shape[axis_idx]) + # ensure first and last slices are included + show_slices = np.linspace(0, img.shape[axis_idx] - 1, n_slices, dtype=int) + + # Get slices along the specified axis and arrange them in a grid + grid_img = _grid_from_slices( + [img.take(slc, axis=axis_idx) for slc in show_slices] + ) + + # Plot the grid image + fig, ax = plt.subplots(1, 1, figsize=(12, 12)) + kwargs = _set_imshow_defaults(img, kwargs) + ax.imshow(grid_img, **kwargs) + + section_name = section.capitalize() + ax.set_title(f"{section_name} slices") + ax.set_xlabel(space.axis_labels[axis_idx][1]) + ax.set_ylabel(space.axis_labels[axis_idx][0]) + ax = _clear_spines_and_ticks(ax) + plt.tight_layout() + + if save_path: + _save_and_close_figure( + fig, save_path.parent, save_path.name.split(".")[0] + ) + return fig, ax + + +def _compute_attenuated_mip( + img: np.ndarray, axis: int, attenuation_factor: float +) -> tuple[np.ndarray, str]: + """Compute the maximum intensity projection (MIP) with attenuation. + + If the image is zero-padded, attenuation is only applied within the + non-zero region along the specified axis. + + Parameters + ---------- + img : np.ndarray + Image volume. + axis : int + Axis along which to compute the MIP. + attenuation_factor : float + Attenuation factor for the MIP. 0 means no attenuation. + + Returns + ------- + tuple[np.ndarray, str] + MIP image and label. The label is "MIP" if no attenuation is applied, + and "MIP (attenuated)" otherwise. + """ + + mip_label = "MIP" + + if attenuation_factor < 0: + raise ValueError("Attenuation factor must be non-negative.") + + if attenuation_factor < 1e-6: + # If the factor is too small, skip attenuation + mip = np.max(img, axis=axis) + return mip, mip_label + + # Find the non-zero bounding box along the specified axis + other_axes = tuple(i for i in range(img.ndim) if i != axis) + non_zero_mask = np.any(img != 0, axis=other_axes) + non_zero_indices = np.nonzero(non_zero_mask)[0] + start, end = non_zero_indices[0], non_zero_indices[-1] + 1 + + # Trim the image along the attenuation axis (get rid of zero-padding) + slices = [slice(None)] * img.ndim + slices[axis] = slice(start, end) + trimmed_img = img[tuple(slices)] + + # Apply attenuation to the trimmed image + attenuation = np.exp( + -attenuation_factor * np.arange(trimmed_img.shape[axis]) + ) + attenuation_shape = [1] * trimmed_img.ndim + attenuation_shape[axis] = trimmed_img.shape[axis] + attenuation = attenuation.reshape(attenuation_shape) + attenuated_img = trimmed_img.astype(np.float32) * attenuation + + # Compute and return the attenuated MIP + mip = np.max(attenuated_img, axis=axis) + mip_label += " (attenuated)" + + return mip, mip_label + + +def _save_and_close_figure(fig: plt.Figure, plots_dir: Path, filename: str): + """Save figure in both PNG and PDF formats and close it.""" + fig.savefig(plots_dir / f"{filename}.png") + fig.savefig(plots_dir / f"{filename}.pdf") + plt.close(fig) + + +def _clear_spines_and_ticks(ax: plt.Axes) -> plt.Axes: + """Clear spines and ticks from a matplotlib axis.""" + ax.set_xticks([]) + ax.set_yticks([]) + for spine in ax.spines.values(): + spine.set_visible(False) + return ax + + +def _set_imshow_defaults(img: np.ndarray, kwargs: dict) -> dict: + """Set default values for imshow keyword arguments. + + These apply only if the user does not provide them explicitly. + """ + if "vmin" not in kwargs and "vmax" not in kwargs: + vmin, vmax = _auto_adjust_contrast(img) + kwargs.setdefault("vmin", vmin) + kwargs.setdefault("vmax", vmax) + + kwargs.setdefault("cmap", "gray") + kwargs.setdefault("aspect", "equal") + return kwargs + + +def _grid_from_slices(slices: list[np.ndarray]) -> np.ndarray: + """Create a grid image from a list of 2D slices. + + The number of rows is automatically determined based on the square root + of the number of slices, rounded up. + + Parameters + ---------- + slices : list[np.ndarray] + List of 2D slices to concatenate. + + Returns + ------- + np.ndarray + A 2D image, with the input slices arranged in a grid. + + """ + + n_slices = len(slices) + slice_height, slice_width = slices[0].shape + + # Form image mosaic grid by concatenating slices + n_rows = int(np.ceil(np.sqrt(n_slices))) + n_cols = int(np.ceil(n_slices / n_rows)) + grid_img = np.zeros( + (n_rows * slice_height, n_cols * slice_width), + ) + for i, slice in enumerate(slices): + row = i // n_cols + col = i % n_cols + grid_img[ + row * slice_height : (row + 1) * slice_height, + col * slice_width : (col + 1) * slice_width, + ] = slice + + return grid_img + + +def _pad_with_zeros( + img: np.ndarray, target: int = 512 +) -> tuple[np.ndarray, tuple[int, int, int]]: + """Pad the volume with zeros to reach the target size in all dimensions.""" + pad_sizes = [(target - s) // 2 for s in img.shape] + padded_img = np.pad( + img, + ( + (pad_sizes[0], pad_sizes[0]), + (pad_sizes[1], pad_sizes[1]), + (pad_sizes[2], pad_sizes[2]), + ), + mode="constant", + ) + return padded_img, tuple(pad_sizes) + + +def _auto_adjust_contrast(img, lower_percentile=1, upper_percentile=99): + """Adjust contrast of an image using percentile-based scaling.""" + # Mask near-zero voxels to exclude background + if np.issubdtype(img.dtype, np.integer): + background_threshold = 1 + else: + background_threshold = np.finfo(img.dtype).eps + + brain_mask = img > background_threshold + + # Exclude bright artifacts + vmax = np.percentile(img[brain_mask], upper_percentile) + artifact_mask = img <= vmax + combined_mask = brain_mask & artifact_mask + + # Compute vmin and vmax + vmin = np.percentile(img[combined_mask], lower_percentile) + vmax = np.percentile(img[combined_mask], upper_percentile) + + return vmin, vmax diff --git a/data/SWC_Rat_use-for-template.csv b/data/SWC_Rat_use-for-template.csv new file mode 100644 index 0000000..fd7f044 --- /dev/null +++ b/data/SWC_Rat_use-for-template.csv @@ -0,0 +1,35 @@ +subject_id,color,hemi,comment +CAP01,red,both,chipped cerebellum and cortex +CAP02,red,both ,dark tiles & a few labelled cells +CAP03,green,both,looks ok +CAP04,green,both ,looks ok +CAP05,red,both,left cortical damage & some labelled cells +CAP08,green,both ,missing tiles & some labelled neurons +CAP09,red,both,labelled cells in the cerebellum & SNR not great +CAP11,red,both ,missing tiles +CAP12,red,both,missing tiles & some labelled cells +CAP14,green,both ,few missing tiles +CAP18,green,both,missing tiles & chipped cortex +CAP22,red,both ,missing tiles +CAP24,red,both,chipped cerebellum & missing tiles & few labelled cells +CAP25,red,both ,chipped cortex & mangled OB +CAP26,red,both,bands of different brightness +CAP27,red,both ,missing tiles & bleedthrough +CAP28,green,both,labelled cells in the cerebellum +CAP29,red,none,very poor SNR & one hemisphere darker +CAP30,red,both,bands of different brightness +CAP31,red,both,chipped cerebellum & one very bright spot +CAP32,red,both,maybe a few missing slices otherwise OK +CAP33,red,both,looks very good +CAP34,red,both,few missing tiles but looks ok +CAP35,red,both,chipped brainstem & fully present OB +CAP38,green,none,extensive cortex damage and deformations +CAP40,red,left,right cortex is severely damaged +CAP42,green,none,cortex is damaged from recording & frontal pole missing +CAP43,red,both,some cortical damage & missing slices +CAP44,green,both,few dark tiles & some cortical damage +CAP45,red,both,cortical damage +CAP48,red,both,image is a bit dark +CAP49,green,both,chipped cortex +CAP50,green,both,damaged cerebellum and cortex +CAP54,green,both,little cerebellar damage but otherwise ok diff --git a/examples/rat/1_source_images.py b/examples/rat/1_source_images.py new file mode 100644 index 0000000..f20c721 --- /dev/null +++ b/examples/rat/1_source_images.py @@ -0,0 +1,281 @@ +""" +Identify source images for the SWC rat template +=============================================== +Set up project directory with all relevant data and identify source images +to be used for building the SWC rat brain template. + +This scipt must be run on the SWC HPC cluster. +""" + +# %% +# Imports +# ------- +import os +import shutil +from datetime import date +from pathlib import Path + +import pandas as pd +from loguru import logger +from tqdm import tqdm + +from brainglobe_template_builder.io import ( # type: ignore + get_path_from_env_variable, + get_unique_folder_in_dir, + load_tiff, +) +from brainglobe_template_builder.plots import plot_grid, plot_orthographic + +# Set up directories and logging +# ------------------------------ + +# Prepare directory structure +atlas_dir = get_path_from_env_variable( + "ATLAS_FORGE_DIR", "/ceph/neuroinformatics/neuroinformatics/atlas-forge" +) + +species_id = "Rat" # the subfolder names within atlas_dir +species_dir = atlas_dir / species_id + +species_dir.mkdir(parents=True, exist_ok=True) +# Make "rawdata", "derivatives", "templates", and "logs" directories +for folder in ["rawdata", "derivatives", "templates", "logs"]: + (species_dir / folder).mkdir(exist_ok=True) + +# Directory where source images are stored for this species +# This must contain subfolders for each subject +source_dir = get_path_from_env_variable( + "ATLAS_SOURCE_DIR", "/ceph/akrami/capsid_testing/imaging/2p" +) + +# Set up logging +today = date.today() +current_script_name = os.path.basename(__file__).replace(".py", "") +logger.add(species_dir / "logs" / f"{today}_{current_script_name}.log") +logger.info(f"Will save outputs to {species_dir}.") + +# %% +# Load a dataframe with all SWC brains used for atlases +# ------------------------------------------------------ + +path_of_this_script = Path(__file__).resolve() +source_csv_dir = path_of_this_script.parent.parent.parent / "data" +source_csv_path = source_csv_dir / "SWC_brain-list_for-atlases_2024-04-15.csv" +df = pd.read_csv(source_csv_path) + +# Strip trailing space from str columns and from column names +df = df.apply(lambda x: x.str.strip() if x.dtype == "object" else x) +df.columns = df.columns.str.strip() +# Select only data for the species of interest +species_common_name = "rat" # as in the source csv table +df = df[df["Common name"] == species_common_name] +logger.info(f"Found {len(df)} {species_id} subjects.") + +# %% +# Exract subject IDs and paths +# ---------------------------- + +# Rename "Specimen ID" to "subject_id" +df["subject_id"] = df["Specimen ID"] +# Assert that subject IDs are unique +assert len(df["subject_id"].unique()) == len(df), "Non-unique subject IDs" + + +# Use the Ceph path as the source data directory for all subjects +# (All data have been migrated there, after the table was created) +data_path_col = "Data path (raw)" # column name in the source csv +df[data_path_col] = source_dir.as_posix() + + +# Find each subject's source data folder +df["subject_path"] = df.apply( + lambda x: get_unique_folder_in_dir( + Path(x[data_path_col]), x["subject_id"], str_position="end" + ), + axis=1, +) + +sub_ids = df["subject_id"].values +logger.info(f"Identified {sub_ids.size} unique subject IDs. {sub_ids}") +sub_paths_msg = "Subject paths:\n" +for idx, row in df.iterrows(): + sub_id = row["subject_id"] + sub_path = row["subject_path"].as_posix() + sub_paths_msg += f" {sub_id}: {sub_path}\n" +logger.debug(sub_paths_msg) + +# %% +# Identify all downsampled images +# ------------------------------- +# We will identify all downsampled images for each subject, by iterating over +# the downsampled_stacks folders and the various channels. +# We will aggregate the image information and paths in a dataframe. +# We assume the standard naming conventions of SWC's serial two-photon output. + +valid_colors = ["far_red", "red", "green", "blue"] # order matters + +images_list_of_dicts = [] +for _, row in df.iterrows(): + sub = row["subject_id"] + + # Path to the downsampled stacks (often, but not always, in a subfolder) + sub_path_down = row["subject_path"] + if "downsampled_stacks" in os.listdir(sub_path_down): + sub_path_down = sub_path_down / "downsampled_stacks" + + # Stacks are stored for each resolution in a separate folder + # These folders are named like "010_micron", "025_micron", etc. + stacks = [f for f in os.listdir(sub_path_down) if f.endswith("micron")] + + for stack in sorted(stacks): + microns = int(stack.split("_")[0]) + stack_path = sub_path_down / stack + + images = [ + f + for f in os.listdir(stack_path) + if f.endswith(".tif") and "_ch0" in f and not f.startswith(".") + ] + for img in sorted(images): + ch_num = int(img.split("_ch0")[1].split("_")[0]) + ch_color = [ + c for c in valid_colors if img.split(".tif")[0].endswith(c) + ][0] + ch_color = ch_color if ch_color != "far_red" else "farred" + image_id = f"sub-{sub}_res-{microns}um_channel-{ch_color}" + images_list_of_dicts.append( + { + "subject_id": sub, + "microns": microns, + "channel": ch_num, + "color": ch_color, + "image_id": image_id, + "image_path": stack_path / img, + } + ) + +images_df = pd.DataFrame(images_list_of_dicts) +n_img = len(images_df) +logger.info(f"Found {n_img} images across subjects, resolutions and channels.") + +# %% +# Check for missing downsampled stacks +# ------------------------------------- +# Logs a warning if any of the expected resolutions are missing + +expected_microns = [10, 25, 50] +for sub in df["subject_id"]: + for microns in expected_microns: + if not ( + images_df[ + (images_df["subject_id"] == sub) + & (images_df["microns"] == microns) + ].shape[0] + ): + logger.warning(f"Subject {sub} lacks {microns} micron stack") + + +# %% +# Save dataframes and images to rawdata +# ------------------------------------- +# Save the dataframes to the species rawdata directory. + +rawdata_dir = species_dir / "rawdata" +subjects_csv = rawdata_dir / f"{today}_subjects.csv" +df.to_csv(subjects_csv, index=False) +logger.info(f"Saved subject information to csv: {subjects_csv}") + +images_csv = rawdata_dir / f"{today}_images.csv" +images_df.to_csv(images_csv, index=False) +logger.info(f"Saved image information to csv: {images_csv}") + + +# %% +# Save images to the species rawdata directory +# (doesn't overwrite existing images). +# High-resolution images (10um) are just symlinked to avoid duplication +# of large files. + +n_copied, n_symlinked = 0, 0 +for idx in tqdm(images_df.index): + row = images_df.loc[idx, :] + sub = row["subject_id"] + sub_dir = rawdata_dir / f"sub-{sub}" + sub_dir.mkdir(exist_ok=True) + microns = row["microns"] + image_id = row["image_id"] + img_source_path = row["image_path"] + img_dest_path = sub_dir / f"{image_id}.tif" + + print(img_dest_path) + # if the destination path exists, skip + if img_dest_path.exists(): + logger.debug(f"Skipping {img_dest_path} as it already exists.") + continue + + # if image is 10 microns, symlink it (to avoid duplication of large files) + if microns == 10: + img_dest_path.symlink_to(img_source_path) + logger.debug(f"Symlinked {img_dest_path} to {img_source_path}") + n_symlinked += 1 + # else copy the image + else: + shutil.copyfile(img_source_path, img_dest_path) + logger.debug(f"Copied {img_source_path} to {img_dest_path}") + n_copied += 1 + +logger.info( + f"Copied {n_copied} and symlinked {n_symlinked} " + f"images to {rawdata_dir}." +) + + +# %% +# Save diagnostic plots to get a quick overview of the images. +# Plots will only be generated for the low-resolution images (50um). +# Plots are saved in the rawdata/sub-/plots folder. + +subjects = sorted([f for f in os.listdir(rawdata_dir) if f.startswith("sub-")]) + +for sub in tqdm(subjects): + sub_dir = rawdata_dir / sub + + # Find all 50um images for the subject + images = [ + f + for f in os.listdir(sub_dir) + if f.endswith(".tif") and "res-50um" in f + ] + + # Plots will be saved in a subfolder + if images: + sub_plot_dir = sub_dir / "plots" + sub_plot_dir.mkdir(exist_ok=True) + logger.debug(f"Saving plots to {sub_plot_dir}...") + + for img in images: + # load the tiff image as numpy array + img_path = sub_dir / img + try: + img = load_tiff(img_path) + except Exception as e: + logger.error(f"Failed to load {img_path}: {e}") + continue + + # Plot frontal (coronal) sections in a grid + fig, _ = plot_grid( + img, + anat_space="PSL", + section="frontal", + n_slices=12, + save_path=sub_plot_dir / f"{img_path.stem}_grid", + ) + logger.debug(f"Saved grid plot for {img_path.stem}.") + # Plot the image in three orthogonal views + max intensity projection + fig, _ = plot_orthographic( + img, + anat_space="PSL", + save_path=sub_plot_dir / f"{img_path.stem}_orthographic", + mip_attenuation=0.02, + ) + logger.debug(f"Saved orthographic plot for {img_path.stem}.")