-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpre_model_convert.py
executable file
·39 lines (33 loc) · 1.03 KB
/
pre_model_convert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import os
import torch
import time
from safetensors.torch import save_model, load_model
from transformers import AutoModelForCausalLM
DIR_PATH = "checkpoints/baichuan2_7b_checkpoints"
def hfmodel_torch_2_safetensors(model_dir: str):
tmp_time = time.time()
model = AutoModelForCausalLM.from_pretrained(
model_dir,
trust_remote_code=True,
device_map="auto",
torch_dtype=(
torch.bfloat16
if torch.cuda.is_bf16_supported()
else torch.float32
),
)
print("loaded: ", time.time() - tmp_time)
tmp_time = time.time()
save_model(
model=model,
filename=os.path.join(model_dir, "model.safetensors"),
)
print("safetensors save time:", time.time() - tmp_time)
del model
torch.cuda.empty_cache()
if __name__ == "__main__":
for root, dirs, files in os.walk(DIR_PATH):
for dir in dirs:
model_dir = os.path.join(root, dir)
print(model_dir)
hfmodel_torch_2_safetensors(model_dir=model_dir)