-
Notifications
You must be signed in to change notification settings - Fork 2.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add MemoryProfileCallback Signed-off-by: Shriya Palsamudram <[email protected]> * Apply isort and black reformatting Signed-off-by: ShriyaPalsamudram <[email protected]> * Remove reference cycles, save snapshot on specific ranks Signed-off-by: Shriya Palsamudram <[email protected]> * Remove unnecessary imports Signed-off-by: Shriya Palsamudram <[email protected]> * Apply isort and black reformatting Signed-off-by: ShriyaPalsamudram <[email protected]> * Update docstring Signed-off-by: Shriya Palsamudram <[email protected]> --------- Signed-off-by: Shriya Palsamudram <[email protected]> Signed-off-by: ShriyaPalsamudram <[email protected]> Signed-off-by: Shriya Rishab <[email protected]> Co-authored-by: ShriyaPalsamudram <[email protected]>
- Loading branch information
1 parent
1c90b5e
commit 6d1be93
Showing
2 changed files
with
80 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import os | ||
|
||
import torch | ||
from pytorch_lightning.callbacks.callback import Callback | ||
from torch.utils.viz._cycles import warn_tensor_cycles | ||
|
||
from nemo.lightning import io | ||
from nemo.utils import logging | ||
from nemo.utils.get_rank import get_rank | ||
|
||
|
||
class MemoryProfileCallback(Callback, io.IOMixin): | ||
""" | ||
This callback enables recording a timeline of memory allocations during training. | ||
The generated .pickle profiles can be analyzed at https://pytorch.org/memory_viz | ||
More info about the profiles can be found [here](https://pytorch.org/blog/understanding-gpu-memory-1/). | ||
Args: | ||
dir (Optional[str]): Directory to store the memory profile dump | ||
warn_cycles (Optional[bool]): Whether to enable [reference cycle detection](https://pytorch.org/blog/understanding-gpu-memory-2/) | ||
rank (Optional[list[int]]): List of ranks to collect snapshot on, defaults to all if list is empty | ||
Example: | ||
>>> callback = MemoryProfileCallback(dir="/mem_profile", ranks=[0]) | ||
>>> trainer = Trainer(callbacks=[callback]) | ||
""" | ||
|
||
def __init__(self, dir: str = "/mem_profile", warn_cycles=True, ranks=[]): | ||
|
||
self.dir = dir | ||
self.ranks = ranks | ||
|
||
os.makedirs(self.dir, exist_ok=True) | ||
logging.info(f"Torch memory profiles will be written to: {self.dir}") | ||
|
||
if warn_cycles: | ||
logging.info("Enabling reference cycle detector") | ||
warn_tensor_cycles() | ||
|
||
def enable_on_rank(self) -> bool: | ||
if not self.ranks: | ||
return True | ||
return get_rank() in self.ranks | ||
|
||
def setup(self, trainer, pl_module, stage) -> None: | ||
"""PyTorch Lightning hook: | ||
https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#on-train-end | ||
We use it here to start recording the memory profiler. | ||
""" | ||
|
||
if trainer.max_steps > 1000: | ||
logging.warning( | ||
f"Memory profiling creates snapshots during the entire training process, \ | ||
where every iteration increases the size of the snapshot. \ | ||
Try reducing trainer.max_steps to avoid running into issues" | ||
) | ||
|
||
if torch.distributed.is_initialized() and self.enable_on_rank(): | ||
torch.cuda.memory._record_memory_history(max_entries=100000) | ||
|
||
def on_train_end(self, trainer, pl_module) -> None: | ||
"""PyTorch Lightning hook: | ||
https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#on-train-end | ||
We use it here to finish memory profiling and write the snapshot. | ||
""" | ||
|
||
logging.info( | ||
f"on_train_batch_end rank: {get_rank()} mem: {torch.cuda.memory_allocated()/1024/1024/1024} / {torch.cuda.max_memory_reserved()/1024/1024/1024}" | ||
) | ||
|
||
if torch.distributed.is_initialized() and self.enable_on_rank(): | ||
rank = get_rank() | ||
_snapshot_path = f"{self.dir}/memory_snapshot-rank{rank}.pickle" | ||
logging.info(f"Writing memory profile snapshot to {_snapshot_path}") | ||
torch.cuda.memory._dump_snapshot(f"{_snapshot_path}") | ||
torch.cuda.memory._record_memory_history(enabled=None) | ||
logging.info(f"Finished writing memory profile snapshot: {_snapshot_path}") |