Skip to content

Commit

Permalink
Sc 45 instanseg (#85)
Browse files Browse the repository at this point in the history
* SC_45 instanseg

* SC_45 instanseg

* SC_45 instanseg intermediate

* SC_45 instanseg

* SC_45 fix iou false bug if multiple output masks + remove docs

* SC_45 docs

* SC_45 map labels intermedate zarr store to reduce ram

* SC_45 instanseg
  • Loading branch information
ArneDefauw authored Jan 22, 2025
1 parent 96f425b commit 20e5ac9
Show file tree
Hide file tree
Showing 12 changed files with 377 additions and 957 deletions.
1 change: 1 addition & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ Operations on image and labels layers.
im.segment
im.segment_points
im.cellpose_callable
im.instanseg_callable
im.add_grid_labels_layer
im.expand_labels_layer
im.align_labels_layers
Expand Down
963 changes: 193 additions & 770 deletions docs/tutorials/advanced/Harpy_instanseg.ipynb

Large diffs are not rendered by default.

4 changes: 1 addition & 3 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,7 @@ clustering =
#napari-convpaint @ git+ssh://[email protected]/guiwitz/napari-convpaint.git

instanseg =
instanseg
torchvision
monai
instanseg-torch

[options.package_data]
* = *.yaml
Expand Down
39 changes: 39 additions & 0 deletions src/harpy/_tests/test_image/test_segmentation/test_segmentation.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import importlib.util

import dask.array as da
import dask.dataframe as dd
import pandas as pd
import pytest
from dask.dataframe import DataFrame
from spatialdata import SpatialData

from harpy.image._image import _get_spatial_element
from harpy.image.segmentation._segmentation import segment, segment_points
from harpy.image.segmentation.segmentation_models._baysor import _dummy
from harpy.image.segmentation.segmentation_models._cellpose import cellpose_callable
Expand Down Expand Up @@ -124,3 +126,40 @@ def test_segment_points(sdata_multi_c_no_backed: SpatialData):
assert _output_labels_layer in sdata_multi_c_no_backed.labels
for _output_shapes_layer in output_shapes_layer:
assert _output_shapes_layer in sdata_multi_c_no_backed.shapes


@pytest.mark.skipif(not importlib.util.find_spec("instanseg"), reason="requires the instanseg library")
def test_segment_instanseg(sdata_multi_c_no_backed: SpatialData):
from instanseg import InstanSeg

from harpy.image.segmentation.segmentation_models._instanseg import instanseg_callable

instanseg_fluorescence = InstanSeg("fluorescence_nuclei_and_cells", verbosity=1, device="cpu")

output_labels_layer = ["labels_nuclei_instanseg", "labels_cells_instanseg"]
output_shapes_layer = ["shapes_nuclei_instanseg", "shapes_cells_instanseg"]
sdata_multi_c_no_backed = segment(
sdata_multi_c_no_backed,
img_layer="combine",
model=instanseg_callable,
output_labels_layer=output_labels_layer,
output_shapes_layer=output_shapes_layer,
labels_layer_align="labels_cells_instanseg",
trim=False,
chunks=50,
overwrite=True,
depth=30,
crd=[10, 110, 0, 100],
scale_factors=[2, 2, 2, 2],
instanseg_model=instanseg_fluorescence,
output="all_outputs",
)

for _output_labels_layer in output_labels_layer:
assert _output_labels_layer in sdata_multi_c_no_backed.labels
for _output_shapes_layer in output_shapes_layer:
assert _output_shapes_layer in sdata_multi_c_no_backed.shapes

for _output_labels_layer in output_labels_layer:
se = _get_spatial_element(sdata_multi_c_no_backed, layer=output_labels_layer[0])
assert da.any(se.data).compute()
6 changes: 2 additions & 4 deletions src/harpy/_tests/test_notebooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,8 @@ def test_notebooks_flowsom(notebook):

@pytest.mark.skip
@pytest.mark.skipif(
not importlib.util.find_spec("InstanSeg")
or not importlib.util.find_spec("monai")
or not importlib.util.find_spec("torchvision"),
reason="requires the InstanSeg, monai and torchvision libraries",
not importlib.util.find_spec("instanseg"),
reason="requires the instanseg library",
)
@pytest.mark.parametrize(
"notebook",
Expand Down
1 change: 1 addition & 0 deletions src/harpy/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@
)
from .segmentation._segmentation import segment, segment_points
from .segmentation.segmentation_models._cellpose import cellpose_callable
from .segmentation.segmentation_models._instanseg import instanseg_callable
107 changes: 0 additions & 107 deletions src/harpy/image/segmentation/_align_masks.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,11 @@
from __future__ import annotations

from typing import Any

import dask.array as da
import numpy as np
from dask.array import Array
from numpy.typing import NDArray
from spatialdata import SpatialData
from spatialdata.models.models import ScaleFactors_t

from harpy.image.segmentation._map import map_labels
from harpy.image.segmentation._utils import (
_SEG_DTYPE,
_add_depth_to_chunks_size,
_check_boundary,
_clean_up_masks,
_merge_masks,
_rechunk_overlap,
_substract_depth_from_chunks_size,
)
from harpy.utils.pylogger import get_pylogger

log = get_pylogger(__name__)
Expand Down Expand Up @@ -117,100 +104,6 @@ def align_labels_layers(
return sdata


def _align_dask_arrays(
x_label_1: Array,
x_label_2: Array,
**kwargs: Any, # keyword arguments to be passed to map_overlap/map_blocks
):
# we will align labels of x_label_1 with labels of x_labels_2.

assert x_label_1.shape == x_label_2.shape, "Only arrays with same shape are currently supported."

chunks = kwargs.pop("chunks", None)
depth = kwargs.pop("depth", 100)
boundary = kwargs.pop("boundary", "reflect")

if isinstance(depth, int):
depth = {0: 0, 1: depth, 2: depth}
else:
assert len(depth) == x_label_1.ndim, f"Please provide depth for each dimension ({x_label_1.ndim})."
if x_label_1.ndim == 2:
depth = {0: 0, 1: depth[0], 2: depth[1]}

assert depth[0] == 0, "Depth not equal to 0 for 'z' dimension is not supported"

if chunks is None:
assert (
x_label_1.chunksize == x_label_2.chunksize
), "If chunks is not specified, please ensure Dask arrays have the same chunksize."

_check_boundary(boundary)

_to_squeeze = False
if x_label_1.ndim == 2:
_to_squeeze = True
x_label_1 = x_label_1[None, ...]
x_label_2 = x_label_2[None, ...]

# rechunk so that we ensure minimum chunksize, in order to control output_chunks sizes.
x_label_1 = _rechunk_overlap(x_label_1, depth=depth, chunks=chunks)
x_label_2 = _rechunk_overlap(x_label_2, depth=depth, chunks=chunks)

assert (
x_label_1.numblocks[0] == 1
), f"Expected the number of blocks in the Z-dimension to be `1`, found `{x_label_1.numblocks[0]}`."
assert (
x_label_2.numblocks[0] == 1
), f"Expected the number of blocks in the Z-dimension to be `1`, found `{x_label_2.numblocks[0]}`."

# output_chunks can be derived from either x_label_1 or x_label_2
output_chunks = _add_depth_to_chunks_size(x_label_1.chunks, depth)

x_labels = da.map_overlap(
lambda m, f: _relabel_array_1_to_array_2_per_chunk(m, f),
x_label_1,
x_label_2,
dtype=_SEG_DTYPE,
allow_rechunk=False, # already dealed with correcting for case where depth > chunksize
chunks=output_chunks, # e.g. ((1024+60, 1024+60, 452+60), (1024+60, 1024+60, 452+60) ),
depth=depth,
trim=False,
boundary="reflect",
# this reflect is useless for this use case, but clean_up_masks and _merge_masks only support
# results from map_overlap generated with "reflect", "nearest" and "constant"
)

x_labels = da.map_blocks(
_clean_up_masks,
x_labels,
dtype=_SEG_DTYPE,
depth=depth,
)

output_chunks = _substract_depth_from_chunks_size(x_labels.chunks, depth=depth)

x_labels = da.map_overlap(
_merge_masks,
x_labels,
dtype=_SEG_DTYPE,
num_blocks=x_labels.numblocks,
trim=False,
allow_rechunk=False, # already dealed with correcting for case where depth > chunksize
chunks=output_chunks, # e.g. ((7,) ,(1024, 1024, 452), (1024, 1024, 452), (1,) ),
depth=depth,
boundary="reflect",
_depth=depth,
)

x_labels = x_labels.rechunk(x_labels.chunksize)

# squeeze if a trivial dimension was added.
if _to_squeeze:
x_labels = x_labels.squeeze(0)

return x_labels


def _relabel_array_1_to_array_2_per_chunk(array_1: NDArray, array_2: NDArray) -> NDArray:
assert array_1.shape == array_2.shape

Expand Down
43 changes: 37 additions & 6 deletions src/harpy/image/segmentation/_map.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from __future__ import annotations

import os
import shutil
import uuid
from collections.abc import Callable, Iterable, Mapping
from pathlib import Path
from types import MappingProxyType
from typing import Any

Expand All @@ -11,6 +15,7 @@
from spatialdata import SpatialData
from spatialdata.models.models import ScaleFactors_t
from spatialdata.transformations import Translation, get_transformation
from upath import UPath

from harpy.image._image import (
_get_spatial_element,
Expand Down Expand Up @@ -154,9 +159,9 @@ def _get_layers(sdata: SpatialData, labels_layers: list[str]) -> tuple[list[Arra
)

# Ensure the translation is the same as the first label layer
assert (
transformations == first_transformations
), f"Provided labels layers '{labels_layers}' should all have the same transformations defined on them."
assert transformations == first_transformations, (
f"Provided labels layers '{labels_layers}' should all have the same transformations defined on them."
)

labels_data.append(x_label)

Expand All @@ -171,13 +176,19 @@ def _get_layers(sdata: SpatialData, labels_layers: list[str]) -> tuple[list[Arra
kwargs.setdefault("iou_depth", iou_depth)
kwargs.setdefault("iou_threshold", iou_threshold)

if sdata.is_backed():
_temp_path = UPath(sdata.path).parent / f"tmp_{uuid.uuid4()}"
else:
_temp_path = None

# labels_arrays is a list of dask arrays
# do some processing on the labels
array = _combine_dask_arrays(
labels_arrays,
relabel_chunks=relabel_chunks,
trim=trim,
func=func,
temp_path=_temp_path,
fn_kwargs=fn_kwargs,
**kwargs,
)
Expand Down Expand Up @@ -205,6 +216,10 @@ def _get_layers(sdata: SpatialData, labels_layers: list[str]) -> tuple[list[Arra
overwrite=overwrite,
)

if _temp_path is not None:
# TODO this will not work if sdata is remote (e.g. s3 bucket).
shutil.rmtree(_temp_path)

return sdata


Expand All @@ -213,6 +228,7 @@ def _combine_dask_arrays(
relabel_chunks: bool,
trim: bool,
func: Callable[..., NDArray],
temp_path: str | Path,
fn_kwargs: Mapping[str, Any] = MappingProxyType({}), # keyword arguments to be passed to func
**kwargs: Any, # keyword arguments to be passed to map_overlap/map_blocks
) -> Array:
Expand Down Expand Up @@ -268,9 +284,9 @@ def _fix_depth(_depth):
for i, x_label in enumerate(_labels_arrays):
# rechunk so that we ensure minimum chunksize, in order to control output_chunks sizes.
x_label = _rechunk_overlap(x_label, depth=depth, chunks=chunks)
assert (
x_label.numblocks[0] == 1
), f"Expected the number of blocks in the Z-dimension to be `1`, found `{x_label.numblocks[0]}`."
assert x_label.numblocks[0] == 1, (
f"Expected the number of blocks in the Z-dimension to be `1`, found `{x_label.numblocks[0]}`."
)

if i == 0:
# output_chunks can be derived from any rechunked x_label in labels_arrays
Expand Down Expand Up @@ -307,6 +323,21 @@ def _fix_depth(_depth):
# return x_labels.squeeze(0)

if not trim:
# write to intermediate zarr store if sdata is backed to reduce ram memory.
if temp_path is not None:
zarr_path = os.path.join(temp_path, f"labels_{uuid.uuid4()}.zarr")
_chunks = x_labels.chunks
x_labels.rechunk(x_labels.chunksize).to_zarr(
zarr_path,
overwrite=False,
)
x_labels = da.from_zarr(zarr_path)
x_labels = x_labels.rechunk(_chunks)
else:
x_labels = x_labels.persist()

log.info("Linking labels across chunks.")

iou_depth = da.overlap.coerce_depth(len(depth), iou_depth)

if any(iou_depth[ax] > depth[ax] for ax in depth.keys()):
Expand Down
Loading

0 comments on commit 20e5ac9

Please sign in to comment.