Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecation: Transformers will no longer support past_key_values to be tuples #1962

Open
BenjaminBossan opened this issue Jul 26, 2024 · 15 comments
Labels

Comments

@BenjaminBossan
Copy link
Member

As reported by @ArthurZucker:

Quick question, I am seeing this in peft:

if uses_cache and (model_kwargs["past_key_values"] is not None):
where there is a reliance on get_seq_length() which we are deprecating + we will no longer convert the cache to tuple object automatically in 2 releases

This will presumably affect all prompt learning methods and thus needs to be fixed soon.

For Llama, I identified the following tests, which would result in past_key_values being tuples and can serve as a starting point to work on this issue:

tests/test_decoder_models.py::PeftDecoderModelTester::test_inference_safetensors_70_test_trl_internal_testing_tiny_random_LlamaForCausalLM_prefix_tuning
tests/test_decoder_models.py::PeftDecoderModelTester::test_passing_input_embeds_works_70_test_trl_internal_testing_tiny_random_LlamaForCausalLM_prefix_tuning
tests/test_decoder_models.py::PeftDecoderModelTester::test_training_prompt_learning_tasks_72_test_trl_internal_testing_tiny_random_LlamaForCausalLM_prefix_tuning

(note that there are more tests that would fail, this is just a selection)

I haven't really worked on the prompt learning methods in PEFT and know little about the inner workings of transformers cache, so any support would be welcome.

@piyush-llm
Copy link

piyush-llm commented Jul 30, 2024

@BenjaminBossan @ArthurZucker Can you provide more information regarding deprecation of get_seq_length(). I have some idea about prompt learning methods and transformer's KV Cache.

Current implementation of cache utilities:
https://github.com/huggingface/transformers/blob/f0bc49e7f61f74f055c47ad40e6010f57eed0b0b/src/transformers/cache_utils.py#L60

Can you point me to the correct branch?

@gante
Copy link
Member

gante commented Aug 1, 2024

Hi @piyush-llm 👋

Alongside a Cache, the prepare_inputs_for_generation method also returns cache_positions on models that have received the update. This new tensor contains the indexes of the cache that will be updated, so its last value can be used as the sequence length :)

@BenjaminBossan
Copy link
Member Author

Thanks for the details. Also check out #869 for discussion and code snippets.

@YOUSEFNANIS
Copy link

As reported by @ArthurZucker:

Quick question, I am seeing this in peft:

if uses_cache and (model_kwargs["past_key_values"] is not None):

where there is a reliance on get_seq_length() which we are deprecating + we will no longer convert the cache to tuple object automatically in 2 releases

This will presumably affect all prompt learning methods and thus needs to be fixed soon.

For Llama, I identified the following tests, which would result in past_key_values being tuples and can serve as a starting point to work on this issue:

tests/test_decoder_models.py::PeftDecoderModelTester::test_inference_safetensors_70_test_trl_internal_testing_tiny_random_LlamaForCausalLM_prefix_tuning
tests/test_decoder_models.py::PeftDecoderModelTester::test_passing_input_embeds_works_70_test_trl_internal_testing_tiny_random_LlamaForCausalLM_prefix_tuning
tests/test_decoder_models.py::PeftDecoderModelTester::test_training_prompt_learning_tasks_72_test_trl_internal_testing_tiny_random_LlamaForCausalLM_prefix_tuning

(note that there are more tests that would fail, this is just a selection)

I haven't really worked on the prompt learning methods in PEFT and know little about the inner workings of transformers cache, so any support would be welcome.

is it ok to just convert past_key_values from a tuple to a list ?

@BenjaminBossan
Copy link
Member Author

is it ok to just convert past_key_values from a tuple to a list ?

Unfortunately not, it would be expected to be a Cache object.

@BenjaminBossan
Copy link
Member Author

An update from the internal discussion: So Cache shouldn't be used at all during training. At the same time, the role of past_key_values will be migrated to mean "cache". This probably means that going forward, we cannot use past_key_values at all for the purpose of injecting "virtual tokens" (or rather, embeddings) as we do right now for prompt learning.

I had hoped that this can be avoid, but it smells like this whole approach needs a rewrite. One idea would be to calculate the virtual embeddings as we do right now (PeftModel.get_prompt) but instead of handing them off to past_key_values, they need to be injected via a pre-forward hook (or, if that's not feasible, we need to monkey patch forward 🙈).

For generating, we'll probably still need a separate code path, since we do want to use caching for generation.

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@BenjaminBossan
Copy link
Member Author

not stale

@JerryLife
Copy link

I have encountered the same issue in prefix tuning. If completely addressing this issue takes time, is there any alternative way to get around with it?

@BenjaminBossan
Copy link
Member Author

@JerryLife Could you please try of this suggestion fixes it for you:

#869 (comment)

Please let me know if it works or not, the more data I have, the better I can plan the fix.

@JerryLife
Copy link

JerryLife commented Sep 13, 2024

Thank you for your prompt help! It works on my side. Here are more details.
Modules: peft==0.12.0, transformers==4.44.2
Model: MistralForCasualLM

My workaround

  1. Define a custom get_prompt function
from typing import Optional
import torch
from peft import PeftType
from peft.utils import TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING
from transformers import DynamicCache

def custom_get_prompt(self, batch_size: int, task_ids: Optional[torch.Tensor] = None) -> torch.Tensor:
    """
    Returns the virtual prompts to use for Peft. Only applicable when using a prompt learning method.
    """
    peft_config = self.active_peft_config
    prompt_encoder = self.prompt_encoder[self.active_adapter]
    prompt_tokens = (
        self.prompt_tokens[self.active_adapter]
        .unsqueeze(0)
        .expand(batch_size, -1)
        .to(prompt_encoder.embedding.weight.device)
    )
    if peft_config.peft_type == PeftType.PREFIX_TUNING:
        prompt_tokens = prompt_tokens[:, : peft_config.num_virtual_tokens]
        if peft_config.inference_mode:
            past_key_values = prompt_encoder.embedding.weight.repeat(batch_size, 1, 1)
        else:
            past_key_values = prompt_encoder(prompt_tokens)
        if self.base_model_torch_dtype is not None:
            past_key_values = past_key_values.to(self.base_model_torch_dtype)
        past_key_values = past_key_values.view(
            batch_size,
            peft_config.num_virtual_tokens,
            peft_config.num_layers * 2,
            peft_config.num_attention_heads,
            peft_config.token_dim // peft_config.num_attention_heads,
        )
        if peft_config.num_transformer_submodules == 2:
            past_key_values = torch.cat([past_key_values, past_key_values], dim=2)
        past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(
            peft_config.num_transformer_submodules * 2
        )
        if TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING.get(self.config.model_type, None) is not None:
            post_process_fn = TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING[self.config.model_type]
            past_key_values = post_process_fn(past_key_values)

        ############################## Workaround for PEFT 0.13 ##############################
        past_key_values = DynamicCache.from_legacy_cache(past_key_values)
        ######################################################################################

        return past_key_values
    else:
        if peft_config.peft_type == PeftType.MULTITASK_PROMPT_TUNING:
            prompts = prompt_encoder(prompt_tokens, task_ids)
        else:
            if peft_config.inference_mode:
                prompts = prompt_encoder.embedding.weight.repeat(batch_size, 1, 1)
            else:
                prompts = prompt_encoder(prompt_tokens)
        return prompts
  1. Monkey patch the function on PeftModel using gorilla
import gorilla
from peft import PeftModel
patch = gorilla.Patch(peft.PeftModel, 'get_prompt', custom_get_prompt, settings=gorilla.Settings(allow_hit=True))
gorilla.apply(patch)

Then, the training works on my side. Thank you again for your asssitance!

@BenjaminBossan
Copy link
Member Author

Thanks a lot @JerryLife this really helps us move forward.

Modules: peft==0.13

If you could give me access to PEFT version 0.13, that would make my life a lot easier ;-)

@JerryLife
Copy link

JerryLife commented Sep 14, 2024

@BenjaminBossan I am sorry for the typo. I have corrected it as peft==0.12.0 and I am looking forward to 0.13 in the near future ;). Thank you for your efforts!

Copy link

github-actions bot commented Oct 8, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@BenjaminBossan
Copy link
Member Author

not stale

BenjaminBossan added a commit that referenced this issue Oct 24, 2024
See #869, #1962

Fix several issues caused by changes to cache in transformers. In
particular, past_key_values for prefix tuning is now converted to a
transformers Cache instance.

---------

Co-authored-by: Raushan Turganbay <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

5 participants