Skip to content

Commit 6ab2123

Browse files
authored
Merge pull request #76 from EleutherAI/run-cfg
remove latent cfg from populate cache method
2 parents 8c607b3 + 9160520 commit 6ab2123

File tree

2 files changed

+7
-8
lines changed

2 files changed

+7
-8
lines changed

delphi/__main__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,6 @@ def scorer_postprocess(result, score_dir):
203203

204204
def populate_cache(
205205
run_cfg: RunConfig,
206-
latent_cfg: LatentConfig,
207206
cfg: CacheConfig,
208207
model: PreTrainedModel,
209208
hookpoint_to_sae_encode: dict[str, Callable],
@@ -286,7 +285,6 @@ async def run(
286285
):
287286
populate_cache(
288287
run_cfg,
289-
latent_cfg,
290288
cache_cfg,
291289
model,
292290
hookpoint_to_sae_encode,

delphi/latents/cache.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import json
22
from collections import defaultdict
33
from pathlib import Path
4+
from typing import Callable
45

56
import numpy as np
67
import torch
78
from safetensors.numpy import save_file
8-
from torch import Tensor, nn
9+
from torch import Tensor
910
from torchtyping import TensorType
1011
from tqdm import tqdm
1112

@@ -157,7 +158,7 @@ class LatentCache:
157158
def __init__(
158159
self,
159160
model,
160-
hookpoint_to_sae_encode: dict[str, nn.Module],
161+
hookpoint_to_sparse_encode: dict[str, Callable],
161162
batch_size: int,
162163
filters: dict[str, TensorType["indices"]] | None = None,
163164
):
@@ -166,12 +167,12 @@ def __init__(
166167
167168
Args:
168169
model: The model to cache latents for.
169-
hookpoint_to_sae_encode: Dictionary of submodules to cache.
170+
hookpoint_to_sparse_encode: Dictionary of sparse encoding functions.
170171
batch_size: Size of batches for processing.
171172
filters: Filters for selecting specific latents.
172173
"""
173174
self.model = model
174-
self.hookpoint_to_sae_encode = hookpoint_to_sae_encode
175+
self.hookpoint_to_sparse_encode = hookpoint_to_sparse_encode
175176

176177
self.batch_size = batch_size
177178
self.width = None
@@ -237,12 +238,12 @@ def run(self, n_tokens: int, tokens: TensorType["batch", "seq"]):
237238

238239
with torch.no_grad():
239240
with collect_activations(
240-
self.model, list(self.hookpoint_to_sae_encode.keys())
241+
self.model, list(self.hookpoint_to_sparse_encode.keys())
241242
) as activations:
242243
self.model(batch.to(self.model.device))
243244

244245
for hookpoint, latents in activations.items():
245-
sae_latents = self.hookpoint_to_sae_encode[hookpoint](
246+
sae_latents = self.hookpoint_to_sparse_encode[hookpoint](
246247
latents
247248
)
248249
self.cache.add(sae_latents, batch, batch_number, hookpoint)

0 commit comments

Comments
 (0)