From bf4affe08d677bc1d85395cf45cdf0caed8417da Mon Sep 17 00:00:00 2001 From: Aditya NG Date: Wed, 8 May 2024 00:55:45 +0530 Subject: [PATCH] feat(sweep): sweep script for getting a vast hyperparam sweep --- kan_gpt/VERSION | 2 +- kan_gpt/sweep.py | 71 ++++++++++++++++++++++++++++++++++++++++++++++++ kan_gpt/train.py | 5 ++-- 3 files changed, 75 insertions(+), 3 deletions(-) create mode 100644 kan_gpt/sweep.py diff --git a/kan_gpt/VERSION b/kan_gpt/VERSION index 0ea3a94..0d91a54 100644 --- a/kan_gpt/VERSION +++ b/kan_gpt/VERSION @@ -1 +1 @@ -0.2.0 +0.3.0 diff --git a/kan_gpt/sweep.py b/kan_gpt/sweep.py new file mode 100644 index 0000000..9b9af7b --- /dev/null +++ b/kan_gpt/sweep.py @@ -0,0 +1,71 @@ +import wandb +from kan_gpt.train import main + + +def wandb_sweep(): + run = wandb.init(resume="allow", anonymous="must") + + class Args: + model_type = wandb.config.model_type + dummy_dataset = wandb.config.dummy_dataset + learning_rate = wandb.config.learning_rate + max_iters = wandb.config.max_iters + num_workers = wandb.config.num_workers + batch_size = wandb.config.batch_size + dataset = wandb.config.dataset + architecture = wandb.config.architecture + device = wandb.config.device + + run_args = Args() + + main(args=run_args, run=run) + + +def sweep(args): + sweep_configuration = { + "method": "random", + "name": "sweep", + "metric": {"goal": "minimize", "name": "test_loss"}, + "parameters": { + "model_type": {"values": args.model_type}, + "batch_size": {"values": args.batch_size}, + "dummy_dataset": {"values": args.dummy_dataset}, + "learning_rate": {"values": args.learning_rate}, + "max_iters": {"values": args.max_iters}, + "num_workers": {"values": args.num_workers}, + "dataset": {"values": args.dataset}, + "architecture": {"values": args.architecture}, + "device": {"values": args.device}, + }, + } + + sweep_id = wandb.sweep(sweep_configuration, project="KAN-GPT") + print("sweep_id (generated)", sweep_id) + + wandb.agent(sweep_id, function=wandb_sweep) + + +if __name__ == "__main__": + + class SweepArgs: + model_type = ["gpt-mini", "gpt-nano", "gpt2"] + dummy_dataset = [ + False, + ] + learning_rate = [5e-3, 5e-4, 5e-5, 5e-6] + max_iters = [ + 1000, + ] + num_workers = [ + 0, + ] + batch_size = [1, 2, 4, 8, 12, 16] + dataset = [ + "tinyshakespeare", + ] + architecture = ["MLP", "KAN"] + device = [ + "auto", + ] + + sweep(SweepArgs()) diff --git a/kan_gpt/train.py b/kan_gpt/train.py index 3f71f31..99a11e6 100644 --- a/kan_gpt/train.py +++ b/kan_gpt/train.py @@ -62,7 +62,7 @@ def save_model( return save_path -def main(args): +def main(args, run=None): config = { "model_type": args.model_type, "batch_size": args.batch_size, @@ -116,7 +116,8 @@ def main(args): train_config.device = args.device trainer = Trainer(train_config, model, train_dataset) - run = wandb.init(project="KAN-GPT", config=config) + if run is None: + run = wandb.init(project="KAN-GPT", config=config) wandb.watch(model) def batch_end_callback(trainer):