Skip to content

Commit c499046

Browse files
update for filtering
1 parent bac2598 commit c499046

File tree

3 files changed

+56
-23
lines changed

3 files changed

+56
-23
lines changed

stereo/core/cell.py

+21-9
Original file line numberDiff line numberDiff line change
@@ -174,23 +174,35 @@ def _set_batch(self, batch: Union[np.ndarray, list, int, str]):
174174

175175
def sub_set(self, index):
176176
"""
177-
get the subset of Cell by the index info the Cell object will be inplaced by the subset.
177+
get the subset of Cell by the index info, the Cell object will be inplaced by the subset.
178178
179179
:param index: a numpy array of index info.
180180
:return: the subset of Cell object.
181181
"""
182182

183183
if self.cell_border is not None:
184184
self.cell_border = self.cell_border[index]
185-
if isinstance(index, list) or isinstance(index, slice):
186-
self._obs = self._obs.iloc[index].copy()
187-
elif isinstance(index, np.ndarray):
188-
if index.dtype == bool:
189-
self._obs = self._obs[index].copy()
185+
# if isinstance(index, list) or isinstance(index, slice):
186+
# self._obs = self._obs.iloc[index].copy()
187+
# elif isinstance(index, np.ndarray):
188+
# if index.dtype == bool:
189+
# self._obs = self._obs[index].copy()
190+
# else:
191+
# self._obs = self._obs.iloc[index].copy()
192+
# else:
193+
# self._obs = self._obs.iloc[index].copy()
194+
if isinstance(index, pd.Series):
195+
index = index.to_numpy()
196+
self._obs = self._obs.iloc[index].copy()
197+
for col in self._obs.columns:
198+
if self._obs[col].dtype.name == 'category':
199+
self._obs[col] = self._obs[col].cat.remove_unused_categories()
200+
for key, value in self._matrix.items():
201+
if isinstance(value, pd.DataFrame):
202+
self._matrix[key] = value.iloc[index].copy()
203+
self._matrix[key].reset_index(drop=True, inplace=True)
190204
else:
191-
self._obs = self._obs.iloc[index].copy()
192-
else:
193-
self._obs = self._obs.iloc[index].copy()
205+
self._matrix[key] = value[index]
194206
return self
195207

196208
def get_property(self, name):

stereo/core/gene.py

+15-9
Original file line numberDiff line numberDiff line change
@@ -148,15 +148,21 @@ def sub_set(self, index):
148148
:param index: a numpy array of index info.
149149
:return: the subset of Gene object.
150150
"""
151-
if isinstance(index, list) or isinstance(index, slice):
152-
self._var = self._var.iloc[index].copy()
153-
elif isinstance(index, np.ndarray):
154-
if index.dtype == bool:
155-
self._var = self._var[index].copy()
156-
else:
157-
self._var = self._var.iloc[index].copy()
158-
else:
159-
self._var = self._var.iloc[index].copy()
151+
# if isinstance(index, list) or isinstance(index, slice):
152+
# self._var = self._var.iloc[index].copy()
153+
# elif isinstance(index, np.ndarray):
154+
# if index.dtype == bool:
155+
# self._var = self._var[index].copy()
156+
# else:
157+
# self._var = self._var.iloc[index].copy()
158+
# else:
159+
# self._var = self._var.iloc[index].copy()
160+
if isinstance(index, pd.Series):
161+
index = index.to_numpy()
162+
self._var = self._var.iloc[index].copy()
163+
for col in self._var.columns:
164+
if self._var[col].dtype.name == 'category':
165+
self._var[col] = self._var[col].cat.remove_unused_categories()
160166
return self
161167

162168
def to_df(self, copy=False):

stereo/preprocess/filter.py

+20-5
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@ def filter_cells(
3535
max_n_genes_by_counts=None,
3636
pct_counts_mt=None,
3737
cell_list=None,
38-
inplace=True):
38+
excluded=False,
39+
inplace=True,
40+
):
3941
"""
4042
filter cells based on numbers of genes expressed.
4143
@@ -46,6 +48,7 @@ def filter_cells(
4648
:param max_n_genes_by_counts: Maximum number of n_genes_by_counts for a cell pass filtering.
4749
:param pct_counts_mt: Maximum number of pct_counts_mt for a cell pass filtering.
4850
:param cell_list: the list of cells which will be filtered.
51+
:param excluded: set it to True to exclude the cells which are specified by parameter `cell_list` while False to include.
4952
:param inplace: whether inplace the original data or return a new data.
5053
5154
:return: StereoExpData object.
@@ -67,7 +70,10 @@ def filter_cells(
6770
if pct_counts_mt:
6871
cell_subset &= data.cells.pct_counts_mt <= pct_counts_mt
6972
if cell_list is not None:
70-
cell_subset &= np.isin(data.cells.cell_name, cell_list)
73+
if excluded:
74+
cell_subset &= ~np.isin(data.cells.cell_name, cell_list)
75+
else:
76+
cell_subset &= np.isin(data.cells.cell_name, cell_list)
7177
data.sub_by_index(cell_index=cell_subset)
7278
return data
7379

@@ -80,6 +86,8 @@ def filter_genes(
8086
max_count=None,
8187
gene_list=None,
8288
mean_umi_gt=None,
89+
excluded=False,
90+
filter_mt_genes=False,
8391
inplace=True
8492
):
8593
"""
@@ -90,14 +98,16 @@ def filter_genes(
9098
:param max_cell: Maximun number of cells for a gene pass filtering.
9199
:param mean_umi_gt: Filter genes whose mean umi greater than this value.
92100
:param gene_list: the list of genes which will be filtered.
101+
:param excluded: set it to True to exclude the genes which are specified by parameter `gene_list` while False to include.
93102
:param inplace: whether inplace the original data or return a new data.
94103
95104
:return: StereoExpData object.
96105
"""
97106
data = data if inplace else copy.deepcopy(data)
98-
if min_cell is None and max_cell is None \
107+
if not filter_mt_genes and \
108+
(min_cell is None and max_cell is None \
99109
and min_count is None and max_count is None \
100-
and gene_list is None and mean_umi_gt is None:
110+
and gene_list is None and mean_umi_gt is None):
101111
raise ValueError('please set any of `min_cell`, `max_cell`, `min_count`, `max_count`, `gene_list` and `mean_umi_gt`')
102112
cal_genes_indicators(data)
103113
gene_subset = np.ones(data.genes.size, dtype=np.bool8)
@@ -110,9 +120,14 @@ def filter_genes(
110120
if max_count:
111121
gene_subset &= data.genes.n_counts <= max_count
112122
if gene_list is not None:
113-
gene_subset &= np.isin(data.gene_names, gene_list)
123+
if excluded:
124+
gene_subset &= ~np.isin(data.gene_names, gene_list)
125+
else:
126+
gene_subset &= np.isin(data.gene_names, gene_list)
114127
if mean_umi_gt is not None:
115128
gene_subset &= data.genes.mean_umi > mean_umi_gt
129+
if filter_mt_genes:
130+
gene_subset &= ~np.char.startswith(np.char.lower(data.gene_names), 'mt-')
116131
data.sub_by_index(gene_index=gene_subset)
117132
return data
118133

0 commit comments

Comments
 (0)