Skip to content

Commit

Permalink
Maintained the dimension-wise scaling (lambda_factors[i]).
Browse files Browse the repository at this point in the history
Kept the position-based mask for initial tokens (n_hat).
ensures that both forms of non-uniformities are explicitly considered
  • Loading branch information
jshuadvd committed Jun 30, 2024
1 parent 6a62df9 commit 24a5251
Showing 1 changed file with 15 additions and 13 deletions.
28 changes: 15 additions & 13 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,34 +28,36 @@ def forward(self, positions):

def non_uniform_interpolation(pos_embed, extension_ratio, lambda_factors, n_hat):
"""
Perform non-uniform interpolation on position embeddings.
Perform non-uniform interpolation on position embeddings as described in the LongRoPE paper.
This function implements the two forms of non-uniformities:
1. Varying RoPE dimensions (lambda_factors)
2. Token positions (n_hat)
Args:
pos_embed (torch.Tensor): Position embeddings.
extension_ratio (float): Extension ratio for context window.
lambda_factors (list): Lambda factors for interpolation.
n_hat (int): Threshold for applying interpolation.
pos_embed (torch.Tensor): Original position embeddings.
extension_ratio (float): Ratio of target length to original length.
lambda_factors (list): Lambda factors for each RoPE dimension.
n_hat (int): Number of initial tokens to keep without interpolation.
Returns:
torch.Tensor: Interpolated position embeddings.
"""

if extension_ratio is None:
raise ValueError("extension_ratio cannot be None")

if lambda_factors is None:
raise ValueError("lambda_factors cannot be None")

d_model = pos_embed.shape[-1]
interpolated_pos = pos_embed.clone()

for i in range(d_model // 2):
# Apply different scaling based on token position
mask = torch.arange(pos_embed.shape[-2], device=pos_embed.device) < n_hat
scale = torch.where(
mask,
torch.ones_like(pos_embed[..., 0], device=pos_embed.device),
1 / (lambda_factors[i] * extension_ratio),
1
/ (
lambda_factors[i] * extension_ratio
), # Use dimension-specific lambda factors
)
# Apply scaling to both sine and cosine components
interpolated_pos[..., 2 * i] *= scale
interpolated_pos[..., 2 * i + 1] *= scale

Expand Down

0 comments on commit 24a5251

Please sign in to comment.