diff --git a/src/nimbus_inference/utils.py b/src/nimbus_inference/utils.py index fc95a23..e22956d 100644 --- a/src/nimbus_inference/utils.py +++ b/src/nimbus_inference/utils.py @@ -120,6 +120,9 @@ def handle_qupath_segmentation_map(instance_mask: np.array): logging.warning("QuPath RGB segmentation map detected. Converting to instance map by") logging.warning("combining the RGB channels into a single channel via the following formula:") logging.warning("label = RED*256**2 + GREEN * 256 + BLUE") + # move channel axis to last if not already + if instance_mask.shape.index(3) == 0: + instance_mask = np.moveaxis(instance_mask, 0, -1) instance_mask_handled = instance_mask[..., 0] * 256**2 + instance_mask[..., 1] * 256 \ + instance_mask[..., 2] instance_mask_handled = instance_mask_handled.round(0).astype(np.uint64) @@ -263,10 +266,11 @@ def get_segmentation(self, fov: str): fov_path = self.fov_paths[idx] instance_path = self.segmentation_naming_convention(fov_path) if isinstance(instance_path, str): - instance_mask = np.squeeze(io.imread(instance_path)) + instance_mask = io.imread(instance_path) else: - instance_mask = np.squeeze(instance_path) - if len(instance_mask.shape) == 3 and instance_mask.shape[-1] == 3: + instance_mask = instance_path + instance_mask = np.squeeze(instance_mask) + if len(instance_mask.shape) == 3: instance_mask = handle_qupath_segmentation_map(instance_mask) instance_mask = instance_mask.astype(np.uint32) return instance_mask