Skip to content

Commit

Permalink
Merge pull request #6359 from hiyouga/hiyouga/fix_qwen2vl_infer
Browse files Browse the repository at this point in the history
[model] fix qwen2vl infern
  • Loading branch information
hiyouga authored Dec 17, 2024
2 parents e2fbd07 + 142191e commit 81815f0
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 16 deletions.
14 changes: 10 additions & 4 deletions src/llamafactory/data/collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,15 @@ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tenso
fake_images = [Image.new("RGB", (64, 64), (255, 255, 255))]
fake_messages = self.template.mm_plugin.process_messages(fake_messages, fake_images, [], self.processor)
fake_input_ids = self.processor.tokenizer.encode(fake_messages[0]["content"], add_special_tokens=False)
features[0]["input_ids"] = features[0]["input_ids"] + fake_input_ids
features[0]["attention_mask"] = features[0]["attention_mask"] + [0] * len(fake_input_ids)
features[0]["labels"] = features[0]["labels"] + [IGNORE_INDEX] * len(fake_input_ids)
if self.tokenizer.padding_side == "right":
features[0]["input_ids"] = features[0]["input_ids"] + fake_input_ids
features[0]["attention_mask"] = features[0]["attention_mask"] + [0] * len(fake_input_ids)
features[0]["labels"] = features[0]["labels"] + [IGNORE_INDEX] * len(fake_input_ids)
else:
features[0]["input_ids"] = fake_input_ids + features[0]["input_ids"]
features[0]["attention_mask"] = [0] * len(fake_input_ids) + features[0]["attention_mask"]
features[0]["labels"] = [IGNORE_INDEX] * len(fake_input_ids) + features[0]["labels"]

batch_images = fake_images
batch_input_ids[0] = features[0]["input_ids"]

Expand All @@ -123,7 +129,7 @@ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tenso
features: Dict[str, "torch.Tensor"] = super().__call__(features)

if self.model is not None and hasattr(self.model, "get_rope_index"): # for qwen2vl mrope
features["position_ids"], _ = self.model.get_rope_index(
features["position_ids"], features["rope_deltas"] = self.model.get_rope_index(
input_ids=features["input_ids"],
image_grid_thw=mm_inputs.get("image_grid_thw", None),
video_grid_thw=mm_inputs.get("video_grid_thw", None),
Expand Down
26 changes: 15 additions & 11 deletions src/llamafactory/train/sft/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

if TYPE_CHECKING:
from torch.utils.data import Dataset
from transformers import ProcessorMixin
from transformers import PreTrainedTokenizer, ProcessorMixin
from transformers.trainer import PredictionOutput

from ...hparams import FinetuningArguments
Expand All @@ -53,6 +53,8 @@ def __init__(
) -> None:
if is_transformers_version_greater_than("4.46"):
kwargs["processing_class"] = kwargs.pop("tokenizer")
else:
self.processing_class: "PreTrainedTokenizer" = kwargs.get("tokenizer")

super().__init__(**kwargs)
self.finetuning_args = finetuning_args
Expand Down Expand Up @@ -113,7 +115,7 @@ def prediction_step(
"""
labels = inputs["labels"] if "labels" in inputs else None
if self.args.predict_with_generate:
assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor."
assert self.processing_class.padding_side == "left", "This method only accepts left-padded tensor."
labels = labels.detach().clone() if labels is not None else None # backup labels
prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1)
if prompt_len > label_len:
Expand All @@ -125,7 +127,7 @@ def prediction_step(
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
)
if generated_tokens is not None and self.args.predict_with_generate:
generated_tokens[:, :prompt_len] = self.tokenizer.pad_token_id
generated_tokens[:, :prompt_len] = self.processing_class.pad_token_id
generated_tokens = generated_tokens.contiguous()

return loss, generated_tokens, labels
Expand All @@ -134,8 +136,8 @@ def _pad_tensors_to_target_len(self, src_tensor: "torch.Tensor", tgt_tensor: "to
r"""
Pads the tensor to the same length as the target tensor.
"""
assert self.tokenizer.pad_token_id is not None, "Pad token is required."
padded_tensor = self.tokenizer.pad_token_id * torch.ones_like(tgt_tensor)
assert self.processing_class.pad_token_id is not None, "Pad token is required."
padded_tensor = self.processing_class.pad_token_id * torch.ones_like(tgt_tensor)
padded_tensor[:, -src_tensor.shape[-1] :] = src_tensor # adopt left-padding
return padded_tensor.contiguous() # in contiguous memory

Expand All @@ -152,20 +154,22 @@ def save_predictions(self, dataset: "Dataset", predict_results: "PredictionOutpu
logger.info_rank0(f"Saving prediction results to {output_prediction_file}")

labels = np.where(
predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id
predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.processing_class.pad_token_id
)
preds = np.where(
predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id
predict_results.predictions != IGNORE_INDEX,
predict_results.predictions,
self.processing_class.pad_token_id,
)

for i in range(len(preds)):
pad_len = np.nonzero(preds[i] != self.tokenizer.pad_token_id)[0]
pad_len = np.nonzero(preds[i] != self.processing_class.pad_token_id)[0]
if len(pad_len): # move pad token to last
preds[i] = np.concatenate((preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1)

decoded_inputs = self.tokenizer.batch_decode(dataset["input_ids"], skip_special_tokens=True)
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True)
decoded_inputs = self.processing_class.batch_decode(dataset["input_ids"], skip_special_tokens=False)
decoded_preds = self.processing_class.batch_decode(preds, skip_special_tokens=True)
decoded_labels = self.processing_class.batch_decode(labels, skip_special_tokens=True)

with open(output_prediction_file, "w", encoding="utf-8") as f:
for text, pred, label in zip(decoded_inputs, decoded_preds, decoded_labels):
Expand Down
2 changes: 1 addition & 1 deletion src/llamafactory/train/sft/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def run_sft(

data_collator = SFTDataCollatorWith4DAttentionMask(
template=template,
model=model,
model=model if not training_args.predict_with_generate else None,
pad_to_multiple_of=8 if training_args.do_train else None, # for shift short attention
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
block_diag_attn=model_args.block_diag_attn,
Expand Down

0 comments on commit 81815f0

Please sign in to comment.