Skip to content

Commit

Permalink
Merge pull request #13 from Adibvafa/enhance_prediction
Browse files Browse the repository at this point in the history
Add support for multiple sequence generation
  • Loading branch information
Adibvafa authored Sep 21, 2024
2 parents c646420 + 34c9c29 commit dedc17e
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 34 deletions.
87 changes: 53 additions & 34 deletions CodonTransformer/CodonPrediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,20 @@ def predict_dna_sequence(
deterministic: bool = True,
temperature: float = 0.2,
top_p: float = 0.95,
) -> DNASequencePrediction:
num_sequences: int = 1,
) -> Union[DNASequencePrediction, List[DNASequencePrediction]]:
"""
Predict the DNA sequence for a given protein using the CodonTransformer model.
Predict the DNA sequence(s) for a given protein using the CodonTransformer model.
This function takes a protein sequence and an organism (as ID or name) as input
and returns the predicted DNA sequence using the CodonTransformer model. It can use
and returns the predicted DNA sequence(s) using the CodonTransformer model. It can use
either provided tokenizer and model objects or load them from specified paths.
Args:
protein (str): The input protein sequence for which to predict the DNA sequence.
organism (Union[int, str]): Either the ID of the organism or its name (e.g.,
"Escherichia coli general"). If a string is provided, it will be converted
to the corresponding ID using `ORGANISM2ID`.
to the corresponding ID using ORGANISM2ID.
device (torch.device): The device (CPU or GPU) to run the model on.
tokenizer (Union[str, PreTrainedTokenizerFast, None], optional): Either a file
path to load the tokenizer from, a pre-loaded tokenizer object, or None. If
Expand All @@ -77,25 +78,29 @@ def predict_dna_sequence(
- High randomness: 0.8
The temperature must be a positive float. Defaults to 0.2.
top_p (float, optional): The cumulative probability threshold for nucleus sampling.
Tokens with cumulative probability up to `top_p` are considered for sampling.
Tokens with cumulative probability up to top_p are considered for sampling.
This parameter helps balance diversity and coherence in the predicted DNA sequences.
The value must be a float between 0 and 1. Defaults to 0.95.
num_sequences (int, optional): The number of DNA sequences to generate. Only applicable
when deterministic is False. Defaults to 1.
Returns:
DNASequencePrediction: An object containing the prediction results:
Union[DNASequencePrediction, List[DNASequencePrediction]]: An object or list of objects
containing the prediction results:
- organism (str): Name of the organism used for prediction.
- protein (str): Input protein sequence for which DNA sequence is predicted.
- processed_input (str): Processed input sequence (merged protein and DNA).
- predicted_dna (str): Predicted DNA sequence.
Raises:
ValueError: If the protein sequence is empty, if the organism is invalid,
if the temperature is not a positive float, or if `top_p` is not between 0 and 1.
if the temperature is not a positive float, if top_p is not between 0 and 1,
or if num_sequences is less than 1 or used with deterministic mode.
Note:
This function uses `ORGANISM2ID` and `INDEX2TOKEN` dictionaries imported from
`CodonTransformer.CodonUtils`. `ORGANISM2ID` maps organism names to their
corresponding IDs. `INDEX2TOKEN` maps model output indices (token IDs) to
This function uses ORGANISM2ID and INDEX2TOKEN dictionaries imported from
CodonTransformer.CodonUtils. ORGANISM2ID maps organism names to their
corresponding IDs. INDEX2TOKEN maps model output indices (token IDs) to
respective codons.
Example:
Expand All @@ -116,7 +121,7 @@ def predict_dna_sequence(
>>> protein = "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLA"
>>> organism = "Escherichia coli general"
>>>
>>> # Predict DNA sequence with deterministic decoding
>>> # Predict DNA sequence with deterministic decoding (single sequence)
>>> output = predict_dna_sequence(
... protein=protein,
... organism=organism,
Expand All @@ -127,7 +132,7 @@ def predict_dna_sequence(
... deterministic=True
... )
>>>
>>> # Predict DNA sequence with low randomness and top_p sampling
>>> # Predict multiple DNA sequences with low randomness and top_p sampling
>>> output_random = predict_dna_sequence(
... protein=protein,
... organism=organism,
Expand All @@ -137,23 +142,33 @@ def predict_dna_sequence(
... attention_type="original_full",
... deterministic=False,
... temperature=0.2,
... top_p=0.95
... top_p=0.95,
... num_sequences=3
... )
>>>
>>> print(format_model_output(output))
>>> print(format_model_output(output_random))
>>> for i, seq in enumerate(output_random, 1):
... print(f"Sequence {i}:")
... print(format_model_output(seq))
... print()
"""
if not protein:
raise ValueError("Protein sequence cannot be empty.")

# Validate temperature
if not isinstance(temperature, (float, int)) or temperature <= 0:
raise ValueError("Temperature must be a positive float.")

# Validate top_p
if not isinstance(top_p, (float, int)) or not 0 < top_p <= 1.0:
raise ValueError("top_p must be a float between 0 and 1.")

if not isinstance(num_sequences, int) or num_sequences < 1:
raise ValueError("num_sequences must be a positive integer.")

if deterministic and num_sequences > 1:
raise ValueError(
"Multiple sequences can only be generated in non-deterministic mode."
)

# Load tokenizer
if not isinstance(tokenizer, PreTrainedTokenizerFast):
tokenizer = load_tokenizer(tokenizer)
Expand Down Expand Up @@ -184,27 +199,31 @@ def predict_dna_sequence(
output_dict = model(**tokenized_input, return_dict=True)
logits = output_dict.logits.detach().cpu()

# Decode the predicted DNA sequence from the model output
if deterministic:
predicted_indices = logits.argmax(dim=-1).squeeze().tolist()
else:
predicted_indices = sample_non_deterministic(
logits=logits, temperature=temperature, top_p=top_p
predictions = []
for _ in range(num_sequences):
# Decode the predicted DNA sequence from the model output
if deterministic:
predicted_indices = logits.argmax(dim=-1).squeeze().tolist()
else:
predicted_indices = sample_non_deterministic(
logits=logits, temperature=temperature, top_p=top_p
)

predicted_dna = list(map(INDEX2TOKEN.__getitem__, predicted_indices))
predicted_dna = (
"".join([token[-3:] for token in predicted_dna[1:-1]]).strip().upper()
)

predicted_dna = list(map(INDEX2TOKEN.__getitem__, predicted_indices))

# Skip special tokens [CLS] and [SEP] to create the predicted_dna
predicted_dna = (
"".join([token[-3:] for token in predicted_dna[1:-1]]).strip().upper()
)
predictions.append(
DNASequencePrediction(
organism=organism_name,
protein=protein,
processed_input=merged_seq,
predicted_dna=predicted_dna,
)
)

return DNASequencePrediction(
organism=organism_name,
protein=protein,
processed_input=merged_seq,
predicted_dna=predicted_dna,
)
return predictions[0] if num_sequences == 1 else predictions


def sample_non_deterministic(
Expand Down
76 changes: 76 additions & 0 deletions tests/test_CodonPrediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
AMINO_ACIDS,
ORGANISM2ID,
STOP_SYMBOLS,
DNASequencePrediction,
)


Expand Down Expand Up @@ -416,6 +417,81 @@ def test_predict_dna_sequence_long_protein_over_max_length(self):
"Translated protein does not match the original protein sequence up to the maximum length supported.",
)

def test_predict_dna_sequence_multi_output(self):
"""Test that the function returns multiple sequences when num_sequences > 1."""
protein_sequence = "MFQLLAPWY"
organism = "Escherichia coli general"
num_sequences = 20

result = predict_dna_sequence(
protein=protein_sequence,
organism=organism,
device=self.device,
tokenizer=self.tokenizer,
model=self.model,
deterministic=False,
num_sequences=num_sequences,
)

self.assertIsInstance(result, list)
self.assertEqual(len(result), num_sequences)

for prediction in result:
self.assertIsInstance(prediction, DNASequencePrediction)
self.assertTrue(
all(nucleotide in "ATCG" for nucleotide in prediction.predicted_dna)
)

# Check that all predicted DNA sequences translate back to the original protein
translated_protein = get_amino_acid_sequence(prediction.predicted_dna[:-3])
self.assertEqual(translated_protein, protein_sequence)

def test_predict_dna_sequence_deterministic_multi_raises_error(self):
"""Test that requesting multiple sequences in deterministic mode raises an error."""
protein_sequence = "MFWY"
organism = "Escherichia coli general"

with self.assertRaises(ValueError):
predict_dna_sequence(
protein=protein_sequence,
organism=organism,
device=self.device,
tokenizer=self.tokenizer,
model=self.model,
deterministic=True,
num_sequences=3,
)

def test_predict_dna_sequence_multi_diversity(self):
"""Test that multiple sequences generated are diverse."""
protein_sequence = "MFWYMFWY"
organism = "Escherichia coli general"
num_sequences = 10

result = predict_dna_sequence(
protein=protein_sequence,
organism=organism,
device=self.device,
tokenizer=self.tokenizer,
model=self.model,
deterministic=False,
num_sequences=num_sequences,
temperature=0.8,
)

unique_sequences = set(prediction.predicted_dna for prediction in result)

self.assertGreater(
len(unique_sequences),
2,
"Multiple sequence generation should produce diverse results",
)

# Check that all sequences are valid translations of the input protein
for prediction in result:
translated_protein = get_amino_acid_sequence(prediction.predicted_dna[:-3])
self.assertEqual(translated_protein, protein_sequence)


if __name__ == "__main__":
unittest.main()

0 comments on commit dedc17e

Please sign in to comment.