diff --git a/dacapo/__init__.py b/dacapo/__init__.py index 3b06000aa..e40e7277a 100644 --- a/dacapo/__init__.py +++ b/dacapo/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.3.2" +__version__ = "0.3.5" __version_info__ = tuple(int(i) for i in __version__.split(".")) from .options import Options # noqa diff --git a/dacapo/compute_context/local_torch.py b/dacapo/compute_context/local_torch.py index a547b7dd7..08e813712 100644 --- a/dacapo/compute_context/local_torch.py +++ b/dacapo/compute_context/local_torch.py @@ -56,9 +56,10 @@ def device(self): if self._device is None: if torch.cuda.is_available(): # TODO: make this more sophisticated, for multiple GPUs for instance - free = torch.cuda.mem_get_info()[0] / 1024**3 - if free < self.oom_limit: # less than 1 GB free, decrease chance of OOM - return torch.device("cpu") + # commented out code below is for checking free memory and falling back on CPU, whhen model in GPU and memory is low model get moved to CPU + # free = torch.cuda.mem_get_info()[0] / 1024**3 + # if free < self.oom_limit: # less than 1 GB free, decrease chance of OOM + # return torch.device("cpu") return torch.device("cuda") # Multiple MPS ops are not available yet : https://github.com/pytorch/pytorch/issues/77764 # got error aten::max_pool3d_with_indices diff --git a/dacapo/experiments/architectures/cnnectome_unet.py b/dacapo/experiments/architectures/cnnectome_unet.py index d89e902ac..c064305b1 100644 --- a/dacapo/experiments/architectures/cnnectome_unet.py +++ b/dacapo/experiments/architectures/cnnectome_unet.py @@ -3,6 +3,8 @@ import torch import torch.nn as nn +from funlib.geometry import Coordinate + import math @@ -176,7 +178,7 @@ def __init__(self, architecture_config): self.unet = self.module() @property - def eval_shape_increase(self): + def eval_shape_increase(self) -> Coordinate: """ The increase in shape due to the U-Net. @@ -192,7 +194,7 @@ def eval_shape_increase(self): """ if self._eval_shape_increase is None: return super().eval_shape_increase - return self._eval_shape_increase + return Coordinate(self._eval_shape_increase) def module(self): """ @@ -235,16 +237,15 @@ def module(self): """ fmaps_in = self.fmaps_in levels = len(self.downsample_factors) + 1 - dims = len(self.downsample_factors[0]) - if hasattr(self, "kernel_size_down"): + if self.kernel_size_down is not None: kernel_size_down = self.kernel_size_down else: - kernel_size_down = [[(3,) * dims, (3,) * dims]] * levels - if hasattr(self, "kernel_size_up"): + kernel_size_down = [[(3,) * self.dims, (3,) * self.dims]] * levels + if self.kernel_size_up is not None: kernel_size_up = self.kernel_size_up else: - kernel_size_up = [[(3,) * dims, (3,) * dims]] * (levels - 1) + kernel_size_up = [[(3,) * self.dims, (3,) * self.dims]] * (levels - 1) # downsample factors has to be a list of tuples downsample_factors = [tuple(x) for x in self.downsample_factors] @@ -280,7 +281,7 @@ def module(self): conv = ConvPass( self.fmaps_out, self.fmaps_out, - [(3,) * len(upsample_factor)] * 2, + kernel_size_up[-1], activation="ReLU", batch_norm=self.batch_norm, ) @@ -306,11 +307,11 @@ def scale(self, voxel_size): The voxel size should be given as a tuple ``(z, y, x)``. """ for upsample_factor in self.upsample_factors: - voxel_size = voxel_size / upsample_factor + voxel_size = voxel_size / Coordinate(upsample_factor) return voxel_size @property - def input_shape(self): + def input_shape(self) -> Coordinate: """ Return the input shape of the U-Net. @@ -324,7 +325,7 @@ def input_shape(self): Note: The input shape should be given as a tuple ``(batch, channels, [length,] depth, height, width)``. """ - return self._input_shape + return Coordinate(self._input_shape) @property def num_in_channels(self) -> int: diff --git a/dacapo/experiments/architectures/cnnectome_unet_config.py b/dacapo/experiments/architectures/cnnectome_unet_config.py index 7eab80115..643921386 100644 --- a/dacapo/experiments/architectures/cnnectome_unet_config.py +++ b/dacapo/experiments/architectures/cnnectome_unet_config.py @@ -128,6 +128,6 @@ class CNNectomeUNetConfig(ArchitectureConfig): }, ) batch_norm: bool = attr.ib( - default=True, + default=False, metadata={"help_text": "Whether to use batch normalization."}, ) diff --git a/dacapo/experiments/datasplits/datasets/dataset_config.py b/dacapo/experiments/datasplits/datasets/dataset_config.py index 4217eb00e..4a4ba3018 100644 --- a/dacapo/experiments/datasplits/datasets/dataset_config.py +++ b/dacapo/experiments/datasplits/datasets/dataset_config.py @@ -63,3 +63,6 @@ def verify(self) -> Tuple[bool, str]: This method is used to validate the configuration of the dataset. """ return True, "No validation for this DataSet" + + def __hash__(self) -> int: + return hash(self.name) diff --git a/dacapo/experiments/datasplits/simple_config.py b/dacapo/experiments/datasplits/simple_config.py index 53a66945b..9ee88283a 100644 --- a/dacapo/experiments/datasplits/simple_config.py +++ b/dacapo/experiments/datasplits/simple_config.py @@ -44,7 +44,7 @@ def get_paths(self, group_name: str) -> list[Path]: len(level_2_matches) == 0 ), f"Found raw data at {level_1} and {level_2}" return [Path(x).parent for x in level_1_matches] - elif len(level_2_matches).parent > 0: + elif len(level_2_matches) > 0: return [Path(x) for x in level_2_matches] raise Exception(f"No raw data found at {level_0} or {level_1} or {level_2}") diff --git a/dacapo/experiments/tasks/hot_distance_task.py b/dacapo/experiments/tasks/hot_distance_task.py index 630e58ed5..3d86da131 100644 --- a/dacapo/experiments/tasks/hot_distance_task.py +++ b/dacapo/experiments/tasks/hot_distance_task.py @@ -4,6 +4,8 @@ from .predictors import HotDistancePredictor from .task import Task +import warnings + class HotDistanceTask(Task): """ @@ -34,10 +36,19 @@ def __init__(self, task_config): >>> task = HotDistanceTask(task_config) """ + + if task_config.kernel_size is None: + warnings.warn( + "The default kernel size of 3 will be changing to 1. " + "Please specify the kernel size explicitly.", + DeprecationWarning, + ) + task_config.kernel_size = 3 self.predictor = HotDistancePredictor( channels=task_config.channels, scale_factor=task_config.scale_factor, mask_distances=task_config.mask_distances, + kernel_size=task_config.kernel_size, ) self.loss = HotDistanceLoss() self.post_processor = ThresholdPostProcessor() diff --git a/dacapo/experiments/tasks/hot_distance_task_config.py b/dacapo/experiments/tasks/hot_distance_task_config.py index 18cab91b3..7e0cc37aa 100644 --- a/dacapo/experiments/tasks/hot_distance_task_config.py +++ b/dacapo/experiments/tasks/hot_distance_task_config.py @@ -56,3 +56,7 @@ class HotDistanceTaskConfig(TaskConfig): "is less than the distance to object boundary." }, ) + + kernel_size: int | None = attr.ib( + default=None, + ) diff --git a/dacapo/experiments/tasks/one_hot_task.py b/dacapo/experiments/tasks/one_hot_task.py index 870140f50..55d115d15 100644 --- a/dacapo/experiments/tasks/one_hot_task.py +++ b/dacapo/experiments/tasks/one_hot_task.py @@ -4,6 +4,8 @@ from .predictors import OneHotPredictor from .task import Task +import warnings + class OneHotTask(Task): """ @@ -30,7 +32,17 @@ def __init__(self, task_config): Examples: >>> task = OneHotTask(task_config) """ - self.predictor = OneHotPredictor(classes=task_config.classes) + + if task_config.kernel_size is None: + warnings.warn( + "The default kernel size of 3 will be changing to 1. " + "Please specify the kernel size explicitly.", + DeprecationWarning, + ) + task_config.kernel_size = 3 + self.predictor = OneHotPredictor( + classes=task_config.classes, kernel_size=task_config.kernel_size + ) self.loss = DummyLoss() self.post_processor = ArgmaxPostProcessor() self.evaluator = DummyEvaluator() diff --git a/dacapo/experiments/tasks/one_hot_task_config.py b/dacapo/experiments/tasks/one_hot_task_config.py index de4817a0e..4207448de 100644 --- a/dacapo/experiments/tasks/one_hot_task_config.py +++ b/dacapo/experiments/tasks/one_hot_task_config.py @@ -28,3 +28,6 @@ class OneHotTaskConfig(TaskConfig): classes: List[str] = attr.ib( metadata={"help_text": "The classes corresponding with each id starting from 0"} ) + kernel_size: int | None = attr.ib( + default=None, + ) diff --git a/dacapo/experiments/tasks/post_processors/argmax_post_processor.py b/dacapo/experiments/tasks/post_processors/argmax_post_processor.py index f736d3e17..34cb0245d 100644 --- a/dacapo/experiments/tasks/post_processors/argmax_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/argmax_post_processor.py @@ -133,7 +133,7 @@ def process( overwrite=True, ) - read_roi = Roi((0, 0, 0), block_size[-self.prediction_array.dims :]) + read_roi = Roi((0,) * block_size.dims, block_size) input_array = open_ds( f"{self.prediction_array_identifier.container.path}/{self.prediction_array_identifier.dataset}" ) diff --git a/dacapo/experiments/tasks/post_processors/threshold_post_processor.py b/dacapo/experiments/tasks/post_processors/threshold_post_processor.py index 59059e516..24ecead7a 100644 --- a/dacapo/experiments/tasks/post_processors/threshold_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/threshold_post_processor.py @@ -18,6 +18,9 @@ from funlib.persistence import Array from typing import Iterable +import logging + +logger = logging.getLogger(__name__) class ThresholdPostProcessor(PostProcessor): @@ -108,13 +111,15 @@ def process( if self.prediction_array._source_data.chunks is not None: block_size = self.prediction_array._source_data.chunks - write_size = [ - b * v - for b, v in zip( - block_size[-self.prediction_array.dims :], - self.prediction_array.voxel_size, - ) - ] + write_size = Coordinate( + [ + b * v + for b, v in zip( + block_size[-self.prediction_array.dims :], + self.prediction_array.voxel_size, + ) + ] + ) output_array = create_from_identifier( output_array_identifier, self.prediction_array.axis_names, @@ -125,7 +130,7 @@ def process( overwrite=True, ) - read_roi = Roi((0, 0, 0), write_size[-self.prediction_array.dims :]) + read_roi = Roi(write_size * 0, write_size) input_array = open_ds( f"{self.prediction_array_identifier.container.path}/{self.prediction_array_identifier.dataset}" ) @@ -135,7 +140,7 @@ def process_block(block): data = input_array[write_roi] > parameters.threshold data = data.astype(np.uint8) if int(data.max()) == 0: - print("No data in block", write_roi) + logger.debug("No data in block", write_roi) return output_array[write_roi] = data diff --git a/dacapo/experiments/tasks/predictors/distance_predictor.py b/dacapo/experiments/tasks/predictors/distance_predictor.py index 861a9e1dd..741e14db6 100644 --- a/dacapo/experiments/tasks/predictors/distance_predictor.py +++ b/dacapo/experiments/tasks/predictors/distance_predictor.py @@ -223,6 +223,10 @@ def create_distance_mask( >>> predictor.create_distance_mask(distances, mask, voxel_size, normalize, normalize_args) """ + no_channel_dim = len(mask.shape) == len(distances.shape) - 1 + if no_channel_dim: + mask = mask[np.newaxis] + mask_output = mask.copy() for i, (channel_distance, channel_mask) in enumerate(zip(distances, mask)): tmp = np.zeros( @@ -231,9 +235,11 @@ def create_distance_mask( ) slices = tmp.ndim * (slice(1, -1),) tmp[slices] = channel_mask + sampling = tuple(float(v) / 2 for v in voxel_size) + sampling = sampling[-len(tmp.shape) :] boundary_distance = distance_transform_edt( tmp, - sampling=voxel_size, + sampling=sampling, ) if self.epsilon is None: add = 0 @@ -273,6 +279,8 @@ def create_distance_mask( np.sum(channel_mask_output) ) ) + if no_channel_dim: + mask_output = mask_output[0] return mask_output def process( @@ -298,7 +306,20 @@ def process( >>> predictor.process(labels, voxel_size, normalize, normalize_args) """ + + num_dims = len(labels.shape) + if num_dims == voxel_size.dims: + channel_dim = False + elif num_dims == voxel_size.dims + 1: + channel_dim = True + else: + raise ValueError("Cannot handle multiple channel dims") + + if not channel_dim: + labels = labels[np.newaxis] + all_distances = np.zeros(labels.shape, dtype=np.float32) - 1 + for ii, channel in enumerate(labels): boundaries = self.__find_boundaries(channel) @@ -315,13 +336,15 @@ def process( distances = np.ones(channel.shape, dtype=np.float32) * max_distance else: # get distances (voxel_size/2 because image is doubled) - distances = distance_transform_edt( - boundaries, sampling=tuple(float(v) / 2 for v in voxel_size) - ) + sampling = tuple(float(v) / 2 for v in voxel_size) + # fixing the sampling for 2D images + if len(boundaries.shape) < len(sampling): + sampling = sampling[-len(boundaries.shape) :] + distances = distance_transform_edt(boundaries, sampling=sampling) distances = distances.astype(np.float32) # restore original shape - downsample = (slice(None, None, 2),) * len(voxel_size) + downsample = (slice(None, None, 2),) * distances.ndim distances = distances[downsample] # todo: inverted distance @@ -354,7 +377,7 @@ def __find_boundaries(self, labels: np.ndarray): # bound.: 00000001000100000001000 2n - 1 if labels.dtype == bool: - raise ValueError("Labels should not be bools") + # raise ValueError("Labels should not be bools") labels = labels.astype(np.uint8) logger.debug(f"computing boundaries for {labels.shape}") diff --git a/dacapo/experiments/tasks/predictors/dummy_predictor.py b/dacapo/experiments/tasks/predictors/dummy_predictor.py index 3293f6423..46da2f6d9 100644 --- a/dacapo/experiments/tasks/predictors/dummy_predictor.py +++ b/dacapo/experiments/tasks/predictors/dummy_predictor.py @@ -50,7 +50,7 @@ def create_model(self, architecture): >>> model = predictor.create_model(architecture) """ head = torch.nn.Conv3d( - architecture.num_out_channels, self.embedding_dims, kernel_size=3 + architecture.num_out_channels, self.embedding_dims, kernel_size=1 ) return Model(architecture, head) @@ -71,9 +71,8 @@ def create_target(self, gt): # zeros return np_to_funlib_array( np.zeros((self.embedding_dims,) + gt.data.shape[-gt.dims :]), - gt.roi, + gt.roi.offset, gt.voxel_size, - ["c^"] + gt.axis_names, ) def create_weight(self, gt, target, mask, moving_class_counts=None): @@ -96,9 +95,8 @@ def create_weight(self, gt, target, mask, moving_class_counts=None): return ( np_to_funlib_array( np.ones(target.data.shape), - target.roi, + target.roi.offset, target.voxel_size, - target.axis_names, ), None, ) diff --git a/dacapo/experiments/tasks/predictors/hot_distance_predictor.py b/dacapo/experiments/tasks/predictors/hot_distance_predictor.py index 607c426f0..f2ec4f874 100644 --- a/dacapo/experiments/tasks/predictors/hot_distance_predictor.py +++ b/dacapo/experiments/tasks/predictors/hot_distance_predictor.py @@ -49,7 +49,13 @@ class HotDistancePredictor(Predictor): This is a subclass of Predictor. """ - def __init__(self, channels: List[str], scale_factor: float, mask_distances: bool): + def __init__( + self, + channels: List[str], + scale_factor: float, + mask_distances: bool, + kernel_size: int, + ): """ Initializes the HotDistancePredictor. @@ -64,6 +70,7 @@ def __init__(self, channels: List[str], scale_factor: float, mask_distances: boo Note: The channels argument is a list of strings, each string is the name of a class that is being segmented. """ + self.kernel_size = kernel_size self.channels = ( channels * 2 ) # one hot + distance (TODO: add hot/distance to channel names) @@ -119,11 +126,11 @@ def create_model(self, architecture): """ if architecture.dims == 2: head = torch.nn.Conv2d( - architecture.num_out_channels, self.embedding_dims, kernel_size=3 + architecture.num_out_channels, self.embedding_dims, self.kernel_size ) elif architecture.dims == 3: head = torch.nn.Conv3d( - architecture.num_out_channels, self.embedding_dims, kernel_size=3 + architecture.num_out_channels, self.embedding_dims, self.kernel_size ) return Model(architecture, head) @@ -141,12 +148,11 @@ def create_target(self, gt): Examples: >>> target = predictor.create_target(gt) """ - target = self.process(gt.data, gt.voxel_size, self.norm, self.dt_scale_factor) + target = self.process(gt[:], gt.voxel_size, self.norm, self.dt_scale_factor) return np_to_funlib_array( target, - gt.roi, + gt.roi.offset, gt.voxel_size, - gt.axis_names, ) def create_weight(self, gt, target, mask, moving_class_counts=None): @@ -209,9 +215,8 @@ def create_weight(self, gt, target, mask, moving_class_counts=None): return ( np_to_funlib_array( weights, - gt.roi, + gt.roi.offset, gt.voxel_size, - gt.axis_names, ), moving_class_counts, ) diff --git a/dacapo/experiments/tasks/predictors/inner_distance_predictor.py b/dacapo/experiments/tasks/predictors/inner_distance_predictor.py index b2f50b59a..a6f18d865 100644 --- a/dacapo/experiments/tasks/predictors/inner_distance_predictor.py +++ b/dacapo/experiments/tasks/predictors/inner_distance_predictor.py @@ -120,9 +120,8 @@ def create_target(self, gt): ) return np_to_funlib_array( distances, - gt.roi, + gt.roi.offset, gt.voxel_size, - gt.axis_names, ) def create_weight(self, gt, target, mask, moving_class_counts=None): @@ -155,9 +154,8 @@ def create_weight(self, gt, target, mask, moving_class_counts=None): return ( np_to_funlib_array( weights, - gt.roi, + gt.roi.offset, gt.voxel_size, - gt.axis_names, ), moving_class_counts, ) diff --git a/dacapo/experiments/tasks/predictors/one_hot_predictor.py b/dacapo/experiments/tasks/predictors/one_hot_predictor.py index 1ad7fdeec..ff6e21db6 100644 --- a/dacapo/experiments/tasks/predictors/one_hot_predictor.py +++ b/dacapo/experiments/tasks/predictors/one_hot_predictor.py @@ -30,7 +30,7 @@ class OneHotPredictor(Predictor): This is a subclass of Predictor. """ - def __init__(self, classes: List[str]): + def __init__(self, classes: List[str], kernel_size: int): """ Initialize the OneHotPredictor. @@ -42,6 +42,7 @@ def __init__(self, classes: List[str]): >>> predictor = OneHotPredictor(classes) """ self.classes = classes + self.kernel_size = kernel_size @property def embedding_dims(self): @@ -70,8 +71,17 @@ def create_model(self, architecture): Examples: >>> model = predictor.create_model(architecture) """ - head = torch.nn.Conv3d( - architecture.num_out_channels, self.embedding_dims, kernel_size=3 + + if architecture.dims == 3: + conv_layer = torch.nn.Conv3d + elif architecture.dims == 2: + conv_layer = torch.nn.Conv2d + else: + raise Exception(f"Unsupported number of dimensions: {architecture.dims}") + head = conv_layer( + architecture.num_out_channels, + self.embedding_dims, + kernel_size=self.kernel_size, ) return Model(architecture, head) diff --git a/dacapo/experiments/trainers/gunpowder_trainer.py b/dacapo/experiments/trainers/gunpowder_trainer.py index dcb40c115..507151ad7 100644 --- a/dacapo/experiments/trainers/gunpowder_trainer.py +++ b/dacapo/experiments/trainers/gunpowder_trainer.py @@ -173,6 +173,8 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None): assert isinstance(dataset.weight, int), dataset raw_source = gp.ArraySource(raw_key, dataset.raw) + if dataset.raw.channel_dims == 0: + raw_source += gp.Unsqueeze([raw_key], axis=0) if self.clip_raw: raw_source += gp.Crop( raw_key, dataset.gt.roi.snap_to_grid(dataset.raw.voxel_size) @@ -266,13 +268,13 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None): request.add(weight_key, output_size) request.add( mask_placeholder, - prediction_voxel_size * self.mask_integral_downsample_factor, + prediction_voxel_size, ) # request additional keys for snapshots request.add(gt_key, output_size) request.add(mask_key, output_size) request[mask_placeholder].roi = request[mask_placeholder].roi.snap_to_grid( - prediction_voxel_size * self.mask_integral_downsample_factor + prediction_voxel_size ) self._request = request diff --git a/dacapo/predict_local.py b/dacapo/predict_local.py index 674d00a40..f1760ff9f 100644 --- a/dacapo/predict_local.py +++ b/dacapo/predict_local.py @@ -44,10 +44,12 @@ def predict( else: input_roi = output_roi.grow(context, context) - read_roi = Roi((0, 0, 0), input_size) + read_roi = Roi((0,) * input_size.dims, input_size) write_roi = read_roi.grow(-context, -context) - axes = ["c^", "z", "y", "x"] + axes = raw_array.axis_names + if "c^" not in axes: + axes = ["c^"] + axes num_channels = model.num_out_channels @@ -71,6 +73,12 @@ def predict( compute_context = create_compute_context() device = compute_context.device + model_device = str(next(model.parameters()).device).split(":")[0] + + assert model_device == str( + device + ), f"Model is not on the right device, Model: {model_device}, Compute device: {device}" + def predict_fn(block): raw_input = raw_array.to_ndarray(block.read_roi) @@ -97,7 +105,7 @@ def predict_fn(block): predictions = Array( predictions, block.write_roi.offset, - raw_array.voxel_size, + output_voxel_size, axis_names, raw_array.units, ) @@ -114,7 +122,7 @@ def predict_fn(block): task = daisy.Task( f"predict_{out_container}_{out_dataset}", total_roi=input_roi, - read_roi=Roi((0, 0, 0), input_size), + read_roi=Roi((0,) * input_size.dims, input_size), write_roi=Roi(context, output_size), process_function=predict_fn, check_function=None, diff --git a/dacapo/utils/balance_weights.py b/dacapo/utils/balance_weights.py index e713745c6..5a53852ef 100644 --- a/dacapo/utils/balance_weights.py +++ b/dacapo/utils/balance_weights.py @@ -69,6 +69,9 @@ def balance_weights( scale_slab *= np.take(w, labels_slab) """ + if label_data.dtype == bool: + label_data = label_data.astype(np.uint8) + if moving_counts is None: moving_counts = [] unique_labels = np.unique(label_data) diff --git a/dacapo/validate.py b/dacapo/validate.py index 4e091ff55..6e92430c9 100644 --- a/dacapo/validate.py +++ b/dacapo/validate.py @@ -246,6 +246,9 @@ def validate_run(run: Run, iteration: int, datasets_config=None): # validation_dataset.name, # criterion, # ) + dataset_iteration_scores.append( + [getattr(scores, criterion) for criterion in scores.criteria] + ) except: logger.error( f"Could not evaluate run {run.name} on dataset {validation_dataset.name} with parameters {parameters}.", @@ -257,10 +260,6 @@ def validate_run(run: Run, iteration: int, datasets_config=None): # the evaluator # array_store.remove(output_array_identifier) - dataset_iteration_scores.append( - [getattr(scores, criterion) for criterion in scores.criteria] - ) - iteration_scores.append(dataset_iteration_scores) # array_store.remove(prediction_array_identifier) diff --git a/tests/conf.py b/tests/conf.py deleted file mode 100644 index 57a8708d5..000000000 --- a/tests/conf.py +++ /dev/null @@ -1,3 +0,0 @@ -import multiprocessing as mp - -mp.set_start_method("fork", force=True) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 000000000..9a90c5cab --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,28 @@ +import multiprocessing as mp +import os +import yaml + +from dacapo.options import Options + +import pytest + + +@pytest.fixture(params=["fork", "spawn"], autouse=True) +def context(monkeypatch): + ctx = mp.get_context("spawn") + monkeypatch.setattr(mp, "Queue", ctx.Queue) + monkeypatch.setattr(mp, "Process", ctx.Process) + monkeypatch.setattr(mp, "Event", ctx.Event) + monkeypatch.setattr(mp, "Value", ctx.Value) + + +@pytest.fixture(autouse=True) +def runs_base_dir(tmpdir): + options_file = tmpdir / "dacapo.yaml" + os.environ["DACAPO_OPTIONS_FILE"] = f"{options_file}" + + with open(options_file, "w") as f: + options_file.write(yaml.safe_dump({"runs_base_dir": f"{tmpdir}"})) + + assert Options.config_file() == options_file + assert Options.instance().runs_base_dir == tmpdir diff --git a/tests/fixtures/__init__.py b/tests/fixtures/__init__.py index 3ea282acc..e0d4a47a0 100644 --- a/tests/fixtures/__init__.py +++ b/tests/fixtures/__init__.py @@ -1,11 +1,33 @@ from .db import options -from .architectures import dummy_architecture +from .architectures import ( + dummy_architecture, + unet_architecture, + unet_3d_architecture, +) from .arrays import dummy_array, zarr_array, cellmap_array -from .datasplits import dummy_datasplit, twelve_class_datasplit, six_class_datasplit +from .datasplits import ( + dummy_datasplit, + twelve_class_datasplit, + six_class_datasplit, + upsample_six_class_datasplit, +) from .evaluators import binary_3_channel_evaluator from .losses import dummy_loss from .post_processors import argmax, threshold from .predictors import distance_predictor, onehot_predictor -from .runs import dummy_run, distance_run, onehot_run -from .tasks import dummy_task, distance_task, onehot_task +from .runs import ( + dummy_run, + distance_run, + onehot_run, + unet_2d_distance_run, + unet_3d_distance_run, + hot_distance_run, +) +from .tasks import ( + dummy_task, + distance_task, + onehot_task, + six_onehot_task, + hot_distance_task, +) from .trainers import dummy_trainer, gunpowder_trainer diff --git a/tests/fixtures/architectures.py b/tests/fixtures/architectures.py index 6980c8f6b..79e7f9fca 100644 --- a/tests/fixtures/architectures.py +++ b/tests/fixtures/architectures.py @@ -1,4 +1,7 @@ -from dacapo.experiments.architectures import DummyArchitectureConfig +from dacapo.experiments.architectures import ( + DummyArchitectureConfig, + CNNectomeUNetConfig, +) import pytest @@ -8,3 +11,36 @@ def dummy_architecture(): yield DummyArchitectureConfig( name="dummy_architecture", num_in_channels=1, num_out_channels=12 ) + + +@pytest.fixture() +def unet_architecture(): + yield CNNectomeUNetConfig( + name="tmp_unet_architecture", + input_shape=(1, 132, 132), + eval_shape_increase=(1, 32, 32), + fmaps_in=1, + num_fmaps=8, + fmaps_out=8, + fmap_inc_factor=2, + downsample_factors=[(1, 4, 4), (1, 4, 4)], + kernel_size_down=[[(1, 3, 3)] * 2] * 3, + kernel_size_up=[[(1, 3, 3)] * 2] * 2, + constant_upsample=True, + padding="valid", + ) + + +@pytest.fixture() +def unet_3d_architecture(): + yield CNNectomeUNetConfig( + name="tmp_unet_3d_architecture", + input_shape=(188, 188, 188), + eval_shape_increase=(72, 72, 72), + fmaps_in=1, + num_fmaps=6, + fmaps_out=6, + fmap_inc_factor=2, + downsample_factors=[(2, 2, 2), (2, 2, 2), (2, 2, 2)], + constant_upsample=True, + ) diff --git a/tests/fixtures/datasplits.py b/tests/fixtures/datasplits.py index 448c9c834..e94aee0c6 100644 --- a/tests/fixtures/datasplits.py +++ b/tests/fixtures/datasplits.py @@ -73,10 +73,10 @@ def twelve_class_datasplit(tmp_path): gt_dataset[:] += random_data > i raw_dataset[:] = random_data raw_dataset.attrs["offset"] = (0, 0, 0) - raw_dataset.attrs["resolution"] = (2, 2, 2) + raw_dataset.attrs["voxel_size"] = (2, 2, 2) raw_dataset.attrs["axis_names"] = ("z", "y", "x") gt_dataset.attrs["offset"] = (0, 0, 0) - gt_dataset.attrs["resolution"] = (2, 2, 2) + gt_dataset.attrs["voxel_size"] = (2, 2, 2) gt_dataset.attrs["axis_names"] = ("z", "y", "x") crop1 = RawGTDatasetConfig(name="crop1", raw_config=crop1_raw, gt_config=crop1_gt) @@ -184,10 +184,127 @@ def six_class_datasplit(tmp_path): gt_dataset[:] += random_data > i raw_dataset[:] = random_data raw_dataset.attrs["offset"] = (0, 0, 0) - raw_dataset.attrs["resolution"] = (2, 2, 2) + raw_dataset.attrs["voxel_size"] = (2, 2, 2) raw_dataset.attrs["axis_names"] = ("z", "y", "x") gt_dataset.attrs["offset"] = (0, 0, 0) - gt_dataset.attrs["resolution"] = (2, 2, 2) + gt_dataset.attrs["voxel_size"] = (2, 2, 2) + gt_dataset.attrs["axis_names"] = ("z", "y", "x") + + crop1 = RawGTDatasetConfig( + name="crop1", raw_config=crop1_raw, gt_config=crop1_distances + ) + crop2 = RawGTDatasetConfig( + name="crop2", raw_config=crop2_raw, gt_config=crop2_distances + ) + crop3 = RawGTDatasetConfig( + name="crop3", raw_config=crop3_raw, gt_config=crop3_distances + ) + + six_class_distances_datasplit_config = TrainValidateDataSplitConfig( + name="six_class_distances_datasplit", + train_configs=[crop1, crop2], + validate_configs=[crop3], + ) + return six_class_distances_datasplit_config + + +@pytest.fixture() +def upsample_six_class_datasplit(tmp_path): + """ + two crops for training, one for validation. Raw data is normally distributed + around 0 with std 1. + gt is provided as distances. First, gt is generated as a 12 class problem: + gt has 12 classes where class i in [0, 11] is all voxels with raw intensity + between (raw.min() + i(raw.max()-raw.min())/12, raw.min() + (i+1)(raw.max()-raw.min())/12). + Then we pair up classes (i, i+1) for i in [0,2,4,6,8,10], and compute distances to + the nearest voxel in the pair. This leaves us with 6 distance channels. + """ + twelve_class_zarr = zarr.open(tmp_path / "twelve_class.zarr", "w") + crop1_raw = ZarrArrayConfig( + name="crop1_raw", + file_name=tmp_path / "twelve_class.zarr", + dataset=f"volumes/crop1/raw", + ) + crop1_gt = ZarrArrayConfig( + name="crop1_gt", + file_name=tmp_path / "twelve_class.zarr", + dataset=f"volumes/crop1/gt", + ) + crop1_distances = BinarizeArrayConfig( + "crop1_distances", + source_array_config=crop1_gt, + groupings=[ + ("a", [0, 1]), + ("b", [2, 3]), + ("c", [4, 5]), + ("d", [6, 7]), + ("e", [8, 9]), + ("f", [10, 11]), + ], + ) + crop2_raw = ZarrArrayConfig( + name="crop2_raw", + file_name=tmp_path / "twelve_class.zarr", + dataset=f"volumes/crop2/raw", + ) + crop2_gt = ZarrArrayConfig( + name="crop2_gt", + file_name=tmp_path / "twelve_class.zarr", + dataset=f"volumes/crop2/gt", + ) + crop2_distances = BinarizeArrayConfig( + "crop2_distances", + source_array_config=crop2_gt, + groupings=[ + ("a", [0, 1]), + ("b", [2, 3]), + ("c", [4, 5]), + ("d", [6, 7]), + ("e", [8, 9]), + ("f", [10, 11]), + ], + ) + crop3_raw = ZarrArrayConfig( + name="crop3_raw", + file_name=tmp_path / "twelve_class.zarr", + dataset=f"volumes/crop3/raw", + ) + crop3_gt = ZarrArrayConfig( + name="crop3_gt", + file_name=tmp_path / "twelve_class.zarr", + dataset=f"volumes/crop3/gt", + ) + crop3_distances = BinarizeArrayConfig( + "crop3_distances", + source_array_config=crop3_gt, + groupings=[ + ("a", [0, 1]), + ("b", [2, 3]), + ("c", [4, 5]), + ("d", [6, 7]), + ("e", [8, 9]), + ("f", [10, 11]), + ], + ) + for raw, gt in zip( + [crop1_raw, crop2_raw, crop3_raw], [crop1_gt, crop2_gt, crop3_gt] + ): + raw_dataset = twelve_class_zarr.create_dataset( + raw.dataset, shape=(40, 20, 20), dtype=np.float32 + ) + gt_dataset = twelve_class_zarr.create_dataset( + gt.dataset, shape=(40, 20, 20), dtype=np.uint8 + ) + random_data = np.random.rand(40, 20, 20) + # as intensities increase so does the class + for i in list(np.linspace(random_data.min(), random_data.max(), 13))[1:]: + gt_dataset[:] += random_data > i + raw_dataset[:] = random_data + raw_dataset.attrs["offset"] = (0, 0, 0) + raw_dataset.attrs["voxel_size"] = (4, 4, 4) + raw_dataset.attrs["axis_names"] = ("z", "y", "x") + gt_dataset.attrs["offset"] = (0, 0, 0) + gt_dataset.attrs["voxel_size"] = (2, 2, 2) gt_dataset.attrs["axis_names"] = ("z", "y", "x") crop1 = RawGTDatasetConfig( diff --git a/tests/fixtures/predictors.py b/tests/fixtures/predictors.py index cc93369cf..c6dd6de51 100644 --- a/tests/fixtures/predictors.py +++ b/tests/fixtures/predictors.py @@ -10,4 +10,4 @@ def distance_predictor(): @pytest.fixture() def onehot_predictor(): - yield OneHotPredictor(classes=["a", "b", "c"]) + yield OneHotPredictor(classes=["a", "b", "c"], kernel_size=1) diff --git a/tests/fixtures/runs.py b/tests/fixtures/runs.py index 99c4d3269..c842db118 100644 --- a/tests/fixtures/runs.py +++ b/tests/fixtures/runs.py @@ -17,7 +17,25 @@ def distance_run( trainer_config=gunpowder_trainer, datasplit_config=six_class_datasplit, repetition=0, - num_iterations=100, + num_iterations=10, + ) + + +@pytest.fixture() +def hot_distance_run( + six_class_datasplit, + dummy_architecture, + hot_distance_task, + gunpowder_trainer, +): + yield RunConfig( + name="hot_distance_run", + task_config=hot_distance_task, + architecture_config=dummy_architecture, + trainer_config=gunpowder_trainer, + datasplit_config=six_class_datasplit, + repetition=0, + num_iterations=10, ) @@ -35,7 +53,7 @@ def dummy_run( trainer_config=dummy_trainer, datasplit_config=dummy_datasplit, repetition=0, - num_iterations=100, + num_iterations=10, ) @@ -53,5 +71,41 @@ def onehot_run( trainer_config=gunpowder_trainer, datasplit_config=twelve_class_datasplit, repetition=0, - num_iterations=100, + num_iterations=10, + ) + + +@pytest.fixture() +def unet_2d_distance_run( + six_class_datasplit, + unet_architecture, + distance_task, + gunpowder_trainer, +): + yield RunConfig( + name="unet_2d_distance_run", + task_config=distance_task, + architecture_config=unet_architecture, + trainer_config=gunpowder_trainer, + datasplit_config=six_class_datasplit, + repetition=0, + num_iterations=10, + ) + + +@pytest.fixture() +def unet_3d_distance_run( + six_class_datasplit, + unet_3d_architecture, + distance_task, + gunpowder_trainer, +): + yield RunConfig( + name="unet_3d_distance_run", + task_config=distance_task, + architecture_config=unet_3d_architecture, + trainer_config=gunpowder_trainer, + datasplit_config=six_class_datasplit, + repetition=0, + num_iterations=10, ) diff --git a/tests/fixtures/tasks.py b/tests/fixtures/tasks.py index fcd2c673e..5792811b4 100644 --- a/tests/fixtures/tasks.py +++ b/tests/fixtures/tasks.py @@ -2,6 +2,7 @@ DistanceTaskConfig, DummyTaskConfig, OneHotTaskConfig, + HotDistanceTaskConfig, ) import pytest @@ -28,9 +29,36 @@ def distance_task(): ) +@pytest.fixture() +def hot_distance_task(): + yield HotDistanceTaskConfig( + name="hot_distance_task", + channels=[ + "a", + "b", + "c", + "d", + "e", + "f", + ], + clip_distance=5, + tol_distance=10, + ) + + @pytest.fixture() def onehot_task(): yield OneHotTaskConfig( name="one_hot_task", classes=["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l"], + kernel_size=1, + ) + + +@pytest.fixture() +def six_onehot_task(): + yield OneHotTaskConfig( + name="one_hot_task", + classes=["a", "b", "c", "d", "e", "f"], + kernel_size=1, ) diff --git a/tests/operations/helpers.py b/tests/operations/helpers.py new file mode 100644 index 000000000..74fb43208 --- /dev/null +++ b/tests/operations/helpers.py @@ -0,0 +1,166 @@ +import numpy as np +from funlib.persistence import prepare_ds +from funlib.geometry import Coordinate + +from dacapo.experiments.datasplits import SimpleDataSplitConfig +from dacapo.experiments.tasks import ( + DistanceTaskConfig, + OneHotTaskConfig, + AffinitiesTaskConfig, +) +from dacapo.experiments.architectures import CNNectomeUNetConfig + +from pathlib import Path + + +def build_test_data_config( + tmpdir: Path, data_dims: int, channels: bool, upsample: bool, task_type: str +): + """ + Builds the simplest possible datasplit given the parameters. + + Labels are alternating planes/lines of 0/1 in the last dimension. + Intensities are random where labels are > 0, else 0. (If channels, stack twice.) + if task_type is "semantic", labels are binarized via labels > 0. + + if upsampling, labels are upsampled by a factor of 2 in each dimension + """ + + data_shape = (32, 32, 32)[-data_dims:] + axis_names = ["z", "y", "x"][-data_dims:] + mesh = np.meshgrid( + *[np.linspace(0, dim - 1, dim * (1 + upsample)) for dim in data_shape] + ) + labels = mesh[-1] * (mesh[-1] % 2 > 0.75) + + intensities = np.random.rand(*labels.shape) * labels > 0 + + if channels: + intensities = np.stack([intensities, intensities], axis=0) + + intensities_array = prepare_ds( + tmpdir / "test_data.zarr/raw", + intensities.shape, + offset=(0,) * data_dims, + voxel_size=(2,) * data_dims, + axis_names=["c^"] * int(channels) + axis_names, + dtype=intensities.dtype, + mode="w", + ) + intensities_array[:] = intensities + + if task_type == "semantic": + labels = labels > 0 + + labels_array = prepare_ds( + tmpdir / "test_data.zarr/labels", + labels.shape, + offset=(0,) * data_dims, + voxel_size=(2 - upsample,) * data_dims, + axis_names=axis_names, + dtype=labels.dtype, + mode="w", + ) + labels_array[:] = labels + + return SimpleDataSplitConfig(name="test_data", path=tmpdir / "test_data.zarr") + + +def build_test_task_config(task, data_dims: int, architecture_dims: int): + """ + Build the simplest task config given the parameters. + """ + if task == "distance": + return DistanceTaskConfig( + name="test_distance_task", + channels=["fg"], + clip_distance=4, + tol_distance=4, + scale_factor=8, + ) + if task == "onehot": + return OneHotTaskConfig( + name="test_onehot_task", classes=["bg", "fg"], kernel_size=1 + ) + if task == "affs": + # TODO: should configs be able to take any sequence for the neighborhood? + if data_dims == 2: + # 2D + neighborhood = [Coordinate(1, 0), Coordinate(0, 1)] + elif data_dims == 3 and architecture_dims == 2: + # 3D but only generate 2D affs + neighborhood = [Coordinate(0, 1, 0), Coordinate(0, 0, 1)] + elif data_dims == 3 and architecture_dims == 3: + # 3D + neighborhood = [ + Coordinate(1, 0, 0), + Coordinate(0, 1, 0), + Coordinate(0, 0, 1), + ] + return AffinitiesTaskConfig(name="test_affs_task", neighborhood=neighborhood) + + +def build_test_architecture_config( + data_dims: int, + architecture_dims: int, + channels: bool, + batch_norm: bool, + upsample: bool, + use_attention: bool, + padding: str, +): + """ + Build the simplest architecture config given the parameters. + """ + if data_dims == 2: + input_shape = (32, 32) + eval_shape_increase = (8, 8) + downsample_factors = [(2, 2)] + upsample_factors = [(2, 2)] * int(upsample) + + kernel_size_down = [[(3, 3)] * 2] * 2 + kernel_size_up = [[(3, 3)] * 2] * 1 + kernel_size_down = None # the default should work + kernel_size_up = None # the default should work + + elif data_dims == 3 and architecture_dims == 2: + input_shape = (1, 32, 32) + eval_shape_increase = (15, 8, 8) + downsample_factors = [(1, 2, 2)] + + # test data upsamples in all dimensions so we have + # to here too + upsample_factors = [(2, 2, 2)] * int(upsample) + + # we have to force the 3D kernels to be 2D + kernel_size_down = [[(1, 3, 3)] * 2] * 2 + kernel_size_up = [[(1, 3, 3)] * 2] * 1 + + elif data_dims == 3 and architecture_dims == 3: + input_shape = (32, 32, 32) + eval_shape_increase = (8, 8, 8) + downsample_factors = [(2, 2, 2)] + upsample_factors = [(2, 2, 2)] * int(upsample) + + kernel_size_down = [[(3, 3, 3)] * 2] * 2 + kernel_size_up = [[(3, 3, 3)] * 2] * 1 + kernel_size_down = None # the default should work + kernel_size_up = None # the default should work + + return CNNectomeUNetConfig( + name="test_cnnectome_unet", + input_shape=input_shape, + eval_shape_increase=eval_shape_increase, + fmaps_in=1 + channels, + num_fmaps=2, + fmaps_out=2, + fmap_inc_factor=2, + downsample_factors=downsample_factors, + kernel_size_down=kernel_size_down, + kernel_size_up=kernel_size_up, + constant_upsample=True, + upsample_factors=upsample_factors, + batch_norm=batch_norm, + use_attention=use_attention, + padding=padding, + ) diff --git a/tests/operations/test_architecture.py b/tests/operations/test_architecture.py new file mode 100644 index 000000000..e3e569a4b --- /dev/null +++ b/tests/operations/test_architecture.py @@ -0,0 +1,84 @@ +from ..fixtures import * + +import pytest +from pytest_lazy_fixtures import lf +import torch.nn as nn +from dacapo.experiments import Run +import logging + +logging.basicConfig(level=logging.INFO) + + +@pytest.mark.parametrize( + "architecture_config", + [ + lf("dummy_architecture"), + lf("unet_architecture"), + ], +) +def test_architecture( + architecture_config, +): + architecture = architecture_config.architecture_type(architecture_config) + assert architecture.dims is not None, f"Architecture dims are None {architecture}" + + +@pytest.mark.parametrize( + "architecture_config", + [ + lf("dummy_architecture"), + lf("unet_architecture"), + lf("unet_3d_architecture"), + ], +) +def test_stored_architecture( + architecture_config, +): + from dacapo.store.create_store import create_config_store + + config_store = create_config_store() + try: + config_store.store_architecture_config(architecture_config) + except: + config_store.delete_architecture_config(architecture_config.name) + config_store.store_architecture_config(architecture_config) + + retrieved_arch_config = config_store.retrieve_architecture_config( + architecture_config.name + ) + + architecture = retrieved_arch_config.architecture_type(retrieved_arch_config) + + assert architecture.dims is not None, f"Architecture dims are None {architecture}" + + +@pytest.mark.parametrize( + "architecture_config", + [ + lf("unet_3d_architecture"), + lf("unet_architecture"), + ], +) +def test_conv_dims( + architecture_config, +): + architecture = architecture_config.architecture_type(architecture_config) + for name, module in architecture.named_modules(): + if isinstance(module, nn.Conv2d): + raise ValueError(f"Conv2d found in 3d unet {name}") + + +@pytest.mark.parametrize( + "run_config", + [ + lf("unet_3d_distance_run"), + ], +) +def test_3d_conv_unet_in_run( + run_config, +): + run = Run(run_config) + model = run.model + for name, module in model.named_modules(): + if isinstance(module, nn.Conv2d): + raise ValueError(f"Conv2d found in 3d unet {name}") diff --git a/tests/operations/test_context.py b/tests/operations/test_context.py new file mode 100644 index 000000000..b2924e721 --- /dev/null +++ b/tests/operations/test_context.py @@ -0,0 +1,22 @@ +import torch +from dacapo.compute_context import create_compute_context +import pytest + + +@pytest.mark.parametrize("device", [""]) +def test_create_compute_context(device): + compute_context = create_compute_context() + assert compute_context is not None + assert compute_context.device is not None + if torch.cuda.is_available(): + assert compute_context.device == torch.device( + "cuda" + ), "Model is not on CUDA when CUDA is available {}".format( + compute_context.device + ) + else: + assert compute_context.device == torch.device( + "cpu" + ), "Model is not on CPU when CUDA is not available {}".format( + compute_context.device + ) diff --git a/tests/operations/test_mini.py b/tests/operations/test_mini.py new file mode 100644 index 000000000..f50705538 --- /dev/null +++ b/tests/operations/test_mini.py @@ -0,0 +1,85 @@ +from ..fixtures import * +from .helpers import ( + build_test_data_config, + build_test_task_config, + build_test_architecture_config, +) + +from dacapo.experiments import Run +from dacapo.train import train_run +from dacapo.validate import validate_run + +import pytest +from pytest_lazy_fixtures import lf + +from dacapo.experiments.run_config import RunConfig + +import pytest + + +# TODO: Move unet parameters that don't affect interaction with other modules +# to a separate architcture test +@pytest.mark.parametrize("data_dims", [2, 3]) +@pytest.mark.parametrize("channels", [True, False]) +@pytest.mark.parametrize("task", ["distance", "onehot", "affs"]) +@pytest.mark.parametrize("trainer", [lf("gunpowder_trainer")]) +@pytest.mark.parametrize("architecture_dims", [2, 3]) +@pytest.mark.parametrize("upsample", [True, False]) +# @pytest.mark.parametrize("batch_norm", [True, False]) +@pytest.mark.parametrize("batch_norm", [False]) +# @pytest.mark.parametrize("use_attention", [True, False]) +@pytest.mark.parametrize("use_attention", [False]) +@pytest.mark.parametrize("padding", ["valid", "same"]) +@pytest.mark.parametrize("func", ["train", "validate"]) +def test_mini( + tmpdir, + data_dims, + channels, + task, + trainer, + architecture_dims, + batch_norm, + upsample, + use_attention, + padding, + func, +): + # Invalid configurations: + if data_dims == 2 and architecture_dims == 3: + # cannot train a 3D model on 2D data + # TODO: maybe check that an appropriate warning is raised somewhere + return + + data_config = build_test_data_config( + tmpdir, + data_dims, + channels, + upsample, + "instance" if task == "affs" else "semantic", + ) + task_config = build_test_task_config(task, data_dims, architecture_dims) + architecture_config = build_test_architecture_config( + data_dims, + architecture_dims, + channels, + batch_norm, + upsample, + use_attention, + padding, + ) + + run_config = RunConfig( + name=f"test_{func}", + task_config=task_config, + architecture_config=architecture_config, + trainer_config=trainer, + datasplit_config=data_config, + repetition=0, + num_iterations=1, + ) + run = Run(run_config) + + if func == "train": + train_run(run) + elif func == "validate": + validate_run(run, 1) diff --git a/tests/operations/test_train.py b/tests/operations/test_train.py index a852101be..ad45b848c 100644 --- a/tests/operations/test_train.py +++ b/tests/operations/test_train.py @@ -9,9 +9,7 @@ import pytest from pytest_lazy_fixtures import lf -import logging - -logging.basicConfig(level=logging.INFO) +import pytest # skip the test for the Apple Paravirtual device @@ -23,9 +21,10 @@ lf("distance_run"), lf("dummy_run"), lf("onehot_run"), + lf("hot_distance_run"), ], ) -def test_train( +def test_large( options, run_config, ): diff --git a/tests/operations/test_validate.py b/tests/operations/test_validate.py index 860f941e9..8a4d8cf26 100644 --- a/tests/operations/test_validate.py +++ b/tests/operations/test_validate.py @@ -1,6 +1,3 @@ -import os -from upath import UPath as Path -import shutil from ..fixtures import * from dacapo.experiments import Run @@ -22,78 +19,22 @@ lf("onehot_run"), ], ) -def test_validate( +def test_large( options, run_config, ): - # set debug to True to run the test in a specific directory (for debugging) - debug = False - if debug: - tmp_path = f"{Path(__file__).parent}/tmp" - if os.path.exists(tmp_path): - shutil.rmtree(tmp_path, ignore_errors=True) - os.makedirs(tmp_path, exist_ok=True) - old_path = os.getcwd() - os.chdir(tmp_path) - # when done debugging, delete "tests/operations/tmp" - # ------------------------------------- store = create_config_store() + weights_store = create_weights_store() store.store_run_config(run_config) + # validate validate(run_config.name, 0) - # weights_store.store_weights(run, 1) - # validate_run(run_config.name, 1) + + # validate_run + run = Run(run_config) + weights_store.store_weights(run, 1) + validate_run(run, 1) # test validating weights that don't exist with pytest.raises(FileNotFoundError): validate(run_config.name, 2) - - if debug: - os.chdir(old_path) - - -@pytest.mark.parametrize( - "run_config", - [ - lf("distance_run"), - lf("onehot_run"), - ], -) -def test_validate_run( - options, - run_config, -): - # set debug to True to run the test in a specific directory (for debugging) - debug = False - if debug: - tmp_path = f"{Path(__file__).parent}/tmp" - if os.path.exists(tmp_path): - shutil.rmtree(tmp_path, ignore_errors=True) - os.makedirs(tmp_path, exist_ok=True) - old_path = os.getcwd() - os.chdir(tmp_path) - # when done debugging, delete "tests/operations/tmp" - # ------------------------------------- - - # create a store - - store = create_config_store() - weights_store = create_weights_store() - - # store the configs - - store.store_run_config(run_config) - - run_config = store.retrieve_run_config(run_config.name) - run = Run(run_config) - - # ------------------------------------- - - # validate - - # test validating iterations for which we know there are weights - weights_store.store_weights(run, 0) - validate_run(run, 0) - - if debug: - os.chdir(old_path)