11import json
22from collections import defaultdict
33from pathlib import Path
4+ from typing import Callable
45
56import numpy as np
67import torch
78from safetensors .numpy import save_file
8- from torch import Tensor , nn
9+ from torch import Tensor
910from torchtyping import TensorType
1011from tqdm import tqdm
1112
@@ -157,7 +158,7 @@ class LatentCache:
157158 def __init__ (
158159 self ,
159160 model ,
160- hookpoint_to_sae_encode : dict [str , nn . Module ],
161+ hookpoint_to_sparse_encode : dict [str , Callable ],
161162 batch_size : int ,
162163 filters : dict [str , TensorType ["indices" ]] | None = None ,
163164 ):
@@ -166,12 +167,12 @@ def __init__(
166167
167168 Args:
168169 model: The model to cache latents for.
169- hookpoint_to_sae_encode : Dictionary of submodules to cache .
170+ hookpoint_to_sparse_encode : Dictionary of sparse encoding functions .
170171 batch_size: Size of batches for processing.
171172 filters: Filters for selecting specific latents.
172173 """
173174 self .model = model
174- self .hookpoint_to_sae_encode = hookpoint_to_sae_encode
175+ self .hookpoint_to_sparse_encode = hookpoint_to_sparse_encode
175176
176177 self .batch_size = batch_size
177178 self .width = None
@@ -237,12 +238,12 @@ def run(self, n_tokens: int, tokens: TensorType["batch", "seq"]):
237238
238239 with torch .no_grad ():
239240 with collect_activations (
240- self .model , list (self .hookpoint_to_sae_encode .keys ())
241+ self .model , list (self .hookpoint_to_sparse_encode .keys ())
241242 ) as activations :
242243 self .model (batch .to (self .model .device ))
243244
244245 for hookpoint , latents in activations .items ():
245- sae_latents = self .hookpoint_to_sae_encode [hookpoint ](
246+ sae_latents = self .hookpoint_to_sparse_encode [hookpoint ](
246247 latents
247248 )
248249 self .cache .add (sae_latents , batch , batch_number , hookpoint )
0 commit comments