diff --git a/train.py b/train.py index 4a3a115..44ac730 100644 --- a/train.py +++ b/train.py @@ -17,7 +17,6 @@ import wandb import os import logging -import argparse from evaluation import evaluate_passkey_retrieval @@ -126,26 +125,6 @@ def preprocess_data(data, tokenizer, max_length, overlap): return sequences -def get_args(): - parser = argparse.ArgumentParser(description="Train LongRoPE model") - parser.add_argument( - "--batch_size", type=int, default=8, help="Batch size for training" - ) - parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate") - parser.add_argument("--epochs", type=int, default=10, help="Number of epochs") - parser.add_argument( - "--max_len", type=int, default=2048000, help="Maximum sequence length" - ) - parser.add_argument("--d_model", type=int, default=4096, help="Model dimension") - parser.add_argument( - "--n_heads", type=int, default=32, help="Number of attention heads" - ) - parser.add_argument( - "--num_layers", type=int, default=6, help="Number of transformer layers" - ) - return parser.parse_args() - - def compute_perplexity(loss): return torch.exp(loss) @@ -297,8 +276,6 @@ def main(): Main function to set up and run the LongRoPE model training process. """ - args = get_args() - # Initialize Weights & Biases for experiment tracking wandb.init(project="longrope", entity="your-entity-name")