Skip to content

Commit 5699d73

Browse files
Create bridges for every module in neo (#987)
* Create bridge for every module in neo * Use path to access subcomponents in get_component instead of Bridge name attribute values * Handle remote names with multiple parts --------- Co-authored-by: Bryce Meyer <[email protected]>
1 parent bea73e1 commit 5699d73

File tree

2 files changed

+26
-3
lines changed
  • transformer_lens/model_bridge

2 files changed

+26
-3
lines changed

transformer_lens/model_bridge/generalized_components/base.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,15 @@ def _getattr_helper(self, name: str) -> Any:
182182
original_component = self._modules.get("_original_component", None)
183183
if original_component is not None:
184184
try:
185-
return getattr(original_component, name)
185+
name_split = name.split(".")
186+
187+
if len(name_split) > 1:
188+
current = getattr(original_component, name_split[0])
189+
for part in name_split[1:]:
190+
current = getattr(current, part)
191+
return current
192+
else:
193+
return getattr(original_component, name)
186194
except AttributeError:
187195
pass
188196

transformer_lens/model_bridge/supported_architectures/neo.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
AttentionBridge,
1212
BlockBridge,
1313
EmbeddingBridge,
14+
LinearBridge,
1415
MLPBridge,
1516
NormalizationBridge,
1617
UnembeddingBridge,
@@ -66,9 +67,23 @@ def __init__(self, cfg: Any) -> None:
6667
name="transformer.h",
6768
submodules={
6869
"ln1": NormalizationBridge(name="ln_1"),
69-
"attn": AttentionBridge(name="attn"),
70+
"attn": AttentionBridge(
71+
name="attn.attention",
72+
submodules={
73+
"W_Q": LinearBridge(name="q_proj"),
74+
"W_K": LinearBridge(name="k_proj"),
75+
"W_V": LinearBridge(name="v_proj"),
76+
"W_O": LinearBridge(name="out_proj"),
77+
},
78+
),
7079
"ln2": NormalizationBridge(name="ln_2"),
71-
"mlp": MLPBridge(name="mlp"),
80+
"mlp": MLPBridge(
81+
name="mlp",
82+
submodules={
83+
"W_in": LinearBridge(name="c_fc"),
84+
"W_out": LinearBridge(name="c_proj"),
85+
},
86+
),
7287
},
7388
),
7489
"ln_final": NormalizationBridge(name="transformer.ln_f"),

0 commit comments

Comments
 (0)