Skip to content

Commit e604c11

Browse files
feat: update torch and modify cuda graphs (#220)
1 parent 07cf0c6 commit e604c11

File tree

8 files changed

+35
-68
lines changed

8 files changed

+35
-68
lines changed

Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.8 1 &&
2929

3030
RUN python3.9 -m ensurepip --default-pip --upgrade
3131

32-
RUN pip install --pre torch==1.14.0.dev20221029+cu117 --extra-index-url https://download.pytorch.org/whl/nightly/cu117
32+
RUN pip install --pre torch==2.0.0.dev20221214+cu117 --extra-index-url https://download.pytorch.org/whl/nightly/cu117
3333

3434

3535
WORKDIR /syncback

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
triton==2.0.0.dev20221202
2-
torch==1.14.0.dev20221029+cu117
2+
torch== 2.0.0.dev20221214+cu117
33
pytest
44
tabulate
55
termcolor

src/kernl/implementations/cuda_graph.py

Lines changed: 22 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -12,55 +12,35 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
#
15-
import os
15+
1616
from typing import Callable, Union
1717

1818
import torch
19+
from torch._inductor.compile_fx import cudagraphify_impl
20+
from torch._inductor.utils import dynamo_utils
21+
from torch._subclasses import FakeTensor
1922

2023

21-
def cuda_graphs_wrapper(
22-
model: Callable,
23-
inputs: Union[list[torch.Tensor], tuple[torch.Tensor]],
24-
copy_outputs: bool = False,
25-
pool: (int, int) = torch.cuda.graph_pool_handle(),
26-
):
27-
"""
28-
From torchdynamo
29-
"""
30-
assert isinstance(inputs, (list, tuple)), f"inputs is of type {type(inputs)} instead of list"
31-
32-
# required warmup, not just for perf but for correctness
33-
torch.cuda.synchronize()
34-
stream = torch.cuda.Stream()
35-
stream.wait_stream(torch.cuda.current_stream())
36-
with torch.cuda.stream(stream):
37-
# 2 rounds, 1 to build the model (triton kernels, casting, etc.),
38-
# and 1 for warmup
39-
for _ in range(2):
40-
model(*inputs)
41-
stream.synchronize()
42-
torch.cuda.current_stream().wait_stream(stream)
43-
torch.cuda.synchronize()
44-
# copy inputs after executing the warmup in case it mutates them at the first iteration
45-
static_inputs = [torch.zeros_like(x) for x in inputs]
24+
def cuda_graphs_wrapper(model: Callable, inputs: Union[list[torch.Tensor], tuple[torch.Tensor]]):
25+
assert isinstance(inputs, (list, tuple))
26+
# if using fake tensors, defer cudagraphs until we get real inputs at runtime
27+
if not any(isinstance(inp, FakeTensor) for inp in inputs):
28+
model(*inputs) # additional warmup needed when input is mutated by some kernel
29+
f = cudagraphify_impl(lambda args: model(*args), inputs)
30+
return lambda args: f(list(args))
4631

47-
# record
48-
graph = torch.cuda.CUDAGraph()
49-
with torch.cuda.graph(graph, stream=stream, pool=pool):
50-
static_outputs = model(*static_inputs)
51-
if not isinstance(static_outputs, (list, tuple)):
52-
static_outputs = (static_outputs,)
32+
compiled_fn = None
5333

5434
def run(*new_inputs):
55-
if "PYTEST_CURRENT_TEST" not in os.environ: # for benchmarks, we may want to avoid input copy overhead
56-
assert isinstance(new_inputs, (list, tuple)), f"inputs is of type {type(new_inputs)} instead of list"
57-
assert len(static_inputs) == len(new_inputs), f"{len(static_inputs)} == {len(new_inputs)}"
58-
for dst, src in zip(static_inputs, new_inputs):
59-
dst.copy_(src) # cuda graph can only read data from the same address
60-
graph.replay()
61-
if copy_outputs:
62-
return [x.clone() for x in static_outputs]
63-
else:
64-
return static_outputs
35+
nonlocal compiled_fn
36+
if compiled_fn is None:
37+
with dynamo_utils.preserve_rng_state():
38+
model(*new_inputs) # additional warmup needed when input is mutated by some kernel
39+
f = cudagraphify_impl(lambda args: model(*args), new_inputs)
40+
41+
def compiled_fn(args):
42+
return f(list(args))
43+
44+
return compiled_fn(new_inputs)
6545

6646
return run

src/kernl/model_optimization.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,11 @@
2222
from kernl.optimizer.dynamo_backend import dynamo_backend_ofi
2323

2424

25-
# single shared pool by default
26-
_pool: (int, int) = torch.cuda.graph_pool_handle()
27-
28-
2925
# needs to be generated once to be reused several times, like encoder/decoder models
3026
# https://github.com/pytorch/torchdynamo/issues/1816
3127
def _compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
3228
dynamo_backend_ofi(gm)
33-
return cuda_graphs_wrapper(gm, example_inputs, pool=_pool)
29+
return cuda_graphs_wrapper(gm, example_inputs)
3430

3531

3632
def optimize_model(original_model: PreTrainedModel) -> None:

test/models/bert.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
import torch
1919
import torch._dynamo as torchdynamo
20-
from torch._dynamo.optimizations import BACKENDS
2120
from transformers import AutoModel
2221

2322
from kernl.implementations.cuda_graph import cuda_graphs_wrapper
@@ -35,8 +34,7 @@ def get_model_baseline(base):
3534

3635
def get_model_dynamo_cuda_graphs(base):
3736
def compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
38-
compiled = BACKENDS["cudagraphs"](gm, example_inputs)
39-
return compiled
37+
return cuda_graphs_wrapper(gm, example_inputs)
4038

4139
@torchdynamo.optimize(compiler)
4240
def run(*args, **kwargs):

test/test_attention.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,6 @@ def test_benchmark_skinny_cross_attention(benchmark, implementation, shape):
191191
v = torch.rand_like(k)
192192
sm_scale = 0.3
193193

194-
p = torch.cuda.graph_pool_handle()
195194
expected = attention_reference(
196195
q=q.float(),
197196
k=k.float(),
@@ -203,8 +202,8 @@ def test_benchmark_skinny_cross_attention(benchmark, implementation, shape):
203202
)
204203
output = torch.empty_like(q)
205204
fn = implementations_skinny_cross_attention[implementation](output, sm_scale)
206-
r = cuda_graphs_wrapper(fn, [q, k, v], pool=p)
207-
_ = r(q, k, v)[0]
208-
result = benchmark(r)[0]
205+
r = cuda_graphs_wrapper(fn, [q, k, v])
206+
_ = r([q, k, v])[0]
207+
result = benchmark(r, [q, k, v])[0]
209208

210209
assert_all_close(a=expected, b=result.float(), atol=1e-2)

test/test_layer_norm.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,9 @@ def test_benchmark_layer_norm(benchmark, shape: int, dtype, cuda_graphs: bool, i
6565

6666
fn = implementations_layer_norm[implementation](layer_weight, layer_bias, eps)
6767
if cuda_graphs:
68-
run = cuda_graphs_wrapper(model=fn, inputs=[x], copy_outputs=False)
68+
run = cuda_graphs_wrapper(model=fn, inputs=[x])
6969
# CUDA graphs wraps output in a tuple
70-
fn = lambda tensor: run(tensor)[0] # noqa: E731
70+
fn = lambda tensor: run([tensor])[0] # noqa: E731
7171

7272
value = benchmark(fn, x)
7373
assert_all_close(value.float(), expected, atol=1e-1)
@@ -99,9 +99,9 @@ def test_benchmark_rms_norm(benchmark, shape: int, dtype, cuda_graphs: bool, imp
9999

100100
fn = implementations_rms_norm[implementation](layer_weight, eps)
101101
if cuda_graphs:
102-
run = cuda_graphs_wrapper(model=fn, inputs=[x], copy_outputs=False)
102+
run = cuda_graphs_wrapper(model=fn, inputs=[x])
103103
# CUDA graphs wraps output in a tuple
104-
fn = lambda tensor: run(tensor)[0] # noqa: E731
104+
fn = lambda tensor: run([tensor])[0] # noqa: E731
105105

106106
value = benchmark(fn, x)
107107
assert_all_close(value.float(), expected, atol=1e-1)

test/test_linear_layer.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,6 @@
2424
from kernl.implementations.linear_layer import linear_layer
2525

2626

27-
@pytest.fixture
28-
def cuda_graphs_pool() -> (int, int):
29-
return torch.cuda.graph_pool_handle()
30-
31-
3227
def get_pytorch_activation(activation: str) -> Callable:
3328
if activation == "gelu":
3429
return torch.nn.functional.gelu
@@ -71,7 +66,6 @@ def test_benchmark(
7166
bias: bool,
7267
activation: str,
7368
contiguous: bool,
74-
cuda_graphs_pool: (int, int),
7569
):
7670
batch, M, N, K = shape
7771

@@ -96,9 +90,9 @@ def test_benchmark(
9690

9791
fn = implementations[implementation](layer_weight, layer_bias, activation)
9892
if cuda_graphs:
99-
run = cuda_graphs_wrapper(model=fn, inputs=[x], pool=cuda_graphs_pool)
93+
run = cuda_graphs_wrapper(model=fn, inputs=[x])
10094
# CUDA graphs wraps output in a tuple
101-
fn = lambda tensor: run(tensor)[0] # noqa: E731
95+
fn = lambda tensor: run([tensor])[0] # noqa: E731
10296

10397
value = benchmark(fn, x)
10498

0 commit comments

Comments
 (0)