@@ -35,7 +35,7 @@ def main():
3535
3636 # Example callbacks
3737 checkpoint_callback = ModelCheckpoint (
38- dirpath = "checkpoints" , # save to checkpoints/
38+ dirpath = "meta_pomo/ checkpoints" , # save to checkpoints/
3939 filename = "epoch_{epoch:03d}" , # save as epoch_XXX.ckpt
4040 save_top_k = 1 , # save only the best model
4141 save_last = True , # save the last model
@@ -47,8 +47,8 @@ def main():
4747 # Meta callbacks
4848 meta_callback = ReptileCallback (
4949 num_tasks = 1 , # the number of tasks in a mini-batch, i.e. `B` in the original paper
50- alpha = 0.99 , # initial weight of the task model for the outer-loop optimization of reptile
51- alpha_decay = 0.999 , # weight decay of the task model for the outer-loop optimization of reptile
50+ alpha = 0.9 , # initial weight of the task model for the outer-loop optimization of reptile
51+ alpha_decay = 1 , # weight decay of the task model for the outer-loop optimization of reptile. No decay performs better.
5252 min_size = 20 , # minimum of sampled size in meta tasks (only supported in cross-size generalization)
5353 max_size = 150 , # maximum of sampled size in meta tasks (only supported in cross-size generalization)
5454 data_type = "size_distribution" , # choose from ["size", "distribution", "size_distribution"]
@@ -63,7 +63,7 @@ def main():
6363
6464 # Adjust your trainer to the number of epochs you want to run
6565 trainer = RL4COTrainer (
66- max_epochs = 20000 , # (the number of meta_model updates) * (the number of tasks in a mini-batch)
66+ max_epochs = 15000 , # (the number of meta_model updates) * (the number of tasks in a mini-batch)
6767 callbacks = callbacks ,
6868 accelerator = "gpu" ,
6969 devices = [device_id ],
0 commit comments