Skip to content

Commit

Permalink
do some updates.
Browse files Browse the repository at this point in the history
  • Loading branch information
tanliwei-coder committed May 30, 2024
1 parent 481be61 commit 3ab1069
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 8 deletions.
4 changes: 4 additions & 0 deletions stereo/algorithm/paste/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,10 @@ def pairwise_align(
if use_gpu:
try:
import torch
backend = ot.backend.TorchBackend()
except Exception:
logger.warning("We currently only have gpu support for Pytorch. Please install torch.")
backend = ot.backend.NumpyBackend()

if isinstance(backend, ot.backend.TorchBackend):
if torch.cuda.is_available():
Expand Down Expand Up @@ -225,8 +227,10 @@ def center_align(
if use_gpu:
try:
import torch
backend = ot.backend.TorchBackend()
except Exception:
logger.warning("We currently only have gpu support for Pytorch. Please install torch.")
backend = ot.backend.NumpyBackend()

if isinstance(backend, ot.backend.TorchBackend):
if torch.cuda.is_available():
Expand Down
11 changes: 10 additions & 1 deletion stereo/core/ms_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,15 @@ def remove_scopes_data(self, scope):


def integrate(self, scope=None, remove_existed=False, **kwargs):
"""
Integrate some single-samples specified by `scope` to a merged one.
:param scope: Which scope of samples to be integrated, defaults to None.
Each integrate sample is saved in memory, performing this function
by passing duplicate `scope` will return the saved one.
:param remove_existed: Whether to remove the saved integrate sample when passing a duplicate `scope`, defaults to False.
"""
from stereo.utils.data_helper import merge
if self._var_type not in {"union", "intersect"}:
raise Exception("Please specify the operation on samples with the parameter '_var_type'")
Expand Down Expand Up @@ -816,7 +825,7 @@ def to_integrate(
assert isinstance(item, str) or len(item) == len(self[_from]._names), "`item`'s length not equal to _from"
scope_names = self[scope]._names
scope_key = self.generate_scope_key(scope_names)
assert scope_key in self._scopes_data or self._merged_data, f"`to_integrate` need running function `integrate`"
assert scope_key in self._scopes_data or self._merged_data, f"`to_integrate` need running function `integrate` first"
if type == 'obs':
if scope_key in self._scopes_data:
self._scopes_data[scope_key].cells[res_key] = fill
Expand Down
30 changes: 30 additions & 0 deletions stereo/core/ms_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,22 @@ def result_keys(self, result_keys):
self._result_keys = result_keys
self._reset_result_keys()

@property
def mode(self):
return self.__mode

@mode.setter
def mode(self, mode):
self.__mode = mode

@property
def scope(self):
return self.__scope

@scope.setter
def scope(self, scope):
self.__scope = scope

def _reset_result_keys(self):
for scope_key, result_keys in self._result_keys.items():
self._result_keys[scope_key] = []
Expand Down Expand Up @@ -281,9 +297,23 @@ def set_scope_and_mode(
scope: slice = slice(None),
mode: str = "integrate"
):
"""
Set the `scope` and `mode` globally for Multi-slice analysis.
:param scope: the scope, defaults to slice(None)
:param mode: the mode, defaults to "integrate"
"""
assert mode in ("integrate", "isolated"), 'mode should be one of [`integrate`, `isolated`]'
self.__mode = mode
self.__scope = scope
if self.__class__.ATTR_NAME == 'tl':
self.ms_data.plt.scope = scope
self.ms_data.plt.mode = mode
elif self.__class__.ATTR_NAME == 'plt':
self.ms_data.tl.scope = scope
self.ms_data.tl.mode = mode
else:
pass


slice_generator = _scope_slice()
6 changes: 3 additions & 3 deletions stereo/core/st_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1009,10 +1009,10 @@ def phenograph(self,
@logit
def find_marker_genes(self,
cluster_res_key,
method: Literal['t_test', 'wilcoxon_test'] = 't_test',
method: Literal['t_test', 'wilcoxon_test', 'logreg'] = 't_test',
case_groups: Union[str, np.ndarray, list] = 'all',
control_groups: Union[str, np.ndarray, list] = 'rest',
corr_method: str = 'benjamini-hochberg',
corr_method: Literal['bonferroni', 'benjamini-hochberg'] = 'benjamini-hochberg',
use_raw: bool = True,
use_highly_genes: bool = True,
hvg_res_key: Optional[str] = 'highly_variable_genes',
Expand All @@ -1031,7 +1031,7 @@ def find_marker_genes(self,
:param method: choose method for statistics.
:param case_groups: case group, default all clusters.
:param control_groups: control group, default the rest of groups.
:param corr_method: correlation method.
:param corr_method: p-value correction method, only available for `t_test` and `wilcoxon_test`.
:param use_raw: whether to use raw express matrix for analysis, default True.
:param use_highly_genes: whether to use only the expression of hypervariable genes as input, default True.
:param hvg_res_key: the key of highly variable genes to get corresponding result.
Expand Down
4 changes: 2 additions & 2 deletions stereo/plots/plot_st_gears.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ class PlotStGears(MSDataPlotBase):
def __init__(self, ms_data, pipeline_res=None):
super().__init__(ms_data, pipeline_res)
from stereo.algorithm.st_gears.visual import RegisPlotter
self.__plotter = RegisPlotter(num_cols=3, dpi_val=1000)
self.__plotter = RegisPlotter(num_cols=3, dpi_val=100)
self.__scatter_size_factor = 11000


def plot_scatter_by_grid(
def scatter_for_st_gears(
self,
ctype: Literal['cell_label', 'num_anchors', 'pi_max', 'strong_anch', 'weight', 'sum_gene'] = 'cell_label',
lay_type: Literal['to_pre', 'to_next'] = 'to_pre',
Expand Down
7 changes: 5 additions & 2 deletions stereo/preprocess/sc_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from stereo.algorithm.sctransform import SCTransform
from stereo.core.stereo_exp_data import StereoExpData
from stereo.preprocess.filter import filter_genes


def sc_transform(
Expand Down Expand Up @@ -35,12 +36,14 @@ def sc_transform(
)
new_exp_matrix = res[0][exp_matrix_key]
if issparse(new_exp_matrix):
data.sub_by_index(gene_index=res[1]['umi_genes'])
# data.sub_by_index(gene_index=res[1]['umi_genes'])
filter_genes(data, gene_list=res[1]['umi_genes'], inplace=True)
data.exp_matrix = new_exp_matrix.T.tocsr()
# gene_index = np.isin(data.gene_names, res[1]['umi_genes'])
# data.genes = data.genes.sub_set(gene_index)
else:
data.sub_by_index(gene_index=new_exp_matrix.index.to_numpy())
# data.sub_by_index(gene_index=new_exp_matrix.index.to_numpy())
filter_genes(data, gene_list=new_exp_matrix.index.to_numpy(), inplace=True)
data.exp_matrix = new_exp_matrix.T.to_numpy()
# gene_index = np.isin(data.gene_names, new_exp_matrix.index.values)
# data.genes = data.genes.sub_set(gene_index)
Expand Down

0 comments on commit 3ab1069

Please sign in to comment.