Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[T170073014] Rewrite distributed examples for Tensor Parallel, Sequence Parallel, 2D (FSDP + TP) #1201

Merged
merged 20 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
21a5fcf
update requirements.txt
lessw2020 Nov 15, 2023
f962b60
add torchrun support, move to init_device_mesh
lessw2020 Nov 15, 2023
bc3c1dd
update twod fully working
lessw2020 Nov 16, 2023
11a3bb2
ensure proper dp group seeding for synth data
lessw2020 Nov 16, 2023
9cebdf0
swiglu model added
lessw2020 Nov 16, 2023
2447883
sequential running of custom, auto, seq parallel models
lessw2020 Nov 16, 2023
a388c20
streamline to 2D TP only for two_d_parallel example
lessw2020 Nov 17, 2023
842c3f0
sequence parallel working...needs init_device_mesh update
lessw2020 Nov 18, 2023
3aa1c53
seq parallel now using init_device_mesh
lessw2020 Nov 21, 2023
b54e2ec
tp and sp examples all working and updated
lessw2020 Nov 21, 2023
4889e3b
updates from code review
lessw2020 Nov 21, 2023
b215178
remove utils.py. Sample models created in example files
lessw2020 Nov 22, 2023
242c328
remove originals.py, leftover imports, various updates from code revi…
lessw2020 Nov 22, 2023
2f4a083
code linting via ruff
lessw2020 Nov 22, 2023
742966b
code formatting via ruff
lessw2020 Nov 22, 2023
7da71bc
move rank_log to utils.py, update example files
lessw2020 Nov 22, 2023
836f798
move logging imports and config to log_utils, update examples with ne…
lessw2020 Nov 22, 2023
2de0144
add gpu verification, update run_python_examples.sh
lessw2020 Nov 22, 2023
77fe3d8
update min gpu = 4 for fsdp+tp
lessw2020 Nov 22, 2023
5f4a5d3
move gpu check to top of examples, but before import init_device_mesh…
lessw2020 Nov 22, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion distributed/tensor_parallelism/fsdp_tp_example.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
import torch
import torch.distributed as dist
import torch.nn as nn
Expand All @@ -13,7 +14,7 @@

from torch.distributed._tensor.device_mesh import init_device_mesh
import os
from log_utils import rank_log, get_logger
from log_utils import rank_log, get_logger, verify_min_gpu_count

"""
This is the script to test 2D Parallel which combines Tensor/Sequence
Expand Down Expand Up @@ -46,6 +47,12 @@
https://docs.google.com/presentation/d/17g6WqrO00rP3MsxbRENsPpjrlSkwiA_QB4r93_eB5is/
"""

_min_gpu_count = 2
lessw2020 marked this conversation as resolved.
Show resolved Hide resolved

if not verify_min_gpu_count(min_gpus=_min_gpu_count):
print(f"Unable to locate sufficient {_min_gpu_count} gpus to run this example. Exiting.")
sys.exit(0)


def find_multiple(n: int, k: int) -> int:
"""function to find resizing multiple for SwiGLU MLP"""
Expand Down
8 changes: 8 additions & 0 deletions distributed/tensor_parallelism/log_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import torch

logging.basicConfig(
format="%(asctime)s %(message)s", datefmt="%m/%d/%Y %I:%M:%S %p", level=logging.INFO
Expand All @@ -12,3 +13,10 @@ def rank_log(_rank, logger, msg):
"""helper function to log only on global rank 0"""
if _rank == 0:
logger.info(f" {msg}")


def verify_min_gpu_count(min_gpus: int = 2) -> bool:
""" verification that we have at least 2 gpus to run dist examples """
has_cuda = torch.cuda.is_available()
gpu_count = torch.cuda.device_count()
return has_cuda and gpu_count >= min_gpus
8 changes: 7 additions & 1 deletion distributed/tensor_parallelism/sequence_parallel_example.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import sys
import torch
import torch.nn as nn

Expand All @@ -11,7 +12,7 @@
RowwiseParallel,
)

from log_utils import rank_log, get_logger
from log_utils import rank_log, get_logger, verify_min_gpu_count


"""
Expand All @@ -29,6 +30,11 @@
now is different so that we need one all-gather for input and one reduce-scatter
in the end of the second linear layer.
"""
_min_gpu_count = 2

if not verify_min_gpu_count(min_gpus=_min_gpu_count):
print(f"Unable to locate sufficient {_min_gpu_count} gpus to run this example. Exiting.")
sys.exit(0)


class ToyModel(nn.Module):
Expand Down
10 changes: 9 additions & 1 deletion distributed/tensor_parallelism/tensor_parallel_example.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import sys
import torch
import torch.nn as nn

Expand All @@ -11,7 +12,9 @@
)


from log_utils import rank_log, get_logger
from log_utils import rank_log, get_logger, verify_min_gpu_count




"""
Expand Down Expand Up @@ -45,6 +48,11 @@
Parallelism APIs in this example to show users how to use them.
"""

_min_gpu_count = 2

if not verify_min_gpu_count(min_gpus=_min_gpu_count):
print(f"Unable to locate sufficient {_min_gpu_count} gpus to run this example. Exiting.")
sys.exit(0)

class ToyModel(nn.Module):
"""MLP based model"""
Expand Down
9 changes: 4 additions & 5 deletions run_python_examples.sh
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ function distributed() {
start
python tensor_parallelism/tensor_parallel_example.py || error "tensor parallel example failed"
python tensor_parallelism/sequence_parallel_example.py || error "sequence parallel example failed"
python tensor_parallelism/two_d_parallel_example.py || error "2D parallel example failed"
python ddp/main.py || error "ddp example failed"
python tensor_parallelism/fsdp_tp_parallel_example.py || error "2D parallel example failed"
python ddp/main.py || error "ddp example failed"
}

function fast_neural_style() {
Expand Down Expand Up @@ -96,7 +96,7 @@ function mnist() {
python main.py --epochs 1 --dry-run || error "mnist example failed"
}
function mnist_forward_forward() {
start
start
python main.py --epochs 1 --no_mps --no_cuda || error "mnist forward forward failed"

}
Expand Down Expand Up @@ -212,9 +212,8 @@ function clean() {
function run_all() {
# cpp
dcgan
# distributed
fast_neural_style
distributed
fast_neural_style
imagenet
mnist
mnist_forward_forward
Expand Down
Loading