diff --git a/src/main.py b/src/main.py index 6645141..fd2b06d 100644 --- a/src/main.py +++ b/src/main.py @@ -98,10 +98,6 @@ def load_data(data_path, tokenizer, max_sequence_length): return tensor_data -import torch.nn as nn -import torch - - class LongRoPEModel(nn.Module): """ Long Range Rotary Position Encoding (LongRoPE) model. @@ -292,54 +288,112 @@ def search_lambda_factors( max_iterations, ): """ - Search for optimal lambda factors using evolutionary search. + Search for optimal lambda factors using evolutionary search as described in the LongRoPE paper. + + This function implements the efficient search algorithm, including the monotonic constraint + and optimized initial population generation. Args: - model (nn.Module): LongRoPE model. - data (list): List of input sequences. - extension_ratio (float): Extension ratio for context window. - population_size (int): Size of the population for evolutionary search. - num_mutations (int): Number of mutations per iteration. - num_crossovers (int): Number of crossovers per iteration. - max_iterations (int): Maximum number of iterations for evolutionary search. + model: LongRoPE model to be extended. + data: List of input sequences for evaluation. + extension_ratio: Ratio of target length to current length. + population_size: Size of the population for evolutionary search. + num_mutations: Number of mutations per iteration. + num_crossovers: Number of crossovers per iteration. + max_iterations: Maximum number of iterations for evolutionary search. Returns: tuple: (Best lambda factors, best n_hat) """ - - population = initialize_population(population_size, extension_ratio, model.d_model) - - for i in range(max_iterations): + # Define search space as described in Section 3.2 of the paper + search_space = { + "lambda_i": ( + 1.0, + extension_ratio * 1.25, + 0.01, + ), # Min, max, and step size for lambda_i + "n_hat": [ + 0, + 1, + 2, + 4, + 8, + 12, + 16, + 20, + 24, + 28, + 32, + 64, + 128, + 256, + ], # Possible n_hat values + } + + # Initialize population with optimized method (including PI, NTK, and YaRN as individuals) + population = initialize_population(population_size, search_space, model.d_model) + + for _ in range(max_iterations): + # Evaluate the fitness (perplexity) of each individual in the population perplexities = evaluate_population(model, data, population) + # Select the top-performing individuals as parents parents = select_topk(population, perplexities, k=population_size // 2) - population = mutate(parents, num_mutations) + crossover(parents, num_crossovers) + # Create new population through mutation and crossover + population = mutate(parents, num_mutations, model.d_model) + crossover( + parents, num_crossovers, model.d_model + ) + # Apply monotonic constraint to ensure λi ≤ λi+1 + population = [ + apply_monotonic_constraint(individual) for individual in population + ] - best_lambda_factors, best_n_hat = min( - population, key=lambda x: evaluate_individual(model, data, x) - ) + # Select the best individual based on the lowest perplexity + best_individual = min(population, key=lambda x: evaluate_individual(model, data, x)) + return best_individual["lambda_i"], best_individual["n_hat"] - return best_lambda_factors, best_n_hat + +def apply_monotonic_constraint(individual): + """ + Apply the monotonic constraint to lambda factors as described in the paper. + + This ensures that λi ≤ λi+1, which is theoretically justified and improves performance. + + Args: + individual: Dictionary containing 'lambda_i' and 'n_hat' + + Returns: + individual: Dictionary with monotonically non-decreasing lambda factors + """ + lambda_i = individual["lambda_i"] + for i in range(1, len(lambda_i)): + lambda_i[i] = max(lambda_i[i], lambda_i[i - 1]) + return individual -def initialize_population(population_size, extension_ratio, d_model): +def initialize_population(population_size, search_space, d_model): """ Initialize the population for evolutionary search. + This function implements the optimized initial population generation described in Section 3.2, + including PI, NTK, and YaRN as initial individuals. + Args: - population_size (int): Size of the population. - extension_ratio (float): Extension ratio for context window. - d_model (int): Dimension of the model. + population_size: Number of individuals in the population + search_space: Dictionary defining the search space for lambda_i and n_hat + d_model: Dimension of the model Returns: - list: Initialized population. + population: List of individuals, each represented as a dictionary """ population = [] - for _ in range(population_size): - lambda_factors = torch.FloatTensor(d_model).uniform_(1.0, extension_ratio) - n_hat = random.randint(0, d_model) - population.append((lambda_factors, n_hat)) - + individual = { + "lambda_i": [ + random.uniform(*search_space["lambda_i"]) for _ in range(d_model // 2) + ], + "n_hat": random.choice(search_space["n_hat"]), + } + population.append(apply_monotonic_constraint(individual)) return population