Skip to content

Commit 917661d

Browse files
committed
remove sae references
1 parent 6a74db7 commit 917661d

File tree

3 files changed

+13
-11
lines changed

3 files changed

+13
-11
lines changed

delphi/__main__.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,9 @@ def load_artifacts(run_cfg: RunConfig):
5454
token=run_cfg.hf_token,
5555
)
5656

57-
hookpoint_to_sae_encode = load_sparse_coders(model, run_cfg, compile=True)
57+
hookpoint_to_sparse_encode = load_sparse_coders(model, run_cfg, compile=True)
5858

59-
return run_cfg.hookpoints, hookpoint_to_sae_encode, model
59+
return run_cfg.hookpoints, hookpoint_to_sparse_encode, model
6060

6161

6262
async def process_cache(
@@ -205,7 +205,7 @@ def populate_cache(
205205
run_cfg: RunConfig,
206206
cfg: CacheConfig,
207207
model: PreTrainedModel,
208-
hookpoint_to_sae_encode: dict[str, Callable],
208+
hookpoint_to_sparse_encode: dict[str, Callable],
209209
latents_path: Path,
210210
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
211211
):
@@ -239,7 +239,7 @@ def populate_cache(
239239

240240
cache = LatentCache(
241241
model,
242-
hookpoint_to_sae_encode,
242+
hookpoint_to_sparse_encode,
243243
batch_size=cfg.batch_size,
244244
)
245245
cache.run(cfg.n_tokens, tokens)
@@ -275,7 +275,7 @@ async def run(
275275

276276
latent_range = torch.arange(run_cfg.max_latents) if run_cfg.max_latents else None
277277

278-
hookpoints, hookpoint_to_sae_encode, model = load_artifacts(run_cfg)
278+
hookpoints, hookpoint_to_sparse_encode, model = load_artifacts(run_cfg)
279279
tokenizer = AutoTokenizer.from_pretrained(run_cfg.model, token=run_cfg.hf_token)
280280

281281
if (
@@ -286,14 +286,14 @@ async def run(
286286
run_cfg,
287287
cache_cfg,
288288
model,
289-
hookpoint_to_sae_encode,
289+
hookpoint_to_sparse_encode,
290290
latents_path,
291291
tokenizer,
292292
)
293293
else:
294294
print(f"Files found in {latents_path}, skipping cache population...")
295295

296-
del model, hookpoint_to_sae_encode
296+
del model, hookpoint_to_sparse_encode
297297

298298
if (
299299
not glob(str(scores_path / ".*")) + glob(str(scores_path / "*"))

delphi/tests/conftest.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,13 @@ def cache_setup(
6262
sparse_model="EleutherAI/sae-pythia-70m-32k",
6363
hookpoints=["layers.1"],
6464
)
65-
hookpoint_to_sae_encode = load_sparse_coders(model, run_cfg_gemma)
65+
hookpoint_to_sparse_encode = load_sparse_coders(model, run_cfg_gemma)
6666

6767
# Define cache config and initialize cache
6868
cache_cfg = CacheConfig(batch_size=1, ctx_len=16, n_tokens=100)
69-
cache = LatentCache(model, hookpoint_to_sae_encode, batch_size=cache_cfg.batch_size)
69+
cache = LatentCache(
70+
model, hookpoint_to_sparse_encode, batch_size=cache_cfg.batch_size
71+
)
7072

7173
# Generate mock tokens and run the cache
7274
tokens = mock_dataset

examples/caching_activations.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@
8787
" hookpoints=[\"layer_10/width_16k/average_l0_39\"],\n",
8888
")\n",
8989
"\n",
90-
"hookpoint_to_sae_encode = load_sparse_coders(model, run_cfg)"
90+
"hookpoint_to_sparse_encode = load_sparse_coders(model, run_cfg)"
9191
]
9292
},
9393
{
@@ -134,7 +134,7 @@
134134
"\n",
135135
"cache = LatentCache(\n",
136136
" model,\n",
137-
" hookpoint_to_sae_encode,\n",
137+
" hookpoint_to_sparse_encode,\n",
138138
" batch_size = cfg.batch_size,\n",
139139
")"
140140
]

0 commit comments

Comments
 (0)