Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit ea20cc2

Browse files
authored
[Model]enable glm4-9b (#291)
* enable glm4-9b Signed-off-by: intellinjun <[email protected]> * Update __init__.py * Update __init__.py * add glm4 extension test Signed-off-by: intellinjun <[email protected]> * update huggingface.py Signed-off-by: intellinjun <[email protected]> * Update huggingface.py * update huggingface.py Signed-off-by: intellinjun <[email protected]> * update huggingface Signed-off-by: intellinjun <[email protected]> --------- Signed-off-by: intellinjun <[email protected]>
1 parent d4e5289 commit ea20cc2

File tree

6 files changed

+171
-4
lines changed

6 files changed

+171
-4
lines changed

docs/supported_models.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,8 @@ Neural Speed supports the following models:
235235
<tr>
236236
<td><a href="https://huggingface.co/THUDM/chatglm-6b" target="_blank" rel="noopener noreferrer">ChatGLM-6B</a>,
237237
<a href="https://huggingface.co/THUDM/chatglm2-6b" target="_blank" rel="noopener noreferrer">ChatGLM2-6B</a>,
238-
<a href="https://huggingface.co/THUDM/chatglm3-6b" target="_blank" rel="noopener noreferrer">ChatGLM3-6B</a></td>
238+
<a href="https://huggingface.co/THUDM/chatglm3-6b" target="_blank" rel="noopener noreferrer">ChatGLM3-6B</a>,
239+
<a href="https://huggingface.co/THUDM/glm-4-9b" target="_blank" rel="noopener noreferrer">GLM-4-9B</a></td>
239240
<td>✅</td>
240241
<td> </td>
241242
<td> </td>

neural_speed/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,10 @@ def _get_model_type(model_config):
9595
if model_type == "chatglm" and "chatglm3" in model_config._name_or_path:
9696
# due to the same model architecture.
9797
model_type = "chatglm2"
98-
98+
# For GLM4
99+
if model_type == "chatglm" and "glm-4" in model_config._name_or_path:
100+
# due to the same model architecture.
101+
model_type = "chatglm2"
99102
# for TheBloke/falcon-40b-instruct-GPTQ & TheBloke/Falcon-7B-Instruct-GPTQ
100103
if model_type == "RefinedWebModel" or model_type == "RefinedWeb":
101104
model_type = "falcon"

neural_speed/convert/convert_chatglm.py

Lines changed: 146 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,148 @@ def write_vocab_gguf(dir_model):
524524
print("Done. Output file: " + fname_out)
525525
print("")
526526

527+
def chatglm4_convert(model, tokenizer, dir_model, fname_out, ftype, hparams):
528+
print("GLM-4 converting: ")
529+
list_vars = model.state_dict()
530+
for name in list_vars.keys():
531+
print(name, list_vars[name].shape, list_vars[name].dtype)
532+
533+
fout = open(fname_out, "wb")
534+
535+
print(hparams)
536+
537+
fout.write(struct.pack("i", 0x67676d66))
538+
fout.write(struct.pack("i", 1))
539+
540+
fout.write(struct.pack("i", hparams["padded_vocab_size"]))
541+
fout.write(struct.pack("i", hparams["hidden_size"]))
542+
fout.write(struct.pack("i", 0))
543+
fout.write(struct.pack("i", hparams["num_attention_heads"]))
544+
fout.write(struct.pack("i", 0))
545+
fout.write(struct.pack("i", hparams["num_layers"]))
546+
fout.write(struct.pack("i", 0))
547+
fout.write(struct.pack("i", ftype))
548+
fout.write(struct.pack("i", hparams["seq_length"]))
549+
fout.write(struct.pack("f", 0))
550+
fout.write(struct.pack("f", 0))
551+
fout.write(struct.pack("i", 0))
552+
553+
fout.write(struct.pack("i", 0)) # word_embed_proj_dim (for opt)
554+
fout.write(struct.pack("i", 0)) # do_layer_norm_before (for opt)
555+
556+
fout.write(struct.pack("i", hparams["multi_query_group_num"]))
557+
fout.write(struct.pack("i", hparams["ffn_hidden_size"]))
558+
fout.write(struct.pack("i", 0))
559+
fout.write(struct.pack("i", 0)) # n_experts
560+
fout.write(struct.pack("i", 0)) # n_expert_used
561+
fout.write(struct.pack("i", 0)) # n_embd_head_k for gemma
562+
fout.write(struct.pack("f", hparams.get("layernorm_epsilon", 1e-5))) # rms_norm_eps or layer_norm_eps
563+
fout.write(struct.pack("f", 10000.0)) # freq_base
564+
fout.write(struct.pack("f", hparams.get("rope_ratio", 1))) # rope_factor
565+
566+
fout.write(struct.pack("f", 0.0)) # config.json "rope_scaling.factor", not enabled
567+
fout.write(struct.pack("i", 0)) # rope_scaling.original_max_position_embeddings
568+
fout.write(struct.pack("i", 0)) # params["rope_scaling"]["type"] =="yarn" else 0))
569+
570+
fout.write(struct.pack("i", tokenizer.bos_token_id if tokenizer.bos_token_id is not None else 1))
571+
fout.write(struct.pack("i", tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 2))
572+
fout.write(struct.pack("i", tokenizer.pad_token_id if tokenizer.pad_token_id is not None else -1))
573+
fout.write(struct.pack("i", tokenizer.sep_token_id if tokenizer.sep_token_id is not None else -1))
574+
575+
576+
for i in range(hparams["vocab_size"]):
577+
if i < tokenizer.vocab_size:
578+
text = tokenizer.decode([i]).encode('utf-8')
579+
fout.write(struct.pack("i", len(text)))
580+
fout.write(text)
581+
fout.write(struct.pack("f", 0.0 - i))
582+
else:
583+
text = tokenizer.decode([tokenizer.vocab_size - 1]).encode('utf-8')
584+
fout.write(struct.pack("i", len(text)))
585+
fout.write(text)
586+
fout.write(struct.pack("f", -10000))
587+
588+
for name in list_vars.keys():
589+
data = list_vars[name].float().squeeze().numpy()
590+
data = data.astype(np.float32)
591+
if name == "transformer.rotary_pos_emb.inv_freq":
592+
continue
593+
# No gradients for these
594+
595+
n_dims = len(data.shape)
596+
print(name, n_dims, data.shape)
597+
598+
# default type is fp32
599+
ftype_cur = 0
600+
if ftype == 1 and n_dims > 1:
601+
print(" Converting to float16", data.shape, data[:3, :3].tolist())
602+
data = data.astype(np.float16)
603+
ftype_cur = 1
604+
else:
605+
print(" Converting to float32", data.shape, data[:3, :3].tolist() if n_dims > 1 else data[:3].tolist())
606+
data = data.astype(np.float32)
607+
608+
# header
609+
str = name.encode('utf-8')
610+
fout.write(struct.pack("iii", n_dims, len(str), ftype_cur))
611+
for i in range(n_dims):
612+
fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
613+
print(str)
614+
fout.write(str)
615+
616+
# data
617+
data.tofile(fout)
618+
if "mlp.dense_h_to_4h" in name:
619+
name_0 = name.replace("dense_h_to_4h", "dense_h_to_4h_0")
620+
name_1 = name.replace("dense_h_to_4h", "dense_h_to_4h_1")
621+
shape_0 = data.shape[0]
622+
half_shape_0 = int(shape_0 / 2)
623+
data_0 = data[0:half_shape_0, :]
624+
data_1 = data[half_shape_0:shape_0, :]
625+
626+
print("Converting: %-75s" % name_0, " shape: ", data_0.shape)
627+
print("Converting: %-75s" % name_1, " shape: ", data_1.shape)
628+
629+
n_dims = len(data_0.shape)
630+
assert (len(data_0.shape) == len(data_1.shape))
631+
# ftype == 0 -> float32, ftype == 1 -> float16
632+
ftype_cur = 0
633+
if ftype != 0:
634+
if name_0[-7:] == ".weight" and n_dims == 2:
635+
print(" to float16".rjust(15))
636+
data_0 = data_0.astype(np.float16)
637+
data_1 = data_1.astype(np.float32)
638+
ftype_cur = 1
639+
else:
640+
print(" to float32".rjust(15))
641+
data_0 = data_0.astype(np.float32)
642+
data_1 = data_1.astype(np.float32)
643+
ftype_cur = 0
644+
else:
645+
if data_0.dtype != np.float32:
646+
print(" to float32".rjust(15))
647+
data_0 = data_0.astype(np.float32)
648+
data_1 = data_1.astype(np.float32)
649+
ftype_cur = 0
650+
651+
str_0 = name_0.encode("utf-8")
652+
fout.write(struct.pack("iii", n_dims, len(str_0), ftype_cur))
653+
for i in range(n_dims):
654+
fout.write(struct.pack("i", data_0.shape[n_dims - 1 - i]))
655+
fout.write(str_0)
656+
data_0.tofile(fout)
657+
658+
str_1 = name_1.encode("utf-8")
659+
fout.write(struct.pack("iii", n_dims, len(str_1), ftype_cur))
660+
for i in range(n_dims):
661+
fout.write(struct.pack("i", data_1.shape[n_dims - 1 - i]))
662+
fout.write(str_1)
663+
data_1.tofile(fout)
664+
665+
fout.close()
666+
667+
print("Done. Output file: " + fname_out)
668+
print("")
527669

528670
def chatglm3_convert(model, tokenizer, dir_model, fname_out, ftype, hparams):
529671
print("ChatGLM-3 converting: ")
@@ -973,7 +1115,10 @@ def main(args_in: Optional[List[str]] = None) -> None:
9731115
# ChatGLM3 shares the same architecture and model config with ChatGLM2
9741116
# but its tokenizer further supports system prompts,
9751117
# so we can check system token to discriminate ChatGLM3 from ChatGLM2.
976-
if hasattr(tokenizer, "tokenizer") and "<|system|>" in tokenizer.tokenizer.special_tokens:
1118+
# For GLM4-9B
1119+
if model.config.num_layers == 40:
1120+
chatglm4_convert(model, tokenizer, dir_model, fname_out, ftype, hparams)
1121+
elif hasattr(tokenizer, "tokenizer") and "<|system|>" in tokenizer.tokenizer.special_tokens:
9771122
if args.format == "GGUF":
9781123
chatglm3_convert_gguf(model, tokenizer, dir_model, fname_out, ftype, hparams)
9791124
else:

neural_speed/models/chatglm/chatglm2.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,12 @@ static const model_scratch chatglm_mem_req(int n_layers, float scratch_size_rati
3131
static_cast<unsigned long long>(scratch_size_ratio * 2048) * MB,
3232
static_cast<unsigned long long>(scratch_size_ratio * 4096) * MB,
3333
};
34+
case 40:
35+
return {
36+
static_cast<unsigned long long>(scratch_size_ratio * 4096) * MB,
37+
static_cast<unsigned long long>(scratch_size_ratio * 2048) * MB,
38+
static_cast<unsigned long long>(scratch_size_ratio * 4096) * MB,
39+
};
3440
default:
3541
MODEL_ASSERT(false);
3642
}

scripts/huggingface.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,10 @@ def __init__(
401401
if self.model_type == "chatglm" and "chatglm3" in self._config._name_or_path:
402402
# due to the same model architecture.
403403
self.model_type = "chatglm2"
404+
# For GLM4
405+
if self.model_type == "chatglm" and "glm-4" in self._config._name_or_path:
406+
# due to the same model architecture.
407+
self.model_type = "chatglm2"
404408

405409
@property
406410
def config(self):
@@ -594,6 +598,8 @@ def _create_model(
594598
if init_from_bin != "default_none":
595599
if self.config.model_type == "chatglm" and "chatglm2" in self.config._name_or_path:
596600
model_type = "chatglm2"
601+
elif self.config.model_type == "chatglm" and "glm-4" in self.config._name_or_path:
602+
model_type = "chatglm2"
597603
else:
598604
model_type = self.config.model_type
599605

tests/model-test/cpp_graph_inference.sh

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ model_name_map["baichuan13b-gptq"]="Baichuan2-13B-Chat-GPTQ"
170170
model_name_map["mistral-gptq"]="TheBloke/Mistral-7B-Instruct-v0.2-GPTQ"
171171
model_name_map["phi3"]="microsoft/Phi-3-mini-128k-instruct"
172172
model_name_map["llama3"]="meta-llama/Meta-Llama-3-8B"
173+
model_name_map["glm4"]="THUDM/glm-4-9b"
173174

174175

175176
function main() {
@@ -256,6 +257,11 @@ function main() {
256257
extension=" --model_name chatglm3 --tokenizer $model_path"
257258
requirements_file="$working_dir/neural_speed/models/requirements/chatglm-6b.sh"
258259
input_list=(32 1024)
260+
elif [[ "${model}" == "glm4" ]]; then
261+
quant_script="./build/bin/quant_chatglm2"
262+
convert_script="${convert_script}/convert_chatglm.py"
263+
infer_cmd="./build/bin/run_chatglm2"
264+
input_list=(32 1024)
259265
elif [[ "${model}" == "chatglm-6b" ]]; then
260266
quant_script="./build/bin/quant_chatglm"
261267
convert_script="${convert_script}/convert_chatglm.py"
@@ -474,7 +480,7 @@ function main() {
474480
$infer_cmd -f "/tf_dataset2/models/nlp_toolkit/whisper-tiny/jfk.wav" -m ${model}-${precision}.bin
475481
else
476482
real_ctx=$ctx # TODO(Zhenzhong): use same ctx for chatglm & baichuan
477-
[[ "${model}" == "chatglm2" || "${model}" == "chatglm-6b" ||
483+
[[ "${model}" == "chatglm2" || "${model}" == "chatglm-6b" || "${model}" == "glm4" ||
478484
"${model}" == "baichuan-13b" || "${model}" == "baichuan2-13b" ]] && real_ctx=2048
479485
if [[ "${model}" == *"gptq" ]]; then
480486
NEURAL_SPEED_VERBOSE=1 OMP_NUM_THREADS=$cores_per_instance numactl -m 0 -C 0-$(($cores_per_instance - 1)) $infer_cmd 2>&1 | tee ${WORKSPACE}/${logs_file} || true &

0 commit comments

Comments
 (0)