@@ -51,9 +51,10 @@ class TransformerBridge(nn.Module):
51
51
52
52
# Top-level hook aliases for legacy TransformerLens names
53
53
# Placing these on the main bridge ensures aliases like 'hook_embed' are available
54
- hook_aliases = {
54
+ hook_aliases : Dict [ str , Union [ str , List [ str ]]] = {
55
55
"hook_embed" : "embed.hook_out" ,
56
- "hook_pos_embed" : "pos_embed.hook_out" ,
56
+ # rotary style models use rotary_emb.hook_out, but gpt2-style models use pos_embed.hook_out
57
+ "hook_pos_embed" : ["pos_embed.hook_out" , "rotary_emb.hook_out" ],
57
58
"hook_unembed" : "unembed.hook_out" ,
58
59
}
59
60
@@ -131,11 +132,25 @@ def _initialize_hook_registry(self) -> None:
131
132
132
133
# Add bridge aliases if compatibility mode is enabled
133
134
if self .compatibility_mode and self .hook_aliases :
134
- for alias_name , target_name in self .hook_aliases .items ():
135
+ for alias_name , target in self .hook_aliases .items ():
135
136
# Use the existing alias system to resolve the target hook
136
- target_hook = resolve_alias (self , alias_name , self .hook_aliases )
137
- if target_hook is not None :
138
- self ._hook_registry [alias_name ] = target_hook
137
+ # Convert to Dict[str, str] for resolve_alias if target_name is a list
138
+ if isinstance (target , list ):
139
+ # For list targets, try each one until one works
140
+ for single_target in target :
141
+ try :
142
+ target_hook = resolve_alias (
143
+ self , alias_name , {alias_name : single_target }
144
+ )
145
+ if target_hook is not None :
146
+ self ._hook_registry [alias_name ] = target_hook
147
+ break
148
+ except AttributeError :
149
+ continue
150
+ else :
151
+ target_hook = resolve_alias (self , alias_name , {alias_name : target })
152
+ if target_hook is not None :
153
+ self ._hook_registry [alias_name ] = target_hook
139
154
140
155
self ._hook_registry_initialized = True
141
156
@@ -213,9 +228,17 @@ def hook_dict(self) -> dict[str, HookPoint]:
213
228
214
229
# Add aliases if compatibility mode is enabled
215
230
if self .compatibility_mode :
216
- for alias_name , target_name in self .hook_aliases .items ():
217
- if target_name in hooks :
218
- hooks [alias_name ] = hooks [target_name ]
231
+ for alias_name , target in self .hook_aliases .items ():
232
+ # Handle both string and list target names
233
+ if isinstance (target , list ):
234
+ # For list targets, find the first one that exists in hooks
235
+ for single_target in target :
236
+ if single_target in hooks :
237
+ hooks [alias_name ] = hooks [single_target ]
238
+ break
239
+ else :
240
+ if target in hooks :
241
+ hooks [alias_name ] = hooks [target ]
219
242
220
243
return hooks
221
244
@@ -239,9 +262,16 @@ def __getattr__(self, name: str) -> Any:
239
262
240
263
# Check if this is a hook alias when compatibility mode is enabled
241
264
if self .compatibility_mode and name in self .hook_aliases :
242
- target_name = self .hook_aliases [name ]
243
- if target_name in self ._hook_registry :
244
- return self ._hook_registry [target_name ]
265
+ target = self .hook_aliases [name ]
266
+ # Handle both string and list target names
267
+ if isinstance (target , list ):
268
+ # For list targets, find the first one that exists in the registry
269
+ for single_target in target :
270
+ if single_target in self ._hook_registry :
271
+ return self ._hook_registry [single_target ]
272
+ else :
273
+ if target in self ._hook_registry :
274
+ return self ._hook_registry [target ]
245
275
246
276
return super ().__getattr__ (name )
247
277
@@ -1040,7 +1070,15 @@ def stop_hook(tensor: torch.Tensor, *, hook: Any) -> torch.Tensor:
1040
1070
# If compatibility mode is enabled, we need to handle aliases
1041
1071
# Create duplicate cache entries for TransformerLens compatibility
1042
1072
# Use the aliases collected from components (reverse mapping: new -> old)
1043
- reverse_aliases = {new_name : old_name for old_name , new_name in aliases .items ()}
1073
+ # Handle the case where some alias values might be lists
1074
+ reverse_aliases = {}
1075
+ for old_name , new_name in aliases .items ():
1076
+ if isinstance (new_name , list ):
1077
+ # For list values, create a mapping for each item in the list
1078
+ for single_new_name in new_name :
1079
+ reverse_aliases [single_new_name ] = old_name
1080
+ else :
1081
+ reverse_aliases [new_name ] = old_name
1044
1082
1045
1083
# Create duplicate entries in cache
1046
1084
cache_items_to_add = {}
@@ -1056,8 +1094,16 @@ def stop_hook(tensor: torch.Tensor, *, hook: Any) -> torch.Tensor:
1056
1094
1057
1095
# Add cache entries for all aliases (both hook and cache aliases)
1058
1096
for alias_name , target_name in aliases .items ():
1059
- if target_name in cache and alias_name not in cache :
1060
- cache [alias_name ] = cache [target_name ]
1097
+ # Handle both string and list target names
1098
+ if isinstance (target_name , list ):
1099
+ # For list targets, find the first one that exists in cache
1100
+ for single_target in target_name :
1101
+ if single_target in cache and alias_name not in cache :
1102
+ cache [alias_name ] = cache [single_target ]
1103
+ break
1104
+ else :
1105
+ if target_name in cache and alias_name not in cache :
1106
+ cache [alias_name ] = cache [target_name ]
1061
1107
1062
1108
if return_cache_object :
1063
1109
cache_obj = ActivationCache (cache , self , has_batch_dim = not remove_batch_dim )
0 commit comments