Skip to content

Commit

Permalink
Log GPU memory usage
Browse files Browse the repository at this point in the history
  • Loading branch information
jshuadvd committed Jul 14, 2024
1 parent cba5999 commit bd0ede3
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@
from importlib import reload
import src.main
from accelerate import Accelerator
from tqdm import tqdm
import wandb
import os
import logging
import hashlib
import pickle
from tqdm import tqdm
import GPUtil


from evaluation import evaluate_passkey_retrieval

Expand Down Expand Up @@ -329,6 +331,12 @@ def train(
f"Val Loss: {avg_val_loss:.4f}, Val Perplexity: {val_perplexity:.4f}"
)

# Log GPU memory usage
for gpu in GPUtil.getGPUs():
gpu_memory_used = gpu.memoryUsed
logger.info(f"GPU {gpu.id} memory use: {gpu_memory_used}MB")
wandb.log({f"GPU_{gpu.id}_memory_used": gpu_memory_used})

# Save checkpoint
accelerator.save_state(
{
Expand Down

0 comments on commit bd0ede3

Please sign in to comment.