Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update the train() function. The update provides more comprehensive t…
…racking of the models performance.The improvements make the training process more robust, efficient, and easier to monitor and debug. Key differences and improvements: Gradient Accumulation: The updated function introduces gradient accumulation (gradient_accumulation_steps). This allows for effectively larger batch sizes without increasing memory usage. The loss is divided by gradient_accumulation_steps and gradients are accumulated over multiple forward passes before updating the model. Mixed Precision Training: The updated function uses autocast and GradScaler for mixed precision training. This can speed up training and reduce memory usage, especially beneficial for large models. Learning Rate Scheduling: A learning rate scheduler is now used (scheduler.step()), which can help in better convergence. Improved Loss Tracking: The updated function keeps track of total loss over an epoch, allowing for a more accurate average loss calculation. Perplexity Calculation: The updated function calculates perplexity, which is a common metric for language models. Logging: The updated function uses wandb for logging various metrics, which allows for better tracking and visualization of the training process. Checkpointing: The updated function saves the model state after each epoch, allowing for training resumption if interrupted. Early Stopping: The updated function implements early stopping, which can prevent overfitting and save computation time. Best Model Saving: The updated function saves the best model based on validation loss, ensuring you keep the best performing model. Removed Input Size Check: The original function had a check for input size exceeding model.rope.max_len, which is removed in the updated version.
- Loading branch information