Skip to content

Commit

Permalink
do some updates
Browse files Browse the repository at this point in the history
  • Loading branch information
tanliwei-coder committed Jul 31, 2024
1 parent 780b4d1 commit 29a6960
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 23 deletions.
27 changes: 15 additions & 12 deletions stereo/io/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def read_gem(
bin_type: str = "bins",
bin_size: int = 100,
is_sparse: bool = True,
bin_coord_offset: bool = False,
center_coordinates: bool = False,
gene_name_index: bool = False
):
"""
Expand All @@ -66,9 +66,9 @@ def read_gem(
the size of bin to merge, when `bin_type` is set to `'bins'`.
is_sparse
the expression matrix is sparse matrix, if `True`, otherwise `np.ndarray`.
bin_coord_offset
if set it to True, the coordinates of bins are calculated as
((gene_coordinates - min_coordinates) // bin_size) * bin_size + min_coordinates + bin_size/2
center_coordinates
if set it to True, the coordinate of each bin will be the center of the bin,
otherwise, the coordinate of each bin will be the left-top corner of the bin.
gene_name_index
In a v0.1 gem file, the column geneID is the gene name actually, but in a v0.2,
geneID just a ID for genes and there is an additional column called geneName where is the gene name,
Expand Down Expand Up @@ -106,7 +106,7 @@ def read_gem(
if data.bin_type == 'cell_bins':
gdf = parse_cell_bin_coor(df)
else:
if bin_coord_offset:
if center_coordinates:
df = parse_bin_coor(df, bin_size)
else:
df = parse_bin_coor_no_offset(df, bin_size)
Expand Down Expand Up @@ -153,7 +153,7 @@ def read_gem(
'maxExp': data.exp_matrix.max(), # noqa
'resolution': resolution,
}
data.bin_coord_offset = bin_coord_offset
data.center_coordinates = center_coordinates
logger.info(f'the martrix has {data.cell_names.size} cells, and {data.gene_names.size} genes.')
return data

Expand Down Expand Up @@ -732,6 +732,7 @@ def read_h5ad(
flavor: str = 'scanpy',
bin_type: str = None,
bin_size: int = None,
spatial_key: str = 'spatial',
**kwargs
) -> Union[StereoExpData, AnnBasedStereoExpData]:
"""
Expand All @@ -742,20 +743,22 @@ def read_h5ad(
file_path
the path of the h5ad file.
anndata
the object of AnnData which to be loaded, only available while `flavor` is `'scanpy'`.
the object of AnnData to be loaded, only available while `flavor` is `'scanpy'`.
`file_path` and `anndata` only can input one of them.
flavor
the format of the h5ad file, defaults to `'scanpy'`.
`scanpy`: AnnData format of scanpy
`stereopy`: h5ad format of stereo
`stereopy`: h5 format of stereo
bin_type
the bin type includes `'bins'` or `'cell_bins'`.
the bin type includes `'bins'` and `'cell_bins'`.
bin_size
the size of bin to merge, when `bin_type` is set to `'bins'`.
spatial_key
the key of spatial information in AnnData.obsm, default to `'spatial'`.
Only available while `flavor` is `'scanpy'`.
Returns
---------------
An object of StereoExpData while `flavor` is `'stereopy'` or an object of AnnBasedStereoExpData while `flavor` is
`'scanpy'`
An object of StereoExpData while `flavor` is `'stereopy'` or an object of AnnBasedStereoExpData while `flavor` is `'scanpy'`
"""
flavor = flavor.lower()
Expand All @@ -775,7 +778,7 @@ def read_h5ad(
if file_path is not None and anndata is not None:
raise Exception("'file_path' and 'anndata' only can input one of them")
return AnnBasedStereoExpData(h5ad_file_path=file_path, based_ann_data=anndata, bin_type=bin_type,
bin_size=bin_size, **kwargs)
bin_size=bin_size, spatial_key=spatial_key, **kwargs)
else:
raise ValueError("Invalid value for 'flavor'")

Expand Down
6 changes: 3 additions & 3 deletions stereo/plots/interact_plot/interactive_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def generate_gem_file(

if selected_areas.bin_type == 'bins':
@nb.njit(cache=True, nogil=True, parallel=True)
def __get_filtering_flag(data, bin_size, position, bin_coord_offset, num_threads):
def __get_filtering_flag(data, bin_size, position, center_coordinates, num_threads):
num_threads = min(position.shape[0], num_threads)
num_per_thread = position.shape[0] // num_threads
num_left = position.shape[0] % num_threads
Expand All @@ -183,7 +183,7 @@ def __get_filtering_flag(data, bin_size, position, bin_coord_offset, num_threads
end = interval[i + 1]
for j in range(start, end):
x_start, y_start = position[j]
if bin_coord_offset:
if center_coordinates:
x_start -= bin_size // 2
y_start -= bin_size // 2
x_end = x_start + bin_size
Expand All @@ -198,7 +198,7 @@ def __get_filtering_flag(data, bin_size, position, bin_coord_offset, num_threads
original_gem_df[['x', 'y', 'UMICount']].to_numpy(),
selected_areas.bin_size,
selected_areas.position,
selected_areas.bin_coord_offset,
selected_areas.center_coordinates,
nb.get_num_threads(),
)
selected_gem_df = original_gem_df[flag]
Expand Down
6 changes: 3 additions & 3 deletions stereo/plots/interact_plot/poly_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def generate_gem_file(

if selected_areas.bin_type == 'bins':
@nb.njit(cache=True, nogil=True, parallel=True)
def __get_filtering_flag(data, bin_size, position, bin_coord_offset, num_threads, drop):
def __get_filtering_flag(data, bin_size, position, center_coordinates, num_threads, drop):
num_threads = min(position.shape[0], num_threads)
num_per_thread = position.shape[0] // num_threads
num_left = position.shape[0] % num_threads
Expand All @@ -227,7 +227,7 @@ def __get_filtering_flag(data, bin_size, position, bin_coord_offset, num_threads
end = interval[i + 1]
for j in range(start, end):
x_start, y_start = position[j]
if bin_coord_offset:
if center_coordinates:
x_start -= bin_size // 2
y_start -= bin_size // 2
x_end = x_start + bin_size
Expand All @@ -242,7 +242,7 @@ def __get_filtering_flag(data, bin_size, position, bin_coord_offset, num_threads
original_gem_df[['x', 'y', 'UMICount']].to_numpy(),
selected_areas.bin_size,
selected_areas.position,
selected_areas.bin_coord_offset,
selected_areas.center_coordinates,
nb.get_num_threads(),
drop
)
Expand Down
12 changes: 7 additions & 5 deletions stereo/utils/data_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,10 +321,12 @@ def __merge_for_stereo_exp_data(
new_data.sn[str(batch)] = data.sn
if i == 0:
new_data.exp_matrix = data.exp_matrix.copy()
new_data.cells = Cell(cell_name=cell_names, cell_border=data.cells.cell_border, batch=data.cells.batch)
new_data.genes = Gene(gene_name=data.gene_names)
new_data.cells._obs = data.cells._obs.copy(deep=True)
new_data.cells._obs.index = cell_names
# new_data.cells = Cell(cell_name=cell_names, cell_border=data.cells.cell_border, batch=data.cells.batch)
# new_data.genes = Gene(gene_name=data.gene_names)
# new_data.cells._obs = data.cells._obs.copy(deep=True)
# new_data.cells._obs.index = cell_names
new_data.cells = Cell(obs=data.cells._obs.copy(deep=True), cell_border=data.cells.cell_border, batch=data.cells.batch)
new_data.genes = Gene(var=data.genes._var.copy(deep=True))
new_data.position = data.position
if data.position_z is None:
new_data.position_z = np.repeat([[0]], repeats=data.position.shape[0], axis=0).astype(
Expand All @@ -337,7 +339,7 @@ def __merge_for_stereo_exp_data(
new_data.offset_y = data.offset_y
new_data.attr = data.attr
else:
current_obs = data.cells._obs.copy()
current_obs = data.cells._obs.copy(deep=True)
current_obs.index = cell_names
new_data.cells._obs = pd.concat([new_data.cells._obs, current_obs])
if new_data.cell_borders is not None and data.cell_borders is not None:
Expand Down

0 comments on commit 29a6960

Please sign in to comment.