Skip to content

Commit 1800f8c

Browse files
committed
fix #6499
1 parent f8e80d5 commit 1800f8c

File tree

5 files changed

+18
-28
lines changed

5 files changed

+18
-28
lines changed

src/llamafactory/data/preprocess.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from .processors.feedback import preprocess_feedback_dataset
1919
from .processors.pairwise import preprocess_pairwise_dataset, print_pairwise_dataset_example
20-
from .processors.pretrain import preprocess_pretrain_dataset
20+
from .processors.pretrain import preprocess_pretrain_dataset, print_pretrain_dataset_example
2121
from .processors.supervised import (
2222
preprocess_packed_supervised_dataset,
2323
preprocess_supervised_dataset,
@@ -47,7 +47,7 @@ def get_preprocess_and_print_func(
4747
tokenizer=tokenizer,
4848
data_args=data_args,
4949
)
50-
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
50+
print_function = partial(print_pretrain_dataset_example, tokenizer=tokenizer)
5151
elif stage == "sft" and not do_generate:
5252
if data_args.packing:
5353
if data_args.neat_packing: # hack datasets to have int32 attention mask

src/llamafactory/data/processors/pretrain.py

+5
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,8 @@ def preprocess_pretrain_dataset(
5252
result["input_ids"][i][0] = tokenizer.bos_token_id
5353

5454
return result
55+
56+
57+
def print_pretrain_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
58+
print("input_ids:\n{}".format(example["input_ids"]))
59+
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))

src/llamafactory/data/processors/unsupervised.py

+2
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,5 @@ def preprocess_unsupervised_dataset(
100100
def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
101101
print("input_ids:\n{}".format(example["input_ids"]))
102102
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
103+
print("label_ids:\n{}".format(example["labels"]))
104+
print("labels:\n{}".format(tokenizer.decode(example["labels"], skip_special_tokens=False)))

src/llamafactory/train/sft/trainer.py

+9-22
Original file line numberDiff line numberDiff line change
@@ -111,40 +111,27 @@ def prediction_step(
111111
inputs: Dict[str, Union["torch.Tensor", Any]],
112112
prediction_loss_only: bool,
113113
ignore_keys: Optional[List[str]] = None,
114+
**gen_kwargs,
114115
) -> Tuple[Optional[float], Optional["torch.Tensor"], Optional["torch.Tensor"]]:
115116
r"""
116117
Removes the prompt part in the generated tokens.
117118
118119
Subclass and override to inject custom behavior.
119120
"""
120-
labels = inputs["labels"] if "labels" in inputs else None
121-
if self.args.predict_with_generate:
122-
assert self.processing_class.padding_side == "left", "This method only accepts left-padded tensor."
123-
labels = labels.detach().clone() if labels is not None else None # backup labels
124-
prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1)
125-
if prompt_len > label_len:
126-
inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"])
127-
if label_len > prompt_len: # truncate the labels instead of padding the inputs (llama2 fp16 compatibility)
128-
inputs["labels"] = inputs["labels"][:, :prompt_len]
129-
130-
loss, generated_tokens, _ = super().prediction_step( # ignore the returned labels (may be truncated)
131-
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
121+
if self.args.predict_with_generate: # do not pass labels to model when generate
122+
labels = inputs.pop("labels", None)
123+
else:
124+
labels = inputs.get("labels")
125+
126+
loss, generated_tokens, _ = super().prediction_step(
127+
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys, **gen_kwargs
132128
)
133129
if generated_tokens is not None and self.args.predict_with_generate:
134-
generated_tokens[:, :prompt_len] = self.processing_class.pad_token_id
130+
generated_tokens[:, : inputs["input_ids"].size(-1)] = self.processing_class.pad_token_id
135131
generated_tokens = generated_tokens.contiguous()
136132

137133
return loss, generated_tokens, labels
138134

139-
def _pad_tensors_to_target_len(self, src_tensor: "torch.Tensor", tgt_tensor: "torch.Tensor") -> "torch.Tensor":
140-
r"""
141-
Pads the tensor to the same length as the target tensor.
142-
"""
143-
assert self.processing_class.pad_token_id is not None, "Pad token is required."
144-
padded_tensor = self.processing_class.pad_token_id * torch.ones_like(tgt_tensor)
145-
padded_tensor[:, -src_tensor.shape[-1] :] = src_tensor # adopt left-padding
146-
return padded_tensor.contiguous() # in contiguous memory
147-
148135
def save_predictions(
149136
self, dataset: "Dataset", predict_results: "PredictionOutput", skip_special_tokens: bool = True
150137
) -> None:

src/llamafactory/train/sft/workflow.py

-4
Original file line numberDiff line numberDiff line change
@@ -117,17 +117,13 @@ def run_sft(
117117
# Evaluation
118118
if training_args.do_eval:
119119
metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs)
120-
if training_args.predict_with_generate: # eval_loss will be wrong if predict_with_generate is enabled
121-
metrics.pop("eval_loss", None)
122120
trainer.log_metrics("eval", metrics)
123121
trainer.save_metrics("eval", metrics)
124122

125123
# Predict
126124
if training_args.do_predict:
127125
logger.warning_once("Batch generation can be very slow. Consider using `scripts/vllm_infer.py` instead.")
128126
predict_results = trainer.predict(dataset_module["eval_dataset"], metric_key_prefix="predict", **gen_kwargs)
129-
if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled
130-
predict_results.metrics.pop("predict_loss", None)
131127
trainer.log_metrics("predict", predict_results.metrics)
132128
trainer.save_metrics("predict", predict_results.metrics)
133129
trainer.save_predictions(dataset_module["eval_dataset"], predict_results, generating_args.skip_special_tokens)

0 commit comments

Comments
 (0)