Skip to content

Commit bea73e1

Browse files
Create bridge for every module in neox (#995)
* Create bridge for every module in neox * Use transformer lens config in split_qkv_functions --------- Co-authored-by: Bryce Meyer <[email protected]>
1 parent fcfc684 commit bea73e1

File tree

3 files changed

+90
-29
lines changed

3 files changed

+90
-29
lines changed

transformer_lens/model_bridge/supported_architectures/bloom.py

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def __init__(self, cfg: Any) -> None:
108108
}
109109

110110
def split_qkv_matrix(
111-
self, attention_bridge: JointQKVAttentionBridge
111+
self, original_attention_component: Any
112112
) -> tuple[torch.nn.Linear, torch.nn.Linear, torch.nn.Linear]:
113113
"""Split the QKV matrix into separate linear transformations.
114114
Args:
@@ -118,54 +118,49 @@ def split_qkv_matrix(
118118
"""
119119

120120
# Keep mypy happy
121-
assert attention_bridge.original_component is not None
122-
assert isinstance(attention_bridge.original_component.query_key_value, LinearBridge)
123-
assert attention_bridge.original_component.query_key_value.original_component is not None
121+
assert original_attention_component is not None
122+
assert original_attention_component.query_key_value is not None
124123

125-
qkv_weights = attention_bridge.original_component.query_key_value.original_component.weight
124+
qkv_weights = original_attention_component.query_key_value.weight
126125

127126
# Keep mypy happy
128127
assert isinstance(qkv_weights, torch.Tensor)
129128

130-
d_head = self.cfg.hidden_size // self.cfg.n_head
131-
132-
# Original qkv_weights shape: [3 * n_head * d_head, d_model]
133-
# We want to split it into [d_model, n_head * d_head] for each of Q, K, V
134-
W_split = qkv_weights.T.reshape(self.cfg.hidden_size, 3, self.cfg.n_head * d_head)
129+
# We want to split weights into [d_model, n_heads * d_head] for each of Q, K, V
130+
W_split = qkv_weights.T.reshape(self.cfg.d_model, 3, self.cfg.n_heads * self.cfg.d_head)
135131

136132
W_Q, W_K, W_V = W_split[:, 0, :], W_split[:, 1, :], W_split[:, 2, :]
137133

138-
qkv_bias = attention_bridge.original_component.query_key_value.original_component.bias
134+
qkv_bias = original_attention_component.query_key_value.bias
139135

140136
# Keep mypy happy
141137
assert isinstance(qkv_bias, torch.Tensor)
142138

143-
# Original qkv_bias shape: [3 * n_head * d_head]
144-
# Reshape to [3, n_head * d_head] to split by Q, K, V
145-
qkv_bias = qkv_bias.reshape(3, self.cfg.n_head * d_head)
139+
# Reshape to [3, n_heads * d_head] to split by Q, K, V
140+
qkv_bias = qkv_bias.reshape(3, self.cfg.n_heads * self.cfg.d_head)
146141

147142
b_Q, b_K, b_V = qkv_bias[0, :], qkv_bias[1, :], qkv_bias[2, :]
148143

149144
# Create nn.Linear modules
150-
# W_Q, W_K, W_V shapes are [d_model, n_head * d_head]
145+
# W_Q, W_K, W_V shapes are [d_model, n_heads * d_head]
151146
# nn.Linear expects weight shape [out_features, in_features]
152-
# So for Linear(d_model, n_head * d_head), weight should be [n_head * d_head, d_model]
147+
# So for Linear(d_model, n_heads * d_head), weight should be [n_heads * d_head, d_model]
153148
W_Q_transformation = torch.nn.Linear(W_Q.shape[0], W_Q.shape[1], bias=True)
154149
W_Q_transformation.weight = torch.nn.Parameter(
155150
W_Q.T
156-
) # Transpose to [n_head * d_head, d_model]
151+
) # Transpose to [n_heads * d_head, d_model]
157152
W_Q_transformation.bias = torch.nn.Parameter(b_Q)
158153

159154
W_K_transformation = torch.nn.Linear(W_K.shape[0], W_K.shape[1], bias=True)
160155
W_K_transformation.weight = torch.nn.Parameter(
161156
W_K.T
162-
) # Transpose to [n_head * d_head, d_model]
157+
) # Transpose to [n_heads * d_head, d_model]
163158
W_K_transformation.bias = torch.nn.Parameter(b_K)
164159

165160
W_V_transformation = torch.nn.Linear(W_V.shape[0], W_V.shape[1], bias=True)
166161
W_V_transformation.weight = torch.nn.Parameter(
167162
W_V.T
168-
) # Transpose to [n_head * d_head, d_model]
163+
) # Transpose to [n_heads * d_head, d_model]
169164
W_V_transformation.bias = torch.nn.Parameter(b_V)
170165

171166
return W_Q_transformation, W_K_transformation, W_V_transformation

transformer_lens/model_bridge/supported_architectures/gpt2.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,6 @@ def __init__(self, cfg: Any) -> None:
9797
},
9898
config={
9999
"split_qkv_matrix": self.split_qkv_matrix,
100-
"original_model_config": self.cfg,
101100
},
102101
),
103102
"ln2": NormalizationBridge(name="ln_2"),
@@ -133,8 +132,6 @@ def split_qkv_matrix(
133132
# Keep mypy happy
134133
assert isinstance(qkv_weights, torch.Tensor)
135134

136-
d_head = self.cfg.n_embd // self.cfg.n_head
137-
138135
# Original qkv_weights shape: [d_model, 3 * d_model]
139136
# Split into three equal parts along dimension 1 to get Q, K, V weights
140137
W_Q, W_K, W_V = torch.tensor_split(qkv_weights, 3, dim=1)
@@ -144,9 +141,9 @@ def split_qkv_matrix(
144141
# Keep mypy happy
145142
assert isinstance(qkv_bias, torch.Tensor)
146143

147-
# Original qkv_bias shape: [3 * n_head * d_head]
148-
# Reshape to [3, n_head * d_head] to split by Q, K, V
149-
qkv_bias = qkv_bias.reshape(3, self.cfg.n_head * d_head)
144+
# Original qkv_bias shape: [3 * n_heads * d_head]
145+
# Reshape to [3, n_heads * d_head] to split by Q, K, V
146+
qkv_bias = qkv_bias.reshape(3, self.cfg.n_heads * self.cfg.d_head)
150147
b_Q, b_K, b_V = qkv_bias[0, :], qkv_bias[1, :], qkv_bias[2, :]
151148

152149
# Create nn.Linear modules

transformer_lens/model_bridge/supported_architectures/neox.py

Lines changed: 73 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from typing import Any
44

5+
import torch
6+
57
from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
68
from transformer_lens.model_bridge.conversion_utils.conversion_steps import (
79
RearrangeWeightConversion,
@@ -12,9 +14,10 @@
1214
ChainWeightConversion,
1315
)
1416
from transformer_lens.model_bridge.generalized_components import (
15-
AttentionBridge,
1617
BlockBridge,
1718
EmbeddingBridge,
19+
JointQKVAttentionBridge,
20+
LinearBridge,
1821
MLPBridge,
1922
NormalizationBridge,
2023
UnembeddingBridge,
@@ -131,16 +134,82 @@ def __init__(self, cfg: Any) -> None:
131134

132135
self.component_mapping = {
133136
"embed": EmbeddingBridge(name="gpt_neox.embed_in"),
134-
"pos_embed": EmbeddingBridge(name="gpt_neox.embed_pos"),
137+
"rotary_emb": EmbeddingBridge(name="gpt_neox.rotary_emb"),
135138
"blocks": BlockBridge(
136139
name="gpt_neox.layers",
137140
submodules={
138141
"ln1": NormalizationBridge(name="input_layernorm"),
139142
"ln2": NormalizationBridge(name="post_attention_layernorm"),
140-
"attn": AttentionBridge(name="attention"),
141-
"mlp": MLPBridge(name="mlp"),
143+
"attn": JointQKVAttentionBridge(
144+
name="attention",
145+
submodules={
146+
"W_QKV": LinearBridge(
147+
name="query_key_value",
148+
),
149+
"W_O": LinearBridge(name="dense"),
150+
},
151+
config={"split_qkv_matrix": self.split_qkv_matrix},
152+
),
153+
"mlp": MLPBridge(
154+
name="mlp",
155+
submodules={
156+
"W_in": LinearBridge(name="dense_h_to_4h"),
157+
"W_out": LinearBridge(name="dense_4h_to_h"),
158+
},
159+
),
142160
},
143161
),
144162
"ln_final": NormalizationBridge(name="gpt_neox.final_layer_norm"),
145163
"unembed": UnembeddingBridge(name="embed_out"),
146164
}
165+
166+
def split_qkv_matrix(
167+
self, original_attention_component: Any
168+
) -> tuple[torch.nn.Linear, torch.nn.Linear, torch.nn.Linear]:
169+
"""Split the QKV matrix into separate linear transformations.
170+
Args:
171+
attention_component: The original attention layer component
172+
Returns:
173+
Tuple of nn.Linear modules for Q, K, and V transformations
174+
"""
175+
176+
# Keep mypy happy
177+
assert original_attention_component is not None
178+
assert original_attention_component.query_key_value is not None
179+
180+
qkv_weights = original_attention_component.query_key_value.weight
181+
182+
# Keep mypy happy
183+
assert isinstance(qkv_weights, torch.Tensor)
184+
185+
# Original qkv_weights shape: [d_model * 3 * d_model]
186+
# Split into three equal parts along dimension 1 to get Q, K, V weights
187+
W_Q, W_K, W_V = torch.tensor_split(qkv_weights, 3, dim=1)
188+
189+
qkv_bias = original_attention_component.query_key_value.bias
190+
191+
# Keep mypy happy
192+
assert isinstance(qkv_bias, torch.Tensor)
193+
194+
# Original qkv_bias shape: [n_heads * 3 * d_head]
195+
# Reshape to [3, n_heads * d_head] to split by Q, K, V
196+
qkv_bias = qkv_bias.reshape(3, self.cfg.n_heads * self.cfg.d_head)
197+
b_Q, b_K, b_V = qkv_bias[0, :], qkv_bias[1, :], qkv_bias[2, :]
198+
199+
# Create nn.Linear modules
200+
# After tensor_split, W_Q, W_K, W_V shapes are [d_model, d_model] ([in_features, out_features])
201+
# nn.Linear expects weight shape [out_features, in_features]
202+
# So we need to transpose the weights
203+
W_Q_transformation = torch.nn.Linear(W_Q.shape[0], W_Q.shape[1], bias=True)
204+
W_Q_transformation.weight = torch.nn.Parameter(W_Q.T)
205+
W_Q_transformation.bias = torch.nn.Parameter(b_Q)
206+
207+
W_K_transformation = torch.nn.Linear(W_K.shape[0], W_K.shape[1], bias=True)
208+
W_K_transformation.weight = torch.nn.Parameter(W_K.T)
209+
W_K_transformation.bias = torch.nn.Parameter(b_K)
210+
211+
W_V_transformation = torch.nn.Linear(W_V.shape[0], W_V.shape[1], bias=True)
212+
W_V_transformation.weight = torch.nn.Parameter(W_V.T)
213+
W_V_transformation.bias = torch.nn.Parameter(b_V)
214+
215+
return W_Q_transformation, W_K_transformation, W_V_transformation

0 commit comments

Comments
 (0)