Skip to content

Commit

Permalink
Merge pull request #5 from angelolab/mps-support
Browse files Browse the repository at this point in the history
MPS backend support and Dependency fix
  • Loading branch information
JLrumberger authored Mar 5, 2024
2 parents 3d0ae97 + 0f929b7 commit 5fc102e
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 7 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
# Ensure that if any job fails, all jobs are cancelled.
fail-fast: false
matrix:
python-version: ["3.9", "3.10", "3.11", "3.12"]
python-version: ["3.9", "3.10", "3.11"]
os: [ubuntu-latest, macos-latest, windows-latest]

steps:
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ urls.Documentation = "https://Nimbus-Inference.readthedocs.io/"
urls.Source = "https://github.com/angelolab/Nimbus-Inference"
urls.Home-page = "https://github.com/angelolab/Nimbus-Inference"
dependencies = [
"ark-analysis",
"torch==2.2.0",
"torchvision==0.17.0",
"alpineer",
Expand Down
20 changes: 16 additions & 4 deletions src/nimbus_inference/nimbus.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from alpineer import io_utils
from alpineer import io_utils, misc_utils
from skimage.util.shape import view_as_windows
import nimbus_inference
from nimbus_inference.utils import (
Expand Down Expand Up @@ -81,7 +81,7 @@ class Nimbus(nn.Module):
def __init__(
self, fov_paths, segmentation_naming_convention, output_dir, save_predictions=True,
include_channels=[], half_resolution=True, batch_size=4, test_time_aug=True,
input_shape=[1024, 1024], suffix=".tiff",
input_shape=[1024, 1024], suffix=".tiff", device="auto",
):
"""Initializes a Nimbus Application.
Args:
Expand All @@ -96,6 +96,8 @@ def __init__(
test_time_aug (bool): Whether to use test time augmentation.
input_shape (list): Shape of input images.
suffix (str): Suffix of images to load.
device (str): Device to run model on, either "auto" (either "mps" or "cuda"
, with "cpu" as a fallback), "cpu", "cuda", or "mps". Defaults to "auto".
"""
super(Nimbus, self).__init__()
self.fov_paths = fov_paths
Expand All @@ -111,7 +113,17 @@ def __init__(
self.suffix = suffix
if self.output_dir != "":
os.makedirs(self.output_dir, exist_ok=True)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if device == "auto":
if torch.backends.mps.is_available():
self.device = torch.device("mps")
elif torch.cuda.is_available():
self.device = torch.device("cuda")
else:
self.device = torch.device("cpu")
else:
misc_utils.verify_in_list(device=[device], valid_devices=["cpu", "cuda", "mps"])
self.device = torch.device(device)

def check_inputs(self):
"""check inputs for Nimbus model"""
Expand Down Expand Up @@ -313,7 +325,7 @@ def _tile_input(self, image, tile_size, output_shape, pad_mode="reflect"):
image = np.pad(image, ((0, 0), (0, 0), (pad_h0, pad_h1), (pad_w0, pad_w1)), mode=pad_mode)
b, c = image.shape[:2]
# tile image
view = np.squeeze(
view = np.squeeze(
view_as_windows(image, [b, c] + list(tile_size), step=[b, c] + list(output_shape)),
axis=(0,1)
)
Expand Down
5 changes: 3 additions & 2 deletions templates/1_Nimbus_Predict.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
"source": [
"# set up the base directory\n",
"base_dir = os.path.normpath(\"../data/example_dataset\")\n",
"base_dir = os.path.normpath(\"C:/Users/lorenz/Desktop/angelo_lab/data/example_dataset\")"
"# base_dir = os.path.normpath(\"C:/Users/lorenz/Desktop/angelo_lab/data/example_dataset\")"
]
},
{
Expand Down Expand Up @@ -198,6 +198,7 @@
" test_time_aug=True,\n",
" input_shape=[1024,1024],\n",
" suffix=\".tiff\",\n",
" device=\"auto\",\n",
")\n",
"\n",
"# check if all inputs are valid\n",
Expand Down Expand Up @@ -238,7 +239,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": null,
"id": "76225704",
"metadata": {
"scrolled": true
Expand Down

0 comments on commit 5fc102e

Please sign in to comment.