From 2b50df42b21de707abc08b53f67419d3adf27ced Mon Sep 17 00:00:00 2001 From: Adibvafa Date: Fri, 20 Sep 2024 17:38:41 -0400 Subject: [PATCH 01/11] Add more tests to check predict_dna_sequence. --- tests/test_CodonPrediction.py | 83 +++++++++++++++++++++++++++++++++++ 1 file changed, 83 insertions(+) diff --git a/tests/test_CodonPrediction.py b/tests/test_CodonPrediction.py index 28cb49e..16a6136 100644 --- a/tests/test_CodonPrediction.py +++ b/tests/test_CodonPrediction.py @@ -87,6 +87,89 @@ def test_predict_dna_sequence_invalid_inputs(self): model=self.model, ) + def test_predict_dna_sequence_top_p_effect(self): + """Test that changing top_p affects the diversity of outputs.""" + protein_sequence = "MFWY" + organism = "Escherichia coli general" + num_iterations = 50 + temperature = 0.5 + top_p_values = [0.8, 0.95] + outputs_by_top_p = {top_p: set() for top_p in top_p_values} + + for top_p in top_p_values: + for _ in range(num_iterations): + result = predict_dna_sequence( + protein=protein_sequence, + organism=organism, + device=self.device, + tokenizer=self.tokenizer, + model=self.model, + deterministic=False, + temperature=temperature, + top_p=top_p, + ) + outputs_by_top_p[top_p].add( + result.predicted_dna[:-3] + ) # Remove stop codon + + # Assert that higher top_p results in more diverse outputs + diversity_lower_top_p = len(outputs_by_top_p[0.8]) + diversity_higher_top_p = len(outputs_by_top_p[0.95]) + self.assertGreaterEqual( + diversity_higher_top_p, + diversity_lower_top_p, + "Higher top_p should result in more diverse outputs", + ) + + def test_predict_dna_sequence_invalid_temperature_and_top_p(self): + """Test that invalid temperature and top_p values raise ValueError.""" + protein_sequence = "MWWMW" + organism = "Escherichia coli general" + invalid_params = [ + {"temperature": -0.1, "top_p": 0.95}, + {"temperature": 0, "top_p": 0.95}, + {"temperature": 0.5, "top_p": -0.1}, + {"temperature": 0.5, "top_p": 1.1}, + ] + + for params in invalid_params: + with self.subTest(params=params): + with self.assertRaises(ValueError): + predict_dna_sequence( + protein=protein_sequence, + organism=organism, + device=self.device, + tokenizer=self.tokenizer, + model=self.model, + deterministic=False, + temperature=params["temperature"], + top_p=params["top_p"], + ) + + def test_predict_dna_sequence_translation_consistency(self): + """Test that the predicted DNA translates back to the original protein.""" + from CodonTransformer.CodonData import get_amino_acid_sequence + + protein_sequence = "MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVE" + organism = "Escherichia coli general" + result = predict_dna_sequence( + protein=protein_sequence, + organism=organism, + device=self.device, + tokenizer=self.tokenizer, + model=self.model, + deterministic=True, + ) + + # Translate predicted DNA back to protein + translated_protein = get_amino_acid_sequence(result.predicted_dna[:-3]) + + self.assertEqual( + translated_protein, + protein_sequence, + "Translated protein does not match the original protein sequence", + ) + if __name__ == "__main__": unittest.main() From dfd11f69a438737fcf67586891c9c77dc4e7b245 Mon Sep 17 00:00:00 2001 From: Adibvafa Date: Fri, 20 Sep 2024 17:38:56 -0400 Subject: [PATCH 02/11] Add support for top_p in non-deterministic generation. --- CodonTransformer/CodonPrediction.py | 98 +++++++++++++++++++++++++---- 1 file changed, 85 insertions(+), 13 deletions(-) diff --git a/CodonTransformer/CodonPrediction.py b/CodonTransformer/CodonPrediction.py index 8c5a4f6..52f4df3 100644 --- a/CodonTransformer/CodonPrediction.py +++ b/CodonTransformer/CodonPrediction.py @@ -40,6 +40,7 @@ def predict_dna_sequence( attention_type: str = "original_full", deterministic: bool = True, temperature: float = 0.2, + top_p: float = 0.95, ) -> DNASequencePrediction: """ Predict the DNA sequence for a given protein using the CodonTransformer model. @@ -76,6 +77,10 @@ def predict_dna_sequence( - Medium randomness: 0.5 - High randomness: 0.8 The temperature must be a positive float. Defaults to 0.2. + top_p (float, optional): The cumulative probability threshold for nucleus sampling. + Tokens with cumulative probability up to `top_p` are considered for sampling. + This parameter helps balance diversity and coherence in the predicted DNA sequences. + The value must be a float between 0 and 1. Defaults to 0.95. Returns: DNASequencePrediction: An object containing the prediction results: @@ -86,7 +91,7 @@ def predict_dna_sequence( Raises: ValueError: If the protein sequence is empty, if the organism is invalid, - or if the temperature is not a positive float. + if the temperature is not a positive float, or if `top_p` is not between 0 and 1. Note: This function uses `ORGANISM2ID` and `INDEX2TOKEN` dictionaries imported from @@ -123,7 +128,7 @@ def predict_dna_sequence( ... deterministic=True ... ) >>> - >>> # Predict DNA sequence with low randomness + >>> # Predict DNA sequence with low randomness and top_p sampling >>> output_random = predict_dna_sequence( ... protein=protein, ... organism=organism, @@ -132,7 +137,8 @@ def predict_dna_sequence( ... model=model, ... attention_type="original_full", ... deterministic=False, - ... temperature=0.2 + ... temperature=0.2, + ... top_p=0.95 ... ) >>> >>> print(format_model_output(output)) @@ -149,6 +155,10 @@ def predict_dna_sequence( if not isinstance(temperature, (float, int)) or temperature <= 0: raise ValueError("Temperature must be a positive float.") + # Validate top_p + if not isinstance(top_p, (float, int)) or not 0 < top_p <= 1.0: + raise ValueError("top_p must be a float between 0 and 1.") + # Load tokenizer if not isinstance(tokenizer, PreTrainedTokenizerFast): tokenizer = load_tokenizer(tokenizer) @@ -181,18 +191,11 @@ def predict_dna_sequence( # Decode the predicted DNA sequence from the model output if deterministic: - # Select the most probable tokens (argmax) predicted_indices = logits.argmax(dim=-1).squeeze().tolist() else: - # Sample tokens according to their probability distribution - # Apply temperature scaling and convert logits to probabilities - logits = logits / temperature - probabilities = torch.softmax(logits, dim=-1) - - # Sample from the probability distribution at each position - probabilities = probabilities.squeeze(0) # Shape: [seq_len, vocab_size] - predicted_indices = ( - torch.multinomial(probabilities, num_samples=1).squeeze(-1).tolist() + # Use the standalone non-deterministic sampling function + predicted_indices = sample_non_deterministic( + logits=logits, temperature=temperature, top_p=top_p ) predicted_dna = list(map(INDEX2TOKEN.__getitem__, predicted_indices)) @@ -210,6 +213,75 @@ def predict_dna_sequence( ) +def sample_non_deterministic( + logits: torch.Tensor, + temperature: float = 1.0, + top_p: float = 0.95, +) -> List[int]: + """ + Sample token indices from logits using temperature scaling and nucleus (top-p) sampling. + + This function applies temperature scaling to the logits, computes probabilities, + and then performs nucleus sampling to select token indices. It is used for + non-deterministic decoding in language models to introduce randomness while + maintaining coherence in the generated sequences. + + Args: + logits (torch.Tensor): The logits output from the model of shape + [seq_len, vocab_size] or [batch_size, seq_len, vocab_size]. + temperature (float, optional): Temperature value for scaling logits. + Must be a positive float. Defaults to 1.0. + top_p (float, optional): Cumulative probability threshold for nucleus sampling. + Must be a float between 0 and 1. Tokens with cumulative probability up to + `top_p` are considered for sampling. Defaults to 0.95. + + Returns: + List[int]: A list of sampled token indices corresponding to the predicted tokens. + + Raises: + ValueError: If `temperature` is not a positive float or if `top_p` is not between 0 and 1. + + Example: + >>> logits = model_output.logits # Assume logits is a tensor of shape [seq_len, vocab_size] + >>> predicted_indices = sample_non_deterministic(logits, temperature=0.7, top_p=0.9) + """ + if not isinstance(temperature, (float, int)) or temperature <= 0: + raise ValueError("Temperature must be a positive float.") + if not isinstance(top_p, (float, int)) or not 0 < top_p <= 1.0: + raise ValueError("top_p must be a float between 0 and 1.") + + # Apply temperature scaling and compute probabilities + logits = logits / temperature + probabilities = torch.softmax(logits, dim=-1) + + # Remove batch dimension if present + if probabilities.dim() == 3 and probabilities.size(0) == 1: + probabilities = probabilities.squeeze(0) # Shape: [seq_len, vocab_size] + + predicted_indices = [] + for probs in probabilities: + sorted_probs, sorted_indices = torch.sort(probs, descending=True) + cumulative_probs = torch.cumsum(sorted_probs, dim=0) + + # Find the cutoff index where cumulative_probs exceeds top_p + cutoff_index = torch.where(cumulative_probs > top_p)[0] + if len(cutoff_index) > 0: + cutoff_index = cutoff_index[0].item() + # Keep only tokens up to the cutoff index + sorted_probs = sorted_probs[: cutoff_index + 1] + sorted_indices = sorted_indices[: cutoff_index + 1] + + # Re-normalize the probabilities after filtering + filtered_probs = sorted_probs / sorted_probs.sum() + + # Sample from the filtered distribution + sampled_index = torch.multinomial(filtered_probs, num_samples=1).item() + predicted_index = sorted_indices[sampled_index].item() + predicted_indices.append(predicted_index) + + return predicted_indices + + def load_model( model_path: Optional[str] = None, device: torch.device = None, From e6c3ed59adcc9a4683a8b93ac0afc976713479c8 Mon Sep 17 00:00:00 2001 From: Adibvafa Date: Fri, 20 Sep 2024 17:40:37 -0400 Subject: [PATCH 03/11] Improve style. --- CodonTransformer/CodonPrediction.py | 1 - 1 file changed, 1 deletion(-) diff --git a/CodonTransformer/CodonPrediction.py b/CodonTransformer/CodonPrediction.py index 52f4df3..41288de 100644 --- a/CodonTransformer/CodonPrediction.py +++ b/CodonTransformer/CodonPrediction.py @@ -193,7 +193,6 @@ def predict_dna_sequence( if deterministic: predicted_indices = logits.argmax(dim=-1).squeeze().tolist() else: - # Use the standalone non-deterministic sampling function predicted_indices = sample_non_deterministic( logits=logits, temperature=temperature, top_p=top_p ) From bbfe03aa290307023cfb246d2fb893f383264097 Mon Sep 17 00:00:00 2001 From: Adibvafa Date: Fri, 20 Sep 2024 18:03:19 -0400 Subject: [PATCH 04/11] Add extensive testing for predict_dna_sequence. --- tests/test_CodonPrediction.py | 266 +++++++++++++++++++++++++++++++++- 1 file changed, 264 insertions(+), 2 deletions(-) diff --git a/tests/test_CodonPrediction.py b/tests/test_CodonPrediction.py index 16a6136..e82b535 100644 --- a/tests/test_CodonPrediction.py +++ b/tests/test_CodonPrediction.py @@ -1,8 +1,16 @@ import unittest import warnings +import random import torch +from CodonTransformer.CodonData import get_amino_acid_sequence +from CodonTransformer.CodonUtils import ( + AMINO_ACIDS, + INDEX2TOKEN, + STOP_SYMBOLS, + ORGANISM2ID, +) from CodonTransformer.CodonPrediction import ( load_model, load_tokenizer, @@ -148,8 +156,6 @@ def test_predict_dna_sequence_invalid_temperature_and_top_p(self): def test_predict_dna_sequence_translation_consistency(self): """Test that the predicted DNA translates back to the original protein.""" - from CodonTransformer.CodonData import get_amino_acid_sequence - protein_sequence = "MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVE" organism = "Escherichia coli general" result = predict_dna_sequence( @@ -170,6 +176,262 @@ def test_predict_dna_sequence_translation_consistency(self): "Translated protein does not match the original protein sequence", ) + def test_predict_dna_sequence_long_protein_sequence(self): + """Test the function with a very long protein sequence to check performance and correctness.""" + protein_sequence = ( + "M" + + "MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGG" + * 20 + + STOP_SYMBOLS[0] + ) + organism = "Escherichia coli general" + result = predict_dna_sequence( + protein=protein_sequence, + organism=organism, + device=self.device, + tokenizer=self.tokenizer, + model=self.model, + deterministic=True, + ) + + # Check that the predicted DNA translates back to the original protein + dna_sequence = result.predicted_dna[:-3] + translated_protein = get_amino_acid_sequence(dna_sequence) + self.assertEqual( + translated_protein, + protein_sequence[:-1], + "Translated protein does not match the original long protein sequence", + ) + + def test_predict_dna_sequence_edge_case_organisms(self): + """Test the function with organism IDs at the boundaries of the mapping.""" + protein_sequence = "MWWMW" + # Assuming ORGANISM2ID has IDs starting from 0 to N + min_organism_id = min(ORGANISM2ID.values()) + max_organism_id = max(ORGANISM2ID.values()) + organisms = [min_organism_id, max_organism_id] + + for organism_id in organisms: + with self.subTest(organism_id=organism_id): + result = predict_dna_sequence( + protein=protein_sequence, + organism=organism_id, + device=self.device, + tokenizer=self.tokenizer, + model=self.model, + deterministic=True, + ) + self.assertIsInstance(result.predicted_dna, str) + self.assertTrue( + all(nucleotide in "ATCG" for nucleotide in result.predicted_dna) + ) + + def test_predict_dna_sequence_concurrent_calls(self): + """Test the function's behavior under concurrent execution.""" + import threading + + protein_sequence = "MWWMW" + organism = "Escherichia coli general" + results = [] + + def call_predict(): + result = predict_dna_sequence( + protein=protein_sequence, + organism=organism, + device=self.device, + tokenizer=self.tokenizer, + model=self.model, + deterministic=True, + ) + results.append(result.predicted_dna) + + threads = [threading.Thread(target=call_predict) for _ in range(10)] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + self.assertEqual(len(results), 10) + self.assertTrue(all(dna == results[0] for dna in results)) + + def test_predict_dna_sequence_random_seed_consistency(self): + """Test that setting a random seed results in consistent outputs in non-deterministic mode.""" + protein_sequence = "MFWY" + organism = "Escherichia coli general" + temperature = 0.5 + top_p = 0.95 + torch.manual_seed(42) + + result1 = predict_dna_sequence( + protein=protein_sequence, + organism=organism, + device=self.device, + tokenizer=self.tokenizer, + model=self.model, + deterministic=False, + temperature=temperature, + top_p=top_p, + ) + + torch.manual_seed(42) + + result2 = predict_dna_sequence( + protein=protein_sequence, + organism=organism, + device=self.device, + tokenizer=self.tokenizer, + model=self.model, + deterministic=False, + temperature=temperature, + top_p=top_p, + ) + + self.assertEqual( + result1.predicted_dna, + result2.predicted_dna, + "Outputs should be consistent when random seed is set", + ) + + def test_predict_dna_sequence_invalid_tokenizer_and_model(self): + """Test that providing invalid tokenizer or model raises appropriate exceptions.""" + protein_sequence = "MWWMW" + organism = "Escherichia coli general" + + with self.subTest("Invalid tokenizer"): + with self.assertRaises(Exception): + predict_dna_sequence( + protein=protein_sequence, + organism=organism, + device=self.device, + tokenizer="invalid_tokenizer_path", + model=self.model, + ) + + with self.subTest("Invalid model"): + with self.assertRaises(Exception): + predict_dna_sequence( + protein=protein_sequence, + organism=organism, + device=self.device, + tokenizer=self.tokenizer, + model="invalid_model_path", + ) + + def test_predict_dna_sequence_stop_codon_handling(self): + """Test the function's handling of protein sequences ending with a non '_' or '*' stop symbol.""" + protein_sequence = "MWW/" + organism = "Escherichia coli general" + + with self.assertRaises(ValueError): + predict_dna_sequence( + protein=protein_sequence, + organism=organism, + device=self.device, + tokenizer=self.tokenizer, + model=self.model, + ) + + def test_predict_dna_sequence_ambiguous_amino_acids(self): + """Test the function's response to ambiguous or non-standard amino acids.""" + protein_sequence = "MWWBXZ" + organism = "Escherichia coli general" + + with self.assertRaises(ValueError): + predict_dna_sequence( + protein=protein_sequence, + organism=organism, + device=self.device, + tokenizer=self.tokenizer, + model=self.model, + ) + + def test_predict_dna_sequence_device_compatibility(self): + """Test that the function works correctly on both CPU and GPU devices.""" + protein_sequence = "MWWMW" + organism = "Escherichia coli general" + + devices = [torch.device("cpu")] + if torch.cuda.is_available(): + devices.append(torch.device("cuda")) + + for device in devices: + with self.subTest(device=device): + result = predict_dna_sequence( + protein=protein_sequence, + organism=organism, + device=device, + tokenizer=self.tokenizer, + model=self.model, + deterministic=True, + ) + self.assertIsInstance(result.predicted_dna, str) + self.assertTrue( + all(nucleotide in "ATCG" for nucleotide in result.predicted_dna) + ) + + def test_predict_dna_sequence_random_proteins(self): + """Test random proteins to ensure translated DNA matches the original protein.""" + organism = "Escherichia coli general" + num_tests = 200 + + for _ in range(num_tests): + # Generate a random protein sequence of random length between 10 and 50 + protein_length = random.randint(10, 500) + protein_sequence = "M" + "".join( + random.choices(AMINO_ACIDS, k=protein_length - 1) + ) + protein_sequence += random.choice(STOP_SYMBOLS) + + result = predict_dna_sequence( + protein=protein_sequence, + organism=organism, + device=self.device, + tokenizer=self.tokenizer, + model=self.model, + deterministic=True, + ) + + # Remove stop codon from predicted DNA + dna_sequence = result.predicted_dna[:-3] + + # Translate predicted DNA back to protein + translated_protein = get_amino_acid_sequence(dna_sequence) + self.assertEqual( + translated_protein, + protein_sequence[:-1], # Remove stop symbol + f"Translated protein does not match the original protein sequence for protein: {protein_sequence}", + ) + + def test_predict_dna_sequence_long_protein_over_max_length(self): + """Test that the model handles protein sequences longer than 2048 amino acids.""" + # Create a protein sequence longer than 2048 amino acids + base_sequence = ( + "MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGG" + ) + protein_sequence = base_sequence * 100 # Length > 2048 amino acids + organism = "Escherichia coli general" + + result = predict_dna_sequence( + protein=protein_sequence, + organism=organism, + device=self.device, + tokenizer=self.tokenizer, + model=self.model, + deterministic=True, + ) + + # Remove stop codon from predicted DNA + dna_sequence = result.predicted_dna[:-3] + translated_protein = get_amino_acid_sequence(dna_sequence) + + # Due to potential model limitations, compare up to the model's max supported length + max_length = len(translated_protein) + self.assertEqual( + translated_protein[:max_length], + protein_sequence[:max_length], + "Translated protein does not match the original protein sequence up to the maximum length supported.", + ) + if __name__ == "__main__": unittest.main() From 9334cdc3f9e6e7de0a86117059dc9d8cafcd1f84 Mon Sep 17 00:00:00 2001 From: Adibvafa Date: Fri, 20 Sep 2024 18:03:35 -0400 Subject: [PATCH 05/11] Add a list of possible stop symbols. --- CodonTransformer/CodonPrediction.py | 3 ++- CodonTransformer/CodonUtils.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/CodonTransformer/CodonPrediction.py b/CodonTransformer/CodonPrediction.py index 41288de..6d3e1b1 100644 --- a/CodonTransformer/CodonPrediction.py +++ b/CodonTransformer/CodonPrediction.py @@ -23,6 +23,7 @@ from CodonTransformer.CodonData import get_merged_seq from CodonTransformer.CodonUtils import ( AMINO_ACIDS, + STOP_SYMBOLS, INDEX2TOKEN, NUM_ORGANISMS, ORGANISM2ID, @@ -148,7 +149,7 @@ def predict_dna_sequence( raise ValueError("Protein sequence cannot be empty.") # Ensure the protein sequence contains only valid amino acids - if not all(aminoacid in AMINO_ACIDS for aminoacid in protein): + if not all(aminoacid in AMINO_ACIDS + STOP_SYMBOLS for aminoacid in protein): raise ValueError("Invalid amino acid found in protein sequence.") # Validate temperature diff --git a/CodonTransformer/CodonUtils.py b/CodonTransformer/CodonUtils.py index dd91e3f..84d8bea 100644 --- a/CodonTransformer/CodonUtils.py +++ b/CodonTransformer/CodonUtils.py @@ -38,6 +38,7 @@ "W", # Tryptophan "Y", # Tyrosine ] +STOP_SYMBOLS = ["_", "*"] # Stop codon symbols # Dictionary ambiguous amino acids to standard amino acids AMBIGUOUS_AMINOACID_MAP: Dict[str, str] = { From 00b9e3f76d06b4d3da2891991f255c5b33c535a5 Mon Sep 17 00:00:00 2001 From: Adibvafa Date: Fri, 20 Sep 2024 18:06:14 -0400 Subject: [PATCH 06/11] Add docstrings for sample_non_deterministic and STOP_SYMBOLS. --- CodonTransformer/CodonPrediction.py | 2 +- README.md | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/CodonTransformer/CodonPrediction.py b/CodonTransformer/CodonPrediction.py index 6d3e1b1..af54a50 100644 --- a/CodonTransformer/CodonPrediction.py +++ b/CodonTransformer/CodonPrediction.py @@ -215,7 +215,7 @@ def predict_dna_sequence( def sample_non_deterministic( logits: torch.Tensor, - temperature: float = 1.0, + temperature: float = 0.2, top_p: float = 0.95, ) -> List[int]: """ diff --git a/README.md b/README.md index 8187f77..2baaa62 100644 --- a/README.md +++ b/README.md @@ -275,6 +275,10 @@ This subpackage contains functions and classes that handle the core prediction f Predict the DNA sequence for a given protein using the CodonTransformer model. +- `sample_non_deterministic(logits: torch.Tensor, temperature: float = 0.2, top_p: float = 0.95) -> List[int]` + + Sample token indices from logits using temperature scaling and nucleus (top-p) sampling. + - `load_model(path: str, device: torch.device = None, num_organisms: int = None, remove_prefix: bool = True, attention_type: str = "original_full") -> torch.nn.Module` Load a BigBirdForMaskedLM model from a file or checkpoint. @@ -383,6 +387,7 @@ The CodonUtils subpackage contains constants and helper functions essential for #### Constants - `AMINO_ACIDS`: List of all standard amino acids +- `STOP_SYMBOLS`: List of possible stop symbols to end the protein with - `AMBIGUOUS_AMINOACID_MAP`: Mapping of ambiguous amino acids to standard amino acids - `START_CODONS` and `STOP_CODONS`: Lists of start and stop codons - `TOKEN2INDEX` and `INDEX2TOKEN`: Mappings between tokens and their indices From 1fe2f3d33679a04c80a479a347390e4bc5fb0240 Mon Sep 17 00:00:00 2001 From: Adibvafa Date: Fri, 20 Sep 2024 18:10:05 -0400 Subject: [PATCH 07/11] Remove checking for protein sequence validity and bring it to preprocessing function. --- CodonTransformer/CodonData.py | 11 ++++++----- CodonTransformer/CodonPrediction.py | 6 ------ 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/CodonTransformer/CodonData.py b/CodonTransformer/CodonData.py index b6f6d86..bce9015 100644 --- a/CodonTransformer/CodonData.py +++ b/CodonTransformer/CodonData.py @@ -17,6 +17,7 @@ from tqdm import tqdm from CodonTransformer.CodonUtils import ( + STOP_SYMBOLS, AMBIGUOUS_AMINOACID_MAP, AMINO2CODON_TYPE, AMINO_ACIDS, @@ -177,13 +178,13 @@ def preprocess_protein_sequence(protein: str) -> str: ) # Check for sequence validity - if any( - aminoacid not in AMINO_ACIDS + ["*", STOP_SYMBOL] for aminoacid in protein[:-1] - ): + if any(aminoacid not in AMINO_ACIDS + STOP_SYMBOLS for aminoacid in protein): raise ValueError("Invalid characters in protein sequence.") - if protein[-1] not in AMINO_ACIDS + ["*", STOP_SYMBOL]: - raise ValueError("Protein sequence must end with *, or _, or an amino acid.") + if protein[-1] not in AMINO_ACIDS + STOP_SYMBOLS: + raise ValueError( + "Protein sequence must end with `*`, or `_`, or an amino acid." + ) # Replace '*' at the end of protein with STOP_SYMBOL if present if protein[-1] == "*": diff --git a/CodonTransformer/CodonPrediction.py b/CodonTransformer/CodonPrediction.py index af54a50..f147105 100644 --- a/CodonTransformer/CodonPrediction.py +++ b/CodonTransformer/CodonPrediction.py @@ -22,8 +22,6 @@ from CodonTransformer.CodonData import get_merged_seq from CodonTransformer.CodonUtils import ( - AMINO_ACIDS, - STOP_SYMBOLS, INDEX2TOKEN, NUM_ORGANISMS, ORGANISM2ID, @@ -148,10 +146,6 @@ def predict_dna_sequence( if not protein: raise ValueError("Protein sequence cannot be empty.") - # Ensure the protein sequence contains only valid amino acids - if not all(aminoacid in AMINO_ACIDS + STOP_SYMBOLS for aminoacid in protein): - raise ValueError("Invalid amino acid found in protein sequence.") - # Validate temperature if not isinstance(temperature, (float, int)) or temperature <= 0: raise ValueError("Temperature must be a positive float.") From a7cb8466edf4908f1e60eda995cfcc7c344bae95 Mon Sep 17 00:00:00 2001 From: Adibvafa Date: Fri, 20 Sep 2024 18:12:33 -0400 Subject: [PATCH 08/11] Remove test_predict_dna_sequence_ambiguous_amino_acids test. --- tests/test_CodonPrediction.py | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/tests/test_CodonPrediction.py b/tests/test_CodonPrediction.py index e82b535..75f1456 100644 --- a/tests/test_CodonPrediction.py +++ b/tests/test_CodonPrediction.py @@ -7,7 +7,6 @@ from CodonTransformer.CodonData import get_amino_acid_sequence from CodonTransformer.CodonUtils import ( AMINO_ACIDS, - INDEX2TOKEN, STOP_SYMBOLS, ORGANISM2ID, ) @@ -79,7 +78,7 @@ def test_predict_dna_sequence_non_deterministic(self): def test_predict_dna_sequence_invalid_inputs(self): test_cases = [ - ("MKTZZFVLLL", "Escherichia coli general", "invalid protein sequence"), + ("MKTZZFVLLL?", "Escherichia coli general", "invalid protein sequence"), ("MKTFFVLLL", "Alien $%#@!", "invalid organism code"), ("", "Escherichia coli general", "empty protein sequence"), ] @@ -331,20 +330,6 @@ def test_predict_dna_sequence_stop_codon_handling(self): model=self.model, ) - def test_predict_dna_sequence_ambiguous_amino_acids(self): - """Test the function's response to ambiguous or non-standard amino acids.""" - protein_sequence = "MWWBXZ" - organism = "Escherichia coli general" - - with self.assertRaises(ValueError): - predict_dna_sequence( - protein=protein_sequence, - organism=organism, - device=self.device, - tokenizer=self.tokenizer, - model=self.model, - ) - def test_predict_dna_sequence_device_compatibility(self): """Test that the function works correctly on both CPU and GPU devices.""" protein_sequence = "MWWMW" From 318625099669b19858526c52e66590ba93ba6157 Mon Sep 17 00:00:00 2001 From: Adibvafa Date: Fri, 20 Sep 2024 18:20:17 -0400 Subject: [PATCH 09/11] Improve style. --- CodonTransformer/CodonData.py | 2 +- tests/test_CodonPrediction.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/CodonTransformer/CodonData.py b/CodonTransformer/CodonData.py index bce9015..9eea9e9 100644 --- a/CodonTransformer/CodonData.py +++ b/CodonTransformer/CodonData.py @@ -17,7 +17,6 @@ from tqdm import tqdm from CodonTransformer.CodonUtils import ( - STOP_SYMBOLS, AMBIGUOUS_AMINOACID_MAP, AMINO2CODON_TYPE, AMINO_ACIDS, @@ -25,6 +24,7 @@ START_CODONS, STOP_CODONS, STOP_SYMBOL, + STOP_SYMBOLS, find_pattern_in_fasta, get_taxonomy_id, sort_amino2codon_skeleton, diff --git a/tests/test_CodonPrediction.py b/tests/test_CodonPrediction.py index 75f1456..1fb0d43 100644 --- a/tests/test_CodonPrediction.py +++ b/tests/test_CodonPrediction.py @@ -1,20 +1,20 @@ +import random import unittest import warnings -import random import torch from CodonTransformer.CodonData import get_amino_acid_sequence -from CodonTransformer.CodonUtils import ( - AMINO_ACIDS, - STOP_SYMBOLS, - ORGANISM2ID, -) from CodonTransformer.CodonPrediction import ( load_model, load_tokenizer, predict_dna_sequence, ) +from CodonTransformer.CodonUtils import ( + AMINO_ACIDS, + ORGANISM2ID, + STOP_SYMBOLS, +) class TestCodonPrediction(unittest.TestCase): From df61ef9f1f481cf5448a137d5a326de0e6544fcc Mon Sep 17 00:00:00 2001 From: Adibvafa Date: Fri, 20 Sep 2024 20:40:44 -0400 Subject: [PATCH 10/11] Fix issue with top_p sampling. --- tests/test_CodonPrediction.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_CodonPrediction.py b/tests/test_CodonPrediction.py index 1fb0d43..079cbe0 100644 --- a/tests/test_CodonPrediction.py +++ b/tests/test_CodonPrediction.py @@ -60,7 +60,6 @@ def test_predict_dna_sequence_non_deterministic(self): "ATGTTTTGGTAC", "ATGTTCTGGTAC", } - for _ in range(num_iterations): for temperature in temperatures: result = predict_dna_sequence( From dfa53daf2925493eea33166d98af203dbcb4d8e0 Mon Sep 17 00:00:00 2001 From: Adibvafa Date: Fri, 20 Sep 2024 20:41:08 -0400 Subject: [PATCH 11/11] Fix issue with top_p sampling. --- CodonTransformer/CodonPrediction.py | 49 +++++++++++++---------------- 1 file changed, 21 insertions(+), 28 deletions(-) diff --git a/CodonTransformer/CodonPrediction.py b/CodonTransformer/CodonPrediction.py index f147105..adb85ed 100644 --- a/CodonTransformer/CodonPrediction.py +++ b/CodonTransformer/CodonPrediction.py @@ -241,39 +241,32 @@ def sample_non_deterministic( """ if not isinstance(temperature, (float, int)) or temperature <= 0: raise ValueError("Temperature must be a positive float.") + if not isinstance(top_p, (float, int)) or not 0 < top_p <= 1.0: raise ValueError("top_p must be a float between 0 and 1.") - # Apply temperature scaling and compute probabilities - logits = logits / temperature - probabilities = torch.softmax(logits, dim=-1) + # Compute probabilities using temperature scaling + logits /= temperature + probs = torch.softmax(logits, dim=-1) # Remove batch dimension if present - if probabilities.dim() == 3 and probabilities.size(0) == 1: - probabilities = probabilities.squeeze(0) # Shape: [seq_len, vocab_size] - - predicted_indices = [] - for probs in probabilities: - sorted_probs, sorted_indices = torch.sort(probs, descending=True) - cumulative_probs = torch.cumsum(sorted_probs, dim=0) - - # Find the cutoff index where cumulative_probs exceeds top_p - cutoff_index = torch.where(cumulative_probs > top_p)[0] - if len(cutoff_index) > 0: - cutoff_index = cutoff_index[0].item() - # Keep only tokens up to the cutoff index - sorted_probs = sorted_probs[: cutoff_index + 1] - sorted_indices = sorted_indices[: cutoff_index + 1] - - # Re-normalize the probabilities after filtering - filtered_probs = sorted_probs / sorted_probs.sum() - - # Sample from the filtered distribution - sampled_index = torch.multinomial(filtered_probs, num_samples=1).item() - predicted_index = sorted_indices[sampled_index].item() - predicted_indices.append(predicted_index) - - return predicted_indices + if probs.dim() == 3: + probs = probs.squeeze(0) # Shape: [seq_len, vocab_size] + + # Sort probabilities in descending order + probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) + probs_sum = torch.cumsum(probs_sort, dim=-1) + mask = probs_sum - probs_sort > top_p + + # Zero out probabilities for tokens beyond the top-p threshold + probs_sort[mask] = 0.0 + + # Renormalize the probabilities + probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) + next_token = torch.multinomial(probs_sort, num_samples=1) + predicted_indices = torch.gather(probs_idx, -1, next_token).squeeze(-1) + + return predicted_indices.tolist() def load_model(