Skip to content

Commit

Permalink
Increase test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
earmingol committed Dec 18, 2023
1 parent b209840 commit 6043437
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 0 deletions.
15 changes: 15 additions & 0 deletions sccellfie/expression/tests/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,27 @@ def test_agg_expression_cells_specific_gene_present():
adata.obs[groupby] = ['group1' if i < adata.n_obs // 2 else 'group2' for i in range(adata.n_obs)]
gene_symbols = ['gene1', 'gene10']

# Test aggregation with one gene
agg_result = agg_expression_cells(adata, groupby, gene_symbols=gene_symbols[0], agg_func='mean')
assert agg_result.shape == (len(adata.obs[groupby].unique()), 1), "Shape mismatch"
assert gene_symbols[0] in agg_result.columns, "Missing gene in result"

# Test aggregation with specific genes
agg_result = agg_expression_cells(adata, groupby, gene_symbols=gene_symbols, agg_func='mean')
assert agg_result.shape == (len(adata.obs[groupby].unique()), len(gene_symbols)), "Shape mismatch"
assert all(gene in agg_result.columns for gene in gene_symbols), "Missing genes in result"


def test_agg_expression_cells_layer():
adata = create_random_adata(layers='gene_scores')
groupby = 'group'
adata.obs[groupby] = ['group1' if i < adata.n_obs // 2 else 'group2' for i in range(adata.n_obs)]

# Test aggregation with one gene
agg_result = agg_expression_cells(adata, groupby, layer='gene_scores', agg_func='mean')
assert agg_result.shape == (len(adata.obs[groupby].unique()), adata.shape[1]), "Shape mismatch"


def test_agg_expression_cells_invalid_agg_func():
adata = create_random_adata()
groupby = 'group'
Expand Down
2 changes: 2 additions & 0 deletions sccellfie/tests/toy_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ def create_random_adata(n_obs=100, n_vars=50, layers=None):
adata = sc.AnnData(X=X, obs=obs, var=var)

if layers:
if isinstance(layers, str):
layers = [layers]
for layer in layers:
adata.layers[layer] = np.random.rand(n_obs, n_vars)
return adata
Expand Down

0 comments on commit 6043437

Please sign in to comment.