Skip to content

Commit

Permalink
【Fix】 fix cell_segment_v3
Browse files Browse the repository at this point in the history
  • Loading branch information
Zhenbin24 committed Dec 1, 2023
1 parent fb7e7bd commit ee5b222
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,19 @@
import numpy as np
import tifffile

from stereo.image.tissue_cut import DEEP
from stereo.image.tissue_cut import (
DEEP,
INTENSITY,
SingleStrandDNATissueCut
)
from stereo.log_manager import logger


class CellSegPipe(object):

def __init__(
self,
model_path,
img_path,
out_path,
is_water,
Expand All @@ -29,7 +34,6 @@ def __init__(
tissue_seg_model_path='',
tissue_seg_method=DEEP,
post_processing_workers=10,
model_path=None,
*args,
**kwargs
):
Expand Down Expand Up @@ -58,10 +62,12 @@ def __init__(
logger.info('Transform 16bit to 8bit : %.2f' % (t1 - t0))
self.tissue_mask = []
self.tissue_mask_thumb = []
self.tissue_seg_model_path = tissue_seg_model_path
self.tissue_seg_method = tissue_seg_method
self.tissue_num = [] # tissue num in each image
self.tissue_bbox = [] # tissue roi bbox in each image
self.img_filter = [] # image filtered by tissue mask
self.get_tissue_mask(tissue_seg_model_path, tissue_seg_method)
self.get_tissue_mask()
self.get_roi()
self.cell_mask = []
self.post_mask_list = []
Expand Down Expand Up @@ -96,14 +102,44 @@ def convert_gray(self):
logger.info('Image %s convert to gray!' % self.file[idx])
self.img_list[idx] = img[:, :, 0]

@staticmethod
def transfer_32bit_to_8bit(image_32bit):
min_32bit = np.min(image_32bit)
max_32bit = np.max(image_32bit)
return np.array(np.rint(255 * ((image_32bit - min_32bit) / (max_32bit - min_32bit))), dtype=np.uint8)

@staticmethod
def transfer_16bit_to_8bit(image_16bit):
min_16bit = np.min(image_16bit)
max_16bit = np.max(image_16bit)
return np.array(np.rint(255 * ((image_16bit - min_16bit) / (max_16bit - min_16bit))), dtype=np.uint8)

def trans16to8(self):
pass
from stereo.log_manager import logger
for idx, img in enumerate(self.img_list):
assert img.dtype in ['uint16', 'uint8']
if img.dtype != 'uint8':
logger.info('%s transfer to 8bit' % self.file[idx])
self.img_list[idx] = self.transfer_16bit_to_8bit(img)

def save_each_file_result(self, file_name, idx):
pass

def get_tissue_mask(self, tissue_seg_model_path, tissue_seg_method):
pass
def get_tissue_mask(self):
tissue_seg_model_path = self.tissue_seg_model_path
tissue_seg_method = self.tissue_seg_method
if tissue_seg_method is None:
tissue_seg_method = DEEP
if not tissue_seg_model_path or len(tissue_seg_model_path) == 0:
tissue_seg_method = INTENSITY
ss_dna_tissue_cut = SingleStrandDNATissueCut(
src_img_path=self.img_path,
model_path=tissue_seg_model_path,
dst_img_path=self.out_path,
seg_method=tissue_seg_method
)
ss_dna_tissue_cut.tissue_seg()
self.tissue_mask = ss_dna_tissue_cut.mask

@staticmethod
def filter_roi(props):
Expand All @@ -120,10 +156,24 @@ def filter_roi(props):
def get_roi(self):
pass

def tissue_cell_infer(self):
pass

def tissue_label_filter(self, tissue_cell_label):
"""filter cell mask in tissue area"""
tissue_cell_label_filter = []
for idx, label in enumerate(tissue_cell_label):
tissue_bbox = self.tissue_bbox[idx]
label_filter_list = []
for i in range(self.tissue_num[idx]):
tissue_bbox_temp = tissue_bbox[i]
label_filter = np.multiply(
label[i],
self.tissue_mask[idx][tissue_bbox_temp[0]: tissue_bbox_temp[2],
tissue_bbox_temp[1]: tissue_bbox_temp[3]] # noqa
).astype(np.uint8)
label_filter_list.append(label_filter)
tissue_cell_label_filter.append(label_filter_list)
return tissue_cell_label_filter

def tissue_cell_infer(self):
pass

def mosaic(self, tissue_cell_label_filter):
Expand Down
45 changes: 0 additions & 45 deletions stereo/image/segmentation/seg_utils/v1/cell_seg_pipeline_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,6 @@
utils,
cell_infer
)
from stereo.image.tissue_cut import (
SingleStrandDNATissueCut,
DEEP,
INTENSITY
)


class CellSegPipeV1(CellSegPipe):
Expand Down Expand Up @@ -56,14 +51,6 @@ def save_cell_mask(self):
self.file_name[0] + '_' + str(shapes[0]) + '_' + str(shapes[1]) + '_' +
str(x_list[idx]) + '_' + str(y_list[idx]) + '.tif'), score_list[idx])

def trans16to8(self):
from stereo.log_manager import logger
for idx, img in enumerate(self.img_list):
assert img.dtype in ['uint16', 'uint8']
if img.dtype != 'uint8':
logger.info('%s transfer to 8bit' % self.file[idx])
self.img_list[idx] = utils.transfer_16bit_to_8bit(img)

def get_roi(self):
for idx, tissue_mask in enumerate(self.tissue_mask):
label_image = measure.label(tissue_mask, connectivity=2)
Expand All @@ -80,37 +67,6 @@ def get_roi(self):
self.tissue_num.append(len(filtered_props))
self.tissue_bbox.append([p['bbox'] for p in filtered_props])

def get_tissue_mask(self, tissue_seg_model_path, tissue_seg_method):
if tissue_seg_method is None:
tissue_seg_method = DEEP
if not tissue_seg_model_path or len(tissue_seg_model_path) == 0:
tissue_seg_method = INTENSITY
ssDNA_tissue_cut = SingleStrandDNATissueCut(
src_img_path=self.img_path,
model_path=tissue_seg_model_path,
dst_img_path=self.out_path,
seg_method=tissue_seg_method
)
ssDNA_tissue_cut.tissue_seg()
self.tissue_mask = ssDNA_tissue_cut.mask

def tissue_label_filter(self, tissue_cell_label):
"""filter cell mask in tissue area"""
tissue_cell_label_filter = []
for idx, label in enumerate(tissue_cell_label):
tissue_bbox = self.tissue_bbox[idx]
label_filter_list = []
for i in range(self.tissue_num[idx]):
tissue_bbox_temp = tissue_bbox[i]
label_filter = np.multiply(
label[i],
self.tissue_mask[idx][tissue_bbox_temp[0]: tissue_bbox_temp[2],
tissue_bbox_temp[1]: tissue_bbox_temp[3]] # noqa
).astype(np.uint8)
label_filter_list.append(label_filter)
tissue_cell_label_filter.append(label_filter_list)
return tissue_cell_label_filter

def tissue_cell_infer(self, q=None):
"""cell segmentation in tissue area by neural network"""
tissue_cell_label = []
Expand Down Expand Up @@ -146,7 +102,6 @@ def watershed_score(self, cell_mask):

def __get_img_filter(self):
"""get tissue image by tissue mask"""
# for idx, img in enumerate(self.img_list):
for img, tissue_mask in zip(self.img_list, self.tissue_mask):
img_filter = np.multiply(img, tissue_mask).astype(np.uint8)
self.img_filter.append(img_filter)
Expand Down
32 changes: 27 additions & 5 deletions stereo/image/segmentation/seg_utils/v3/cell_seg_pipeline_v3.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
# import image
import os

import matplotlib.pyplot as plt
import numpy as np
from cellbin.image import Image
from cellbin.modules.cell_segmentation import CellSegmentation
from tifffile import tifffile

from stereo.image.segmentation.seg_utils.base_cell_seg_pipe.cell_seg_pipeline import CellSegPipe
from stereo.log_manager import logger
Expand All @@ -11,6 +14,11 @@

class CellSegPipeV3(CellSegPipe):

def _get_img_filter(self, img, tissue_mask):
"""get tissue image by tissue mask"""
img_filter = np.multiply(img, tissue_mask)
return img_filter

def run(self):
logger.info('Start do cell mask, the method is v3, this will take some minutes.')
cell_seg = CellSegmentation(
Expand All @@ -19,14 +27,28 @@ def run(self):
num_threads=self.kwargs.get('num_threads', 0),
)
logger.info(f"Load {self.model_path}) finished.")

image = Image()
image.read(image=self.img_path)
if self.img_path.split('.')[-1] == "tif":
img = tifffile.imread(self.img_path)
elif self.img_path.split('.')[-1] == "png":
img = plt.imread(self.img_path)
if img.dtype == np.float32:
img.astype('uint32')
img = self.transfer_32bit_to_8bit(img)
else:
raise Exception("cell seg only support tif and png")

# img must be 16 bit ot 8 bit, and 16 bit image finally will be transferred to 8 bit
assert img.dtype == np.uint16 or img.dtype == np.uint8, f'{img.dtype} is not supported'
if img.dtype == np.uint16:
img = self.transfer_16bit_to_8bit(img)

if self.kwargs.get('need_tissue_cut', None):
self.get_tissue_mask()
img = self._get_img_filter(img, self.tissue_mask)

# Run cell segmentation
mask = cell_seg.run(image.image)
mask = cell_seg.run(img)
self.mask = mask

self.save_cell_mask()

def save_cell_mask(self):
Expand Down
21 changes: 11 additions & 10 deletions stereo/image/segmentation/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,16 @@


def cell_seg(
model_path: str,
img_path: str,
out_path: str,
model_path: str = None,
deep_crop_size: int = 20000,
overlap: int = 100,
gpu: str = '-1',
tissue_seg_model_path: str = None,
tissue_seg_method: str = None,
post_processing_workers: int = 10,
is_water: bool = False,
tissue_seg_dst_img_path=None,
num_threads: int = 0,
need_tissue_cut=True,
method: str = 'v1',
Expand Down Expand Up @@ -49,8 +48,6 @@ def cell_seg(
the number of processes for post-processing.
is_water:
The file name used to generate the mask. If true, the name ends with _watershed.
tissue_seg_dst_img_path
default to the img_path's directory.
num_threads
multi threads num of the model reading process
need_tissue_cut
Expand All @@ -66,30 +63,34 @@ def cell_seg(
if method not in VersionType.get_version_list():
raise Exception("version must be %s" % ('、'.join(VersionType.get_version_list())))

if not model_path:
raise Exception("cell_seg() missing 1 required keyword argument: 'model_path'")

os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu)
if method == VersionType.v1.value:
cell_seg_pipeline = CellSegPipeV1(
model_path,
img_path,
out_path,
is_water,
deep_crop_size,
overlap,
gpu=gpu,
need_tissue_cut=need_tissue_cut,
tissue_seg_model_path=tissue_seg_model_path,
tissue_seg_method=tissue_seg_method,
post_processing_workers=post_processing_workers,
model_path=model_path
)
cell_seg_pipeline.run()
elif method == VersionType.v3.value:
cell_seg_pipeline = CellSegPipeV3(
model_path,
img_path,
out_path,
is_water,
deep_crop_size,
overlap,
gpu=gpu,
num_threads=num_threads,
model_path=model_path
need_tissue_cut=need_tissue_cut,
tissue_seg_model_path=tissue_seg_model_path,
tissue_seg_method=tissue_seg_method,
post_processing_workers=post_processing_workers,
)
cell_seg_pipeline.run()

0 comments on commit ee5b222

Please sign in to comment.