|
16 | 16 | "name": "stderr",
|
17 | 17 | "output_type": "stream",
|
18 | 18 | "text": [
|
19 |
| - "/var/folders/m3/z6c6rcdj1rbb2jh9vqpgvxg40000gn/T/ipykernel_48359/2396058561.py:18: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n", |
| 19 | + "/var/folders/m3/z6c6rcdj1rbb2jh9vqpgvxg40000gn/T/ipykernel_57027/2944939757.py:18: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n", |
20 | 20 | " ipython.magic(\"load_ext autoreload\")\n",
|
21 |
| - "/var/folders/m3/z6c6rcdj1rbb2jh9vqpgvxg40000gn/T/ipykernel_48359/2396058561.py:19: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n", |
| 21 | + "/var/folders/m3/z6c6rcdj1rbb2jh9vqpgvxg40000gn/T/ipykernel_57027/2944939757.py:19: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n", |
22 | 22 | " ipython.magic(\"autoreload 2\")\n"
|
23 | 23 | ]
|
24 | 24 | }
|
|
51 | 51 | " %pip install transformers>=4.31.0 # Llama requires transformers>=4.31.0 and transformers in turn requires Python 3.8\n",
|
52 | 52 | " %pip install torch\n",
|
53 | 53 | " %pip install tiktoken\n",
|
54 |
| - " %pip install transformer_lens\n", |
| 54 | + " # %pip install transformer_lens\n", |
55 | 55 | " %pip install transformers_stream_generator\n",
|
56 | 56 | " # !huggingface-cli login --token NEEL'S TOKEN"
|
57 | 57 | ]
|
58 | 58 | },
|
59 | 59 | {
|
60 | 60 | "cell_type": "code",
|
61 |
| - "execution_count": 4, |
| 61 | + "execution_count": 2, |
62 | 62 | "metadata": {},
|
63 | 63 | "outputs": [
|
64 | 64 | {
|
65 | 65 | "name": "stdout",
|
66 | 66 | "output_type": "stream",
|
67 | 67 | "text": [
|
68 |
| - "TransformerLens currently supports 190 models out of the box.\n" |
| 68 | + "TransformerLens currently supports 205 models out of the box.\n" |
69 | 69 | ]
|
70 | 70 | }
|
71 | 71 | ],
|
72 | 72 | "source": [
|
73 | 73 | "import torch\n",
|
74 | 74 | "\n",
|
75 |
| - "from transformer_lens import HookedTransformer, HookedEncoderDecoder, loading\n", |
| 75 | + "from transformer_lens import HookedTransformer, HookedEncoderDecoder, HookedEncoder, loading\n", |
76 | 76 | "from transformers import AutoTokenizer, LlamaForCausalLM, LlamaTokenizer\n",
|
77 | 77 | "from typing import List\n",
|
78 | 78 | "import gc\n",
|
|
144 | 144 | " inputs = tokenizer(prompt, return_tensors=\"pt\")\n",
|
145 | 145 | " input_ids = inputs[\"input_ids\"]\n",
|
146 | 146 | " attention_mask = inputs[\"attention_mask\"]\n",
|
147 |
| - " decoder_input_ids = torch.tensor([[model.cfg.decoder_start_token_id]]).to(input_ids.device)\n", |
| 147 | + " decoder_input_ids = torch.tensor([[tl_model.cfg.decoder_start_token_id]]).to(input_ids.device)\n", |
148 | 148 | "\n",
|
149 | 149 | "\n",
|
150 | 150 | " while True:\n",
|
151 |
| - " logits = model.forward(input=input_ids, one_zero_attention_mask=attention_mask, decoder_input=decoder_input_ids)\n", |
| 151 | + " logits = tl_model.forward(input=input_ids, one_zero_attention_mask=attention_mask, decoder_input=decoder_input_ids)\n", |
152 | 152 | " # logits.shape == (batch_size (1), predicted_pos, vocab_size)\n",
|
153 | 153 | "\n",
|
154 | 154 | " token_idx = torch.argmax(logits[0, -1, :]).item()\n",
|
|
160 | 160 | " # break if End-Of-Sequence token generated\n",
|
161 | 161 | " if token_idx == tokenizer.eos_token_id:\n",
|
162 | 162 | " break\n",
|
163 |
| - " print(tl_model.generate(\"Hello my name is\"))\n", |
| 163 | + " del tl_model\n", |
| 164 | + " gc.collect()\n", |
| 165 | + " if IN_COLAB:\n", |
| 166 | + " %rm -rf /root/.cache/huggingface/hub/models*\n", |
| 167 | + "\n", |
| 168 | + "def run_encoder_only_set(model_set: List[str], device=\"cuda\") -> None:\n", |
| 169 | + " for model in model_set:\n", |
| 170 | + " print(\"Testing \" + model)\n", |
| 171 | + " tokenizer = AutoTokenizer.from_pretrained(\"bert-base-cased\")\n", |
| 172 | + " tl_model = HookedEncoder.from_pretrained(model, device=device)\n", |
| 173 | + "\n", |
| 174 | + " if GENERATE:\n", |
| 175 | + " # Slightly adapted version of the BERT demo\n", |
| 176 | + " prompt = \"The capital of France is [MASK].\"\n", |
| 177 | + "\n", |
| 178 | + " input_ids = tokenizer(prompt, return_tensors=\"pt\")[\"input_ids\"]\n", |
| 179 | + "\n", |
| 180 | + " logprobs = tl_model(input_ids)[input_ids == tokenizer.mask_token_id].log_softmax(dim=-1)\n", |
| 181 | + " prediction = tokenizer.decode(logprobs.argmax(dim=-1).item())\n", |
| 182 | + "\n", |
| 183 | + " print(f\"Prompt: {prompt}\")\n", |
| 184 | + " print(f'Prediction: \"{prediction}\"')\n", |
| 185 | + "\n", |
164 | 186 | " del tl_model\n",
|
165 | 187 | " gc.collect()\n",
|
166 | 188 | " if IN_COLAB:\n",
|
|
169 | 191 | },
|
170 | 192 | {
|
171 | 193 | "cell_type": "code",
|
172 |
| - "execution_count": 17, |
| 194 | + "execution_count": 4, |
173 | 195 | "metadata": {},
|
174 | 196 | "outputs": [],
|
175 | 197 | "source": [
|
176 | 198 | "# The following models can run in the T4 free environment\n",
|
177 | 199 | "free_compatible = [\n",
|
178 | 200 | " \"ai-forever/mGPT\",\n",
|
179 | 201 | " \"ArthurConmy/redwood_attn_2l\",\n",
|
180 |
| - " \"bert-base-cased\",\n", |
181 | 202 | " \"bigcode/santacoder\",\n",
|
182 | 203 | " \"bigscience/bloom-1b1\",\n",
|
183 | 204 | " \"bigscience/bloom-560m\",\n",
|
|
256 | 277 | " \"Qwen/Qwen2-0.5B-Instruct\",\n",
|
257 | 278 | " \"Qwen/Qwen2-1.5B\",\n",
|
258 | 279 | " \"Qwen/Qwen2-1.5B-Instruct\",\n",
|
| 280 | + " \"Qwen/Qwen2.5-0.5B\",\n", |
| 281 | + " \"Qwen/Qwen2.5-0.5B-Instruct\",\n", |
| 282 | + " \"Qwen/Qwen2.5-1.5B\",\n", |
| 283 | + " \"Qwen/Qwen2.5-1.5B-Instruct\",\n", |
259 | 284 | " \"roneneldan/TinyStories-1Layer-21M\",\n",
|
260 | 285 | " \"roneneldan/TinyStories-1M\",\n",
|
261 | 286 | " \"roneneldan/TinyStories-28M\",\n",
|
|
290 | 315 | },
|
291 | 316 | {
|
292 | 317 | "cell_type": "code",
|
293 |
| - "execution_count": 18, |
| 318 | + "execution_count": 5, |
294 | 319 | "metadata": {},
|
295 | 320 | "outputs": [],
|
296 | 321 | "source": [
|
|
340 | 365 | " \"Qwen/Qwen1.5-7B-Chat\",\n",
|
341 | 366 | " \"Qwen/Qwen2-7B\",\n",
|
342 | 367 | " \"Qwen/Qwen2-7B-Instruct\",\n",
|
| 368 | + " \"Qwen/Qwen2.5-3B\",\n", |
| 369 | + " \"Qwen/Qwen2.5-3B-Instruct\",\n", |
| 370 | + " \"Qwen/Qwen2.5-7B\",\n", |
| 371 | + " \"Qwen/Qwen2.5-7B-Instruct\",\n", |
343 | 372 | " \"stabilityai/stablelm-base-alpha-3b\",\n",
|
344 | 373 | " \"stabilityai/stablelm-base-alpha-7b\",\n",
|
345 | 374 | " \"stabilityai/stablelm-tuned-alpha-3b\",\n",
|
|
354 | 383 | },
|
355 | 384 | {
|
356 | 385 | "cell_type": "code",
|
357 |
| - "execution_count": 19, |
| 386 | + "execution_count": 6, |
358 | 387 | "metadata": {},
|
359 | 388 | "outputs": [],
|
360 | 389 | "source": [
|
|
374 | 403 | " \"Qwen/Qwen-14B-Chat\",\n",
|
375 | 404 | " \"Qwen/Qwen1.5-14B\",\n",
|
376 | 405 | " \"Qwen/Qwen1.5-14B-Chat\",\n",
|
| 406 | + " \"Qwen/Qwen2.5-14B\",\n", |
| 407 | + " \"Qwen/Qwen2.5-14B-Instruct\",\n", |
377 | 408 | "]\n",
|
378 | 409 | "\n",
|
379 | 410 | "if IN_COLAB:\n",
|
|
384 | 415 | },
|
385 | 416 | {
|
386 | 417 | "cell_type": "code",
|
387 |
| - "execution_count": 20, |
| 418 | + "execution_count": 7, |
388 | 419 | "metadata": {},
|
389 | 420 | "outputs": [],
|
390 | 421 | "source": [
|
|
402 | 433 | " \"meta-llama/Meta-Llama-3-70B-Instruct\",\n",
|
403 | 434 | " \"mistralai/Mixtral-8x7B-Instruct-v0.1\",\n",
|
404 | 435 | " \"mistralai/Mixtral-8x7B-v0.1\",\n",
|
| 436 | + " \"Qwen/Qwen2.5-32B\",\n", |
| 437 | + " \"Qwen/Qwen2.5-32B-Instruct\",\n", |
| 438 | + " \"Qwen/Qwen2.5-72B\",\n", |
| 439 | + " \"Qwen/Qwen2.5-72B-Instruct\",\n", |
| 440 | + " \"Qwen/QwQ-32B-Preview\",\n", |
405 | 441 | "]\n",
|
406 | 442 | "\n",
|
407 | 443 | "mark_models_as_tested(incompatible_models)"
|
408 | 444 | ]
|
409 | 445 | },
|
410 | 446 | {
|
411 | 447 | "cell_type": "code",
|
412 |
| - "execution_count": 21, |
| 448 | + "execution_count": 8, |
413 | 449 | "metadata": {},
|
414 | 450 | "outputs": [],
|
415 | 451 | "source": [
|
|
431 | 467 | },
|
432 | 468 | {
|
433 | 469 | "cell_type": "code",
|
434 |
| - "execution_count": 22, |
| 470 | + "execution_count": 9, |
435 | 471 | "metadata": {},
|
436 | 472 | "outputs": [],
|
437 | 473 | "source": [
|
|
449 | 485 | },
|
450 | 486 | {
|
451 | 487 | "cell_type": "code",
|
452 |
| - "execution_count": 23, |
| 488 | + "execution_count": 10, |
| 489 | + "metadata": {}, |
| 490 | + "outputs": [], |
| 491 | + "source": [ |
| 492 | + "# This model works on the free version of Colab\n", |
| 493 | + "encoder_only_models = [\"bert-base-cased\"]\n", |
| 494 | + "\n", |
| 495 | + "if IN_COLAB:\n", |
| 496 | + " run_encoder_only_set(encoder_only_models)\n", |
| 497 | + "\n", |
| 498 | + "mark_models_as_tested(encoder_only_models)" |
| 499 | + ] |
| 500 | + }, |
| 501 | + { |
| 502 | + "cell_type": "code", |
| 503 | + "execution_count": 11, |
453 | 504 | "metadata": {},
|
454 | 505 | "outputs": [],
|
455 | 506 | "source": [
|
|
460 | 511 | },
|
461 | 512 | {
|
462 | 513 | "cell_type": "code",
|
463 |
| - "execution_count": 24, |
| 514 | + "execution_count": 12, |
464 | 515 | "metadata": {},
|
465 | 516 | "outputs": [
|
466 | 517 | {
|
|
499 | 550 | }
|
500 | 551 | },
|
501 | 552 | "nbformat": 4,
|
502 |
| - "nbformat_minor": 2 |
| 553 | + "nbformat_minor": 4 |
503 | 554 | }
|
0 commit comments