@@ -54,9 +54,9 @@ def load_artifacts(run_cfg: RunConfig):
5454 token = run_cfg .hf_token ,
5555 )
5656
57- hookpoint_to_sae_encode = load_sparse_coders (model , run_cfg , compile = True )
57+ hookpoint_to_sparse_encode = load_sparse_coders (model , run_cfg , compile = True )
5858
59- return run_cfg .hookpoints , hookpoint_to_sae_encode , model
59+ return run_cfg .hookpoints , hookpoint_to_sparse_encode , model
6060
6161
6262async def process_cache (
@@ -205,10 +205,9 @@ def populate_cache(
205205 run_cfg : RunConfig ,
206206 cfg : CacheConfig ,
207207 model : PreTrainedModel ,
208- hookpoint_to_sae_encode : dict [str , Callable ],
208+ hookpoint_to_sparse_encode : dict [str , Callable ],
209209 latents_path : Path ,
210210 tokenizer : PreTrainedTokenizer | PreTrainedTokenizerFast ,
211- filter_bos : bool ,
212211):
213212 """
214213 Populates an on-disk cache in `latents_path` with SAE latent activations.
@@ -224,7 +223,7 @@ def populate_cache(
224223 )
225224 tokens = data ["input_ids" ]
226225
227- if filter_bos :
226+ if run_cfg . filter_bos :
228227 if tokenizer .bos_token_id is None :
229228 print ("Tokenizer does not have a BOS token, skipping BOS filtering" )
230229 else :
@@ -240,7 +239,7 @@ def populate_cache(
240239
241240 cache = LatentCache (
242241 model ,
243- hookpoint_to_sae_encode ,
242+ hookpoint_to_sparse_encode ,
244243 batch_size = cfg .batch_size ,
245244 )
246245 cache .run (cfg .n_tokens , tokens )
@@ -276,7 +275,7 @@ async def run(
276275
277276 latent_range = torch .arange (run_cfg .max_latents ) if run_cfg .max_latents else None
278277
279- hookpoints , hookpoint_to_sae_encode , model = load_artifacts (run_cfg )
278+ hookpoints , hookpoint_to_sparse_encode , model = load_artifacts (run_cfg )
280279 tokenizer = AutoTokenizer .from_pretrained (run_cfg .model , token = run_cfg .hf_token )
281280
282281 if (
@@ -287,15 +286,14 @@ async def run(
287286 run_cfg ,
288287 cache_cfg ,
289288 model ,
290- hookpoint_to_sae_encode ,
289+ hookpoint_to_sparse_encode ,
291290 latents_path ,
292291 tokenizer ,
293- filter_bos = run_cfg .filter_bos ,
294292 )
295293 else :
296294 print (f"Files found in { latents_path } , skipping cache population..." )
297295
298- del model , hookpoint_to_sae_encode
296+ del model , hookpoint_to_sparse_encode
299297
300298 if (
301299 not glob (str (scores_path / ".*" )) + glob (str (scores_path / "*" ))
0 commit comments