Skip to content

Commit

Permalink
Clean up LoRA
Browse files Browse the repository at this point in the history
  • Loading branch information
vpj committed Aug 2, 2024
1 parent 957ade6 commit dc47621
Show file tree
Hide file tree
Showing 9 changed files with 78 additions and 171 deletions.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import torch.nn as nn
from transformers import AutoTokenizer
from labml_nn.transformers.LoRA import Linear, Embedding
from labml_nn.lora import Linear, Embedding

tokenizer = AutoTokenizer.from_pretrained("gpt2")

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,22 @@
{
"cells": [
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"source": [
"import torch\n",
"from torch.optim import Adam\n",
"from torch.utils.data import DataLoader, TensorDataset\n",
"from torch.utils.data import random_split\n",
"from transformers import AutoTokenizer\n",
"\n",
"from labml import tracker, experiment\n",
"from labml_nn.lora.gpt2 import GPTModel"
],
"id": "f072832ec9d346e1"
},
{
"cell_type": "code",
"id": "initial_id",
Expand Down Expand Up @@ -29,8 +46,6 @@
"id": "ac8e51ae5bbfcae7",
"metadata": {},
"source": [
"from transformers import AutoTokenizer\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(\"gpt2\")\n",
"\n",
"tokens = tokenizer.encode(text, add_special_tokens=False)"
Expand Down Expand Up @@ -64,11 +79,7 @@
"cell_type": "code",
"id": "5c4cc78ac1a02c1d",
"metadata": {},
"source": [
"import torch\n",
"\n",
"input_ids = torch.tensor(tokens).view(-1, context_length)"
],
"source": "input_ids = torch.tensor(tokens).view(-1, context_length)",
"outputs": [],
"execution_count": null
},
Expand All @@ -77,10 +88,6 @@
"id": "7037fd75e2161382",
"metadata": {},
"source": [
"from torch.utils.data import DataLoader, TensorDataset\n",
"from torch.optim import Adam\n",
"from torch.utils.data import random_split\n",
"\n",
"dataset = TensorDataset(input_ids)\n",
"\n",
"train_ratio = 0.8\n",
Expand All @@ -102,8 +109,6 @@
"id": "a98b7baa064b8494",
"metadata": {},
"source": [
"from labml_nn.transformers.LoRA.GPT2 import GPTModel\n",
"\n",
"model = GPTModel()\n",
"state_dict = torch.load('transformed.pth', weights_only=True)\n",
"\n",
Expand All @@ -128,8 +133,6 @@
"id": "e2f5076894770740",
"metadata": {},
"source": [
"from labml import tracker, experiment\n",
"\n",
"optimizer = Adam(model.parameters(), lr=5e-5)\n",
"criterion = torch.nn.CrossEntropyLoss()\n",
"\n",
Expand All @@ -143,39 +146,38 @@
" inputs = batch[0]\n",
" inputs = inputs.to(device)\n",
" labels = inputs.clone()\n",
" \n",
"\n",
" outputs = model(inputs)\n",
" \n",
"\n",
" shift_logits = outputs[..., :-1, :]\n",
" shift_labels = labels[..., 1:]\n",
" \n",
"\n",
" loss = criterion(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1))\n",
" \n",
"\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
" \n",
"\n",
" tracker.save(step, {'loss': loss})\n",
" step += 1\n",
" print(f'Epoch: {epoch + 1}, Loss: {loss.item()}')\n",
" \n",
"\n",
" test_loss = 0\n",
" for batch in test_dataloader:\n",
" inputs = batch[0]\n",
" inputs = inputs.to(device)\n",
" labels = inputs.clone()\n",
" \n",
"\n",
" outputs = model(inputs)\n",
" \n",
"\n",
" shift_logits = outputs[..., :-1, :]\n",
" shift_labels = labels[..., 1:]\n",
" \n",
"\n",
" loss = criterion(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1))\n",
" \n",
"\n",
" test_loss += loss.item()\n",
" test_loss /= len(test_dataloader)\n",
" tracker.save(step, {'test_loss': test_loss})\n",
" \n",
"\n",
"print(\"Training complete.\")"
],
Expand Down
46 changes: 46 additions & 0 deletions labml_nn/lora/transform_hf_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import torch
from transformers import AutoModelForCausalLM


def transform_hf_model():
model = AutoModelForCausalLM.from_pretrained("gpt2")

state_dict = model.state_dict()

mapping = {
'transformer.wte.weight': 'token_embedding.weight',
'transformer.wpe.weight': 'position_embedding.weight',
'transformer.ln_f.weight': 'final_norm.weight',
'transformer.ln_f.bias': 'final_norm.bias',
'lm_head.weight': 'lm_head.weight'
}

for i in range(12):
mapping[f'transformer.h.{i}.ln_1.weight'] = f'blocks.{i}.pre_norm.weight'
mapping[f'transformer.h.{i}.ln_1.bias'] = f'blocks.{i}.pre_norm.bias'
mapping[f'transformer.h.{i}.attn.c_attn.weight'] = f'blocks.{i}.attn.c_att.weight'
mapping[f'transformer.h.{i}.attn.c_attn.bias'] = f'blocks.{i}.attn.c_att.bias'
mapping[f'transformer.h.{i}.attn.c_proj.weight'] = f'blocks.{i}.attn.c_proj.weight'
mapping[f'transformer.h.{i}.attn.c_proj.bias'] = f'blocks.{i}.attn.c_proj.bias'
mapping[f'transformer.h.{i}.ln_2.weight'] = f'blocks.{i}.post_norm.weight'
mapping[f'transformer.h.{i}.ln_2.bias'] = f'blocks.{i}.post_norm.bias'
mapping[f'transformer.h.{i}.mlp.c_fc.weight'] = f'blocks.{i}.ffn.c_fc.weight'
mapping[f'transformer.h.{i}.mlp.c_fc.bias'] = f'blocks.{i}.ffn.c_fc.bias'
mapping[f'transformer.h.{i}.mlp.c_proj.weight'] = f'blocks.{i}.ffn.c_proj.weight'
mapping[f'transformer.h.{i}.mlp.c_proj.bias'] = f'blocks.{i}.ffn.c_proj.bias'

new_state_dict = {}
for old_key, new_key in mapping.items():
if old_key in state_dict:
new_state_dict[new_key] = state_dict[old_key]

# transpose weight matrices of convo 1d layers to use linear layers instead
convo_layers = ([f'blocks.{i}.ffn.c_fc.weight' for i in range(12)] +
[f'blocks.{i}.ffn.c_proj.weight' for i in range(12)] +
[f'blocks.{i}.attn.c_att.weight' for i in range(12)] +
[f'blocks.{i}.attn.c_proj.weight' for i in range(12)])

for layer in convo_layers:
new_state_dict[layer] = torch.transpose(new_state_dict[layer], 0, 1)

torch.save(new_state_dict, 'transformed.pth')
File renamed without changes.
File renamed without changes.
6 changes: 3 additions & 3 deletions labml_nn/RWKV/experiment.py → labml_nn/rwkv/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@

import torch
import torch.nn as nn
from labml_nn.RWKV.configs import RWKVConfigs
from labml_nn.rwkv.configs import RWKVConfigs

from labml_nn.RWKV import RWKV
from labml_nn.RWKV import TimeMixing
from labml_nn.rwkv import RWKV
from labml_nn.rwkv import TimeMixing
from labml import experiment
from labml.configs import option
from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
Expand Down
97 changes: 0 additions & 97 deletions labml_nn/transformers/LoRA/experiment.ipynb

This file was deleted.

44 changes: 0 additions & 44 deletions labml_nn/transformers/LoRA/load_hf.py

This file was deleted.

0 comments on commit dc47621

Please sign in to comment.