Skip to content

Commit

Permalink
Merge pull request #6401 from Zeyi-Lin/hiyouga/swanlab
Browse files Browse the repository at this point in the history
feat: add swanlab for experiment tracking and visualization.
  • Loading branch information
hiyouga authored Dec 21, 2024
2 parents c6e3c14 + 82e5d75 commit 947e22a
Show file tree
Hide file tree
Showing 11 changed files with 224 additions and 8 deletions.
32 changes: 31 additions & 1 deletion src/llamafactory/hparams/finetuning_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,37 @@ class BAdamArgument:


@dataclass
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, BAdamArgument):
class SwanLabArguments:
use_swanlab: bool = field(
default=False,
metadata={"help": "Whether or not to use the SwanLab (an experiment tracking and visualization tool)."},
)
swanlab_project: str = field(
default="llamafactory",
metadata={"help": "The project name in SwanLab."},
)
swanlab_workspace: str = field(
default=None,
metadata={"help": "The workspace name in SwanLab."},
)
swanlab_experiment_name: str = field(
default=None,
metadata={"help": "The experiment name in SwanLab."},
)
swanlab_mode: Literal["cloud", "local"] = field(
default="cloud",
metadata={"help": "The mode of SwanLab."},
)
swanlab_api_key: str = field(
default=None,
metadata={"help": "The API key for SwanLab."},
)


@dataclass
class FinetuningArguments(
FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, BAdamArgument, SwanLabArguments
):
r"""
Arguments pertaining to which techniques we are going to fine-tuning with.
"""
Expand Down
5 changes: 4 additions & 1 deletion src/llamafactory/train/dpo/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from ...extras.constants import IGNORE_INDEX
from ...extras.packages import is_transformers_version_equal_to_4_46
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps, get_swanlab_callback


if TYPE_CHECKING:
Expand Down Expand Up @@ -106,6 +106,9 @@ def __init__(
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback)

if finetuning_args.use_swanlab:
self.add_callback(get_swanlab_callback(finetuning_args))

@override
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
Expand Down
5 changes: 4 additions & 1 deletion src/llamafactory/train/kto/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from ...extras.constants import IGNORE_INDEX
from ...extras.packages import is_transformers_version_equal_to_4_46
from ..callbacks import SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps, get_swanlab_callback


if TYPE_CHECKING:
Expand Down Expand Up @@ -101,6 +101,9 @@ def __init__(
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback)

if finetuning_args.use_swanlab:
self.add_callback(get_swanlab_callback(finetuning_args))

@override
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
Expand Down
5 changes: 4 additions & 1 deletion src/llamafactory/train/ppo/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from ...extras import logging
from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor
from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_swanlab_callback
from .ppo_utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm


Expand Down Expand Up @@ -186,6 +186,9 @@ def __init__(
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback)

if finetuning_args.use_swanlab:
self.add_callback(get_swanlab_callback(finetuning_args))

def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None:
r"""
Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.
Expand Down
5 changes: 4 additions & 1 deletion src/llamafactory/train/pt/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_swanlab_callback


if TYPE_CHECKING:
Expand Down Expand Up @@ -56,6 +56,9 @@ def __init__(
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback)

if finetuning_args.use_swanlab:
self.add_callback(get_swanlab_callback(finetuning_args))

@override
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
Expand Down
5 changes: 4 additions & 1 deletion src/llamafactory/train/rm/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from ...extras import logging
from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than
from ..callbacks import FixValueHeadModelCallback, PissaConvertCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_swanlab_callback


if TYPE_CHECKING:
Expand Down Expand Up @@ -68,6 +68,9 @@ def __init__(
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback)

if finetuning_args.use_swanlab:
self.add_callback(get_swanlab_callback(finetuning_args))

@override
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
Expand Down
5 changes: 4 additions & 1 deletion src/llamafactory/train/sft/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from ...extras.constants import IGNORE_INDEX
from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_swanlab_callback


if TYPE_CHECKING:
Expand Down Expand Up @@ -71,6 +71,9 @@ def __init__(
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback)

if finetuning_args.use_swanlab:
self.add_callback(get_swanlab_callback(finetuning_args))

@override
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
Expand Down
23 changes: 22 additions & 1 deletion src/llamafactory/train/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@


if TYPE_CHECKING:
from transformers import PreTrainedModel, Seq2SeqTrainingArguments
from transformers import PreTrainedModel, Seq2SeqTrainingArguments, TrainerCallback
from trl import AutoModelForCausalLMWithValueHead

from ..hparams import DataArguments
Expand Down Expand Up @@ -457,3 +457,24 @@ def get_batch_logps(
labels[labels == label_pad_token_id] = 0 # dummy token
per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
return (per_token_logps * loss_mask).sum(-1), loss_mask.sum(-1)


def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCallback":
r"""
Gets the callback for logging to SwanLab.
"""
import swanlab
from swanlab.integration.transformers import SwanLabCallback

if finetuning_args.swanlab_api_key is not None:
swanlab.login(api_key=finetuning_args.swanlab_api_key)

swanlab_callback = SwanLabCallback(
project=finetuning_args.swanlab_project,
workspace=finetuning_args.swanlab_workspace,
experiment_name=finetuning_args.swanlab_experiment_name,
mode=finetuning_args.swanlab_mode,
config={"Framework": "🦙LLaMA Factory"},
)

return swanlab_callback
22 changes: 22 additions & 0 deletions src/llamafactory/webui/components/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,28 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
)
)

with gr.Accordion(open=False) as swanlab_tab:
with gr.Row():
use_swanlab = gr.Checkbox()
swanlab_project = gr.Textbox(value="llamafactory", placeholder="Project name", interactive=True)
swanlab_experiment_name = gr.Textbox(value="", placeholder="Experiment name", interactive=True)
swanlab_workspace = gr.Textbox(value="", placeholder="Workspace name", interactive=True)
swanlab_api_key = gr.Textbox(value="", placeholder="API key", interactive=True)
swanlab_mode = gr.Dropdown(choices=["cloud", "local", "disabled"], value="cloud", interactive=True)

input_elems.update({use_swanlab, swanlab_api_key, swanlab_project, swanlab_workspace, swanlab_experiment_name, swanlab_mode})
elem_dict.update(
dict(
swanlab_tab=swanlab_tab,
use_swanlab=use_swanlab,
swanlab_api_key=swanlab_api_key,
swanlab_project=swanlab_project,
swanlab_workspace=swanlab_workspace,
swanlab_experiment_name=swanlab_experiment_name,
swanlab_mode=swanlab_mode,
)
)

with gr.Row():
cmd_preview_btn = gr.Button()
arg_save_btn = gr.Button()
Expand Down
115 changes: 115 additions & 0 deletions src/llamafactory/webui/locales.py
Original file line number Diff line number Diff line change
Expand Up @@ -1353,6 +1353,121 @@
"info": "비율-BAdam의 업데이트 비율.",
},
},
"swanlab_tab": {
"en": {
"label": "SwanLab configurations",
},
"ru": {
"label": "Конфигурации SwanLab",
},
"zh": {
"label": "SwanLab 参数设置",
},
"ko": {
"label": "SwanLab 설정",
},
},
"use_swanlab": {
"en": {
"label": "Use SwanLab",
"info": "Enable SwanLab for experiment tracking and visualization.",
},
"ru": {
"label": "Использовать SwanLab",
"info": "Включить SwanLab для отслеживания и визуализации экспериментов.",
},
"zh": {
"label": "使用 SwanLab",
"info": "启用 SwanLab 进行实验跟踪和可视化。",
},
"ko": {
"label": "SwanLab 사용",
"info": "SwanLab를 사용하여 실험을 추적하고 시각화합니다.",
},
},
"swanlab_api_key": {
"en": {
"label": "API Key(optional)",
"info": "API key for SwanLab. Once logged in, no need to login again in the programming environment.",
},
"ru": {
"label": "API ключ(Необязательный)",
"info": "API ключ для SwanLab. После входа в программное окружение, нет необходимости входить снова.",
},
"zh": {
"label": "API密钥(选填)",
"info": "用于在编程环境登录SwanLab,已登录则无需填写。",
},
"ko": {
"label": "API 키(선택 사항)",
"info": "SwanLab의 API 키. 프로그래밍 환경에 로그인한 후 다시 로그인할 필요가 없습니다.",
},
},
"swanlab_project": {
"en": {
"label": "Project(optional)",
},
"ru": {
"label": "Проект(Необязательный)",
},
"zh": {
"label": "项目(选填)",
},
"ko": {
"label": "프로젝트(선택 사항)",
},
},
"swanlab_workspace": {
"en": {
"label": "Workspace(optional)",
"info": "Workspace for SwanLab. If not filled, it defaults to the personal workspace.",

},
"ru": {
"label": "Рабочая область(Необязательный)",
"info": "Рабочая область SwanLab, если не заполнено, то по умолчанию в личной рабочей области.",
},
"zh": {
"label": "Workspace(选填)",
"info": "SwanLab组织的工作区,如不填写则默认在个人工作区下",
},
"ko": {
"label": "작업 영역(선택 사항)",
"info": "SwanLab 조직의 작업 영역, 비어 있으면 기본적으로 개인 작업 영역에 있습니다.",
},
},
"swanlab_experiment_name": {
"en": {
"label": "Experiment name (optional)",
},
"ru": {
"label": "Имя эксперимента(Необязательный)",
},
"zh": {
"label": "实验名(选填) ",
},
"ko": {
"label": "실험 이름(선택 사항)",
},
},
"swanlab_mode": {
"en": {
"label": "Mode",
"info": "Cloud or offline version.",
},
"ru": {
"label": "Режим",
"info": "Версия в облаке или локальная версия.",
},
"zh": {
"label": "模式",
"info": "云端版或离线版",
},
"ko": {
"label": "모드",
"info": "클라우드 버전 또는 오프라인 버전.",
},
},
"cmd_preview_btn": {
"en": {
"value": "Preview command",
Expand Down
10 changes: 10 additions & 0 deletions src/llamafactory/webui/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def _parse_train_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
report_to="all" if get("train.report_to") else "none",
use_galore=get("train.use_galore"),
use_badam=get("train.use_badam"),
use_swanlab=get("train.use_swanlab"),
output_dir=get_save_dir(model_name, finetuning_type, get("train.output_dir")),
fp16=(get("train.compute_type") == "fp16"),
bf16=(get("train.compute_type") == "bf16"),
Expand Down Expand Up @@ -228,6 +229,15 @@ def _parse_train_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
args["badam_switch_interval"] = get("train.badam_switch_interval")
args["badam_update_ratio"] = get("train.badam_update_ratio")

# swanlab config
if get("train.use_swanlab"):
args["swanlab_api_key"] = get("train.swanlab_api_key")
args["swanlab_project"] = get("train.swanlab_project")
args["swanlab_workspace"] = get("train.swanlab_workspace")
args["swanlab_experiment_name"] = get("train.swanlab_experiment_name")
args["swanlab_mode"] = get("train.swanlab_mode")


# eval config
if get("train.val_size") > 1e-6 and args["stage"] != "ppo":
args["val_size"] = get("train.val_size")
Expand Down

0 comments on commit 947e22a

Please sign in to comment.