Skip to content

Commit

Permalink
Merge pull request #285 from STOmics/multiAnalysis
Browse files Browse the repository at this point in the history
New demands for combined transcription and protein analysis
  • Loading branch information
tanliwei-genomics-cn authored May 27, 2024
2 parents 0583899 + f4b3962 commit d2da126
Showing 1 changed file with 40 additions and 3 deletions.
43 changes: 40 additions & 3 deletions stereo/algorithm/total_vi.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import gc
from copy import deepcopy
from typing import Optional
from typing import Union

import numpy as np

import anndata
import pandas as pd

Expand Down Expand Up @@ -45,6 +48,7 @@ def main(
rna_use_raw: bool = False,
protein_use_raw: bool = False,
use_gpu: Union[int, str, bool] = None,
core: int = None,
train_kwargs: Optional[dict] = {},
**kwags
):
Expand Down Expand Up @@ -131,6 +135,7 @@ def main(
})

total_vi = scvi.model.TOTALVI(mdata, **kwags)
scvi.settings.num_threads = core
total_vi.train(use_gpu=use_gpu, **train_kwargs)

if not self._use_hvg:
Expand Down Expand Up @@ -160,7 +165,8 @@ def save_result(
use_cluster_res_key: str = None,
out_dir: str = None,
diff_exp_file_name: str = None,
h5mu_file_name: str = None
h5mu_file_name: str = None,
fragment: int = 5
):
import os.path as opth
if out_dir is None or not opth.exists(out_dir):
Expand All @@ -184,6 +190,8 @@ def save_result(
split_batches=False)
protein_adata.var['protein_names'] = protein_adata.var_names
if self._use_hvg:
mdata.mod['multiomics'].uns['omics'] = [['Transcriptomics'], ['Proteomics']]
mdata.mod['multiomics'].uns['leiden_resolution'] = 1
hvg_adata: anndata.AnnData = stereo_to_anndata(self._hvg_data, base_adata=mdata.mod['multiomics'],
split_batches=False)
mdata.update()
Expand All @@ -198,8 +206,37 @@ def save_result(
diff_exp_file_name = f'{self._rna_data.sn}_{self._rna_data.bin_size}_differential_expression.csv'
de_df.to_csv(f'{out_dir}/{diff_exp_file_name}')

denoised_rna, denoised_protein = self._total_vi_instance.get_normalized_expression(n_samples=25,
return_mean=True)
# Divide the 4000 gene list into 5 slices, After sharding,
# the output arrays need to be merged and the shapes need to be consistent.
frequency = len(rna.var_names) // fragment
assert (frequency != 0), 'The number of slices of genes is wrong, causing the array to go out of bounds'
denoised_rna_list = []
denoised_protein_list = []
rna_end = rna_start = 0
for i in range(fragment):
if i < fragment - 1:
rna_end += frequency
gene_list = rna.var_names[rna_start:rna_end]
rna_start = rna_end
else:
gene_list = rna.var_names[rna_start:]
denoised_rna_, denoised_protein_ = \
self._total_vi_instance.get_normalized_expression(n_samples=25,
batch_size=64,
gene_list=gene_list,
return_mean=True)
denoised_rna_list.append(denoised_rna_)
denoised_protein_list.append(denoised_protein_)
del denoised_rna_, denoised_protein_
gc.collect()

# Merging results after sharded model inference
denoised_rna = pd.concat(denoised_rna_list, axis=1)
denoised_protein = pd.concat(denoised_protein_list, axis=0)
# he protein uses the average of five results. Can the median be used
denoised_protein = denoised_protein.groupby(denoised_protein.index).sum()
denoised_protein = denoised_protein.div(fragment, fill_value=np.NaN)

denoised_protein = denoised_protein.loc[rna_adata.obs_names]

protein_foreground_prob = self._total_vi_instance.get_protein_foreground_probability(
Expand Down

0 comments on commit d2da126

Please sign in to comment.