-
Notifications
You must be signed in to change notification settings - Fork 60
/
sft_vlm.py
128 lines (112 loc) · 4.63 KB
/
sft_vlm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
pip install deepspeed
pip install trl
pip install pillow
pip install transformers>=4.45.1
# Tested on 8x H100 GPUs
accelerate launch --config_file=deepspeed_zero3.yaml scripts/sft_vlm.py \
--model_name_or_path meta-llama/Llama-3.2-11B-Vision-Instruct \
--dataset_name HuggingFaceH4/llava-instruct-mix-vsft \
--per_device_train_batch_size 8 \
--gradient_accumulation_steps 8 \
--output_dir llama-3.2-11b-vision-sft \
--bf16 \
--torch_dtype bfloat16 \
--gradient_checkpointing
"""
from trl.commands.cli_utils import SFTScriptArguments, TrlParser
import torch
from accelerate import Accelerator
from datasets import load_dataset
from transformers import AutoModelForVision2Seq, AutoProcessor, LlavaForConditionalGeneration
from trl import (
ModelConfig,
SFTConfig,
SFTTrainer,
get_peft_config,
get_quantization_config,
get_kbit_device_map,
)
if __name__ == "__main__":
parser = TrlParser((SFTScriptArguments, SFTConfig, ModelConfig))
sft_script_args, training_args, model_config = parser.parse_args_and_config()
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
training_args.dataset_text_field = "" # need a dummy field
training_args.remove_unused_columns = False
training_args.dataset_kwargs = {"skip_prepare_dataset": True}
################
# Model, Tokenizer & Processor
################
torch_dtype = (
model_config.torch_dtype
if model_config.torch_dtype in ["auto", None]
else getattr(torch, model_config.torch_dtype)
)
quantization_config = get_quantization_config(model_config)
model_kwargs = dict(
revision=model_config.model_revision,
attn_implementation=model_config.attn_implementation,
torch_dtype=torch_dtype,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
processor = AutoProcessor.from_pretrained(
model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code
)
model = AutoModelForVision2Seq.from_pretrained(
model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs
)
################
# Create a data collator to encode text and image pairs
################
def collate_fn(examples):
# Get the texts and images, and apply the chat template
texts = [processor.apply_chat_template(example["messages"], tokenize=False) for example in examples]
images = [example["images"] for example in examples]
if isinstance(model, LlavaForConditionalGeneration):
# LLava1.5 does not support multiple images
images = [image[0] for image in images]
# Tokenize the texts and process the images
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
# The labels are the input_ids, and we mask the padding tokens in the loss computation
labels = batch["input_ids"].clone()
labels[labels == processor.tokenizer.pad_token_id] = -100 #
# Ignore the image token index in the loss computation (model specific)
image_token_id = processor.tokenizer.convert_tokens_to_ids(processor.image_token)
labels[labels == image_token_id] = -100
batch["labels"] = labels
return batch
################
# Dataset
################
dataset = load_dataset(sft_script_args.dataset_name)
################
# Training
################
trainer = SFTTrainer(
model=model,
args=training_args,
data_collator=collate_fn,
train_dataset=dataset[sft_script_args.dataset_train_split],
eval_dataset=dataset[sft_script_args.dataset_test_split],
tokenizer=processor.tokenizer,
peft_config=get_peft_config(model_config),
)
trainer.train()
trainer.save_model(training_args.output_dir)
trainer.push_to_hub()
if Accelerator().is_main_process:
processor.push_to_hub(training_args.hub_model_id)