Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sc 45 instanseg #85

Merged
merged 8 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ Operations on image and labels layers.
im.segment
im.segment_points
im.cellpose_callable
im.instanseg_callable
im.add_grid_labels_layer
im.expand_labels_layer
im.align_labels_layers
Expand Down
963 changes: 193 additions & 770 deletions docs/tutorials/advanced/Harpy_instanseg.ipynb

Large diffs are not rendered by default.

4 changes: 1 addition & 3 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,7 @@ clustering =
#napari-convpaint @ git+ssh://[email protected]/guiwitz/napari-convpaint.git

instanseg =
instanseg
torchvision
monai
instanseg-torch

[options.package_data]
* = *.yaml
Expand Down
39 changes: 39 additions & 0 deletions src/harpy/_tests/test_image/test_segmentation/test_segmentation.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import importlib.util

import dask.array as da
import dask.dataframe as dd
import pandas as pd
import pytest
from dask.dataframe import DataFrame
from spatialdata import SpatialData

from harpy.image._image import _get_spatial_element
from harpy.image.segmentation._segmentation import segment, segment_points
from harpy.image.segmentation.segmentation_models._baysor import _dummy
from harpy.image.segmentation.segmentation_models._cellpose import cellpose_callable
Expand Down Expand Up @@ -124,3 +126,40 @@ def test_segment_points(sdata_multi_c_no_backed: SpatialData):
assert _output_labels_layer in sdata_multi_c_no_backed.labels
for _output_shapes_layer in output_shapes_layer:
assert _output_shapes_layer in sdata_multi_c_no_backed.shapes


@pytest.mark.skipif(not importlib.util.find_spec("instanseg"), reason="requires the instanseg library")
def test_segment_instanseg(sdata_multi_c_no_backed: SpatialData):
from instanseg import InstanSeg

from harpy.image.segmentation.segmentation_models._instanseg import instanseg_callable

instanseg_fluorescence = InstanSeg("fluorescence_nuclei_and_cells", verbosity=1, device="cpu")

output_labels_layer = ["labels_nuclei_instanseg", "labels_cells_instanseg"]
output_shapes_layer = ["shapes_nuclei_instanseg", "shapes_cells_instanseg"]
sdata_multi_c_no_backed = segment(
sdata_multi_c_no_backed,
img_layer="combine",
model=instanseg_callable,
output_labels_layer=output_labels_layer,
output_shapes_layer=output_shapes_layer,
labels_layer_align="labels_cells_instanseg",
trim=False,
chunks=50,
overwrite=True,
depth=30,
crd=[10, 110, 0, 100],
scale_factors=[2, 2, 2, 2],
instanseg_model=instanseg_fluorescence,
output="all_outputs",
)

for _output_labels_layer in output_labels_layer:
assert _output_labels_layer in sdata_multi_c_no_backed.labels
for _output_shapes_layer in output_shapes_layer:
assert _output_shapes_layer in sdata_multi_c_no_backed.shapes

for _output_labels_layer in output_labels_layer:
se = _get_spatial_element(sdata_multi_c_no_backed, layer=output_labels_layer[0])
assert da.any(se.data).compute()
6 changes: 2 additions & 4 deletions src/harpy/_tests/test_notebooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,8 @@ def test_notebooks_flowsom(notebook):

@pytest.mark.skip
@pytest.mark.skipif(
not importlib.util.find_spec("InstanSeg")
or not importlib.util.find_spec("monai")
or not importlib.util.find_spec("torchvision"),
reason="requires the InstanSeg, monai and torchvision libraries",
not importlib.util.find_spec("instanseg"),
reason="requires the instanseg library",
)
@pytest.mark.parametrize(
"notebook",
Expand Down
1 change: 1 addition & 0 deletions src/harpy/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@
)
from .segmentation._segmentation import segment, segment_points
from .segmentation.segmentation_models._cellpose import cellpose_callable
from .segmentation.segmentation_models._instanseg import instanseg_callable
107 changes: 0 additions & 107 deletions src/harpy/image/segmentation/_align_masks.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,11 @@
from __future__ import annotations

from typing import Any

import dask.array as da
import numpy as np
from dask.array import Array
from numpy.typing import NDArray
from spatialdata import SpatialData
from spatialdata.models.models import ScaleFactors_t

from harpy.image.segmentation._map import map_labels
from harpy.image.segmentation._utils import (
_SEG_DTYPE,
_add_depth_to_chunks_size,
_check_boundary,
_clean_up_masks,
_merge_masks,
_rechunk_overlap,
_substract_depth_from_chunks_size,
)
from harpy.utils.pylogger import get_pylogger

log = get_pylogger(__name__)
Expand Down Expand Up @@ -117,100 +104,6 @@ def align_labels_layers(
return sdata


def _align_dask_arrays(
x_label_1: Array,
x_label_2: Array,
**kwargs: Any, # keyword arguments to be passed to map_overlap/map_blocks
):
# we will align labels of x_label_1 with labels of x_labels_2.

assert x_label_1.shape == x_label_2.shape, "Only arrays with same shape are currently supported."

chunks = kwargs.pop("chunks", None)
depth = kwargs.pop("depth", 100)
boundary = kwargs.pop("boundary", "reflect")

if isinstance(depth, int):
depth = {0: 0, 1: depth, 2: depth}
else:
assert len(depth) == x_label_1.ndim, f"Please provide depth for each dimension ({x_label_1.ndim})."
if x_label_1.ndim == 2:
depth = {0: 0, 1: depth[0], 2: depth[1]}

assert depth[0] == 0, "Depth not equal to 0 for 'z' dimension is not supported"

if chunks is None:
assert (
x_label_1.chunksize == x_label_2.chunksize
), "If chunks is not specified, please ensure Dask arrays have the same chunksize."

_check_boundary(boundary)

_to_squeeze = False
if x_label_1.ndim == 2:
_to_squeeze = True
x_label_1 = x_label_1[None, ...]
x_label_2 = x_label_2[None, ...]

# rechunk so that we ensure minimum chunksize, in order to control output_chunks sizes.
x_label_1 = _rechunk_overlap(x_label_1, depth=depth, chunks=chunks)
x_label_2 = _rechunk_overlap(x_label_2, depth=depth, chunks=chunks)

assert (
x_label_1.numblocks[0] == 1
), f"Expected the number of blocks in the Z-dimension to be `1`, found `{x_label_1.numblocks[0]}`."
assert (
x_label_2.numblocks[0] == 1
), f"Expected the number of blocks in the Z-dimension to be `1`, found `{x_label_2.numblocks[0]}`."

# output_chunks can be derived from either x_label_1 or x_label_2
output_chunks = _add_depth_to_chunks_size(x_label_1.chunks, depth)

x_labels = da.map_overlap(
lambda m, f: _relabel_array_1_to_array_2_per_chunk(m, f),
x_label_1,
x_label_2,
dtype=_SEG_DTYPE,
allow_rechunk=False, # already dealed with correcting for case where depth > chunksize
chunks=output_chunks, # e.g. ((1024+60, 1024+60, 452+60), (1024+60, 1024+60, 452+60) ),
depth=depth,
trim=False,
boundary="reflect",
# this reflect is useless for this use case, but clean_up_masks and _merge_masks only support
# results from map_overlap generated with "reflect", "nearest" and "constant"
)

x_labels = da.map_blocks(
_clean_up_masks,
x_labels,
dtype=_SEG_DTYPE,
depth=depth,
)

output_chunks = _substract_depth_from_chunks_size(x_labels.chunks, depth=depth)

x_labels = da.map_overlap(
_merge_masks,
x_labels,
dtype=_SEG_DTYPE,
num_blocks=x_labels.numblocks,
trim=False,
allow_rechunk=False, # already dealed with correcting for case where depth > chunksize
chunks=output_chunks, # e.g. ((7,) ,(1024, 1024, 452), (1024, 1024, 452), (1,) ),
depth=depth,
boundary="reflect",
_depth=depth,
)

x_labels = x_labels.rechunk(x_labels.chunksize)

# squeeze if a trivial dimension was added.
if _to_squeeze:
x_labels = x_labels.squeeze(0)

return x_labels


def _relabel_array_1_to_array_2_per_chunk(array_1: NDArray, array_2: NDArray) -> NDArray:
assert array_1.shape == array_2.shape

Expand Down
43 changes: 37 additions & 6 deletions src/harpy/image/segmentation/_map.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from __future__ import annotations

import os
import shutil
import uuid
from collections.abc import Callable, Iterable, Mapping
from pathlib import Path
from types import MappingProxyType
from typing import Any

Expand All @@ -11,6 +15,7 @@
from spatialdata import SpatialData
from spatialdata.models.models import ScaleFactors_t
from spatialdata.transformations import Translation, get_transformation
from upath import UPath

from harpy.image._image import (
_get_spatial_element,
Expand Down Expand Up @@ -154,9 +159,9 @@ def _get_layers(sdata: SpatialData, labels_layers: list[str]) -> tuple[list[Arra
)

# Ensure the translation is the same as the first label layer
assert (
transformations == first_transformations
), f"Provided labels layers '{labels_layers}' should all have the same transformations defined on them."
assert transformations == first_transformations, (
f"Provided labels layers '{labels_layers}' should all have the same transformations defined on them."
)

labels_data.append(x_label)

Expand All @@ -171,13 +176,19 @@ def _get_layers(sdata: SpatialData, labels_layers: list[str]) -> tuple[list[Arra
kwargs.setdefault("iou_depth", iou_depth)
kwargs.setdefault("iou_threshold", iou_threshold)

if sdata.is_backed():
_temp_path = UPath(sdata.path).parent / f"tmp_{uuid.uuid4()}"
else:
_temp_path = None

# labels_arrays is a list of dask arrays
# do some processing on the labels
array = _combine_dask_arrays(
labels_arrays,
relabel_chunks=relabel_chunks,
trim=trim,
func=func,
temp_path=_temp_path,
fn_kwargs=fn_kwargs,
**kwargs,
)
Expand Down Expand Up @@ -205,6 +216,10 @@ def _get_layers(sdata: SpatialData, labels_layers: list[str]) -> tuple[list[Arra
overwrite=overwrite,
)

if _temp_path is not None:
# TODO this will not work if sdata is remote (e.g. s3 bucket).
shutil.rmtree(_temp_path)

return sdata


Expand All @@ -213,6 +228,7 @@ def _combine_dask_arrays(
relabel_chunks: bool,
trim: bool,
func: Callable[..., NDArray],
temp_path: str | Path,
fn_kwargs: Mapping[str, Any] = MappingProxyType({}), # keyword arguments to be passed to func
**kwargs: Any, # keyword arguments to be passed to map_overlap/map_blocks
) -> Array:
Expand Down Expand Up @@ -268,9 +284,9 @@ def _fix_depth(_depth):
for i, x_label in enumerate(_labels_arrays):
# rechunk so that we ensure minimum chunksize, in order to control output_chunks sizes.
x_label = _rechunk_overlap(x_label, depth=depth, chunks=chunks)
assert (
x_label.numblocks[0] == 1
), f"Expected the number of blocks in the Z-dimension to be `1`, found `{x_label.numblocks[0]}`."
assert x_label.numblocks[0] == 1, (
f"Expected the number of blocks in the Z-dimension to be `1`, found `{x_label.numblocks[0]}`."
)

if i == 0:
# output_chunks can be derived from any rechunked x_label in labels_arrays
Expand Down Expand Up @@ -307,6 +323,21 @@ def _fix_depth(_depth):
# return x_labels.squeeze(0)

if not trim:
# write to intermediate zarr store if sdata is backed to reduce ram memory.
if temp_path is not None:
zarr_path = os.path.join(temp_path, f"labels_{uuid.uuid4()}.zarr")
_chunks = x_labels.chunks
x_labels.rechunk(x_labels.chunksize).to_zarr(
zarr_path,
overwrite=False,
)
x_labels = da.from_zarr(zarr_path)
x_labels = x_labels.rechunk(_chunks)
else:
x_labels = x_labels.persist()

log.info("Linking labels across chunks.")

iou_depth = da.overlap.coerce_depth(len(depth), iou_depth)

if any(iou_depth[ax] > depth[ax] for ax in depth.keys()):
Expand Down
Loading
Loading