-
Notifications
You must be signed in to change notification settings - Fork 60
/
peft_finetuning.py
55 lines (44 loc) · 1.35 KB
/
peft_finetuning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# This example is a very quick showcase of partial fine-tuning the Llama 3.1 8B model
# on the IMDB dataset using QLoRA with bitsandbytes.
# In order to run this example, you'll need to install peft, trl, and bitsandbytes:
# pip install peft trl bitsandbytes
import torch
from datasets import load_dataset
from trl import SFTTrainer
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TrainingArguments
model_id = "meta-llama/Meta-Llama-3.1-8B"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
dataset = load_dataset("imdb", split="train")
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=3,
per_device_train_batch_size=4,
logging_dir='./logs',
logging_steps=10,
)
QLoRA = True
if QLoRA:
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_quant_type="nf4"
)
lora_config = LoraConfig(
r=8,
target_modules="all-linear",
bias="none",
task_type="CAUSAL_LM",
)
else:
lora_config = None
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
peft_config=lora_config,
train_dataset=dataset,
dataset_text_field="text",
)
trainer.train()