Skip to content

Commit

Permalink
Merge pull request #6363 from hiyouga/hiyouga/control_skip_eos
Browse files Browse the repository at this point in the history
[infer] support control eos
  • Loading branch information
hiyouga authored Dec 17, 2024
2 parents 9708a39 + eda76de commit 6973828
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 7 deletions.
8 changes: 6 additions & 2 deletions src/llamafactory/chat/hf_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,9 @@ def _chat(
)
generate_output = model.generate(**gen_kwargs)
response_ids = generate_output[:, prompt_length:]
response = tokenizer.batch_decode(response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
response = tokenizer.batch_decode(
response_ids, skip_special_tokens=generating_args["skip_special_tokens"], clean_up_tokenization_spaces=True
)
results = []
for i in range(len(response)):
eos_index = (response_ids[i] == tokenizer.eos_token_id).nonzero()
Expand Down Expand Up @@ -249,7 +251,9 @@ def _stream_chat(
videos,
input_kwargs,
)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
streamer = TextIteratorStreamer(
tokenizer, skip_prompt=True, skip_special_tokens=generating_args["skip_special_tokens"]
)
gen_kwargs["streamer"] = streamer
thread = Thread(target=model.generate, kwargs=gen_kwargs, daemon=True)
thread.start()
Expand Down
2 changes: 1 addition & 1 deletion src/llamafactory/chat/vllm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ async def _generate(
stop=stop,
stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
max_tokens=max_tokens,
skip_special_tokens=True,
skip_special_tokens=self.generating_args["skip_special_tokens"],
)

if images is not None: # add image features
Expand Down
4 changes: 4 additions & 0 deletions src/llamafactory/hparams/generating_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ class GeneratingArguments:
default=None,
metadata={"help": "Default system message to use in chat completion."},
)
skip_special_tokens: bool = field(
default=True,
metadata={"help": "Whether or not to remove special tokens in the decoding."},
)

def to_dict(self) -> Dict[str, Any]:
args = asdict(self)
Expand Down
12 changes: 9 additions & 3 deletions src/llamafactory/train/sft/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,9 @@ def _pad_tensors_to_target_len(self, src_tensor: "torch.Tensor", tgt_tensor: "to
padded_tensor[:, -src_tensor.shape[-1] :] = src_tensor # adopt left-padding
return padded_tensor.contiguous() # in contiguous memory

def save_predictions(self, dataset: "Dataset", predict_results: "PredictionOutput") -> None:
def save_predictions(
self, dataset: "Dataset", predict_results: "PredictionOutput", gen_kwargs: Dict[str, Any]
) -> None:
r"""
Saves model predictions to `output_dir`.
Expand All @@ -168,8 +170,12 @@ def save_predictions(self, dataset: "Dataset", predict_results: "PredictionOutpu
preds[i] = np.concatenate((preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1)

decoded_inputs = self.processing_class.batch_decode(dataset["input_ids"], skip_special_tokens=False)
decoded_preds = self.processing_class.batch_decode(preds, skip_special_tokens=True)
decoded_labels = self.processing_class.batch_decode(labels, skip_special_tokens=True)
decoded_preds = self.processing_class.batch_decode(
preds, skip_special_tokens=gen_kwargs["skip_special_tokens"]
)
decoded_labels = self.processing_class.batch_decode(
labels, skip_special_tokens=gen_kwargs["skip_special_tokens"]
)

with open(output_prediction_file, "w", encoding="utf-8") as f:
for text, pred, label in zip(decoded_inputs, decoded_preds, decoded_labels):
Expand Down
2 changes: 1 addition & 1 deletion src/llamafactory/train/sft/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def run_sft(
predict_results.metrics.pop("predict_loss", None)
trainer.log_metrics("predict", predict_results.metrics)
trainer.save_metrics("predict", predict_results.metrics)
trainer.save_predictions(dataset_module["eval_dataset"], predict_results)
trainer.save_predictions(dataset_module["eval_dataset"], predict_results, gen_kwargs)

# Create model card
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)

0 comments on commit 6973828

Please sign in to comment.