Skip to content

Commit

Permalink
Merge pull request #8 from angelolab/ome_tiff
Browse files Browse the repository at this point in the history
Added ome-tiff inference pipeline
  • Loading branch information
JLrumberger authored Apr 11, 2024
2 parents 178dccd + e08fe2a commit 5250d0b
Show file tree
Hide file tree
Showing 9 changed files with 609 additions and 271 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ dependencies = [
"ipywidgets",
"natsort",
"ipython",
"zarr",
]


Expand Down
75 changes: 18 additions & 57 deletions src/nimbus_inference/nimbus.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
from alpineer import io_utils, misc_utils
from skimage.util.shape import view_as_windows
import nimbus_inference
from nimbus_inference.utils import (
prepare_normalization_dict,
predict_fovs,
predict_ome_fovs,
nimbus_preprocess,
from nimbus_inference.utils import (prepare_normalization_dict,
predict_fovs, nimbus_preprocess, MultiplexDataset
)
from huggingface_hub import hf_hub_download
from nimbus_inference.unet import UNet
Expand Down Expand Up @@ -69,7 +66,7 @@ def segmentation_naming_convention(fov_path):
Returns:
seg_path (str): paths to segmentation fovs
"""
fov_name = os.path.basename(fov_path)
fov_name = os.path.basename(fov_path).replace(".ome.tiff", "")
return os.path.join(deepcell_output_dir, fov_name + "_whole_cell.tiff")

return segmentation_naming_convention
Expand All @@ -79,18 +76,15 @@ class Nimbus(nn.Module):
"""Nimbus application class for predicting marker activity for cells in multiplexed images."""

def __init__(
self, fov_paths, segmentation_naming_convention, output_dir, save_predictions=True,
include_channels=[], half_resolution=True, batch_size=4, test_time_aug=True,
input_shape=[1024, 1024], suffix=".tiff", device="auto",
self, dataset: MultiplexDataset, output_dir: str, save_predictions: bool=True,
half_resolution: bool=True, batch_size: int=4, test_time_aug: bool=True,
input_shape: list=[1024, 1024], device: str="auto",
):
"""Initializes a Nimbus Application.
Args:
fov_paths (list): List of paths to fovs to be analyzed.
segmentation_naming_convention (function): Function that returns the path to the
segmentation mask for a given fov path.
dataset (MultiplexDataset): Path to directory containing fovs.
output_dir (str): Path to directory to save output.
save_predictions (bool): Whether to save predictions.
include_channels (list): List of channels to include in analysis.
half_resolution (bool): Whether to run model on half resolution images.
batch_size (int): Batch size for model inference.
test_time_aug (bool): Whether to use test time augmentation.
Expand All @@ -100,17 +94,14 @@ def __init__(
, with "cpu" as a fallback), "cpu", "cuda", or "mps". Defaults to "auto".
"""
super(Nimbus, self).__init__()
self.fov_paths = fov_paths
self.include_channels = include_channels
self.segmentation_naming_convention = segmentation_naming_convention
self.dataset = dataset
self.output_dir = output_dir
self.half_resolution = half_resolution
self.save_predictions = save_predictions
self.batch_size = batch_size
self.checked_inputs = False
self.test_time_aug = test_time_aug
self.input_shape = input_shape
self.suffix = suffix
if self.output_dir != "":
os.makedirs(self.output_dir, exist_ok=True)

Expand All @@ -127,25 +118,9 @@ def __init__(

def check_inputs(self):
"""check inputs for Nimbus model"""
# check if all paths in fov_paths exists
if not isinstance(self.fov_paths, (list, tuple)):
self.fov_paths = [self.fov_paths]
io_utils.validate_paths(self.fov_paths)

# check if segmentation_naming_convention returns valid paths
path_to_segmentation = self.segmentation_naming_convention(self.fov_paths[0])
if not os.path.exists(path_to_segmentation):
raise FileNotFoundError(
"Function segmentation_naming_convention does not return valid\
path. Segmentation path {} does not exist.".format(
path_to_segmentation
)
)
# check if output_dir exists
io_utils.validate_paths([self.output_dir])

if isinstance(self.include_channels, str):
self.include_channels = [self.include_channels]
self.checked_inputs = True
print("All inputs are valid.")

Expand All @@ -161,14 +136,14 @@ def initialize_model(self, padding="reflect"):
self.checkpoint_path = os.path.join(
path,
"assets",
"resUnet_baseline_hickey_tonic_dec_mskc_mskp_2_channel_halfres_512_bs32.pt"
"resUnet_baseline_hickey_tonic_dec_mskc_mskp_2_channel_halfres_512_bs32_cw_0.8.pt"
)
if not os.path.exists(self.checkpoint_path):
local_dir = os.path.join(path, "assets")
print("Downloading weights from Hugging Face Hub...")
self.checkpoint_path = hf_hub_download(
repo_id="JLrumberger/Nimbus-Inference",
filename="resUnet_baseline_hickey_tonic_dec_mskc_mskp_2_channel_halfres_512_bs32.pt",
filename="resUnet_baseline_hickey_tonic_dec_mskc_mskp_2_channel_halfres_512_bs32_cw_0.8.pt",
local_dir=local_dir,
local_dir_use_symlinks=False,
)
Expand All @@ -192,13 +167,11 @@ def prepare_normalization_dict(
if os.path.exists(self.normalization_dict_path) and not overwrite:
self.normalization_dict = json.load(open(self.normalization_dict_path))
else:

n_jobs = os.cpu_count() if multiprocessing else 1
self.normalization_dict = prepare_normalization_dict(
self.fov_paths, self.output_dir, quantile, self.include_channels, n_subset, n_jobs
self.dataset, self.output_dir, quantile, n_subset,
n_jobs
)
if self.include_channels == []:
self.include_channels = list(self.normalization_dict.keys())

def predict_fovs(self):
"""Predicts cell classification for input data.
Expand All @@ -214,24 +187,12 @@ def predict_fovs(self):
print("Available GPUs: ", gpus)
print("Predictions will be saved in {}".format(self.output_dir))
print("Iterating through fovs will take a while...")
if self.suffix.lower() in [".ome.tif", ".ome.tiff"]:
self.cell_table = predict_ome_fovs(
nimbus=self, fov_paths=self.fov_paths, output_dir=self.output_dir,
normalization_dict=self.normalization_dict,
segmentation_naming_convention=self.segmentation_naming_convention,
include_channels=self.include_channels, save_predictions=self.save_predictions,
half_resolution=self.half_resolution, batch_size=self.batch_size,
test_time_augmentation=self.test_time_aug, suffix=self.suffix,
)
elif self.suffix.lower() in [".tiff", ".tif", ".jpg", ".jpeg", ".png"]:
self.cell_table = predict_fovs(
nimbus=self, fov_paths=self.fov_paths, output_dir=self.output_dir,
normalization_dict=self.normalization_dict,
segmentation_naming_convention=self.segmentation_naming_convention,
include_channels=self.include_channels, save_predictions=self.save_predictions,
half_resolution=self.half_resolution, batch_size=self.batch_size,
test_time_augmentation=self.test_time_aug, suffix=self.suffix,
)
self.cell_table = predict_fovs(
nimbus=self, dataset=self.dataset, output_dir=self.output_dir,
normalization_dict=self.normalization_dict, save_predictions=self.save_predictions,
half_resolution=self.half_resolution, batch_size=self.batch_size,
test_time_augmentation=self.test_time_aug, suffix=self.dataset.suffix,
)
self.cell_table.to_csv(os.path.join(self.output_dir, "nimbus_cell_table.csv"), index=False)
return self.cell_table

Expand Down
Loading

0 comments on commit 5250d0b

Please sign in to comment.