Skip to content

Commit

Permalink
Fix issue with top_p sampling.
Browse files Browse the repository at this point in the history
  • Loading branch information
Adibvafa committed Sep 21, 2024
1 parent df61ef9 commit dfa53da
Showing 1 changed file with 21 additions and 28 deletions.
49 changes: 21 additions & 28 deletions CodonTransformer/CodonPrediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,39 +241,32 @@ def sample_non_deterministic(
"""
if not isinstance(temperature, (float, int)) or temperature <= 0:
raise ValueError("Temperature must be a positive float.")

if not isinstance(top_p, (float, int)) or not 0 < top_p <= 1.0:
raise ValueError("top_p must be a float between 0 and 1.")

# Apply temperature scaling and compute probabilities
logits = logits / temperature
probabilities = torch.softmax(logits, dim=-1)
# Compute probabilities using temperature scaling
logits /= temperature
probs = torch.softmax(logits, dim=-1)

# Remove batch dimension if present
if probabilities.dim() == 3 and probabilities.size(0) == 1:
probabilities = probabilities.squeeze(0) # Shape: [seq_len, vocab_size]

predicted_indices = []
for probs in probabilities:
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
cumulative_probs = torch.cumsum(sorted_probs, dim=0)

# Find the cutoff index where cumulative_probs exceeds top_p
cutoff_index = torch.where(cumulative_probs > top_p)[0]
if len(cutoff_index) > 0:
cutoff_index = cutoff_index[0].item()
# Keep only tokens up to the cutoff index
sorted_probs = sorted_probs[: cutoff_index + 1]
sorted_indices = sorted_indices[: cutoff_index + 1]

# Re-normalize the probabilities after filtering
filtered_probs = sorted_probs / sorted_probs.sum()

# Sample from the filtered distribution
sampled_index = torch.multinomial(filtered_probs, num_samples=1).item()
predicted_index = sorted_indices[sampled_index].item()
predicted_indices.append(predicted_index)

return predicted_indices
if probs.dim() == 3:
probs = probs.squeeze(0) # Shape: [seq_len, vocab_size]

# Sort probabilities in descending order
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > top_p

# Zero out probabilities for tokens beyond the top-p threshold
probs_sort[mask] = 0.0

# Renormalize the probabilities
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_token = torch.multinomial(probs_sort, num_samples=1)
predicted_indices = torch.gather(probs_idx, -1, next_token).squeeze(-1)

return predicted_indices.tolist()


def load_model(
Expand Down

0 comments on commit dfa53da

Please sign in to comment.