Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add LoRA multihead attention module #1324

Open
wants to merge 47 commits into
base: main
Choose a base branch
from

Conversation

BenjaminBossan
Copy link
Member

@BenjaminBossan BenjaminBossan commented Jan 5, 2024

First stab at adding LoRA support for nn.MultiheadAttention. See #761.

Todos:

  • For now, only works with _qkv_same_embed_dim=True -- make it work with False too. _qkv_same_embed_dim=False is out of scope for this PR and can be added in a later PR if needed.
  • Show that it works in a real world test: See user feedback on the issue.
  • Unit tests
  • Docs Apart from docstrings, I don't think anything else needs to be added

Update: I now also included the out_proj to apply LoRA to.

This is a simple test that I ran successfully with the PR in its current state:

import open_clip
import requests
import torch
from torch import nn
from peft import LoraConfig, get_peft_model
from PIL import Image
from peft.tuners.lora.layer import MultiheadAttention as PeftMha

model, preprocess = open_clip.create_model_from_pretrained('hf-hub:laion/CLIP-ViT-g-14-laion2B-s12B-b42K')
tokenizer = open_clip.get_tokenizer('hf-hub:laion/CLIP-ViT-g-14-laion2B-s12B-b42K')
peft_model = get_peft_model(model, config)
opt = torch.optim.SGD(peft_model.parameters(), 0.1)
print(len([m for m in peft_model.modules() if isinstance(m, PeftMha)]))  # 64 PEFT MHA layers
peft_model.print_trainable_parameters()  # trainable params: 2,588,672 || all params: 1,055,873,793 || trainable%: 0.24516869508096598

# text encoder
text = tokenizer(["a diagram", "a dog", "a cat"])
text_features = peft_model.encode_text(text)
loss = text_features.sum()
loss.backward()
opt.step()

# image encoder
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
image = preprocess(image).unsqueeze(0)
image_features = model.encode_image(image)
image_features.sum().backward()
opt.step()

For now, only works with _qkv_same_embed_dim=True.
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

This is no longer necessary when unloading the model because the
base_layer is already the original layer. This is just a leftover
from before we adopted the base_layer pattern.
There was a bug because the removal of the parameter resulted in it no
longer appearing in the state_dict and named_parameters. This commit
fixes this bug.

The bug also exists in the referenced lora-torch library.
Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work ! I left few preliminary comments, I think we can go for the _restore_weights approach for now as I don't see any other alternative

lora_alpha: int = 1,
lora_dropout: float = 0.0,
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
is_target_conv_1d_layer: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
is_target_conv_1d_layer: bool = False,

I don't think this is used?


self._active_adapter = adapter_name
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora)
self.is_target_conv_1d_layer = is_target_conv_1d_layer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.is_target_conv_1d_layer = is_target_conv_1d_layer

We can also just hard-code it to False

self._restore_weights()
return super().state_dict(*args, **kwargs)

def named_modules(self, *args, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need also to over-write the modules() method?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not needed, as modules calls named_modules under the hood. I added a comment to that effect.

@@ -193,11 +193,6 @@ def _replace_module(self, parent, child_name, new_module, child):
if hasattr(child, "base_layer"):
child = child.base_layer

if not hasattr(new_module, "base_layer"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this has been removed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, forgot to put this into the description of the PR.

These lines are obsolete for some time now. They only apply when we unload the model (otherwise, the if does not match). Remember when we made the base_layer switch, we ensured that when unloading, we simply return the base_layer, no more need to create a new layer (say, a new nn.Linear when using lora.Linear) and replace the new layer's weight by the parent layer's weight. The base_layer already has the original weight. Therefore, these lines are unnecessary.

I removed them now because they were annoying with MultiheadAttention, because that layer has no weight attribute, so this line would fail.

Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you Benjamin for adding support for torch MHA layer in LoRA, interesting way to use merge, forward and unmerge logic!

@BenjaminBossan
Copy link
Member Author

@younesbelkada Could I address all your concerns?

I pinged the user who wanted to test it on their case. When it comes to docs, I didn't really find a place where we list all supported layers, so no update needed really.

Before, LoRA was applied only to the in_proj. Now it is also applied to
the out_proj.

Unfortunately, there is no easy way to just apply a normal lora.Linear
to the out_proj by targeting it with target_modules. If that worked, it
would be much nicer to do that, so that users can decide for themselves
if they want to apply LoRA to the out_proj or not.

The reason why it doesn't work is twofold:

1. We cannot really control the order in which LoRA is applied, so when
   the LoRA adapter is injected to out_proj, the whole MHA layer may
   already be wrapped by lora.MultiheadAttention.
2. Even if we successfully applied a normal lora.Linear to the out_proj,
   it would not work correctly. This is because the forward method of
   out_proj is not used at all by nn.MultiheadAttention. Instead, it
   just passes the weight and bias to F.multi_head_attention_forward.
   Therefore, we must ensure that the weights are merged and unmerged
   correctly, same as for in_proj, and we cannot do that if we use a
   normal lora.Linear.

Note that the test test_merge_layers for MHA fails. This is most likely
because of an existing bug in now merging is implemented, see PR huggingface#1355.
Once that is merged, the test should pass.
@BenjaminBossan
Copy link
Member Author

Note: The test test_merge_layers for MHA fails. This is most likely because of an existing bug in how merging is implemented, see PR #1355. Once that is merged, the test should pass.

@ambroser53
Copy link

Just want to bump a bunch of the issues I've mentioned in #761 but specifically the problem with requires_grad reproducable in this repo

@bghira
Copy link

bghira commented Feb 26, 2024

just wanted to bump this one because it's really the only way for tuning CLIP models after they are released.

@BenjaminBossan
Copy link
Member Author

@bghira Do you happen to have a use case where you could test if this PR works and is working well enough speed-wise? I think the implementation could be ready to be merged but ideally we'd have someone with a real use case give it a try.

@bghira
Copy link

bghira commented Feb 26, 2024

i do and i may be able to test it. stupid question but is the code example above complete? i dont see the hinge loss function

@BenjaminBossan
Copy link
Member Author

stupid question but is the code example above complete? i dont see the hinge loss function

You mean the code right at the top? No, it's not complete at all, just a quick test to show that MHA is applied and the backward pass does not fail. This is not proper nor complete training code.

@damian0815
Copy link

damian0815 commented Jul 26, 2024

it's only happening after calling .forward() on the model (restoring the state dict before that works fine). moreover if i put a breakpoint on the line where the failing restore happens and execute set(model.state_dict().keys()).symmetric_difference(restore_state_dict.keys()) in the debugger, the result is an empty set().

@BenjaminBossan
Copy link
Member Author

definitely useful, yes.

That's good to hear. Hopefully this PR can be merged some day so that we can have MHA support in PEFT proper, it's just that multihead attention is implemented in a way that makes applying LoRA very difficult and requires some hacks. To wit:

However, the restoring fails with

I think this is related to this part:

https://github.com/huggingface/peft/pull/1324/files#diff-24a141c266b7b714ae8fcc470f31bc283f7b0f5a671bbf6d5f092741fc374104R1290-R1294

Could you check if calling _restore_weights manually would solve the error?

The code would be something along these lines:

for module in model.modules():
    if isinstance(module, peft.tuners.lora.MultiheadAttention):
        module._restore_weights()

@damian0815
Copy link

damian0815 commented Jul 26, 2024

yes, that solved it - thanks (but i had to use peft.tuners.lora.layers.MultiheadAttention for the fully qualified module class)

@BenjaminBossan
Copy link
Member Author

Great, thanks for confirming @damian0815, and sorry for the wrong path.

I tried to create a unit test based on the description you provided, I think I could reproduce your error. Could you quickly check if the test captures your situation?

@pytest.mark.xfail(strict=True)
def test_mha_load_init_model_first():
    # this test fails as it currently requires a workaround to pass, see test below
    # https://github.com/huggingface/peft/pull/1324#issuecomment-2252473980
    inputs = torch.rand(10, 10, 10)
    model = ModelMha()
    config = LoraConfig(target_modules=["mha"], init_lora_weights=False)
    model = get_peft_model(model, config).eval()
    restore_state_dict = {k: v.detach().cpu() for k, v in model.state_dict().items()}

    del model

    model = ModelMha()
    # inferencing with PEFT model first is necessary to trigger the error in load_state_dict
    model = get_peft_model(model, config)
    model(inputs)
    model.load_state_dict(restore_state_dict)


def test_mha_load_init_model_first_with_workaround():
    import peft

    inputs = torch.rand(10, 10, 10)
    model = ModelMha()
    config = LoraConfig(target_modules=["mha"], init_lora_weights=False)
    model = get_peft_model(model, config).eval()
    with torch.inference_mode():
        output_before = model(inputs)
        restore_state_dict = {k: v.detach().cpu() for k, v in model.state_dict().items()}

    del model

    model = ModelMha()
    model = get_peft_model(model, config)
    model(inputs)

    # workaround, see test above
    for module in model.modules():
        if isinstance(module, peft.tuners.lora.layer.MultiheadAttention):
            module._restore_weights()

    model.load_state_dict(restore_state_dict)
    with torch.inference_mode():
        output_after = model(inputs)

    assert torch.allclose(output_before, output_after)

Unfortunately, I could not find a way to hook into load_state_dict to automatically call _restore_weights, since load_state_dict is not recursive, so the PEFT MultiheadAttention is never directly invoked :( I hope this is enough of an edge case that I can ignore it for now.

@damian0815
Copy link

damian0815 commented Jul 26, 2024

looks about right - we're not deleting/reloading the model in-between though, simply messing with the weights (doing a blend with the base model -- which is in fact disabled when LoRA training is active, but the save/restore logic runs anyway) and then restoring the weights by loading the restore_state_dict in place.

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

BenjaminBossan and others added 14 commits September 3, 2024 16:54
There was a situation were loading the state dict would fail and require
a workaround. For this, there was an xfail-ing test with strict=True.
This test no longer fails, so the marker has been removed, as well as
the test with the workaround.
The buffer does not need to be part of the checkpoint, by making it
non-persistent, the file size can be greatly reduced.
Fix bug in parsing command line arguments in the PiSSA preprocess.py script from
the PiSSA example.
In docs and examples, use eval_strategy instead of evaluation_strategy, which is
deprecated.
Extend the functionality of having different adapters in the same batch to also
work with `modules_to_save`.
There was a bug in BOFT that made it impossible in some circumstances to
load more than one adapter (creating more than 1 adapter was possible
though). This was because a code path that adjusts
boft_n_butterfly_factor was only visited when creating a fresh adapter,
but not when updating with the 2nd adapter. This was fixed by moving
this code path from the BOFT layer's __init__ method to update_layer.

A test for loading multiple adapters was added. Since this was a gap in
our test suite, this test will be applied to all appropriate PEFT
methods, not only BOFT, but the others methods are all passing without
needing further changes.

For good measure, I also added BOFT to the test suite that checks
multiple active adapters. These tests would have also passed without the
fix in this PR, since these tests do not load multiple adapters but
instead create them, which always worked. Still it's better to have
these tests as well.
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@mashijie1028
Copy link

mashijie1028 commented Oct 20, 2024

@BenjaminBossan
Hi! I found that LoRA does not work for in_proj_weight in attn of open_clip. I was wondering how to fix this.
To be more specific, when I implement LoRA as follows:

lora_config = LoraConfig(
    r=16,
    target_modules=["in_proj_weight"],
    lora_alpha=32,
    lora_dropout=0.05
)

An error occurs as ValueError: Target modules {'in_proj_weight'} not found in the base model. Please check the target modules and try again.
But when I implement for out_proj, LoRA works fine!
Could you please tell me how to set target_modules in LoraConfig to implement LoRA on attn layers? Thanks!

By the way, I download peft as you mentioned before:

python -m pip install git+https://github.com/BenjaminBossan/peft.git@feat-add-lora-multihead-attention

(I report the same issue here)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.