Skip to content

Commit

Permalink
SC_48 make _plot public (#89)
Browse files Browse the repository at this point in the history
  • Loading branch information
ArneDefauw authored Jan 28, 2025
1 parent f869e74 commit 00e0314
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 23 deletions.
1 change: 1 addition & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ Plotting functions.
.. autosummary::
:toctree: generated
pl.plot
pl.plot_image
pl.plot_shapes
pl.plot_labels
Expand Down
2 changes: 1 addition & 1 deletion docs/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ pip install basicpy

On Mac, please comment out the line `mkl=2024.0.0` in `environment.yml`.

For a mimimal list of requirements for `Harpy`, we refer to the [setup.cfg](../setup.cfg).
For a mimimal list of requirements for `Harpy`, we refer to the [pyproject.toml](../pyproject.toml).

## 2. Install `Harpy`:

Expand Down
2 changes: 1 addition & 1 deletion src/harpy/plot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ._clustering import cluster
from ._enrichment import nhood_enrichment
from ._flowsom import pixel_clusters, pixel_clusters_heatmap
from ._plot import plot_image, plot_labels, plot_shapes
from ._plot import plot, plot_image, plot_labels, plot_shapes
from ._preprocess import preprocess_transcriptomics
from ._qc_cells import plot_adata, ridgeplot_channel, ridgeplot_channel_sample
from ._qc_image import (
Expand Down
40 changes: 26 additions & 14 deletions src/harpy/plot/_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import pandas as pd
from geopandas.geodataframe import GeoDataFrame
from geopandas.geoseries import GeoSeries
from matplotlib.axes import Axes
from scipy.sparse import issparse
from shapely.affinity import translate
from spatialdata import SpatialData
Expand Down Expand Up @@ -151,7 +152,10 @@ def plot_shapes(
output: str | Path | None = None,
) -> None:
"""
Plot shapes and/or images/labels from a SpatialData object.
Plots a SpatialData object.
This function support plotting of a raster (`img_layer` or `labels_layer`), together with a `shapes_layer` respresenting (cell) boundaries.
These shapes can be colored if a `table_layer` and a `column` is specified.
The number of provided `img_layer` or `labels_layer` and `shapes_layer` should be equal if both are iterables and if their length is greater than 1.
Expand Down Expand Up @@ -195,8 +199,10 @@ def plot_shapes(
Labels layer(s) to be plotted.
Displayed as columns in the plot, if multiple are provided.
shapes_layer
Specifies which shapes to plot. If set to None, no shapes_layer is plotted.
Displayed as columns in the plot, if multiple are provided.
Specifies which shapes to plot. Default is 'segmentation_mask_boundaries'. If set to None, no shapes_layer is plot.
Can be colored by `column` in `sdata.tables[table_layer].obs` or `sdata.tables[table_layer].var`.
For this the index of the `shapes_layer` will be matched with `sdata.tables[table_layer].obs[_INSTANCE_KEY]` for those observations for which
`sdata.tables[table_layer].obs[_REGION_KEY]` equals `region` (if `region` is not `None`).
table_layer
Table layer to be plotted (i.e. to base cell colors on) if `column` is specified.
column
Expand Down Expand Up @@ -278,7 +284,7 @@ def plot_shapes(
"""
if img_layer is not None and labels_layer is not None:
raise ValueError(
"Both img_layer and labels_layer is not None. " "Please specify either img_layer or labels_layer, not both."
"Both img_layer and labels_layer is not None. Please specify either img_layer or labels_layer, not both."
)

if column is not None and table_layer is None:
Expand Down Expand Up @@ -363,8 +369,8 @@ def plot_shapes(

idx = 0
for _channel in channels:
for _layer, _shapes_layer in zip(layer, shapes_layer):
_plot(
for _layer, _shapes_layer in zip(layer, shapes_layer, strict=True):
plot(
sdata,
axes[idx],
img_layer=_layer if img_layer_type else None,
Expand Down Expand Up @@ -404,9 +410,9 @@ def plot_shapes(
plt.close()


def _plot(
def plot(
sdata: SpatialData,
ax: plt.Axes,
ax: Axes,
img_layer: str | None = None,
labels_layer: str | None = None,
shapes_layer: str | None = "segmentation_mask_boundaries",
Expand All @@ -432,22 +438,28 @@ def _plot(
shapes_title: bool = False,
channel_title: bool = True,
aspect: str = "equal",
) -> plt.Axes:
) -> Axes:
"""
Plots a SpatialData object.
This function support plotting of a raster (`img_layer` or `labels_layer`), together with a `shapes_layer` respresenting (cell) boundaries.
These shapes can be colored if a `table_layer` and a `column` is specified.
Parameters
----------
sdata
Data containing spatial information for plotting.
ax
Axes object to plot on.
Matplotlib axes object to plot on.
img_layer
Image layer to be plotted. By default, the last added image layer is plotted.
labels_layer
Labels layer to be plotted.
shapes_layer
Specifies which shapes to plot. Default is 'segmentation_mask_boundaries'. If set to None, no shapes_layer is plot.
Can be colored by `column` in `sdata.tables[table_layer].obs` or `sdata.tables[table_layer].var`.
For this the index of the `shapes_layer` will be matched with `sdata.tables[table_layer].obs[_INSTANCE_KEY]` for those observations for which
`sdata.tables[table_layer].obs[_REGION_KEY]` equals `region` (if `region` is not `None`).
table_layer
Table layer to be plotted (i.e. to base cell colors on) if `column` is specified.
column
Expand Down Expand Up @@ -500,7 +512,7 @@ def _plot(
Returns
-------
The axes with the plotted SpatialData.
The Axes object.
Raises
------
Expand All @@ -525,7 +537,7 @@ def _plot(
"""
if img_layer is not None and labels_layer is not None:
raise ValueError(
"Both img_layer and labels_layer is not None. " "Please specify either img_layer or labels_layer, not both."
"Both img_layer and labels_layer is not None. Please specify either img_layer or labels_layer, not both."
)

if column is not None and table_layer is None:
Expand Down Expand Up @@ -653,7 +665,7 @@ def _plot(
mask_polygons = polygons.index.isin(adata_view.obs[_INSTANCE_KEY])
if (~mask_polygons).any():
log.warning(
f"There are '{sum( ~mask_polygons )}' cells in provided shapes_layer '{shapes_layer}' not found in 'sdata.tables[{table_layer}]' (linked through '{_INSTANCE_KEY}'), these cells will not be plotted."
f"There are '{sum(~mask_polygons)}' cells in provided shapes_layer '{shapes_layer}' not found in 'sdata.tables[{table_layer}]' (linked through '{_INSTANCE_KEY}'), these cells will not be plotted."
)
polygons = polygons[mask_polygons]

Expand Down Expand Up @@ -736,7 +748,7 @@ def _plot(
else:
log.info(
f"Layer '{layer}' has 3 spatial dimensions, but no z-slice was specified. "
f"By default the z-slice located at the midpoint of the z-dimension ({_se.shape[0]//2}) will be utilized."
f"By default the z-slice located at the midpoint of the z-dimension ({_se.shape[0] // 2}) will be utilized."
)
_se = _se[_se.shape[0] // 2, ...]

Expand Down
16 changes: 9 additions & 7 deletions src/harpy/table/_allocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,9 @@ def bin_counts(

# Sanity check that every barcode that could be assigned to a bin is assigned exactly ones to a bin.
_mask = cell_counts == 1
assert _mask.all(), f"Some spots, given by 'sdata.tables[{table_layer}].obsm[{_SPATIAL}]', where assigned to more than one cell defined in '{labels_layer}'."
assert _mask.all(), (
f"Some spots, given by 'sdata.tables[{table_layer}].obsm[{_SPATIAL}]', where assigned to more than one cell defined in '{labels_layer}'."
)
cell_counts = cell_counts.reset_index(level=_CELL_INDEX)
assert cell_counts.index.is_unique, "Spots should not be assigned to more than one cell."

Expand All @@ -283,9 +285,9 @@ def bin_counts(
# get adata
adata_in = sdata.tables[table_layer].copy() # should we do a copy here? otherwise in memory adata will be changed
merged = pd.merge(adata_in.obs, cell_counts[_CELL_INDEX], left_index=True, right_index=True, how="inner")
assert (
merged.shape[0] != 0
), "Result after merging AnnData object, passed via 'table_layer' parameter with aggregated spots is empty."
assert merged.shape[0] != 0, (
"Result after merging AnnData object, passed via 'table_layer' parameter with aggregated spots is empty."
)
adata_in = adata_in[merged.index]
adata_in.obs = merged

Expand Down Expand Up @@ -430,13 +432,13 @@ def _process_partition(_chunk, _chunk_info, ddf_partition):
# Create a list to store delayed operations
delayed_objects = []

for _chunk, _chunk_info in zip(delayed_chunks, chunk_info):
for _chunk, _chunk_info in zip(delayed_chunks, chunk_info, strict=True):
# Query the partition lazily without computing it
z_start, y_start, x_start = _chunk_info[0]
_chunk_shape = _chunk_info[1]

y_query = f"{y_start + coords.y0 } <= {name_y} < {y_start + coords.y0 + _chunk_shape[1]}"
x_query = f"{x_start + coords.x0 } <= {name_x} < {x_start + coords.x0 + _chunk_shape[2]}"
y_query = f"{y_start + coords.y0} <= {name_y} < {y_start + coords.y0 + _chunk_shape[1]}"
x_query = f"{x_start + coords.x0} <= {name_x} < {x_start + coords.x0 + _chunk_shape[2]}"
query = f"{y_query} and {x_query}"

if name_z in ddf.columns:
Expand Down

0 comments on commit 00e0314

Please sign in to comment.