Skip to content

Commit

Permalink
Update comments in the main() method
Browse files Browse the repository at this point in the history
  • Loading branch information
jshuadvd committed Jul 5, 2024
1 parent 306ba56 commit 45de677
Showing 1 changed file with 24 additions and 3 deletions.
27 changes: 24 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,35 +243,49 @@ def train(

# %%
def main():
"""Main function to setup and run training."""

"""
Main function to set up and run the LongRoPE model training process.
"""
# Initialize Weights & Biases for experiment tracking
wandb.init(project="longrope", entity="your-entity-name")

# Load and configure the tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.model_max_length = 2048000
tokenizer.model_max_length = 2048000 # Set maximum sequence length to 2048k tokens

# Load the raw data
data = load_data("../data/raw/enwik8.gz")

# Set parameters for data preprocessing
max_length = 65536
overlap = 4096

# Preprocess the data into sequences
sequences = preprocess_data(data, tokenizer, max_length, overlap)

# Create target sequences (shifted by one token)
targets = [seq[1:] + [tokenizer.eos_token_id] for seq in sequences]

# Validate that all target indices are within the vocabulary size
validate_targets(targets, tokenizer.vocab_size)

# Create a custom dataset from sequences and targets
dataset = CustomDataset(sequences, targets)

# Split the dataset into training and validation sets
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(
dataset, [train_size, val_size]
)

# Create data loaders for training and validation
train_loader = DataLoader(
train_dataset, batch_size=8, shuffle=True, collate_fn=collate_fn
)
val_loader = DataLoader(val_dataset, batch_size=8, collate_fn=collate_fn)

# Initialize the LongRoPE model
model = LongRoPEModel(
d_model=4096,
n_heads=32,
Expand All @@ -280,14 +294,17 @@ def main():
max_len=2048000,
)

# Set up optimizer, loss function, and learning rate scheduler
optimizer = optim.AdamW(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()
scheduler = CosineAnnealingLR(optimizer, T_max=10)

# Prepare model, optimizer, data loaders, and scheduler for distributed training
model, optimizer, train_loader, val_loader, scheduler = accelerator.prepare(
model, optimizer, train_loader, val_loader, scheduler
)

# Extend the context window of the model
extended_model = model.extend_context(
data_path="../data/raw/enwik8.gz",
target_length=2048000,
Expand All @@ -299,17 +316,21 @@ def main():
max_iterations=10,
)

# Recover performance on shorter contexts
recovered_model = extended_model.recover_short_context(
data_path="../data/raw/enwik8.gz",
max_sequence_length=48192,
tokenizer=tokenizer,
)

# Create new optimizer and scheduler for the recovered model
optimizer = optim.AdamW(recovered_model.parameters(), lr=1e-4)
scheduler = CosineAnnealingLR(optimizer, T_max=10)

# Train the recovered model
train(recovered_model, train_loader, val_loader, optimizer, criterion, scheduler)

# Finish logging and close the Weights & Biases run
wandb.finish()


Expand Down

0 comments on commit 45de677

Please sign in to comment.