Skip to content

Commit

Permalink
do some updates for fixing bugs and add new features.
Browse files Browse the repository at this point in the history
  • Loading branch information
tanliwei-genomics-cn committed Nov 29, 2023
1 parent 6bb5082 commit 88cc086
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 6 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ param==1.12.2
hvplot==0.7.3
colorcet==2.0.6
datashader>=0.13.0,<=0.14.1rc1
anndata==0.7.5
anndata>=0.7.5
phenograph==1.5.7
requests>=2.31.0
urllib3==1.26.9
Expand Down
4 changes: 4 additions & 0 deletions stereo/algorithm/ccd/community_clustering_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,10 @@ def plot_clustering(self):
labels = np.unique(self.adata.obs[f'tissue_{self.method_key}'].values)
if 'unknown' in labels:
labels = labels[labels != 'unknown']
if len(labels) > len(self.cluster_palette):
logger.warning(f"Number of clusters ({len(labels)}) is larger than pallette size. All clusters will be colored gray.")
self.cluster_palette = {l: '#CCCCCC' for l in labels}
self.cluster_palette['unknown'] = '#CCCCCC'
plot_spatial(self.adata, annotation=f'tissue_{self.method_key}', palette=self.cluster_palette,
spot_size=self.spot_size, ax=ax, title=f'{self.adata.uns["sample_name"]}')
handles, labels = ax.get_legend_handles_labels()
Expand Down
7 changes: 5 additions & 2 deletions stereo/algorithm/community_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,11 +259,11 @@ def _main(self, slices, annotation="annotation", **kwargs):
try:
algo.plot_stats()
except Exception as e:
print('plot_stats raise exception while running multi slice, err=%s', str(e))
print('plot_stats raise exception while running multi slice, err=%s' % str(e))
try:
algo.plot_celltype_table()
except Exception as e:
print('plot_celltype_table raise exception while running multi slice, err=%s', str(e))
print('plot_celltype_table raise exception while running multi slice, err=%s' % str(e))
if self.params['plotting'] > 2:
algo.plot_cluster_mixtures()
algo.boxplot_stats()
Expand Down Expand Up @@ -442,6 +442,9 @@ def plot_all_slices(self, img_name, clustering=False):
for (algo, ax) in zip(self.algo_list, axes.flatten()):
palette = algo.cluster_palette if clustering else algo.annotation_palette
annotation = f'tissue_{self.algo_list[0].method_key}' if clustering else self.algo_list[0].annotation
clusters = np.unique(algo.adata.obs[annotation].values)
if len(clusters) > len(cluster_palette):
logger.warning(f"Number of clusters ({len(clusters)}) is larger than pallette size. All clusters will be colored gray.")
plot_spatial(algo.adata, annotation=annotation, palette=palette, spot_size=algo.spot_size, ax=ax)
ax.get_legend().remove()
ax.set_title(f'{algo.filename}', fontsize=6, loc='center', wrap=True)
Expand Down
14 changes: 13 additions & 1 deletion stereo/core/ms_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from joblib import (
Parallel,
delayed,
Expand All @@ -6,7 +7,7 @@

from stereo import logger
from stereo.core import StPipeline
from stereo.plots.decorator import download
from stereo.plots.decorator import download, download_only


class _scope_slice(object):
Expand Down Expand Up @@ -161,6 +162,11 @@ def _run_isolated_method(self, item, *args, **kwargs):
if new_attr:
def log_delayed_task(idx, *arg, **kwargs):
logger.info(f'data_obj(idx={idx}) in ms_data start to run {item}')
if self.__class__.ATTR_NAME == 'plt':
out_path = kwargs.get('out_path', None)
if out_path is not None:
path_name, ext = os.path.splitext(out_path)
kwargs['out_path'] = f'{path_name}_{idx}{ext}'
new_attr(*arg, **kwargs)

Parallel(n_jobs=n_jobs, backend='threading', verbose=100)(
Expand All @@ -178,6 +184,12 @@ def log_delayed_task(idx, *arg, **kwargs):
def log_delayed_task(idx, obj, *arg, **kwargs):
logger.info(f'data_obj(idx={idx}) in ms_data start to run {item}')
new_attr = base.get_attribute_helper(item, obj, obj.tl.result)
if base == PlotBase:
out_path = kwargs.get('out_path', None)
if out_path is not None:
path_name, ext = os.path.splitext(out_path)
kwargs['out_path'] = f'{path_name}_{idx}{ext}'
new_attr = download_only(new_attr)
if new_attr:
new_attr(*arg, **kwargs)
else:
Expand Down
7 changes: 6 additions & 1 deletion stereo/core/st_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1321,6 +1321,8 @@ def annotation(
# df.group.cat.categories = new_annotation_list

cluster_res: pd.DataFrame = self.result[cluster_res_key]
if cluster_res['group'].dtype.name != 'category':
cluster_res['group'] = cluster_res['group'].astype('category')

if isinstance(annotation_information, (list, np.ndarray)) and \
len(annotation_information) != cluster_res['group'].cat.categories.size:
Expand All @@ -1333,7 +1335,10 @@ def annotation(
elif isinstance(annotation_information, dict):
new_categories_list = []
for i in cluster_res['group'].cat.categories:
new_categories_list.append(annotation_information[i])
if i in annotation_information:
new_categories_list.append(annotation_information[i])
else:
new_categories_list.append(i)
new_categories = np.array(new_categories_list, dtype='U')
else:
raise TypeError("The type of 'annotation_information' only supports list, ndarray or dict.")
Expand Down
2 changes: 1 addition & 1 deletion stereo/io/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,7 +942,7 @@ def stereo_to_anndata(

if output is not None:
adata.write_h5ad(output)
logger.info("Finished output to {output}")
logger.info(f"Finished output to {output}")

return adata

Expand Down
18 changes: 18 additions & 0 deletions stereo/plots/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,24 @@ def _action(_, figure: Figure):

return wrapped

def download_only(func):
@wraps(func)
def wrapped(*args, **kwargs):
out_path = None
dpi = 100
if 'out_path' in kwargs:
out_path = kwargs['out_path']
del kwargs['out_path']
if 'out_dpi' in kwargs:
dpi = kwargs['out_dpi']
del kwargs['out_dpi']
fig: Figure = func(*args, **kwargs)
if type(fig) is Figure and out_path is not None:
fig.savefig(out_path, bbox_inches='tight', dpi=dpi)
return fig

return wrapped


def reorganize_coordinate(func):
@wraps(func)
Expand Down

0 comments on commit 88cc086

Please sign in to comment.