From 1a03e5d08757604a69fab2c8e84c863cd21b549b Mon Sep 17 00:00:00 2001 From: Lawrence McAfee Date: Wed, 19 Jul 2023 13:43:29 -0700 Subject: [PATCH] Test #2: Memory, timing --- megatron/model/transformer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 61ce2890ae..24278a6d1e 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -27,8 +27,10 @@ try: from flash_attn.flash_attn_interface import flash_attn_unpadded_func except ImportError: - flash_attn_unpadded_func = None - + try: + from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_unpadded_func + except ImportError: + flash_attn_unpadded_func = None """ We use the following notation throughout this file: h: hidden size