Skip to content

Commit

Permalink
Test multiple sequence generation.
Browse files Browse the repository at this point in the history
  • Loading branch information
Adibvafa committed Sep 21, 2024
1 parent 31dacd4 commit 34c9c29
Showing 1 changed file with 76 additions and 0 deletions.
76 changes: 76 additions & 0 deletions tests/test_CodonPrediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
AMINO_ACIDS,
ORGANISM2ID,
STOP_SYMBOLS,
DNASequencePrediction,
)


Expand Down Expand Up @@ -416,6 +417,81 @@ def test_predict_dna_sequence_long_protein_over_max_length(self):
"Translated protein does not match the original protein sequence up to the maximum length supported.",
)

def test_predict_dna_sequence_multi_output(self):
"""Test that the function returns multiple sequences when num_sequences > 1."""
protein_sequence = "MFQLLAPWY"
organism = "Escherichia coli general"
num_sequences = 20

result = predict_dna_sequence(
protein=protein_sequence,
organism=organism,
device=self.device,
tokenizer=self.tokenizer,
model=self.model,
deterministic=False,
num_sequences=num_sequences,
)

self.assertIsInstance(result, list)
self.assertEqual(len(result), num_sequences)

for prediction in result:
self.assertIsInstance(prediction, DNASequencePrediction)
self.assertTrue(
all(nucleotide in "ATCG" for nucleotide in prediction.predicted_dna)
)

# Check that all predicted DNA sequences translate back to the original protein
translated_protein = get_amino_acid_sequence(prediction.predicted_dna[:-3])
self.assertEqual(translated_protein, protein_sequence)

def test_predict_dna_sequence_deterministic_multi_raises_error(self):
"""Test that requesting multiple sequences in deterministic mode raises an error."""
protein_sequence = "MFWY"
organism = "Escherichia coli general"

with self.assertRaises(ValueError):
predict_dna_sequence(
protein=protein_sequence,
organism=organism,
device=self.device,
tokenizer=self.tokenizer,
model=self.model,
deterministic=True,
num_sequences=3,
)

def test_predict_dna_sequence_multi_diversity(self):
"""Test that multiple sequences generated are diverse."""
protein_sequence = "MFWYMFWY"
organism = "Escherichia coli general"
num_sequences = 10

result = predict_dna_sequence(
protein=protein_sequence,
organism=organism,
device=self.device,
tokenizer=self.tokenizer,
model=self.model,
deterministic=False,
num_sequences=num_sequences,
temperature=0.8,
)

unique_sequences = set(prediction.predicted_dna for prediction in result)

self.assertGreater(
len(unique_sequences),
2,
"Multiple sequence generation should produce diverse results",
)

# Check that all sequences are valid translations of the input protein
for prediction in result:
translated_protein = get_amino_acid_sequence(prediction.predicted_dna[:-3])
self.assertEqual(translated_protein, protein_sequence)


if __name__ == "__main__":
unittest.main()

0 comments on commit 34c9c29

Please sign in to comment.