-
Notifications
You must be signed in to change notification settings - Fork 58
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
8d97674
commit c211fca
Showing
132 changed files
with
1,170,631 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
env | ||
env_2 | ||
env_pypy | ||
cache | ||
cache_lr | ||
env_cpython | ||
.vscode | ||
# data | ||
simulator_output | ||
wandb | ||
|
||
# python files | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
.ipynb_checkpoints | ||
*.log |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
.PHONY: help lint lint/flake8 lint/black lint/isort format format/black format/autopep8 format/isort | ||
.DEFAULT_GOAL := help | ||
|
||
lint/flake8: ## check style with flake8 | ||
flake8 simulator | ||
|
||
lint/black: ## check style with black | ||
black --check simulator | ||
|
||
lint/isort: ## check style with isort | ||
isort --check-only --profile black simulator | ||
|
||
lint: lint/black lint/isort ## check style | ||
|
||
format/black: ## format code with black | ||
black simulator | ||
|
||
format/autopep8: ## format code with autopep8 | ||
autopep8 --in-place --aggressive --aggressive --recursive simulator/ | ||
|
||
format/isort: ## format code with isort | ||
isort --profile black simulator | ||
|
||
format: format/isort format/black ## format code | ||
|
||
run: | ||
python -m simulator.main | ||
|
||
run-with-args: | ||
python -m simulator.main $(ARGS) | ||
|
||
run-with-trace: | ||
python -m simulator.main --write_chrome_trace True; \ | ||
trace_file=`ls -t simulator_output/*/chrome_trace.json | head -1`; \ | ||
zip -r $$trace_file.zip $$trace_file | ||
|
||
run-with-trace-and-args: | ||
python -m simulator.main --write_chrome_trace True $(ARGS); \ | ||
trace_file=`ls -t simulator_output/*/chrome_trace.json | head -1`; \ | ||
zip -r $$trace_file.zip $$trace_file |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
import datetime | ||
import itertools | ||
import os | ||
import random | ||
from tqdm import tqdm | ||
|
||
import ray | ||
import torch.backends.cudnn as cudnn | ||
import pandas as pd | ||
|
||
from benchmark.all_gather_wrapper import AllGatherWrapper | ||
|
||
|
||
NUM_GPUS = 2 | ||
OUTPUT_DIR = "all_gather_benchmarking_output" | ||
|
||
# NUM_HEADS = [8, 16, 32, 64, 128, 256] | ||
# EMBEDDING_DIMS = [1024, 2048, 4096, 8192, 16384, 32768] | ||
# BATCH_SIZES = [2, 4, 8, 16, 32, 64] | ||
# CONTEXT_LENGTHS = [256, 512, 1024, 2048, 4096, 8192] | ||
# DECODE = [False, True] | ||
|
||
# Debug config | ||
# NUM_HEADS = [96] | ||
# EMBEDDING_DIMS = [96 * 128] | ||
# BATCH_SIZES = [2, 4, 8, 16, 32, 64] | ||
# CONTEXT_LENGTHS = [256, 512, 1024, 2048, 4096, 8136] | ||
# DECODE = [True, False] | ||
|
||
# [llama 7b, llama 13b, llama 33b, llama 65b, gpt3 175b, code-llama 34b, llama 2 70b, falcon 7b, falcon 40b, falcon 180b] | ||
# all llama and falcon models share the same vocab size so we only run one of them | ||
VOCAB_SIZE = [32768, 50257, 65024] | ||
NUM_TOKENS = \ | ||
list(range(0, 128, 1)) + \ | ||
list(range(128, 1536, 4)) + \ | ||
list(range(1536, 98 * 1024, 256)) + \ | ||
list(range(98 * 1024, 196 * 1024, 512)) | ||
NUM_TENSOR_PARALLEL_WORKERS = 2 | ||
|
||
|
||
def safe_ray_get(futures): | ||
outputs = [] | ||
for future in futures: | ||
try: | ||
output = ray.get(future) | ||
outputs.append(output) | ||
except Exception as e: | ||
print(f"Error: {e}") | ||
outputs.append(None) | ||
return outputs | ||
|
||
@ray.remote(num_gpus=1) | ||
class ModelRunner: | ||
def run_all_gather( | ||
self, rank, num_workers, comm_id, vocab_size, num_tokens | ||
): | ||
wrapper = AllGatherWrapper(rank, num_workers, comm_id, vocab_size, num_tokens) | ||
stats = wrapper.profile() | ||
return stats | ||
|
||
|
||
def run_benchmark(): | ||
runner_pool = [ModelRunner.remote() for _ in range(NUM_GPUS)] | ||
|
||
# create a dir in out dir with human readable timestamp | ||
output_dir = f"{OUTPUT_DIR}/{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}" | ||
os.makedirs(output_dir, exist_ok=True) | ||
|
||
all_results = [] | ||
|
||
params = itertools.product(VOCAB_SIZE, NUM_TOKENS) | ||
|
||
used_comm_ids = set() | ||
|
||
for ( | ||
vocab_size, | ||
num_tokens, | ||
) in tqdm(list(params)): | ||
# for each experiment generate a random master port | ||
while True: | ||
comm_id = random.randint(65535, 655350000000) | ||
if comm_id not in used_comm_ids: | ||
used_comm_ids.add(comm_id) | ||
break | ||
|
||
promises = [] | ||
for rank in range(NUM_TENSOR_PARALLEL_WORKERS): | ||
promise = runner_pool[rank].run_all_gather.remote( | ||
rank, NUM_TENSOR_PARALLEL_WORKERS, comm_id, vocab_size, num_tokens | ||
) | ||
promises.append(promise) | ||
|
||
for rank in range(NUM_TENSOR_PARALLEL_WORKERS): | ||
result = safe_ray_get([promises[rank]])[0] | ||
if result and rank == 0: | ||
all_results.append(result) | ||
if not result: | ||
runner_pool[rank] = ModelRunner.remote() | ||
|
||
df = pd.DataFrame(all_results) | ||
# the time_stats column is a dict, so we need to expand it into columns recursively and add prefix | ||
|
||
df = pd.json_normalize(df["time_stats"]).add_prefix("time_stats.").join(df.drop(columns=["time_stats"])) | ||
|
||
# write results to a csv file | ||
df.to_csv(f"{output_dir}/results.csv") | ||
|
||
|
||
if __name__ == "__main__": | ||
cudnn.benchmark = True | ||
run_benchmark() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
import torch | ||
|
||
from benchmark.cuda_timer import CudaTimer | ||
from benchmark.timer_stats_store import TimerStatsStore | ||
|
||
|
||
WARMUP_STEPS = 5 | ||
GRAPH_STEPS = 10 | ||
ACTIVE_STEPS = 10 | ||
|
||
|
||
class GraphAllGather: | ||
|
||
def __init__( | ||
self, | ||
vocab_size: int, | ||
context_length: int, | ||
num_workers: int, | ||
dtype: torch.dtype = torch.float16, | ||
disable_graph: bool = False, | ||
) -> None: | ||
self.vocab_size = vocab_size | ||
self.disable_graph = disable_graph | ||
|
||
self.buffer = torch.empty( | ||
size=(context_length, vocab_size // num_workers), | ||
dtype=dtype, | ||
device='cuda', | ||
) | ||
self.gathered_list = [ | ||
torch.empty_like(self.buffer) | ||
for _ in range(num_workers) | ||
] | ||
if not self.disable_graph: | ||
self.graph = self._build_graph() | ||
|
||
def _build_graph(self) -> torch.cuda.CUDAGraph: | ||
# Warm up. | ||
torch.distributed.all_gather(self.gathered_list, self.buffer) | ||
torch.cuda.synchronize() | ||
|
||
# Build graph. | ||
graph = torch.cuda.CUDAGraph() | ||
with torch.cuda.graph(graph): | ||
torch.distributed.all_gather(self.gathered_list, self.buffer) | ||
torch.cuda.synchronize() | ||
return graph | ||
|
||
def launch(self) -> torch.Tensor: | ||
# NOTE: x must be a slice of self.buffer. | ||
if self.disable_graph: | ||
torch.distributed.all_reduce(self.buffer) | ||
else: | ||
self.graph.replay() | ||
|
||
|
||
class AllGatherWrapper: | ||
def __init__(self, rank, num_workers, comm_id, vocab_size, num_tokens): | ||
self._rank = rank | ||
self._num_workers = num_workers | ||
self._vocab_size = vocab_size | ||
self._num_tokens = num_tokens | ||
self._comm_id = comm_id | ||
|
||
self._init_communication(comm_id) | ||
self._graph_all_reduce = GraphAllGather(vocab_size, num_tokens, num_workers, disable_graph=True) | ||
|
||
def _init_communication(self, comm_id): | ||
print(f"Initializing process group with comm id: {comm_id} for rank: {self._rank} with world size: {self._num_workers}") | ||
if torch.distributed.is_initialized(): | ||
return | ||
|
||
torch.distributed.init_process_group( | ||
backend="nccl", | ||
rank=self._rank, | ||
world_size=self._num_workers, | ||
init_method=f"file:///tmp/sing_bm_{comm_id}", | ||
) | ||
|
||
def _run_all_reduce(self): | ||
torch.cuda.synchronize() | ||
with CudaTimer("all_reduce"): | ||
self._graph_all_reduce.launch() | ||
|
||
torch.cuda.synchronize() | ||
|
||
def profile(self): | ||
for _ in range(WARMUP_STEPS): | ||
self._run_all_reduce() | ||
|
||
TimerStatsStore.clear_stats() | ||
|
||
for _ in range(ACTIVE_STEPS): | ||
self._run_all_reduce() | ||
|
||
return { | ||
"time_stats": TimerStatsStore.get_stats(), | ||
"rank": self._rank, | ||
"num_workers": self._num_workers, | ||
"vocab_size": self._vocab_size, | ||
"num_tokens": self._num_tokens, | ||
} |
Oops, something went wrong.