diff --git a/stereo/image/__init__.py b/stereo/image/__init__.py index 87ee1c8d..e26a5b8b 100644 --- a/stereo/image/__init__.py +++ b/stereo/image/__init__.py @@ -4,7 +4,7 @@ try: from .pyramid import merge_pyramid, create_pyramid from .segmentation.segment import cell_seg - from .cellbin.modules.cell_segmentation import cell_seg_v3 + from .segmentation.seg_utils.v3 import CellSegPipeV3 as cell_seg_v3 from . import tissue_cut from .segmentation_deepcell.segment import cell_seg_deepcell diff --git a/stereo/image/cellbin/dnn/__init__.py b/stereo/image/cellbin/dnn/__init__.py deleted file mode 100644 index d4b9e427..00000000 --- a/stereo/image/cellbin/dnn/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Deep Neural Networks (dnn module) -# (Pytorch, TensorFlow) models with ONNX. -# In this section you will find the functions, which describe how to run classification, segmentation and detection -# DNN models with ONNX. - -from abc import ABC -from abc import abstractmethod - - -class BaseNet(ABC): - @abstractmethod - def _f_load_model(self): - return - - @abstractmethod - def f_predict(self, img): - return diff --git a/stereo/image/cellbin/dnn/cseg/__init__.py b/stereo/image/cellbin/dnn/cseg/__init__.py deleted file mode 100644 index ef003fa5..00000000 --- a/stereo/image/cellbin/dnn/cseg/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -from abc import ABC -from abc import abstractmethod - - -class CellSegmentation(ABC): - @abstractmethod - def f_predict(self, img): - """ - input img output cell mask - :param img:CHANGE - :return: 掩模大图 - """ - return diff --git a/stereo/image/cellbin/dnn/cseg/cell_trace.py b/stereo/image/cellbin/dnn/cseg/cell_trace.py deleted file mode 100644 index d8c7e663..00000000 --- a/stereo/image/cellbin/dnn/cseg/cell_trace.py +++ /dev/null @@ -1,13 +0,0 @@ -import cv2 - - -def get_trace(mask): - num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8) - h, w = mask.shape[: 2] - output = [] - for i in range(num_labels): - box_w, box_h, area = stats[i][2:] - if box_h == h and box_w == w: - continue - output.append([box_h, box_w, area]) - return output diff --git a/stereo/image/cellbin/dnn/cseg/detector.py b/stereo/image/cellbin/dnn/cseg/detector.py deleted file mode 100644 index 56575495..00000000 --- a/stereo/image/cellbin/dnn/cseg/detector.py +++ /dev/null @@ -1,71 +0,0 @@ -import numpy as np - -from stereo.image.cellbin.dnn.onnx_net import OnnxNet -from stereo.image.cellbin.image.wsi_split import SplitWSI -from . import CellSegmentation -from .predict import CellPredict -from .processing import ( - f_prepocess, - f_preformat, - f_postformat, - f_preformat_mesmer, - f_postformat_mesmer, - f_padding, - f_fusion -) - - -# TensorRT/ONNX -# HE/DAPI/mIF -class Segmentation(CellSegmentation): - - def __init__(self, model_path="", net="bcdu", mode="onnx", gpu="-1", num_threads=0, - win_size=(256, 256), intput_size=(256, 256, 1), overlap=16): - """ - - :param model_path: - :param net: - :param mode: - :param gpu: - :param num_threads: - """ - # self.PREPROCESS_SIZE = (8192, 8192) - - self._win_size = win_size - self._input_size = intput_size - self._overlap = overlap - - self._net = net - self._gpu = gpu - self._mode = mode - # self._model_path = model_path - self._model = None - self._sess = None - self._num_threads = num_threads - # self._f_init_model() - - def f_init_model(self, model_path): - """ - init model - """ - self._model = OnnxNet(model_path, self._gpu, self._num_threads) - - if self._net == "mesmer": - self._sess = CellPredict(self._model, f_preformat_mesmer, f_postformat_mesmer) - else: - self._sess = CellPredict(self._model, f_preformat, f_postformat) - - def f_predict(self, img): - """ - - :param img:CHANGE - :return: - """ - img = f_prepocess(img) - sp_run = SplitWSI(img, self._win_size, self._overlap, 100, True, True, False, np.uint8) - sp_run.f_set_run_fun(self._sess.f_predict) - sp_run.f_set_pre_fun(f_padding, self._win_size) - sp_run.f_set_fusion_fun(f_fusion) - _, _, pred = sp_run.f_split2run() - pred[pred > 0] = 1 - return pred diff --git a/stereo/image/cellbin/dnn/cseg/predict.py b/stereo/image/cellbin/dnn/cseg/predict.py deleted file mode 100644 index abef101e..00000000 --- a/stereo/image/cellbin/dnn/cseg/predict.py +++ /dev/null @@ -1,63 +0,0 @@ -from queue import Queue -from threading import Event -from threading import Thread - -import numpy as np - - -class CellPredict(object): - def __init__(self, model, f_preformat, f_postformat): - self._model = model - self._f_preformat = f_preformat - self._f_postformat = f_postformat - self._t_queue_maxsize = 100 - self._t_workdone = Event() - self._t_queue = Queue(maxsize=self._t_queue_maxsize) - - def _f_productor(self, img_lst): - self._t_workdone.set() - for img in img_lst: - val_sum = np.sum(img) - if val_sum <= 0.0: - pred = np.zeros(img.shape, np.uint8) - else: - pred = self._model.f_predict(self._f_preformat(img)) - self._t_queue.put([pred, val_sum], block=True) - self._t_workdone.clear() - return - - def _f_consumer(self, pred_lst): - while (self._t_workdone.is_set()) or (not self._t_queue.empty()): - pred, val_sum = self._t_queue.get(block=True) - if val_sum > 0: - pred = self._f_postformat(pred) - pred_lst.append(pred) - return - - def _f_clear(self): - self._t_queue = Queue(maxsize=self._t_queue_maxsize) - - def _run_batch(self, img_lst): - self._f_clear() - pred_lst = [] - t_productor = Thread(target=self._f_productor, args=(img_lst,)) - t_consumer = Thread(target=self._f_consumer, args=(pred_lst,)) - t_productor.start() - t_consumer.start() - t_productor.join() - t_consumer.join() - self._f_clear() - return pred_lst - - def f_predict(self, img_lst): - img = img_lst - - if isinstance(img_lst, list): - return self._run_batch(img_lst) - - if np.sum(img) < 1: - pred = np.zeros(img.shape, np.uint8) - else: - pred = self._model.f_predict(self._f_preformat(img)) - pred = self._f_postformat(pred) - return pred diff --git a/stereo/image/cellbin/dnn/cseg/processing.py b/stereo/image/cellbin/dnn/cseg/processing.py deleted file mode 100644 index cd38757a..00000000 --- a/stereo/image/cellbin/dnn/cseg/processing.py +++ /dev/null @@ -1,86 +0,0 @@ -import cv2 -import numpy as np - -from stereo.image.cellbin.image.augmentation import f_padding as f_pad -from stereo.image.cellbin.image.augmentation import ( - f_percentile_threshold, - f_histogram_normalization, - f_equalize_adapthist, - f_rgb2gray, - f_ij_16_to_8 -) -from stereo.image.cellbin.image.mask import f_instance2semantics -from stereo.image.cellbin.image.morphology import f_deep_watershed - - -def f_prepocess(img): - img = np.squeeze(img) - img = f_ij_16_to_8(img) - img = f_rgb2gray(img, True) - img = f_percentile_threshold(img) - img = f_equalize_adapthist(img, 128) - img = f_histogram_normalization(img) - img = np.array(img).astype(np.float32) - img = np.ascontiguousarray(img) - return img - - -def f_postpocess(pred): - pred = pred[0, :, :, 0] - - pred = f_instance2semantics(pred) - return pred - - -def f_preformat(img): - img = np.expand_dims(img, axis=2) - img = np.expand_dims(img, axis=0) - return img - - -def f_postformat(pred): - if not isinstance(pred, list): - pred = [pred] - pred = f_deep_watershed(pred, - maxima_threshold=0.1, - maxima_smooth=0, - interior_threshold=0.2, - interior_smooth=2, - fill_holes_threshold=15, - small_objects_threshold=15, - radius=10, - watershed_line=0) - return f_postpocess(pred) - - -def f_preformat_mesmer(img): - img = np.stack((img, img), axis=-1) - img = np.expand_dims(img, axis=0) - return img - - -def f_postformat_mesmer(pred): - if isinstance(pred, list): - pred = [pred[0], pred[1][..., 1:2]] - pred = f_deep_watershed(pred, - maxima_threshold=0.075, - maxima_smooth=0, - interior_threshold=0.2, - interior_smooth=2, - small_objects_threshold=15, - fill_holes_threshold=15, - radius=2, - watershed_line=0) - return f_postpocess(pred) - - -def f_padding(img, shape, mode='constant'): - h, w = img.shape[:2] - win_h, win_w = shape[:2] - img = f_pad(img, 0, abs(win_h - h), 0, abs(win_w - w), mode) - return img - - -def f_fusion(img1, img2): - img1 = cv2.bitwise_or(img1, img2) - return img1 diff --git a/stereo/image/cellbin/dnn/onnx_net.py b/stereo/image/cellbin/dnn/onnx_net.py deleted file mode 100644 index 1e37bb02..00000000 --- a/stereo/image/cellbin/dnn/onnx_net.py +++ /dev/null @@ -1,44 +0,0 @@ -from os import path - -import onnxruntime - -from . import BaseNet - - -class OnnxNet(BaseNet): - def __init__(self, model_path, gpu="-1", num_threads=0): - super(OnnxNet, self).__init__() - self._providers = ['CPUExecutionProvider'] - self._providers_id = [{'device_id': -1}] - self._model = None - self._gpu = int(gpu) - self._model_path = model_path - self._input_name = 'input_1' - self._output_name = None - self._num_threads = num_threads - self._f_init() - - def _f_init(self): - if self._gpu > -1: - self._providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] - self._providers_id = [{'device_id': self._gpu}, {'device_id': -1}] - self._f_load_model() - - def _f_load_model(self): - if path.exists(self._model_path): - sessionOptions = onnxruntime.SessionOptions() - if (self._gpu < 0) and (self._num_threads > 0): - sessionOptions.intra_op_num_threads = self._num_threads - self._model = onnxruntime.InferenceSession( - self._model_path, - providers=self._providers, - provider_options=self._providers_id, - sess_options=sessionOptions - ) - self._input_name = self._model.get_inputs()[0].name - else: - raise Exception(f"Weight path '{self._model_path}' does not exist") - - def f_predict(self, data): - pred = self._model.run(self._output_name, {self._input_name: data}) - return pred diff --git a/stereo/image/cellbin/image/__init__.py b/stereo/image/cellbin/image/__init__.py deleted file mode 100644 index 1e33d235..00000000 --- a/stereo/image/cellbin/image/__init__.py +++ /dev/null @@ -1,90 +0,0 @@ -import os - -import cv2 -import numpy as np -import tifffile - - -class Image(object): - def __init__(self): - self.suffix: str = '' - self.image = None - self.channel: int = 1 - self.dtype = None - self.width: int = 0 - self.height: int = 0 - self.depth: int = 8 - self.ndim = 1 - - def read(self, image, buffer=None): - """ - update by dengzhonghan on 2023/3/1 - - support zeiss 2 channel image (channel at first) - - support get specific channel - - Args: - image (): image path in string format or image in numpy array format - - Returns: - 1: Fail - 0: Success - - """ - if type(image) is str: - self.suffix = os.path.splitext(image)[1] - if self.suffix in ['.tif', '.tiff']: - self.image = tifffile.imread(image) # 3 channel is RGB?? - elif self.suffix in ['.png']: - self.image = cv2.imread(image, -1) - else: - return 1 - elif type(image) is np.ndarray: - self.image = image - elif type(image) is list and len(image) == 4: - assert buffer is not None - y0, y1, x0, x1 = image - if buffer.ndim == 3: - self.image = buffer[y0: y1, x0: x1, :] - else: - self.image = buffer[y0: y1, x0: x1] - else: - return 1 - if self.image is None or len(self.image) == 0: - raise Exception(f"Reading {image} error!") - self.ndim = self.image.ndim - self.dtype = self.image.dtype - if self.dtype == 'uint8': - self.depth = 8 - elif self.dtype == 'uint16': - self.depth = 16 - - if self.ndim == 3: - shape = self.image.shape - if shape[0] in [1, 2, 3, 4]: - self.image = self.image.transpose(1, 2, 0) - self.height, self.width, self.channel = self.image.shape - else: - self.height, self.width = self.image.shape - self.channel = 1 - - return 0 - - @staticmethod - def write_s(image, output_path: str, compression=False): - try: - if compression: - tifffile.imwrite(output_path, image, compression="zlib", compressionargs={"level": 8}) - else: - tifffile.imwrite(output_path, image) - except Exception as e: - print(e) - print("Write image has some error, will write without compression.") - tifffile.imwrite(output_path, image) - - def get_channel(self, ch): - if self.channel == 1 or ch == -1: - return - else: - self.image = np.array(self.image[:, :, ch]) # cv circle raise error if no np.array - self.channel = 1 - return diff --git a/stereo/image/cellbin/image/augmentation.py b/stereo/image/cellbin/image/augmentation.py deleted file mode 100644 index 8bbe6acd..00000000 --- a/stereo/image/cellbin/image/augmentation.py +++ /dev/null @@ -1,273 +0,0 @@ -import copy - -import cv2 -import numpy as np -from PIL import Image -from skimage.exposure import equalize_adapthist -from skimage.exposure import rescale_intensity - - -def f_rgb2gray(img, need_not=False): - """ - rgb2gray - - :param img: (CHANGE) np.array - :param need_not: if need bitwise_not - :return: np.array - """ - if img.ndim == 3: - if img.shape[0] == 3 and img.shape[1] > 3 and img.shape[2] > 3: - img = img.transpose(1, 2, 0) - img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) - if need_not: - img = cv2.bitwise_not(img) - return img - - -def f_gray2bgr(img): - """ - gray2bgr - - :param img: (CHANGE) np.array - :return: np.array - """ - - img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) - return img - - -def f_padding(img, top, bot, left, right, mode='constant', value=0): - """ - update by dengzhonghan on 2023/2/23 - 1. support 3d array padding. - 2. not support 1d array padding. - - Args: - img (): numpy ndarray (2D or 3D). - top (): number of values padded to the top direction. - bot (): number of values padded to the bottom direction. - left (): number of values padded to the left direction. - right (): number of values padded to the right direction. - mode (): padding mode in numpy, default is constant. - value (): constant value when using constant mode, default is 0. - - Returns: - pad_img: padded image. - - """ - - if mode == 'constant': - if img.ndim == 2: - pad_img = np.pad(img, ((top, bot), (left, right)), mode, constant_values=value) - elif img.ndim == 3: - pad_img = np.pad(img, ((top, bot), (left, right), (0, 0)), mode, constant_values=value) - else: - if img.ndim == 2: - pad_img = np.pad(img, ((top, bot), (left, right)), mode) - elif img.ndim == 3: - pad_img = np.pad(img, ((top, bot), (left, right), (0, 0)), mode) - return pad_img - - -def f_resize(img, shape=(1024, 2048), mode="NEAREST"): - """ - resize img with pillow - - :param img: (CHANGE) np.array - :param shape: tuple - :param mode: An optional resampling filter. This can be one of Resampling.NEAREST, - Resampling.BOX, Resampling.BILINEAR, Resampling.HAMMING, Resampling.BICUBIC or Resampling.LANCZOS. - If the image has mode “1” or “P”, it is always set to Resampling.NEAREST. - If the image mode specifies a number of bits, such as “I;16”, then the default filter is Resampling.NEAREST. - Otherwise, the default filter is Resampling.BICUBIC - :return:np.array - """ - imode = Image.NEAREST - if mode == "BILINEAR": - imode = Image.BILINEAR - elif mode == "BICUBIC": - imode = Image.BICUBIC - elif mode == "LANCZOS": - imode = Image.LANCZOS - elif mode == "HAMMING": - imode = Image.HAMMING - elif mode == "BOX": - imode = Image.BOX - if img.dtype != 'uint8': - imode = Image.NEAREST - img = Image.fromarray(img) - img = img.resize((shape[1], shape[0]), resample=imode) - img = np.array(img).astype(np.uint8) - return img - - -def f_percentile_threshold(img, percentile=99.9): - """ - Threshold an image to reduce bright spots - - :param img: (CHANGE) numpy array of image data - :param percentile: cutoff used to threshold image - :return: np.array: thresholded version of input image - """ - - non_zero_vals = img[img > 0] - - # only threshold if channel isn't blank - if len(non_zero_vals) > 0: - img_max = np.percentile(non_zero_vals, percentile) - - # threshold values down to max - threshold_mask = img > img_max - img[threshold_mask] = img_max - - return img - - -def f_equalize_adapthist(img, kernel_size=None): - """ - Pre-process images using Contrast Limited Adaptive - Histogram Equalization (CLAHE). - - :param img: (CHANGE) (numpy.array): numpy array of phase image data. - :param kernel_size: (integer): Size of kernel for CLAHE, - defaults to 1/8 of image size. - :return: numpy.array:Pre-processed image - """ - return equalize_adapthist(img, kernel_size=kernel_size) - - -def f_histogram_normalization(img): - """ - If one of the inputs is a constant-value array, it will - be normalized as an array of all zeros of the same shape. - - :param img: (CHANGE) (numpy.array): numpy array of phase image data. - :return: numpy.array:image data with dtype float32. - """ - - img = img.astype('float32') - sample_value = img[(0,) * img.ndim] - if (img == sample_value).all(): - return np.zeros_like(img) - img = rescale_intensity(img, out_range=(0.0, 1.0)) - - return img - - -def f_ij_16_to_8(img, chunk_size=1000): - """ - 16 bits img to 8 bits - - :param img: (CHANGE) np.array - :param chunk_size: chunk size (bit) - :return: np.array - """ - - if img.dtype == 'uint8': - return img - dst = np.zeros(img.shape, np.uint8) - p_max = np.max(img) - p_min = np.min(img) - scale = 256.0 / (p_max - p_min + 1) - for idx in range(img.shape[0] // chunk_size + 1): - sl = slice(idx * chunk_size, (idx + 1) * chunk_size) - win_img = copy.deepcopy(img[sl]) - win_img = np.int16(win_img) - win_img = (win_img & 0xffff) - win_img = win_img - p_min - win_img[win_img < 0] = 0 - win_img = win_img * scale + 0.5 - win_img[win_img > 255] = 255 - dst[sl] = np.array(win_img).astype(np.uint8) - return dst - - -def enhance(arr, mode, thresh): - """ - Only support 2D array - - Args: - arr (): 2D numpy array - mode (): enhance mode - thresh (): threshold - - Returns: - - """ - data = arr.ravel() - min_v = np.min(data) - data_ = data[np.where(data <= thresh)] - if len(data_) == 0: - return 0, 0 - if mode == 'median': - var_ = np.std(data_) - thr = np.median(data_) - max_v = thr + var_ - elif mode == 'hist': - freq_count, bins = np.histogram(data_, range(min_v, int(thresh + 1))) - count = np.sum(freq_count) - freq = freq_count / count - thr = bins[np.argmax(freq)] - max_v = thr + (thr - min_v) - else: - raise Exception('Only support median and histogram') - - return min_v, max_v - - -def encode(arr, min_v, max_v): - """ - Encode image with min and max pixel value - - Args: - arr (): 2D numpy array - min_v (): min value obtained from enhance method - max_v (): max value - - Returns: - mat: encoded mat - - """ - if min_v >= max_v: - arr = arr.astype(np.uint8) - return arr - mat = np.zeros((arr.shape[0], arr.shape[1]), dtype=np.uint8) - v_w = max_v - min_v - mat[arr < min_v] = 0 - mat[arr > max_v] = 255 - pos = (arr >= min_v) & (arr <= max_v) - mat[pos] = (arr[pos] - min_v) * (255 / v_w) - return mat - - -def f_ij_auto_contrast(img): - limit = img.size / 10 - threshold = img.size / 5000 - if img.dtype != 'uint8': - bit_max = 65536 - else: - bit_max = 256 - hist, _ = np.histogram(img.flatten(), 256, [0, bit_max]) - hmin = 0 - hmax = bit_max - 1 - for i in range(1, len(hist) - 1): - count = hist[i] - if count > limit: - continue - if count > threshold: - hmin = i - break - for i in range(len(hist) - 2, 0, -1): - count = hist[i] - if count > limit: - continue - if count > threshold: - hmax = i - break - if hmax > hmin: - hmax = int(hmax * bit_max / 256) - hmin = int(hmin * bit_max / 256) - img[img < hmin] = hmin - img[img > hmax] = hmax - cv2.normalize(img, img, 0, bit_max - 1, cv2.NORM_MINMAX) - return img diff --git a/stereo/image/cellbin/image/mask.py b/stereo/image/cellbin/image/mask.py deleted file mode 100644 index 68369963..00000000 --- a/stereo/image/cellbin/image/mask.py +++ /dev/null @@ -1,68 +0,0 @@ -import cv2 -import numpy as np - - -def f_fill_all_hole(mask_in): - """ - fill all holes in the mask - - :param mask_in: np.array np.uint8 - :return: np.array np.uint8 - """ - ''' 对二值图像进行孔洞填充 ''' - im_floodfill = cv2.copyMakeBorder(mask_in, 2, 2, 2, 2, cv2.BORDER_CONSTANT, value=[0]) - h, w = im_floodfill.shape[:2] - mask = np.zeros((h + 2, w + 2), np.uint8) - - cv2.floodFill(im_floodfill, mask, (0, 0), 255) - im_floodfill_inv = cv2.bitwise_not(im_floodfill[2:-2, 2:-2]) - - # Combine the two images to get the foreground. - return mask_in | im_floodfill_inv - - -def f_instance2semantics(ins): - """ - update by cenweixuan on 2023/3/07 - :param ins: - :return: - """ - h, w = ins.shape[:2] - tmp0 = ins[1:, 1:] - ins[:h - 1, :w - 1] - ind0 = np.where(tmp0 != 0) - - tmp1 = ins[1:, :w - 1] - ins[:h - 1, 1:] - ind1 = np.where(tmp1 != 0) - ins[ind1] = 0 - ins[ind0] = 0 - ins[np.where(ins > 0)] = 1 - return np.array(ins, dtype=np.uint8) - - -def iou(a, b, epsilon=1e-5): - """ - add by jqc on 2023/04/10 - Args: - a (): - b (): - epsilon (): - - Returns: - - """ - # 首先将a和b按照0/1的方式量化 - a = (a > 0).astype(int) - b = (b > 0).astype(int) - - # 计算交集(intersection) - intersection = np.logical_and(a, b) - intersection = np.sum(intersection) - - # 计算并集(union) - union = np.logical_or(a, b) - union = np.sum(union) - - # 计算IoU - iou = intersection / (union + epsilon) - - return iou diff --git a/stereo/image/cellbin/image/morphology.py b/stereo/image/cellbin/image/morphology.py deleted file mode 100644 index a5e01db1..00000000 --- a/stereo/image/cellbin/image/morphology.py +++ /dev/null @@ -1,238 +0,0 @@ -"""Functions for pre- and post-processing image data""" -from _warnings import warn - -import numpy as np -import scipy.ndimage as nd -from skimage.feature import peak_local_max -from skimage.measure import label -from skimage.measure import regionprops -from skimage.morphology import ( - disk, - ball, - square, - cube, - dilation -) -from skimage.morphology import h_maxima -from skimage.morphology import remove_small_holes -from skimage.morphology import remove_small_objects -from skimage.segmentation import find_boundaries -from skimage.segmentation import relabel_sequential -from skimage.segmentation import watershed - - -def f_deep_watershed(outputs, - radius=10, - maxima_threshold=0.1, - interior_threshold=0.01, - maxima_smooth=0, - interior_smooth=1, - maxima_index=0, - interior_index=-1, - label_erosion=0, - small_objects_threshold=0, - fill_holes_threshold=0, - pixel_expansion=None, - watershed_line=1, - maxima_algorithm='h_maxima', - **kwargs): - """ - Uses ``maximas`` and ``interiors`` to perform watershed segmentation. - ``maximas`` are used as the watershed seeds for each object and - ``interiors`` are used as the watershed mask. - - :param outputs:(list): List of [maximas, interiors] model outputs. - Use `maxima_index` and `interior_index` if list is longer than 2, - or if the outputs are in a different order. - :param radius: (int): Radius of disk used to search for maxima - :param maxima_threshold:(float): Threshold for the maxima prediction. - :param interior_threshold:(float): Threshold for the interior prediction. - :param maxima_smooth:(int): smoothing factor to apply to ``interiors``. - Use ``0`` for no smoothing. - :param interior_smooth:(int): smoothing factor to apply to ``interiors``. - Use ``0`` for no smoothing. - :param maxima_index:(int): The index of the maxima prediction in ``outputs``. - :param interior_index:(int): The index of the interior prediction in ``outputs``. - :param label_erosion:(int): Number of pixels to erode segmentation labels. - :param small_objects_threshold:(int): Removes objects smaller than this size. - :param fill_holes_threshold:(int): Maximum size for holes within segmented - objects to be filled. - :param pixel_expansion:(int): Number of pixels to expand ``interiors``. - :param watershed_line:(int): If need watershed line. - :param maxima_algorithm:(str): Algorithm used to locate peaks in ``maximas``. - One of ``h_maxima`` (default) or ``peak_local_max``. - ``peak_local_max`` is much faster but seems to underperform when - given regious of ambiguous maxima. - :param kwargs: - :return:numpy.array: Integer label mask for instance segmentation. - - Raises: - ValueError: ``outputs`` is not properly formatted. - """ - - try: - maximas = outputs[maxima_index] - interiors = outputs[interior_index] - except (TypeError, KeyError, IndexError): - raise ValueError('`outputs` should be a list of at least two ' - 'NumPy arryas of equal shape.') - - valid_algos = {'h_maxima', 'peak_local_max'} - if maxima_algorithm not in valid_algos: - raise ValueError('Invalid value for maxima_algorithm: {}. ' - 'Must be one of {}'.format(maxima_algorithm, valid_algos)) - - total_pixels = maximas.shape[1] * maximas.shape[2] - if maxima_algorithm == 'h_maxima' and total_pixels > 5000 ** 2: - print('h_maxima peak finding algorithm was selected, ' - 'but the provided image is larger than 5k x 5k pixels.' - 'This will lead to slow prediction performance.') - # Handle deprecated arguments - min_distance = kwargs.pop('min_distance', None) - if min_distance is not None: - radius = min_distance - warn('`min_distance` is now deprecated in favor of `radius`. ' - 'The value passed for `radius` will be used.') - - # distance_threshold vs interior_threshold - distance_threshold = kwargs.pop('distance_threshold', None) - if distance_threshold is not None: - interior_threshold = distance_threshold - warn('`distance_threshold` is now deprecated in favor of ' - '`interior_threshold`. The value passed for ' - '`distance_threshold` will be used.', - DeprecationWarning) - - # detection_threshold vs maxima_threshold - detection_threshold = kwargs.pop('detection_threshold', None) - if detection_threshold is not None: - maxima_threshold = detection_threshold - warn('`detection_threshold` is now deprecated in favor of ' - '`maxima_threshold`. The value passed for ' - '`detection_threshold` will be used.', - DeprecationWarning) - - if maximas.shape[:-1] != interiors.shape[:-1]: - raise ValueError( - 'All input arrays must have the same shape. Got {} and {}'.format(maximas.shape, interiors.shape)) - - if maximas.ndim not in {4, 5}: - raise ValueError('maxima and interior tensors must be rank 4 or 5. ' - 'Rank 4 is 2D data of shape (batch, x, y, c). ' - 'Rank 5 is 3D data of shape (batch, frames, x, y, c).') - - input_is_3d = maximas.ndim > 4 - - # fill_holes is not supported in 3D - if fill_holes_threshold and input_is_3d: - warn('`fill_holes` is not supported for 3D data.') - fill_holes_threshold = 0 - - label_images = [] - for maxima, interior in zip(maximas, interiors): - # squeeze out the channel dimension if passed - maxima = nd.gaussian_filter(maxima[..., 0], maxima_smooth) - interior = nd.gaussian_filter(interior[..., 0], interior_smooth) - - if pixel_expansion: - fn = cube if input_is_3d else square - interior = dilation(interior, selem=fn(pixel_expansion * 2 + 1)) - - # peak_local_max is much faster but has poorer performance - # when dealing with more ambiguous local maxima - if maxima_algorithm == 'peak_local_max': - coords = peak_local_max( - maxima, - min_distance=radius, - threshold_abs=maxima_threshold, - exclude_border=kwargs.get('exclude_border', False)) - - markers = np.zeros_like(maxima) - slc = tuple(coords[:, i] for i in range(coords.shape[1])) - markers[slc] = 1 - else: - # Find peaks and merge equal regions - fn = ball if input_is_3d else disk - markers = h_maxima(image=maxima, - h=maxima_threshold, - footprint=fn(radius)) - - markers = label(markers) - label_image = watershed(-1 * interior, markers, - mask=interior > interior_threshold, - watershed_line=watershed_line) - - if label_erosion: - label_image = f_erode_edges(label_image, label_erosion) - - # Remove small objects - if small_objects_threshold: - label_image = remove_small_objects(label_image, - min_size=small_objects_threshold) - - # fill in holes that lie completely within a segmentation label - if fill_holes_threshold > 0: - label_image = f_fill_holes(label_image, size=fill_holes_threshold) - - # Relabel the label image - label_image, _, _ = relabel_sequential(label_image) - - label_images.append(label_image) - - label_images = np.stack(label_images, axis=0) - label_images = np.expand_dims(label_images, axis=-1) - - return label_images - - -def f_erode_edges(mask, erosion_width): - """ - Erode edge of objects to prevent them from touching - - :param mask: (numpy.array): uniquely labeled instance mask - :param erosion_width: erosion_width (int): integer value for pixel width to erode edges - :return: numpy.array: mask where each instance has had the edges eroded - - Raises: - ValueError: mask.ndim is not 2 or 3 - """ - - if mask.ndim not in {2, 3}: - raise ValueError('erode_edges expects arrays of ndim 2 or 3.' - 'Got ndim: {}'.format(mask.ndim)) - if erosion_width: - new_mask = np.copy(mask) - for _ in range(erosion_width): - boundaries = find_boundaries(new_mask, mode='inner') - new_mask[boundaries > 0] = 0 - return new_mask - - return mask - - -def f_fill_holes(label_img, size=10, connectivity=1): - """ - Fills holes located completely within a given label with pixels of the same value - - :param label_img: (numpy.array): a 2D labeled image - :param size: (int): maximum size for a hole to be filled in - :param connectivity: (int): the connectivity used to define the hole - :return:numpy.array: a labeled image with no holes smaller than ``size`` - contained within any label. - """ - - output_image = np.copy(label_img) - - props = regionprops(np.squeeze(label_img.astype('int')), cache=False) - for prop in props: - if prop.euler_number < 1: - patch = output_image[prop.slice] - - filled = remove_small_holes( - ar=(patch == prop.label), - area_threshold=size, - connectivity=connectivity) - - output_image[prop.slice] = np.where(filled, prop.label, patch) - - return output_image diff --git a/stereo/image/cellbin/image/wsi_split.py b/stereo/image/cellbin/image/wsi_split.py deleted file mode 100644 index ae5aa31c..00000000 --- a/stereo/image/cellbin/image/wsi_split.py +++ /dev/null @@ -1,160 +0,0 @@ -from math import ceil - -import numpy as np -from tqdm import tqdm - - -class SplitWSI(object): - def __init__( - self, - img, - win_shape, - overlap=0, - batch_size=0, - need_fun_ret=False, - need_combine_ret=True, - editable=False, - tar_dtype=np.uint8 - ): - """ - update by cenweixuan on 2023/3/07 - help split the img and run the function piece by piece then combine the pieces into img - :param img:(ndarry) - :param win_shape:(tuple)pieces shape - :param overlap:(int) - :param need_fun_ret: fun's batch return - :param need_combine_ret: if need combine ret - :param editable:True to overwrite the img with dst - :param batch_size:>0 your fun must support to input a list - :param tar_dtype:output dtype - """ - self._img = img - self._win_shape = win_shape - self._overlap = overlap - self._editable = editable - self._batch_size = batch_size - self._tar_dtype = tar_dtype - - self._need_fun_ret = need_fun_ret - self._need_combine_ret = need_combine_ret - - self._box_lst = [] - self._dst = np.array([]) - self._fun_ret = [] - - self._y_nums = 0 - self._x_nums = 0 - - self._runfun = None - self._runfun_args = None - self._is_set_runfun = 0 - - self._prefun = None - self._prefun_args = None - self._is_set_prefun = 0 - - self._fusion = None - self._fusion_args = None - self._is_set_fusion_fun = 0 - - self._f_init() - - def _f_init(self): - if self._need_combine_ret: - if self._editable: - self._dst = self._img - else: - self._dst = np.zeros(self._img.shape, self._tar_dtype) - - def get_nums(self): - return self._x_nums, self._y_nums - - def f_set_run_fun(self, fun, *args): - self._runfun = fun - self._runfun_args = args - self._is_set_runfun = 1 - - def f_set_pre_fun(self, fun, *args): - self._prefun = fun - self._prefun_args = args - self._is_set_prefun = 1 - - def f_set_fusion_fun(self, fun, *args): - self._fusion = fun - self._fusion_args = args - self._is_set_fusion_fun = 1 - - def _f_split(self): - h, w = self._img.shape[:2] - win_h, win_w = self._win_shape[:2] - self._y_nums = ceil(h / (win_h - self._overlap)) - self._x_nums = ceil(w / (win_w - self._overlap)) - for y_temp in range(self._y_nums): - for x_temp in range(self._x_nums): - x_begin = int(max(0, x_temp * (win_w - self._overlap))) - y_begin = int(max(0, y_temp * (win_h - self._overlap))) - x_end = int(min(x_begin + win_w, w)) - y_end = int(min(y_begin + win_h, h)) - if y_begin >= y_end or x_begin >= x_end: - continue - self._box_lst.append([y_begin, y_end, x_begin, x_end]) - return - - def _f_get_batch_input(self, batch_box): - batch_input = [] - for box in batch_box: - y_begin, y_end, x_begin, x_end = box - img_win = self._img[y_begin: y_end, x_begin: x_end] - if self._is_set_prefun: - img_win = self._prefun(img_win, *self._prefun_args) - batch_input.append(img_win) - return batch_input - - def _f_set_img(self, box, img_win): - h, w = self._dst.shape[:2] - win_h, win_w = img_win.shape[:2] - win_y_begin, win_x_begin = 0, 0 - y_begin, y_end, x_begin, x_end = box - if self._overlap != 0: - if y_begin != 0: - y_begin = min(y_begin + self._overlap // 2, h - 1) - win_y_begin = min(win_y_begin + self._overlap // 2, win_h - 1) - if x_begin != 0: - x_begin = min(x_begin + self._overlap // 2, w - 1) - win_x_begin = min(win_x_begin + self._overlap // 2, win_w - 1) - if y_end != h: - y_end = y_end - self._overlap // 2 - if x_end != w: - x_end = x_end - self._overlap // 2 - if self._is_set_fusion_fun: - self._dst[y_begin: y_end, x_begin: x_end, ...] = self._fusion( - self._dst[y_begin: y_end, x_begin: x_end, ...], - img_win[win_y_begin: win_y_begin + y_end - y_begin, win_x_begin: win_x_begin + x_end - x_begin, ...], - *self._fusion_args) - else: - self._dst[y_begin: y_end, x_begin: x_end, ...] = \ - img_win[win_y_begin: win_y_begin + y_end - y_begin, win_x_begin: win_x_begin + x_end - x_begin, ...] - return - - def _f_run(self): - for i in tqdm(range(0, len(self._box_lst), self._batch_size)): - batch_box = self._box_lst[i:min(i + self._batch_size, len(self._box_lst))] - batch_input = self._f_get_batch_input(batch_box) - if self._batch_size > 1: - batch_output = self._runfun(batch_input, *self._runfun_args) - else: - batch_output = [self._runfun(batch_input[0], *self._runfun_args)] - - if self._need_fun_ret: - self._fun_ret.append(batch_output) - - if self._need_combine_ret: - for box, pred in zip(batch_box, batch_output): - self._f_set_img(box, pred) - return - - def f_split2run(self): - self._f_split() - if self._is_set_runfun and (self._runfun is not None): - self._f_run() - return self._box_lst, self._fun_ret, self._dst diff --git a/stereo/image/cellbin/modules/__init__.py b/stereo/image/cellbin/modules/__init__.py deleted file mode 100644 index 8352c984..00000000 --- a/stereo/image/cellbin/modules/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -import enum - - -class StainType(enum.Enum): - ssDNA = 'ssdna' - DAPI = 'dapi' - HE = 'HE' - mIF = 'mIF' - - -class CellBinElement(object): - def __init__(self): - self.schedule = None - self.task_name = '' - self.sub_task_name = '' diff --git a/stereo/image/cellbin/modules/cell_segmentation.py b/stereo/image/cellbin/modules/cell_segmentation.py deleted file mode 100644 index 09c443c8..00000000 --- a/stereo/image/cellbin/modules/cell_segmentation.py +++ /dev/null @@ -1,160 +0,0 @@ -import os - -import matplotlib.pyplot as plt -import numpy as np -from tifffile import tifffile - -from stereo import logger -from stereo.image.tissue_cut import ( - SingleStrandDNATissueCut, - DEEP, - INTENSITY -) -from . import CellBinElement -from ..dnn.cseg.cell_trace import get_trace as get_t -from ..dnn.cseg.detector import Segmentation - - -class CellSegmentation(CellBinElement): - def __init__(self, model_path, gpu="-1", num_threads=0): - super(CellSegmentation, self).__init__() - - self._MODE = "onnx" - self._NET = "bcdu" - self._WIN_SIZE = (256, 256) - self._INPUT_SIZE = (256, 256, 1) - self._OVERLAP = 16 - - self._gpu = gpu - self._model_path = model_path - self._num_threads = num_threads - - self._cell_seg = Segmentation( - net=self._NET, - mode=self._MODE, - gpu=self._gpu, - num_threads=self._num_threads, - win_size=self._WIN_SIZE, - intput_size=self._INPUT_SIZE, - overlap=self._OVERLAP - ) - self._cell_seg.f_init_model(model_path=self._model_path) - - def run(self, img): - mask = self._cell_seg.f_predict(img) - return mask - - @staticmethod - def get_trace(mask): - return get_t(mask) - - -def _get_tissue_mask(img_path, model_path, method, dst_img_path): - if method is None: - method = DEEP - if not model_path or len(model_path) == 0: - method = INTENSITY - ssDNA_tissue_cut = SingleStrandDNATissueCut( - src_img_path=img_path, - model_path=model_path, - dst_img_path=dst_img_path, - seg_method=method - ) - ssDNA_tissue_cut.tissue_seg() - return ssDNA_tissue_cut.mask[0] - - -def _get_img_filter(img, tissue_mask): - """get tissue image by tissue mask""" - img_filter = np.multiply(img, tissue_mask) - return img_filter - - -def cell_seg_v3( - model_path: str, - img_path: str, - out_path: str, - gpu="-1", - num_threads=0, - need_tissue_cut=True, - tissue_seg_model_path: str = None, - tissue_seg_method: str = None, - tissue_seg_dst_img_path=None, - -): - """ - Implement cell segmentation v3 by deep learning model. - - Parameters - ----------------- - model_path - the path to deep learning model. - img_path - the path to image file. - out_path - the path to output mask result. - gpu - set gpu id, if `'-1'`, use cpu for prediction. - num_threads - multi threads num of the model reading process - need_tissue_cut - whether cut image as tissue before cell segmentation - tissue_seg_model_path - the path of deep learning model of tissue segmentation, if set it to None, it would use OpenCV to process. - tissue_seg_method - the method of tissue segmentation, 1 is based on deep learning and 0 is based on OpenCV. - tissue_seg_dst_img_path - default to the img_path's directory. - Returns - ------------ - None - - """ - cell_bcdu = CellSegmentation( - model_path=model_path, - gpu=gpu, - num_threads=num_threads - ) - if img_path.split('.')[-1] == "tif": - img = tifffile.imread(img_path) - elif img_path.split('.')[-1] == "png": - img = plt.imread(img_path) - if img.dtype == np.float32: - img.astype('uint32') - img = transfer_32bit_to_8bit(img) - else: - raise Exception("cell seg only support tif and png") - - # img must be 16 bit ot 8 bit, and 16 bit image finally will be transferred to 8 bit - assert img.dtype == np.uint16 or img.dtype == np.uint8, f'{img.dtype} is not supported' - if img.dtype == np.uint16: - img = transfer_16bit_to_8bit(img) - if need_tissue_cut: - if tissue_seg_dst_img_path is None: - tissue_seg_dst_img_path = os.path.dirname(img_path) - tissue_mask = _get_tissue_mask(img_path, tissue_seg_model_path, tissue_seg_method, tissue_seg_dst_img_path) - img = _get_img_filter(img, tissue_mask) - mask = cell_bcdu.run(img) - CellSegmentation.get_trace(mask) - file_name = img_path.split('.')[0].split('/')[-1] - file_path = '/'.join([out_path, file_name + r'.cell_cut.tif']) - if not os.path.exists(out_path): - os.makedirs(out_path) - tifffile.imwrite(file_path, mask) - logger.info('seg results saved in %s' % file_path) - - -def transfer_16bit_to_8bit(image_16bit): - min_16bit = np.min(image_16bit) - max_16bit = np.max(image_16bit) - image_8bit = np.array(np.rint(255 * ((image_16bit - min_16bit) / (max_16bit - min_16bit))), dtype=np.uint8) - return image_8bit - - -def transfer_32bit_to_8bit(image_32bit): - min_32bit = np.min(image_32bit) - max_32bit = np.max(image_32bit) - image_8bit = np.array( - np.rint(255 * ((image_32bit - min_32bit) / (max_32bit - min_32bit))), dtype=np.uint8 - ) - return image_8bit diff --git a/stereo/image/segmentation/seg_utils/v1_pro/__init__.py b/stereo/image/segmentation/seg_utils/v1_pro/__init__.py deleted file mode 100644 index 53234636..00000000 --- a/stereo/image/segmentation/seg_utils/v1_pro/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from stereo.image.segmentation.seg_utils.v1_pro.cell_seg_pipeline_v1_pro import CellSegPipeV1Pro # noqa diff --git a/stereo/image/segmentation/seg_utils/v3/__init__.py b/stereo/image/segmentation/seg_utils/v3/__init__.py new file mode 100644 index 00000000..555abc29 --- /dev/null +++ b/stereo/image/segmentation/seg_utils/v3/__init__.py @@ -0,0 +1 @@ +from stereo.image.segmentation.seg_utils.v3.cell_seg_pipeline_v3 import CellSegPipeV3 # noqa diff --git a/stereo/image/segmentation/seg_utils/v1_pro/cell_seg_pipeline_v1_pro.py b/stereo/image/segmentation/seg_utils/v3/cell_seg_pipeline_v3.py similarity index 83% rename from stereo/image/segmentation/seg_utils/v1_pro/cell_seg_pipeline_v1_pro.py rename to stereo/image/segmentation/seg_utils/v3/cell_seg_pipeline_v3.py index de89dd35..b5f079c9 100644 --- a/stereo/image/segmentation/seg_utils/v1_pro/cell_seg_pipeline_v1_pro.py +++ b/stereo/image/segmentation/seg_utils/v3/cell_seg_pipeline_v3.py @@ -6,12 +6,13 @@ from stereo.image.segmentation.seg_utils.base_cell_seg_pipe.cell_seg_pipeline import CellSegPipe from stereo.log_manager import logger +from stereo.tools.tools import make_dirs -class CellSegPipeV1Pro(CellSegPipe): +class CellSegPipeV3(CellSegPipe): def run(self): - logger.info('Start do cell mask, this will take some minutes.') + logger.info('Start do cell mask, the method is v3, this will take some minutes.') cell_seg = CellSegmentation( model_path=self.model_path, gpu=self.kwargs.get('gpu', '-1'), @@ -29,6 +30,7 @@ def run(self): self.save_cell_mask() def save_cell_mask(self): + make_dirs(self.out_path) cell_mask_path = os.path.join(self.out_path, f"{self.file_name[-1]}_mask.tif") Image.write_s(self.mask, cell_mask_path, compression=True) logger.info('Result saved : %s ' % (cell_mask_path)) diff --git a/stereo/image/segmentation/segment.py b/stereo/image/segmentation/segment.py index 6e97995d..1191a106 100644 --- a/stereo/image/segmentation/segment.py +++ b/stereo/image/segmentation/segment.py @@ -5,8 +5,7 @@ from stereo.constant import VersionType from stereo.image.segmentation.seg_utils.v1 import CellSegPipeV1 -from stereo.image.segmentation.seg_utils.v1_pro import CellSegPipeV1Pro -from ..cellbin.modules.cell_segmentation import cell_seg_v3 +from stereo.image.segmentation.seg_utils.v3 import CellSegPipeV3 def cell_seg( @@ -23,7 +22,7 @@ def cell_seg( tissue_seg_dst_img_path=None, num_threads: int = 0, need_tissue_cut=True, - version: str = 'v1', + method: str = 'v1', ): """ Implement cell segmentation by deep learning model. @@ -56,22 +55,22 @@ def cell_seg( multi threads num of the model reading process need_tissue_cut whether cut image as tissue before cell segmentation - version - the version, version must be `v1` , `v1_pro`, `v3` + method + the method, method must be `v1` , `v1_pro`, `v3` Returns ------------ None """ - if version not in VersionType.get_version_list(): + if method not in VersionType.get_version_list(): raise Exception("version must be %s" % ('、'.join(VersionType.get_version_list()))) if not model_path: raise Exception("cell_seg() missing 1 required keyword argument: 'model_path'") os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu) - if version == VersionType.v1.value: + if method == VersionType.v1.value: cell_seg_pipeline = CellSegPipeV1( img_path, out_path, @@ -84,8 +83,8 @@ def cell_seg( model_path=model_path ) cell_seg_pipeline.run() - elif version == VersionType.v1_pro.value: - cell_seg_pipeline = CellSegPipeV1Pro( + elif method == VersionType.v3.value: + cell_seg_pipeline = CellSegPipeV3( img_path, out_path, is_water, @@ -94,15 +93,3 @@ def cell_seg( model_path=model_path ) cell_seg_pipeline.run() - else: - cell_seg_v3( - model_path, - img_path, - out_path, - gpu=gpu, - num_threads=num_threads, - need_tissue_cut=need_tissue_cut, - tissue_seg_model_path=tissue_seg_model_path, - tissue_seg_method=tissue_seg_method, - tissue_seg_dst_img_path=tissue_seg_dst_img_path, - )