diff --git a/src/llamafactory/hparams/finetuning_args.py b/src/llamafactory/hparams/finetuning_args.py index 6d350a7321..44a3e36283 100644 --- a/src/llamafactory/hparams/finetuning_args.py +++ b/src/llamafactory/hparams/finetuning_args.py @@ -305,7 +305,37 @@ class BAdamArgument: @dataclass -class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, BAdamArgument): +class SwanLabArguments: + use_swanlab: bool = field( + default=False, + metadata={"help": "Whether or not to use the SwanLab (an experiment tracking and visualization tool)."}, + ) + swanlab_project: str = field( + default="llamafactory", + metadata={"help": "The project name in SwanLab."}, + ) + swanlab_workspace: str = field( + default=None, + metadata={"help": "The workspace name in SwanLab."}, + ) + swanlab_experiment_name: str = field( + default=None, + metadata={"help": "The experiment name in SwanLab."}, + ) + swanlab_mode: Literal["cloud", "local"] = field( + default="cloud", + metadata={"help": "The mode of SwanLab."}, + ) + swanlab_api_key: str = field( + default=None, + metadata={"help": "The API key for SwanLab."}, + ) + + +@dataclass +class FinetuningArguments( + FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, BAdamArgument, SwanLabArguments +): r""" Arguments pertaining to which techniques we are going to fine-tuning with. """ diff --git a/src/llamafactory/train/dpo/trainer.py b/src/llamafactory/train/dpo/trainer.py index 330de386b4..0ab177b42e 100644 --- a/src/llamafactory/train/dpo/trainer.py +++ b/src/llamafactory/train/dpo/trainer.py @@ -31,7 +31,7 @@ from ...extras.constants import IGNORE_INDEX from ...extras.packages import is_transformers_version_equal_to_4_46 from ..callbacks import PissaConvertCallback, SaveProcessorCallback -from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps +from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps, get_swanlab_callback if TYPE_CHECKING: @@ -106,6 +106,9 @@ def __init__( self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.add_callback(BAdamCallback) + if finetuning_args.use_swanlab: + self.add_callback(get_swanlab_callback(finetuning_args)) + @override def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: diff --git a/src/llamafactory/train/kto/trainer.py b/src/llamafactory/train/kto/trainer.py index 3d007ae70b..1c6b5fe842 100644 --- a/src/llamafactory/train/kto/trainer.py +++ b/src/llamafactory/train/kto/trainer.py @@ -30,7 +30,7 @@ from ...extras.constants import IGNORE_INDEX from ...extras.packages import is_transformers_version_equal_to_4_46 from ..callbacks import SaveProcessorCallback -from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps +from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps, get_swanlab_callback if TYPE_CHECKING: @@ -101,6 +101,9 @@ def __init__( self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.add_callback(BAdamCallback) + if finetuning_args.use_swanlab: + self.add_callback(get_swanlab_callback(finetuning_args)) + @override def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: diff --git a/src/llamafactory/train/ppo/trainer.py b/src/llamafactory/train/ppo/trainer.py index 4ab7a11879..a60b7d7cda 100644 --- a/src/llamafactory/train/ppo/trainer.py +++ b/src/llamafactory/train/ppo/trainer.py @@ -40,7 +40,7 @@ from ...extras import logging from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback -from ..trainer_utils import create_custom_optimizer, create_custom_scheduler +from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_swanlab_callback from .ppo_utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm @@ -186,6 +186,9 @@ def __init__( self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.add_callback(BAdamCallback) + if finetuning_args.use_swanlab: + self.add_callback(get_swanlab_callback(finetuning_args)) + def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None: r""" Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer. diff --git a/src/llamafactory/train/pt/trainer.py b/src/llamafactory/train/pt/trainer.py index 2e77ba4320..ab4b008171 100644 --- a/src/llamafactory/train/pt/trainer.py +++ b/src/llamafactory/train/pt/trainer.py @@ -21,7 +21,7 @@ from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than from ..callbacks import PissaConvertCallback, SaveProcessorCallback -from ..trainer_utils import create_custom_optimizer, create_custom_scheduler +from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_swanlab_callback if TYPE_CHECKING: @@ -56,6 +56,9 @@ def __init__( self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.add_callback(BAdamCallback) + if finetuning_args.use_swanlab: + self.add_callback(get_swanlab_callback(finetuning_args)) + @override def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: diff --git a/src/llamafactory/train/rm/trainer.py b/src/llamafactory/train/rm/trainer.py index 4b7408377e..244f460e67 100644 --- a/src/llamafactory/train/rm/trainer.py +++ b/src/llamafactory/train/rm/trainer.py @@ -27,7 +27,7 @@ from ...extras import logging from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than from ..callbacks import FixValueHeadModelCallback, PissaConvertCallback, SaveProcessorCallback -from ..trainer_utils import create_custom_optimizer, create_custom_scheduler +from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_swanlab_callback if TYPE_CHECKING: @@ -68,6 +68,9 @@ def __init__( self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.add_callback(BAdamCallback) + if finetuning_args.use_swanlab: + self.add_callback(get_swanlab_callback(finetuning_args)) + @override def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index b10ebf6900..d1510bb2c4 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -29,7 +29,7 @@ from ...extras.constants import IGNORE_INDEX from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than from ..callbacks import PissaConvertCallback, SaveProcessorCallback -from ..trainer_utils import create_custom_optimizer, create_custom_scheduler +from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_swanlab_callback if TYPE_CHECKING: @@ -71,6 +71,9 @@ def __init__( self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.add_callback(BAdamCallback) + if finetuning_args.use_swanlab: + self.add_callback(get_swanlab_callback(finetuning_args)) + @override def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: diff --git a/src/llamafactory/train/trainer_utils.py b/src/llamafactory/train/trainer_utils.py index 7d916ec156..5b8fb4037a 100644 --- a/src/llamafactory/train/trainer_utils.py +++ b/src/llamafactory/train/trainer_utils.py @@ -40,7 +40,7 @@ if TYPE_CHECKING: - from transformers import PreTrainedModel, Seq2SeqTrainingArguments + from transformers import PreTrainedModel, Seq2SeqTrainingArguments, TrainerCallback from trl import AutoModelForCausalLMWithValueHead from ..hparams import DataArguments @@ -457,3 +457,24 @@ def get_batch_logps( labels[labels == label_pad_token_id] = 0 # dummy token per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) return (per_token_logps * loss_mask).sum(-1), loss_mask.sum(-1) + + +def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCallback": + r""" + Gets the callback for logging to SwanLab. + """ + import swanlab + from swanlab.integration.transformers import SwanLabCallback + + if finetuning_args.swanlab_api_key is not None: + swanlab.login(api_key=finetuning_args.swanlab_api_key) + + swanlab_callback = SwanLabCallback( + project=finetuning_args.swanlab_project, + workspace=finetuning_args.swanlab_workspace, + experiment_name=finetuning_args.swanlab_experiment_name, + mode=finetuning_args.swanlab_mode, + config={"Framework": "🦙LLaMA Factory"}, + ) + + return swanlab_callback \ No newline at end of file diff --git a/src/llamafactory/webui/components/train.py b/src/llamafactory/webui/components/train.py index bd53d163e7..399823d84b 100644 --- a/src/llamafactory/webui/components/train.py +++ b/src/llamafactory/webui/components/train.py @@ -270,6 +270,28 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: ) ) + with gr.Accordion(open=False) as swanlab_tab: + with gr.Row(): + use_swanlab = gr.Checkbox() + swanlab_project = gr.Textbox(value="llamafactory", placeholder="Project name", interactive=True) + swanlab_experiment_name = gr.Textbox(value="", placeholder="Experiment name", interactive=True) + swanlab_workspace = gr.Textbox(value="", placeholder="Workspace name", interactive=True) + swanlab_api_key = gr.Textbox(value="", placeholder="API key", interactive=True) + swanlab_mode = gr.Dropdown(choices=["cloud", "local", "disabled"], value="cloud", interactive=True) + + input_elems.update({use_swanlab, swanlab_api_key, swanlab_project, swanlab_workspace, swanlab_experiment_name, swanlab_mode}) + elem_dict.update( + dict( + swanlab_tab=swanlab_tab, + use_swanlab=use_swanlab, + swanlab_api_key=swanlab_api_key, + swanlab_project=swanlab_project, + swanlab_workspace=swanlab_workspace, + swanlab_experiment_name=swanlab_experiment_name, + swanlab_mode=swanlab_mode, + ) + ) + with gr.Row(): cmd_preview_btn = gr.Button() arg_save_btn = gr.Button() diff --git a/src/llamafactory/webui/locales.py b/src/llamafactory/webui/locales.py index 45b847b4cb..8b5baade55 100644 --- a/src/llamafactory/webui/locales.py +++ b/src/llamafactory/webui/locales.py @@ -1353,6 +1353,121 @@ "info": "비율-BAdam의 업데이트 비율.", }, }, + "swanlab_tab": { + "en": { + "label": "SwanLab configurations", + }, + "ru": { + "label": "Конфигурации SwanLab", + }, + "zh": { + "label": "SwanLab 参数设置", + }, + "ko": { + "label": "SwanLab 설정", + }, + }, + "use_swanlab": { + "en": { + "label": "Use SwanLab", + "info": "Enable SwanLab for experiment tracking and visualization.", + }, + "ru": { + "label": "Использовать SwanLab", + "info": "Включить SwanLab для отслеживания и визуализации экспериментов.", + }, + "zh": { + "label": "使用 SwanLab", + "info": "启用 SwanLab 进行实验跟踪和可视化。", + }, + "ko": { + "label": "SwanLab 사용", + "info": "SwanLab를 사용하여 실험을 추적하고 시각화합니다.", + }, + }, + "swanlab_api_key": { + "en": { + "label": "API Key(optional)", + "info": "API key for SwanLab. Once logged in, no need to login again in the programming environment.", + }, + "ru": { + "label": "API ключ(Необязательный)", + "info": "API ключ для SwanLab. После входа в программное окружение, нет необходимости входить снова.", + }, + "zh": { + "label": "API密钥(选填)", + "info": "用于在编程环境登录SwanLab,已登录则无需填写。", + }, + "ko": { + "label": "API 키(선택 사항)", + "info": "SwanLab의 API 키. 프로그래밍 환경에 로그인한 후 다시 로그인할 필요가 없습니다.", + }, + }, + "swanlab_project": { + "en": { + "label": "Project(optional)", + }, + "ru": { + "label": "Проект(Необязательный)", + }, + "zh": { + "label": "项目(选填)", + }, + "ko": { + "label": "프로젝트(선택 사항)", + }, + }, + "swanlab_workspace": { + "en": { + "label": "Workspace(optional)", + "info": "Workspace for SwanLab. If not filled, it defaults to the personal workspace.", + + }, + "ru": { + "label": "Рабочая область(Необязательный)", + "info": "Рабочая область SwanLab, если не заполнено, то по умолчанию в личной рабочей области.", + }, + "zh": { + "label": "Workspace(选填)", + "info": "SwanLab组织的工作区,如不填写则默认在个人工作区下", + }, + "ko": { + "label": "작업 영역(선택 사항)", + "info": "SwanLab 조직의 작업 영역, 비어 있으면 기본적으로 개인 작업 영역에 있습니다.", + }, + }, + "swanlab_experiment_name": { + "en": { + "label": "Experiment name (optional)", + }, + "ru": { + "label": "Имя эксперимента(Необязательный)", + }, + "zh": { + "label": "实验名(选填) ", + }, + "ko": { + "label": "실험 이름(선택 사항)", + }, + }, + "swanlab_mode": { + "en": { + "label": "Mode", + "info": "Cloud or offline version.", + }, + "ru": { + "label": "Режим", + "info": "Версия в облаке или локальная версия.", + }, + "zh": { + "label": "模式", + "info": "云端版或离线版", + }, + "ko": { + "label": "모드", + "info": "클라우드 버전 또는 오프라인 버전.", + }, + }, "cmd_preview_btn": { "en": { "value": "Preview command", diff --git a/src/llamafactory/webui/runner.py b/src/llamafactory/webui/runner.py index da0a9c7ebc..2b5c55b8a3 100644 --- a/src/llamafactory/webui/runner.py +++ b/src/llamafactory/webui/runner.py @@ -147,6 +147,7 @@ def _parse_train_args(self, data: Dict["Component", Any]) -> Dict[str, Any]: report_to="all" if get("train.report_to") else "none", use_galore=get("train.use_galore"), use_badam=get("train.use_badam"), + use_swanlab=get("train.use_swanlab"), output_dir=get_save_dir(model_name, finetuning_type, get("train.output_dir")), fp16=(get("train.compute_type") == "fp16"), bf16=(get("train.compute_type") == "bf16"), @@ -228,6 +229,15 @@ def _parse_train_args(self, data: Dict["Component", Any]) -> Dict[str, Any]: args["badam_switch_interval"] = get("train.badam_switch_interval") args["badam_update_ratio"] = get("train.badam_update_ratio") + # swanlab config + if get("train.use_swanlab"): + args["swanlab_api_key"] = get("train.swanlab_api_key") + args["swanlab_project"] = get("train.swanlab_project") + args["swanlab_workspace"] = get("train.swanlab_workspace") + args["swanlab_experiment_name"] = get("train.swanlab_experiment_name") + args["swanlab_mode"] = get("train.swanlab_mode") + + # eval config if get("train.val_size") > 1e-6 and args["stage"] != "ppo": args["val_size"] = get("train.val_size")