-
Notifications
You must be signed in to change notification settings - Fork 2k
[None][fix] Fix FakeTensorMode mismatch in pattern matcher for FP8 models #10646
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[None][fix] Fix FakeTensorMode mismatch in pattern matcher for FP8 models #10646
Conversation
…ake tensors Signed-off-by: Karthik Vetrivel <[email protected]>
📝 WalkthroughWalkthroughThis change adds runtime detection and handling for FakeTensor inputs in the Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes 🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
📜 Recent review detailsConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
🧰 Additional context used📓 Path-based instructions (2)**/*.py📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Files:
**/*.{cpp,cc,cxx,h,hpp,hxx,cu,cuh,py}📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Files:
🧠 Learnings (1)📚 Learning: 2025-10-20T16:54:09.824ZApplied to files:
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py (1)
83-96: Add required NVIDIA copyright header to the file.Per the coding guidelines, all TensorRT-LLM source files must include an NVIDIA copyright header. Add the following at the very beginning of the file (before the module docstring):
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0Also update the docstring to reflect the dual-path behavior:
The docstring at lines 85-86 states the function exports "via torch_export_to_gm" but the function now conditionally uses either
make_fxortorch_export_to_gm. Update to clarify both paths:def _trace_to_gm(fn: Callable, args: Sequence[torch.Tensor]) -> GraphModule: """ Exports a function or Module into a GraphModule. Uses make_fx for FakeTensor inputs to avoid FakeTensorMode mismatch, otherwise uses torch_export_to_gm. """The use of internal PyTorch APIs (
torch._guards.detect_fake_modeandtorch.fx.experimental.proxy_tensor.make_fx) is acceptable for this use case but note these may change across PyTorch versions.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py: The code developed for TensorRT-LLM should conform to Python 3.8+
Indent Python code with 4 spaces. Do not use tabs
Always maintain the namespace when importing Python modules, even if only one class or function from a module is used
Python filenames should use snake_case (e.g.,some_file.py)
Python classes should use PascalCase (e.g.,class SomeClass)
Python functions and methods should use snake_case (e.g.,def my_awesome_function():)
Python local variables should use snake_case, with prefixkfor variable names that start with a number (e.g.,k_99th_percentile)
Python global variables should use upper snake_case with prefixG(e.g.,G_MY_GLOBAL)
Python constants should use upper snake_case (e.g.,MY_CONSTANT)
Avoid shadowing variables declared in an outer scope in Python
Initialize all externally visible members of a Python class in the constructor
For Python interfaces that may be used outside a file, prefer docstrings over comments
Use comments in Python for code within a function, or interfaces that are local to a file
Use Google-style docstrings for Python classes and functions, which can be parsed by Sphinx
Python attributes and variables can be documented inline with the format"""<type>: Description"""
Avoid using reflection in Python when functionality can be easily achieved without reflection
When using try-except blocks in Python, limit the except clause to the smallest set of errors possible
When using try-except blocks in Python to handle multiple possible variable types (duck-typing), keep the body of the try as small as possible and use the else block for the main logic
Files:
tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py
**/*.{cpp,cc,cxx,h,hpp,hxx,cu,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
All TensorRT-LLM source files (.cpp, .h, .cu, .py, and other source files) should contain an NVIDIA copyright header with the year of latest meaningful modification
Files:
tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py
🧠 Learnings (1)
📚 Learning: 2025-10-20T16:54:09.824Z
Learnt from: nvchenghaoz
Repo: NVIDIA/TensorRT-LLM PR: 8469
File: tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py:6-6
Timestamp: 2025-10-20T16:54:09.824Z
Learning: In tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py, the import `from ...modules.mamba.layernorm_gated import _layer_norm_fwd` is correct and should not be changed to modules.fla.layernorm_gated. The _layer_norm_fwd function exists in both modules/mamba/layernorm_gated.py and modules/fla/layernorm_gated.py, but the mamba version is the intended implementation for this use case.
Applied to files:
tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
| def _trace_to_gm(fn: Callable, args: Sequence[torch.Tensor]) -> GraphModule: | ||
| """ | ||
| Exports a function or Module into a GraphModule via torch_export_to_gm. | ||
| """ | ||
| from torch._guards import detect_fake_mode | ||
| from torch.fx.experimental.proxy_tensor import make_fx | ||
|
|
||
| module = fn if isinstance(fn, torch.nn.Module) else _WrapperModule(fn) | ||
|
|
||
| # Use make_fx for FakeTensors to avoid FakeTensorMode mismatch during pattern matching | ||
| if detect_fake_mode(args) is not None: | ||
| return make_fx(module, tracing_mode="fake")(*args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Fridah-nv , please take a look
Fridah-nv
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is an interesting fix, thank you!
A while ago when I tried to use make_fx as the trace_fn to trace python code, I found that the graph output by make_fx is different from torch_export_to_gm/torch.export.export, I wonder if you observe something similar?
Since our initial graph is produced by torch_export_to_gm, there could be more matching failure if the pattern graph is generated with make_fx and be slightly different from the actual pattern in the target graph.
Please double check if this can be an issue for other pattern matchers, for fuse_fp8_linear it should be fine as it only contains one op in the pattern.
|
@Fridah-nv Thanks for taking a look! I think you're right—if we need to maintain strict graph semantics, then I was exploring a few larger ideas, but they didn't seem worth the complexity. I'm open to exploring other ideas though and would love to hear your thoughts. |
|
@Fridah-nv I ran tests to see if the graphs were different and I found that: 15/15 patterns produce identical graphs: All replacement functions in |
That sounds good! I think it's fair to accept this fix as long as the unit tests pass. @lucaslie |
|
/bot run |
|
PR_Github #32021 [ run ] triggered by Bot. Commit: |
|
I believe the The |
|
|
PR_Github #32021 [ run ] completed with state |
[None][fix] Fix FakeTensorMode mismatch in pattern matcher for FP8 models
Summary by CodeRabbit
Release Notes
✏️ Tip: You can customize this high-level summary in your review settings.
Description
Problem
FP8 models like
nvidia/Llama-3.1-8B-Instruct-FP8failed during thefuse_fp8_lineartransform with the following error:Root Cause
During pattern matching in
fuse_quant.py, PyTorch's pattern matcher callstrace_fn(replace_fn, args)to trace replacement patterns. TheargsareFakeTensorobjects from the model'sFakeTensorMode. However,_trace_to_gmwas callingtorch_export_to_gm, which creates a newFakeTensorModecontext. This caused PyTorch'sdetect_fake_modeto assert that all fake tensors must come from the same mode.Call stack:
Solution
Modified
_trace_to_gminpattern_matcher.pyto detect if input tensors areFakeTensors and usemake_fx(which respects the existing fake mode context) instead oftorch_export_to_gm(which creates a new context).Test Coverage
Verified on H100 system with:
Results:
fuse_fp8_linear: 224 pattern matches (previously failed)fuse_allreduce_residual_rmsnorm: 64 pattern matchesPR Checklist