Skip to content

Commit 30c90f4

Browse files
authored
Merge pull request #813 from TransformerLensOrg/dev
Release 2.10
2 parents 3267a43 + 358eba7 commit 30c90f4

File tree

14 files changed

+235
-73
lines changed

14 files changed

+235
-73
lines changed

.devcontainer/Dockerfile

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# If .venv is already setup with python3.8, it will use python3.8. To use 3.11 remove it first.
2+
13
# Use Nvidia Ubuntu 20 base (includes CUDA if a supported GPU is present)
24
# https://hub.docker.com/r/nvidia/cuda
35
FROM nvidia/cuda:11.6.2-cudnn8-devel-ubuntu20.04@sha256:55211df43bf393d3393559d5ab53283d4ebc3943d802b04546a24f3345825bd9
@@ -17,18 +19,26 @@ RUN groupadd --gid $USER_GID $USERNAME \
1719
&& chmod 0440 /etc/sudoers.d/$USERNAME
1820

1921
# Install dependencies
20-
RUN sudo apt-get update && \
22+
RUN apt-get update && \
23+
DEBIAN_FRONTEND=noninteractive apt-get -qq -y install \
24+
software-properties-common && \
25+
add-apt-repository -y ppa:deadsnakes/ppa && \
26+
apt-get update && \
2127
DEBIAN_FRONTEND=noninteractive apt-get -qq -y install \
2228
build-essential \
23-
python3.9 \
24-
python3.9-dev \
25-
python3.9-distutils \
26-
python3.9-venv \
29+
python3.11 \
30+
python3.11-dev \
31+
python3.11-distutils \
32+
python3.11-venv \
2733
curl \
28-
git
34+
git && \
35+
# Update python3 default to point to python3.11
36+
update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.8 1 && \
37+
update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.11 2 && \
38+
update-alternatives --set python3 /usr/bin/python3.11
2939

3040
# User the new user
3141
USER $USERNAME
3242

3343
# Install poetry
34-
RUN curl -sSL https://install.python-poetry.org | python3 -
44+
RUN curl -sSL https://install.python-poetry.org | python3.11 -

.github/workflows/checks.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ jobs:
6969
poetry install --with dev
7070
- name: Unit Test
7171
run: make unit-test
72+
env:
73+
HF_TOKEN: ${{ vars.HF_TOKEN }}
7274
- name: Acceptance Test
7375
run: make acceptance-test
7476
- name: Build check
@@ -106,6 +108,8 @@ jobs:
106108
run: poetry run mypy .
107109
- name: Test Suite with Coverage Report
108110
run: make coverage-report-test
111+
env:
112+
HF_TOKEN: ${{ vars.HF_TOKEN }}
109113
- name: Build check
110114
run: poetry build
111115
- name: Upload Coverage Report Artifact
@@ -195,7 +199,7 @@ jobs:
195199
- name: Build Docs
196200
run: poetry run build-docs
197201
env:
198-
HF_TOKEN: ${{ secrets.HF_TOKEN }}
202+
HF_TOKEN: ${{ vars.HF_TOKEN }}
199203
- name: Upload Docs Artifact
200204
uses: actions/upload-artifact@v3
201205
with:

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,4 @@ docs/build
1919
.pylintrc
2020
docs/source/generated
2121
**.orig
22+
.venv

demos/Colab_Compatibility.ipynb

Lines changed: 70 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616
"name": "stderr",
1717
"output_type": "stream",
1818
"text": [
19-
"/var/folders/m3/z6c6rcdj1rbb2jh9vqpgvxg40000gn/T/ipykernel_48359/2396058561.py:18: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n",
19+
"/var/folders/m3/z6c6rcdj1rbb2jh9vqpgvxg40000gn/T/ipykernel_57027/2944939757.py:18: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n",
2020
" ipython.magic(\"load_ext autoreload\")\n",
21-
"/var/folders/m3/z6c6rcdj1rbb2jh9vqpgvxg40000gn/T/ipykernel_48359/2396058561.py:19: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n",
21+
"/var/folders/m3/z6c6rcdj1rbb2jh9vqpgvxg40000gn/T/ipykernel_57027/2944939757.py:19: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n",
2222
" ipython.magic(\"autoreload 2\")\n"
2323
]
2424
}
@@ -51,28 +51,28 @@
5151
" %pip install transformers>=4.31.0 # Llama requires transformers>=4.31.0 and transformers in turn requires Python 3.8\n",
5252
" %pip install torch\n",
5353
" %pip install tiktoken\n",
54-
" %pip install transformer_lens\n",
54+
" # %pip install transformer_lens\n",
5555
" %pip install transformers_stream_generator\n",
5656
" # !huggingface-cli login --token NEEL'S TOKEN"
5757
]
5858
},
5959
{
6060
"cell_type": "code",
61-
"execution_count": 4,
61+
"execution_count": 2,
6262
"metadata": {},
6363
"outputs": [
6464
{
6565
"name": "stdout",
6666
"output_type": "stream",
6767
"text": [
68-
"TransformerLens currently supports 190 models out of the box.\n"
68+
"TransformerLens currently supports 205 models out of the box.\n"
6969
]
7070
}
7171
],
7272
"source": [
7373
"import torch\n",
7474
"\n",
75-
"from transformer_lens import HookedTransformer, HookedEncoderDecoder, loading\n",
75+
"from transformer_lens import HookedTransformer, HookedEncoderDecoder, HookedEncoder, loading\n",
7676
"from transformers import AutoTokenizer, LlamaForCausalLM, LlamaTokenizer\n",
7777
"from typing import List\n",
7878
"import gc\n",
@@ -144,11 +144,11 @@
144144
" inputs = tokenizer(prompt, return_tensors=\"pt\")\n",
145145
" input_ids = inputs[\"input_ids\"]\n",
146146
" attention_mask = inputs[\"attention_mask\"]\n",
147-
" decoder_input_ids = torch.tensor([[model.cfg.decoder_start_token_id]]).to(input_ids.device)\n",
147+
" decoder_input_ids = torch.tensor([[tl_model.cfg.decoder_start_token_id]]).to(input_ids.device)\n",
148148
"\n",
149149
"\n",
150150
" while True:\n",
151-
" logits = model.forward(input=input_ids, one_zero_attention_mask=attention_mask, decoder_input=decoder_input_ids)\n",
151+
" logits = tl_model.forward(input=input_ids, one_zero_attention_mask=attention_mask, decoder_input=decoder_input_ids)\n",
152152
" # logits.shape == (batch_size (1), predicted_pos, vocab_size)\n",
153153
"\n",
154154
" token_idx = torch.argmax(logits[0, -1, :]).item()\n",
@@ -160,7 +160,29 @@
160160
" # break if End-Of-Sequence token generated\n",
161161
" if token_idx == tokenizer.eos_token_id:\n",
162162
" break\n",
163-
" print(tl_model.generate(\"Hello my name is\"))\n",
163+
" del tl_model\n",
164+
" gc.collect()\n",
165+
" if IN_COLAB:\n",
166+
" %rm -rf /root/.cache/huggingface/hub/models*\n",
167+
"\n",
168+
"def run_encoder_only_set(model_set: List[str], device=\"cuda\") -> None:\n",
169+
" for model in model_set:\n",
170+
" print(\"Testing \" + model)\n",
171+
" tokenizer = AutoTokenizer.from_pretrained(\"bert-base-cased\")\n",
172+
" tl_model = HookedEncoder.from_pretrained(model, device=device)\n",
173+
"\n",
174+
" if GENERATE:\n",
175+
" # Slightly adapted version of the BERT demo\n",
176+
" prompt = \"The capital of France is [MASK].\"\n",
177+
"\n",
178+
" input_ids = tokenizer(prompt, return_tensors=\"pt\")[\"input_ids\"]\n",
179+
"\n",
180+
" logprobs = tl_model(input_ids)[input_ids == tokenizer.mask_token_id].log_softmax(dim=-1)\n",
181+
" prediction = tokenizer.decode(logprobs.argmax(dim=-1).item())\n",
182+
"\n",
183+
" print(f\"Prompt: {prompt}\")\n",
184+
" print(f'Prediction: \"{prediction}\"')\n",
185+
"\n",
164186
" del tl_model\n",
165187
" gc.collect()\n",
166188
" if IN_COLAB:\n",
@@ -169,15 +191,14 @@
169191
},
170192
{
171193
"cell_type": "code",
172-
"execution_count": 17,
194+
"execution_count": 4,
173195
"metadata": {},
174196
"outputs": [],
175197
"source": [
176198
"# The following models can run in the T4 free environment\n",
177199
"free_compatible = [\n",
178200
" \"ai-forever/mGPT\",\n",
179201
" \"ArthurConmy/redwood_attn_2l\",\n",
180-
" \"bert-base-cased\",\n",
181202
" \"bigcode/santacoder\",\n",
182203
" \"bigscience/bloom-1b1\",\n",
183204
" \"bigscience/bloom-560m\",\n",
@@ -256,6 +277,10 @@
256277
" \"Qwen/Qwen2-0.5B-Instruct\",\n",
257278
" \"Qwen/Qwen2-1.5B\",\n",
258279
" \"Qwen/Qwen2-1.5B-Instruct\",\n",
280+
" \"Qwen/Qwen2.5-0.5B\",\n",
281+
" \"Qwen/Qwen2.5-0.5B-Instruct\",\n",
282+
" \"Qwen/Qwen2.5-1.5B\",\n",
283+
" \"Qwen/Qwen2.5-1.5B-Instruct\",\n",
259284
" \"roneneldan/TinyStories-1Layer-21M\",\n",
260285
" \"roneneldan/TinyStories-1M\",\n",
261286
" \"roneneldan/TinyStories-28M\",\n",
@@ -290,7 +315,7 @@
290315
},
291316
{
292317
"cell_type": "code",
293-
"execution_count": 18,
318+
"execution_count": 5,
294319
"metadata": {},
295320
"outputs": [],
296321
"source": [
@@ -340,6 +365,10 @@
340365
" \"Qwen/Qwen1.5-7B-Chat\",\n",
341366
" \"Qwen/Qwen2-7B\",\n",
342367
" \"Qwen/Qwen2-7B-Instruct\",\n",
368+
" \"Qwen/Qwen2.5-3B\",\n",
369+
" \"Qwen/Qwen2.5-3B-Instruct\",\n",
370+
" \"Qwen/Qwen2.5-7B\",\n",
371+
" \"Qwen/Qwen2.5-7B-Instruct\",\n",
343372
" \"stabilityai/stablelm-base-alpha-3b\",\n",
344373
" \"stabilityai/stablelm-base-alpha-7b\",\n",
345374
" \"stabilityai/stablelm-tuned-alpha-3b\",\n",
@@ -354,7 +383,7 @@
354383
},
355384
{
356385
"cell_type": "code",
357-
"execution_count": 19,
386+
"execution_count": 6,
358387
"metadata": {},
359388
"outputs": [],
360389
"source": [
@@ -374,6 +403,8 @@
374403
" \"Qwen/Qwen-14B-Chat\",\n",
375404
" \"Qwen/Qwen1.5-14B\",\n",
376405
" \"Qwen/Qwen1.5-14B-Chat\",\n",
406+
" \"Qwen/Qwen2.5-14B\",\n",
407+
" \"Qwen/Qwen2.5-14B-Instruct\",\n",
377408
"]\n",
378409
"\n",
379410
"if IN_COLAB:\n",
@@ -384,7 +415,7 @@
384415
},
385416
{
386417
"cell_type": "code",
387-
"execution_count": 20,
418+
"execution_count": 7,
388419
"metadata": {},
389420
"outputs": [],
390421
"source": [
@@ -402,14 +433,19 @@
402433
" \"meta-llama/Meta-Llama-3-70B-Instruct\",\n",
403434
" \"mistralai/Mixtral-8x7B-Instruct-v0.1\",\n",
404435
" \"mistralai/Mixtral-8x7B-v0.1\",\n",
436+
" \"Qwen/Qwen2.5-32B\",\n",
437+
" \"Qwen/Qwen2.5-32B-Instruct\",\n",
438+
" \"Qwen/Qwen2.5-72B\",\n",
439+
" \"Qwen/Qwen2.5-72B-Instruct\",\n",
440+
" \"Qwen/QwQ-32B-Preview\",\n",
405441
"]\n",
406442
"\n",
407443
"mark_models_as_tested(incompatible_models)"
408444
]
409445
},
410446
{
411447
"cell_type": "code",
412-
"execution_count": 21,
448+
"execution_count": 8,
413449
"metadata": {},
414450
"outputs": [],
415451
"source": [
@@ -431,7 +467,7 @@
431467
},
432468
{
433469
"cell_type": "code",
434-
"execution_count": 22,
470+
"execution_count": 9,
435471
"metadata": {},
436472
"outputs": [],
437473
"source": [
@@ -449,7 +485,22 @@
449485
},
450486
{
451487
"cell_type": "code",
452-
"execution_count": 23,
488+
"execution_count": 10,
489+
"metadata": {},
490+
"outputs": [],
491+
"source": [
492+
"# This model works on the free version of Colab\n",
493+
"encoder_only_models = [\"bert-base-cased\"]\n",
494+
"\n",
495+
"if IN_COLAB:\n",
496+
" run_encoder_only_set(encoder_only_models)\n",
497+
"\n",
498+
"mark_models_as_tested(encoder_only_models)"
499+
]
500+
},
501+
{
502+
"cell_type": "code",
503+
"execution_count": 11,
453504
"metadata": {},
454505
"outputs": [],
455506
"source": [
@@ -460,7 +511,7 @@
460511
},
461512
{
462513
"cell_type": "code",
463-
"execution_count": 24,
514+
"execution_count": 12,
464515
"metadata": {},
465516
"outputs": [
466517
{
@@ -499,5 +550,5 @@
499550
}
500551
},
501552
"nbformat": 4,
502-
"nbformat_minor": 2
553+
"nbformat_minor": 4
503554
}

demos/Main_Demo.ipynb

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,26 @@
429429
"cv.attention.attention_patterns(tokens=gpt2_str_tokens, attention=attention_pattern)"
430430
]
431431
},
432+
{
433+
"cell_type": "markdown",
434+
"metadata": {},
435+
"source": [
436+
"In this case, we only wanted the layer 0 attention patterns, but we are storing the internal activations from all locations in the model. It's convenient to have access to all activations, but this can be prohibitively expensive for memory use with larger models, batch sizes, or sequence lengths. In addition, we don't need to do the full forward pass through the model to collect layer 0 attention patterns. The following cell will collect only the layer 0 attention patterns and stop the forward pass at layer 1, requiring far less memory and compute."
437+
]
438+
},
439+
{
440+
"cell_type": "code",
441+
"execution_count": null,
442+
"metadata": {},
443+
"outputs": [],
444+
"source": [
445+
"attn_hook_name = \"blocks.0.attn.hook_pattern\"\n",
446+
"attn_layer = 0\n",
447+
"_, gpt2_attn_cache = model.run_with_cache(gpt2_tokens, remove_batch_dim=True, stop_at_layer=attn_layer + 1, names_filter=[attn_hook_name])\n",
448+
"gpt2_attn = gpt2_attn_cache[attn_hook_name]\n",
449+
"assert torch.equal(gpt2_attn, attention_pattern)"
450+
]
451+
},
432452
{
433453
"attachments": {},
434454
"cell_type": "markdown",

tests/unit/components/test_attention.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
import einops
12
import pytest
23
import torch
34
import torch.nn as nn
45
from transformers.utils import is_bitsandbytes_available
56

67
from transformer_lens.components import Attention
78
from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
9+
from transformer_lens.utilities.attention import complex_attn_linear
810

911
if is_bitsandbytes_available():
1012
from bitsandbytes.nn.modules import Params4bit
@@ -98,3 +100,31 @@ def test_attention_config_dict():
98100
assert attn.cfg.load_in_4bit == False
99101
assert attn.cfg.dtype == torch.float32
100102
assert attn.cfg.act_fn == "relu"
103+
104+
105+
def test_remove_einsum_from_complex_attn_linear():
106+
batch = 64
107+
pos = 128
108+
head_index = 8
109+
d_model = 512
110+
d_head = 64
111+
input = torch.randn(batch, pos, head_index, d_model)
112+
w = torch.randn(head_index, d_model, d_head)
113+
b = torch.randn(head_index, d_head)
114+
result_new = complex_attn_linear(input, w, b)
115+
116+
# Check if new implementation without einsum produces correct shape
117+
assert result_new.shape == (batch, pos, head_index, d_head)
118+
119+
# Old implementation used einsum
120+
result_old = (
121+
einops.einsum(
122+
input,
123+
w,
124+
"batch pos head_index d_model, head_index d_model d_head -> batch pos head_index d_head",
125+
)
126+
+ b
127+
)
128+
129+
# Check if the results are the same
130+
assert torch.allclose(result_new, result_old, atol=1e-4)

0 commit comments

Comments
 (0)