diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index c98fc8c07e..87297c9958 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -232,14 +232,43 @@ def _make_graphed_callables( set(warmup_func_idx) ), f"Warmup runs {len(warmup_func)} but only {len(set(warmup_func_idx))} are unique." - # Run warmup. + # Filter the TE modules that cudagraph can access. + visited_te_modules = set() + + def hook_fn(module, input, output): + if ( + isinstance(module, TransformerEngineBaseModule) + and FP8GlobalStateManager.is_fp8_enabled() + ): + visited_te_modules.add(module) + + # Filter the weights without gradients in backward. These weights are not in the computation graph so can be removed from cudagraph inputs. + no_grad_weights = None + + def get_no_grads(static_input_surface, grad_inputs, func_idx): + grad_index = 0 + none_grads = [] + for i in range(len(static_input_surface)): + if static_input_surface[i].requires_grad: + if grad_inputs[grad_index] is None and i >= len(flatten_sample_args[func_idx]): + none_grads.append(i) + grad_index += 1 + return set(none_grads) + + # Run warmup and do the above two filtering. with torch.cuda.stream(torch.cuda.Stream()): for func_idx, func in zip(warmup_func_idx, warmup_func): args = sample_args[func_idx] kwargs = sample_kwargs[func_idx] static_input_surface = per_callable_static_input_surfaces[func_idx] for _ in range(num_warmup_iters): + hooks = [] + for module in func.modules(): + hook = module.register_forward_hook(hook_fn) + hooks.append(hook) outputs, _ = _tree_flatten(func(*args, **kwargs)) + for hook in hooks: + hook.remove() grad_inputs = torch.autograd.grad( outputs=tuple(o for o in outputs if o.requires_grad), inputs=tuple(i for i in static_input_surface if i.requires_grad), @@ -247,10 +276,24 @@ def _make_graphed_callables( only_inputs=True, allow_unused=allow_unused_input, ) + if no_grad_weights is None: + no_grad_weights = get_no_grads(static_input_surface, grad_inputs, func_idx) del outputs, grad_inputs for module in func.modules(): if hasattr(module, "is_first_microbatch"): module.is_first_microbatch = True + if len(no_grad_weights) > 0: + per_callable_static_input_surfaces[func_idx] = tuple( + inp + for i, inp in enumerate(per_callable_static_input_surfaces[func_idx]) + if i not in no_grad_weights + ) + per_callable_module_params[func_idx] = tuple( + param + for i, param in enumerate(per_callable_module_params[func_idx]) + if i + len(flatten_sample_args[func_idx]) not in no_grad_weights + ) + no_grad_weights = None torch.cuda.synchronize() # All captures here share a mempool. To avoid replays corrupting each other's memory, @@ -495,6 +538,17 @@ def new_fwd(*user_args, **user_kwargs): isinstance(m, TransformerEngineBaseModule) and FP8GlobalStateManager.is_fp8_enabled() ): + if m not in visited_te_modules: + # Only Set the FP8 meta for the modules included by forward + continue + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + if ( + not fp8_recipe.fp8_mha + and not fp8_recipe.fp8_dpa + and hasattr(m, "attention_dropout") + and m.deterministic + ): + continue m.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() m.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(