-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Add LoRA multihead attention module #1324
base: main
Are you sure you want to change the base?
Changes from all commits
49fab86
d8e9589
0e188a3
b409d81
173062c
1e007f5
557c4a1
add1f51
e44e030
8d62579
c5d8a6b
9dc4a4d
c3fb2ce
17d407b
4cbf6e9
e0cae11
52c8d9b
977c84b
96d376d
0c17476
4b8db0c
7e91712
e12070b
7b6c7cb
e6ab8ed
8ec6c3c
f6ba465
fb18886
4ff2ec3
d1f6ab2
65363be
7ba2e68
6ef04b0
07c7240
cc3ac3d
03c466f
e558caa
38f4a98
7e5c61d
183bf52
b970607
732e8e7
79e2b38
61e6934
ced2f15
4c31bbc
1dbb9a5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -246,14 +246,6 @@ def _replace_module(self, parent, child_name, new_module, child): | |
if hasattr(child, "base_layer"): | ||
child = child.base_layer | ||
|
||
if not hasattr(new_module, "base_layer"): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why this has been removed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, forgot to put this into the description of the PR. These lines are obsolete for some time now. They only apply when we unload the model (otherwise, the I removed them now because they were annoying with |
||
if hasattr(new_module, "W_q"): # HQQ | ||
new_module.W_q = child.W_q | ||
else: | ||
new_module.weight = child.weight | ||
if hasattr(child, "bias"): | ||
new_module.bias = child.bias | ||
|
||
if getattr(child, "state", None) is not None: | ||
if hasattr(new_module, "base_layer"): | ||
new_module.base_layer.state = child.state | ||
|
@@ -265,15 +257,18 @@ def _replace_module(self, parent, child_name, new_module, child): | |
# dispatch to correct device | ||
for name, module in new_module.named_modules(): | ||
if (self.prefix in name) or ("ranknum" in name): | ||
weight = ( | ||
child.qweight | ||
if hasattr(child, "qweight") | ||
else child.W_q | ||
if hasattr(child, "W_q") | ||
else child.weight | ||
if hasattr(child, "weight") | ||
else next(child.parameters()) | ||
) | ||
if hasattr(child, "qweight"): | ||
weight = child.qweight | ||
elif hasattr(child, "W_q"): | ||
weight = child.W_q | ||
elif hasattr(child, "weight"): | ||
weight = child.weight | ||
elif getattr(child, "in_proj_weight", None) is not None: # MHA | ||
weight = child.in_proj_weight | ||
elif getattr(child, "q_proj_weight", None) is not None: # MHA | ||
weight = child.q_proj_weight | ||
else: | ||
weight = next(child.parameters()) | ||
if not any(p.device == meta for p in module.parameters()): | ||
module.to(weight.device) | ||
|
||
|
@@ -359,7 +354,7 @@ def dynamic_dispatch_func(target, adapter_name, lora_config, **kwargs): | |
raise ValueError( | ||
f"Target module {target} is not supported. Currently, only the following modules are supported: " | ||
"`torch.nn.Linear`, `torch.nn.Embedding`, `torch.nn.Conv2d`, `torch.nn.Conv3d`, " | ||
"`transformers.pytorch_utils.Conv1D`." | ||
"`transformers.pytorch_utils.Conv1D`, `torch.nn.MultiheadAttention.`." | ||
) | ||
|
||
return new_module | ||
|
@@ -508,7 +503,13 @@ def _unload_and_optionally_merge( | |
except AttributeError: | ||
continue | ||
with onload_layer(target): | ||
if hasattr(target, "base_layer"): | ||
if hasattr(target, "unload_and_optionally_merge_module"): | ||
# if layers have special unloading method, like MultiheadAttention, use that | ||
unloaded_module = target.unload_and_optionally_merge_module( | ||
merge=merge, safe_merge=safe_merge, adapter_names=adapter_names | ||
) | ||
self._replace_module(parent, target_name, unloaded_module, target) | ||
elif hasattr(target, "base_layer"): | ||
if merge: | ||
target.merge(safe_merge=safe_merge, adapter_names=adapter_names) | ||
self._replace_module(parent, target_name, target.get_base_layer(), target) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need also to over-write the
modules()
method?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not needed, as
modules
callsnamed_modules
under the hood. I added a comment to that effect.