Skip to content

Commit

Permalink
Small performance boost and bugfix
Browse files Browse the repository at this point in the history
  • Loading branch information
ianfd committed Jan 25, 2024
1 parent 4f149e3 commit ad5ca91
Showing 1 changed file with 16 additions and 8 deletions.
24 changes: 16 additions & 8 deletions src/mousipy/mousipy.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,6 @@ def translate_multiple(adata, original_data, multiple, stay_sparse=False, verbos
"""
var = adata.var.copy()
ortholog_indices = {gene: i for i, gene in enumerate(var.index)}
new_genes = []

if stay_sparse:
# Sparse implementation remains unchanged
Expand All @@ -200,7 +199,11 @@ def translate_multiple(adata, original_data, multiple, stay_sparse=False, verbos

for hgene in hgenes:
if hgene not in ortholog_indices:
new_genes.append(hgene)
# Create a new DataFrame row for the new gene
new_row = pd.DataFrame({col: pd.NA for col in var.columns}, index=[hgene])
new_row['original_gene_symbol'] = 'multiple'
var = pd.concat([var, new_row])

X = csr_matrix(np.hstack((X.toarray(), mgene_data.reshape(-1, 1))))
ortholog_indices[hgene] = X.shape[1] - 1
else:
Expand All @@ -220,7 +223,11 @@ def translate_multiple(adata, original_data, multiple, stay_sparse=False, verbos

for hgene in hgenes:
if hgene not in ortholog_indices:
new_genes.append(hgene)
# Create a new DataFrame row for the new gene
new_row = pd.DataFrame({col: pd.NA for col in var.columns}, index=[hgene])
new_row['original_gene_symbol'] = 'multiple'
var = pd.concat([var, new_row])

new_data[:, next_new_gene_idx] = mgene_data.ravel()
ortholog_indices[hgene] = next_new_gene_idx
next_new_gene_idx += 1
Expand All @@ -230,11 +237,12 @@ def translate_multiple(adata, original_data, multiple, stay_sparse=False, verbos

X = new_data

# Creating new DataFrame rows for new genes
for hgene in new_genes:
new_row = pd.DataFrame({col: pd.NA for col in var.columns}, index=[hgene])
new_row['original_gene_symbol'] = 'multiple'
var = pd.concat([var, new_row])
# Check the dimensions of X and var
if X.shape[1] != var.shape[0]:
# If they do not match, modify var to match the dimensions
missing_rows = X.shape[1] - var.shape[0]
additional_rows = pd.DataFrame(index=range(var.shape[0], X.shape[1]))
var = pd.concat([var, additional_rows])

return AnnData(X, adata.obs, var, adata.uns, adata.obsm)

Expand Down

0 comments on commit ad5ca91

Please sign in to comment.