Skip to content

Commit 446b9d0

Browse files
map hook_pos_embed to rotary_emb, allow hook_aliases to be a list (#1034)
* map hook_pos_embed to rotary_emb, allow hook_aliases to be a list * ran format * removed neuroscope * added full coverage for hook aliases change * updated cache to accept list --------- Co-authored-by: Bryce Meyer <[email protected]>
1 parent 285c039 commit 446b9d0

File tree

3 files changed

+103
-26
lines changed

3 files changed

+103
-26
lines changed

.github/workflows/checks.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ jobs:
152152
- "Exploratory_Analysis_Demo"
153153
# - "Grokking_Demo"
154154
# - "Head_Detector_Demo"
155-
- "Interactive_Neuroscope"
155+
# - "Interactive_Neuroscope"
156156
# - "LLaMA"
157157
# - "LLaMA2_GPU_Quantized"
158158
- "Main_Demo"

transformer_lens/model_bridge/bridge.py

Lines changed: 61 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,10 @@ class TransformerBridge(nn.Module):
5151

5252
# Top-level hook aliases for legacy TransformerLens names
5353
# Placing these on the main bridge ensures aliases like 'hook_embed' are available
54-
hook_aliases = {
54+
hook_aliases: Dict[str, Union[str, List[str]]] = {
5555
"hook_embed": "embed.hook_out",
56-
"hook_pos_embed": "pos_embed.hook_out",
56+
# rotary style models use rotary_emb.hook_out, but gpt2-style models use pos_embed.hook_out
57+
"hook_pos_embed": ["pos_embed.hook_out", "rotary_emb.hook_out"],
5758
"hook_unembed": "unembed.hook_out",
5859
}
5960

@@ -131,11 +132,25 @@ def _initialize_hook_registry(self) -> None:
131132

132133
# Add bridge aliases if compatibility mode is enabled
133134
if self.compatibility_mode and self.hook_aliases:
134-
for alias_name, target_name in self.hook_aliases.items():
135+
for alias_name, target in self.hook_aliases.items():
135136
# Use the existing alias system to resolve the target hook
136-
target_hook = resolve_alias(self, alias_name, self.hook_aliases)
137-
if target_hook is not None:
138-
self._hook_registry[alias_name] = target_hook
137+
# Convert to Dict[str, str] for resolve_alias if target_name is a list
138+
if isinstance(target, list):
139+
# For list targets, try each one until one works
140+
for single_target in target:
141+
try:
142+
target_hook = resolve_alias(
143+
self, alias_name, {alias_name: single_target}
144+
)
145+
if target_hook is not None:
146+
self._hook_registry[alias_name] = target_hook
147+
break
148+
except AttributeError:
149+
continue
150+
else:
151+
target_hook = resolve_alias(self, alias_name, {alias_name: target})
152+
if target_hook is not None:
153+
self._hook_registry[alias_name] = target_hook
139154

140155
self._hook_registry_initialized = True
141156

@@ -213,9 +228,17 @@ def hook_dict(self) -> dict[str, HookPoint]:
213228

214229
# Add aliases if compatibility mode is enabled
215230
if self.compatibility_mode:
216-
for alias_name, target_name in self.hook_aliases.items():
217-
if target_name in hooks:
218-
hooks[alias_name] = hooks[target_name]
231+
for alias_name, target in self.hook_aliases.items():
232+
# Handle both string and list target names
233+
if isinstance(target, list):
234+
# For list targets, find the first one that exists in hooks
235+
for single_target in target:
236+
if single_target in hooks:
237+
hooks[alias_name] = hooks[single_target]
238+
break
239+
else:
240+
if target in hooks:
241+
hooks[alias_name] = hooks[target]
219242

220243
return hooks
221244

@@ -239,9 +262,16 @@ def __getattr__(self, name: str) -> Any:
239262

240263
# Check if this is a hook alias when compatibility mode is enabled
241264
if self.compatibility_mode and name in self.hook_aliases:
242-
target_name = self.hook_aliases[name]
243-
if target_name in self._hook_registry:
244-
return self._hook_registry[target_name]
265+
target = self.hook_aliases[name]
266+
# Handle both string and list target names
267+
if isinstance(target, list):
268+
# For list targets, find the first one that exists in the registry
269+
for single_target in target:
270+
if single_target in self._hook_registry:
271+
return self._hook_registry[single_target]
272+
else:
273+
if target in self._hook_registry:
274+
return self._hook_registry[target]
245275

246276
return super().__getattr__(name)
247277

@@ -1040,7 +1070,15 @@ def stop_hook(tensor: torch.Tensor, *, hook: Any) -> torch.Tensor:
10401070
# If compatibility mode is enabled, we need to handle aliases
10411071
# Create duplicate cache entries for TransformerLens compatibility
10421072
# Use the aliases collected from components (reverse mapping: new -> old)
1043-
reverse_aliases = {new_name: old_name for old_name, new_name in aliases.items()}
1073+
# Handle the case where some alias values might be lists
1074+
reverse_aliases = {}
1075+
for old_name, new_name in aliases.items():
1076+
if isinstance(new_name, list):
1077+
# For list values, create a mapping for each item in the list
1078+
for single_new_name in new_name:
1079+
reverse_aliases[single_new_name] = old_name
1080+
else:
1081+
reverse_aliases[new_name] = old_name
10441082

10451083
# Create duplicate entries in cache
10461084
cache_items_to_add = {}
@@ -1056,8 +1094,16 @@ def stop_hook(tensor: torch.Tensor, *, hook: Any) -> torch.Tensor:
10561094

10571095
# Add cache entries for all aliases (both hook and cache aliases)
10581096
for alias_name, target_name in aliases.items():
1059-
if target_name in cache and alias_name not in cache:
1060-
cache[alias_name] = cache[target_name]
1097+
# Handle both string and list target names
1098+
if isinstance(target_name, list):
1099+
# For list targets, find the first one that exists in cache
1100+
for single_target in target_name:
1101+
if single_target in cache and alias_name not in cache:
1102+
cache[alias_name] = cache[single_target]
1103+
break
1104+
else:
1105+
if target_name in cache and alias_name not in cache:
1106+
cache[alias_name] = cache[target_name]
10611107

10621108
if return_cache_object:
10631109
cache_obj = ActivationCache(cache, self, has_batch_dim=not remove_batch_dim)

transformer_lens/utilities/aliases.py

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,17 +30,48 @@ def resolve_alias(
3030
stacklevel=3, # Adjusted for utility function call
3131
)
3232

33-
target_name_split = target_name.split(".")
34-
35-
if len(target_name_split) > 1:
36-
current_attr = target_object
37-
for i in range(len(target_name_split) - 1):
38-
current_attr = getattr(current_attr, target_name_split[i])
39-
next_attr = getattr(current_attr, target_name_split[i + 1])
40-
return next_attr
33+
def _resolve_single_target(target_name: str) -> Any:
34+
"""Helper function to resolve a single target name."""
35+
target_name_split = target_name.split(".")
36+
# there are multiple target names, so we need to check all of them
37+
# this is the case for hook_pos_embed, which can be either pos_embed.hook_out (gpt2-style) or rotary_emb.hook_out (gemma/etc-style)
38+
if len(target_name_split) > 1:
39+
current_attr = target_object
40+
for i in range(len(target_name_split) - 1):
41+
if not hasattr(current_attr, target_name_split[i]):
42+
continue
43+
current_attr = getattr(current_attr, target_name_split[i])
44+
45+
# Check if the final attribute exists
46+
if not hasattr(current_attr, target_name_split[-1]):
47+
raise AttributeError(
48+
f"'{type(current_attr).__name__}' object has no attribute '{target_name_split[-1]}'"
49+
)
50+
next_attr = getattr(current_attr, target_name_split[-1])
51+
return next_attr
52+
else:
53+
# Check if the target attribute exists before getting it
54+
if not hasattr(target_object, target_name):
55+
raise AttributeError(
56+
f"'{type(target_object).__name__}' object has no attribute '{target_name}'"
57+
)
58+
# Return the target hook
59+
return getattr(target_object, target_name)
60+
61+
# if the target_name is a list, we check all elements
62+
if isinstance(target_name, list):
63+
for target_name_item in target_name:
64+
try:
65+
result = _resolve_single_target(target_name_item)
66+
return result
67+
except AttributeError:
68+
continue
69+
# If we get here, none of the targets in the list were found
70+
raise AttributeError(
71+
f"None of the target names {target_name} could be resolved on '{type(target_object).__name__}' object"
72+
)
4173
else:
42-
# Return the target hook
43-
return getattr(target_object, target_name)
74+
return _resolve_single_target(target_name)
4475
return None
4576

4677

0 commit comments

Comments
 (0)