@@ -524,6 +524,148 @@ def write_vocab_gguf(dir_model):
524524 print ("Done. Output file: " + fname_out )
525525 print ("" )
526526
527+ def chatglm4_convert (model , tokenizer , dir_model , fname_out , ftype , hparams ):
528+ print ("GLM-4 converting: " )
529+ list_vars = model .state_dict ()
530+ for name in list_vars .keys ():
531+ print (name , list_vars [name ].shape , list_vars [name ].dtype )
532+
533+ fout = open (fname_out , "wb" )
534+
535+ print (hparams )
536+
537+ fout .write (struct .pack ("i" , 0x67676d66 ))
538+ fout .write (struct .pack ("i" , 1 ))
539+
540+ fout .write (struct .pack ("i" , hparams ["padded_vocab_size" ]))
541+ fout .write (struct .pack ("i" , hparams ["hidden_size" ]))
542+ fout .write (struct .pack ("i" , 0 ))
543+ fout .write (struct .pack ("i" , hparams ["num_attention_heads" ]))
544+ fout .write (struct .pack ("i" , 0 ))
545+ fout .write (struct .pack ("i" , hparams ["num_layers" ]))
546+ fout .write (struct .pack ("i" , 0 ))
547+ fout .write (struct .pack ("i" , ftype ))
548+ fout .write (struct .pack ("i" , hparams ["seq_length" ]))
549+ fout .write (struct .pack ("f" , 0 ))
550+ fout .write (struct .pack ("f" , 0 ))
551+ fout .write (struct .pack ("i" , 0 ))
552+
553+ fout .write (struct .pack ("i" , 0 )) # word_embed_proj_dim (for opt)
554+ fout .write (struct .pack ("i" , 0 )) # do_layer_norm_before (for opt)
555+
556+ fout .write (struct .pack ("i" , hparams ["multi_query_group_num" ]))
557+ fout .write (struct .pack ("i" , hparams ["ffn_hidden_size" ]))
558+ fout .write (struct .pack ("i" , 0 ))
559+ fout .write (struct .pack ("i" , 0 )) # n_experts
560+ fout .write (struct .pack ("i" , 0 )) # n_expert_used
561+ fout .write (struct .pack ("i" , 0 )) # n_embd_head_k for gemma
562+ fout .write (struct .pack ("f" , hparams .get ("layernorm_epsilon" , 1e-5 ))) # rms_norm_eps or layer_norm_eps
563+ fout .write (struct .pack ("f" , 10000.0 )) # freq_base
564+ fout .write (struct .pack ("f" , hparams .get ("rope_ratio" , 1 ))) # rope_factor
565+
566+ fout .write (struct .pack ("f" , 0.0 )) # config.json "rope_scaling.factor", not enabled
567+ fout .write (struct .pack ("i" , 0 )) # rope_scaling.original_max_position_embeddings
568+ fout .write (struct .pack ("i" , 0 )) # params["rope_scaling"]["type"] =="yarn" else 0))
569+
570+ fout .write (struct .pack ("i" , tokenizer .bos_token_id if tokenizer .bos_token_id is not None else 1 ))
571+ fout .write (struct .pack ("i" , tokenizer .eos_token_id if tokenizer .eos_token_id is not None else 2 ))
572+ fout .write (struct .pack ("i" , tokenizer .pad_token_id if tokenizer .pad_token_id is not None else - 1 ))
573+ fout .write (struct .pack ("i" , tokenizer .sep_token_id if tokenizer .sep_token_id is not None else - 1 ))
574+
575+
576+ for i in range (hparams ["vocab_size" ]):
577+ if i < tokenizer .vocab_size :
578+ text = tokenizer .decode ([i ]).encode ('utf-8' )
579+ fout .write (struct .pack ("i" , len (text )))
580+ fout .write (text )
581+ fout .write (struct .pack ("f" , 0.0 - i ))
582+ else :
583+ text = tokenizer .decode ([tokenizer .vocab_size - 1 ]).encode ('utf-8' )
584+ fout .write (struct .pack ("i" , len (text )))
585+ fout .write (text )
586+ fout .write (struct .pack ("f" , - 10000 ))
587+
588+ for name in list_vars .keys ():
589+ data = list_vars [name ].float ().squeeze ().numpy ()
590+ data = data .astype (np .float32 )
591+ if name == "transformer.rotary_pos_emb.inv_freq" :
592+ continue
593+ # No gradients for these
594+
595+ n_dims = len (data .shape )
596+ print (name , n_dims , data .shape )
597+
598+ # default type is fp32
599+ ftype_cur = 0
600+ if ftype == 1 and n_dims > 1 :
601+ print (" Converting to float16" , data .shape , data [:3 , :3 ].tolist ())
602+ data = data .astype (np .float16 )
603+ ftype_cur = 1
604+ else :
605+ print (" Converting to float32" , data .shape , data [:3 , :3 ].tolist () if n_dims > 1 else data [:3 ].tolist ())
606+ data = data .astype (np .float32 )
607+
608+ # header
609+ str = name .encode ('utf-8' )
610+ fout .write (struct .pack ("iii" , n_dims , len (str ), ftype_cur ))
611+ for i in range (n_dims ):
612+ fout .write (struct .pack ("i" , data .shape [n_dims - 1 - i ]))
613+ print (str )
614+ fout .write (str )
615+
616+ # data
617+ data .tofile (fout )
618+ if "mlp.dense_h_to_4h" in name :
619+ name_0 = name .replace ("dense_h_to_4h" , "dense_h_to_4h_0" )
620+ name_1 = name .replace ("dense_h_to_4h" , "dense_h_to_4h_1" )
621+ shape_0 = data .shape [0 ]
622+ half_shape_0 = int (shape_0 / 2 )
623+ data_0 = data [0 :half_shape_0 , :]
624+ data_1 = data [half_shape_0 :shape_0 , :]
625+
626+ print ("Converting: %-75s" % name_0 , " shape: " , data_0 .shape )
627+ print ("Converting: %-75s" % name_1 , " shape: " , data_1 .shape )
628+
629+ n_dims = len (data_0 .shape )
630+ assert (len (data_0 .shape ) == len (data_1 .shape ))
631+ # ftype == 0 -> float32, ftype == 1 -> float16
632+ ftype_cur = 0
633+ if ftype != 0 :
634+ if name_0 [- 7 :] == ".weight" and n_dims == 2 :
635+ print (" to float16" .rjust (15 ))
636+ data_0 = data_0 .astype (np .float16 )
637+ data_1 = data_1 .astype (np .float32 )
638+ ftype_cur = 1
639+ else :
640+ print (" to float32" .rjust (15 ))
641+ data_0 = data_0 .astype (np .float32 )
642+ data_1 = data_1 .astype (np .float32 )
643+ ftype_cur = 0
644+ else :
645+ if data_0 .dtype != np .float32 :
646+ print (" to float32" .rjust (15 ))
647+ data_0 = data_0 .astype (np .float32 )
648+ data_1 = data_1 .astype (np .float32 )
649+ ftype_cur = 0
650+
651+ str_0 = name_0 .encode ("utf-8" )
652+ fout .write (struct .pack ("iii" , n_dims , len (str_0 ), ftype_cur ))
653+ for i in range (n_dims ):
654+ fout .write (struct .pack ("i" , data_0 .shape [n_dims - 1 - i ]))
655+ fout .write (str_0 )
656+ data_0 .tofile (fout )
657+
658+ str_1 = name_1 .encode ("utf-8" )
659+ fout .write (struct .pack ("iii" , n_dims , len (str_1 ), ftype_cur ))
660+ for i in range (n_dims ):
661+ fout .write (struct .pack ("i" , data_1 .shape [n_dims - 1 - i ]))
662+ fout .write (str_1 )
663+ data_1 .tofile (fout )
664+
665+ fout .close ()
666+
667+ print ("Done. Output file: " + fname_out )
668+ print ("" )
527669
528670def chatglm3_convert (model , tokenizer , dir_model , fname_out , ftype , hparams ):
529671 print ("ChatGLM-3 converting: " )
@@ -973,7 +1115,10 @@ def main(args_in: Optional[List[str]] = None) -> None:
9731115 # ChatGLM3 shares the same architecture and model config with ChatGLM2
9741116 # but its tokenizer further supports system prompts,
9751117 # so we can check system token to discriminate ChatGLM3 from ChatGLM2.
976- if hasattr (tokenizer , "tokenizer" ) and "<|system|>" in tokenizer .tokenizer .special_tokens :
1118+ # For GLM4-9B
1119+ if model .config .num_layers == 40 :
1120+ chatglm4_convert (model , tokenizer , dir_model , fname_out , ftype , hparams )
1121+ elif hasattr (tokenizer , "tokenizer" ) and "<|system|>" in tokenizer .tokenizer .special_tokens :
9771122 if args .format == "GGUF" :
9781123 chatglm3_convert_gguf (model , tokenizer , dir_model , fname_out , ftype , hparams )
9791124 else :
0 commit comments