diff --git a/pyproject.toml b/pyproject.toml index b5202d9..394a110 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ dependencies = [ "ipywidgets", "natsort", "ipython", + "zarr", ] diff --git a/src/nimbus_inference/nimbus.py b/src/nimbus_inference/nimbus.py index 4850e8c..3bdffc0 100644 --- a/src/nimbus_inference/nimbus.py +++ b/src/nimbus_inference/nimbus.py @@ -1,11 +1,8 @@ from alpineer import io_utils, misc_utils from skimage.util.shape import view_as_windows import nimbus_inference -from nimbus_inference.utils import ( - prepare_normalization_dict, - predict_fovs, - predict_ome_fovs, - nimbus_preprocess, +from nimbus_inference.utils import (prepare_normalization_dict, + predict_fovs, nimbus_preprocess, MultiplexDataset ) from huggingface_hub import hf_hub_download from nimbus_inference.unet import UNet @@ -69,7 +66,7 @@ def segmentation_naming_convention(fov_path): Returns: seg_path (str): paths to segmentation fovs """ - fov_name = os.path.basename(fov_path) + fov_name = os.path.basename(fov_path).replace(".ome.tiff", "") return os.path.join(deepcell_output_dir, fov_name + "_whole_cell.tiff") return segmentation_naming_convention @@ -79,18 +76,15 @@ class Nimbus(nn.Module): """Nimbus application class for predicting marker activity for cells in multiplexed images.""" def __init__( - self, fov_paths, segmentation_naming_convention, output_dir, save_predictions=True, - include_channels=[], half_resolution=True, batch_size=4, test_time_aug=True, - input_shape=[1024, 1024], suffix=".tiff", device="auto", + self, dataset: MultiplexDataset, output_dir: str, save_predictions: bool=True, + half_resolution: bool=True, batch_size: int=4, test_time_aug: bool=True, + input_shape: list=[1024, 1024], device: str="auto", ): """Initializes a Nimbus Application. Args: - fov_paths (list): List of paths to fovs to be analyzed. - segmentation_naming_convention (function): Function that returns the path to the - segmentation mask for a given fov path. + dataset (MultiplexDataset): Path to directory containing fovs. output_dir (str): Path to directory to save output. save_predictions (bool): Whether to save predictions. - include_channels (list): List of channels to include in analysis. half_resolution (bool): Whether to run model on half resolution images. batch_size (int): Batch size for model inference. test_time_aug (bool): Whether to use test time augmentation. @@ -100,9 +94,7 @@ def __init__( , with "cpu" as a fallback), "cpu", "cuda", or "mps". Defaults to "auto". """ super(Nimbus, self).__init__() - self.fov_paths = fov_paths - self.include_channels = include_channels - self.segmentation_naming_convention = segmentation_naming_convention + self.dataset = dataset self.output_dir = output_dir self.half_resolution = half_resolution self.save_predictions = save_predictions @@ -110,7 +102,6 @@ def __init__( self.checked_inputs = False self.test_time_aug = test_time_aug self.input_shape = input_shape - self.suffix = suffix if self.output_dir != "": os.makedirs(self.output_dir, exist_ok=True) @@ -127,25 +118,9 @@ def __init__( def check_inputs(self): """check inputs for Nimbus model""" - # check if all paths in fov_paths exists - if not isinstance(self.fov_paths, (list, tuple)): - self.fov_paths = [self.fov_paths] - io_utils.validate_paths(self.fov_paths) - - # check if segmentation_naming_convention returns valid paths - path_to_segmentation = self.segmentation_naming_convention(self.fov_paths[0]) - if not os.path.exists(path_to_segmentation): - raise FileNotFoundError( - "Function segmentation_naming_convention does not return valid\ - path. Segmentation path {} does not exist.".format( - path_to_segmentation - ) - ) # check if output_dir exists io_utils.validate_paths([self.output_dir]) - if isinstance(self.include_channels, str): - self.include_channels = [self.include_channels] self.checked_inputs = True print("All inputs are valid.") @@ -161,14 +136,14 @@ def initialize_model(self, padding="reflect"): self.checkpoint_path = os.path.join( path, "assets", - "resUnet_baseline_hickey_tonic_dec_mskc_mskp_2_channel_halfres_512_bs32.pt" + "resUnet_baseline_hickey_tonic_dec_mskc_mskp_2_channel_halfres_512_bs32_cw_0.8.pt" ) if not os.path.exists(self.checkpoint_path): local_dir = os.path.join(path, "assets") print("Downloading weights from Hugging Face Hub...") self.checkpoint_path = hf_hub_download( repo_id="JLrumberger/Nimbus-Inference", - filename="resUnet_baseline_hickey_tonic_dec_mskc_mskp_2_channel_halfres_512_bs32.pt", + filename="resUnet_baseline_hickey_tonic_dec_mskc_mskp_2_channel_halfres_512_bs32_cw_0.8.pt", local_dir=local_dir, local_dir_use_symlinks=False, ) @@ -192,13 +167,11 @@ def prepare_normalization_dict( if os.path.exists(self.normalization_dict_path) and not overwrite: self.normalization_dict = json.load(open(self.normalization_dict_path)) else: - n_jobs = os.cpu_count() if multiprocessing else 1 self.normalization_dict = prepare_normalization_dict( - self.fov_paths, self.output_dir, quantile, self.include_channels, n_subset, n_jobs + self.dataset, self.output_dir, quantile, n_subset, + n_jobs ) - if self.include_channels == []: - self.include_channels = list(self.normalization_dict.keys()) def predict_fovs(self): """Predicts cell classification for input data. @@ -214,24 +187,12 @@ def predict_fovs(self): print("Available GPUs: ", gpus) print("Predictions will be saved in {}".format(self.output_dir)) print("Iterating through fovs will take a while...") - if self.suffix.lower() in [".ome.tif", ".ome.tiff"]: - self.cell_table = predict_ome_fovs( - nimbus=self, fov_paths=self.fov_paths, output_dir=self.output_dir, - normalization_dict=self.normalization_dict, - segmentation_naming_convention=self.segmentation_naming_convention, - include_channels=self.include_channels, save_predictions=self.save_predictions, - half_resolution=self.half_resolution, batch_size=self.batch_size, - test_time_augmentation=self.test_time_aug, suffix=self.suffix, - ) - elif self.suffix.lower() in [".tiff", ".tif", ".jpg", ".jpeg", ".png"]: - self.cell_table = predict_fovs( - nimbus=self, fov_paths=self.fov_paths, output_dir=self.output_dir, - normalization_dict=self.normalization_dict, - segmentation_naming_convention=self.segmentation_naming_convention, - include_channels=self.include_channels, save_predictions=self.save_predictions, - half_resolution=self.half_resolution, batch_size=self.batch_size, - test_time_augmentation=self.test_time_aug, suffix=self.suffix, - ) + self.cell_table = predict_fovs( + nimbus=self, dataset=self.dataset, output_dir=self.output_dir, + normalization_dict=self.normalization_dict, save_predictions=self.save_predictions, + half_resolution=self.half_resolution, batch_size=self.batch_size, + test_time_augmentation=self.test_time_aug, suffix=self.dataset.suffix, + ) self.cell_table.to_csv(os.path.join(self.output_dir, "nimbus_cell_table.csv"), index=False) return self.cell_table diff --git a/src/nimbus_inference/utils.py b/src/nimbus_inference/utils.py index 7d74fea..9e35350 100644 --- a/src/nimbus_inference/utils.py +++ b/src/nimbus_inference/utils.py @@ -6,80 +6,234 @@ import numpy as np import pandas as pd import imageio as io -# from skimage import io +from copy import copy from tqdm.autonotebook import tqdm from joblib import Parallel, delayed from joblib.externals.loky import get_reusable_executor from skimage.segmentation import find_boundaries from skimage.measure import regionprops_table +from pyometiff import OMETIFFReader +from pyometiff.omexml import OMEXML +from alpineer import io_utils, misc_utils +from typing import Callable +import tifffile +import zarr +import sys, os +import logging +import os, sys -def calculate_normalization(channel_path, quantile): - """Calculates the normalization value for a given channel - Args: - channel_path (str): path to channel - quantile (float): quantile to use for normalization - Returns: - normalization_value (float): normalization value - """ - mplex_img = io.imread(channel_path) - mplex_img = mplex_img.astype(np.float32) - foreground = mplex_img[mplex_img > 0] - normalization_value = np.quantile(foreground, quantile) - chan = os.path.basename(channel_path).split(".")[0] - return chan, normalization_value +class HidePrints: + def __enter__(self): + self._original_stdout = sys.stdout + sys.stdout = open(os.devnull, 'w') + def __exit__(self, exc_type, exc_val, exc_tb): + sys.stdout.close() + sys.stdout = self._original_stdout -def prepare_normalization_dict( - fov_paths, output_dir, quantile=0.999, include_channels=[], n_subset=10, n_jobs=1, - output_name="normalization_dict.json" - ): - """Prepares the normalization dict for a list of fovs - Args: - fov_paths (list): list of paths to fovs - output_dir (str): path to output directory - quantile (float): quantile to use for normalization - exclude_channels (list): list of channels to exclude - n_subset (int): number of fovs to use for normalization - n_jobs (int): number of jobs to use for joblib multiprocessing - output_name (str): name of output file - Returns: - normalization_dict (dict): dict with channel names as keys and norm factors as values - """ - normalization_dict = {} - if n_subset is not None: - random.shuffle(fov_paths) - fov_paths = fov_paths[:n_subset] - print("Iterate over fovs...") - for fov_path in tqdm(fov_paths): - channels = os.listdir(fov_path) - if include_channels: + +class LazyOMETIFFReader(OMETIFFReader): + def __init__(self, fpath: str): + """Lazy OMETIFFReader class that reads channels only when needed + Args: + fpath (str): path to ome.tif file + """ + super().__init__(fpath) + self.metadata = self.get_metadata() + self.channels = self.get_channel_names() + self.shape = self.get_shape() + + def get_metadata(self): + """Get the metadata of the OME-TIFF file + Returns: + metadata (dict): metadata of the OME-TIFF file + """ + with tifffile.TiffFile(str(self.fpath)) as tif: + if tif.is_ome: + omexml_string = tif.ome_metadata + with HidePrints(): + metadata = self.parse_metadata(omexml_string) + return metadata + else: + raise ValueError("File is not an OME-TIFF file.") + + def get_channel_names(self): + """Get the channel names of the OME-TIFF file + Returns: + channel_names (list): list of channel names + """ + if hasattr(self, "metadata"): + return list(self.metadata["Channels"].keys()) + else: + return [] + + def get_shape(self): + """Get the shape of the OME-TIFF file array data + Returns: + shape (tuple): shape of the array data + """ + with tifffile.imread(str(self.fpath), aszarr=True) as store: + z = zarr.open(store, mode='r') + shape = z.shape + return shape + + def get_channel(self, channel_name: str): + """Get an individual channel from the OME-TIFF file by name + Args: + channel_name (str): name of the channel + Returns: + channel (np.array): channel image + """ + idx = self.channels.index(channel_name) + with tifffile.imread(str(self.fpath), aszarr=True) as store: + z = zarr.open(store, mode='r') + # correct DimOrder, often DimOrder is TZCYX, but image is stored as CYX, + # thus we remove the trailing dimensions + dim_order = self.metadata["DimOrder"] + dim_order = dim_order[-len(z.shape):] + channel_idx = dim_order.find("C") + slice_string = "z[" + ":," * channel_idx + str(idx) + "]" + channel = eval(slice_string) + return channel + + +class MultiplexDataset(): + def __init__( + self, fov_paths: list, segmentation_naming_convention: Callable = None, + include_channels: list = [], suffix: str = ".tiff", silent=False, + ): + """Multiplex dataset class that gives a common interface for data loading of multiplex + datasets stored as individual channel images in folders or as multi-channel tiffs. + Args: + fov_paths (list): list of paths to fovs + segmentation_naming_convention (function): function to get instance mask path from fov + path + suffix (str): suffix of channel images + silent (bool): whether to print messages + """ + self.fov_paths = fov_paths + self.segmentation_naming_convention = segmentation_naming_convention + self.suffix = suffix + self.silent = silent + self.include_channels = include_channels + self.multi_channel = self.is_multi_channel_tiff(fov_paths[0]) + self.channels = self.get_channels() + self.check_inputs() + self.fovs = self.get_fovs() + self.channels = self.filter_channels(self.channels) + + def filter_channels(self, channels): + """Filter channels based on include_channels + Args: + channels (list): list of channel names + Returns: + channels (list): filtered list of channel names + """ + if self.include_channels: + return [channel for channel in channels if channel in self.include_channels] + return channels + + def check_inputs(self): + """check inputs for Nimbus model""" + # check if all paths in fov_paths exists + if not isinstance(self.fov_paths, (list, tuple)): + self.fov_paths = [self.fov_paths] + io_utils.validate_paths(self.fov_paths) + if isinstance(self.include_channels, str): + self.include_channels = [self.include_channels] + misc_utils.verify_in_list( + include_channels=self.include_channels, dataset_channels=self.channels + ) + if not self.silent: + print("All inputs are valid") + + def __len__(self): + """Return the number of fovs in the dataset""" + return len(self.fov_paths) + + def is_multi_channel_tiff(self, fov_path: str): + """Check if fov is a multi-channel tiff + Args: + fov_path (str): path to fov + Returns: + multi_channel (bool): whether fov is multi-channel + """ + multi_channel = False + if fov_path.lower().endswith(("ome.tif", "ome.tiff")): + self.img_reader = LazyOMETIFFReader(fov_path) + if len(self.img_reader.shape) > 2: + multi_channel = True + return multi_channel + + def get_channels(self): + """Get the channel names for the dataset""" + if self.multi_channel: + return self.img_reader.channels + else: channels = [ - channel for channel in channels if channel.split(".")[0] in include_channels + channel.replace(self.suffix, "") for channel in os.listdir(self.fov_paths[0]) \ + if channel.endswith(self.suffix) ] - channel_paths = [os.path.join(fov_path, channel) for channel in channels] - if n_jobs > 1: - normalization_values = Parallel(n_jobs=n_jobs)( - delayed(calculate_normalization)(channel_path, quantile) - for channel_path in channel_paths - ) + return channels + + def get_fovs(self): + """Get the fovs in the dataset""" + return [os.path.basename(fov).replace(self.suffix, "") for fov in self.fov_paths] + + def get_channel(self, fov: str, channel: str): + """Get a channel from a fov + Args: + fov (str): name of a fov + channel (str): channel name + Returns: + channel (np.array): channel image + """ + if self.multi_channel: + return self.get_channel_stack(fov, channel) else: - normalization_values = [ - calculate_normalization(channel_path, quantile) - for channel_path in channel_paths - ] - for channel, normalization_value in normalization_values: - if channel not in normalization_dict: - normalization_dict[channel] = [] - normalization_dict[channel].append(normalization_value) - if n_jobs > 1: - get_reusable_executor().shutdown(wait=True) - for channel in normalization_dict.keys(): - normalization_dict[channel] = np.mean(normalization_dict[channel]) - # save normalization dict - with open(os.path.join(output_dir, output_name), 'w') as f: - json.dump(normalization_dict, f) - return normalization_dict + return self.get_channel_single(fov, channel) + + def get_channel_single(self, fov: str, channel: str): + """Get a channel from a fov stored as a folder with individual channel images + Args: + fov (str): name of a fov + channel (str): channel name + Returns: + channel (np.array): channel image + """ + idx = self.fovs.index(fov) + fov_path = self.fov_paths[idx] + channel_path = os.path.join(fov_path, channel + self.suffix) + channel = np.squeeze(io.imread(channel_path)) + return channel + + def get_channel_stack(self, fov: str, channel: str): + """Get a channel from a multi-channel tiff + Args: + fov (str): name of a fov + channel (str): channel name + data_format (str): data format + Returns: + channel (np.array): channel image + """ + idx = self.fovs.index(fov) + fov_path = self.fov_paths[idx] + self.img_reader = LazyOMETIFFReader(fov_path) + return np.squeeze(self.img_reader.get_channel(channel)) + + def get_segmentation(self, fov: str): + """Get the instance mask for a fov + Args: + fov (str): name of a fov + Returns: + instance_mask (np.array): instance mask + """ + idx = self.fovs.index(fov) + fov_path = self.fov_paths[idx] + instance_path = self.segmentation_naming_convention(fov_path) + instance_mask = np.squeeze(io.imread(instance_path)) + return instance_mask def prepare_input_data(mplex_img, instance_mask): @@ -120,7 +274,7 @@ def test_time_aug( Args: input_data (np.array): input data for segmentation model, mplex_img and binary mask channel (str): channel name - app (tf.keras.Model): segmentation model + app (Nimbus): segmentation model normalization_dict (dict): dict with channel names as keys and norm factors as values rotate (bool): whether to rotate flip (bool): whether to flip @@ -169,19 +323,17 @@ def test_time_aug( def predict_fovs( - nimbus, fov_paths, normalization_dict, segmentation_naming_convention, output_dir, - suffix, include_channels=[], save_predictions=True, half_resolution=False, batch_size=4, - test_time_augmentation=True + nimbus, dataset: MultiplexDataset, normalization_dict: dict, output_dir: str, + suffix: str="tiff", save_predictions: bool=True, half_resolution: bool=False, + batch_size: int=4, test_time_augmentation: bool=True ): """Predicts the segmentation map for each mplex image in each fov Args: nimbus (Nimbus): nimbus object - fov_paths (list): list of fov paths + dataset (MultiplexDataset): dataset object normalization_dict (dict): dict with channel names as keys and norm factors as values - segmentation_naming_convention (function): function to get instance mask path from fov path output_dir (str): path to output dir suffix (str): suffix of mplex images - include_channels (list): list of channels to include save_predictions (bool): whether to save predictions half_resolution (bool): whether to use half resolution batch_size (int): batch size @@ -190,22 +342,15 @@ def predict_fovs( cell_table (pd.DataFrame): cell table with predicted confidence scores per fov and cell """ fov_dict_list = [] - for fov_path in fov_paths: + for fov_path, fov in zip(dataset.fov_paths, dataset.fovs): print(f"Predicting {fov_path}...") out_fov_path = os.path.join( os.path.normpath(output_dir), os.path.basename(fov_path) ) df_fov = pd.DataFrame() - for channel in tqdm(os.listdir(fov_path)): - channel_path = os.path.join(fov_path, channel) - channel_ = channel.split(".")[0] - if not channel.endswith(suffix) or ( - include_channels != [] and channel_ not in include_channels - ): - continue - mplex_img = np.squeeze(io.imread(channel_path)) - instance_path = segmentation_naming_convention(fov_path) - instance_mask = np.squeeze(io.imread(instance_path)) + instance_mask = dataset.get_segmentation(fov) + for channel_name in tqdm(dataset.channels): + mplex_img = dataset.get_channel(fov, channel_name) input_data = prepare_input_data(mplex_img, instance_mask) if half_resolution: scale = 0.5 @@ -218,13 +363,13 @@ def predict_fovs( input_data = np.stack([img, binary_mask], axis=0)[np.newaxis,...] if test_time_augmentation: prediction = test_time_aug( - input_data, channel, nimbus, normalization_dict, batch_size=batch_size + input_data, channel_name, nimbus, normalization_dict, batch_size=batch_size ) else: prediction = nimbus.predict_segmentation( input_data, preprocess_kwargs={ - "normalize": True, "marker": channel, + "normalize": True, "marker": channel_name, "normalization_dict": normalization_dict }, ) @@ -235,12 +380,12 @@ def predict_fovs( if df_fov.empty: df_fov["label"] = df["label"] df_fov["fov"] = os.path.basename(fov_path) - df_fov[channel.split(".")[0]] = df["intensity_mean"] + df_fov[channel_name] = df["intensity_mean"] if save_predictions: os.makedirs(out_fov_path, exist_ok=True) pred_int = (prediction*255.0).astype(np.uint8) io.imwrite( - os.path.join(out_fov_path, channel), pred_int, photometric="minisblack", + os.path.join(out_fov_path, channel_name + suffix), pred_int, photometric="minisblack", # compress=0, ) fov_dict_list.append(df_fov) @@ -275,5 +420,72 @@ def nimbus_preprocess(image, **kwargs): return output -def predict_ome_fovs(): - pass \ No newline at end of file +def calculate_normalization(dataset: MultiplexDataset, quantile: float): + """Calculates the normalization values for a given ome file + Args: + dataset (MultiplexDataset): dataset object + quantile (float): quantile to use for normalization + Returns: + normalization_values (dict): dict with channel names as keys and norm factors as values + """ + normalization_values = {} + for channel in dataset.channels: + mplex_img = dataset.get_channel(dataset.fovs[0], channel) + mplex_img = mplex_img.astype(np.float32) + foreground = mplex_img[mplex_img > 0] + normalization_values[channel] = np.quantile(foreground, quantile) + return normalization_values + + +def prepare_normalization_dict( + dataset: MultiplexDataset, output_dir: str, quantile: float=0.999, n_subset: int=10, + n_jobs: int=1, output_name: str="normalization_dict.json" + ): + """Prepares the normalization dict for a list of ome.tif fovs + Args: + MultiplexDataset (list): list of paths to fovs + output_dir (str): path to output directory + quantile (float): quantile to use for normalization + n_subset (int): number of fovs to use for normalization + n_jobs (int): number of jobs to use for joblib multiprocessing + output_name (str): name of output file + Returns: + normalization_dict (dict): dict with channel names as keys and norm factors as values + """ + normalization_dict = {} + fov_paths = copy(dataset.fov_paths) + if n_subset is not None: + random.shuffle(fov_paths) + fov_paths = fov_paths[:n_subset] + print("Iterate over fovs...") + if n_jobs > 1: + normalization_values = Parallel(n_jobs=n_jobs)( + delayed(calculate_normalization)( + MultiplexDataset( + [fov_path], dataset.segmentation_naming_convention, dataset.channels, + dataset.suffix, True + ), quantile) + for fov_path in fov_paths + ) + else: + normalization_values = [ + calculate_normalization( + MultiplexDataset( + [fov_path], dataset.segmentation_naming_convention, dataset.channels, + dataset.suffix, True + ), quantile) + for fov_path in fov_paths + ] + for norm_dict in normalization_values: + for channel, normalization_value in norm_dict.items(): + if channel not in normalization_dict: + normalization_dict[channel] = [] + normalization_dict[channel].append(normalization_value) + if n_jobs > 1: + get_reusable_executor().shutdown(wait=True) + for channel in normalization_dict.keys(): + normalization_dict[channel] = np.mean(normalization_dict[channel]) + # save normalization dict + with open(os.path.join(output_dir, output_name), 'w') as f: + json.dump(normalization_dict, f) + return normalization_dict diff --git a/src/nimbus_inference/viewer_widget.py b/src/nimbus_inference/viewer_widget.py index e406e79..db5ec90 100644 --- a/src/nimbus_inference/viewer_widget.py +++ b/src/nimbus_inference/viewer_widget.py @@ -7,26 +7,25 @@ import numpy as np from natsort import natsorted from skimage.segmentation import find_boundaries - +from nimbus_inference.utils import MultiplexDataset class NimbusViewer(object): def __init__( - self, input_dir, output_dir, segmentation_naming_convention=None, img_width='600px' + self, dataset: MultiplexDataset, output_dir: str, img_width='600px', suffix=".tiff" ): """Viewer for Nimbus application. Args: - input_dir (str): Path to directory containing individual channels of multiplexed images + dataset (MultiplexDataset): dataset object output_dir (str): Path to directory containing output of Nimbus application. segmentation_naming_convention (fn): Function that maps input path to segmentation path img_width (str): Width of images in viewer. + suffix (str): Suffix of images in dataset. """ self.image_width = img_width - self.input_dir = input_dir + self.dataset = dataset self.output_dir = output_dir - self.segmentation_naming_convention = segmentation_naming_convention - self.fov_names = [os.path.basename(p) for p in os.listdir(output_dir) if \ - os.path.isdir(os.path.join(output_dir, p))] - self.fov_names = natsorted(self.fov_names) + self.suffix = suffix + self.fov_names = natsorted(copy(self.dataset.fovs)) self.update_button = widgets.Button(description="Update Image") self.update_button.on_click(self.update_button_click) self.overlay_checkbox = widgets.Checkbox( @@ -65,13 +64,58 @@ def select_fov(self, change): Args: change (dict): Change dictionary from ipywidgets. """ - fov_path = os.path.join(self.output_dir, self.fov_select.value) - channels = [ - ch for ch in os.listdir(fov_path) if os.path.isfile(os.path.join(fov_path, ch)) - ] - self.red_select.options = natsorted(channels) - self.green_select.options = natsorted(channels) - self.blue_select.options = natsorted(channels) + channels = natsorted(copy(self.dataset.channels)) + self.red_select.options = channels + self.green_select.options = channels + self.blue_select.options = channels + + def overlay(self, composite_image, add_boundaries=False, add_overlay=False): + """Adds overlay to composite image. + Args: + composite_image (np.array): Composite image to add overlay to. + boundaries (bool): Whether to add boundaries to overlay. + Returns: + composite_image (np.array): Composite image with overlay. + """ + seg_img = self.dataset.get_segmentation(self.fov_select.value) + seg_boundaries = find_boundaries(seg_img, mode='inner') + seg_img[seg_boundaries] = 0 + seg_img = np.clip(seg_img, 0, 1) + seg_img = np.repeat(seg_img[..., np.newaxis], 3, axis=-1) * np.max(composite_image) + background_mask = composite_image < np.max(composite_image) * 0.2 + if add_overlay: + composite_image[background_mask] += (seg_img[background_mask] * 0.2).astype( + composite_image.dtype + ) + if add_boundaries: + val = (np.max(composite_image, axis=(0,1))*0.5).astype(composite_image.dtype) + val = np.min(val[val>0]) + composite_image[seg_boundaries] = [val]*3 + else: + seg_boundaries = None + return composite_image, seg_boundaries + + def create_composite_from_dataset(self, path_dict): + """Creates composite image from input paths. + Args: + path_dict (dict): Dictionary of paths to images. + Returns: + composite_image (np.array): Composite image. + """ + for k in ["red", "green", "blue"]: + if k not in path_dict.keys(): + path_dict[k] = None + output_image = [] + for p in list(path_dict.values()): + if p: + img = self.dataset.get_channel(p["fov"], p["channel"]) + output_image.append(img) + else: + p = [p for p in path_dict.values() if p][0] + img = self.dataset.get_channel(p["fov"], p["channel"]) + output_image.append(img*0) + composite_image = np.stack(output_image, axis=-1) + return composite_image def create_composite_image(self, path_dict, add_overlay=True, add_boundaries=False): """Creates composite image from input paths. @@ -94,29 +138,7 @@ def create_composite_image(self, path_dict, add_overlay=True, add_boundaries=Fal output_image.append(img*0) # add overlay of instances composite_image = np.stack(output_image, axis=-1) - if self.segmentation_naming_convention and add_overlay: - fov_path = os.path.split(list(path_dict.values())[0])[0] - seg_path = self.segmentation_naming_convention(fov_path) - seg_img = io.imread(seg_path) - seg_boundaries = find_boundaries(seg_img, mode='inner') - seg_img[seg_boundaries] = 0 - seg_img = np.clip(seg_img, 0, 1) - seg_img = np.repeat(seg_img[..., np.newaxis], 3, axis=-1) * np.max(composite_image) - background_mask = composite_image < np.max(composite_image) * 0.2 - composite_image[background_mask] += (seg_img[background_mask] * 0.2).astype( - composite_image.dtype - ) - elif self.segmentation_naming_convention and add_boundaries: - fov_path = os.path.split(list(path_dict.values())[0])[0] - seg_path = self.segmentation_naming_convention(fov_path) - seg_img = io.imread(seg_path) - seg_boundaries = find_boundaries(seg_img, mode='inner') - val = (np.max(composite_image, axis=(0,1))*0.5).astype(composite_image.dtype) - val = np.min(val[val>0]) - composite_image[seg_boundaries] = [val]*3 - else: - seg_boundaries = None - return composite_image, seg_boundaries + return composite_image def layout(self): """Creates layout for viewer.""" @@ -192,25 +214,34 @@ def update_composite(self): in_path_dict = copy(path_dict) if self.red_select.value: path_dict["red"] = os.path.join( - self.output_dir, self.fov_select.value, self.red_select.value + self.output_dir, self.fov_select.value, self.red_select.value + self.suffix ) - in_path_dict["red"] = self.search_for_similar(self.red_select.value) + in_path_dict["red"] = {"fov": self.fov_select.value, "channel": self.red_select.value} if self.green_select.value: path_dict["green"] = os.path.join( - self.output_dir, self.fov_select.value, self.green_select.value + self.output_dir, self.fov_select.value, self.green_select.value + self.suffix ) - in_path_dict["green"] = self.search_for_similar(self.green_select.value) + in_path_dict["green"] = { + "fov": self.fov_select.value, "channel": self.green_select.value + } if self.blue_select.value: path_dict["blue"] = os.path.join( - self.output_dir, self.fov_select.value, self.blue_select.value + self.output_dir, self.fov_select.value, self.blue_select.value + self.suffix ) - in_path_dict["blue"] = self.search_for_similar(self.blue_select.value) + in_path_dict["blue"] = { + "fov": self.fov_select.value, "channel": self.blue_select.value + } non_none = [p for p in path_dict.values() if p] if not non_none: return - composite_image, _ = self.create_composite_image(path_dict) - in_composite_image, seg_boundaries = self.create_composite_image( - in_path_dict, add_overlay=False, add_boundaries=self.overlay_checkbox.value + composite_image = self.create_composite_image(path_dict) + composite_image, _ = self.overlay( + composite_image, add_overlay=True + ) + + in_composite_image = self.create_composite_from_dataset(in_path_dict) + in_composite_image, seg_boundaries = self.overlay( + in_composite_image, add_boundaries=self.overlay_checkbox.value ) in_composite_image = in_composite_image / np.quantile( in_composite_image, 0.999, axis=(0,1) @@ -230,4 +261,4 @@ def display(self): """Displays viewer.""" self.select_fov(None) self.layout() - self.update_composite() \ No newline at end of file + self.update_composite() diff --git a/templates/1_Nimbus_Predict.ipynb b/templates/1_Nimbus_Predict.ipynb index 9da6ce4..f8067bb 100644 --- a/templates/1_Nimbus_Predict.ipynb +++ b/templates/1_Nimbus_Predict.ipynb @@ -22,6 +22,7 @@ "from IPython.display import display, HTML\n", "display(HTML(\"\"))\n", "from nimbus_inference.nimbus import Nimbus, prep_naming_convention\n", + "from nimbus_inference.utils import MultiplexDataset\n", "from alpineer import io_utils\n", "from ark.utils import example_dataset\n", "from nimbus_inference.viewer_widget import NimbusViewer" @@ -33,7 +34,8 @@ "metadata": {}, "source": [ "## 0: Set root directory and download example dataset\n", - "Here we are using the example data located in `/data/example_dataset/input_data`. To modify this notebook to run using your own data, simply change `base_dir` to point to your own sub-directory within the data folder. Set `base_dir`, the path to all of your imaging data (i.e. multiplexed images and segmentation masks). Subdirectory `nimbus_output` will contain all of the data generated by this notebook. In the following, we expect this folder structure:\n", + "Here we are using the example data located in `/data/example_dataset/input_data`. To modify this notebook to run using your own data, simply change `base_dir` to point to your own sub-directory within the data folder. Set `base_dir`, the path to all of your imaging data (i.e. multiplexed images and segmentation masks). Subdirectory `nimbus_output` will contain all of the data generated by this notebook. In the following, we expect this folder structure, with `fov_1` and `fov_2` either being folders of individual channel images or `.ome.tiff` files that contain all channels in a single file.\n", + "```bash\n", "```\n", "|-- base_dir\n", "| |-- image_data\n", @@ -53,8 +55,7 @@ "outputs": [], "source": [ "# set up the base directory\n", - "base_dir = os.path.normpath(\"../data/example_dataset\")\n", - "# base_dir = os.path.normpath(\"C:/Users/lorenz/Desktop/angelo_lab/data/example_dataset\")" + "base_dir = os.path.normpath(\"../data/example_dataset\")" ] }, { @@ -172,6 +173,29 @@ " print(\"Segmentation data does not exist for fov 0 or naming convention is incorrect\")" ] }, + { + "cell_type": "markdown", + "id": "e7717960", + "metadata": {}, + "source": [ + "Next we will use the `MultiplexDataset` class to abstract away differences in data representation. The class takes `fov_paths`, `segmentation_naming_convention` and a `suffix` and provides methods `.get_channel(fov, channel)` and `.get_segmentation(fov)` to access the data. The `suffix` is used to filter out files that do not end with the specified suffix. When you use `.ome.tiff` files make sure to set the suffix to `.ome.tiff`, otherwise the ViewerWidget won't be able to display the images." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "50997492", + "metadata": {}, + "outputs": [], + "source": [ + "dataset = MultiplexDataset(\n", + " fov_paths=fov_paths,\n", + " suffix=\".tiff\", # or .png, .jpg, .jpeg, .tif or .ome.tiff\n", + " include_channels=include_channels,\n", + " segmentation_naming_convention=segmentation_naming_convention,\n", + ")" + ] + }, { "cell_type": "markdown", "id": "839e5240", @@ -189,15 +213,12 @@ "outputs": [], "source": [ "nimbus = Nimbus(\n", - " fov_paths=fov_paths,\n", - " segmentation_naming_convention=segmentation_naming_convention,\n", + " dataset=dataset,\n", " output_dir=nimbus_output_dir,\n", - " include_channels=include_channels,\n", " save_predictions=True,\n", " batch_size=4,\n", " test_time_aug=True,\n", " input_shape=[1024,1024],\n", - " suffix=\".tiff\",\n", " device=\"auto\",\n", ")\n", "\n", @@ -275,7 +296,7 @@ "metadata": {}, "outputs": [], "source": [ - "viewer = NimbusViewer(input_dir=tiff_dir, output_dir=nimbus_output_dir)\n", + "viewer = NimbusViewer(dataset=dataset, output_dir=nimbus_output_dir)\n", "viewer.display()" ] } diff --git a/tests/test_cell_analyzer.py b/tests/test_cell_analyzer.py index e34d96d..d156060 100644 --- a/tests/test_cell_analyzer.py +++ b/tests/test_cell_analyzer.py @@ -1,6 +1,6 @@ from nimbus_inference.cell_analyzer import CellAnalyzer from nimbus_inference.nimbus import Nimbus, prep_naming_convention -from tests.test_utils import prepare_ome_tif_data, prepare_tif_data +from tests.test_utils import prepare_tif_data import pandas as pd import numpy as np import tempfile diff --git a/tests/test_nimbus.py b/tests/test_nimbus.py index a954e20..cc2955d 100644 --- a/tests/test_nimbus.py +++ b/tests/test_nimbus.py @@ -1,5 +1,6 @@ from tests.test_utils import prepare_ome_tif_data, prepare_tif_data import tempfile +from nimbus_inference.utils import MultiplexDataset from nimbus_inference.nimbus import Nimbus, prep_naming_convention from nimbus_inference.unet import UNet from skimage.data import astronaut @@ -15,16 +16,15 @@ def test_check_inputs(): selected_markers = ["CD45", "CD3", "CD8", "ChyTr"] fov_paths, _ = prepare_tif_data(num_samples, temp_dir, selected_markers) naming_convention = prep_naming_convention(os.path.join(temp_dir, "deepcell_output")) - nimbus = Nimbus( - fov_paths=fov_paths, segmentation_naming_convention=naming_convention, - output_dir=temp_dir - ) + dataset = MultiplexDataset(fov_paths, naming_convention) + nimbus = Nimbus(dataset=dataset, output_dir=temp_dir) nimbus.check_inputs() def test_initialize_model(): + dataset = MultiplexDataset(["tests"]) nimbus = Nimbus( - fov_paths=[""], segmentation_naming_convention="", output_dir="", + dataset, output_dir="", input_shape=[512,512], batch_size=4 ) nimbus.initialize_model(padding="valid") @@ -43,19 +43,17 @@ def test_prepare_normalization_dict(): selected_markers = ["CD45", "CD3", "CD8", "ChyTr"] fov_paths,_ = prepare_tif_data(num_samples, temp_dir, selected_markers) naming_convention = prep_naming_convention(os.path.join(temp_dir, "deepcell_output")) - nimbus = Nimbus( - fov_paths, naming_convention, temp_dir, - include_channels=["CD45", "CD3", "CD8"] + dataset = MultiplexDataset( + fov_paths, naming_convention, include_channels=["CD45", "CD3", "CD8"] ) + nimbus = Nimbus(dataset, temp_dir) # test if normalization dict gets prepared and saved nimbus.prepare_normalization_dict(overwrite=True) assert os.path.exists(os.path.join(temp_dir, "normalization_dict.json")) assert "ChyTr" not in nimbus.normalization_dict.keys() # test if normalization dict gets loaded - nimbus_2 = Nimbus( - fov_paths, naming_convention, temp_dir, include_channels=["CD45", "CD3", "CD8"] - ) + nimbus_2 = Nimbus(dataset, temp_dir) nimbus_2.prepare_normalization_dict() assert nimbus_2.normalization_dict == nimbus.normalization_dict @@ -64,7 +62,8 @@ def test_tile_input(): image = torch.rand([1,2,768,768]) tile_size = (512, 512) output_shape = (320,320) - nimbus = Nimbus(fov_paths=[""], segmentation_naming_convention="", output_dir="") + dataset = MultiplexDataset(["tests"]) + nimbus = Nimbus(MultiplexDataset, output_dir="") nimbus.model = lambda x: x[..., 96:-96, 96:-96] tiled_input, padding = nimbus._tile_input(image, tile_size, output_shape) assert tiled_input.shape == (3,3,1,2,512,512) @@ -76,8 +75,7 @@ def test_tile_and_stitch(): image = rescale(astronaut(), 1.5, channel_axis=-1) image = np.moveaxis(image, -1, 0)[np.newaxis, ...] nimbus = Nimbus( - fov_paths=[""], segmentation_naming_convention="", output_dir="", - input_shape=[512,512], batch_size=4 + dataset="", output_dir="", input_shape=[512,512], batch_size=4 ) # check if tile and stitch works for mock model unequal input and output shape # mock model only center crops the input, so that the stitched output is equal to the input diff --git a/tests/test_utils.py b/tests/test_utils.py index 7882bae..f1195be 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,9 +1,10 @@ -from nimbus_inference.utils import prepare_normalization_dict, calculate_normalization -from nimbus_inference.utils import predict_fovs, predict_ome_fovs, prepare_input_data +from nimbus_inference.utils import (prepare_normalization_dict, calculate_normalization, +predict_fovs, prepare_input_data, MultiplexDataset, LazyOMETIFFReader) from nimbus_inference.utils import test_time_aug as tt_aug from nimbus_inference.nimbus import Nimbus from skimage import io from pyometiff import OMETIFFWriter +import pytest import numpy as np import tempfile import torch @@ -56,25 +57,30 @@ def prepare_tif_data(num_samples, temp_dir, selected_markers, random=False, std= def prepare_ome_tif_data(num_samples, temp_dir, selected_markers, random=False, std=1): np.random.seed(42) metadata_dict = { - "PhysicalSizeX" : "0.88", + "SizeX" : 256, + "SizeY" : 256, + "SizeC" : len(selected_markers) + 3, + "PhysicalSizeX" : 0.5, "PhysicalSizeXUnit" : "µm", - "PhysicalSizeY" : "0.88", + "PhysicalSizeY" : 0.5, "PhysicalSizeYUnit" : "µm", - "PhysicalSizeZ" : "3.3", - "PhysicalSizeZUnit" : "µm", } - + fov_paths = [] + inst_paths = [] + if isinstance(std, (int, float)) or len(std) != len(selected_markers): + std = [std] * len(selected_markers) for i in range(num_samples): metadata_dict["Channels"] = {} channels = [] - for marker in zip(selected_markers): + for j, (marker, s) in enumerate(zip(selected_markers, std)): if random: - img = np.random.rand(256, 256) * std + img = np.random.rand(256, 256) * s else: img = np.ones([256, 256]) - channels.append(img) + channels.append(img) metadata_dict["Channels"][marker] = { "Name" : marker, + "ID": str(j), "SamplesPerPixel": 1, } channel_data = np.stack(channels, axis=0) @@ -87,49 +93,74 @@ def prepare_ome_tif_data(num_samples, temp_dir, selected_markers, random=False, metadata=metadata_dict, explicit_tiffdata=False) writer.write() - return None + deepcell_dir = os.path.join(temp_dir, "deepcell_output") + os.makedirs(deepcell_dir, exist_ok=True) + inst_path = os.path.join(deepcell_dir, f"fov_{i}_whole_cell.tiff") + io.imsave( + inst_path, np.array( + [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]] + ).repeat(64, axis=1).repeat(64, axis=0) + ) + fov_paths.append(sample_name) + inst_paths.append(inst_path) + return fov_paths, inst_paths def test_calculate_normalization(): with tempfile.TemporaryDirectory() as temp_dir: - fov_paths, _ = prepare_tif_data( + # test for single channel data + tif_fov_paths, _ = prepare_tif_data( num_samples=1, temp_dir=temp_dir, selected_markers=["CD4"], random=True, std=[0.5] ) channel = "CD4" - channel_path = os.path.join(fov_paths[0], channel + ".tiff") - channel_out, norm_val = calculate_normalization(channel_path, 0.999) - # test if we get the correct channel and normalization value - assert channel_out == channel - assert np.isclose(norm_val, 0.5, 0.01) + tif_dataset = MultiplexDataset(tif_fov_paths, suffix=".tiff") + ome_fov_paths, _ = prepare_ome_tif_data( + num_samples=1, temp_dir=temp_dir, selected_markers=["CD4", "CD56"], random=True, std=[0.5] + ) + ome_dataset = MultiplexDataset(ome_fov_paths, suffix=".ome.tiff") + for dataset in [tif_dataset, ome_dataset]: + norm_dict = calculate_normalization(dataset, 0.999) + channel_out, norm_val = list(norm_dict.items())[0] + # test if we get the correct channel and normalization value + assert channel_out == channel + assert np.isclose(norm_val, 0.5, 0.01) def test_prepare_normalization_dict(): with tempfile.TemporaryDirectory() as temp_dir: scales = [0.5, 1.0, 1.5, 2.0, 5.0] channels = ["CD4", "CD11c", "CD14", "CD56", "CD57"] - fov_paths, _ = prepare_tif_data( + tif_fov_paths, _ = prepare_tif_data( num_samples=5, temp_dir=temp_dir, selected_markers=channels, random=True, std=scales ) - normalization_dict = prepare_normalization_dict( - fov_paths, temp_dir, quantile=0.999, n_subset=10, n_jobs=1, - output_name="normalization_dict.json" - ) - # test if normalization dict got saved - assert os.path.exists(os.path.join(temp_dir, "normalization_dict.json")) - assert normalization_dict == json.load( - open(os.path.join(temp_dir, "normalization_dict.json")) + tif_dataset = MultiplexDataset(tif_fov_paths, suffix=".tiff") + + # test if everything works for multi channel data + ome_fov_paths, _ = prepare_ome_tif_data( + num_samples=5, temp_dir=temp_dir, selected_markers=channels, random=True, std=scales ) - # test if normalization dict is correct - for channel, scale in zip(channels, scales): - assert np.isclose(normalization_dict[channel], scale, 0.01) + ome_dataset = MultiplexDataset(ome_fov_paths, suffix=".ome.tiff") + for dataset in [tif_dataset, ome_dataset]: + normalization_dict = prepare_normalization_dict( + dataset, temp_dir, quantile=0.999, n_subset=10, n_jobs=1, + output_name="normalization_dict.json" + ) + # test if normalization dict got saved + assert os.path.exists(os.path.join(temp_dir, "normalization_dict.json")) + assert normalization_dict == json.load( + open(os.path.join(temp_dir, "normalization_dict.json")) + ) + # test if normalization dict is correct + for channel, scale in zip(channels, scales): + assert np.isclose(normalization_dict[channel], scale, 0.01) - # test if multiprocessing yields approximately the same results - normalization_dict_mp = prepare_normalization_dict( - fov_paths, temp_dir, quantile=0.999, n_subset=10, n_jobs=2, - output_name="normalization_dict.json" - ) - for key in normalization_dict.keys(): - assert np.isclose(normalization_dict[key], normalization_dict_mp[key], 1e-6) + # test if multiprocessing yields approximately the same results + normalization_dict_mp = prepare_normalization_dict( + dataset, temp_dir, quantile=0.999, n_subset=10, n_jobs=2, + output_name="normalization_dict.json" + ) + for key in normalization_dict.keys(): + assert np.isclose(normalization_dict[key], normalization_dict_mp[key], 1e-6) def test_prepare_input_data(): @@ -161,9 +192,8 @@ def segmentation_naming_convention(fov_path): num_samples=1, temp_dir=temp_dir, selected_markers=[channel] ) output_dir = os.path.join(temp_dir, "nimbus_output") - nimbus = Nimbus( - fov_paths, segmentation_naming_convention, output_dir, - ) + dataset = MultiplexDataset(fov_paths, segmentation_naming_convention, suffix=".tiff") + nimbus = Nimbus(dataset, output_dir) nimbus.prepare_normalization_dict() mplex_img = io.imread(os.path.join(fov_paths[0], channel+".tiff")) instance_mask = io.imread(inst_paths[0]) @@ -206,16 +236,14 @@ def segmentation_naming_convention(fov_path): fov_paths, _ = prepare_tif_data( num_samples=1, temp_dir=temp_dir, selected_markers=["CD4", "CD56"] ) + dataset = MultiplexDataset(fov_paths, segmentation_naming_convention, suffix=".tiff") output_dir = os.path.join(temp_dir, "nimbus_output") - nimbus = Nimbus( - fov_paths, segmentation_naming_convention, output_dir, - ) + nimbus = Nimbus(dataset, output_dir) output_dir = os.path.join(temp_dir, "nimbus_output") nimbus.prepare_normalization_dict() cell_table = predict_fovs( - nimbus=nimbus, fov_paths=fov_paths, output_dir=output_dir, - normalization_dict=nimbus.normalization_dict, - segmentation_naming_convention=segmentation_naming_convention, suffix=".tiff", + nimbus=nimbus, dataset=dataset, output_dir=output_dir, + normalization_dict=nimbus.normalization_dict, suffix=".tiff", save_predictions=False, half_resolution=True, ) # check if we get the correct number of cells @@ -230,10 +258,64 @@ def segmentation_naming_convention(fov_path): # # run again with save_predictions=True and check if predictions get written to output_dir cell_table = predict_fovs( - nimbus=nimbus, fov_paths=fov_paths, output_dir=output_dir, - normalization_dict=nimbus.normalization_dict, - segmentation_naming_convention=segmentation_naming_convention, suffix=".tiff", + nimbus=nimbus, dataset=dataset, output_dir=output_dir, + normalization_dict=nimbus.normalization_dict, suffix=".tiff", save_predictions=True, half_resolution=True, ) assert os.path.exists(os.path.join(output_dir, "fov_0", "CD4.tiff")) assert os.path.exists(os.path.join(output_dir, "fov_0", "CD56.tiff")) + + +def test_LazyOMETIFFReader(): + with tempfile.TemporaryDirectory() as temp_dir: + fov_paths, _ = prepare_ome_tif_data( + num_samples=1, temp_dir=temp_dir, selected_markers=["CD4", "CD56"] + ) + reader = LazyOMETIFFReader(fov_paths[0]) + assert hasattr(reader, "metadata") + assert reader.channels == ["CD4", "CD56"] + cd4_channel = reader.get_channel("CD4") + cd56_channel = reader.get_channel("CD56") + assert cd4_channel.shape == (256, 256) + assert cd56_channel.shape == (256, 256) + + +def test_MultiplexDataset(): + with tempfile.TemporaryDirectory() as temp_dir: + def segmentation_naming_convention(fov_path): + temp_dir_, fov_ = os.path.split(fov_path) + fov_ = fov_.split(".")[0] + return os.path.join(temp_dir_, "deepcell_output", fov_ + "_whole_cell.tiff") + + fov_paths, _ = prepare_ome_tif_data( + num_samples=1, temp_dir=temp_dir, selected_markers=["CD4", "CD56"] + ) + # check if check inputs raises error when inputs are incorrect + with pytest.raises(FileNotFoundError): + dataset = MultiplexDataset(["abc"], segmentation_naming_convention, suffix=".ome.tiff") + # check if we get the correct channels and fov_paths + dataset = MultiplexDataset(fov_paths, segmentation_naming_convention, suffix=".ome.tiff") + assert len(dataset) == 1 + assert dataset.channels == ["CD4", "CD56"] + assert dataset.fov_paths == fov_paths + assert dataset.multi_channel == True + cd4_channel = io.imread(fov_paths[0])[0] + cd4_channel_ = dataset.get_channel(fov="fov_0", channel="CD4") + assert np.alltrue(cd4_channel == cd4_channel_) + fov_0_seg = io.imread(segmentation_naming_convention(fov_paths[0])) + fov_0_seg_ = dataset.get_segmentation(fov="fov_0") + assert np.alltrue(fov_0_seg == fov_0_seg_) + + # test everything again with single channel data + fov_paths, _ = prepare_tif_data( + num_samples=1, temp_dir=temp_dir, selected_markers=["CD4", "CD56"] + ) + dataset = MultiplexDataset(fov_paths, segmentation_naming_convention, suffix=".tiff") + assert len(dataset) == 1 + assert dataset.channels == ["CD4", "CD56"] + assert dataset.fov_paths == fov_paths + assert dataset.multi_channel == False + cd4_channel_ = dataset.get_channel(fov="fov_0", channel="CD4") + assert np.alltrue(cd4_channel == cd4_channel_) + fov_0_seg_ = dataset.get_segmentation(fov="fov_0") + assert np.alltrue(fov_0_seg == fov_0_seg_) diff --git a/tests/test_viewer_widget.py b/tests/test_viewer_widget.py index 9e37108..fb266ab 100644 --- a/tests/test_viewer_widget.py +++ b/tests/test_viewer_widget.py @@ -1,5 +1,6 @@ from nimbus_inference.viewer_widget import NimbusViewer from nimbus_inference.nimbus import Nimbus, prep_naming_convention +from nimbus_inference.utils import MultiplexDataset from tests.test_utils import prepare_ome_tif_data, prepare_tif_data import numpy as np import tempfile @@ -8,36 +9,67 @@ def test_NimbusViewer(): with tempfile.TemporaryDirectory() as temp_dir: - _ = prepare_tif_data( + fov_paths, _ = prepare_tif_data( num_samples=2, temp_dir=temp_dir, selected_markers=["CD4", "CD11c", "CD56"] ) - viewer_widget = NimbusViewer(temp_dir, temp_dir) + dataset = MultiplexDataset(fov_paths) + viewer_widget = NimbusViewer(dataset, temp_dir) assert isinstance(viewer_widget, NimbusViewer) def test_composite_image(): with tempfile.TemporaryDirectory() as temp_dir: - _ = prepare_tif_data( + fov_paths, _ = prepare_tif_data( num_samples=2, temp_dir=temp_dir, selected_markers=["CD4", "CD11c", "CD56"] ) - viewer_widget = NimbusViewer(temp_dir, temp_dir) + dataset = MultiplexDataset(fov_paths) + viewer_widget = NimbusViewer(dataset, temp_dir) path_dict = { "red": os.path.join(temp_dir, "fov_0", "CD4.tiff"), "green": os.path.join(temp_dir, "fov_0", "CD11c.tiff"), } - composite_image, _ = viewer_widget.create_composite_image(path_dict) + composite_image = viewer_widget.create_composite_image(path_dict) assert isinstance(composite_image, np.ndarray) assert composite_image.shape == (256, 256, 3) path_dict["blue"] = os.path.join(temp_dir, "fov_0", "CD56.tiff") - composite_image, _ = viewer_widget.create_composite_image(path_dict) + composite_image = viewer_widget.create_composite_image(path_dict) assert composite_image.shape == (256, 256, 3) + + +def test_create_composite_from_dataset(): + with tempfile.TemporaryDirectory() as temp_dir: + fov_paths, _ = prepare_tif_data( + num_samples=2, temp_dir=temp_dir, selected_markers=["CD4", "CD11c", "CD56"] + ) + dataset = MultiplexDataset(fov_paths) + viewer_widget = NimbusViewer(dataset, temp_dir) + path_dict = { + "red": {"fov": "fov_0", "channel": "CD4"}, + "green": {"fov": "fov_0", "channel": "CD11c"}, + } + composite_image = viewer_widget.create_composite_from_dataset(path_dict) + assert isinstance(composite_image, np.ndarray) + assert composite_image.shape == (256, 256, 3) + + +def test_overlay(): + with tempfile.TemporaryDirectory() as temp_dir: + fov_paths, _ = prepare_tif_data( + num_samples=2, temp_dir=temp_dir, selected_markers=["CD4", "CD11c", "CD56"] + ) + path_dict = { + "red": os.path.join(temp_dir, "fov_0", "CD4.tiff"), + "green": os.path.join(temp_dir, "fov_0", "CD11c.tiff"), + } # test if segmentation gets added naming_convention = prep_naming_convention(os.path.join(temp_dir, "deepcell_output")) - viewer_widget = NimbusViewer( - temp_dir, temp_dir, segmentation_naming_convention=naming_convention + dataset = MultiplexDataset(fov_paths, naming_convention) + viewer_widget = NimbusViewer(dataset, temp_dir) + composite_image = viewer_widget.create_composite_image(path_dict) + composite_image, seg_boundaries = viewer_widget.overlay( + composite_image, add_boundaries=True ) - composite_image, seg_boundaries = viewer_widget.create_composite_image(path_dict) assert composite_image.shape == (256, 256, 3) assert seg_boundaries.shape == (256, 256) assert np.unique(seg_boundaries).tolist() == [0, 1]