Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 33 additions & 10 deletions qllm/modeling/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from transformers import AutoModelForCausalLM
from pathlib import Path
import tqdm
import functools
import glob
import json
import contextlib
Expand Down Expand Up @@ -46,6 +47,14 @@ def no_init_weights(attrs: list = None):
if old_attr[idx] is not None:
setattr(torch.Tensor, attr, old_attr[idx])

@contextlib.contextmanager
def patch_cache_file_in_parallel():
old_cached_file_func = transformers.utils.hub.cached_file
old_get_checkpoint_shard_files_func = transformers.utils.hub.get_checkpoint_shard_files
yield
transformers.utils.hub.cached_file = old_cached_file_func
transformers.utils.hub.get_checkpoint_shard_files = old_get_checkpoint_shard_files_func

def get_no_split_layer_type_name(model:torch.nn.Module):
try:
return model._get_no_split_modules("auto")
Expand Down Expand Up @@ -92,6 +101,16 @@ def _get_resolved_weight_or_index_file(model_name_or_path):
return str(weight_or_index_file)


def parallel_download_decorator(task_func_shard, *args, **kwargs):
with concurrent.futures.ThreadPoolExecutor() as executor:
def cached_file_func_in_thread(task_func, *args, **kwargs):
return executor.submit(task_func, *args, **kwargs)
transformers.utils.hub.cached_file = functools.partial(cached_file_func_in_thread, transformers.utils.hub.cached_file)
result = task_func_shard(*args, **kwargs)
result_0 = [future.result() for future in result[0]]
return result_0, result[1]


def _load_check_point(model, model_name_or_path, get_keys_only: bool = False):
weight_or_index_file = _get_resolved_weight_or_index_file(model_name_or_path)
all_keys = set()
Expand Down Expand Up @@ -183,16 +202,20 @@ def from_pretrained(

torch_dtype = kwargs.pop("torch_dtype", auto_conf.torch_dtype)

llm = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path,
torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code,
attn_implementation=attn_implementation,
# device_map="auto",
# low_cpu_mem_usage=True,
# max_memory={0: 1*1024 * 1024 * 1024, "cpu": 5*1024 * 1024 * 1024},
# offload_folder="/tmp/a2"
)
with patch_cache_file_in_parallel():
transformers.utils.hub.get_checkpoint_shard_files = functools.partial(
parallel_download_decorator, transformers.utils.hub.get_checkpoint_shard_files
)
llm = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path,
torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code,
attn_implementation=attn_implementation,
# device_map="auto",
# low_cpu_mem_usage=True,
# max_memory={0: 1*1024 * 1024 * 1024, "cpu": 5*1024 * 1024 * 1024},
# offload_folder="/tmp/a2"
)
return llm

@classmethod
Expand Down