-
Notifications
You must be signed in to change notification settings - Fork 583
Fix transformer 2.9.0 (torch 2.9.1 used by SGLang 0.5.5) build #2445
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: release_v2.9
Are you sure you want to change the base?
Fix transformer 2.9.0 (torch 2.9.1 used by SGLang 0.5.5) build #2445
Conversation
for more information, see https://pre-commit.ci
Greptile OverviewGreptile SummaryThis PR fixes build issues for Transformer Engine 2.9.0 when used with torch 2.9.1 (as required by SGLang 0.5.5). The changes address three main issues:
The C++ header fixes are solid. However, the Python changes weaken installation safety checks and the flash attention fallback logic appears incomplete - the bare Confidence Score: 3/5
Important Files ChangedFile Analysis
Sequence DiagramsequenceDiagram
participant Build as Build System
participant PyInit as Python __init__.py
participant CommonH as common.h
participant CUDADriver as cuda_driver.h
participant NVTE as transformer_engine.h
participant Backends as backends.py
participant FlashAttn as Flash Attention
Note over Build,FlashAttn: Transformer Engine 2.9.0 Build Process
Build->>CommonH: Compile C++ headers
CommonH->>CUDADriver: #include cuda_driver.h
CUDADriver->>CommonH: #include common.h (circular!)
Note over CommonH,CUDADriver: Fixed: Move NVTE include after local headers
CommonH->>NVTE: #include transformer_engine.h (reordered)
Build->>PyInit: Import transformer_engine
PyInit->>PyInit: Check package installation
Note over PyInit: Commented out: assert te_installed_via_pypi
PyInit->>PyInit: Allow non-PyPI installations
Build->>Backends: Import backends.py
Backends->>FlashAttn: try: from flash_attn_3...
alt flash_attn_3 module exists
FlashAttn-->>Backends: Import successful
else ModuleNotFoundError
Note over Backends: Print error message
Backends->>FlashAttn: Fallback: from flash_attn_interface...
FlashAttn-->>Backends: Import may fail
end
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 files reviewed, 2 comments
| # extension are all installed via PyPI and have matching versions. | ||
| if te_framework_installed: | ||
| assert te_installed_via_pypi, "Could not find `transformer-engine` PyPI package." | ||
| # assert te_installed_via_pypi, "Could not find `transformer-engine` PyPI package." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: Commenting out this validation weakens installation safety checks. Users with mismatched installations (e.g., PyPI framework package + source-built core) may encounter runtime issues. Consider a more targeted fix that specifically allows the SGLang 0.5.5 scenario while keeping validation for other cases.
| from flash_attn_interface import flash_attn_func as flash_attn_func_v3 | ||
| from flash_attn_interface import ( | ||
| flash_attn_varlen_func as flash_attn_varlen_func_v3, | ||
| ) | ||
| from flash_attn_interface import ( | ||
| flash_attn_with_kvcache as flash_attn_with_kvcache_v3, | ||
| ) | ||
| from flash_attn_interface import _flash_attn_forward as _flash_attn_fwd_v3 | ||
| from flash_attn_interface import _flash_attn_backward as _flash_attn_bwd_v3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: This fallback will fail if flash_attn_interface module doesn't exist. The package flash-attn-3 was detected but imports from flash_attn_3.flash_attn_interface failed. If the goal is to support an alternative location like flash_attn.flash_attn_interface (from flash-attn v2), this should use flash_attn.flash_attn_interface instead of bare flash_attn_interface. Without the proper module path, these imports will raise ModuleNotFoundError and the variables will remain undefined, causing issues later.
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: