Skip to content

Commit

Permalink
Merge pull request #6362 from hiyouga/hiyouga/mllm_packing
Browse files Browse the repository at this point in the history
[model] generalized packing
  • Loading branch information
hiyouga authored Dec 17, 2024
2 parents 81815f0 + 2d107d3 commit 9708a39
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 57 deletions.
2 changes: 1 addition & 1 deletion src/llamafactory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
longlora:
transformers>=4.41.2,<=4.46.1
packing:
transformers>=4.41.2,<=4.46.1
transformers>=4.43.0,<=4.46.1
Disable version checking: DISABLE_VERSION_CHECK=1
Enable VRAM recording: RECORD_VRAM=1
Expand Down
13 changes: 0 additions & 13 deletions src/llamafactory/extras/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,19 +81,6 @@

STAGES_USE_PAIR_DATA = {"rm", "dpo"}

SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN = {
"cohere",
"falcon",
"gemma",
"gemma2",
"llama",
"mistral",
"phi",
"phi3",
"qwen2",
"starcoder2",
}

SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}

VIDEO_PLACEHOLDER = os.environ.get("VIDEO_PLACEHOLDER", "<video>")
Expand Down
50 changes: 8 additions & 42 deletions src/llamafactory/model/model_utils/packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,14 @@
from transformers.utils.versions import require_version

from ...extras import logging
from ...extras.constants import SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN
from ...extras.packages import is_transformers_version_greater_than


if TYPE_CHECKING:
from transformers import PretrainedConfig
if is_transformers_version_greater_than("4.43.0"):
import transformers.modeling_flash_attention_utils


if TYPE_CHECKING:
from ...hparams import ModelArguments


Expand Down Expand Up @@ -113,45 +114,10 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor
return indices, cu_seqlens, max_seqlen_in_batch


def _patch_for_block_diag_attn(model_type: str) -> None:
require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1")
if is_transformers_version_greater_than("4.43.0"):
import transformers.modeling_flash_attention_utils

transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data
return

import transformers.models

if model_type == "cohere":
transformers.models.cohere.modeling_cohere._get_unpad_data = get_unpad_data
elif model_type == "falcon":
transformers.models.falcon.modeling_falcon._get_unpad_data = get_unpad_data
elif model_type == "gemma":
transformers.models.gemma.modeling_gemma._get_unpad_data = get_unpad_data
elif model_type == "gemma2":
transformers.models.gemma2.modeling_gemma2._get_unpad_data = get_unpad_data
elif model_type == "llama":
transformers.models.llama.modeling_llama._get_unpad_data = get_unpad_data
elif model_type == "mistral":
transformers.models.mistral.modeling_mistral._get_unpad_data = get_unpad_data
elif model_type == "phi":
transformers.models.phi.modeling_phi._get_unpad_data = get_unpad_data
elif model_type == "phi3":
transformers.models.phi3.modeling_phi3._get_unpad_data = get_unpad_data
elif model_type == "qwen2":
transformers.models.qwen2.modeling_qwen2._get_unpad_data = get_unpad_data
elif model_type == "starcoder2":
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = get_unpad_data


def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
def configure_packing(model_args: "ModelArguments", is_trainable: bool) -> None:
if not is_trainable or not model_args.block_diag_attn:
return

model_type = getattr(config, "model_type", None)
if model_type in SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN:
_patch_for_block_diag_attn(model_type)
logger.info_rank0("Using block diagonal attention for sequence packing without cross-attention.")
else:
raise ValueError("Current model does not support block diagonal attention.")
require_version("transformers>=4.43.0,<=4.46.1", "To fix: pip install transformers>=4.43.0,<=4.46.1")
transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data
logger.info_rank0("Using block diagonal attention for sequence packing without cross-attention.")
2 changes: 1 addition & 1 deletion src/llamafactory/model/patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def patch_config(
configure_quantization(config, tokenizer, model_args, init_kwargs)
configure_moe(config, model_args, is_trainable)
configure_visual_model(config)
configure_packing(config, model_args, is_trainable)
configure_packing(model_args, is_trainable)

if model_args.use_cache and not is_trainable:
setattr(config, "use_cache", True)
Expand Down

0 comments on commit 9708a39

Please sign in to comment.