diff --git a/stereo/algorithm/cell_pose/cell_pose.py b/stereo/algorithm/cell_pose/cell_pose.py index 67ba9624..dbb86c6a 100644 --- a/stereo/algorithm/cell_pose/cell_pose.py +++ b/stereo/algorithm/cell_pose/cell_pose.py @@ -21,15 +21,18 @@ class CellPose: - def __init__(self, - img_path: str, - out_path: str, - photo_size: Optional[int] = 2048, - photo_step: Optional[int] = 2000, - model_type: Optional[str] = 'cyto2', - dmin: Optional[int] = 10, - dmax: Optional[int] = 40, - step: Optional[int] = 10): + def __init__( + self, + img_path: str, + out_path: str, + photo_size: Optional[int] = 2048, + photo_step: Optional[int] = 2000, + model_type: Optional[str] = 'cyto2', + dmin: Optional[int] = 10, + dmax: Optional[int] = 40, + step: Optional[int] = 10, + gpu: Optional[bool] = False + ): """ :param img_path: input file path. @@ -44,6 +47,7 @@ def __init__(self, :param dmin: cell minimum diameter, default is 10. :param dmax: cell diameter, default is 40. :param step: the step size of cell diameter search, default is 10. + :param gpu: Whether to use gpu acceleration, the default is False. """ self.img_path = img_path self.out_path = out_path @@ -51,6 +55,7 @@ def __init__(self, self.photo_step = photo_step self.dmin = dmin self.dmax = dmax + self.gpu = gpu self.step = step self.model_type = model_type self.segment_cells() @@ -75,7 +80,7 @@ def _process_image(self): patches = patchify.patchify(regray_image, (self.photo_size, self.photo_size), step=self.photo_step) wid = patches.shape[0] high = patches.shape[1] - model = models.Cellpose(gpu=True, model_type=self.model_type) + model = models.Cellpose(gpu=self.gpu, model_type=self.model_type) a_patches = np.full((wid, high, (self.photo_step), (self.photo_step)), 255) for i in range(wid): for j in range(high):