Skip to content

Commit

Permalink
Merge branch 'mcore_cuda_graph' into 'main'
Browse files Browse the repository at this point in the history
Enable CUDA graphs for MCore inference

See merge request ADLR/megatron-lm!2531
  • Loading branch information
jaredcasper committed Jan 28, 2025
2 parents 3d1554d + 684facb commit d5069b8
Show file tree
Hide file tree
Showing 7 changed files with 123 additions and 43 deletions.
123 changes: 84 additions & 39 deletions examples/inference/gpt/gpt_batch_inference.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,30 @@
import os
from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import InferenceWrapperConfig
from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import (
InferenceWrapperConfig,
)
from pretrain_gpt import model_provider
import torch
import sys
import time
import tqdm
import warnings
from argparse import Namespace
from megatron.core.inference.engines.abstract_engine import AbstractEngine
from megatron.core.inference.engines.mcore_engine import MCoreEngine
from megatron.core.inference.sampling_params import SamplingParams
from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import GPTInferenceWrapper
from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import (
GPTInferenceWrapper,
)
from megatron.core.inference.inference_request import InferenceRequest
from megatron.core.inference.text_generation_controllers.text_generation_controller import TextGenerationController
from megatron.core.inference.text_generation_controllers.text_generation_controller import (
TextGenerationController,
)
from megatron.core.transformer.module import MegatronModule
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir, os.path.pardir)))
from megatron.legacy.model.module import Float16Module

sys.path.append(
os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir))
)

from megatron.training import get_args
from megatron.training import get_tokenizer
Expand All @@ -22,35 +34,47 @@
from megatron.training import get_model
from typing import List


def add_text_generate_args(parser):
"""Text generation arguments."""
group = parser.add_argument_group(title='text generation')

group.add_argument("--temperature", type=float, default=1.0,
help='Sampling temperature.')
group.add_argument("--top_k", type=int, default=1,
help='Top k sampling.')
group.add_argument("--top_p", type=float, default=0.0,
help='Top p sampling.')
group.add_argument("--return-log-probs", action='store_true', default=False,
help='Return the log probabilities of the final output tokens')
group.add_argument("--num-tokens-to-generate", type=int, default=30,
help='Number of tokens to generate for each prompt')
group.add_argument("--prompts", metavar='N', type=str, nargs='+',
help='Input prompts with each prompt within quotes and seperated by space')
group.add_argument("--max-batch-size", type=int, default=1,
help='Max number of prompts to process at once')
group.add_argument("--temperature", type=float, default=1.0, help='Sampling temperature.')
group.add_argument("--top_k", type=int, default=1, help='Top k sampling.')
group.add_argument("--top_p", type=float, default=0.0, help='Top p sampling.')
group.add_argument(
"--return-log-probs",
action='store_true',
default=False,
help='Return the log probabilities of the final output tokens',
)
group.add_argument(
"--num-tokens-to-generate",
type=int,
default=30,
help='Number of tokens to generate for each prompt',
)
group.add_argument(
"--prompts",
metavar='N',
type=str,
nargs='+',
help='Input prompts with each prompt within quotes and seperated by space',
)
group.add_argument(
"--max-batch-size", type=int, default=1, help='Max number of prompts to process at once'
)
return parser


def get_inference_engine(args: Namespace, model: MegatronModule) -> AbstractEngine:
"""Utility to get the relevant backend for running inference
This function will automatically chose the TRTLLMBackend when possible, and if not revert to Mcore backend if the user does not specify any backends. TRT LLM Backend is not implmented yet.
This function will automatically chose the TRTLLMBackend when possible, and if not revert to Mcore backend if the user does not specify any backends. TRT LLM Backend is not implmented yet.
Args:
args (Namespace): The user arguments parsed from command line
model (MegatronModule): The megatron model .
model (MegatronModule): The megatron model .
Returns:
AbstractBackend: The chosen backend
Expand All @@ -62,23 +86,32 @@ def get_inference_engine(args: Namespace, model: MegatronModule) -> AbstractEngi
inference_batch_times_seqlen_threshold=args.inference_batch_times_seqlen_threshold,
fp32_residual_connection=args.fp32_residual_connection,
params_dtype=args.params_dtype,
padded_vocab_size=args.padded_vocab_size
padded_vocab_size=args.padded_vocab_size,
)

inference_wrapped_model = GPTInferenceWrapper(model, inference_wrapper_config)
text_generation_controller = TextGenerationController(inference_wrapped_model=inference_wrapped_model, tokenizer=tokenizer)
return MCoreEngine(text_generation_controller=text_generation_controller, max_batch_size=args.max_batch_size)

text_generation_controller = TextGenerationController(
inference_wrapped_model=inference_wrapped_model, tokenizer=tokenizer
)
return MCoreEngine(
text_generation_controller=text_generation_controller, max_batch_size=args.max_batch_size
)


def main():
"""Main program."""

# Note: The default args passed here can be overwritten by using appropriate params (check arguments.py file)
# Micro batch size is not needed to be set by user. (It is calculated based on inference-batch-times-seqlen-threshold argument)
initialize_megatron(extra_args_provider=add_text_generate_args,
args_defaults={'no_load_rng': True,
'no_load_optim': True,
'micro_batch_size': 1,
'exit_on_missing_checkpoint': True})
initialize_megatron(
extra_args_provider=add_text_generate_args,
args_defaults={
'no_load_rng': True,
'no_load_optim': True,
'micro_batch_size': 1,
'exit_on_missing_checkpoint': True,
},
)

# Set up model and load checkpoint
model = get_model(model_provider, wrap_with_ddp=False)
Expand All @@ -90,26 +123,38 @@ def main():
inference_engine = get_inference_engine(args, model)

sampling_params = SamplingParams(
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
return_log_probs=args.return_log_probs,
num_tokens_to_generate=args.num_tokens_to_generate)
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
return_log_probs=args.return_log_probs,
num_tokens_to_generate=args.num_tokens_to_generate,
)

if args.enable_cuda_graph:
print(f"Running warmup for CUDA graphs...")
inference_engine.generate(
prompts=args.prompts, sampling_params=sampling_params
)

start_time = time.perf_counter()
results: List[InferenceRequest] = inference_engine.generate(
prompts=args.prompts, sampling_params=sampling_params
)

end_time = time.perf_counter()
latency = end_time - start_time

if torch.distributed.get_rank() == 0:
for idx, result in enumerate(results):
print(f' \n------------- RESULT FOR PROMPT {idx} --------------- ')
result = {
'id': result.request_id,
'input_prompt': result.prompt,
'input_prompt': result.prompt,
'generated_text': result.generated_text,
'generated_tokens': result.generated_tokens
}
'generated_tokens': result.generated_tokens,
'latency': latency,
}
print(result)


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def forward_pass_with_pipeline_parallel_small_input_batch(
tokens = inference_input["tokens"]
position_ids = inference_input["position_ids"]
attention_mask = inference_input["attention_mask"]

batch_size, seq_len = tokens.shape
recv_buffer = None
if not parallel_state.is_pipeline_first_stage():
Expand Down Expand Up @@ -237,7 +238,7 @@ def run_one_forward_step(self, inference_input: Dict[str, Any]) -> torch.Tensor:
Appropriate utility is called for the forward pass depending on the type of model parallelism used
Args:
inference_input (List): A list containg the inputs for the gpt model [tokens, position ids, attention mask]
inference_input (Dict[str, Any]): A dict containg the inputs for the gpt model [tokens, position ids, attention mask]
Returns:
torch.Tensor: The output logits of shape [batch_size, seq_len, padded_vocab_size]. The logits are returned only in the last pipeline stage for PP models.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
AbstractModelInferenceWrapper,
)
from megatron.core.inference.sampling_params import SamplingParams
from megatron.core.transformer.cuda_graphs import create_cudagraphs


class TextGenerationController:
Expand Down Expand Up @@ -329,6 +330,14 @@ def generate_all_output_tokens_static_batch(
batch_size, device=torch.cuda.current_device()
).cuda()

# Check whether CUDA graphs are enabled
if hasattr(self.inference_wrapped_model.model, "module"): # if model is Float16Module
enable_cuda_graph = self.inference_wrapped_model.model.module.config.enable_cuda_graph
else:
enable_cuda_graph = self.inference_wrapped_model.model.config.enable_cuda_graph

use_attention_mask = True

with torch.no_grad():

self.inference_wrapped_model.prep_model_for_inference(
Expand All @@ -349,11 +358,21 @@ def generate_all_output_tokens_static_batch(
)
)

if (
not use_attention_mask
and "attention_mask" in inference_input_for_context_window
):
inference_input_for_context_window["attention_mask"] = None

# Returns the final logits of shape [batch_size, context_length, vocab_size]
# Note: This is returned in all TP ranks or last PP stage in PP models
logits = self.inference_wrapped_model.run_one_forward_step(
inference_input_for_context_window
)

if enable_cuda_graph:
create_cudagraphs()

if self.model_is_pipeline_parallel:
context_length = context_end_position - context_start_position
logits = broadcast_from_last_pipeline_stage(
Expand Down Expand Up @@ -409,6 +428,10 @@ def generate_all_output_tokens_static_batch(
if all_prompts_done:
break

# Disable attention mask for CUDA graphs (decode only)
if use_attention_mask and enable_cuda_graph and torch.all(generation_started):
use_attention_mask = False

# Include all the generated tokens
batch_prompt_tokens_with_generations = batch_prompt_tokens[:, : (context_end_position + 1)]
if sampling_params.return_log_probs:
Expand Down
6 changes: 5 additions & 1 deletion megatron/core/models/gpt/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,11 @@ def forward(
packed_seq=packed_seq_params is not None
and packed_seq_params.qkv_format == 'thd',
)
if (self.config.enable_cuda_graph or self.config.flash_decode) and inference_params:
if (
(self.config.enable_cuda_graph or self.config.flash_decode)
and rotary_pos_cos is not None
and inference_params
):
sequence_len_offset = torch.tensor(
[inference_params.sequence_len_offset] * inference_params.current_batch_size,
dtype=torch.int32,
Expand Down
6 changes: 4 additions & 2 deletions megatron/core/transformer/cuda_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def create_cudagraphs(cls):
return

# Otherwise, create all the recorded cudagraphs.
logging.getLogger(__name__).info(f"Creating {len(cls.cudagraph_record)} cudagraphs")
logging.getLogger(__name__).info(f"Creating {len(cls.cudagraph_record)} CUDA graphs")

has_te_modules = False
for g in cls.cudagraph_record:
Expand Down Expand Up @@ -371,11 +371,13 @@ def __init__(self, base_module, position):
self.backward_retain_grad = False
self.fp8_enabled = False
self.deallocate_pipeline_outputs = False
self.num_warmup_steps = 2
if isinstance(self.base_module.config, TransformerConfig):
self.fuse_wgrad_accumulation = self.base_module.config.gradient_accumulation_fusion
self.backward_retain_grad = self.base_module.config.cuda_graph_retain_backward_graph
self.fp8_enabled = self.base_module.config.fp8 is not None
self.deallocate_pipeline_outputs = self.base_module.config.deallocate_pipeline_outputs
self.num_warmup_steps = self.base_module.config.cuda_graph_warmup_steps

if self.fp8_enabled:
self.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
Expand Down Expand Up @@ -447,7 +449,7 @@ def create_fwd_graph(self, mempool, args, kwargs, clone_inputs=True):
self.fwd_graph.register_generator_state(state)

# warmup again as case graph capture mode may execute a different codepath
for _ in range(2):
for _ in range(self.num_warmup_steps):
with self.get_fp8_context():
outputs = self.base_module.forward(
*self.fwd_graph_input_args, **self.fwd_graph_input_kwargs
Expand Down
3 changes: 3 additions & 0 deletions megatron/core/transformer/transformer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,9 @@ class TransformerConfig(ModelParallelConfig):
external_cuda_graph: bool = False
"""When set to true, TransformerLayer layers are swapped with user provided CUDA graphs."""

cuda_graph_warmup_steps: int = 3
"""Number of warmup steps for CUDA graphs"""

config_logger_dir: str = ""
"""When non-empty, dumps entry-point configs to config_logger_dir"""

Expand Down
2 changes: 2 additions & 0 deletions megatron/training/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,6 +899,8 @@ def _add_inference_args(parser):
help='Whether to use the flash decoding kernel.')
group.add_argument('--enable-cuda-graph', default=False, action="store_true",
help='Use CUDA graph capture and replay.')
group.add_argument("--cuda-graph-warmup-steps", type=int, default=2,
help="Number of CUDA graph warmup steps")
group.add_argument('--inference-max-seq-length', type=int, default=2560,
help='Maximum sequence length allocated for prefill during inference.',
dest='inference_max_seq_length')
Expand Down

0 comments on commit d5069b8

Please sign in to comment.