Skip to content

Commit 8cb93ff

Browse files
committed
FP8 cuda graphs
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
1 parent 29b0c9c commit 8cb93ff

File tree

7 files changed

+589
-3
lines changed

7 files changed

+589
-3
lines changed

docs/api/pytorch.rst

+2
Original file line numberDiff line numberDiff line change
@@ -41,4 +41,6 @@ pyTorch
4141

4242
.. autoapifunction:: transformer_engine.pytorch.onnx_export
4343

44+
.. autoapifunction:: transformer_engine.pytorch.make_graphed_callables
45+
4446
.. autoapifunction:: transformer_engine.pytorch.get_cpu_offload_context

tests/pytorch/test_cuda_graphs.py

+168
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
"""Cuda graphs tests."""
2+
import argparse
3+
4+
import torch
5+
import transformer_engine.pytorch as te
6+
import apex
7+
8+
9+
def str_to_optimizer(optim):
10+
"""Get optimizer."""
11+
if optim == "sgd":
12+
return torch.optim.SGD
13+
if optim == "adamw":
14+
return torch.optim.AdamW
15+
if optim == "fused_sgd":
16+
return apex.optimizers.FusedSGD
17+
return apex.optimizers.FusedAdam
18+
19+
20+
def str_to_torch_dtype(dtype):
21+
"""Get pytorch dtype."""
22+
if dtype == "bf16":
23+
return torch.bfloat16
24+
if dtype == "fp16":
25+
return torch.float16
26+
return torch.float32
27+
28+
29+
def manual_seed(seed):
30+
"""Set seed."""
31+
torch.manual_seed(seed)
32+
torch.cuda.manual_seed(seed)
33+
34+
35+
def generate_data(args, warmup=False, gen_labels=False):
36+
"""Generate synthetic data."""
37+
dtype = str_to_torch_dtype(args.dtype)
38+
gen_func = torch.ones if warmup else torch.randn
39+
if args.module == "dpa":
40+
inputs = [gen_func(
41+
args.seq_length, args.bs, args.nheads,
42+
args.embed, device="cuda", requires_grad=True, dtype=dtype
43+
) for _ in range(3)]
44+
else:
45+
inputs = [gen_func(args.seq_length, args.bs,
46+
args.hdim, device="cuda", requires_grad=True, dtype=dtype)]
47+
48+
if not gen_labels:
49+
return inputs
50+
51+
target = torch.randn(args.seq_length, args.bs, args.hdim, device="cuda", dtype=dtype)
52+
return inputs, target
53+
54+
55+
def print_values(model, output):
56+
"""Debug."""
57+
values = []
58+
for param in model.parameters():
59+
values.append(param.sum().item())
60+
if param.grad is not None:
61+
values.append(param.grad.sum().item())
62+
values.append(output.sum().item())
63+
print(values)
64+
65+
66+
def parse_args():
67+
"""Arguments."""
68+
parser = argparse.ArgumentParser(description="Args for testing CUDA graphs with TE layers.")
69+
parser.add_argument('--seed', type=int, default=1234)
70+
parser.add_argument('--dtype', type=str, default="bf16", choices=["bf16", "fp16", "fp32"])
71+
parser.add_argument('--optimizer', type=str, default="fused_adamw",
72+
choices=["fused_adamw", "fused_sgd", "sgd", "adamw"])
73+
parser.add_argument('--num-layers', type=int, default=1)
74+
parser.add_argument('--module', default="linear",
75+
choices=['linear', 'layernorm_linear', 'layernorm_mlp',
76+
'transformer', 'dpa', 'mha'])
77+
parser.add_argument('--fp8', action='store_true')
78+
parser.add_argument('--graph', action='store_true')
79+
parser.add_argument('--graph-mode', default="full", choices=['full', 'individual'])
80+
parser.add_argument('--num-warmup-iters', type=int, default=3)
81+
parser.add_argument('--steps', type=int, default=1)
82+
parser.add_argument('--hdim', type=int, default=768)
83+
parser.add_argument('--seq-length', type=int, default=2048)
84+
parser.add_argument('--bs', type=int, default=2)
85+
parser.add_argument('--nheads', type=int, default=12)
86+
parser.add_argument('--dropout', type=float, default=0.1)
87+
return parser.parse_args()
88+
89+
90+
def train(args):
91+
"""Train."""
92+
93+
dtype = str_to_torch_dtype(args.dtype)
94+
95+
# Create modules.
96+
if args.module == "transformer":
97+
modules = [te.TransformerLayer(
98+
args.hdim, args.hdim, args.nheads,
99+
hidden_dropout=args.dropout,
100+
attention_dropout=args.dropout,
101+
params_dtype=dtype,
102+
) for _ in range(args.num_layers)]
103+
elif args.module == "layernorm_mlp":
104+
modules = [te.LayerNormMLP(
105+
args.hdim, args.hdim, params_dtype=dtype
106+
) for _ in range(args.num_layers)]
107+
elif args.module == "layernorm_linear":
108+
modules = [te.LayerNormLinear(
109+
args.hdim, args.hdim, params_dtype=dtype
110+
) for _ in range(args.num_layers)]
111+
elif args.module == "mha":
112+
modules = [te.MultiheadAttention(
113+
args.hdim, args.nheads, attention_dropout=args.dropout, params_dtype=dtype
114+
) for _ in range(args.num_layers)]
115+
elif args.module == "dpa":
116+
assert args.hdim % args.nheads == 0, "Err."
117+
assert args.num_layers == 1, "Err."
118+
args.embed = args.hdim // args.nheads
119+
modules = [te.DotProductAttention(
120+
args.nheads, args.embed, attention_dropout=args.dropout
121+
) for _ in range(args.num_layers)]
122+
else:
123+
modules = [te.Linear(
124+
args.hdim, args.hdim, device="cuda", params_dtype=dtype
125+
) for _ in range(args.num_layers)]
126+
127+
# Generate model and wrap API to return graphed version.
128+
if args.graph:
129+
# Graph entire module at once.
130+
if args.graph_mode == "full":
131+
model = modules[0] if args.module == "dpa" else torch.nn.Sequential(*modules)
132+
model = te.make_graphed_callables(
133+
model,
134+
generate_data(args, warmup=True),
135+
num_warmup_iters=args.num_warmup_iters,
136+
enabled=args.fp8)
137+
else:
138+
modules = [te.make_graphed_callables(
139+
module,
140+
generate_data(args, warmup=True),
141+
num_warmup_iters=args.num_warmup_iters,
142+
enabled=args.fp8) for module in modules]
143+
model = modules[0] if args.module == "dpa" else torch.nn.Sequential(*modules)
144+
else:
145+
model = modules[0] if args.module == "dpa" else torch.nn.Sequential(*modules)
146+
147+
# Loss function and optimizer.
148+
loss_fn = torch.nn.MSELoss()
149+
optimizer = str_to_optimizer(args.optimizer)(model.parameters(), lr=0.001)
150+
151+
# Launch.
152+
for _ in range(args.steps):
153+
inputs, target = generate_data(args, gen_labels=True)
154+
with te.fp8_autocast(enabled=args.fp8):
155+
output = model(*inputs)
156+
loss = loss_fn(output, target)
157+
loss.backward()
158+
optimizer.step()
159+
optimizer.zero_grad()
160+
161+
# Debug.
162+
print_values(model, output)
163+
164+
165+
if __name__ == "__main__":
166+
arguments = parse_args()
167+
manual_seed(arguments.seed)
168+
train(arguments)

tests/pytorch/test_numerics.py

+70-2
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
is_bf16_compatible,
2121
)
2222
from transformer_engine.pytorch import (
23-
DotProductAttention, LayerNormLinear, LayerNormMLP, Linear,
24-
MultiheadAttention, RMSNorm, TransformerLayer
23+
DotProductAttention, LayerNormLinear, LayerNormMLP, Linear, RMSNorm,
24+
make_graphed_callables, MultiheadAttention, TransformerLayer
2525
)
2626
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
2727
from transformer_engine.pytorch.distributed import _set_cuda_rng_state, CudaRNGStatesTracker
@@ -1199,6 +1199,7 @@ def test_gpt_fp8_parameters(dtype, bs, model):
11991199
outputs_fp8_params = _test_gpt_fp8_parameters(bs, dtype, config, True)
12001200
assert_all_equal(outputs, outputs_fp8_params)
12011201

1202+
12021203
@pytest.mark.parametrize("dtype", param_types)
12031204
@pytest.mark.parametrize("bs", batch_sizes)
12041205
@pytest.mark.parametrize("model", model_configs.keys())
@@ -1275,3 +1276,70 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
12751276
y_bshd = block_bshd(x_bshd)
12761277

12771278
assert_all_equal([y_bshd], [y_sbhd.transpose(0,1).contiguous()])
1279+
1280+
1281+
def _test_gpt_e2e_make_graphed_callables(block, forward_func, bs, dtype, config):
1282+
reset_rng_states()
1283+
1284+
inp = torch.randn(config.seq_len, bs, config.hidden_size, device='cuda', dtype=dtype, requires_grad=True)
1285+
1286+
out = forward_func(inp)
1287+
loss = out.sum()
1288+
loss.backward()
1289+
1290+
grads = [inp.grad]
1291+
for p in block.parameters():
1292+
if p.requires_grad:
1293+
grads.append(p.grad)
1294+
1295+
return out, grads
1296+
1297+
1298+
def get_forward_func(block):
1299+
def func(inp):
1300+
with fp8_autocast(enabled=fp8_available):
1301+
out = block(inp)
1302+
return out
1303+
return func
1304+
1305+
1306+
@pytest.mark.parametrize("dtype", param_types)
1307+
@pytest.mark.parametrize("bs", batch_sizes)
1308+
@pytest.mark.parametrize("model", model_configs.keys())
1309+
def test_gpt_make_graphed_callables(dtype, bs, model):
1310+
config = model_configs[model]
1311+
1312+
sigma = 0.023
1313+
init_method = init_method_normal(sigma)
1314+
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
1315+
1316+
block = (
1317+
TransformerLayer(
1318+
config.hidden_size,
1319+
4 * config.hidden_size,
1320+
config.num_attention_heads,
1321+
layernorm_epsilon=config.eps,
1322+
init_method=init_method,
1323+
output_layer_init_method=output_layer_init_method,
1324+
hidden_dropout=0.1,
1325+
attention_dropout=0.1,
1326+
kv_channels=config.embed,
1327+
apply_residual_connection_post_layernorm=False,
1328+
output_layernorm=False,
1329+
fuse_qkv_params=True,
1330+
)
1331+
.to(dtype=dtype)
1332+
.cuda()
1333+
)
1334+
graphed_block = copy.deepcopy(block)
1335+
graph_inp = torch.randn(config.seq_len, bs, config.hidden_size, device='cuda', dtype=dtype, requires_grad=True)
1336+
1337+
forward_func = get_forward_func(block)
1338+
forward_func_graphed = make_graphed_callables(graphed_block, (graph_inp,), num_warmup_iters=3, enabled=fp8_available)
1339+
1340+
out, grads = _test_gpt_e2e_make_graphed_callables(block, forward_func, bs, dtype, config)
1341+
graphed_out, graphed_grads = _test_gpt_e2e_make_graphed_callables(graphed_block, forward_func_graphed, bs, dtype, config)
1342+
1343+
# Check that results match
1344+
assert_allclose(out, graphed_out, 1e-1)
1345+
# assert_allclose(grads, graphed_grads, 1e-1)

transformer_engine/pytorch/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from .transformer import TransformerLayer
1515
from .fp8 import fp8_autocast
1616
from .fp8 import fp8_model_init
17+
from .graph import make_graphed_callables
1718
from .export import onnx_export
1819
from .distributed import checkpoint
1920
from .distributed import CudaRNGStatesTracker

0 commit comments

Comments
 (0)