Skip to content

Commit

Permalink
Add MemoryProfileCallback (#10166)
Browse files Browse the repository at this point in the history
* Add MemoryProfileCallback

Signed-off-by: Shriya Palsamudram <[email protected]>

* Apply isort and black reformatting

Signed-off-by: ShriyaPalsamudram <[email protected]>

* Remove reference cycles, save snapshot on specific ranks

Signed-off-by: Shriya Palsamudram <[email protected]>

* Remove unnecessary imports

Signed-off-by: Shriya Palsamudram <[email protected]>

* Apply isort and black reformatting

Signed-off-by: ShriyaPalsamudram <[email protected]>

* Update docstring

Signed-off-by: Shriya Palsamudram <[email protected]>

---------

Signed-off-by: Shriya Palsamudram <[email protected]>
Signed-off-by: ShriyaPalsamudram <[email protected]>
Signed-off-by: Shriya Rishab <[email protected]>
Co-authored-by: ShriyaPalsamudram <[email protected]>
  • Loading branch information
ShriyaPalsamudram and ShriyaPalsamudram authored Aug 23, 2024
1 parent 1c90b5e commit 6d1be93
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 0 deletions.
2 changes: 2 additions & 0 deletions nemo/lightning/pytorch/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from nemo.lightning.pytorch.callbacks.ddp_parity_checker import DdpParityChecker
from nemo.lightning.pytorch.callbacks.memory_profiler import MemoryProfileCallback
from nemo.lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
from nemo.lightning.pytorch.callbacks.model_transform import ModelTransform
from nemo.lightning.pytorch.callbacks.nsys import NsysCallback
Expand All @@ -8,6 +9,7 @@
from nemo.lightning.pytorch.callbacks.progress_printer import ProgressPrinter

__all__ = [
"MemoryProfileCallback",
"ModelCheckpoint",
"ModelTransform",
"PEFT",
Expand Down
78 changes: 78 additions & 0 deletions nemo/lightning/pytorch/callbacks/memory_profiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import os

import torch
from pytorch_lightning.callbacks.callback import Callback
from torch.utils.viz._cycles import warn_tensor_cycles

from nemo.lightning import io
from nemo.utils import logging
from nemo.utils.get_rank import get_rank


class MemoryProfileCallback(Callback, io.IOMixin):
"""
This callback enables recording a timeline of memory allocations during training.
The generated .pickle profiles can be analyzed at https://pytorch.org/memory_viz
More info about the profiles can be found [here](https://pytorch.org/blog/understanding-gpu-memory-1/).
Args:
dir (Optional[str]): Directory to store the memory profile dump
warn_cycles (Optional[bool]): Whether to enable [reference cycle detection](https://pytorch.org/blog/understanding-gpu-memory-2/)
rank (Optional[list[int]]): List of ranks to collect snapshot on, defaults to all if list is empty
Example:
>>> callback = MemoryProfileCallback(dir="/mem_profile", ranks=[0])
>>> trainer = Trainer(callbacks=[callback])
"""

def __init__(self, dir: str = "/mem_profile", warn_cycles=True, ranks=[]):

self.dir = dir
self.ranks = ranks

os.makedirs(self.dir, exist_ok=True)
logging.info(f"Torch memory profiles will be written to: {self.dir}")

if warn_cycles:
logging.info("Enabling reference cycle detector")
warn_tensor_cycles()

def enable_on_rank(self) -> bool:
if not self.ranks:
return True
return get_rank() in self.ranks

def setup(self, trainer, pl_module, stage) -> None:
"""PyTorch Lightning hook:
https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#on-train-end
We use it here to start recording the memory profiler.
"""

if trainer.max_steps > 1000:
logging.warning(
f"Memory profiling creates snapshots during the entire training process, \
where every iteration increases the size of the snapshot. \
Try reducing trainer.max_steps to avoid running into issues"
)

if torch.distributed.is_initialized() and self.enable_on_rank():
torch.cuda.memory._record_memory_history(max_entries=100000)

def on_train_end(self, trainer, pl_module) -> None:
"""PyTorch Lightning hook:
https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#on-train-end
We use it here to finish memory profiling and write the snapshot.
"""

logging.info(
f"on_train_batch_end rank: {get_rank()} mem: {torch.cuda.memory_allocated()/1024/1024/1024} / {torch.cuda.max_memory_reserved()/1024/1024/1024}"
)

if torch.distributed.is_initialized() and self.enable_on_rank():
rank = get_rank()
_snapshot_path = f"{self.dir}/memory_snapshot-rank{rank}.pickle"
logging.info(f"Writing memory profile snapshot to {_snapshot_path}")
torch.cuda.memory._dump_snapshot(f"{_snapshot_path}")
torch.cuda.memory._record_memory_history(enabled=None)
logging.info(f"Finished writing memory profile snapshot: {_snapshot_path}")

0 comments on commit 6d1be93

Please sign in to comment.