Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add swanlab for experiment tracking and visualization. #6401

Merged
merged 10 commits into from
Dec 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion src/llamafactory/hparams/finetuning_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,37 @@ class BAdamArgument:


@dataclass
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, BAdamArgument):
class SwanLabArguments:
use_swanlab: bool = field(
default=False,
metadata={"help": "Whether or not to use the SwanLab (an experiment tracking and visualization tool)."},
)
swanlab_project: str = field(
default="llamafactory",
metadata={"help": "The project name in SwanLab."},
)
swanlab_workspace: str = field(
default=None,
metadata={"help": "The workspace name in SwanLab."},
)
swanlab_experiment_name: str = field(
default=None,
metadata={"help": "The experiment name in SwanLab."},
)
swanlab_mode: Literal["cloud", "local"] = field(
default="cloud",
metadata={"help": "The mode of SwanLab."},
)
swanlab_api_key: str = field(
default=None,
metadata={"help": "The API key for SwanLab."},
)


@dataclass
class FinetuningArguments(
FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, BAdamArgument, SwanLabArguments
):
r"""
Arguments pertaining to which techniques we are going to fine-tuning with.
"""
Expand Down
5 changes: 4 additions & 1 deletion src/llamafactory/train/dpo/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from ...extras.constants import IGNORE_INDEX
from ...extras.packages import is_transformers_version_equal_to_4_46
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps, get_swanlab_callback


if TYPE_CHECKING:
Expand Down Expand Up @@ -106,6 +106,9 @@ def __init__(
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback)

if finetuning_args.use_swanlab:
self.add_callback(get_swanlab_callback(finetuning_args))

@override
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
Expand Down
5 changes: 4 additions & 1 deletion src/llamafactory/train/kto/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from ...extras.constants import IGNORE_INDEX
from ...extras.packages import is_transformers_version_equal_to_4_46
from ..callbacks import SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps, get_swanlab_callback


if TYPE_CHECKING:
Expand Down Expand Up @@ -101,6 +101,9 @@ def __init__(
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback)

if finetuning_args.use_swanlab:
self.add_callback(get_swanlab_callback(finetuning_args))

@override
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
Expand Down
5 changes: 4 additions & 1 deletion src/llamafactory/train/ppo/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from ...extras import logging
from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor
from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_swanlab_callback
from .ppo_utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm


Expand Down Expand Up @@ -186,6 +186,9 @@ def __init__(
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback)

if finetuning_args.use_swanlab:
self.add_callback(get_swanlab_callback(finetuning_args))

def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None:
r"""
Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.
Expand Down
5 changes: 4 additions & 1 deletion src/llamafactory/train/pt/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_swanlab_callback


if TYPE_CHECKING:
Expand Down Expand Up @@ -56,6 +56,9 @@ def __init__(
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback)

if finetuning_args.use_swanlab:
self.add_callback(get_swanlab_callback(finetuning_args))

@override
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
Expand Down
5 changes: 4 additions & 1 deletion src/llamafactory/train/rm/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from ...extras import logging
from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than
from ..callbacks import FixValueHeadModelCallback, PissaConvertCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_swanlab_callback


if TYPE_CHECKING:
Expand Down Expand Up @@ -68,6 +68,9 @@ def __init__(
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback)

if finetuning_args.use_swanlab:
self.add_callback(get_swanlab_callback(finetuning_args))

@override
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
Expand Down
5 changes: 4 additions & 1 deletion src/llamafactory/train/sft/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from ...extras.constants import IGNORE_INDEX
from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_swanlab_callback


if TYPE_CHECKING:
Expand Down Expand Up @@ -71,6 +71,9 @@ def __init__(
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback)

if finetuning_args.use_swanlab:
self.add_callback(get_swanlab_callback(finetuning_args))
Zeyi-Lin marked this conversation as resolved.
Show resolved Hide resolved

@override
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
Expand Down
23 changes: 22 additions & 1 deletion src/llamafactory/train/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@


if TYPE_CHECKING:
from transformers import PreTrainedModel, Seq2SeqTrainingArguments
from transformers import PreTrainedModel, Seq2SeqTrainingArguments, TrainerCallback
from trl import AutoModelForCausalLMWithValueHead

from ..hparams import DataArguments
Expand Down Expand Up @@ -457,3 +457,24 @@ def get_batch_logps(
labels[labels == label_pad_token_id] = 0 # dummy token
per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
return (per_token_logps * loss_mask).sum(-1), loss_mask.sum(-1)


def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCallback":
r"""
Gets the callback for logging to SwanLab.
"""
import swanlab
from swanlab.integration.transformers import SwanLabCallback

if finetuning_args.swanlab_api_key is not None:
swanlab.login(api_key=finetuning_args.swanlab_api_key)

swanlab_callback = SwanLabCallback(
project=finetuning_args.swanlab_project,
workspace=finetuning_args.swanlab_workspace,
experiment_name=finetuning_args.swanlab_experiment_name,
mode=finetuning_args.swanlab_mode,
config={"Framework": "🦙LLaMA Factory"},
)

return swanlab_callback
22 changes: 22 additions & 0 deletions src/llamafactory/webui/components/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,28 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
)
)

with gr.Accordion(open=False) as swanlab_tab:
with gr.Row():
use_swanlab = gr.Checkbox()
swanlab_project = gr.Textbox(value="llamafactory", placeholder="Project name", interactive=True)
swanlab_experiment_name = gr.Textbox(value="", placeholder="Experiment name", interactive=True)
swanlab_workspace = gr.Textbox(value="", placeholder="Workspace name", interactive=True)
swanlab_api_key = gr.Textbox(value="", placeholder="API key", interactive=True)
swanlab_mode = gr.Dropdown(choices=["cloud", "local", "disabled"], value="cloud", interactive=True)

input_elems.update({use_swanlab, swanlab_api_key, swanlab_project, swanlab_workspace, swanlab_experiment_name, swanlab_mode})
elem_dict.update(
dict(
swanlab_tab=swanlab_tab,
use_swanlab=use_swanlab,
swanlab_api_key=swanlab_api_key,
swanlab_project=swanlab_project,
swanlab_workspace=swanlab_workspace,
swanlab_experiment_name=swanlab_experiment_name,
swanlab_mode=swanlab_mode,
)
)

with gr.Row():
cmd_preview_btn = gr.Button()
arg_save_btn = gr.Button()
Expand Down
115 changes: 115 additions & 0 deletions src/llamafactory/webui/locales.py
Original file line number Diff line number Diff line change
Expand Up @@ -1353,6 +1353,121 @@
"info": "비율-BAdam의 업데이트 비율.",
},
},
"swanlab_tab": {
"en": {
"label": "SwanLab configurations",
},
"ru": {
"label": "Конфигурации SwanLab",
},
"zh": {
"label": "SwanLab 参数设置",
},
"ko": {
"label": "SwanLab 설정",
},
},
"use_swanlab": {
"en": {
"label": "Use SwanLab",
"info": "Enable SwanLab for experiment tracking and visualization.",
},
"ru": {
"label": "Использовать SwanLab",
"info": "Включить SwanLab для отслеживания и визуализации экспериментов.",
},
"zh": {
"label": "使用 SwanLab",
"info": "启用 SwanLab 进行实验跟踪和可视化。",
},
"ko": {
"label": "SwanLab 사용",
"info": "SwanLab를 사용하여 실험을 추적하고 시각화합니다.",
},
},
"swanlab_api_key": {
"en": {
"label": "API Key(optional)",
"info": "API key for SwanLab. Once logged in, no need to login again in the programming environment.",
},
"ru": {
"label": "API ключ(Необязательный)",
"info": "API ключ для SwanLab. После входа в программное окружение, нет необходимости входить снова.",
},
"zh": {
"label": "API密钥(选填)",
"info": "用于在编程环境登录SwanLab,已登录则无需填写。",
},
"ko": {
"label": "API 키(선택 사항)",
"info": "SwanLab의 API 키. 프로그래밍 환경에 로그인한 후 다시 로그인할 필요가 없습니다.",
},
},
"swanlab_project": {
"en": {
"label": "Project(optional)",
},
"ru": {
"label": "Проект(Необязательный)",
},
"zh": {
"label": "项目(选填)",
},
"ko": {
"label": "프로젝트(선택 사항)",
},
},
"swanlab_workspace": {
"en": {
"label": "Workspace(optional)",
"info": "Workspace for SwanLab. If not filled, it defaults to the personal workspace.",

},
"ru": {
"label": "Рабочая область(Необязательный)",
"info": "Рабочая область SwanLab, если не заполнено, то по умолчанию в личной рабочей области.",
},
"zh": {
"label": "Workspace(选填)",
"info": "SwanLab组织的工作区,如不填写则默认在个人工作区下",
},
"ko": {
"label": "작업 영역(선택 사항)",
"info": "SwanLab 조직의 작업 영역, 비어 있으면 기본적으로 개인 작업 영역에 있습니다.",
},
},
"swanlab_experiment_name": {
"en": {
"label": "Experiment name (optional)",
},
"ru": {
"label": "Имя эксперимента(Необязательный)",
},
"zh": {
"label": "实验名(选填) ",
},
"ko": {
"label": "실험 이름(선택 사항)",
},
},
"swanlab_mode": {
"en": {
"label": "Mode",
"info": "Cloud or offline version.",
},
"ru": {
"label": "Режим",
"info": "Версия в облаке или локальная версия.",
},
"zh": {
"label": "模式",
"info": "云端版或离线版",
},
"ko": {
"label": "모드",
"info": "클라우드 버전 또는 오프라인 버전.",
},
},
"cmd_preview_btn": {
"en": {
"value": "Preview command",
Expand Down
10 changes: 10 additions & 0 deletions src/llamafactory/webui/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def _parse_train_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
report_to="all" if get("train.report_to") else "none",
use_galore=get("train.use_galore"),
use_badam=get("train.use_badam"),
use_swanlab=get("train.use_swanlab"),
output_dir=get_save_dir(model_name, finetuning_type, get("train.output_dir")),
fp16=(get("train.compute_type") == "fp16"),
bf16=(get("train.compute_type") == "bf16"),
Expand Down Expand Up @@ -228,6 +229,15 @@ def _parse_train_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
args["badam_switch_interval"] = get("train.badam_switch_interval")
args["badam_update_ratio"] = get("train.badam_update_ratio")

# swanlab config
if get("train.use_swanlab"):
args["swanlab_api_key"] = get("train.swanlab_api_key")
args["swanlab_project"] = get("train.swanlab_project")
args["swanlab_workspace"] = get("train.swanlab_workspace")
args["swanlab_experiment_name"] = get("train.swanlab_experiment_name")
args["swanlab_mode"] = get("train.swanlab_mode")


# eval config
if get("train.val_size") > 1e-6 and args["stage"] != "ppo":
args["val_size"] = get("train.val_size")
Expand Down
Loading