Skip to content

Commit

Permalink
update some plot functions.
Browse files Browse the repository at this point in the history
  • Loading branch information
tanliwei-coder committed Aug 22, 2024
1 parent 2dcaff1 commit d7d82df
Show file tree
Hide file tree
Showing 6 changed files with 146 additions and 68 deletions.
1 change: 0 additions & 1 deletion stereo/io/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1261,7 +1261,6 @@ def read_gef(
}
logger.info(f'the matrix has {data.cell_names.size} cells, and {data.gene_names.size} genes.')
gef.cgef_close()
del gef
return data
else:
if is_cell_bin:
Expand Down
72 changes: 47 additions & 25 deletions stereo/plots/plot_clusters_heatmap.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import Optional
from typing import Sequence
from typing import Optional, Sequence, Literal

import matplotlib.pylab as plt
import numpy as np
Expand All @@ -23,7 +22,9 @@ def clusters_genes_heatmap(
self,
cluster_res_key: str,
dendrogram_res_key: Optional[str] = None,
topn: Optional[int] = 5,
gene_names: Optional[Sequence[str]] = None,
expression_kind: Literal['mean', 'sum'] = 'mean',
groups: Optional[Sequence[str]] = None,
width: int = None,
height: int = None,
Expand All @@ -36,7 +37,10 @@ def clusters_genes_heatmap(
:param cluster_res_key: the key to get cluster result.
:param dendrogram_res_key: the key to get dendrogram result, defaults to None to avoid show dendrogram on plot.
:param topn: select `topn` expressed genes in each cluster, defaults to 5, ignored if `gene_names` is not None,
the number of genes shown in plot may be more than `topn` because the `topn` genes in each cluster are not the same.
:param gene_names: a list of genes to show, defaults to None to show all genes.
:param expression_kind: the kind of expression to show, 'mean' or 'sum', defaults to 'mean'.
:param groups: a list of cell clusters to show, defaults to None to show all cell clusters.
:param width: the figure width in pixels, defaults to None.
:param height: the figure height in pixels, defaults to None.
Expand All @@ -60,16 +64,20 @@ def clusters_genes_heatmap(
f'The cluster result used in dendrogram may not be the same as that specified '
f'by key {cluster_res_key}')

if gene_names is None:
gene_names = self.stereo_exp_data.gene_names
else:
if isinstance(gene_names, str):
gene_names = np.array([gene_names], dtype='U')
elif not isinstance(gene_names, np.ndarray):
gene_names = np.array(gene_names, dtype='U')
if gene_names is not None:
topn = None

if topn is None:
if gene_names is None:
gene_names = self.stereo_exp_data.gene_names
else:
if isinstance(gene_names, str):
gene_names = np.array([gene_names], dtype='U')
elif not isinstance(gene_names, np.ndarray):
gene_names = np.array(gene_names, dtype='U')

if len(gene_names) == 0:
return None
if len(gene_names) == 0:
return None

if groups is None or drg_res is not None:
cluster_res: pd.DataFrame = self.pipeline_res[cluster_res_key]
Expand All @@ -86,21 +94,35 @@ def clusters_genes_heatmap(
elif not isinstance(group_codes, np.ndarray):
group_codes = np.array(group_codes, dtype='U')

mean_expression: pd.DataFrame = cell_cluster_to_gene_exp_cluster(
self.stereo_exp_data,
cluster_res_key,
groups=groups,
genes=gene_names,
kind='mean'
)
mean_expression = mean_expression[group_codes]
if topn is None:
genes_expression: pd.DataFrame = cell_cluster_to_gene_exp_cluster(
self.stereo_exp_data,
cluster_res_key,
groups=groups,
genes=gene_names,
kind='mean'
)
else:
genes_expression = cell_cluster_to_gene_exp_cluster(
self.stereo_exp_data,
cluster_res_key,
groups=groups,
kind=expression_kind
)
gene_names = []
for c in genes_expression.columns:
gene_names.extend(genes_expression[c].sort_values(ascending=False).index[:topn].tolist())
gene_names = np.unique(gene_names)
genes_expression = genes_expression.loc[gene_names]

genes_expression = genes_expression[group_codes]

if standard_scale == 'cluster':
mean_expression -= mean_expression.min(0)
mean_expression = (mean_expression / mean_expression.max(0)).fillna(0)
genes_expression -= genes_expression.min(0)
genes_expression = (genes_expression / genes_expression.max(0)).fillna(0)
elif standard_scale == 'gene':
mean_expression = mean_expression.sub(mean_expression.min(1), axis=0)
mean_expression = mean_expression.div(mean_expression.max(1), axis=0).fillna(0)
genes_expression = genes_expression.sub(genes_expression.min(1), axis=0)
genes_expression = genes_expression.div(genes_expression.max(1), axis=0).fillna(0)
elif standard_scale is None:
pass
else:
Expand Down Expand Up @@ -160,14 +182,14 @@ def clusters_genes_heatmap(
ax_heatmap = fig.add_subplot(axs_main[1, 0])
ax_colorbar = fig.add_subplot(axs_on_right[1, 0])
heatmap(
mean_expression,
genes_expression,
ax=ax_heatmap,
plot_colorbar=True,
colorbar_ax=ax_colorbar,
cmap=colormap,
colorbar_orientation='horizontal',
colorbar_ticklocation='bottom',
colorbar_title='Mean expression in group',
colorbar_title=f'{expression_kind.capitalize()} expression in group',
show_xaxis=True,
show_yaxis=True,
)
Expand Down
94 changes: 58 additions & 36 deletions stereo/plots/plot_clusters_scatter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import Optional
from typing import Sequence
from typing import Optional, Sequence, Literal

import matplotlib.pylab as plt
import numpy as np
Expand All @@ -24,7 +23,9 @@ def clusters_genes_scatter(
self,
cluster_res_key: str,
dendrogram_res_key: Optional[str] = None,
topn: Optional[int] = 5,
gene_names: Optional[Sequence[str]] = None,
expression_kind: Literal['mean', 'sum'] = 'mean',
groups: Optional[Sequence[str]] = None,
width: int = None,
height: int = None,
Expand All @@ -36,7 +37,10 @@ def clusters_genes_scatter(
:param cluster_res_key: the key to get cluster result.
:param dendrogram_res_key: the key to get dendrogram result, defaults to None to avoid show dendrogram on plot.
:param topn: select `topn` expressed genes in each cluster, defaults to 5, ignored if `gene_names` is not None,
the number of genes shown in plot may be more than `topn` because the `topn` genes in each cluster are not the same.
:param gene_names: a list of genes to show, defaults to None to show all genes.
:param expression_kind: the kind of expression to show, 'mean' or 'sum', defaults to 'mean'.
:param groups: a list of cell clusters to show, defaults to None to show all cell clusters.
:param width: the figure width in pixels, defaults to None
:param height: the figure height in pixels, defaults to None
Expand All @@ -59,16 +63,20 @@ def clusters_genes_scatter(
raise KeyError(f'The cluster result used in dendrogram may not be the same as that '
f'specified by key {cluster_res_key}')

if gene_names is None:
gene_names = self.stereo_exp_data.gene_names
else:
if isinstance(gene_names, str):
gene_names = np.array([gene_names], dtype='U')
elif not isinstance(gene_names, np.ndarray):
gene_names = np.array(gene_names, dtype='U')
if gene_names is not None:
topn = None

if topn is None:
if gene_names is None:
gene_names = self.stereo_exp_data.gene_names
else:
if isinstance(gene_names, str):
gene_names = np.array([gene_names], dtype='U')
elif not isinstance(gene_names, np.ndarray):
gene_names = np.array(gene_names, dtype='U')

if len(gene_names) == 0:
return None
if len(gene_names) == 0:
return None

if groups is None or drg_res is not None:
cluster_res: pd.DataFrame = self.pipeline_res[cluster_res_key]
Expand All @@ -85,6 +93,38 @@ def clusters_genes_scatter(
elif not isinstance(group_codes, np.ndarray):
group_codes = np.array(group_codes, dtype='U')

if topn is None:
genes_expression = cell_cluster_to_gene_exp_cluster(
self.stereo_exp_data,
cluster_res_key,
groups=groups,
genes=gene_names,
kind=expression_kind
)
else:
genes_expression = cell_cluster_to_gene_exp_cluster(
self.stereo_exp_data,
cluster_res_key,
groups=groups,
kind=expression_kind
)
gene_names = []
for c in genes_expression.columns:
gene_names.extend(genes_expression[c].sort_values(ascending=False).index[:topn].tolist())
gene_names = np.unique(gene_names)
genes_expression = genes_expression.loc[gene_names]

if standard_scale == 'cluster':
genes_expression -= genes_expression.min(0)
genes_expression = (genes_expression / genes_expression.max(0)).fillna(0)
elif standard_scale == 'gene':
genes_expression = genes_expression.sub(genes_expression.min(1), axis=0)
genes_expression = genes_expression.div(genes_expression.max(1), axis=0).fillna(0)
elif standard_scale is None:
pass
else:
logger.warning('Unknown type for standard_scale, ignored')

pct, _ = calc_pct_and_pct_rest(
self.stereo_exp_data,
cluster_res_key,
Expand All @@ -94,28 +134,9 @@ def clusters_genes_scatter(
if 'genes' in pct.columns:
pct.set_index('genes', inplace=True)

mean_expression = cell_cluster_to_gene_exp_cluster(
self.stereo_exp_data,
cluster_res_key,
groups=groups,
genes=gene_names,
kind='mean'
)

if standard_scale == 'cluster':
mean_expression -= mean_expression.min(0)
mean_expression = (mean_expression / mean_expression.max(0)).fillna(0)
elif standard_scale == 'gene':
mean_expression = mean_expression.sub(mean_expression.min(1), axis=0)
mean_expression = mean_expression.div(mean_expression.max(1), axis=0).fillna(0)
elif standard_scale is None:
pass
else:
logger.warning('Unknown type for standard_scale, ignored')

dot_plot_data = self._create_dot_plot_data(
pct,
mean_expression,
genes_expression,
group_codes,
gene_names
)
Expand Down Expand Up @@ -187,7 +208,7 @@ def clusters_genes_scatter(
)

ax_colorbar = fig.add_subplot(axs_on_right[1, 0])
self._plot_colorbar(ax_colorbar, main_im)
self._plot_colorbar(ax_colorbar, main_im, expression_kind)

ax_dot_size_map = fig.add_subplot(axs_on_right[3, 0])
self._plot_dot_size_map(ax_dot_size_map)
Expand Down Expand Up @@ -218,23 +239,24 @@ def _dotplot(
def _plot_colorbar(
self,
ax: Axes,
im
im,
expression_kind: str
):
ax.set_title('Mean expression in group', fontdict={'fontsize': self.__title_font_size})
ax.set_title(f'{expression_kind.capitalize()} expression in group', fontdict={'fontsize': self.__title_font_size})
plt.colorbar(im, cax=ax, orientation='horizontal', ticklocation='bottom')

def _create_dot_plot_data(
self,
pct: pd.DataFrame,
mean_expression: pd.DataFrame,
genes_expression: pd.DataFrame,
group_codes: Sequence[str],
gene_names: Sequence[str]
):
x = [i for i in range(len(group_codes))]
data_list = []
for y, g in enumerate(gene_names):
dot_size = pct.loc[g][group_codes] * 100
dot_color = mean_expression.loc[g][group_codes]
dot_color = genes_expression.loc[g][group_codes]
df = pd.DataFrame({
'x': x,
'y': y,
Expand Down
7 changes: 6 additions & 1 deletion stereo/plots/plot_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,7 +877,7 @@ def cluster_scatter(
:param show_others: whether to show others when groups is not None.
by default, if `base_image` is None, `show_others` is True, otherwise `show_others` is False.
:param others_color: the color of others, only available when `groups` is not None and `show_others` is True.
:param title: the plot title.
:param title: the plot title, defaults to None to be set as `res_key`, set it to False to disable the title.
:param x_label: the x label.
:param y_label: the y label.
:param dot_size: the dot size.
Expand Down Expand Up @@ -972,6 +972,11 @@ def cluster_scatter(
marker = kwargs['marker']
del kwargs['marker']

if title is None:
title = res_key
elif title is False:
title = None

fig = base_scatter(
x, y,
hue=group_list,
Expand Down
11 changes: 7 additions & 4 deletions stereo/plots/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,16 +170,16 @@ def base_scatter(
show_ticks: bool = False,
vmin: float = None,
vmax: float = None,
hue_order: any = None,
hue_order: Union[list, np.ndarray] = None,
width: float = None,
height: float = None,
boundary: list = None,
show_plotting_scale: bool = False,
plotting_scale_width: float = 2000,
data_resolution: int = None,
data_bin_offset: int = 1,
foreground_alpha: float = 0.5,
base_image: list = None,
foreground_alpha: float = None,
base_image: np.ndarray = None,
base_im_cmap: str = 'Greys',
base_im_boundary: list = None,
base_im_value_range: tuple = None,
Expand Down Expand Up @@ -247,8 +247,11 @@ def base_scatter(
bg_mask = np.where(base_image == bg_pixel, bg_value, 0)
base_image += bg_mask
ax.imshow(base_image, cmap=base_im_cmap, extent=base_im_boundary)
if foreground_alpha is None:
foreground_alpha = 0.5
else:
foreground_alpha = 1
if foreground_alpha is None:
foreground_alpha = 1

if color_bar:
colors = stereo_conf.linear_colors(palette, reverse=color_bar_reverse)
Expand Down
29 changes: 28 additions & 1 deletion stereo/plots/violin.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,32 @@ def _check_indices(

return col_keys, index_keys, index_aliases

def _check_order(order, count):
if order is None:
return [None] * count
elif not isinstance(order, (list, np.ndarray, pd.Index)):
raise ValueError("order must be a list, np.ndarray, pd.Index or None")

if isinstance(order, pd.Index):
order = [order] * count
elif isinstance(order, np.ndarray):
if order.ndim == 1:
order = [order] * count
elif order.ndim == 2:
if order.shape[0] < count:
raise ValueError("order must have the same number of rows as keys")
else:
raise ValueError("order must be 1D or 2D")
elif isinstance(order, list):
if len(order) == 0:
return [None] * count
elif isinstance(order[0], (list, np.ndarray, pd.Index)):
if len(order) != count:
raise ValueError("order must have the same number of elements as keys")
elif isinstance(order[0], (str, int, float, np.number)):
order = [order] * count
return order


def _get_array_values(
X,
Expand Down Expand Up @@ -309,7 +335,8 @@ def violin_distribution(
else:
axs = [ax]
fig = axs[0].figure
for ax, y, ylab in zip(axs, ys, y_label):
orders = _check_order(order, len(ys))
for ax, y, ylab, order in zip(axs, ys, y_label, orders):
ax = sns.violinplot(x=x, y=y, data=obs_df, order=order, orient='vertical', scale=scale, ax=ax,
palette=palette)
if show_stripplot:
Expand Down

0 comments on commit d7d82df

Please sign in to comment.