From 8a544dd5a5458cf5b1163058143b99d50eeedf11 Mon Sep 17 00:00:00 2001 From: Aditya NG Date: Sun, 5 May 2024 23:15:51 +0530 Subject: [PATCH] test(tests/test_prompt.py): eval code --- README.md | 2 ++ tests/test_prompt.py | 18 ++++++++++++++++++ tests/test_train.py | 20 +++++++++++++++++++- 3 files changed, 39 insertions(+), 1 deletion(-) create mode 100644 tests/test_prompt.py diff --git a/README.md b/README.md index 241da82..5059f7e 100644 --- a/README.md +++ b/README.md @@ -92,6 +92,8 @@ python -m kan_gpt.prompt --prompt "Bangalore is often described as the " --model - [x] Train a dummy batch w/o any memory issues - [x] Mini training POC for MLP-GPT - [x] Train MLP-GPT on the webtext dataset as a baseline +- [x] Train KAN-GPT on the webtext dataset as a baseline +- [ ] Metrics comparing KAN-GPT and MLP-GPT - [x] Auto Save checkpoints - [x] Auto Save checkpoints to W&B - [ ] Auto Download model weights from git / huggingface diff --git a/tests/test_prompt.py b/tests/test_prompt.py new file mode 100644 index 0000000..6288e28 --- /dev/null +++ b/tests/test_prompt.py @@ -0,0 +1,18 @@ +from kan_gpt.prompt import main + +VOCAB_SIZE = 8 +BLOCK_SIZE = 16 +MODEL_TYPE = "gpt-pico" + + +def test_train(): + class Args: + model_type = MODEL_TYPE + model_path = None + max_tokens = 3 + prompt = "Bangalore is often described as the " + architecture = "KAN" + device = "cpu" + + args = Args() + main(args) diff --git a/tests/test_train.py b/tests/test_train.py index 3245d9f..4c1f5c6 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -1,7 +1,7 @@ import os import torch from kan_gpt.mingpt.model import GPT as MLP_GPT -from kan_gpt.train import save_model +from kan_gpt.train import save_model, main VOCAB_SIZE = 8 BLOCK_SIZE = 16 @@ -33,3 +33,21 @@ def test_save_model(): f"Model not saved correctly at {save_path}, parameter " f"{name} does not match original model" ) + + +# def test_train(): + +# # TODO: Download mini dataset for testing + +# class Args: +# model_type = MODEL_TYPE +# dummy_dataset = True +# learning_rate = 5e-3 +# max_iters = 1 +# num_workers = 0 +# batch_size = 1 +# architecture = "KAN" +# device = "cpu" + +# args = Args() +# main(args)