Skip to content

Commit

Permalink
update for totalVI
Browse files Browse the repository at this point in the history
  • Loading branch information
tanliwei-coder committed Dec 27, 2024
1 parent e469e57 commit 30930ac
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions stereo/algorithm/total_vi.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ def main(
res_key: str = 'totalVI',
rna_use_raw: bool = False,
protein_use_raw: bool = False,
use_gpu: Union[int, str, bool] = None,
# use_gpu: Union[int, str, bool] = None,
accelerator: Union[str, None] = None,
devices: Union[str, list[int], int, None] = None,
num_threads: int = None,
train_kwargs: Optional[dict] = {},
**kwags
Expand Down Expand Up @@ -136,7 +138,8 @@ def main(

total_vi = scvi.model.TOTALVI(mdata, **kwags)
scvi.settings.dl_num_workers = num_threads
total_vi.train(use_gpu=use_gpu, **train_kwargs)
# total_vi.train(use_gpu=use_gpu, **train_kwargs)
total_vi.train(accelerator=accelerator, devices=devices, **train_kwargs)

if not self._use_hvg:
rna = rna_data
Expand Down Expand Up @@ -166,7 +169,8 @@ def save_result(
out_dir: str = None,
diff_exp_file_name: str = None,
h5mu_file_name: str = None,
fragment: int = 5
fragment: int = 5,
batch_size: int = 64
):
import os.path as opth
if out_dir is None or not opth.exists(out_dir):
Expand Down Expand Up @@ -222,7 +226,7 @@ def save_result(
gene_list = rna.var_names[rna_start:]
denoised_rna_, denoised_protein_ = \
self._total_vi_instance.get_normalized_expression(n_samples=25,
batch_size=64,
batch_size=batch_size,
gene_list=gene_list,
return_mean=True)
denoised_rna_list.append(denoised_rna_)
Expand Down Expand Up @@ -260,6 +264,7 @@ def filter_from_diff_exp(
public_thresholds: dict = None,
rna_thresholds: dict = None,
protein_thresholds: dict = None,
protein_list: list = None
):
if self._use_hvg:
rna_data = self._hvg_data
Expand All @@ -279,13 +284,16 @@ def filter_from_diff_exp(
for column, threshold in public_thresholds.items():
cell_type_df = cell_type_df[cell_type_df[column] > threshold]

pro_rows = cell_type_df.index.str.contains("_")
data_pro = cell_type_df.iloc[pro_rows]
# pro_rows = cell_type_df.index.str.contains("_")
# data_pro = cell_type_df.iloc[pro_rows]
pro_rows = np.intersect1d(cell_type_df.index, protein_list)
data_pro = cell_type_df[cell_type_df.index.isin(pro_rows)]
if protein_thresholds is not None:
for column, threshold in protein_thresholds.items():
data_pro = data_pro[data_pro[column] > threshold]

data_rna = cell_type_df.iloc[~pro_rows]
# data_rna = cell_type_df.iloc[~pro_rows]
data_rna = cell_type_df[~cell_type_df.index.isin(pro_rows)]
if rna_thresholds is not None:
for column, threshold in rna_thresholds.items():
data_rna = data_rna[data_rna[column] > threshold]
Expand Down

0 comments on commit 30930ac

Please sign in to comment.