Skip to content

Commit

Permalink
heavy unet tests - v0.3.5 (#327)
Browse files Browse the repository at this point in the history
multiple unitteststs + fixes
  • Loading branch information
mzouink authored Nov 19, 2024
2 parents fc4a5e6 + 53217ed commit f5c4b36
Show file tree
Hide file tree
Showing 35 changed files with 816 additions and 151 deletions.
2 changes: 1 addition & 1 deletion dacapo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.3.2"
__version__ = "0.3.5"
__version_info__ = tuple(int(i) for i in __version__.split("."))

from .options import Options # noqa
Expand Down
7 changes: 4 additions & 3 deletions dacapo/compute_context/local_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,10 @@ def device(self):
if self._device is None:
if torch.cuda.is_available():
# TODO: make this more sophisticated, for multiple GPUs for instance
free = torch.cuda.mem_get_info()[0] / 1024**3
if free < self.oom_limit: # less than 1 GB free, decrease chance of OOM
return torch.device("cpu")
# commented out code below is for checking free memory and falling back on CPU, whhen model in GPU and memory is low model get moved to CPU
# free = torch.cuda.mem_get_info()[0] / 1024**3
# if free < self.oom_limit: # less than 1 GB free, decrease chance of OOM
# return torch.device("cpu")
return torch.device("cuda")
# Multiple MPS ops are not available yet : https://github.com/pytorch/pytorch/issues/77764
# got error aten::max_pool3d_with_indices
Expand Down
23 changes: 12 additions & 11 deletions dacapo/experiments/architectures/cnnectome_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import torch
import torch.nn as nn

from funlib.geometry import Coordinate

import math


Expand Down Expand Up @@ -176,7 +178,7 @@ def __init__(self, architecture_config):
self.unet = self.module()

@property
def eval_shape_increase(self):
def eval_shape_increase(self) -> Coordinate:
"""
The increase in shape due to the U-Net.
Expand All @@ -192,7 +194,7 @@ def eval_shape_increase(self):
"""
if self._eval_shape_increase is None:
return super().eval_shape_increase
return self._eval_shape_increase
return Coordinate(self._eval_shape_increase)

def module(self):
"""
Expand Down Expand Up @@ -235,16 +237,15 @@ def module(self):
"""
fmaps_in = self.fmaps_in
levels = len(self.downsample_factors) + 1
dims = len(self.downsample_factors[0])

if hasattr(self, "kernel_size_down"):
if self.kernel_size_down is not None:
kernel_size_down = self.kernel_size_down
else:
kernel_size_down = [[(3,) * dims, (3,) * dims]] * levels
if hasattr(self, "kernel_size_up"):
kernel_size_down = [[(3,) * self.dims, (3,) * self.dims]] * levels
if self.kernel_size_up is not None:
kernel_size_up = self.kernel_size_up
else:
kernel_size_up = [[(3,) * dims, (3,) * dims]] * (levels - 1)
kernel_size_up = [[(3,) * self.dims, (3,) * self.dims]] * (levels - 1)

# downsample factors has to be a list of tuples
downsample_factors = [tuple(x) for x in self.downsample_factors]
Expand Down Expand Up @@ -280,7 +281,7 @@ def module(self):
conv = ConvPass(
self.fmaps_out,
self.fmaps_out,
[(3,) * len(upsample_factor)] * 2,
kernel_size_up[-1],
activation="ReLU",
batch_norm=self.batch_norm,
)
Expand All @@ -306,11 +307,11 @@ def scale(self, voxel_size):
The voxel size should be given as a tuple ``(z, y, x)``.
"""
for upsample_factor in self.upsample_factors:
voxel_size = voxel_size / upsample_factor
voxel_size = voxel_size / Coordinate(upsample_factor)
return voxel_size

@property
def input_shape(self):
def input_shape(self) -> Coordinate:
"""
Return the input shape of the U-Net.
Expand All @@ -324,7 +325,7 @@ def input_shape(self):
Note:
The input shape should be given as a tuple ``(batch, channels, [length,] depth, height, width)``.
"""
return self._input_shape
return Coordinate(self._input_shape)

@property
def num_in_channels(self) -> int:
Expand Down
2 changes: 1 addition & 1 deletion dacapo/experiments/architectures/cnnectome_unet_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,6 @@ class CNNectomeUNetConfig(ArchitectureConfig):
},
)
batch_norm: bool = attr.ib(
default=True,
default=False,
metadata={"help_text": "Whether to use batch normalization."},
)
3 changes: 3 additions & 0 deletions dacapo/experiments/datasplits/datasets/dataset_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,6 @@ def verify(self) -> Tuple[bool, str]:
This method is used to validate the configuration of the dataset.
"""
return True, "No validation for this DataSet"

def __hash__(self) -> int:
return hash(self.name)

Check warning on line 68 in dacapo/experiments/datasplits/datasets/dataset_config.py

View check run for this annotation

Codecov / codecov/patch

dacapo/experiments/datasplits/datasets/dataset_config.py#L68

Added line #L68 was not covered by tests
2 changes: 1 addition & 1 deletion dacapo/experiments/datasplits/simple_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def get_paths(self, group_name: str) -> list[Path]:
len(level_2_matches) == 0
), f"Found raw data at {level_1} and {level_2}"
return [Path(x).parent for x in level_1_matches]
elif len(level_2_matches).parent > 0:
elif len(level_2_matches) > 0:
return [Path(x) for x in level_2_matches]

Check warning on line 48 in dacapo/experiments/datasplits/simple_config.py

View check run for this annotation

Codecov / codecov/patch

dacapo/experiments/datasplits/simple_config.py#L46-L48

Added lines #L46 - L48 were not covered by tests

raise Exception(f"No raw data found at {level_0} or {level_1} or {level_2}")

Check warning on line 50 in dacapo/experiments/datasplits/simple_config.py

View check run for this annotation

Codecov / codecov/patch

dacapo/experiments/datasplits/simple_config.py#L50

Added line #L50 was not covered by tests
Expand Down
11 changes: 11 additions & 0 deletions dacapo/experiments/tasks/hot_distance_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from .predictors import HotDistancePredictor
from .task import Task

import warnings


class HotDistanceTask(Task):
"""
Expand Down Expand Up @@ -34,10 +36,19 @@ def __init__(self, task_config):
>>> task = HotDistanceTask(task_config)
"""

if task_config.kernel_size is None:
warnings.warn(
"The default kernel size of 3 will be changing to 1. "
"Please specify the kernel size explicitly.",
DeprecationWarning,
)
task_config.kernel_size = 3
self.predictor = HotDistancePredictor(
channels=task_config.channels,
scale_factor=task_config.scale_factor,
mask_distances=task_config.mask_distances,
kernel_size=task_config.kernel_size,
)
self.loss = HotDistanceLoss()
self.post_processor = ThresholdPostProcessor()
Expand Down
4 changes: 4 additions & 0 deletions dacapo/experiments/tasks/hot_distance_task_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,7 @@ class HotDistanceTaskConfig(TaskConfig):
"is less than the distance to object boundary."
},
)

kernel_size: int | None = attr.ib(
default=None,
)
14 changes: 13 additions & 1 deletion dacapo/experiments/tasks/one_hot_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from .predictors import OneHotPredictor
from .task import Task

import warnings


class OneHotTask(Task):
"""
Expand All @@ -30,7 +32,17 @@ def __init__(self, task_config):
Examples:
>>> task = OneHotTask(task_config)
"""
self.predictor = OneHotPredictor(classes=task_config.classes)

if task_config.kernel_size is None:
warnings.warn(

Check warning on line 37 in dacapo/experiments/tasks/one_hot_task.py

View check run for this annotation

Codecov / codecov/patch

dacapo/experiments/tasks/one_hot_task.py#L37

Added line #L37 was not covered by tests
"The default kernel size of 3 will be changing to 1. "
"Please specify the kernel size explicitly.",
DeprecationWarning,
)
task_config.kernel_size = 3

Check warning on line 42 in dacapo/experiments/tasks/one_hot_task.py

View check run for this annotation

Codecov / codecov/patch

dacapo/experiments/tasks/one_hot_task.py#L42

Added line #L42 was not covered by tests
self.predictor = OneHotPredictor(
classes=task_config.classes, kernel_size=task_config.kernel_size
)
self.loss = DummyLoss()
self.post_processor = ArgmaxPostProcessor()
self.evaluator = DummyEvaluator()
3 changes: 3 additions & 0 deletions dacapo/experiments/tasks/one_hot_task_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,6 @@ class OneHotTaskConfig(TaskConfig):
classes: List[str] = attr.ib(
metadata={"help_text": "The classes corresponding with each id starting from 0"}
)
kernel_size: int | None = attr.ib(
default=None,
)
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def process(
overwrite=True,
)

read_roi = Roi((0, 0, 0), block_size[-self.prediction_array.dims :])
read_roi = Roi((0,) * block_size.dims, block_size)
input_array = open_ds(
f"{self.prediction_array_identifier.container.path}/{self.prediction_array_identifier.dataset}"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
from funlib.persistence import Array

from typing import Iterable
import logging

logger = logging.getLogger(__name__)


class ThresholdPostProcessor(PostProcessor):
Expand Down Expand Up @@ -108,13 +111,15 @@ def process(
if self.prediction_array._source_data.chunks is not None:
block_size = self.prediction_array._source_data.chunks

write_size = [
b * v
for b, v in zip(
block_size[-self.prediction_array.dims :],
self.prediction_array.voxel_size,
)
]
write_size = Coordinate(
[
b * v
for b, v in zip(
block_size[-self.prediction_array.dims :],
self.prediction_array.voxel_size,
)
]
)
output_array = create_from_identifier(
output_array_identifier,
self.prediction_array.axis_names,
Expand All @@ -125,7 +130,7 @@ def process(
overwrite=True,
)

read_roi = Roi((0, 0, 0), write_size[-self.prediction_array.dims :])
read_roi = Roi(write_size * 0, write_size)
input_array = open_ds(
f"{self.prediction_array_identifier.container.path}/{self.prediction_array_identifier.dataset}"
)
Expand All @@ -135,7 +140,7 @@ def process_block(block):
data = input_array[write_roi] > parameters.threshold
data = data.astype(np.uint8)
if int(data.max()) == 0:
print("No data in block", write_roi)
logger.debug("No data in block", write_roi)
return
output_array[write_roi] = data

Check warning on line 145 in dacapo/experiments/tasks/post_processors/threshold_post_processor.py

View check run for this annotation

Codecov / codecov/patch

dacapo/experiments/tasks/post_processors/threshold_post_processor.py#L145

Added line #L145 was not covered by tests

Expand Down
35 changes: 29 additions & 6 deletions dacapo/experiments/tasks/predictors/distance_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,10 @@ def create_distance_mask(
>>> predictor.create_distance_mask(distances, mask, voxel_size, normalize, normalize_args)
"""
no_channel_dim = len(mask.shape) == len(distances.shape) - 1
if no_channel_dim:
mask = mask[np.newaxis]

mask_output = mask.copy()
for i, (channel_distance, channel_mask) in enumerate(zip(distances, mask)):
tmp = np.zeros(
Expand All @@ -231,9 +235,11 @@ def create_distance_mask(
)
slices = tmp.ndim * (slice(1, -1),)
tmp[slices] = channel_mask
sampling = tuple(float(v) / 2 for v in voxel_size)
sampling = sampling[-len(tmp.shape) :]
boundary_distance = distance_transform_edt(
tmp,
sampling=voxel_size,
sampling=sampling,
)
if self.epsilon is None:
add = 0
Expand Down Expand Up @@ -273,6 +279,8 @@ def create_distance_mask(
np.sum(channel_mask_output)
)
)
if no_channel_dim:
mask_output = mask_output[0]
return mask_output

def process(
Expand All @@ -298,7 +306,20 @@ def process(
>>> predictor.process(labels, voxel_size, normalize, normalize_args)
"""

num_dims = len(labels.shape)
if num_dims == voxel_size.dims:
channel_dim = False
elif num_dims == voxel_size.dims + 1:
channel_dim = True
else:
raise ValueError("Cannot handle multiple channel dims")

Check warning on line 316 in dacapo/experiments/tasks/predictors/distance_predictor.py

View check run for this annotation

Codecov / codecov/patch

dacapo/experiments/tasks/predictors/distance_predictor.py#L316

Added line #L316 was not covered by tests

if not channel_dim:
labels = labels[np.newaxis]

all_distances = np.zeros(labels.shape, dtype=np.float32) - 1

for ii, channel in enumerate(labels):
boundaries = self.__find_boundaries(channel)

Expand All @@ -315,13 +336,15 @@ def process(
distances = np.ones(channel.shape, dtype=np.float32) * max_distance
else:
# get distances (voxel_size/2 because image is doubled)
distances = distance_transform_edt(
boundaries, sampling=tuple(float(v) / 2 for v in voxel_size)
)
sampling = tuple(float(v) / 2 for v in voxel_size)
# fixing the sampling for 2D images
if len(boundaries.shape) < len(sampling):
sampling = sampling[-len(boundaries.shape) :]

Check warning on line 342 in dacapo/experiments/tasks/predictors/distance_predictor.py

View check run for this annotation

Codecov / codecov/patch

dacapo/experiments/tasks/predictors/distance_predictor.py#L342

Added line #L342 was not covered by tests
distances = distance_transform_edt(boundaries, sampling=sampling)
distances = distances.astype(np.float32)

# restore original shape
downsample = (slice(None, None, 2),) * len(voxel_size)
downsample = (slice(None, None, 2),) * distances.ndim
distances = distances[downsample]

# todo: inverted distance
Expand Down Expand Up @@ -354,7 +377,7 @@ def __find_boundaries(self, labels: np.ndarray):
# bound.: 00000001000100000001000 2n - 1

if labels.dtype == bool:
raise ValueError("Labels should not be bools")
# raise ValueError("Labels should not be bools")
labels = labels.astype(np.uint8)

logger.debug(f"computing boundaries for {labels.shape}")
Expand Down
8 changes: 3 additions & 5 deletions dacapo/experiments/tasks/predictors/dummy_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def create_model(self, architecture):
>>> model = predictor.create_model(architecture)
"""
head = torch.nn.Conv3d(
architecture.num_out_channels, self.embedding_dims, kernel_size=3
architecture.num_out_channels, self.embedding_dims, kernel_size=1
)

return Model(architecture, head)
Expand All @@ -71,9 +71,8 @@ def create_target(self, gt):
# zeros
return np_to_funlib_array(

Check warning on line 72 in dacapo/experiments/tasks/predictors/dummy_predictor.py

View check run for this annotation

Codecov / codecov/patch

dacapo/experiments/tasks/predictors/dummy_predictor.py#L72

Added line #L72 was not covered by tests
np.zeros((self.embedding_dims,) + gt.data.shape[-gt.dims :]),
gt.roi,
gt.roi.offset,
gt.voxel_size,
["c^"] + gt.axis_names,
)

def create_weight(self, gt, target, mask, moving_class_counts=None):
Expand All @@ -96,9 +95,8 @@ def create_weight(self, gt, target, mask, moving_class_counts=None):
return (
np_to_funlib_array(
np.ones(target.data.shape),
target.roi,
target.roi.offset,
target.voxel_size,
target.axis_names,
),
None,
)
Expand Down
Loading

0 comments on commit f5c4b36

Please sign in to comment.