Skip to content

Commit

Permalink
Implemented a two-stage approach: first extending to 256k, then to 20…
Browse files Browse the repository at this point in the history
…48k.

Added specific fine-tuning steps: 400 steps for 128k, 600 steps for 256k.
Reduced search parameters for the 2048k extension.
  • Loading branch information
jshuadvd committed Jun 30, 2024
1 parent 6fbcfb7 commit 6a62df9
Showing 1 changed file with 48 additions and 23 deletions.
71 changes: 48 additions & 23 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,48 +513,73 @@ def progressive_extension(
max_iterations,
):
"""
Progressively extend the context window of the model.
Progressively extend the context window of the model using the two-stage approach described in the LongRoPE paper.
Args:
model (nn.Module): LongRoPE model.
data (list): List of input sequences.
base_length (int): Base context window length.
target_length (int): Target context window length.
model (nn.Module): LongRoPE model to be extended.
data (list): List of input sequences for fine-tuning and evaluation.
base_length (int): Original context window length of the model.
target_length (int): Target context window length (2048k in the paper).
population_size (int): Size of the population for evolutionary search.
num_mutations (int): Number of mutations per iteration.
num_crossovers (int): Number of crossovers per iteration.
num_mutations (int): Number of mutations per iteration in the search.
num_crossovers (int): Number of crossovers per iteration in the search.
max_iterations (int): Maximum number of iterations for evolutionary search.
Returns:
tuple: (Extended model, lambda factors, n_hat, base lambda factors, base n_hat)
tuple: (Extended model, 2048k lambda factors, 2048k n_hat, 256k lambda factors, 256k n_hat)
"""
# Stage 1: Extend to 256k
curr_model = model
curr_length = base_length

while curr_length < target_length:
lambda_factors, n_hat = search_lambda_factors(
curr_model,
data,
curr_length / base_length,
population_size,
num_mutations,
num_crossovers,
max_iterations,
)
curr_model = fine_tune(curr_model, data, curr_length, lambda_factors, n_hat)
curr_length *= 2
# First extend to 128k
lambda_factors_128k, n_hat_128k = search_lambda_factors(
curr_model,
data,
128000 / base_length,
population_size,
num_mutations,
num_crossovers,
max_iterations,
)
# Fine-tune for 400 steps as specified in the paper
curr_model = fine_tune(
curr_model, data, 128000, lambda_factors_128k, n_hat_128k, steps=400
)

lambda_factors_base, n_hat_base = search_lambda_factors(
# Then extend to 256k
lambda_factors_256k, n_hat_256k = search_lambda_factors(
curr_model,
data,
curr_length / base_length,
256000 / 128000,
population_size,
num_mutations,
num_crossovers,
max_iterations,
)
# Fine-tune for 600 steps as specified in the paper
curr_model = fine_tune(
curr_model, data, 256000, lambda_factors_256k, n_hat_256k, steps=600
)

return curr_model, lambda_factors, n_hat, lambda_factors_base, n_hat_base
# Stage 2: Extend to 2048k without further fine-tuning
lambda_factors_2048k, n_hat_2048k = search_lambda_factors(
curr_model,
data,
2048000 / 256000,
population_size // 2, # Reduce population size for efficiency in 2048k search
num_mutations // 2,
num_crossovers // 2,
max_iterations // 2,
)

return (
curr_model,
lambda_factors_2048k,
n_hat_2048k,
lambda_factors_256k,
n_hat_256k,
)


def short_context_recovery(model, data, base_length, lambda_factors_base, n_hat_base):
Expand Down

0 comments on commit 6a62df9

Please sign in to comment.