@@ -54,9 +54,9 @@ def load_artifacts(run_cfg: RunConfig):
5454 token = run_cfg .hf_token ,
5555 )
5656
57- hookpoint_to_sae_encode = load_sparse_coders (model , run_cfg , compile = True )
57+ hookpoint_to_sparse_encode = load_sparse_coders (model , run_cfg , compile = True )
5858
59- return run_cfg .hookpoints , hookpoint_to_sae_encode , model
59+ return run_cfg .hookpoints , hookpoint_to_sparse_encode , model
6060
6161
6262async def process_cache (
@@ -205,7 +205,7 @@ def populate_cache(
205205 run_cfg : RunConfig ,
206206 cfg : CacheConfig ,
207207 model : PreTrainedModel ,
208- hookpoint_to_sae_encode : dict [str , Callable ],
208+ hookpoint_to_sparse_encode : dict [str , Callable ],
209209 latents_path : Path ,
210210 tokenizer : PreTrainedTokenizer | PreTrainedTokenizerFast ,
211211):
@@ -239,7 +239,7 @@ def populate_cache(
239239
240240 cache = LatentCache (
241241 model ,
242- hookpoint_to_sae_encode ,
242+ hookpoint_to_sparse_encode ,
243243 batch_size = cfg .batch_size ,
244244 )
245245 cache .run (cfg .n_tokens , tokens )
@@ -275,7 +275,7 @@ async def run(
275275
276276 latent_range = torch .arange (run_cfg .max_latents ) if run_cfg .max_latents else None
277277
278- hookpoints , hookpoint_to_sae_encode , model = load_artifacts (run_cfg )
278+ hookpoints , hookpoint_to_sparse_encode , model = load_artifacts (run_cfg )
279279 tokenizer = AutoTokenizer .from_pretrained (run_cfg .model , token = run_cfg .hf_token )
280280
281281 if (
@@ -286,14 +286,14 @@ async def run(
286286 run_cfg ,
287287 cache_cfg ,
288288 model ,
289- hookpoint_to_sae_encode ,
289+ hookpoint_to_sparse_encode ,
290290 latents_path ,
291291 tokenizer ,
292292 )
293293 else :
294294 print (f"Files found in { latents_path } , skipping cache population..." )
295295
296- del model , hookpoint_to_sae_encode
296+ del model , hookpoint_to_sparse_encode
297297
298298 if (
299299 not glob (str (scores_path / ".*" )) + glob (str (scores_path / "*" ))
0 commit comments