From 2d5ffb5b053b405dd53a92b9c9019ad7e3534ae4 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Fri, 29 Aug 2025 16:11:05 -0700 Subject: [PATCH 1/4] register sdpa variant --- .../lowering/passes/_aten_lowering_pass.py | 26 +- tools/llm/run_llm.py | 4 + tools/llm/torchtrt_ext/register_sdpa.py | 363 ++++++++++++++++++ tools/llm/torchtrt_ext/sdpa_converter.py | 62 ++- 4 files changed, 436 insertions(+), 19 deletions(-) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py index 1fc1b9b420..7f07154eb6 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py @@ -1,5 +1,5 @@ import logging -from typing import Callable, Optional, Sequence, Union +from typing import Any, Callable, Optional, Sequence, Union import torch from torch_tensorrt.dynamo._settings import CompilationSettings @@ -53,20 +53,28 @@ def _aten_lowering_pass( *args: LoweringPassSignature, index: Optional[int] = None, + **kwargs: Any, ) -> Union[ LoweringPassSignature, Callable[[LoweringPassSignature], LoweringPassSignature] ]: """Adds a lowering pass to the registry, at a specified index if desired If no index is specified, the lowering pass is inserted at the end of the list + + Additional keyword arguments can be passed to configure the lowering pass behavior. + These will be stored as metadata on the pass function. """ def add_lowering_pass( lowering_pass: LoweringPassSignature, ) -> LoweringPassSignature: + # Store additional parameters as metadata on the function + if kwargs: + lowering_pass._lowering_pass_config = kwargs + ATEN_POST_LOWERING_PASSES.add_pass_with_index(lowering_pass, index) logger.debug( - f"Added lowering pass {lowering_pass} to list at index {index}, current passlist: {ATEN_POST_LOWERING_PASSES}" + f"Added lowering pass {lowering_pass} to list at index {index} with config {kwargs}, current passlist: {ATEN_POST_LOWERING_PASSES}" ) return lowering_pass @@ -81,7 +89,7 @@ def add_lowering_pass( f"aten_lowering_pass decorator called with invalid arguments {args} " "To specify an index to insert the pass, use the keyword 'index='" ) - # If no arguments are specified, the decorator was called with an index keyword + # If no arguments are specified, the decorator was called with keyword arguments else: return add_lowering_pass @@ -95,6 +103,18 @@ def _remove_lowering_pass(*, index: int) -> None: return +def get_lowering_pass_config(lowering_pass: LoweringPassSignature) -> dict[str, Any]: + """Get the configuration parameters for a lowering pass function + + Args: + lowering_pass: The lowering pass function + + Returns: + Dictionary containing the configuration parameters, or empty dict if none + """ + return getattr(lowering_pass, "_lowering_pass_config", {}) + + def post_lowering( gm: torch.fx.GraphModule, settings: CompilationSettings = CompilationSettings() ) -> torch.fx.GraphModule: diff --git a/tools/llm/run_llm.py b/tools/llm/run_llm.py index 7e50b515c2..075f3ace15 100644 --- a/tools/llm/run_llm.py +++ b/tools/llm/run_llm.py @@ -58,6 +58,10 @@ def get_model(args): .eval() .cuda() ) + if register_sdpa._SDPA_MAPPING.get(args.model, None) is not None: + register_sdpa._SDPA_MAPPING[args.model](model_config=model.config) + else: + register_sdpa._SDPA_MAPPING["default"](model_config=model.config) if args.precision == "FP16": model = model.to(torch.float16) diff --git a/tools/llm/torchtrt_ext/register_sdpa.py b/tools/llm/torchtrt_ext/register_sdpa.py index 90a00a5798..038c9a47d7 100644 --- a/tools/llm/torchtrt_ext/register_sdpa.py +++ b/tools/llm/torchtrt_ext/register_sdpa.py @@ -124,3 +124,366 @@ def replace_variants_of_sdpa( "Replaced variants of scaled_dot_product_attention with torch.nn.functional.scaled_dot_product_attention" ) return gm + + +""" +.. _run_llm: + +Running LLM inference with Torch-TensorRT +========================================================== + +This script illustrates Torch-TensorRT workflow with dynamo backend on popular LLM models. +""" + +import argparse +import copy +import os +import timeit +from contextlib import nullcontext + +# %% +# Imports and Model Definition +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +import torch +import torch_tensorrt +from torchtrt_ext import register_sdpa +from transformers import AutoModelForCausalLM, AutoTokenizer +from utils import ( + export_llm, + generate, + generate_with_static_cache, + record_stats, + time_generate, +) + +DEVICE = torch.device("cuda:0") + + +def get_model(args): + """ + Load and configure the language model for inference. + + This function loads a pre-trained causal language model using the specified + model name and configures it with the appropriate precision and settings + for inference. + + Args: + args: Parsed command line arguments containing: + - model (str): Name or path of the model to load + - precision (str): Precision to use ("FP16", "BF16", or "FP32") + + Returns: + torch.nn.Module: The loaded and configured model ready for inference, + moved to CUDA device with the specified precision + """ + with torch.no_grad(): + model = ( + AutoModelForCausalLM.from_pretrained( + args.model, + use_cache=False, + attn_implementation="sdpa", + ) + .eval() + .cuda() + ) + if register_sdpa._SDPA_MAPPING.get(args.model, None) is not None: + register_sdpa._SDPA_MAPPING[args.model](model_config=model.config) + else: + register_sdpa._SDPA_MAPPING["default"](model_config=model.config) + + if args.precision == "FP16": + model = model.to(torch.float16) + elif args.precision == "BF16": + model = model.to(torch.bfloat16) + else: + model = model.to(torch.float32) + + return model + + +def compile_torchtrt(model, input_ids, args): + """ + Compile a PyTorch model to TensorRT using torch_tensorrt.dynamo.compile. + + This function exports the given model to a TorchScript representation and then + compiles it to TensorRT for optimized inference. The compilation process includes + precision-specific optimizations and various performance tuning parameters. + + Args: + model (torch.nn.Module): The PyTorch model to compile + input_ids (torch.Tensor): Input token IDs tensor used for model export + args: Parsed command line arguments containing: + - num_tokens (int): Number of tokens to generate (used for max sequence length) + - precision (str): Precision to use ("FP16", "BF16", or "FP32") + - debug (bool): Whether to enable debug logging + - min_block_size (int): Minimum block size for TensorRT compilation + + Returns: + torch_tensorrt.dynamo.TorchTensorRTModule: The compiled TensorRT model ready + for optimized inference + """ + max_seq_len = input_ids.shape[1] + args.num_tokens + ep = export_llm(model, input_ids, max_seq_len=max_seq_len) + position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).to(DEVICE) + # Set precision specific flags + use_fp32_acc = False + use_explicit_typing = False + if args.precision == "FP16": + enabled_precisions = {torch.float32} + use_fp32_acc = True + use_explicit_typing = True + elif args.precision == "BF16": + enabled_precisions = {torch.bfloat16} + use_fp32_acc = False + else: + enabled_precisions = {torch.float32} + + with torch_tensorrt.logging.debug() if args.debug else nullcontext(): + trt_model = torch_tensorrt.dynamo.compile( + ep, + inputs=[input_ids, position_ids], + enabled_precisions=enabled_precisions, + # truncate_double=True, + use_explicit_typing=use_explicit_typing, + use_fp32_acc=use_fp32_acc, + device=DEVICE, + disable_tf32=True, + use_python_runtime=True, + debug=args.debug, + offload_module_to_cpu=True, + min_block_size=args.min_block_size, + ) + + return trt_model + + +def print_outputs(backend_name, gen_tokens, tokenizer): + """ + Print the generated tokens from the model. + """ + print(f"========= {backend_name} =========") + print( + f"{backend_name} model generated text: ", + tokenizer.decode(gen_tokens[0], skip_special_tokens=True), + ) + print("===================================") + + +def measure_perf(trt_model, input_signature, backend_name): + """ + Measure the performance of a TensorRT model by running it multiple times and + calculating the average time per iteration. + """ + total_time = 0 + iterations = 10 + + print("Running warmup iteration...") + # Warmup run + _ = trt_model(*input_signature) + torch.cuda.synchronize() + + print(f"Measuring performance over {iterations} iterations...") + for i in range(iterations): + start_time = timeit.default_timer() + _ = trt_model(*input_signature) + torch.cuda.synchronize() + end_time = timeit.default_timer() + iter_time = end_time - start_time + total_time += iter_time + + avg_time = total_time / iterations + print( + f"Backend: {backend_name} Average time per iteration: {avg_time*1000:.4f} milliseconds" + ) + print( + f"Backend: {backend_name} Average throughput: {1.0/avg_time:.2f} iterations/second" + ) + + +if __name__ == "__main__": + arg_parser = argparse.ArgumentParser( + description="Run inference on a model with random input values" + ) + arg_parser.add_argument( + "--model", + type=str, + default="meta-llama/Llama-3.2-1B-Instruct", + help="Name of LLM model", + ) + arg_parser.add_argument( + "--tokenizer", + type=str, + default="", + help="Name of LLM model tokenizer", + ) + arg_parser.add_argument( + "--prompt", type=str, default="What is parallel programming ?", help="Prompt" + ) + arg_parser.add_argument( + "--precision", + type=str, + default="FP16", + help="Precision to use in the model. Options: FP16, BF16, FP32", + ) + arg_parser.add_argument( + "--iterations", type=int, default=5, help="no. of iterations to run" + ) + arg_parser.add_argument( + "--min_block_size", type=int, default=1, help="no. of iterations to run" + ) + arg_parser.add_argument( + "--num_tokens", + type=int, + default=128, + help="no. of output tokens to be generated", + ) + arg_parser.add_argument( + "--batch_size", type=int, default=1, help="Batch size used for benchmarking" + ) + arg_parser.add_argument( + "--isl", + type=int, + default=2048, + help="Input sequence length used for benchmarking", + ) + arg_parser.add_argument( + "--enable_pytorch_run", + action="store_true", + help="Enable pytorch run (default: False)", + ) + arg_parser.add_argument( + "--cache", + type=str, + default="", + help="Type of KV cache to use. Options: static_v1, static_v2", + ) + arg_parser.add_argument( + "--cudagraph", action="store_true", help="Enable cudagraphs (default: False)" + ) + arg_parser.add_argument( + "--debug", action="store_true", help="Enable debug (default: False)" + ) + arg_parser.add_argument( + "--benchmark", action="store_true", help="Enable benchmark (default: False)" + ) + + args = arg_parser.parse_args() + with torch.inference_mode(): + model = get_model(args) + + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer or args.model) + + # Prepare input for benchmarking or evaluation + if args.benchmark: + input_ids = torch.randint( + 1, 10000, (args.batch_size, args.isl), dtype=torch.int64 + ).to(model.device) + position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).to(DEVICE) + else: + model_inputs = tokenizer(args.prompt, return_tensors="pt") + input_ids = model_inputs["input_ids"].to(DEVICE) + position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).to(DEVICE) + + MAX_OUTPUT_SEQ_LENGTH = input_ids.shape[1] + args.num_tokens + # Pyt + pyt_gen_tokens = None + pyt_timings = None + pyt_stats = None + + if args.enable_pytorch_run: + pyt_gen_tokens = generate( + model, input_ids.clone(), MAX_OUTPUT_SEQ_LENGTH, tokenizer.eos_token_id + ) + if args.benchmark: + pyt_timings = time_generate( + generate, + model, + input_ids.clone(), + MAX_OUTPUT_SEQ_LENGTH, + tokenizer.eos_token_id, + iterations=args.iterations, + ) + pyt_stats = record_stats( + "PyTorch", + pyt_timings, + args.precision, + batch_size=args.batch_size, + compile_time_s=None, + ) + + if args.cache == "static_v1": + # This import is required to register static v1 KV cache transformations as lowering passes + import static_cache_v1 + if args.cache == "static_v2": + # This import is required to register static v2 KV cache transformations as lowering passes + import static_cache_v2 + + # Compile the model with Torch-TensorRT + trt_model = compile_torchtrt(model, input_ids, args) + + if args.cache == "static_v1" or args.cache == "static_v2": + if args.cudagraph: + # Run a decoding loop with prefill and generate phases so that the CUDAGraph is recorded for both of these phases. + # trt_input_signature = (input_ids.clone(),) + get_zeroed_kv_cache_inputs(trt_model) + torch_tensorrt.runtime.set_cudagraphs_mode(True) + + trt_gen_tokens = generate_with_static_cache( + trt_model, + input_ids.clone(), + MAX_OUTPUT_SEQ_LENGTH, + tokenizer.eos_token_id, + ) + + if args.benchmark: + trt_timings = time_generate( + generate_with_static_cache, + trt_model, + input_ids.clone(), + MAX_OUTPUT_SEQ_LENGTH, + tokenizer.eos_token_id, + iterations=args.iterations, + ) + else: + trt_gen_tokens = generate( + trt_model, + input_ids.clone(), + MAX_OUTPUT_SEQ_LENGTH, + tokenizer.eos_token_id, + ) + if args.benchmark: + trt_timings = time_generate( + generate, + trt_model, + input_ids.clone(), + MAX_OUTPUT_SEQ_LENGTH, + tokenizer.eos_token_id, + iterations=args.iterations, + ) + + if args.benchmark: + trt_stats = record_stats( + "TensorRT", + trt_timings, + args.precision, + batch_size=args.batch_size, + compile_time_s=None, + ) + + if not args.benchmark: + if args.enable_pytorch_run: + print_outputs("PyTorch", pyt_gen_tokens, tokenizer) + + print_outputs("TensorRT", trt_gen_tokens, tokenizer) + + if args.enable_pytorch_run: + print( + f"PyTorch and TensorRT outputs match: {torch.equal(pyt_gen_tokens, trt_gen_tokens)}" + ) + + if args.benchmark: + if args.enable_pytorch_run: + print("=========PyTorch PERFORMANCE============ \n") + print(pyt_stats) + print("===================== \n") + print("=========TensorRT PERFORMANCE============ \n") + print(trt_stats) diff --git a/tools/llm/torchtrt_ext/sdpa_converter.py b/tools/llm/torchtrt_ext/sdpa_converter.py index d05b0379a4..f7a7203f38 100644 --- a/tools/llm/torchtrt_ext/sdpa_converter.py +++ b/tools/llm/torchtrt_ext/sdpa_converter.py @@ -27,24 +27,50 @@ def tril( name: str, row: TRTTensor, col: TRTTensor, + sliding_window_size: Optional[int] = None, ) -> TRTTensor: row_arange_tensor = impl.arange.arange( ctx, target, source_ir, name + "_arange_row", start=0, end=row, step=1 ) - row_reshape_tensor = impl.shuffle.reshape( - ctx, target, source_ir, name + "_reshape_row", row_arange_tensor, [row, 1] - ) - col_arange_tensor = impl.arange.arange( ctx, target, source_ir, name + "_arange_col", start=0, end=col, step=1 ) - col_reshape_tensor = impl.shuffle.reshape( - ctx, target, source_ir, name + "_reshape_col", col_arange_tensor, [1, col] + row_arange_tensor = impl.unsqueeze.unsqueeze( + ctx, target, source_ir, name + "_unsqueeze_row", row_arange_tensor, -1 ) - - mask = impl.elementwise.ge( - ctx, target, source_ir, name + "_ge", row_reshape_tensor, col_reshape_tensor + col_arange_tensor = impl.unsqueeze.unsqueeze( + ctx, target, source_ir, name + "_unsqueeze_col", col_arange_tensor, 0 + ) + # sub will return the following mask tensor: + # [[0, -1, -2, -3], + # [1, 0, -1, -2], + # [2, 1, 0, -1], + # [3, 2, 1, 0]] + mask = impl.elementwise.sub( + ctx, target, source_ir, name + "_sub", row_arange_tensor, col_arange_tensor ) + ge_0_mask = impl.elementwise.ge(ctx, target, source_ir, name + "_ge_0", mask, 0.0) + if sliding_window_size is None: + # return the following lower triangular mask includes the main diagonal: + # 0 ■ ⬚ ⬚ ⬚ ⬚ tensor([[[[ True, False, False, False, False], + # 1 ■ ■ ⬚ ⬚ ⬚ [ True, True, False, False, False], + # 2 ■ ■ ■ ⬚ ⬚ [ True, True, True, False, False], + # 3 ■ ■ ■ ■ ⬚ [ True, True, True, True, False], + # 4 ■ ■ ■ ■ ■ [ True, True, True, True, True]]]]) + return ge_0_mask + + lt_window_mask = impl.elementwise.lt( + ctx, target, source_ir, name + "_lt_window_size", mask, sliding_window_size + ) + mask = impl.elementwise.logical_and( + ctx, target, source_ir, name + "_logical_and", ge_0_mask, lt_window_mask + ) + # return the following mask if sliding_window_size is 3: + # 0 ■ ⬚ ⬚ ⬚ ⬚ tensor([[[[ True, False, False, False, False], + # 1 ■ ■ ⬚ ⬚ ⬚ [ True, True, False, False, False], + # 2 ■ ■ ■ ⬚ ⬚ [ True, True, True, False, False], + # 3 ⬚ ■ ■ ■ ⬚ [False, True, True, True, False], + # 4 ⬚ ⬚ ■ ■ ■ [False, False, True, True,True]]]]) return mask @@ -66,9 +92,13 @@ def scaled_dot_product_attention( # TODO: remove this once we have a better way to handle the causal mask scale = kwargs.get("scale", None) source_ir = SourceIR.ATEN - is_causal = True + + assert is_causal == True, "is_causal should be set to True" + # implementation as described here: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html use_fp32_acc = kwargs.get("use_fp32_acc", False) + sliding_window_size = kwargs.get("sliding_window_size", None) + query_dtype = query.dtype if scale is None: @@ -136,7 +166,9 @@ def scaled_dot_product_attention( S = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", key, 2) # generate the mask tensor - tril_tensor = tril(ctx, target, source_ir, name + "_tril", L, S) + tril_tensor = tril( + ctx, target, source_ir, name + "_tril", L, S, sliding_window_size + ) temp_mask = impl.unary.logical_not( ctx, target, source_ir, name + "_logical_not", tril_tensor @@ -165,11 +197,9 @@ def scaled_dot_product_attention( attn_bias = impl.unary.log( ctx, target, source_ir, name + "_log", one_minus_temp_mask ) - - scaled_add_attn_bias = impl.elementwise.add( - ctx, target, source_ir, name + "_attn_bias_add", mm, attn_bias - ) - + scaled_add_attn_bias = impl.elementwise.add( + ctx, target, source_ir, name + "_attn_bias_add", mm, attn_bias + ) softmax = impl.normalization.softmax( ctx, target, source_ir, name + "_softmax", scaled_add_attn_bias, -1, False ) From 7490b02b247914b5916771ae4562678a7f2cf681 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Fri, 29 Aug 2025 16:15:24 -0700 Subject: [PATCH 2/4] test --- tools/llm/torchtrt_ext/register_sdpa.py | 587 ++++++------------------ 1 file changed, 153 insertions(+), 434 deletions(-) diff --git a/tools/llm/torchtrt_ext/register_sdpa.py b/tools/llm/torchtrt_ext/register_sdpa.py index 038c9a47d7..6284dc6d61 100644 --- a/tools/llm/torchtrt_ext/register_sdpa.py +++ b/tools/llm/torchtrt_ext/register_sdpa.py @@ -1,7 +1,7 @@ import copy import logging import operator -from typing import Callable, Sequence, Tuple +from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type import torch from torch_tensorrt.dynamo._settings import CompilationSettings @@ -13,6 +13,7 @@ from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( clean_up_graph_after_modifications, ) +from transformers import AutoConfig, Gemma3TextConfig from .sdpa_converter import * @@ -33,457 +34,175 @@ torch.ops.aten._scaled_dot_product_flash_attention.default, } - -@_aten_lowering_pass -def replace_variants_of_sdpa( - gm: torch.fx.GraphModule, settings: CompilationSettings -) -> torch.fx.GraphModule: - """Replace scaled_dot_product_attention with an equivalent - implementation which can be accurately converted to TRT - """ - attn_mask = None - is_causal = True - for node in gm.graph.nodes: - if node.op == "call_function" and node.target in REPLACEABLE_ATEN_OPS: - if ( - node.target - == torch.ops.aten._scaled_dot_product_efficient_attention.default - ): - if len(node.args) == 7: - ( - query, - key, - value, - attn_bias, - compute_log_sumexp, - dropout_p, - is_causal, - ) = node.args - elif len(node.args) == 5: - query, key, value, attn_mask, is_causal = node.args - dropout_p = 0.0 - - else: - raise ValueError( - f"Unexpected number of arguments for {node.target} in the graph" - ) - elif ( - node.target - == torch.ops.aten._scaled_dot_product_flash_attention.default - ): - if len(node.args) == 6: - query, key, value, dropout_p, is_causal, return_debug_mask = ( - node.args - ) - if len(node.args) == 5: - query, key, value, dropout_p, is_causal = node.args - elif len(node.args) == 3: - query, key, value = node.args - dropout_p = 0.0 - is_causal = True - else: - raise ValueError( - f"Unexpected number of arguments for {node.target} in the graph" - ) - - logger.warning( - f"This current version of SDPA converter only supports attn_mask = None, dropout_p = 0.0 and is_causal = True configuration. This could cause issues with accuracy for models with different configurations." - ) - modified_input_args = (query, key, value, None, dropout_p, True) - # Create a new node with torch.nn.functional.scaled_dot_product_attention - # The input args is (query, key, value, is_causal). kwargs has scale - with gm.graph.inserting_after(node): - new_node = gm.graph.call_function( - torch.nn.functional.scaled_dot_product_attention, - args=modified_input_args, - kwargs={ - "scale": node.kwargs.get("scale", None), - "use_fp32_acc": settings.use_fp32_acc, - }, - ) - - # Deep copy encounters RuntimeError: Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). So we use copy instead. - new_node.meta = copy.copy(node.meta) - # Check if there's a getitem node following this attention node - for user in list(node.users): - if user.op == "call_function" and user.target == operator.getitem: - # If the getitem is extracting the first element (the output tensor) - if user.args[1] == 0: - # Replace all uses of the getitem with the new attention node - user.replace_all_uses_with(new_node) - new_node.meta["val"] = new_node.meta["val"][0] - # Replace all uses of the original node with the new node - node.replace_all_uses_with(new_node) - - gm.graph.erase_node(node) - - # Clean up the graph - clean_up_graph_after_modifications(gm) - - logger.debug( - "Replaced variants of scaled_dot_product_attention with torch.nn.functional.scaled_dot_product_attention" - ) - return gm - - -""" -.. _run_llm: - -Running LLM inference with Torch-TensorRT -========================================================== - -This script illustrates Torch-TensorRT workflow with dynamo backend on popular LLM models. -""" - -import argparse -import copy -import os -import timeit -from contextlib import nullcontext - -# %% -# Imports and Model Definition -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -import torch -import torch_tensorrt -from torchtrt_ext import register_sdpa -from transformers import AutoModelForCausalLM, AutoTokenizer -from utils import ( - export_llm, - generate, - generate_with_static_cache, - record_stats, - time_generate, +from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import ( + get_lowering_pass_config, ) -DEVICE = torch.device("cuda:0") - - -def get_model(args): - """ - Load and configure the language model for inference. - - This function loads a pre-trained causal language model using the specified - model name and configures it with the appropriate precision and settings - for inference. - - Args: - args: Parsed command line arguments containing: - - model (str): Name or path of the model to load - - precision (str): Precision to use ("FP16", "BF16", or "FP32") - Returns: - torch.nn.Module: The loaded and configured model ready for inference, - moved to CUDA device with the specified precision - """ - with torch.no_grad(): - model = ( - AutoModelForCausalLM.from_pretrained( - args.model, - use_cache=False, - attn_implementation="sdpa", +def _process_sdpa_node( + gm: torch.fx.GraphModule, + node: torch.fx.Node, + settings: CompilationSettings, + sliding_window_size: Optional[int] = None, + use_gqa: bool = False, +) -> torch.fx.GraphModule: + """Helper function to process SDPA nodes with common logic.""" + + if node.target == torch.ops.aten._scaled_dot_product_efficient_attention.default: + if len(node.args) == 7: + ( + query, + key, + value, + attn_mask, + compute_log_sumexp, + dropout_p, + is_causal, + ) = node.args + elif len(node.args) == 5: + query, key, value, attn_mask, is_causal = node.args + dropout_p = 0.0 + else: + raise ValueError( + f"Unexpected number of arguments for {node.target} in the graph" ) - .eval() - .cuda() - ) - if register_sdpa._SDPA_MAPPING.get(args.model, None) is not None: - register_sdpa._SDPA_MAPPING[args.model](model_config=model.config) + elif node.target == torch.ops.aten._scaled_dot_product_flash_attention.default: + if len(node.args) == 6: + ( + query, + key, + value, + dropout_p, + is_causal, + return_debug_mask, + ) = node.args + elif len(node.args) == 5: + query, key, value, dropout_p, is_causal = node.args + elif len(node.args) == 3: + query, key, value = node.args + dropout_p = 0.0 + is_causal = True else: - register_sdpa._SDPA_MAPPING["default"](model_config=model.config) - - if args.precision == "FP16": - model = model.to(torch.float16) - elif args.precision == "BF16": - model = model.to(torch.bfloat16) - else: - model = model.to(torch.float32) - - return model - - -def compile_torchtrt(model, input_ids, args): - """ - Compile a PyTorch model to TensorRT using torch_tensorrt.dynamo.compile. - - This function exports the given model to a TorchScript representation and then - compiles it to TensorRT for optimized inference. The compilation process includes - precision-specific optimizations and various performance tuning parameters. - - Args: - model (torch.nn.Module): The PyTorch model to compile - input_ids (torch.Tensor): Input token IDs tensor used for model export - args: Parsed command line arguments containing: - - num_tokens (int): Number of tokens to generate (used for max sequence length) - - precision (str): Precision to use ("FP16", "BF16", or "FP32") - - debug (bool): Whether to enable debug logging - - min_block_size (int): Minimum block size for TensorRT compilation - - Returns: - torch_tensorrt.dynamo.TorchTensorRTModule: The compiled TensorRT model ready - for optimized inference - """ - max_seq_len = input_ids.shape[1] + args.num_tokens - ep = export_llm(model, input_ids, max_seq_len=max_seq_len) - position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).to(DEVICE) - # Set precision specific flags - use_fp32_acc = False - use_explicit_typing = False - if args.precision == "FP16": - enabled_precisions = {torch.float32} - use_fp32_acc = True - use_explicit_typing = True - elif args.precision == "BF16": - enabled_precisions = {torch.bfloat16} - use_fp32_acc = False + raise ValueError( + f"Unexpected number of arguments for {node.target} in the graph" + ) else: - enabled_precisions = {torch.float32} + return gm - with torch_tensorrt.logging.debug() if args.debug else nullcontext(): - trt_model = torch_tensorrt.dynamo.compile( - ep, - inputs=[input_ids, position_ids], - enabled_precisions=enabled_precisions, - # truncate_double=True, - use_explicit_typing=use_explicit_typing, - use_fp32_acc=use_fp32_acc, - device=DEVICE, - disable_tf32=True, - use_python_runtime=True, - debug=args.debug, - offload_module_to_cpu=True, - min_block_size=args.min_block_size, + # Always set causal to True and generate attn_mask inside the sdpa operator + attn_mask = None + is_causal = True + dropout_p = 0.0 + + logger.warning( + f"SDPA converter configuration: attn_mask={attn_mask}, dropout_p={dropout_p}, " + f"is_causal={is_causal}, sliding_window_size={sliding_window_size}, use_gqa={use_gqa}" + ) + + modified_input_args = ( + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + ) + + # Create a new node with torch.nn.functional.scaled_dot_product_attention + with gm.graph.inserting_after(node): + new_node = gm.graph.call_function( + torch.nn.functional.scaled_dot_product_attention, + args=modified_input_args, + kwargs={ + "scale": node.kwargs.get("scale", None), + "use_fp32_acc": settings.use_fp32_acc, + "sliding_window_size": sliding_window_size, + }, ) - return trt_model - - -def print_outputs(backend_name, gen_tokens, tokenizer): - """ - Print the generated tokens from the model. - """ - print(f"========= {backend_name} =========") - print( - f"{backend_name} model generated text: ", - tokenizer.decode(gen_tokens[0], skip_special_tokens=True), - ) - print("===================================") - - -def measure_perf(trt_model, input_signature, backend_name): - """ - Measure the performance of a TensorRT model by running it multiple times and - calculating the average time per iteration. - """ - total_time = 0 - iterations = 10 - - print("Running warmup iteration...") - # Warmup run - _ = trt_model(*input_signature) - torch.cuda.synchronize() - - print(f"Measuring performance over {iterations} iterations...") - for i in range(iterations): - start_time = timeit.default_timer() - _ = trt_model(*input_signature) - torch.cuda.synchronize() - end_time = timeit.default_timer() - iter_time = end_time - start_time - total_time += iter_time - - avg_time = total_time / iterations - print( - f"Backend: {backend_name} Average time per iteration: {avg_time*1000:.4f} milliseconds" - ) - print( - f"Backend: {backend_name} Average throughput: {1.0/avg_time:.2f} iterations/second" - ) - - -if __name__ == "__main__": - arg_parser = argparse.ArgumentParser( - description="Run inference on a model with random input values" - ) - arg_parser.add_argument( - "--model", - type=str, - default="meta-llama/Llama-3.2-1B-Instruct", - help="Name of LLM model", - ) - arg_parser.add_argument( - "--tokenizer", - type=str, - default="", - help="Name of LLM model tokenizer", - ) - arg_parser.add_argument( - "--prompt", type=str, default="What is parallel programming ?", help="Prompt" - ) - arg_parser.add_argument( - "--precision", - type=str, - default="FP16", - help="Precision to use in the model. Options: FP16, BF16, FP32", - ) - arg_parser.add_argument( - "--iterations", type=int, default=5, help="no. of iterations to run" - ) - arg_parser.add_argument( - "--min_block_size", type=int, default=1, help="no. of iterations to run" - ) - arg_parser.add_argument( - "--num_tokens", - type=int, - default=128, - help="no. of output tokens to be generated", - ) - arg_parser.add_argument( - "--batch_size", type=int, default=1, help="Batch size used for benchmarking" - ) - arg_parser.add_argument( - "--isl", - type=int, - default=2048, - help="Input sequence length used for benchmarking", - ) - arg_parser.add_argument( - "--enable_pytorch_run", - action="store_true", - help="Enable pytorch run (default: False)", - ) - arg_parser.add_argument( - "--cache", - type=str, - default="", - help="Type of KV cache to use. Options: static_v1, static_v2", - ) - arg_parser.add_argument( - "--cudagraph", action="store_true", help="Enable cudagraphs (default: False)" - ) - arg_parser.add_argument( - "--debug", action="store_true", help="Enable debug (default: False)" - ) - arg_parser.add_argument( - "--benchmark", action="store_true", help="Enable benchmark (default: False)" - ) - - args = arg_parser.parse_args() - with torch.inference_mode(): - model = get_model(args) + # Deep copy encounters RuntimeError: Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). So we use copy instead. + new_node.meta = copy.copy(node.meta) + # Check if there's a getitem node following this attention node + for user in list(node.users): + if user.op == "call_function" and user.target == operator.getitem: + # If the getitem is extracting the first element (the output tensor) + if user.args[1] == 0: + # Replace all uses of the getitem with the new attention node + user.replace_all_uses_with(new_node) + new_node.meta["val"] = new_node.meta["val"][0] + # Replace all uses of the original node with the new node + node.replace_all_uses_with(new_node) + + gm.graph.erase_node(node) + return gm - tokenizer = AutoTokenizer.from_pretrained(args.tokenizer or args.model) - # Prepare input for benchmarking or evaluation - if args.benchmark: - input_ids = torch.randint( - 1, 10000, (args.batch_size, args.isl), dtype=torch.int64 - ).to(model.device) - position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).to(DEVICE) +def register_gemma3_sdpa_pass(index: int = 0, model_config: Any = None) -> None: + @_aten_lowering_pass(index=index, model_config=model_config) + def gemma3_sdpa_pass( + gm: torch.fx.GraphModule, settings: CompilationSettings + ) -> torch.fx.GraphModule: + """SDPA pass specifically for Gemma3 models with sliding window attention.""" + config = get_lowering_pass_config(gemma3_sdpa_pass) + sliding_window = None + layer_types = None + model_config = config.get("model_config", None) + if not isinstance(model_config, Gemma3TextConfig): + logger.warning( + f"Expected Gemma3TextConfig, got {type(model_config)}, will use default SDPA replacement instead" + ) else: - model_inputs = tokenizer(args.prompt, return_tensors="pt") - input_ids = model_inputs["input_ids"].to(DEVICE) - position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).to(DEVICE) - - MAX_OUTPUT_SEQ_LENGTH = input_ids.shape[1] + args.num_tokens - # Pyt - pyt_gen_tokens = None - pyt_timings = None - pyt_stats = None - - if args.enable_pytorch_run: - pyt_gen_tokens = generate( - model, input_ids.clone(), MAX_OUTPUT_SEQ_LENGTH, tokenizer.eos_token_id + sliding_window = getattr(model_config, "sliding_window", None) + layer_types = getattr(model_config, "layer_types", None) + logger.debug( + f"got Gemma3 config: sliding_window={sliding_window}, layer_types={layer_types}" ) - if args.benchmark: - pyt_timings = time_generate( - generate, - model, - input_ids.clone(), - MAX_OUTPUT_SEQ_LENGTH, - tokenizer.eos_token_id, - iterations=args.iterations, - ) - pyt_stats = record_stats( - "PyTorch", - pyt_timings, - args.precision, - batch_size=args.batch_size, - compile_time_s=None, - ) - if args.cache == "static_v1": - # This import is required to register static v1 KV cache transformations as lowering passes - import static_cache_v1 - if args.cache == "static_v2": - # This import is required to register static v2 KV cache transformations as lowering passes - import static_cache_v2 + index = 0 + for node in gm.graph.nodes: + if node.op == "call_function" and node.target in REPLACEABLE_ATEN_OPS: + sliding_window_size = None + if ( + sliding_window is not None + and sliding_window > 0 + and layer_types is not None + and index < len(layer_types) + ): + if layer_types[index] == "sliding_attention": + sliding_window_size = sliding_window + index += 1 + + # Process the node + logger.debug( + f"Applying Gemma3-specific SDPA replacement with {node.name=}, {node.target=}, {sliding_window_size=}" + ) + gm = _process_sdpa_node(gm, node, settings, sliding_window_size) - # Compile the model with Torch-TensorRT - trt_model = compile_torchtrt(model, input_ids, args) + clean_up_graph_after_modifications(gm) + logger.debug("Applied Gemma3-specific SDPA replacement") + return gm - if args.cache == "static_v1" or args.cache == "static_v2": - if args.cudagraph: - # Run a decoding loop with prefill and generate phases so that the CUDAGraph is recorded for both of these phases. - # trt_input_signature = (input_ids.clone(),) + get_zeroed_kv_cache_inputs(trt_model) - torch_tensorrt.runtime.set_cudagraphs_mode(True) - trt_gen_tokens = generate_with_static_cache( - trt_model, - input_ids.clone(), - MAX_OUTPUT_SEQ_LENGTH, - tokenizer.eos_token_id, - ) +def register_default_sdpa_pass(index: int = 0, model_config: Any = None) -> None: + @_aten_lowering_pass(index=index, model_config=model_config) + def default_sdpa_pass( + gm: torch.fx.GraphModule, + settings: CompilationSettings, + ) -> torch.fx.GraphModule: + """Default SDPA pass for models without specific implementations.""" - if args.benchmark: - trt_timings = time_generate( - generate_with_static_cache, - trt_model, - input_ids.clone(), - MAX_OUTPUT_SEQ_LENGTH, - tokenizer.eos_token_id, - iterations=args.iterations, - ) - else: - trt_gen_tokens = generate( - trt_model, - input_ids.clone(), - MAX_OUTPUT_SEQ_LENGTH, - tokenizer.eos_token_id, - ) - if args.benchmark: - trt_timings = time_generate( - generate, - trt_model, - input_ids.clone(), - MAX_OUTPUT_SEQ_LENGTH, - tokenizer.eos_token_id, - iterations=args.iterations, - ) + for node in gm.graph.nodes: + if node.op == "call_function" and node.target in REPLACEABLE_ATEN_OPS: + # Process the node with default logic + gm = _process_sdpa_node(gm, node, settings) - if args.benchmark: - trt_stats = record_stats( - "TensorRT", - trt_timings, - args.precision, - batch_size=args.batch_size, - compile_time_s=None, - ) - - if not args.benchmark: - if args.enable_pytorch_run: - print_outputs("PyTorch", pyt_gen_tokens, tokenizer) + clean_up_graph_after_modifications(gm) + logger.debug("Applied default SDPA replacement") + return gm - print_outputs("TensorRT", trt_gen_tokens, tokenizer) - if args.enable_pytorch_run: - print( - f"PyTorch and TensorRT outputs match: {torch.equal(pyt_gen_tokens, trt_gen_tokens)}" - ) - - if args.benchmark: - if args.enable_pytorch_run: - print("=========PyTorch PERFORMANCE============ \n") - print(pyt_stats) - print("===================== \n") - print("=========TensorRT PERFORMANCE============ \n") - print(trt_stats) +# Global registry for SDPA passes +_SDPA_MAPPING: Dict[str, Callable] = { + "google/gemma-3-1b-it": register_gemma3_sdpa_pass, + "default": register_default_sdpa_pass, +} From d8e1ae0244e6ca1e7f4ac398878de8816cfa8a80 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Tue, 2 Sep 2025 11:08:07 -0700 Subject: [PATCH 3/4] resolve comments --- tools/llm/README.md | 1 + tools/llm/run_llm.py | 1 + tools/llm/torchtrt_ext/register_sdpa.py | 69 +++++++++++++++++++++++- tools/llm/torchtrt_ext/sdpa_converter.py | 64 +++++++++++++++++++++- 4 files changed, 131 insertions(+), 4 deletions(-) diff --git a/tools/llm/README.md b/tools/llm/README.md index a141505517..00a02ecb7b 100644 --- a/tools/llm/README.md +++ b/tools/llm/README.md @@ -23,6 +23,7 @@ We have officially verified support for the following models: | LLaMA 3.2 | meta-llama/Llama-3.2-1B-Instruct
meta-llama/Llama-3.2-3B-Instruct | FP16, FP32 | Yes | | Qwen 2.5 | Qwen/Qwen2.5-0.5B-Instruct
Qwen/Qwen2.5-1.5B-Instruct
Qwen/Qwen2.5-4B-Instruct
Qwen/Qwen2.5-7B-Instruct | FP16, FP32 | Yes | | Qwen 3 | Qwen/Qwen3-0.6B
Qwen/Qwen3-1.7B
Qwen/Qwen3-4B
Qwen/Qwen3-8B | FP16, FP32 | Yes | +| Gemma 3 | google/gemma-3-1b-it | FP16, FP32 | Yes | ### Usage diff --git a/tools/llm/run_llm.py b/tools/llm/run_llm.py index 075f3ace15..ab9470cc61 100644 --- a/tools/llm/run_llm.py +++ b/tools/llm/run_llm.py @@ -58,6 +58,7 @@ def get_model(args): .eval() .cuda() ) + # register SDPA variant for the model if register_sdpa._SDPA_MAPPING.get(args.model, None) is not None: register_sdpa._SDPA_MAPPING[args.model](model_config=model.config) else: diff --git a/tools/llm/torchtrt_ext/register_sdpa.py b/tools/llm/torchtrt_ext/register_sdpa.py index 6284dc6d61..a650dc1387 100644 --- a/tools/llm/torchtrt_ext/register_sdpa.py +++ b/tools/llm/torchtrt_ext/register_sdpa.py @@ -46,7 +46,27 @@ def _process_sdpa_node( sliding_window_size: Optional[int] = None, use_gqa: bool = False, ) -> torch.fx.GraphModule: - """Helper function to process SDPA nodes with common logic.""" + """ + Helper function to process SDPA nodes with common logic. + + This function handles the replacement of various scaled dot product attention operations + with the standard torch.nn.functional.scaled_dot_product_attention function. It supports + both efficient attention and flash attention variants, and can handle sliding window + attention for models like Gemma3. + + Args: + gm: The graph module containing the SDPA nodes + node: The specific node to process (must be an SDPA operation) + settings: TensorRT compilation settings + sliding_window_size: Optional sliding window size for models with sliding attention + use_gqa: Whether the model uses Grouped Query Attention + + Returns: + The modified graph module with SDPA nodes replaced + + Raises: + ValueError: If the SDPA node has an unexpected number of arguments + """ if node.target == torch.ops.aten._scaled_dot_product_efficient_attention.default: if len(node.args) == 7: @@ -94,7 +114,7 @@ def _process_sdpa_node( is_causal = True dropout_p = 0.0 - logger.warning( + logger.debug( f"SDPA converter configuration: attn_mask={attn_mask}, dropout_p={dropout_p}, " f"is_causal={is_causal}, sliding_window_size={sliding_window_size}, use_gqa={use_gqa}" ) @@ -138,6 +158,27 @@ def _process_sdpa_node( def register_gemma3_sdpa_pass(index: int = 0, model_config: Any = None) -> None: + """ + Register SDPA pass for Gemma3 models with sliding window attention. + + This function creates and registers a specialized SDPA replacement pass for Gemma3 models. + The pass handles sliding window attention by extracting the sliding_window and layer_types + configuration from the model config and applying appropriate transformations. + + Args: + index: Position in the lowering pass list where this pass should be inserted + model_config: The model configuration object (should be Gemma3TextConfig) + + Example: + from transformers import AutoConfig + config = AutoConfig.from_pretrained("google/gemma-3-1b-it") + register_gemma3_sdpa_pass(index=0, model_config=config) + + Note: + This pass is specifically designed for Gemma3 models and will fall back to + default behavior if the model_config is not a Gemma3TextConfig. + """ + @_aten_lowering_pass(index=index, model_config=model_config) def gemma3_sdpa_pass( gm: torch.fx.GraphModule, settings: CompilationSettings @@ -184,6 +225,30 @@ def gemma3_sdpa_pass( def register_default_sdpa_pass(index: int = 0, model_config: Any = None) -> None: + """ + Register default SDPA pass for models without specific implementations. + + This function creates and registers a default SDPA replacement pass that can be used + for any model type. It provides basic SDPA replacement functionality without + model-specific optimizations. + + Args: + index: Position in the lowering pass list where this pass should be inserted + model_config: The model configuration object (optional, for consistency) + + Example: + # Register default pass at index 0 + register_default_sdpa_pass(index=0) + + # Or with model config for consistency + config = AutoConfig.from_pretrained("Qwen/Qwen3-0.6B") + register_default_sdpa_pass(index=0, model_config=config) + + Note: + This is a fallback pass that should be used when no model-specific + SDPA pass is available or when you want generic SDPA replacement behavior. + """ + @_aten_lowering_pass(index=index, model_config=model_config) def default_sdpa_pass( gm: torch.fx.GraphModule, diff --git a/tools/llm/torchtrt_ext/sdpa_converter.py b/tools/llm/torchtrt_ext/sdpa_converter.py index f7a7203f38..feded31023 100644 --- a/tools/llm/torchtrt_ext/sdpa_converter.py +++ b/tools/llm/torchtrt_ext/sdpa_converter.py @@ -29,6 +29,66 @@ def tril( col: TRTTensor, sliding_window_size: Optional[int] = None, ) -> TRTTensor: + """ + Create a lower triangular mask tensor for attention mechanisms. + + This function generates a lower triangular mask that can be used in attention + operations to enforce causal attention (each position can only attend to itself + and previous positions). It optionally supports sliding window attention by + limiting the attention span to a specified window size. + + The function creates the mask by: + 1. Generating row and column index tensors + 2. Computing the difference between row and column indices + 3. Creating a mask where row >= col (lower triangular) + 4. Optionally applying sliding window constraints + + Args: + ctx: TensorRT conversion context for managing the conversion process + target: Target operation identifier (usually the operation being converted) + source_ir: Source IR type (e.g., ATEN, TRT) - can be None + name: Base name for generated TensorRT operations (will be extended with suffixes) + row: Tensor representing the number of rows (sequence length dimension) + col: Tensor representing the number of columns (sequence length dimension) + sliding_window_size: Optional sliding window size for attention span limitation. + If None, creates a full lower triangular mask. + If specified, creates a sliding window mask where each position + can only attend to positions within the window. + + Returns: + TRTTensor: A boolean mask tensor with shape [batch, heads, seq_len, seq_len] + where True values indicate allowed attention positions. + + Example: + # Create a full lower triangular mask for causal attention + mask = tril(ctx, target, source_ir, "causal_mask", seq_len, seq_len) + + # Create a sliding window mask with window size 3 + mask = tril(ctx, target, source_ir, "sliding_mask", seq_len, seq_len, 3) + + Mask Examples: + Without sliding window (sliding_window_size=None): + For seq_len=5, returns: + [[ True, False, False, False, False], + [ True, True, False, False, False], + [ True, True, True, False, False], + [ True, True, True, True, False], + [ True, True, True, True, True]] + + With sliding window (sliding_window_size=3): + For seq_len=5, returns: + [[ True, False, False, False, False], + [ True, True, False, False, False], + [ True, True, True, False, False], + [False, True, True, True, False], + [False, False, True, True, True]] + + Note: + This function is specifically designed for attention mechanisms in transformer + models and is used internally by the scaled_dot_product_attention converter. + The sliding window functionality is particularly useful for models like Gemma3 + that use sliding window attention to reduce computational complexity. + """ row_arange_tensor = impl.arange.arange( ctx, target, source_ir, name + "_arange_row", start=0, end=row, step=1 ) @@ -86,8 +146,8 @@ def scaled_dot_product_attention( kwargs: Dict[str, Any], name: str, ) -> TRTTensor: - # TODO: Handle attn_mask and is_causal arguments in the future - query, key, value, attn_mask, dropout_p, is_causal = args + # always create our own attn_mask + query, key, value, _, dropout_p, is_causal = args # TODO: remove this once we have a better way to handle the causal mask scale = kwargs.get("scale", None) From 0653f36c826169c7a7c42b2cd180f318abf278ec Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Wed, 3 Sep 2025 14:07:52 -0700 Subject: [PATCH 4/4] add llm test cases --- docsrc/tutorials/compile_hf_models.rst | 8 ++- tests/py/dynamo/models/test_llm_models.py | 60 +++++++++++++++++++++++ 2 files changed, 66 insertions(+), 2 deletions(-) create mode 100644 tests/py/dynamo/models/test_llm_models.py diff --git a/docsrc/tutorials/compile_hf_models.rst b/docsrc/tutorials/compile_hf_models.rst index f6da87b145..8e1a8bbf01 100644 --- a/docsrc/tutorials/compile_hf_models.rst +++ b/docsrc/tutorials/compile_hf_models.rst @@ -59,6 +59,10 @@ We have officially verified support for the following LLM families: | Qwen/Qwen2.5-7B-Instruct - FP16, FP32 - Yes + * - Gemma 3 + - | google/gemma-3-1b-it + - FP16, FP32 + - Yes Getting Started with run_llm.py ------------------------------- @@ -185,8 +189,8 @@ The number of key/value cache tensors is equal to the number of attention heads Generating Outputs ------------------- -We use custom `generate `_ function to generate the outputs. This function performs standard autoregressive decoding without KV caching. -There is also a `generate_with_static_cache `_ function that performs autoregressive decoding with KV caching. +We use custom `generate `_ function to generate the outputs. This function performs standard autoregressive decoding without KV caching. +There is also a `generate_with_static_cache `_ function that performs autoregressive decoding with KV caching. The ``generate_with_static_cache`` function takes care of preparing the inputs to the model compiled with static KV cache. The model inputs are ``input_ids``, ``position_ids``, ``key_cache_0``, ``value_cache_0``, ...., ``start_idx``, ``end_idx``. diff --git a/tests/py/dynamo/models/test_llm_models.py b/tests/py/dynamo/models/test_llm_models.py new file mode 100644 index 0000000000..188954f68d --- /dev/null +++ b/tests/py/dynamo/models/test_llm_models.py @@ -0,0 +1,60 @@ +import os +import sys + +import pytest +import torch +import torch_tensorrt +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer + +sys.path.append(os.path.join(os.path.dirname(__file__), "../../../../tools/llm")) +import argparse + +from run_llm import compile_torchtrt +from torchtrt_ext import register_sdpa + + +@pytest.mark.unit +@pytest.mark.parametrize("precision", ["FP16", "BF16", "FP32"]) +def test_gemma3_decoder_layer(precision): + + with torch.inference_mode(): + args = argparse.Namespace() + args.debug = False + args.num_tokens = 128 + args.model = "google/gemma-3-1b-it" + args.precision = precision + args.min_block_size = 1 + args.prompt = "What is parallel programming ?" + if args.precision == "FP16": + dtype = torch.float16 + elif args.precision == "BF16": + dtype = torch.bfloat16 + else: + args.precision = "FP32" + dtype = torch.float32 + + model = ( + AutoModelForCausalLM.from_pretrained( + args.model, + use_cache=False, + attn_implementation="sdpa", + num_hidden_layers=1, + ) + .eval() + .to("cuda") + ) + + register_sdpa._SDPA_MAPPING[args.model](model_config=model.config) + model = model.to(dtype) + # use randint will generate nan values in the logits, use a fixed input_ids for now + # input_ids = torch.randint(0, model.config.vocab_size, (1, args.num_tokens)).to("cuda") + input_ids = torch.tensor([[2, 3689, 563, 10616, 14929, 2360]]).to("cuda") + + position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).to("cuda") + pyt_outputs = model(input_ids.clone(), position_ids=position_ids.clone()) + trt_model = compile_torchtrt(model, input_ids, args) + trt_outputs = trt_model(input_ids, position_ids=position_ids) + + torch.testing.assert_close( + pyt_outputs.logits, trt_outputs.logits, rtol=5e-1, atol=5e-1 + )