Skip to content

Commit

Permalink
add SpaSEG algrithm
Browse files Browse the repository at this point in the history
  • Loading branch information
tanliwei-coder committed Aug 2, 2024
1 parent f4f9f13 commit 3e82793
Show file tree
Hide file tree
Showing 13 changed files with 1,969 additions and 6 deletions.
3 changes: 2 additions & 1 deletion docs/source/Tutorials(Multi-sample)/Multi_sample.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@ Likewise, functions performed on one sample could also be used on multi-sample d
Time_Series_Analysis
3D_Cell_Cell_Communication
3D_Gene_Regulatory_Network
ST_Gears
ST_Gears
SpaSEG
2 changes: 1 addition & 1 deletion docs/source/Tutorials(Multi-sample)/ST_Gears.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
" \n",
"3. `stack_slices_pairwise_elas_field` eliminates distorsions through Gaussian Smoothed Elastic Fields. Validity is proved mathematically.\n",
"\n",
"This function can be ran on GPU, if you want to use GPU, you need to create an environment according to the guide [Clustering_by_GPU](../Tutorials/Clustering_by_GPU.html).\n",
"This function can be run on GPU, if you want to use GPU, you need to create an environment according to the guide [Clustering_by_GPU](../Tutorials/Clustering_by_GPU.html).\n",
"\n",
"Before anlysing, you also need to install a necessary package: **torch**\n",
"\n",
Expand Down
1,254 changes: 1,254 additions & 0 deletions docs/source/Tutorials(Multi-sample)/SpaSEG.ipynb

Large diffs are not rendered by default.

8 changes: 6 additions & 2 deletions docs/source/content/07_References.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,16 @@ References
.. [Zhang23]
Chao Zhang, Qiang Kang, Mei Li, Hongqing Xie, Shuangsang Fang, Xun Xu,
*BatchEval Pipeline: Batch Effect Evaluation Workflow for Multiple Datasets Joint Analysis*, bioRxiv
*BatchEval Pipeline: Batch Effect Evaluation Workflow for Multiple Datasets Joint Analysis*, bioRxiv.
.. [Zheng17]
Grace X Y Zheng, Jessica M Terry, Phillip Belgrader, Paul Ryvkin, Zachary W Bent, Ryan Wilson, Solongo B Ziraldo, Tobias D Wheeler, Geoff P McDermott, Junjie Zhu, Mark T Gregory, Joe Shuga, Luz Montesclaros, Jason G Underwood, Donald A Masquelier, Stefanie Y Nishimura, Michael Schnall-Levin, Paul W Wyatt, Christopher M Hindson, Rajiv Bharadwaj, Alexander Wong, Kevin D Ness, Lan W Beppu, H Joachim Deeg, Christopher McFarland, Keith R Loeb, William J Valente, Nolan G Ericson, Emily A Stevens, Jerald P Radich, Tarjei S Mikkelsen, Benjamin J Hindson, Jason H Bielas,
*Massively parallel digital transcriptional profiling of single cells*, Nature Communications.
.. [Xia23]
Tianyi Xia, Luni Hu, Lulu Zuo, Yunjia Zhang, Mengyang Xu, Qin Lu, Lei Zhang, Lei Cao, Taotao Pan, Bohan Zhang, Bowen Ma, Chuan Chen, Junfu Guo, Chang Shi, Mei Li, Chao Liu, Yuxiang Li, Yong Zhang, Shuangsang Fang,
*ST-GEARS: Advancing 3D Downstream Research through Accurate Spatial Information Recovery*, bioRxiv
*ST-GEARS: Advancing 3D Downstream Research through Accurate Spatial Information Recovery*, bioRxiv.
.. [Bai23]
Yong Bai, Xiangyu Guo, Keyin Liu, Bingjie Zheng, Yingyue Wang, Qiuhong Luo, Jianhua Yin, Liang Wu, Yuxiang Li, Yong Zhang, Ao Chen, Xun Xu, Xin Jin,
*Efficient reliability analysis of spatially resolved transcriptomics at varying resolutions using SpaSEG*, bioRxiv.
3 changes: 2 additions & 1 deletion stereo/algorithm/ms_algorithm_base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from dataclasses import dataclass
from abc import ABCMeta

from stereo.algorithm.algorithm_base import AlgorithmBase, _camel_to_snake
from stereo.core.ms_data import MSData


@dataclass
class MSDataAlgorithmBase(AlgorithmBase):
class MSDataAlgorithmBase(metaclass=ABCMeta):
ms_data: MSData = None
pipeline_res: dict = None

Expand Down
1 change: 1 addition & 0 deletions stereo/algorithm/spa_seg/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .main import SpaSeg
13 changes: 13 additions & 0 deletions stereo/algorithm/spa_seg/_constant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from enum import Enum

class SpotSize(Enum):
MERFISH_SPOT_SIZE = 20
SEQFISH_SPOT_SIZE = 0.03
SLIDESEQV2_SPOT_SIZE = 15
VISIUM_SPOT_SIZE = 100
STEREO_SPOT_SIZE = 1


class CellbinSize(Enum):
# The average size of mammalian cell is approximately equal to Stereo-seq bin14*bin14
CELLBIN_SIZE = 14
39 changes: 39 additions & 0 deletions stereo/algorithm/spa_seg/_preprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import numpy as np
from anndata import AnnData

from ._constant import CellbinSize, SpotSize


def add_spot_pos(
adata: AnnData,
bin_type: str,
spatial_key: str
):
adata.obs["array_row"] = adata.obsm[spatial_key][:, 0]
adata.obs["array_col"] = adata.obsm[spatial_key][:, 1]

if np.min(adata.obsm[spatial_key]) < 0:
adata.obs['array_col'] = (adata.obs['array_col'].values - adata.obs['array_col'].values.min())
adata.obs['array_row'] = (adata.obs['array_row'].values - adata.obs['array_row'].values.min())

"""
The scale factor refer to the code in stLearn:
https://github.com/BiomedicalMachineLearning/stLearn/blob/master/stlearn/wrapper/read.py
"""

if bin_type == 'cell_bins':
scale = 1.0 / CellbinSize.CELLBIN_SIZE.value
# adata.uns["spot_size"] = SpotSize.STEREO_SPOT_SIZE.value
elif bin_type == "bins":
scale = 1
# adata.uns["spot_size"] = SpotSize.STEREO_SPOT_SIZE.value
else:
raise ValueError("Invalid bin type, available options: 'cell_bins', 'bins'")

adata.obs['array_col'] = adata.obs['array_col'] * scale
adata.obs['array_row'] = adata.obs['array_row'] * scale
adata.obsm['spatial_original'] = adata.obsm[spatial_key].copy()
adata.obsm[spatial_key][:, 0] = adata.obs['array_row'].to_numpy(copy=True)
adata.obsm[spatial_key][:, 1] = adata.obs['array_col'].to_numpy(copy=True)

return adata
160 changes: 160 additions & 0 deletions stereo/algorithm/spa_seg/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import random
import os
import torch
# import scanpy as sc
import numpy as np
import anndata as ad

from sklearn.metrics.cluster import normalized_mutual_info_score
from sklearn.metrics.cluster import adjusted_rand_score
from sklearn.metrics import calinski_harabasz_score
from sklearn.metrics import silhouette_score
from sklearn.metrics import davies_bouldin_score

from stereo.algorithm.scale import scale

def seed_torch(seed):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.enabled = False

def get_max_H_W(adata_list):
H = 0
W = 0
for adata in adata_list:
col, row = adata.obs['array_col'].values.astype(int), adata.obs['array_row'].values.astype(int)
H = max(H, row.max()+1)
W = max(W, col.max()+1)
return H, W

def get_3d_expMatrix(adata, channel, H, W):
# x_pca_scale = sc.pp.scale(adata.obsm['X_pca'], copy=True)
x_pca_scale = scale(adata.obsm['X_pca'], zero_center=True, max_value=None)
col = adata.obs['array_col'].values.astype(int) # y-coordinate
row = adata.obs['array_row'].values.astype(int) # x-coordinate
poss = zip(row, col) # (x, y)
# build and fill the 3D matrix of each spot with the corresponding PCA
# mxt = np.zeros((channel, max(row) + 1, max(col) + 1))
# H: x-coordinate, W: y-coordinate
mxt = np.zeros((channel, H, W), dtype=x_pca_scale.dtype)

for i, idx in enumerate(poss): #(x, y)
mxt[:, idx[0], idx[1]] = x_pca_scale[i, :]
return mxt, col, row

def add_embedding(adata, H_embedding, W_embedding, embedding, opt):
SpaSEG_embedding = embedding.reshape((H_embedding, W_embedding, opt.nChannel))
col = adata.obs['array_col'].values.astype(int)
row = adata.obs['array_row'].values.astype(int)
shape = adata.obsm['X_pca'].shape

poss = zip(col, row)
SpaSEG_pca = np.zeros(shape)

for i, idx in enumerate(poss):
SpaSEG_pca[i, :] = SpaSEG_embedding[idx[1], idx[0], :]

adata.obsm["SpaSEG_embedding"] = SpaSEG_pca

return adata


def outlier(arrayMatrix):
arraystd = np.std(arrayMatrix)
arraymean = np.mean(arrayMatrix)
arrayoutlier = np.where(np.abs(arrayMatrix - arraymean) > (arraystd)) # or 2*arraystd)
# arrayoutlier=np.transpose(np.where(np.abs(arrayMatrix-arraymean)>(arraystd)))#or 2*arraystd)
return arrayoutlier


def merge_outlier(target, data, im, device):
#####################obtain cutoff#####################
nLabels = len(np.unique(np.array(target.cpu())))
labels = np.array(target.cpu())
labels = labels.reshape(im.shape[0] * im.shape[1] * im.shape[2])
u_labels = np.unique(labels)
l_inds = []
for i in range(len(u_labels)):
l_inds.append(np.where(labels == u_labels[i])[0])

l_avg_a = []
for i in range(len(u_labels)):
dataxx = data.permute(0, 2, 3, 1).contiguous().view(-1, data.shape[1])
xxx = dataxx[l_inds[i], :]

l_avg = np.mean(xxx.cpu().detach().numpy(), axis=0)
l_avg_a.append(l_avg.reshape(-1))
dist = []
dist_ind_i = []
dist_ind_j = []
for i in range(len(u_labels)):
for j in range(len(u_labels)):
if j > i:
dist.append(np.linalg.norm(l_avg_a[i] - l_avg_a[j]))
dist_ind_i.append(i)
dist_ind_j.append(j)
output_outlier = outlier(np.array(dist))

idx = np.where(dist == np.min(dist))[0][0]
# print (output_outlier,'yyyyyyy')
if output_outlier[0] != [] and idx in output_outlier[0]:
index_need_change = np.where(labels == u_labels[dist_ind_j[idx]])
target[index_need_change] = torch.as_tensor(u_labels[dist_ind_i[idx]]).to(device)
else:
target = target

return target

def cal_metric(adata, pred_labels_column=None, true_labels_column=None, result_prefix='SpaSEG'):
if true_labels_column:
# adata.obs['ground_truth_code'] = adata.obs[ground_truth_index].cat.codes
# ground_truth = adata.obs['ground_truth_code']
true_labels = adata.obs[true_labels_column]
pred_labels = adata.obs[pred_labels_column]
# nmi
NMI = normalized_mutual_info_score(np.array(true_labels), np.array(pred_labels))
# print('nmi=', NMI, end=' ')
# ari
ARI = adjusted_rand_score(np.array(true_labels), np.array(pred_labels))
# print('ari=', ARI)
metric_dict = {"ARI": ARI, "NMI": NMI}
adata.uns[f"{result_prefix}_metrics_1"] = metric_dict
else:
input_feature_X = adata.obsm['X_pca']
pred_labels = adata.obs[pred_labels_column]
CHS = calinski_harabasz_score(input_feature_X, pred_labels)
SC = silhouette_score(input_feature_X, pred_labels)
DBS = davies_bouldin_score(input_feature_X, pred_labels)
metric_dict = {"CHS": CHS, "SC": SC, "DBS": DBS}
adata.uns[f"{result_prefix}_metrics_2"] = metric_dict

print(metric_dict)

return adata

# def batch_umap_plot(adata_list, sample_id_list):
# adata_map = {sample_id:adata for sample_id, adata in zip(sample_id_list, adata_list)}
# adatas = ad.concat(adata_map, join="inner", index_unique="_", label="batch")
# #adatas.obs['SpaSEG_batch_clusters'] = adatas.obs['SpaSEG_clusters'].astype('str')

# # visualize UMAP before batch correction using embedding from PCA
# sc.pp.neighbors(adatas, use_rep="X_pca", key_added="neighbor_X_pca")
# sc.tl.umap(adatas, neighbors_key="neighbor_X_pca")
# sc.pl.umap(adatas, color="batch", neighbors_key="neighbor_X_pca", title="Uncorrected",
# save='SpaSEG_Uncorrected_batch.pdf', show=False, frameon=False)
# sc.pl.umap(adatas, color="SpaSEG_clusters", neighbors_key="neighbor_X_pca", title="Uncorrected",
# save='SpaSEG_Uncorrected_clusters.pdf', show=False, frameon=False)

# # visualize UMAP after batch correction using embedding from SpaSEG
# sc.pp.neighbors(adatas, use_rep="SpaSEG_embedding", key_added="neighbor_SpaSEG")
# sc.tl.umap(adatas, neighbors_key="neighbor_SpaSEG")
# sc.pl.umap(adatas, color="batch", neighbors_key="neighbor_SpaSEG", title="SpaSEG",
# save='SpaSEG_corrected_batch.pdf', show=False, frameon=False)
# sc.pl.umap(adatas, color="SpaSEG_clusters", neighbors_key="neighbor_SpaSEG", title="SpaSEG",
# save='SpaSEG_corrected_clusters.pdf', show=False, frameon=False)
Loading

0 comments on commit 3e82793

Please sign in to comment.