Skip to content

Commit

Permalink
Update the crossover method to use dictionaries instead of tuples
Browse files Browse the repository at this point in the history
  • Loading branch information
jshuadvd committed Jul 4, 2024
1 parent 169b697 commit f510479
Showing 1 changed file with 14 additions and 16 deletions.
30 changes: 14 additions & 16 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,35 +561,33 @@ def mutate(parents, num_mutations, d_model):
return mutated_population


def crossover(parents, num_crossovers, d_model):
def mutate(parents, num_mutations, d_model):
"""
Perform crossover on the parent population.
Perform mutation on the parent population.
Args:
parents (list): Parent population.
num_crossovers (int): Number of crossovers to perform.
num_mutations (int): Number of mutations to perform.
d_model (int): Dimension of the model.
Returns:
list: Crossover population.
list: Mutated population.
"""
crossover_population = []
for _ in range(num_crossovers):
parent1_lambda, parent1_n_hat = random.choice(parents)
parent2_lambda, parent2_n_hat = random.choice(parents)
child_lambda = parent1_lambda.clone()
child_n_hat = parent1_n_hat
mutated_population = []
for _ in range(num_mutations):
parent = random.choice(parents)
child = {"lambda_i": parent["lambda_i"].clone(), "n_hat": parent["n_hat"]}

for i in range(d_model):
if random.random() < 0.5:
child_lambda[i] = parent2_lambda[i]
if random.random() < 0.1:
child["lambda_i"][i] *= random.uniform(0.8, 1.2)

if random.random() < 0.5:
child_n_hat = parent2_n_hat
if random.random() < 0.1:
child["n_hat"] = random.randint(0, d_model)

crossover_population.append((child_lambda, child_n_hat))
mutated_population.append(child)

return crossover_population
return mutated_population


def fine_tune(model, train_data, val_data, target_length, lambda_factors, n_hat, steps):
Expand Down

0 comments on commit f510479

Please sign in to comment.