@@ -108,7 +108,7 @@ def __init__(self, cfg: Any) -> None:
108
108
}
109
109
110
110
def split_qkv_matrix (
111
- self , attention_bridge : JointQKVAttentionBridge
111
+ self , original_attention_component : Any
112
112
) -> tuple [torch .nn .Linear , torch .nn .Linear , torch .nn .Linear ]:
113
113
"""Split the QKV matrix into separate linear transformations.
114
114
Args:
@@ -118,54 +118,49 @@ def split_qkv_matrix(
118
118
"""
119
119
120
120
# Keep mypy happy
121
- assert attention_bridge .original_component is not None
122
- assert isinstance (attention_bridge .original_component .query_key_value , LinearBridge )
123
- assert attention_bridge .original_component .query_key_value .original_component is not None
121
+ assert original_attention_component is not None
122
+ assert original_attention_component .query_key_value is not None
124
123
125
- qkv_weights = attention_bridge . original_component . query_key_value . original_component .weight
124
+ qkv_weights = original_attention_component . query_key_value .weight
126
125
127
126
# Keep mypy happy
128
127
assert isinstance (qkv_weights , torch .Tensor )
129
128
130
- d_head = self .cfg .hidden_size // self .cfg .n_head
131
-
132
- # Original qkv_weights shape: [3 * n_head * d_head, d_model]
133
- # We want to split it into [d_model, n_head * d_head] for each of Q, K, V
134
- W_split = qkv_weights .T .reshape (self .cfg .hidden_size , 3 , self .cfg .n_head * d_head )
129
+ # We want to split weights into [d_model, n_heads * d_head] for each of Q, K, V
130
+ W_split = qkv_weights .T .reshape (self .cfg .d_model , 3 , self .cfg .n_heads * self .cfg .d_head )
135
131
136
132
W_Q , W_K , W_V = W_split [:, 0 , :], W_split [:, 1 , :], W_split [:, 2 , :]
137
133
138
- qkv_bias = attention_bridge . original_component . query_key_value . original_component .bias
134
+ qkv_bias = original_attention_component . query_key_value .bias
139
135
140
136
# Keep mypy happy
141
137
assert isinstance (qkv_bias , torch .Tensor )
142
138
143
- # Original qkv_bias shape: [3 * n_head * d_head]
144
- # Reshape to [3, n_head * d_head] to split by Q, K, V
145
- qkv_bias = qkv_bias .reshape (3 , self .cfg .n_head * d_head )
139
+ # Reshape to [3, n_heads * d_head] to split by Q, K, V
140
+ qkv_bias = qkv_bias .reshape (3 , self .cfg .n_heads * self .cfg .d_head )
146
141
147
142
b_Q , b_K , b_V = qkv_bias [0 , :], qkv_bias [1 , :], qkv_bias [2 , :]
148
143
149
144
# Create nn.Linear modules
150
- # W_Q, W_K, W_V shapes are [d_model, n_head * d_head]
145
+ # W_Q, W_K, W_V shapes are [d_model, n_heads * d_head]
151
146
# nn.Linear expects weight shape [out_features, in_features]
152
- # So for Linear(d_model, n_head * d_head), weight should be [n_head * d_head, d_model]
147
+ # So for Linear(d_model, n_heads * d_head), weight should be [n_heads * d_head, d_model]
153
148
W_Q_transformation = torch .nn .Linear (W_Q .shape [0 ], W_Q .shape [1 ], bias = True )
154
149
W_Q_transformation .weight = torch .nn .Parameter (
155
150
W_Q .T
156
- ) # Transpose to [n_head * d_head, d_model]
151
+ ) # Transpose to [n_heads * d_head, d_model]
157
152
W_Q_transformation .bias = torch .nn .Parameter (b_Q )
158
153
159
154
W_K_transformation = torch .nn .Linear (W_K .shape [0 ], W_K .shape [1 ], bias = True )
160
155
W_K_transformation .weight = torch .nn .Parameter (
161
156
W_K .T
162
- ) # Transpose to [n_head * d_head, d_model]
157
+ ) # Transpose to [n_heads * d_head, d_model]
163
158
W_K_transformation .bias = torch .nn .Parameter (b_K )
164
159
165
160
W_V_transformation = torch .nn .Linear (W_V .shape [0 ], W_V .shape [1 ], bias = True )
166
161
W_V_transformation .weight = torch .nn .Parameter (
167
162
W_V .T
168
- ) # Transpose to [n_head * d_head, d_model]
163
+ ) # Transpose to [n_heads * d_head, d_model]
169
164
W_V_transformation .bias = torch .nn .Parameter (b_V )
170
165
171
166
return W_Q_transformation , W_K_transformation , W_V_transformation
0 commit comments