Skip to content

Commit 8eb4a55

Browse files
authored
Merge pull request #77 from EleutherAI/cleanup
Add CI tests; remove SAE references
2 parents 88a4338 + a04bd43 commit 8eb4a55

File tree

6 files changed

+46
-17
lines changed

6 files changed

+46
-17
lines changed

.github/workflows/tests.yml

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
name: Tests
2+
3+
on:
4+
push:
5+
branches: [ main ]
6+
pull_request:
7+
branches: [ main ]
8+
9+
jobs:
10+
test:
11+
runs-on: ubuntu-latest
12+
steps:
13+
- uses: actions/checkout@v3
14+
15+
- name: Set up Python
16+
uses: actions/setup-python@v3
17+
with:
18+
python-version: '3.10'
19+
20+
- name: Install dependencies
21+
run: |
22+
python -m pip install --upgrade pip
23+
pip install ".[dev]"
24+
25+
- name: Run tests
26+
run: pytest

delphi/__main__.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,9 @@ def load_artifacts(run_cfg: RunConfig):
5454
token=run_cfg.hf_token,
5555
)
5656

57-
hookpoint_to_sae_encode = load_sparse_coders(model, run_cfg, compile=True)
57+
hookpoint_to_sparse_encode = load_sparse_coders(model, run_cfg, compile=True)
5858

59-
return run_cfg.hookpoints, hookpoint_to_sae_encode, model
59+
return run_cfg.hookpoints, hookpoint_to_sparse_encode, model
6060

6161

6262
async def process_cache(
@@ -205,10 +205,9 @@ def populate_cache(
205205
run_cfg: RunConfig,
206206
cfg: CacheConfig,
207207
model: PreTrainedModel,
208-
hookpoint_to_sae_encode: dict[str, Callable],
208+
hookpoint_to_sparse_encode: dict[str, Callable],
209209
latents_path: Path,
210210
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
211-
filter_bos: bool,
212211
):
213212
"""
214213
Populates an on-disk cache in `latents_path` with SAE latent activations.
@@ -224,7 +223,7 @@ def populate_cache(
224223
)
225224
tokens = data["input_ids"]
226225

227-
if filter_bos:
226+
if run_cfg.filter_bos:
228227
if tokenizer.bos_token_id is None:
229228
print("Tokenizer does not have a BOS token, skipping BOS filtering")
230229
else:
@@ -240,7 +239,7 @@ def populate_cache(
240239

241240
cache = LatentCache(
242241
model,
243-
hookpoint_to_sae_encode,
242+
hookpoint_to_sparse_encode,
244243
batch_size=cfg.batch_size,
245244
)
246245
cache.run(cfg.n_tokens, tokens)
@@ -276,7 +275,7 @@ async def run(
276275

277276
latent_range = torch.arange(run_cfg.max_latents) if run_cfg.max_latents else None
278277

279-
hookpoints, hookpoint_to_sae_encode, model = load_artifacts(run_cfg)
278+
hookpoints, hookpoint_to_sparse_encode, model = load_artifacts(run_cfg)
280279
tokenizer = AutoTokenizer.from_pretrained(run_cfg.model, token=run_cfg.hf_token)
281280

282281
if (
@@ -287,15 +286,14 @@ async def run(
287286
run_cfg,
288287
cache_cfg,
289288
model,
290-
hookpoint_to_sae_encode,
289+
hookpoint_to_sparse_encode,
291290
latents_path,
292291
tokenizer,
293-
filter_bos=run_cfg.filter_bos,
294292
)
295293
else:
296294
print(f"Files found in {latents_path}, skipping cache population...")
297295

298-
del model, hookpoint_to_sae_encode
296+
del model, hookpoint_to_sparse_encode
299297

300298
if (
301299
not glob(str(scores_path / ".*")) + glob(str(scores_path / "*"))

delphi/sparse_coders/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def load_sparse_coders(
2828

2929
# Add SAE hooks to the model
3030
if "gemma" not in run_cfg.sparse_model:
31-
hookpoint_to_sae_encode = load_sparsify_sparse_coders(
31+
hookpoint_to_sparse_encode = load_sparsify_sparse_coders(
3232
model,
3333
run_cfg.sparse_model,
3434
run_cfg.hookpoints,
@@ -54,7 +54,7 @@ def load_sparse_coders(
5454
sae_sizes.append(sae_size)
5555
l0s.append(l0)
5656

57-
hookpoint_to_sae_encode = load_gemma_autoencoders(
57+
hookpoint_to_sparse_encode = load_gemma_autoencoders(
5858
model_path=model_path,
5959
ae_layers=layers,
6060
average_l0s=l0s,
@@ -64,4 +64,4 @@ def load_sparse_coders(
6464
device=model.device,
6565
)
6666

67-
return hookpoint_to_sae_encode
67+
return hookpoint_to_sparse_encode

delphi/tests/conftest.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,13 @@ def cache_setup(
6262
sparse_model="EleutherAI/sae-pythia-70m-32k",
6363
hookpoints=["layers.1"],
6464
)
65-
hookpoint_to_sae_encode = load_sparse_coders(model, run_cfg_gemma)
65+
hookpoint_to_sparse_encode = load_sparse_coders(model, run_cfg_gemma)
6666

6767
# Define cache config and initialize cache
6868
cache_cfg = CacheConfig(batch_size=1, ctx_len=16, n_tokens=100)
69-
cache = LatentCache(model, hookpoint_to_sae_encode, batch_size=cache_cfg.batch_size)
69+
cache = LatentCache(
70+
model, hookpoint_to_sparse_encode, batch_size=cache_cfg.batch_size
71+
)
7072

7173
# Generate mock tokens and run the cache
7274
tokens = mock_dataset

examples/caching_activations.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@
8787
" hookpoints=[\"layer_10/width_16k/average_l0_39\"],\n",
8888
")\n",
8989
"\n",
90-
"hookpoint_to_sae_encode = load_sparse_coders(model, run_cfg)"
90+
"hookpoint_to_sparse_encode = load_sparse_coders(model, run_cfg)"
9191
]
9292
},
9393
{
@@ -134,7 +134,7 @@
134134
"\n",
135135
"cache = LatentCache(\n",
136136
" model,\n",
137-
" hookpoint_to_sae_encode,\n",
137+
" hookpoint_to_sparse_encode,\n",
138138
" batch_size = cfg.batch_size,\n",
139139
")"
140140
]

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ dependencies = [
2525
"sentence_transformers"
2626
]
2727

28+
[project.optional-dependencies]
29+
dev = ["pytest"]
30+
2831
[tool.pyright]
2932
include = ["delphi*"]
3033
reportPrivateImportUsage = false

0 commit comments

Comments
 (0)