Skip to content

Commit

Permalink
Update FA version to 2.5.6
Browse files Browse the repository at this point in the history
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
  • Loading branch information
ksivaman committed Mar 11, 2024
1 parent 8255f87 commit 829ce91
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def add_unique(l: List[str], vals: Union[str, List[str]]) -> None:

# Framework-specific requirements
if "pytorch" in frameworks():
add_unique(install_reqs, ["torch", "flash-attn>=2.0.6,<=2.4.2,!=2.0.9,!=2.1.0"])
add_unique(install_reqs, ["torch", "flash-attn>=2.0.6,<=2.5.6,!=2.0.9,!=2.1.0"])
add_unique(test_reqs, ["numpy", "onnxruntime", "torchvision"])
if "jax" in frameworks():
if not found_pybind11():
Expand Down
4 changes: 4 additions & 0 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@

_flash_attn_version = packaging.version.Version(version("flash-attn"))
_flash_attn_version_required = packaging.version.Version("2.0.6")
_flash_attn_max_version = packaging.version.Version("2.5.6")
_flash_attn_2_1_plus = _flash_attn_version >= packaging.version.Version("2.1")
_flash_attn_2_3_plus = _flash_attn_version >= packaging.version.Version("2.3")
_flash_attn_2_4_plus = _flash_attn_version >= packaging.version.Version("2.4")
Expand Down Expand Up @@ -1656,6 +1657,9 @@ def __init__(
assert (
_flash_attn_version >= _flash_attn_version_required
), f"FlashAttention minimum version {_flash_attn_version_required} is required."
assert (
_flash_attn_version <= _flash_attn_max_version
), f"FlashAttention maximum version {_flash_attn_max_version} is supported."

self.norm_factor = norm_factor
self.attention_dropout_ctx = attention_dropout_ctx
Expand Down

0 comments on commit 829ce91

Please sign in to comment.