Skip to content

Commit

Permalink
feat: speedup contact map alignment
Browse files Browse the repository at this point in the history
- move function to Cython
- speedup for ESM databases
  • Loading branch information
valentynbez committed Mar 31, 2024
1 parent 8066f93 commit 29d9cca
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 113 deletions.
108 changes: 107 additions & 1 deletion mDeepFRI/alignment_utils.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ cimport cython
import numpy as np

cimport numpy as np
from libc.math cimport sqrt
from libc.stdlib cimport free, malloc
from libc.string cimport strlen

Expand Down Expand Up @@ -64,3 +63,110 @@ cpdef pairwise_sqeuclidean(float[:, ::1] X):
D[j, i] = d

return np.asarray(D)


cpdef align_contact_map(str query_alignment,
str target_alignment,
np.ndarray[np.int32_t, ndim=2] sparse_target_contact_map,
int generated_contacts=2):
"""
Aligns a contact map based on the alignments of query and target sequences.
Args:
query_alignment: The alignment of the query sequence.
target_alignment: The alignment of the target sequence.
sparse_target_contact_map: The sparse contact map of the target
sequence represented as a list of tuples (i, j)
indicating contacts between residues iand j.
generated_contacts: The number of generated contacts to add for gapped
regions in the query alignment. Defaults to 2.
Returns:
The aligned contact map as a numpy array.
Algorithm:
1. Initialize an empty list `sparse_query_contact_map` to store the contacts in the aligned contact map.
2. Initialize variables `target_index` and `query_index` to track the indices of residues in the target
and query proteins, respectively.
3. Initialize an empty dictionary `target_to_query_indices` to map target residues to query residues
using shift resulting from the alignments.
4. Iterate over each position in the query alignment:
- If the query residue is '-', increment the `target_index` and do not add a contact
to the aligned contact map.
- If the query residue is not '-', check the target residue:
- If the target residue is '-', add contacts for the generated region in the query alignment:
- For each generated contact, add the contact (query_index + j, query_index)
and (query_index - j, query_index) to the `sparse_query_contact_map` according to generated_contacts.
- Increment the `query_index`.
- If the target residue is not '-', map the target residue to the query residue by adding
an entry in the `target_to_query_indices` dictionary.
- Increment both the `query_index` and `target_index`.
5. Translate the target residue indices to query residue indices
in the `sparse_target_contact_map` by using the `target_to_query_indices` dictionary.
6. Filter out the contacts that are not present in the query alignment by removing contacts
with '-1' indices from the `sparse_target_contact_map`.
7. Add the filtered contacts from the filtered `sparse_target_contact_map` to the `sparse_query_contact_map`.
8. Build the output contact map with dimensions (query_index, query_index) initialized as all zeros.
Query index is the number of residues in the query sequence.
9. Set the diagonal elements of the output contact map to 1.
10. For each contact (i, j) in the `sparse_query_contact_map`:
- If i is less than 0 or greater than or equal to `query_index`, skip the contact.
- Otherwise, set the corresponding elements in the output contact map to 1 symmetrically.
11. Return the aligned contact map as a numpy array.
"""

cdef int target_index = 0
cdef int query_index = 0
cdef int i, j, k, n
cdef int *target_to_query_indices = <int *>malloc(len(target_alignment) * sizeof(int))
cdef int *sparse_query_contact_map = <int *>malloc(len(query_alignment) * len(target_alignment) * 2 * sizeof(int))
cdef int sparse_map_size = 0
cdef int output_contact_map_size = query_index * query_index

# Map target residues to query residues based on the alignments
for i in range(len(query_alignment)):
if query_alignment[i] == "-":
target_to_query_indices[target_index] = -1
target_index += 1
else:
if target_alignment[i] == "-":
for j in range(1, generated_contacts + 1):
sparse_query_contact_map[sparse_map_size] = query_index + j
sparse_query_contact_map[sparse_map_size + 1] = query_index
sparse_map_size += 2
sparse_query_contact_map[sparse_map_size] = query_index - j
sparse_query_contact_map[sparse_map_size + 1] = query_index
sparse_map_size += 2
query_index += 1
else:
target_to_query_indices[target_index] = query_index
query_index += 1
target_index += 1

# Translate the target residues index to query residues index
for i in range(sparse_target_contact_map.shape[0]):
if (target_to_query_indices[sparse_target_contact_map[i, 0]] != -1 and
target_to_query_indices[sparse_target_contact_map[i, 1]] != -1):
sparse_query_contact_map[sparse_map_size] = target_to_query_indices[sparse_target_contact_map[i, 0]]
sparse_query_contact_map[sparse_map_size + 1] = target_to_query_indices[sparse_target_contact_map[i, 1]]
sparse_map_size += 2

# Build the output contact map
cdef np.ndarray[np.int32_t, ndim=2] output_contact_map = np.zeros((query_index, query_index), dtype=np.int32)

# Fill the diagonal
for i in range(query_index):
output_contact_map[i, i] = 1

# Fill the contacts from the sparse query contact map
for i in range(0, sparse_map_size, 2):
j = sparse_query_contact_map[i]
k = sparse_query_contact_map[i + 1]
if j < query_index and k < query_index:
output_contact_map[j, k] = 1
output_contact_map[k, j] = 1

free(target_to_query_indices)
free(sparse_query_contact_map)

return output_contact_map
116 changes: 4 additions & 112 deletions mDeepFRI/bio_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from biotite.structure.io.pdbx import PDBxFile, get_structure
from pysam import FastaFile, FastxFile

from mDeepFRI.alignment_utils import alignment_identity, pairwise_sqeuclidean
from mDeepFRI.alignment_utils import (align_contact_map, alignment_identity,
pairwise_sqeuclidean)

logging.basicConfig(
level=logging.DEBUG,
Expand Down Expand Up @@ -178,7 +179,7 @@ def calculate_contact_map(coordinates: np.ndarray,
cmap = (distances < threshold).astype(np.int32)

if mode == "sparse":
cmap = np.argwhere(cmap == 1).astype(np.uint32)
cmap = np.argwhere(cmap == 1).astype(np.int32)
else:
pass

Expand Down Expand Up @@ -259,7 +260,6 @@ def retrieve_structure_features(idx: str,
structure = requests.get(url).text

else:
chain = None
with foldcomp.open(database_path, ids=[idx]) as db:
for _, pdb in db:
structure = pdb
Expand All @@ -272,119 +272,11 @@ def retrieve_structure_features(idx: str,
for _, pdb in db:
structure = pdb

residues, coords = extract_residues_coordinates(structure, chain=chain)
residues, coords = extract_residues_coordinates(structure)

return (residues, coords)


def align_contact_map(query_alignment: str,
target_alignment: str,
sparse_target_contact_map: List[Tuple[int, int]],
generated_contacts: int = 2) -> np.ndarray:
"""
Aligns a contact map based on the alignments of query and target sequences.
Args:
query_alignment: The alignment of the query sequence.
target_alignment: The alignment of the target sequence.
sparse_target_contact_map: The sparse contact map of the target
sequence represented as a list of tuples (i, j)
indicating contacts between residues iand j.
generated_contacts: The number of generated contacts to add for gapped
regions in the query alignment. Defaults to 2.
Returns:
The aligned contact map as a numpy array.
Algorithm:
1. Initialize an empty list `sparse_query_contact_map` to store the contacts in the aligned contact map.
2. Initialize variables `target_index` and `query_index` to track the indices of residues in the target
and query proteins, respectively.
3. Initialize an empty dictionary `target_to_query_indices` to map target residues to query residues
using shift resulting from the alignments.
4. Iterate over each position in the query alignment:
- If the query residue is '-', increment the `target_index` and do not add a contact
to the aligned contact map.
- If the query residue is not '-', check the target residue:
- If the target residue is '-', add contacts for the generated region in the query alignment:
- For each generated contact, add the contact (query_index + j, query_index)
and (query_index - j, query_index) to the `sparse_query_contact_map` according to generated_contacts.
- Increment the `query_index`.
- If the target residue is not '-', map the target residue to the query residue by adding
an entry in the `target_to_query_indices` dictionary.
- Increment both the `query_index` and `target_index`.
5. Translate the target residue indices to query residue indices
in the `sparse_target_contact_map` by using the `target_to_query_indices` dictionary.
6. Filter out the contacts that are not present in the query alignment by removing contacts
with '-1' indices from the `sparse_target_contact_map`.
7. Add the filtered contacts from the filtered `sparse_target_contact_map` to the `sparse_query_contact_map`.
8. Build the output contact map with dimensions (query_index, query_index) initialized as all zeros.
Query index is the number of residues in the query sequence.
9. Set the diagonal elements of the output contact map to 1.
10. For each contact (i, j) in the `sparse_query_contact_map`:
- If i is less than 0 or greater than or equal to `query_index`, skip the contact.
- Otherwise, set the corresponding elements in the output contact map to 1 symmetrically.
11. Return the aligned contact map as a numpy array.
"""
# The sparse contact map of the query sequence will contain all contacts. Will be used to create dense contact map
sparse_query_contact_map: List[Tuple[int, int]] = []

# The index of the residues in sequences
target_index: int = 0
query_index: int = 0

# Map target residues to query residues based on the alignments
target_to_query_indices: Dict[int, int] = {}

# Map target residues to query residues based on the alignments
for i in range(len(query_alignment)):
# If the query residue is a gap, skip target residue
if query_alignment[i] == "-":
target_to_query_indices[target_index] = -1
target_index += 1
else:
# If the target residue is a gap, add contacts to the query residue
# connected to generated_contacts nearest residues
if target_alignment[i] == "-":
for j in range(1, generated_contacts + 1):
sparse_query_contact_map.append(
(query_index + j, query_index))
sparse_query_contact_map.append(
(query_index - j, query_index))
query_index += 1
else:
# If there is an alignment match, map target residue to query residue
target_to_query_indices[target_index] = query_index
query_index += 1
target_index += 1

# Translate the target residues index to query residues index
sparse_map = list(
map(
lambda x:
(target_to_query_indices[x[0]], target_to_query_indices[x[1]]),
sparse_target_contact_map))
# Filter out the contacts that are not in the query alignment by removing columns and rows from sparse contact map
sparse_map = list(filter(lambda x: x[0] != -1 and x[1] != -1, sparse_map))
# Add the contacts to the output contact map
sparse_query_contact_map.extend(sparse_map)

# Build the output contact map
output_contact_map = np.zeros((query_index, query_index))
# Fill the diagonal
for i in range(query_index):
output_contact_map[i, i] = 1
# Fill the contacts from the sparse query contact map
for i, j in sparse_query_contact_map:
if i >= query_index:
continue
# Apply symmetryR
output_contact_map[i, j] = 1
output_contact_map[j, i] = 1

return output_contact_map


def retrieve_align_contact_map(
alignment: AlignmentResult,
database: str = "pdb100",
Expand Down

0 comments on commit 29d9cca

Please sign in to comment.