Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit ef42ce1

Browse files
authored
fix baichuan, chatglm1&2&3 acc issue (#285)
Signed-off-by: Yu Zhentao <[email protected]>
1 parent 04a8029 commit ef42ce1

File tree

3 files changed

+19
-2
lines changed

3 files changed

+19
-2
lines changed

neural_speed/models/baichuan/baichuan.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ static bool baichuan_model_eval_internal(model_context* ctx, const model_input*
317317
}
318318

319319
lctx.use_buf(ctx0, -1);
320-
if (embd->ne[0] > 1) {
320+
if (!lctx.logits_all && embd->ne[0] > 1) {
321321
inpL = ne_view_1d(ctx0, inpL, n_embd, (embd->ne[0] - 1) * n_embd * ne_element_size(inpL));
322322
}
323323

neural_speed/models/chatglm/chatglm2.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ static bool chatglm_model_eval_internal(model_context* ctx, const model_input* i
337337
}
338338

339339
lctx.use_buf(ctx0, -1);
340-
if (embd->ne[0] > 1) {
340+
if (!lctx.logits_all && embd->ne[0] > 1) {
341341
inpL = ne_view_1d(ctx0, inpL, n_embd, (embd->ne[0] - 1) * n_embd * ne_element_size(inpL));
342342
}
343343
// lm_head

scripts/huggingface.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,13 @@ def __init__(
395395
self._rank = 0
396396
self._world_size = 1
397397

398+
self.model_type = self._config.model_type
399+
if self.model_type == "chatglm" and "chatglm2" in self._config._name_or_path:
400+
self.model_type = "chatglm2"
401+
if self.model_type == "chatglm" and "chatglm3" in self._config._name_or_path:
402+
# due to the same model architecture.
403+
self.model_type = "chatglm2"
404+
398405
@property
399406
def config(self):
400407
# return the associated transformers.AutoConfig for the given pretrained model.
@@ -949,6 +956,11 @@ def tok_encode(
949956
if left_truncate_len:
950957
encoding = encoding[-left_truncate_len:] # pylint: disable=E1130
951958

959+
if self.model_type == "chatglm":
960+
# hacky code for chatGLM
961+
# remove gmask_token_id and bos_token_id for slicing input and label
962+
encoding = encoding[:-2]
963+
952964
return encoding
953965

954966
def tok_batch_encode(
@@ -1296,6 +1308,11 @@ def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]):
12961308
dtype=torch.long,
12971309
device=self.device,
12981310
)
1311+
# hacky code for chatGLM1 inputs
1312+
# it will add [130001, 130004] tokens in the end (gmask_token_id + bos_token_id)
1313+
if self.model_type == "chatglm":
1314+
bos = torch.tensor([self.tokenizer.gmask_token_id, self.tokenizer.bos_token_id])
1315+
inp = torch.cat((inp, bos), -1)
12991316
(inplen,) = inp.shape
13001317
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
13011318
inp = torch.tensor(

0 commit comments

Comments
 (0)