Skip to content

Commit

Permalink
test(tests/test_prompt.py): eval code
Browse files Browse the repository at this point in the history
  • Loading branch information
AdityaNG committed May 5, 2024
1 parent 359f8d5 commit 8a544dd
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 1 deletion.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ python -m kan_gpt.prompt --prompt "Bangalore is often described as the " --model
- [x] Train a dummy batch w/o any memory issues
- [x] Mini training POC for MLP-GPT
- [x] Train MLP-GPT on the webtext dataset as a baseline
- [x] Train KAN-GPT on the webtext dataset as a baseline
- [ ] Metrics comparing KAN-GPT and MLP-GPT
- [x] Auto Save checkpoints
- [x] Auto Save checkpoints to W&B
- [ ] Auto Download model weights from git / huggingface
Expand Down
18 changes: 18 additions & 0 deletions tests/test_prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from kan_gpt.prompt import main

VOCAB_SIZE = 8
BLOCK_SIZE = 16
MODEL_TYPE = "gpt-pico"


def test_train():
class Args:
model_type = MODEL_TYPE
model_path = None
max_tokens = 3
prompt = "Bangalore is often described as the "
architecture = "KAN"
device = "cpu"

args = Args()
main(args)
20 changes: 19 additions & 1 deletion tests/test_train.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import torch
from kan_gpt.mingpt.model import GPT as MLP_GPT
from kan_gpt.train import save_model
from kan_gpt.train import save_model, main

VOCAB_SIZE = 8
BLOCK_SIZE = 16
Expand Down Expand Up @@ -33,3 +33,21 @@ def test_save_model():
f"Model not saved correctly at {save_path}, parameter "
f"{name} does not match original model"
)


# def test_train():

# # TODO: Download mini dataset for testing

# class Args:
# model_type = MODEL_TYPE
# dummy_dataset = True
# learning_rate = 5e-3
# max_iters = 1
# num_workers = 0
# batch_size = 1
# architecture = "KAN"
# device = "cpu"

# args = Args()
# main(args)

0 comments on commit 8a544dd

Please sign in to comment.