Skip to content

Commit a562531

Browse files
committed
fix #6499
1 parent f8e80d5 commit a562531

File tree

3 files changed

+13
-26
lines changed

3 files changed

+13
-26
lines changed

src/llamafactory/data/processors/unsupervised.py

+4
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
1717

1818
from ...extras import logging
19+
from ...extras.constants import IGNORE_INDEX
1920
from ..data_utils import Role
2021
from .processor_utils import infer_seqlen
2122

@@ -98,5 +99,8 @@ def preprocess_unsupervised_dataset(
9899

99100

100101
def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
102+
valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"]))
101103
print("input_ids:\n{}".format(example["input_ids"]))
102104
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
105+
print("label_ids:\n{}".format(example["labels"]))
106+
print(f"labels:\n{tokenizer.decode(valid_labels, skip_special_tokens=False)}")

src/llamafactory/train/sft/trainer.py

+9-22
Original file line numberDiff line numberDiff line change
@@ -111,40 +111,27 @@ def prediction_step(
111111
inputs: Dict[str, Union["torch.Tensor", Any]],
112112
prediction_loss_only: bool,
113113
ignore_keys: Optional[List[str]] = None,
114+
**gen_kwargs,
114115
) -> Tuple[Optional[float], Optional["torch.Tensor"], Optional["torch.Tensor"]]:
115116
r"""
116117
Removes the prompt part in the generated tokens.
117118
118119
Subclass and override to inject custom behavior.
119120
"""
120-
labels = inputs["labels"] if "labels" in inputs else None
121-
if self.args.predict_with_generate:
122-
assert self.processing_class.padding_side == "left", "This method only accepts left-padded tensor."
123-
labels = labels.detach().clone() if labels is not None else None # backup labels
124-
prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1)
125-
if prompt_len > label_len:
126-
inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"])
127-
if label_len > prompt_len: # truncate the labels instead of padding the inputs (llama2 fp16 compatibility)
128-
inputs["labels"] = inputs["labels"][:, :prompt_len]
129-
130-
loss, generated_tokens, _ = super().prediction_step( # ignore the returned labels (may be truncated)
131-
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
121+
if self.args.predict_with_generate: # do not pass labels to model when generate
122+
labels = inputs.pop("labels", None)
123+
else:
124+
labels = inputs.get("labels")
125+
126+
loss, generated_tokens, _ = super().prediction_step(
127+
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys, **gen_kwargs
132128
)
133129
if generated_tokens is not None and self.args.predict_with_generate:
134-
generated_tokens[:, :prompt_len] = self.processing_class.pad_token_id
130+
generated_tokens[:, : inputs["input_ids"].size(-1)] = self.processing_class.pad_token_id
135131
generated_tokens = generated_tokens.contiguous()
136132

137133
return loss, generated_tokens, labels
138134

139-
def _pad_tensors_to_target_len(self, src_tensor: "torch.Tensor", tgt_tensor: "torch.Tensor") -> "torch.Tensor":
140-
r"""
141-
Pads the tensor to the same length as the target tensor.
142-
"""
143-
assert self.processing_class.pad_token_id is not None, "Pad token is required."
144-
padded_tensor = self.processing_class.pad_token_id * torch.ones_like(tgt_tensor)
145-
padded_tensor[:, -src_tensor.shape[-1] :] = src_tensor # adopt left-padding
146-
return padded_tensor.contiguous() # in contiguous memory
147-
148135
def save_predictions(
149136
self, dataset: "Dataset", predict_results: "PredictionOutput", skip_special_tokens: bool = True
150137
) -> None:

src/llamafactory/train/sft/workflow.py

-4
Original file line numberDiff line numberDiff line change
@@ -117,17 +117,13 @@ def run_sft(
117117
# Evaluation
118118
if training_args.do_eval:
119119
metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs)
120-
if training_args.predict_with_generate: # eval_loss will be wrong if predict_with_generate is enabled
121-
metrics.pop("eval_loss", None)
122120
trainer.log_metrics("eval", metrics)
123121
trainer.save_metrics("eval", metrics)
124122

125123
# Predict
126124
if training_args.do_predict:
127125
logger.warning_once("Batch generation can be very slow. Consider using `scripts/vllm_infer.py` instead.")
128126
predict_results = trainer.predict(dataset_module["eval_dataset"], metric_key_prefix="predict", **gen_kwargs)
129-
if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled
130-
predict_results.metrics.pop("predict_loss", None)
131127
trainer.log_metrics("predict", predict_results.metrics)
132128
trainer.save_metrics("predict", predict_results.metrics)
133129
trainer.save_predictions(dataset_module["eval_dataset"], predict_results, generating_args.skip_special_tokens)

0 commit comments

Comments
 (0)