Skip to content

Commit

Permalink
Add support for matching protein sequences.
Browse files Browse the repository at this point in the history
  • Loading branch information
Adibvafa committed Oct 29, 2024
1 parent 7c20ff1 commit f9fb38a
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 1 deletion.
20 changes: 19 additions & 1 deletion CodonTransformer/CodonPrediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from CodonTransformer.CodonData import get_merged_seq
from CodonTransformer.CodonUtils import (
AMINO_ACID_TO_INDEX,
INDEX2TOKEN,
NUM_ORGANISMS,
ORGANISM2ID,
Expand All @@ -41,6 +42,7 @@ def predict_dna_sequence(
temperature: float = 0.2,
top_p: float = 0.95,
num_sequences: int = 1,
match_protein: bool = False,
) -> Union[DNASequencePrediction, List[DNASequencePrediction]]:
"""
Predict the DNA sequence(s) for a given protein using the CodonTransformer model.
Expand Down Expand Up @@ -83,6 +85,9 @@ def predict_dna_sequence(
The value must be a float between 0 and 1. Defaults to 0.95.
num_sequences (int, optional): The number of DNA sequences to generate. Only applicable
when deterministic is False. Defaults to 1.
match_protein (bool, optional): Ensures the predicted DNA sequence is translated
to the input protein sequence by sampling from only the respective codons of
given amino acids. Defaults to False.
Returns:
Union[DNASequencePrediction, List[DNASequencePrediction]]: An object or list of objects
Expand Down Expand Up @@ -198,6 +203,19 @@ def predict_dna_sequence(
# Get the model predictions
output_dict = model(**tokenized_input, return_dict=True)
logits = output_dict.logits.detach().cpu()
logits = logits[:, 1:-1, :] # Remove [CLS] and [SEP] tokens

# Mask the logits of codons that do not correspond to the input protein sequence
if match_protein:
possible_tokens_per_position = [
AMINO_ACID_TO_INDEX[token[0]] for token in merged_seq.split(" ")
]
mask = torch.full_like(logits, float("-inf"))

for pos, possible_tokens in enumerate(possible_tokens_per_position):
mask[:, pos, possible_tokens] = 0

logits = mask + logits

predictions = []
for _ in range(num_sequences):
Expand All @@ -211,7 +229,7 @@ def predict_dna_sequence(

predicted_dna = list(map(INDEX2TOKEN.__getitem__, predicted_indices))
predicted_dna = (
"".join([token[-3:] for token in predicted_dna[1:-1]]).strip().upper()
"".join([token[-3:] for token in predicted_dna]).strip().upper()
)

predictions.append(
Expand Down
97 changes: 97 additions & 0 deletions tests/test_CodonPrediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,103 @@ def test_predict_dna_sequence_multi_diversity(self):
translated_protein = get_amino_acid_sequence(prediction.predicted_dna[:-3])
self.assertEqual(translated_protein, protein_sequence)

def test_predict_dna_sequence_match_protein_repetitive(self):
"""Test that match_protein=True correctly handles highly repetitive and unconventional sequences."""
test_sequences = (
"QQQQQQQQQQQQQQQQ_",
"KRKRKRKRKRKRKRKR_",
"PGPGPGPGPGPGPGPG_",
"DEDEDEDEDEDEDEDEDE_",
"M_M_M_M_M_",
"MMMMMMMMMM_",
"WWWWWWWWWW_",
"CCCCCCCCCC_",
"MWCHMWCHMWCH_",
"Q_QQ_QQQ_QQQQ_",
"MWMWMWMWMWMW_",
"CCCHHHMMMWWW_",
"_",
"M_",
"MGWC_",
)

organism = "Homo sapiens"

for protein_sequence in test_sequences:
# Generate sequence with match_protein=True
result = predict_dna_sequence(
protein=protein_sequence,
organism=organism,
device=self.device,
tokenizer=self.tokenizer,
model=self.model,
deterministic=False,
temperature=20, # High temperature to test protein matching
match_protein=True,
)

dna_sequence = result.predicted_dna
translated_protein = get_amino_acid_sequence(dna_sequence)

self.assertEqual(
translated_protein,
protein_sequence,
f"Translated protein must match original when match_protein=True. Failed for sequence: {protein_sequence}",
)

def test_predict_dna_sequence_match_protein_rare_amino_acids(self):
"""Test match_protein with rare amino acids that have limited codon options."""
# Methionine (M) and Tryptophan (W) have only one codon each
# While Leucine (L) has 6 codons - testing contrast
protein_sequence = "MWLLLMWLLL"
organism = "Escherichia coli general"

# Run multiple predictions
results = []
num_iterations = 10

for _ in range(num_iterations):
result = predict_dna_sequence(
protein=protein_sequence,
organism=organism,
device=self.device,
tokenizer=self.tokenizer,
model=self.model,
deterministic=False,
temperature=20, # High temperature to test protein matching
match_protein=True,
)
results.append(result.predicted_dna)

# Check all sequences
for dna_sequence in results:
# Verify M always uses ATG
m_positions = [0, 5] # Known positions of M in sequence
for pos in m_positions:
self.assertEqual(
dna_sequence[pos * 3 : (pos + 1) * 3],
"ATG",
"Methionine must use ATG codon.",
)

# Verify W always uses TGG
w_positions = [1, 6] # Known positions of W in sequence
for pos in w_positions:
self.assertEqual(
dna_sequence[pos * 3 : (pos + 1) * 3],
"TGG",
"Tryptophan must use TGG codon.",
)

# Verify all L codons are valid
l_positions = [2, 3, 4, 7, 8, 9] # Known positions of L in sequence
l_codons = [dna_sequence[pos * 3 : (pos + 1) * 3] for pos in l_positions]
valid_l_codons = {"TTA", "TTG", "CTT", "CTC", "CTA", "CTG"}
self.assertTrue(
all(codon in valid_l_codons for codon in l_codons),
"All Leucine codons must be valid",
)


if __name__ == "__main__":
unittest.main()

0 comments on commit f9fb38a

Please sign in to comment.