diff --git a/CodonTransformer/CodonPrediction.py b/CodonTransformer/CodonPrediction.py index 83b9251..73f2e36 100644 --- a/CodonTransformer/CodonPrediction.py +++ b/CodonTransformer/CodonPrediction.py @@ -22,6 +22,7 @@ from CodonTransformer.CodonData import get_merged_seq from CodonTransformer.CodonUtils import ( + AMINO_ACID_TO_INDEX, INDEX2TOKEN, NUM_ORGANISMS, ORGANISM2ID, @@ -41,6 +42,7 @@ def predict_dna_sequence( temperature: float = 0.2, top_p: float = 0.95, num_sequences: int = 1, + match_protein: bool = False, ) -> Union[DNASequencePrediction, List[DNASequencePrediction]]: """ Predict the DNA sequence(s) for a given protein using the CodonTransformer model. @@ -83,6 +85,9 @@ def predict_dna_sequence( The value must be a float between 0 and 1. Defaults to 0.95. num_sequences (int, optional): The number of DNA sequences to generate. Only applicable when deterministic is False. Defaults to 1. + match_protein (bool, optional): Ensures the predicted DNA sequence is translated + to the input protein sequence by sampling from only the respective codons of + given amino acids. Defaults to False. Returns: Union[DNASequencePrediction, List[DNASequencePrediction]]: An object or list of objects @@ -198,6 +203,19 @@ def predict_dna_sequence( # Get the model predictions output_dict = model(**tokenized_input, return_dict=True) logits = output_dict.logits.detach().cpu() + logits = logits[:, 1:-1, :] # Remove [CLS] and [SEP] tokens + + # Mask the logits of codons that do not correspond to the input protein sequence + if match_protein: + possible_tokens_per_position = [ + AMINO_ACID_TO_INDEX[token[0]] for token in merged_seq.split(" ") + ] + mask = torch.full_like(logits, float("-inf")) + + for pos, possible_tokens in enumerate(possible_tokens_per_position): + mask[:, pos, possible_tokens] = 0 + + logits = mask + logits predictions = [] for _ in range(num_sequences): @@ -211,7 +229,7 @@ def predict_dna_sequence( predicted_dna = list(map(INDEX2TOKEN.__getitem__, predicted_indices)) predicted_dna = ( - "".join([token[-3:] for token in predicted_dna[1:-1]]).strip().upper() + "".join([token[-3:] for token in predicted_dna]).strip().upper() ) predictions.append( diff --git a/tests/test_CodonPrediction.py b/tests/test_CodonPrediction.py index 310193e..dfce56f 100644 --- a/tests/test_CodonPrediction.py +++ b/tests/test_CodonPrediction.py @@ -492,6 +492,103 @@ def test_predict_dna_sequence_multi_diversity(self): translated_protein = get_amino_acid_sequence(prediction.predicted_dna[:-3]) self.assertEqual(translated_protein, protein_sequence) + def test_predict_dna_sequence_match_protein_repetitive(self): + """Test that match_protein=True correctly handles highly repetitive and unconventional sequences.""" + test_sequences = ( + "QQQQQQQQQQQQQQQQ_", + "KRKRKRKRKRKRKRKR_", + "PGPGPGPGPGPGPGPG_", + "DEDEDEDEDEDEDEDEDE_", + "M_M_M_M_M_", + "MMMMMMMMMM_", + "WWWWWWWWWW_", + "CCCCCCCCCC_", + "MWCHMWCHMWCH_", + "Q_QQ_QQQ_QQQQ_", + "MWMWMWMWMWMW_", + "CCCHHHMMMWWW_", + "_", + "M_", + "MGWC_", + ) + + organism = "Homo sapiens" + + for protein_sequence in test_sequences: + # Generate sequence with match_protein=True + result = predict_dna_sequence( + protein=protein_sequence, + organism=organism, + device=self.device, + tokenizer=self.tokenizer, + model=self.model, + deterministic=False, + temperature=20, # High temperature to test protein matching + match_protein=True, + ) + + dna_sequence = result.predicted_dna + translated_protein = get_amino_acid_sequence(dna_sequence) + + self.assertEqual( + translated_protein, + protein_sequence, + f"Translated protein must match original when match_protein=True. Failed for sequence: {protein_sequence}", + ) + + def test_predict_dna_sequence_match_protein_rare_amino_acids(self): + """Test match_protein with rare amino acids that have limited codon options.""" + # Methionine (M) and Tryptophan (W) have only one codon each + # While Leucine (L) has 6 codons - testing contrast + protein_sequence = "MWLLLMWLLL" + organism = "Escherichia coli general" + + # Run multiple predictions + results = [] + num_iterations = 10 + + for _ in range(num_iterations): + result = predict_dna_sequence( + protein=protein_sequence, + organism=organism, + device=self.device, + tokenizer=self.tokenizer, + model=self.model, + deterministic=False, + temperature=20, # High temperature to test protein matching + match_protein=True, + ) + results.append(result.predicted_dna) + + # Check all sequences + for dna_sequence in results: + # Verify M always uses ATG + m_positions = [0, 5] # Known positions of M in sequence + for pos in m_positions: + self.assertEqual( + dna_sequence[pos * 3 : (pos + 1) * 3], + "ATG", + "Methionine must use ATG codon.", + ) + + # Verify W always uses TGG + w_positions = [1, 6] # Known positions of W in sequence + for pos in w_positions: + self.assertEqual( + dna_sequence[pos * 3 : (pos + 1) * 3], + "TGG", + "Tryptophan must use TGG codon.", + ) + + # Verify all L codons are valid + l_positions = [2, 3, 4, 7, 8, 9] # Known positions of L in sequence + l_codons = [dna_sequence[pos * 3 : (pos + 1) * 3] for pos in l_positions] + valid_l_codons = {"TTA", "TTG", "CTT", "CTC", "CTA", "CTG"} + self.assertTrue( + all(codon in valid_l_codons for codon in l_codons), + "All Leucine codons must be valid", + ) + if __name__ == "__main__": unittest.main()