forked from datamllab/LongLM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
modify_utils.py
51 lines (45 loc) · 2.69 KB
/
modify_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from types import MethodType
def modify_method_of_instance(instance, target_class_name, target_method_name, new_method, visited_instances=None):
"""
This function modifies the method of an instance of a model class.
It's part from chat-GPT.
It will replace the method with the new method.
Currently, we only use this function to modify the attention method of a model. Do not test it further.
instance:
instance of a model to modify.
target_class_name:
name of the attention class to modify. E.g. 'LlamaAttention', 'GPTNeoXAttention', etc.
new_method: new method to replace the original method. E.g. 'self_extend_forward'.
It should include a parameter 'self' to be binded to the instance.
"""
if visited_instances is None:
visited_instances = set()
# Unique identifier for the instance (using id() since object's id is unique)
instance_id = id(instance)
if instance_id in visited_instances:
return
# Add the instance to the already_visited set
visited_instances.add(instance_id)
# Check if this instance is of the target class
if instance.__class__.__name__ == target_class_name:
bond_method = MethodType(new_method, instance)
setattr(instance, target_method_name, bond_method)
elif hasattr(instance, '__dict__'):
for attr_name, attr_value in instance.__dict__.items():
if isinstance(attr_value, object) and not isinstance(attr_value, (list, tuple, dict, set)):
modify_method_of_instance(attr_value, target_class_name, target_method_name, new_method, visited_instances)
elif isinstance(attr_value, (list, tuple)):
for item in attr_value:
if isinstance(item, object):
modify_method_of_instance(item, target_class_name, target_method_name, new_method, visited_instances)
# If attribute value is a dictionary, iterate over its values and recurse
# E.g, for a ModuleList, its moudels are stored in a dictionary: ._modules
elif isinstance(attr_value, dict):
for key, value in attr_value.items():
if isinstance(value, object):
modify_method_of_instance(value, target_class_name, target_method_name, new_method, visited_instances)
# If attribute value is a set, iterate and recurse
elif isinstance(attr_value, set):
for item in attr_value:
if isinstance(item, object):
modify_method_of_instance(item, target_class_name, target_method_name, new_method, visited_instances)