Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prompt Tuning Crash with Llama-3.2 in torch.embedding #2161

Open
2 of 4 tasks
hrsmanian opened this issue Oct 18, 2024 · 4 comments
Open
2 of 4 tasks

Prompt Tuning Crash with Llama-3.2 in torch.embedding #2161

hrsmanian opened this issue Oct 18, 2024 · 4 comments

Comments

@hrsmanian
Copy link

hrsmanian commented Oct 18, 2024

System Info

peft==0.13.2
accelerate==1.0.1
torch==2.4.0

peft_config

peft_config = PromptTuningConfig(
    task_type=TaskType.CAUSAL_LM,
    prompt_tuning_init=PromptTuningInit.TEXT,
    num_virtual_tokens=len(tokenizer(prompt_tuning_init_text)["input_ids"]),
    prompt_tuning_init_text=prompt_tuning_init_text,
    tokenizer_name_or_path=base_model_id,
)

Stack trace

File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/trl/trainer/sft_trainer.py", line 361, in train
    output = super().train(*args, **kwargs)
  File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/transformers/trainer.py", line 2052, in train
    return inner_training_loop(
  File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/transformers/trainer.py", line 2388, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/transformers/trainer.py", line 3485, in training_step
    loss = self.compute_loss(model, inputs)
  File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/transformers/trainer.py", line 3532, in compute_loss
    outputs = model(**inputs)
  File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/accelerate/utils/operations.py", line 820, in forward
    return model_forward(*args, **kwargs)
  File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/accelerate/utils/operations.py", line 808, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
  File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 43, in decorate_autocast
    return func(*args, **kwargs)
  File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/peft/peft_model.py", line 1688, in forward
    prompts = self.get_prompt(batch_size=batch_size, task_ids=task_ids)
  File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/peft/peft_model.py", line 733, in get_prompt
    prompts = prompt_encoder(prompt_tokens)
  File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/peft/tuners/prompt_tuning/model.py", line 90, in forward
    prompt_embeddings = self.embedding(indices)
  File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/nn/modules/sparse.py", line 164, in forward
    return F.embedding(
  File "/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/nn/functional.py", line 2267, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: 'weight' must be 2-D
  0%|          | 0/5986 [00:00<?, ?it/s]         

Who can help?

@saya

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

Any kind of causal LM task tuning shows this issue

Expected behavior

Expected the training to happen

@BenjaminBossan
Copy link
Member

Thanks for reporting this. With the information you've given, I could not not reproduce the error. This is what I tried:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PromptTuningConfig, get_peft_model, TaskType, PromptTuningInit

model_id = "meta-llama/Llama-3.2-1B"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_id)

prompt_tuning_init_text = "Think carefully"
peft_config = PromptTuningConfig(
    task_type=TaskType.CAUSAL_LM,
    prompt_tuning_init=PromptTuningInit.TEXT,
    num_virtual_tokens=len(tokenizer(prompt_tuning_init_text)["input_ids"]),
    prompt_tuning_init_text=prompt_tuning_init_text,
    tokenizer_name_or_path=model_id,
)
model = get_peft_model(model, peft_config)

sentence = "The quick brown fox jumps over the lazy dog."
sample_input = tokenizer(sentence, return_tensors="pt")
output = model(**sample_input)

Could you please provide a minimal reproducer for the error?

@saya
Copy link

saya commented Oct 18, 2024

Can someone please edit the Issue template so that I don‘t get pinged anymore? I am not affiliated with this project.

Thanks

@BenjaminBossan
Copy link
Member

Saya, I'm very sorry about that. Your name is not on the issue template:

Library: @benjaminbossan @sayakpaul

We have "sayakpaul" but I honestly don't know why people are constantly pinging you. I think it must have something to do with how GitHub outcompletes, so people type "@saya" and then "@sayakpaul" is suggested and they hit enter but for some reason only "@saya" is entered. I'm not sure, as I can't reproduce it, but that's the only explanation I got. LMK if you have any idea how this can be remedied without outright removing "sayakpaul".

@saya
Copy link

saya commented Oct 18, 2024

@BenjaminBossan I see. Thanks for the fast reply and sorry for not researching thoroughly. I will just ignore this repo as well in my notification settings, and I should be fine.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants