diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 61ce2890ae..24278a6d1e 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -27,8 +27,10 @@ try: from flash_attn.flash_attn_interface import flash_attn_unpadded_func except ImportError: - flash_attn_unpadded_func = None - + try: + from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_unpadded_func + except ImportError: + flash_attn_unpadded_func = None """ We use the following notation throughout this file: h: hidden size