Skip to content

Commit

Permalink
【PPMix No.3】support qwen2-vl-2b lora finetune (#847)
Browse files Browse the repository at this point in the history
support qwen2-vl lora finetue
1.update file 'paddlemix/examples/qwen2_vl/qwen2vl_finetune.py' to
support lora finetune
2.add new file
'paddlemix/examples/qwen2_vl/shell/basline_2b_lora_bs32_1e8.sh'
3.update file 'paddlemix/examples/qwen2_vl/README.md'


![image](https://github.com/user-attachments/assets/a4882142-ec5f-4e79-936c-535641c0d29b)

---------

Co-authored-by: wcx <[email protected]>
Co-authored-by: nifeng <[email protected]>
  • Loading branch information
3 people authored Dec 2, 2024
1 parent 9e196f8 commit 2ef9368
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 3 deletions.
Binary file added paddlemix/demo_images/qwen2-vl-2b-lora-ft.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
7 changes: 6 additions & 1 deletion paddlemix/examples/qwen2_vl/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,20 @@ opensource_json.tar需下载解压在playground/目录下,opensource_json 里

### 4.2 微调命令

注意:此微调训练为全参数微调,冻结视觉编码器而放开LLM训练,2B模型微调训练的显存大小约为30G,7B模型微调训练的显存大小约为75G
注意:此微调训练为语言模型微调,冻结视觉编码器而放开LLM训练,2B模型全量微调训练的显存大小约为30G,7B模型全量微调训练的显存大小约为75G

```bash
# 2B
sh paddlemix/examples/qwen2_vl/shell/basline_2b_bs32_1e8.sh

# 2B lora
sh paddlemix/examples/qwen2_vl/shell/basline_2b_lora_bs32_1e8.sh

# 7B
sh paddlemix/examples/qwen2_vl/shell/basline_7b_bs32_1e8.sh
```
注意:微调2b模型的运行示例如下:
![运行示例](../../demo_images/qwen2-vl-2b-lora-ft.png)

### 4.3 微调后使用

Expand Down
49 changes: 47 additions & 2 deletions paddlemix/examples/qwen2_vl/qwen2vl_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import sys
import traceback
from dataclasses import dataclass, field
from typing import Dict, Optional
from typing import Dict, Optional, List

import numpy as np
import paddle
Expand All @@ -31,6 +31,7 @@
from paddlenlp.trainer.trainer import Trainer
from paddlenlp.trainer.trainer_utils import get_last_checkpoint
from paddlenlp.transformers import Qwen2Tokenizer
from paddlenlp.peft import LoRAConfig, LoRAModel
from PIL import Image, ImageFile, PngImagePlugin, UnidentifiedImageError

from paddlemix.datasets.internvl_dataset import ConcatDataset, WeightedConcatDataset
Expand Down Expand Up @@ -125,7 +126,31 @@ class ModelArguments(ProcessorArguments):
default=0.0,
metadata={"help": "Set the drop path rate for the ViT model. Default is 0."},
)

lora: Optional[bool] = field(
default=False,
metadata={"help": "Whether or not to use lora to train model."},
)
lora_path: Optional[str] = field(
default=None,
metadata={"help": "Initialize lora state dict."}
)
lora_rank: Optional[int] = field(
default=128,
metadata={"help": "Set the value of rank in lora. Default is 128."},
)
lora_alpha: Optional[int] = field(
default=256,
metadata={"help": "Set the value of alpha in lora. Default is 256."},
)
lora_dropout: Optional[float] = field(
default=0.0,
metadata={"help": "Set the value of dropout in lora. Default is 0.0."},
)
lora_target_modules: Optional[str] = field(
default=None,
metadata={"help": "Lora target modules."}
)


@dataclass
class DataTrainingArguments:
Expand Down Expand Up @@ -528,6 +553,26 @@ def _freeze_params(module):
model.lm_head = model.lm_head.eval()
_freeze_params(model.model)
_freeze_params(model.lm_head)


# lora
if model_args.lora:
if model_args.lora_path is None:
target_modules = model_args.lora_target_modules.split(",")
lora_config = LoRAConfig(
target_modules=target_modules,
r=model_args.lora_rank,
lora_alpha=model_args.lora_alpha,
lora_dropout=model_args.lora_dropout,
merge_weights=False,
tensor_parallel_degree=training_args.tensor_parallel_degree,
dtype=dtype,
)
model = LoRAModel(model, lora_config)
else:
model = LoRAModel.from_pretrained(model=model, lora_path=model_args.lora_path)
model.mark_only_lora_as_trainable()
model.print_trainable_parameters()

print_trainable_params(model)

Expand Down
86 changes: 86 additions & 0 deletions paddlemix/examples/qwen2_vl/shell/basline_2b_lora_bs32_1e8.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

set -x

GPUS=${GPUS:-8}
BATCH_SIZE=${BATCH_SIZE:-32}
PER_DEVICE_BATCH_SIZE=${PER_DEVICE_BATCH_SIZE:-1}

GRADIENT_ACC=$((BATCH_SIZE / PER_DEVICE_BATCH_SIZE / GPUS))
tensor_parallel_degree=${tensor_parallel_degree:-1}
sharding_parallel_degree=$((GPUS / tensor_parallel_degree))

export PYTHONPATH="${PYTHONPATH}:$(pwd)"
export MASTER_PORT=34229
export TF_CPP_MIN_LOG_LEVEL=3

OUTPUT_DIR='work_dirs/basline_330k_2b_bs32_1e8'

if [ ! -d "$OUTPUT_DIR" ]; then
mkdir -p "$OUTPUT_DIR"
fi

TRAINING_MODEL_RESUME="None"
TRAINER_INSTANCES='127.0.0.1'
MASTER='127.0.0.1:8080'

meta_path="paddlemix/examples/qwen2_vl/configs/baseline_6data_330k.json"

TRAINING_PYTHON="python -m paddle.distributed.launch --master ${MASTER} --nnodes 1 --nproc_per_node ${GPUS} --rank 0 --ips ${TRAINER_INSTANCES} --run_mode=collective"
${TRAINING_PYTHON} --log_dir ${OUTPUT_DIR}/paddle_distributed_logs \
paddlemix/examples/qwen2_vl/qwen2vl_finetune.py \
--do_train \
--model_name_or_path "Qwen/Qwen2-VL-2B-Instruct" \
--output_dir ${OUTPUT_DIR} \
--logging_dir ${OUTPUT_DIR}/logs \
--meta_path ${meta_path} \
--overwrite_output_dir True \
--dataloader_num_workers 8 \
--bf16 True \
--fp16 False \
--fp16_opt_level "O2" \
--num_train_epochs 1 \
--per_device_train_batch_size ${PER_DEVICE_BATCH_SIZE} \
--gradient_accumulation_steps ${GRADIENT_ACC} \
--freeze_vit True \
--max_seq_length 8192 \
--image_resolution 512 \
--recompute False \
--max_grad_norm 1.0 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 1000 \
--save_total_limit 1 \
--learning_rate 1e-8 \
--warmup_ratio 0.1 \
--warmup_steps 100 \
--weight_decay 0.1 \
--optim "adamw" \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--report_to "visualdl" \
--tensor_parallel_degree=${tensor_parallel_degree} \
--sharding_parallel_degree=${sharding_parallel_degree} \
--pipeline_parallel_degree=1 \
--sep_parallel_degree=1 \
--sharding="stage1" \
--amp_master_grad=1 \
--hybrid_parallel_topo_order="sharding_first" \
--lora True \
--lora_rank=12 \
--lora_alpha=256 \
--lora_dropout=0.0 \
--lora_target_modules="model.layers.*q_proj.*,model.layers.*k_proj.*,model.layers.*v_proj.*,model.layers.*gate_proj.*,model.layers.*up_proj.*,model.layers.*down_proj.*,model.layers.*o_proj.*" \
2>&1 | tee -a "${OUTPUT_DIR}/training_log.txt"

0 comments on commit 2ef9368

Please sign in to comment.