Skip to content

Conversation

@karthikvetrivel
Copy link
Member

@karthikvetrivel karthikvetrivel commented Jan 13, 2026

[None][fix] Fix FakeTensorMode mismatch in pattern matcher for FP8 models

Summary by CodeRabbit

Release Notes

  • Bug Fixes
    • Enhanced the model export and compilation process to properly support additional tensor input types and configurations during the tracing phase, improving reliability and robustness when deploying models across diverse use cases, environments, and computational backends while reducing potential failures in edge scenarios.

✏️ Tip: You can customize this high-level summary in your review settings.

Description

Problem

FP8 models like nvidia/Llama-3.1-8B-Instruct-FP8 failed during the fuse_fp8_linear transform with the following error:

AssertionError: fake mode (<torch._subclasses.fake_tensor.FakeTensorMode object at 0x...>) 
from fake tensor input 0 doesn't match mode (<torch._subclasses.fake_tensor.FakeTensorMode object at 0x...>) 
from fake tensor input 1

Root Cause

During pattern matching in fuse_quant.py, PyTorch's pattern matcher calls trace_fn(replace_fn, args) to trace replacement patterns. The args are FakeTensor objects from the model's FakeTensorMode. However, _trace_to_gm was calling torch_export_to_gm, which creates a new FakeTensorMode context. This caused PyTorch's detect_fake_mode to assert that all fake tensors must come from the same mode.

Call stack:

fuse_quant.py:289 → patterns.apply()
  → torch/_inductor/pattern_matcher.py:1982 → entry.extra_check(m)
    → pattern_matcher.py:1437 → detect_fake_mode(args)  ← CONFLICT
      → pattern_matcher.py:1520 → trace_fn(replace_fn, args)
        → _trace_to_gm → torch_export_to_gm()  ← CREATES NEW MODE

Solution

Modified _trace_to_gm in pattern_matcher.py to detect if input tensors are FakeTensors and use make_fx (which respects the existing fake mode context) instead of torch_export_to_gm (which creates a new context).

def _trace_to_gm(fn: Callable, args: Sequence[torch.Tensor]) -> GraphModule:
    """Exports a function or Module into a GraphModule via torch_export_to_gm."""
    from torch._guards import detect_fake_mode
    from torch.fx.experimental.proxy_tensor import make_fx

    module = fn if isinstance(fn, torch.nn.Module) else _WrapperModule(fn)

    # Use make_fx for FakeTensors to avoid FakeTensorMode mismatch during pattern matching
    if detect_fake_mode(args) is not None:
        return make_fx(module, tracing_mode="fake")(*args)

    return torch_export_to_gm(module, tuple(args))

Test Coverage

Verified on H100 system with:

python3 examples/auto_deploy/build_and_run_ad.py \
  --model nvidia/Llama-3.1-8B-Instruct-FP8 \
  --args.yaml-extra examples/auto_deploy/model_registry/configs/dashboard_default.yaml \
  --args.yaml-extra examples/auto_deploy/model_registry/configs/world_size_2.yaml

Results:

  • fuse_fp8_linear: 224 pattern matches (previously failed)
  • fuse_allreduce_residual_rmsnorm: 64 pattern matches
  • Model inference produces coherent outputs

PR Checklist

  • Please check this after reviewing the above items as appropriate for this PR.

@karthikvetrivel karthikvetrivel requested a review from a team as a code owner January 13, 2026 22:10
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 13, 2026

📝 Walkthrough

Walkthrough

This change adds runtime detection and handling for FakeTensor inputs in the _trace_to_gm function, routing them through make_fx with fake tracing mode while preserving the existing torch_export_to_gm path for standard tensors.

Changes

Cohort / File(s) Summary
FakeTensor Tracing Path
tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py
Added conditional logic to detect FakeTensors and route to make_fx(tracing_mode="fake") when present, otherwise preserves existing torch_export_to_gm export path

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately describes the main fix: addressing a FakeTensorMode mismatch issue in the pattern matcher for FP8 models.
Description check ✅ Passed The description comprehensively covers problem statement, root cause analysis with call stack, solution with code snippet, and detailed test coverage results.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

📜 Recent review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between ccdfa43 and 2d476e8.

📒 Files selected for processing (1)
  • tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: The code developed for TensorRT-LLM should conform to Python 3.8+
Indent Python code with 4 spaces. Do not use tabs
Always maintain the namespace when importing Python modules, even if only one class or function from a module is used
Python filenames should use snake_case (e.g., some_file.py)
Python classes should use PascalCase (e.g., class SomeClass)
Python functions and methods should use snake_case (e.g., def my_awesome_function():)
Python local variables should use snake_case, with prefix k for variable names that start with a number (e.g., k_99th_percentile)
Python global variables should use upper snake_case with prefix G (e.g., G_MY_GLOBAL)
Python constants should use upper snake_case (e.g., MY_CONSTANT)
Avoid shadowing variables declared in an outer scope in Python
Initialize all externally visible members of a Python class in the constructor
For Python interfaces that may be used outside a file, prefer docstrings over comments
Use comments in Python for code within a function, or interfaces that are local to a file
Use Google-style docstrings for Python classes and functions, which can be parsed by Sphinx
Python attributes and variables can be documented inline with the format """<type>: Description"""
Avoid using reflection in Python when functionality can be easily achieved without reflection
When using try-except blocks in Python, limit the except clause to the smallest set of errors possible
When using try-except blocks in Python to handle multiple possible variable types (duck-typing), keep the body of the try as small as possible and use the else block for the main logic

Files:

  • tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py
**/*.{cpp,cc,cxx,h,hpp,hxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

All TensorRT-LLM source files (.cpp, .h, .cu, .py, and other source files) should contain an NVIDIA copyright header with the year of latest meaningful modification

Files:

  • tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py
🧠 Learnings (1)
📚 Learning: 2025-10-20T16:54:09.824Z
Learnt from: nvchenghaoz
Repo: NVIDIA/TensorRT-LLM PR: 8469
File: tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py:6-6
Timestamp: 2025-10-20T16:54:09.824Z
Learning: In tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py, the import `from ...modules.mamba.layernorm_gated import _layer_norm_fwd` is correct and should not be changed to modules.fla.layernorm_gated. The _layer_norm_fwd function exists in both modules/mamba/layernorm_gated.py and modules/fla/layernorm_gated.py, but the mamba version is the intended implementation for this use case.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py (1)

83-96: Add required NVIDIA copyright header to the file.

Per the coding guidelines, all TensorRT-LLM source files must include an NVIDIA copyright header. Add the following at the very beginning of the file (before the module docstring):

# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

Also update the docstring to reflect the dual-path behavior:

The docstring at lines 85-86 states the function exports "via torch_export_to_gm" but the function now conditionally uses either make_fx or torch_export_to_gm. Update to clarify both paths:

def _trace_to_gm(fn: Callable, args: Sequence[torch.Tensor]) -> GraphModule:
    """
    Exports a function or Module into a GraphModule.
    
    Uses make_fx for FakeTensor inputs to avoid FakeTensorMode mismatch,
    otherwise uses torch_export_to_gm.
    """

The use of internal PyTorch APIs (torch._guards.detect_fake_mode and torch.fx.experimental.proxy_tensor.make_fx) is acceptable for this use case but note these may change across PyTorch versions.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between ccdfa43 and 2d476e8.

📒 Files selected for processing (1)
  • tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: The code developed for TensorRT-LLM should conform to Python 3.8+
Indent Python code with 4 spaces. Do not use tabs
Always maintain the namespace when importing Python modules, even if only one class or function from a module is used
Python filenames should use snake_case (e.g., some_file.py)
Python classes should use PascalCase (e.g., class SomeClass)
Python functions and methods should use snake_case (e.g., def my_awesome_function():)
Python local variables should use snake_case, with prefix k for variable names that start with a number (e.g., k_99th_percentile)
Python global variables should use upper snake_case with prefix G (e.g., G_MY_GLOBAL)
Python constants should use upper snake_case (e.g., MY_CONSTANT)
Avoid shadowing variables declared in an outer scope in Python
Initialize all externally visible members of a Python class in the constructor
For Python interfaces that may be used outside a file, prefer docstrings over comments
Use comments in Python for code within a function, or interfaces that are local to a file
Use Google-style docstrings for Python classes and functions, which can be parsed by Sphinx
Python attributes and variables can be documented inline with the format """<type>: Description"""
Avoid using reflection in Python when functionality can be easily achieved without reflection
When using try-except blocks in Python, limit the except clause to the smallest set of errors possible
When using try-except blocks in Python to handle multiple possible variable types (duck-typing), keep the body of the try as small as possible and use the else block for the main logic

Files:

  • tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py
**/*.{cpp,cc,cxx,h,hpp,hxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

All TensorRT-LLM source files (.cpp, .h, .cu, .py, and other source files) should contain an NVIDIA copyright header with the year of latest meaningful modification

Files:

  • tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py
🧠 Learnings (1)
📚 Learning: 2025-10-20T16:54:09.824Z
Learnt from: nvchenghaoz
Repo: NVIDIA/TensorRT-LLM PR: 8469
File: tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py:6-6
Timestamp: 2025-10-20T16:54:09.824Z
Learning: In tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py, the import `from ...modules.mamba.layernorm_gated import _layer_norm_fwd` is correct and should not be changed to modules.fla.layernorm_gated. The _layer_norm_fwd function exists in both modules/mamba/layernorm_gated.py and modules/fla/layernorm_gated.py, but the mamba version is the intended implementation for this use case.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check

Comment on lines +83 to +94
def _trace_to_gm(fn: Callable, args: Sequence[torch.Tensor]) -> GraphModule:
"""
Exports a function or Module into a GraphModule via torch_export_to_gm.
"""
from torch._guards import detect_fake_mode
from torch.fx.experimental.proxy_tensor import make_fx

module = fn if isinstance(fn, torch.nn.Module) else _WrapperModule(fn)

# Use make_fx for FakeTensors to avoid FakeTensorMode mismatch during pattern matching
if detect_fake_mode(args) is not None:
return make_fx(module, tracing_mode="fake")(*args)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Fridah-nv , please take a look

@lucaslie lucaslie requested review from Fridah-nv and removed request for MrGeva January 13, 2026 22:32
@lucaslie lucaslie moved this from Backlog to In review in AutoDeploy Board Jan 13, 2026
Copy link
Collaborator

@Fridah-nv Fridah-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an interesting fix, thank you!
A while ago when I tried to use make_fx as the trace_fn to trace python code, I found that the graph output by make_fx is different from torch_export_to_gm/torch.export.export, I wonder if you observe something similar?
Since our initial graph is produced by torch_export_to_gm, there could be more matching failure if the pattern graph is generated with make_fx and be slightly different from the actual pattern in the target graph.
Please double check if this can be an issue for other pattern matchers, for fuse_fp8_linear it should be fine as it only contains one op in the pattern.

@karthikvetrivel
Copy link
Member Author

@Fridah-nv Thanks for taking a look! I think you're right—if we need to maintain strict graph semantics, then make_fx could eventually give us a different graph. I think this could be a band-aid fix for now to make models work where the replacement patterns are simple single-op substitutions (like fuse_fp8_linear) and have this fake tensor issue. These models wouldn't work w/ AutoDeploy as-is now.

I was exploring a few larger ideas, but they didn't seem worth the complexity. I'm open to exploring other ideas though and would love to hear your thoughts.

@karthikvetrivel
Copy link
Member Author

karthikvetrivel commented Jan 14, 2026

@Fridah-nv I ran tests to see if the graphs were different and I found that:

15/15 patterns produce identical graphs:

[fuse_quant.py] FP8/NVFP4 Linear: 4/4 
(4/4 meaning fp8_linear_no_bias, fp8_linear_with_bias, etc.)
[rms_norm.py] RMSNorm:            2/2   
[attention.py] Attention:         3/3 
[rope.py] RoPE:                   3/3 
[collectives.py] AllReduce+Norm:  1/1 
[fused_add_rms_norm.py]:          1/1 
[mxfp4_moe.py] Dense MoE:         1/1 

All replacement functions in register_ad_pattern call custom ops (torch.ops.auto_deploy.*), which cannot be decomposed differently by either tracing method. The graphs are identical. I guess the concern about graph differences applies to standard PyTorch ops (like F.gelu, F.layer_norm) which can be decomposed differently. But since all our replacements are custom op calls, this isn't an issue.

@Fridah-nv
Copy link
Collaborator

15/15 patterns produce identical graphs

That sounds good! I think it's fair to accept this fix as long as the unit tests pass. @lucaslie
My observation was a while ago and might not apply to the patterns we need.
A small question: torch._inductor pattern matcher will export both the target pattern graph and replacement graph with trace_fn, will you fix with make_fx apply to only the replacement graph or both?

@karthikvetrivel
Copy link
Member Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #32021 [ run ] triggered by Bot. Commit: 2d476e8

@karthikvetrivel
Copy link
Member Author

I believe the make_fx fix applies to only the replacement graph. Just to confirm, we won't find any FakeTensors when we're pattern tracing right?

The detect_fake_mode(args) check ensures torch_export_to_gm is used for pattern tracing (with your regular dummy_args) while make_fx is used for replacement tracing (when FakeTensors are present at match time). This relies on the assumption that FakeTensors don't pop up during pattern tracing

@Fridah-nv
Copy link
Collaborator

I believe the make_fx fix applies to only the replacement graph. We won't find any FakeTensors when we're pattern tracing.
Make sense to me, thanks

@tensorrt-cicd
Copy link
Collaborator

PR_Github #32021 [ run ] completed with state SUCCESS. Commit: 2d476e8
/LLM/main/L0_MergeRequest_PR pipeline #24811 completed with status: 'SUCCESS'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Status: In review

Development

Successfully merging this pull request may close these issues.

4 participants