Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add LoRA multihead attention module #1324

Open
wants to merge 47 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
49fab86
[WIP] Add LoRA multihead attention module
BenjaminBossan Jan 5, 2024
d8e9589
Make style
BenjaminBossan Jan 5, 2024
0e188a3
Remove commented code
BenjaminBossan Jan 5, 2024
b409d81
Remove assignment of weight to new module
BenjaminBossan Jan 5, 2024
173062c
Make state_dict and named_parameters work
BenjaminBossan Jan 5, 2024
1e007f5
Extend test coverage a bit
BenjaminBossan Jan 8, 2024
557c4a1
Clean ups after reviewer feedback:
BenjaminBossan Jan 9, 2024
add1f51
Reviewer feedback: removed another unnecessary arg
BenjaminBossan Jan 9, 2024
e44e030
Make style
BenjaminBossan Jan 9, 2024
8d62579
Merge branch 'main' into feat-add-lora-multihead-attention
BenjaminBossan Jan 9, 2024
c5d8a6b
Apply LoRA also to the out_proj of MHA
BenjaminBossan Jan 12, 2024
9dc4a4d
Merge branch 'main' into feat-add-lora-multihead-attention
BenjaminBossan Feb 7, 2024
c3fb2ce
Fix bug with incorrectly set gradient
BenjaminBossan Feb 7, 2024
17d407b
Fix failing tests
BenjaminBossan Feb 7, 2024
4cbf6e9
Merge branch 'main' into feat-add-lora-multihead-attention
BenjaminBossan Feb 26, 2024
e0cae11
Move to pytest style asserts
BenjaminBossan Feb 26, 2024
52c8d9b
Fix safe merging code
BenjaminBossan Feb 26, 2024
977c84b
Merge branch 'main' into feat-add-lora-multihead-attention
BenjaminBossan Mar 11, 2024
96d376d
No need to set bias for MHA anymore, see #1530
BenjaminBossan Mar 11, 2024
0c17476
Merge branch 'main' into feat-add-lora-multihead-attention
BenjaminBossan Mar 26, 2024
4b8db0c
Fix style
BenjaminBossan Mar 26, 2024
7e91712
Merge branch 'main' into feat-add-lora-multihead-attention
BenjaminBossan May 21, 2024
e12070b
Merge branch 'main' into feat-add-lora-multihead-attention
BenjaminBossan Jul 25, 2024
7b6c7cb
Remove duplicate merge
BenjaminBossan Jul 25, 2024
e6ab8ed
Raise error for multi adapter batch inference
BenjaminBossan Jul 25, 2024
8ec6c3c
Raise error for DoRA + MHA
BenjaminBossan Jul 25, 2024
f6ba465
Fix error when adding multiple adapters to MHA
BenjaminBossan Jul 25, 2024
fb18886
Better way of param initialization
BenjaminBossan Jul 26, 2024
4ff2ec3
Add tests for broken loading and workaround
BenjaminBossan Jul 26, 2024
d1f6ab2
make style
BenjaminBossan Jul 26, 2024
65363be
Merge branch 'main' into feat-add-lora-multihead-attention
BenjaminBossan Sep 3, 2024
7ba2e68
Fix wrong merge conflict resolution in test
BenjaminBossan Sep 4, 2024
6ef04b0
Ensure that base weights have requires_grad False
BenjaminBossan Sep 4, 2024
07c7240
Merge branch 'main' into feat-add-lora-multihead-attention
BenjaminBossan Sep 4, 2024
cc3ac3d
Remove xpass-ing test
BenjaminBossan Sep 4, 2024
03c466f
Merge branch 'main' into feat-add-lora-multihead-attention
BenjaminBossan Sep 12, 2024
e558caa
MAINT: Give stale bot permissions for PRs too (#2064)
BenjaminBossan Sep 12, 2024
38f4a98
ENH BOFT don't save boft_P buffer (#2050)
sywangyi Sep 13, 2024
7e5c61d
FIX Command line args in PiSSA preprocess (#2053)
keakon Sep 13, 2024
183bf52
MNT Update deprecated evaluation_strategy (#1664)
muellerzr Sep 13, 2024
b970607
ENH Multi adapters in same batch: modules_to_save (#1990)
saeid93 Sep 17, 2024
732e8e7
FIX Bug that prevents BOFT from loading 2 adapters (#2068)
BenjaminBossan Sep 18, 2024
79e2b38
TST Skip some quantization tests on XPU (#2074)
faaany Sep 18, 2024
61e6934
Improve test coverage for initialization of MHA
BenjaminBossan Sep 18, 2024
ced2f15
Merge branch 'main' into feat-add-lora-multihead-attention
BenjaminBossan Oct 14, 2024
4c31bbc
Fix bug with unloading multihead attention layer
BenjaminBossan Oct 21, 2024
1dbb9a5
Fix bug in unloading
BenjaminBossan Oct 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
310 changes: 310 additions & 0 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ def __init__(self, base_layer: nn.Module, ephemeral_gpu_offload: bool = False, *
in_features, out_features = (
base_layer.weight.ds_shape if hasattr(base_layer.weight, "ds_shape") else base_layer.weight.shape
)
elif isinstance(base_layer, nn.MultiheadAttention):
if not base_layer._qkv_same_embed_dim:
raise ValueError(f"Only same dim for query/key/value is supported as of now for {self.__class__}.")
in_features, out_features = base_layer.embed_dim, 3 * base_layer.embed_dim
elif hasattr(base_layer, "infeatures") and hasattr(base_layer, "outfeatures"):
# QuantLinear
in_features, out_features = base_layer.infeatures, base_layer.outfeatures
Expand Down Expand Up @@ -1154,6 +1158,309 @@ def _get_dora_layer_class(self):
return DoraConv3dLayer


class MultiheadAttention(nn.Module, LoraLayer):
"""LoRA implemented in a multihead attention layer

This is currently only implemented for the case of `_qkv_same_embed_dim = True`, i.e. query, key, and value having
the same dimension.

Note: LoRA is applied to both the in_proj (query/key/value) and out_proj. There is currently no way to specify only
one of them. Don't try to apply LoRA to the out_proj of MultiheadAttention by targeting that layer specifically,
since the forward method of that layer is not being used, hence the LoRA adapter would be ignored.

This is a little bit hacky because of the way that MultiheadAttention is implemented in PyTorch. It works by
merging the weights before the forward call and unmerging them after the forward call.
"""

def __init__(
self,
base_layer,
adapter_name: str,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
init_lora_weights: Union[bool, str] = True,
use_rslora: bool = False,
use_dora: bool = False,
**kwargs,
) -> None:
# TODO work with separate weights
if not getattr(base_layer, "_qkv_same_embed_dim", True):
# default for this value appears to be True:
# https://github.com/pytorch/pytorch/blob/701ba5203fe68d55d655bd4d6c008be94cf34ea5/torch/nn/modules/activation.py#L1128-L1130
raise ValueError(
f"Only same embed for query/key/value is supported as of now for {self.__class__.__name__}."
)
if use_dora:
# TODO: probably not so hard to implement
raise ValueError(f"{self.__class__.__name__} does not support DoRA (yet), please set use_dora to False")

super().__init__()
LoraLayer.__init__(self, base_layer, **kwargs)

# Note: LoRA is applied to both in_proj and out_proj. There is currently no way to only specify one of them.
if isinstance(base_layer.out_proj, nn.Linear):
self.base_layer.out_proj = Linear(
base_layer.out_proj,
adapter_name,
r=r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
init_lora_weights=init_lora_weights,
use_rslora=use_rslora,
use_dora=use_dora,
**kwargs,
)
else:
raise ValueError(f"out_proj must be an instance of nn.Linear for {self.__class__.__name__}.")

self._active_adapter = adapter_name
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora)

def update_layer(self, *args, **kwargs) -> None:
super().update_layer(*args, **kwargs)
# Note: LoRA is applied to both in_proj and out_proj. There is currently no way to only specify one of them.
self.base_layer.out_proj.update_layer(*args, **kwargs)

def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
"""
Merge the active adapter weights into the base weights

Args:
safe_merge (`bool`, *optional*):
If True, the merge operation will be performed in a copy of the original weights and check for NaNs
before merging the weights. This is useful if you want to check if the merge operation will produce
NaNs. Defaults to `False`.
adapter_names (`List[str]`, *optional*):
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
to `None`.
"""
adapter_names = check_adapters_to_merge(self, adapter_names)
if not adapter_names:
# no adapter to merge
return

# Implementation follows this:
# https://github.com/Baijiong-Lin/LoRA-Torch/blob/4bfed6820b64fcf47064c30f30606a190a4f0d2e/loratorch/layers.py#L73-L79
# Notably, instead of mutating the weight, we delete the original weight and replace it by the merged weight
# TODO: work with separate weights
for active_adapter in adapter_names:
if active_adapter in self.lora_A.keys():
base_layer = self.get_base_layer()
if safe_merge:
# TODO: work with separate weights
# merging in_proj
orig_weights_in = base_layer.in_proj_weight.data.detach().clone()
orig_weights_in += self.get_delta_weight(active_adapter)
if not torch.isfinite(orig_weights_in).all():
raise ValueError(
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
)

# merging out_proj
orig_weights_out = base_layer.out_proj.weight.data.detach().clone()
orig_weights_out += base_layer.out_proj.get_delta_weight(active_adapter)
if not torch.isfinite(orig_weights_out).all():
raise ValueError(
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
)

del base_layer.in_proj_weight
base_layer.in_proj_weight = orig_weights_in

del base_layer.out_proj.get_base_layer().weight
base_layer.out_proj.get_base_layer().weight = orig_weights_out
base_layer.out_proj.merge(adapter_names=[active_adapter])
else:
# merging in_proj
# TODO: work with separate weights
weight_merged = base_layer.in_proj_weight.data.detach() + self.get_delta_weight(active_adapter)
del base_layer.in_proj_weight
base_layer.in_proj_weight = weight_merged

# merging out_proj
weight_merged = base_layer.out_proj.weight.data.detach() + base_layer.out_proj.get_delta_weight(
active_adapter
)
del base_layer.out_proj.get_base_layer().weight
base_layer.out_proj.get_base_layer().weight = weight_merged
base_layer.out_proj.merge(adapter_names=[active_adapter])
self.merged_adapters.append(active_adapter)

def unmerge(self) -> None:
"""
This method unmerges all merged adapter layers from the base weights.
"""
if not self.merged:
warnings.warn("Already unmerged. Nothing to do.")
return

# TODO work with separate weights
base_layer = self.get_base_layer()
while len(self.merged_adapters) > 0:
active_adapter = self.merged_adapters.pop()
if active_adapter in self.lora_A.keys():
# Ensure that requires_grad=False for the base weights after unmerging. This may not matter since
# requires_grad was False when the optimizer was initialized, but still let's try to be correct here.

# in_proj
old_weight = base_layer.in_proj_weight.data - self.get_delta_weight(active_adapter)
del base_layer.in_proj_weight
base_layer.register_parameter("in_proj_weight", nn.Parameter(old_weight, requires_grad=False))

# out_proj
old_weight = base_layer.out_proj.base_layer.weight.data - base_layer.out_proj.get_delta_weight(
active_adapter
)
del base_layer.out_proj.base_layer.weight
base_layer.out_proj.base_layer.register_parameter(
"weight", nn.Parameter(old_weight, requires_grad=False)
)

self.get_base_layer().out_proj.unmerge()

def unload_and_optionally_merge_module(
self, merge: bool, safe_merge: bool, adapter_names: Optional[list[str]]
) -> nn.MultiheadAttention:
"""
Merging and unloading of the MultiheadAttention module

This requires an extra step for MultiheadAttention, which is why there is this special method instead of
relying on the normal merge_and_unload code path.
"""
if merge:
self.merge(safe_merge=safe_merge, adapter_names=adapter_names)
base_layer = self.get_base_layer()

# extra steps: re-register weights, take care of out_proj layer
# in_proj
weight = base_layer.in_proj_weight
del base_layer.in_proj_weight
base_layer.register_parameter("in_proj_weight", nn.Parameter(weight.data, requires_grad=weight.requires_grad))

# out_proj
out_proj_layer = base_layer.out_proj.get_base_layer()
weight = out_proj_layer.weight
del out_proj_layer.weight
out_proj_layer.register_parameter("weight", nn.Parameter(weight.data, requires_grad=weight.requires_grad))

base_layer.out_proj = out_proj_layer
return base_layer

def get_delta_weight(self, adapter) -> torch.Tensor:
"""
Compute the delta weight for the given adapter.

Args:
adapter (str):
The name of the adapter for which the delta weight should be computed.
"""
device = self.lora_B[adapter].weight.device
dtype = self.lora_B[adapter].weight.dtype

# In case users wants to merge the adapter weights that are in
# float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to
# float16 because the `@` and matmul operation in general is not supported in torch + cpu + fp16.
cast_to_fp32 = device.type == "cpu" and dtype == torch.float16

weight_A = self.lora_A[adapter].weight
weight_B = self.lora_B[adapter].weight

if cast_to_fp32:
weight_A = weight_A.float()
weight_B = weight_B.float()

output_tensor = (weight_B @ weight_A) * self.scaling[adapter]

if cast_to_fp32:
output_tensor = output_tensor.to(dtype=dtype)

# cast back the weights
self.lora_A[adapter].weight.data = weight_A.to(dtype)
self.lora_B[adapter].weight.data = weight_B.to(dtype)

return output_tensor

def _check_forward_args(self, x, *args, **kwargs):
if "adapter_names" in kwargs:
raise TypeError(f"lora.{self.__class__.__name__} does not support mixed adapter batches.")
super()._check_forward_args(x, *args, **kwargs)

def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
previous_dtype = x.dtype
self._check_forward_args(x, *args, **kwargs)

if self.disable_adapters:
if self.merged:
self.unmerge()
result = self.base_layer(x, *args, **kwargs)
elif self.merged:
result = self.base_layer(x, *args, **kwargs)
else:
out_proj = self.get_base_layer().out_proj
if out_proj.active_adapters != self.active_adapters:
# We have a case that in_proj and out_proj have diverging merged adapters. We cannot
# really deal with this correctly, thus it's better to raise than possibly create a hard to debug mess
cls_name = self.get_base_layer().__class__.__name__
raise ValueError(
f"The out_proj layer of {cls_name} has merged layers but {cls_name} itself doesn't; please ensure "
"that either both or none have merged layers"
)

# Merge all adapters that are active for this module, i.e. the LoRA weights for in_proj and out_proj.
# in_proj uses nn.Parameters, therefore, there is no forward method to be used and we have to explicitly
# merge for the LoRA weights to have an effect:
# https://github.com/pytorch/pytorch/blob/6ebb26d572d5fcdc6ac0d1297bdf8d1eb5d20722/torch/nn/modules/activation.py#L1020
# For out_proj, we have an nn.Linear (or rather: NonDynamicallyQuantizableLinear), but its forward method
# is not used:
# https://github.com/pytorch/pytorch/blob/6ebb26d572d5fcdc6ac0d1297bdf8d1eb5d20722/torch/nn/modules/activation.py#L1267-L1271
# Therefore, its LoRA weights also need to be merged to have an effect.
active_adapters = [a for a in self.active_adapters if a in self.lora_A]
try:
self.merge(adapter_names=active_adapters)
result = self.base_layer(x, *args, **kwargs)
finally:
# it's safe to call unmerge(), which unmerges all adapters, because we checked that not self.merged,
# i.e. there is was no merged layer before
self.unmerge()

result = (result[0].to(previous_dtype), result[1].to(previous_dtype) if result[1] is not None else result[1])
return result

def _restore_weights(self):
# Restore the weights as registered parameters on the base layer.
# This is necessary because the way that weights are merged/unmerged (which is necessary for forward to work
# correctly), the Module "forgets" these attributes. Therefore, we need to call register_parameter explicitly.
# We cannot call register_parameter for merging/unmerging because that cuts them off from the autograd graph.
# Note that this is hacky, since we need to ensure that _restore_weights is called by each method that needs it.

# in_proj
# TODO work with separate weights
base_layer = self.get_base_layer()
weight = base_layer.in_proj_weight
del base_layer.in_proj_weight
base_layer.register_parameter("in_proj_weight", nn.Parameter(weight.data, requires_grad=weight.requires_grad))

# out_proj
base_layer = base_layer.out_proj.get_base_layer()
weight = base_layer.weight
del base_layer.weight
base_layer.register_parameter("weight", nn.Parameter(weight.data, requires_grad=weight.requires_grad))

def state_dict(self, *args, **kwargs):
self._restore_weights()
return super().state_dict(*args, **kwargs)

def named_modules(self, *args, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need also to over-write the modules() method?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not needed, as modules calls named_modules under the hood. I added a comment to that effect.

# Note: no need to also implement modules(), as modules() calls named_modules() under the hood
self._restore_weights()
return super().named_modules(*args, **kwargs)

def __repr__(self) -> str:
rep = super().__repr__()
return "lora." + rep


def dispatch_default(
target: torch.nn.Module,
adapter_name: str,
Expand All @@ -1178,6 +1485,9 @@ def dispatch_default(
elif isinstance(target_base_layer, torch.nn.Conv3d):
kwargs.update(lora_config.loftq_config)
new_module = Conv3d(target, adapter_name, **kwargs)
elif isinstance(target_base_layer, torch.nn.MultiheadAttention):
kwargs.update(lora_config.loftq_config)
new_module = MultiheadAttention(target, adapter_name, **kwargs)
elif isinstance(target_base_layer, torch.nn.Linear):
if kwargs["fan_in_fan_out"]:
warnings.warn(
Expand Down
39 changes: 20 additions & 19 deletions src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,14 +246,6 @@ def _replace_module(self, parent, child_name, new_module, child):
if hasattr(child, "base_layer"):
child = child.base_layer

if not hasattr(new_module, "base_layer"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this has been removed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, forgot to put this into the description of the PR.

These lines are obsolete for some time now. They only apply when we unload the model (otherwise, the if does not match). Remember when we made the base_layer switch, we ensured that when unloading, we simply return the base_layer, no more need to create a new layer (say, a new nn.Linear when using lora.Linear) and replace the new layer's weight by the parent layer's weight. The base_layer already has the original weight. Therefore, these lines are unnecessary.

I removed them now because they were annoying with MultiheadAttention, because that layer has no weight attribute, so this line would fail.

if hasattr(new_module, "W_q"): # HQQ
new_module.W_q = child.W_q
else:
new_module.weight = child.weight
if hasattr(child, "bias"):
new_module.bias = child.bias

if getattr(child, "state", None) is not None:
if hasattr(new_module, "base_layer"):
new_module.base_layer.state = child.state
Expand All @@ -265,15 +257,18 @@ def _replace_module(self, parent, child_name, new_module, child):
# dispatch to correct device
for name, module in new_module.named_modules():
if (self.prefix in name) or ("ranknum" in name):
weight = (
child.qweight
if hasattr(child, "qweight")
else child.W_q
if hasattr(child, "W_q")
else child.weight
if hasattr(child, "weight")
else next(child.parameters())
)
if hasattr(child, "qweight"):
weight = child.qweight
elif hasattr(child, "W_q"):
weight = child.W_q
elif hasattr(child, "weight"):
weight = child.weight
elif getattr(child, "in_proj_weight", None) is not None: # MHA
weight = child.in_proj_weight
elif getattr(child, "q_proj_weight", None) is not None: # MHA
weight = child.q_proj_weight
else:
weight = next(child.parameters())
if not any(p.device == meta for p in module.parameters()):
module.to(weight.device)

Expand Down Expand Up @@ -359,7 +354,7 @@ def dynamic_dispatch_func(target, adapter_name, lora_config, **kwargs):
raise ValueError(
f"Target module {target} is not supported. Currently, only the following modules are supported: "
"`torch.nn.Linear`, `torch.nn.Embedding`, `torch.nn.Conv2d`, `torch.nn.Conv3d`, "
"`transformers.pytorch_utils.Conv1D`."
"`transformers.pytorch_utils.Conv1D`, `torch.nn.MultiheadAttention.`."
)

return new_module
Expand Down Expand Up @@ -508,7 +503,13 @@ def _unload_and_optionally_merge(
except AttributeError:
continue
with onload_layer(target):
if hasattr(target, "base_layer"):
if hasattr(target, "unload_and_optionally_merge_module"):
# if layers have special unloading method, like MultiheadAttention, use that
unloaded_module = target.unload_and_optionally_merge_module(
merge=merge, safe_merge=safe_merge, adapter_names=adapter_names
)
self._replace_module(parent, target_name, unloaded_module, target)
elif hasattr(target, "base_layer"):
if merge:
target.merge(safe_merge=safe_merge, adapter_names=adapter_names)
self._replace_module(parent, target_name, target.get_base_layer(), target)
Expand Down
Loading
Loading