Skip to content

Conversation

@yiakwy-xpu-ml-framework-team

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@yiakwy-xpu-ml-framework-team yiakwy-xpu-ml-framework-team changed the base branch from main to release_v2.9 December 2, 2025 22:57
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Dec 2, 2025

Greptile Overview

Greptile Summary

This PR fixes build issues for Transformer Engine 2.9.0 when used with torch 2.9.1 (as required by SGLang 0.5.5). The changes address three main issues:

  • C++ circular dependency: Reordered header includes in common.h to break circular dependency between common.h and cuda_driver.h
  • Missing include: Added cuda_driver.h include in nvshmem_waitkernel.cu for NVTE_CHECK macro
  • PyPI validation: Commented out assertions that prevented non-PyPI installations from working
  • Flash Attention 3 fallback: Added fallback imports when flash_attn_3.flash_attn_interface module structure differs

The C++ header fixes are solid. However, the Python changes weaken installation safety checks and the flash attention fallback logic appears incomplete - the bare flash_attn_interface import will likely fail without a parent module path.

Confidence Score: 3/5

  • This PR is moderately safe to merge with known risks in Python import handling
  • The C++ header fixes are correct and necessary (circular dependency resolution and missing include). However, the Python changes have issues: (1) commenting out PyPI validation weakens installation safety without a targeted fix, potentially allowing mismatched installations, and (2) the flash attention fallback logic appears broken - importing from bare flash_attn_interface without a parent module will fail. The PR may work for the specific SGLang use case but could cause issues in other scenarios.
  • Pay close attention to transformer_engine/pytorch/attention/dot_product_attention/backends.py - the fallback import logic needs verification

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/common/init.py 3/5 Commented out PyPI package validation checks to allow non-PyPI installations
transformer_engine/common/common.h 5/5 Fixed circular include dependency by reordering header includes
transformer_engine/common/nvshmem_api/nvshmem_waitkernel.cu 5/5 Added missing cuda_driver.h include required for NVTE_CHECK macro
transformer_engine/pytorch/attention/dot_product_attention/backends.py 2/5 Added fallback import for flash-attn-3 with non-standard module structure, but imports may fail at runtime

Sequence Diagram

sequenceDiagram
    participant Build as Build System
    participant PyInit as Python __init__.py
    participant CommonH as common.h
    participant CUDADriver as cuda_driver.h
    participant NVTE as transformer_engine.h
    participant Backends as backends.py
    participant FlashAttn as Flash Attention

    Note over Build,FlashAttn: Transformer Engine 2.9.0 Build Process

    Build->>CommonH: Compile C++ headers
    CommonH->>CUDADriver: #include cuda_driver.h
    CUDADriver->>CommonH: #include common.h (circular!)
    Note over CommonH,CUDADriver: Fixed: Move NVTE include after local headers
    CommonH->>NVTE: #include transformer_engine.h (reordered)
    
    Build->>PyInit: Import transformer_engine
    PyInit->>PyInit: Check package installation
    Note over PyInit: Commented out: assert te_installed_via_pypi
    PyInit->>PyInit: Allow non-PyPI installations
    
    Build->>Backends: Import backends.py
    Backends->>FlashAttn: try: from flash_attn_3...
    alt flash_attn_3 module exists
        FlashAttn-->>Backends: Import successful
    else ModuleNotFoundError
        Note over Backends: Print error message
        Backends->>FlashAttn: Fallback: from flash_attn_interface...
        FlashAttn-->>Backends: Import may fail
    end
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

# extension are all installed via PyPI and have matching versions.
if te_framework_installed:
assert te_installed_via_pypi, "Could not find `transformer-engine` PyPI package."
# assert te_installed_via_pypi, "Could not find `transformer-engine` PyPI package."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Commenting out this validation weakens installation safety checks. Users with mismatched installations (e.g., PyPI framework package + source-built core) may encounter runtime issues. Consider a more targeted fix that specifically allows the SGLang 0.5.5 scenario while keeping validation for other cases.

Comment on lines +145 to +153
from flash_attn_interface import flash_attn_func as flash_attn_func_v3
from flash_attn_interface import (
flash_attn_varlen_func as flash_attn_varlen_func_v3,
)
from flash_attn_interface import (
flash_attn_with_kvcache as flash_attn_with_kvcache_v3,
)
from flash_attn_interface import _flash_attn_forward as _flash_attn_fwd_v3
from flash_attn_interface import _flash_attn_backward as _flash_attn_bwd_v3
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: This fallback will fail if flash_attn_interface module doesn't exist. The package flash-attn-3 was detected but imports from flash_attn_3.flash_attn_interface failed. If the goal is to support an alternative location like flash_attn.flash_attn_interface (from flash-attn v2), this should use flash_attn.flash_attn_interface instead of bare flash_attn_interface. Without the proper module path, these imports will raise ModuleNotFoundError and the variables will remain undefined, causing issues later.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant