|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 | # |
15 | | -import os |
| 15 | + |
16 | 16 | from typing import Callable, Union |
17 | 17 |
|
18 | 18 | import torch |
| 19 | +from torch._inductor.compile_fx import cudagraphify_impl |
| 20 | +from torch._inductor.utils import dynamo_utils |
| 21 | +from torch._subclasses import FakeTensor |
19 | 22 |
|
20 | 23 |
|
21 | | -def cuda_graphs_wrapper( |
22 | | - model: Callable, |
23 | | - inputs: Union[list[torch.Tensor], tuple[torch.Tensor]], |
24 | | - copy_outputs: bool = False, |
25 | | - pool: (int, int) = torch.cuda.graph_pool_handle(), |
26 | | -): |
27 | | - """ |
28 | | - From torchdynamo |
29 | | - """ |
30 | | - assert isinstance(inputs, (list, tuple)), f"inputs is of type {type(inputs)} instead of list" |
31 | | - |
32 | | - # required warmup, not just for perf but for correctness |
33 | | - torch.cuda.synchronize() |
34 | | - stream = torch.cuda.Stream() |
35 | | - stream.wait_stream(torch.cuda.current_stream()) |
36 | | - with torch.cuda.stream(stream): |
37 | | - # 2 rounds, 1 to build the model (triton kernels, casting, etc.), |
38 | | - # and 1 for warmup |
39 | | - for _ in range(2): |
40 | | - model(*inputs) |
41 | | - stream.synchronize() |
42 | | - torch.cuda.current_stream().wait_stream(stream) |
43 | | - torch.cuda.synchronize() |
44 | | - # copy inputs after executing the warmup in case it mutates them at the first iteration |
45 | | - static_inputs = [torch.zeros_like(x) for x in inputs] |
| 24 | +def cuda_graphs_wrapper(model: Callable, inputs: Union[list[torch.Tensor], tuple[torch.Tensor]]): |
| 25 | + assert isinstance(inputs, (list, tuple)) |
| 26 | + # if using fake tensors, defer cudagraphs until we get real inputs at runtime |
| 27 | + if not any(isinstance(inp, FakeTensor) for inp in inputs): |
| 28 | + model(*inputs) # additional warmup needed when input is mutated by some kernel |
| 29 | + f = cudagraphify_impl(lambda args: model(*args), inputs) |
| 30 | + return lambda args: f(list(args)) |
46 | 31 |
|
47 | | - # record |
48 | | - graph = torch.cuda.CUDAGraph() |
49 | | - with torch.cuda.graph(graph, stream=stream, pool=pool): |
50 | | - static_outputs = model(*static_inputs) |
51 | | - if not isinstance(static_outputs, (list, tuple)): |
52 | | - static_outputs = (static_outputs,) |
| 32 | + compiled_fn = None |
53 | 33 |
|
54 | 34 | def run(*new_inputs): |
55 | | - if "PYTEST_CURRENT_TEST" not in os.environ: # for benchmarks, we may want to avoid input copy overhead |
56 | | - assert isinstance(new_inputs, (list, tuple)), f"inputs is of type {type(new_inputs)} instead of list" |
57 | | - assert len(static_inputs) == len(new_inputs), f"{len(static_inputs)} == {len(new_inputs)}" |
58 | | - for dst, src in zip(static_inputs, new_inputs): |
59 | | - dst.copy_(src) # cuda graph can only read data from the same address |
60 | | - graph.replay() |
61 | | - if copy_outputs: |
62 | | - return [x.clone() for x in static_outputs] |
63 | | - else: |
64 | | - return static_outputs |
| 35 | + nonlocal compiled_fn |
| 36 | + if compiled_fn is None: |
| 37 | + with dynamo_utils.preserve_rng_state(): |
| 38 | + model(*new_inputs) # additional warmup needed when input is mutated by some kernel |
| 39 | + f = cudagraphify_impl(lambda args: model(*args), new_inputs) |
| 40 | + |
| 41 | + def compiled_fn(args): |
| 42 | + return f(list(args)) |
| 43 | + |
| 44 | + return compiled_fn(new_inputs) |
65 | 45 |
|
66 | 46 | return run |
0 commit comments