Skip to content

Commit

Permalink
Added argparse to manage hyperparameters, making it easier to experim…
Browse files Browse the repository at this point in the history
…ent with different settings
  • Loading branch information
jshuadvd committed Jul 6, 2024
1 parent c670e3f commit 35fe319
Showing 1 changed file with 24 additions and 0 deletions.
24 changes: 24 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import wandb
import os
import logging
import argparse

from evaluation import evaluate_passkey_retrieval

Expand Down Expand Up @@ -125,6 +126,26 @@ def preprocess_data(data, tokenizer, max_length, overlap):
return sequences


def get_args():
parser = argparse.ArgumentParser(description="Train LongRoPE model")
parser.add_argument(
"--batch_size", type=int, default=8, help="Batch size for training"
)
parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate")
parser.add_argument("--epochs", type=int, default=10, help="Number of epochs")
parser.add_argument(
"--max_len", type=int, default=2048000, help="Maximum sequence length"
)
parser.add_argument("--d_model", type=int, default=4096, help="Model dimension")
parser.add_argument(
"--n_heads", type=int, default=32, help="Number of attention heads"
)
parser.add_argument(
"--num_layers", type=int, default=6, help="Number of transformer layers"
)
return parser.parse_args()


def compute_perplexity(loss):
return torch.exp(loss)

Expand Down Expand Up @@ -275,6 +296,9 @@ def main():
"""
Main function to set up and run the LongRoPE model training process.
"""

args = get_args()

# Initialize Weights & Biases for experiment tracking
wandb.init(project="longrope", entity="your-entity-name")

Expand Down

0 comments on commit 35fe319

Please sign in to comment.