-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Deprecation: Transformers will no longer support past_key_values
to be tuples
#1962
Comments
@BenjaminBossan @ArthurZucker Can you provide more information regarding deprecation of get_seq_length(). I have some idea about prompt learning methods and transformer's KV Cache. Current implementation of cache utilities: Can you point me to the correct branch? |
Hi @piyush-llm 👋 Alongside a Cache, the |
Thanks for the details. Also check out #869 for discussion and code snippets. |
is it ok to just convert past_key_values from a tuple to a list ? |
Unfortunately not, it would be expected to be a |
An update from the internal discussion: So I had hoped that this can be avoid, but it smells like this whole approach needs a rewrite. One idea would be to calculate the virtual embeddings as we do right now ( For generating, we'll probably still need a separate code path, since we do want to use caching for generation. |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. |
not stale |
I have encountered the same issue in prefix tuning. If completely addressing this issue takes time, is there any alternative way to get around with it? |
@JerryLife Could you please try of this suggestion fixes it for you: Please let me know if it works or not, the more data I have, the better I can plan the fix. |
Thank you for your prompt help! It works on my side. Here are more details. My workaround
from typing import Optional
import torch
from peft import PeftType
from peft.utils import TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING
from transformers import DynamicCache
def custom_get_prompt(self, batch_size: int, task_ids: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Returns the virtual prompts to use for Peft. Only applicable when using a prompt learning method.
"""
peft_config = self.active_peft_config
prompt_encoder = self.prompt_encoder[self.active_adapter]
prompt_tokens = (
self.prompt_tokens[self.active_adapter]
.unsqueeze(0)
.expand(batch_size, -1)
.to(prompt_encoder.embedding.weight.device)
)
if peft_config.peft_type == PeftType.PREFIX_TUNING:
prompt_tokens = prompt_tokens[:, : peft_config.num_virtual_tokens]
if peft_config.inference_mode:
past_key_values = prompt_encoder.embedding.weight.repeat(batch_size, 1, 1)
else:
past_key_values = prompt_encoder(prompt_tokens)
if self.base_model_torch_dtype is not None:
past_key_values = past_key_values.to(self.base_model_torch_dtype)
past_key_values = past_key_values.view(
batch_size,
peft_config.num_virtual_tokens,
peft_config.num_layers * 2,
peft_config.num_attention_heads,
peft_config.token_dim // peft_config.num_attention_heads,
)
if peft_config.num_transformer_submodules == 2:
past_key_values = torch.cat([past_key_values, past_key_values], dim=2)
past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(
peft_config.num_transformer_submodules * 2
)
if TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING.get(self.config.model_type, None) is not None:
post_process_fn = TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING[self.config.model_type]
past_key_values = post_process_fn(past_key_values)
############################## Workaround for PEFT 0.13 ##############################
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
######################################################################################
return past_key_values
else:
if peft_config.peft_type == PeftType.MULTITASK_PROMPT_TUNING:
prompts = prompt_encoder(prompt_tokens, task_ids)
else:
if peft_config.inference_mode:
prompts = prompt_encoder.embedding.weight.repeat(batch_size, 1, 1)
else:
prompts = prompt_encoder(prompt_tokens)
return prompts
import gorilla
from peft import PeftModel
patch = gorilla.Patch(peft.PeftModel, 'get_prompt', custom_get_prompt, settings=gorilla.Settings(allow_hit=True))
gorilla.apply(patch) Then, the training works on my side. Thank you again for your asssitance! |
Thanks a lot @JerryLife this really helps us move forward.
If you could give me access to PEFT version 0.13, that would make my life a lot easier ;-) |
@BenjaminBossan I am sorry for the typo. I have corrected it as |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. |
not stale |
See #869, #1962 Fix several issues caused by changes to cache in transformers. In particular, past_key_values for prefix tuning is now converted to a transformers Cache instance. --------- Co-authored-by: Raushan Turganbay <[email protected]>
As reported by @ArthurZucker:
This will presumably affect all prompt learning methods and thus needs to be fixed soon.
For Llama, I identified the following tests, which would result in
past_key_values
being tuples and can serve as a starting point to work on this issue:(note that there are more tests that would fail, this is just a selection)
I haven't really worked on the prompt learning methods in PEFT and know little about the inner workings of transformers cache, so any support would be welcome.
The text was updated successfully, but these errors were encountered: