From 84cd1188ac03c165e1a626db297936c2458627d6 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Thu, 19 Dec 2024 13:24:41 +0000 Subject: [PATCH 1/2] fix paligemma infer --- src/llamafactory/chat/hf_engine.py | 2 +- src/llamafactory/webui/locales.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/llamafactory/chat/hf_engine.py b/src/llamafactory/chat/hf_engine.py index 5a5c00c881..61f7f9a6a9 100644 --- a/src/llamafactory/chat/hf_engine.py +++ b/src/llamafactory/chat/hf_engine.py @@ -171,7 +171,7 @@ def _process_args( elif not isinstance(value, torch.Tensor): value = torch.tensor(value) - gen_kwargs[key] = value.to(model.device) + gen_kwargs[key] = value.to(dtype=model.dtype, device=model.device) return gen_kwargs, prompt_length diff --git a/src/llamafactory/webui/locales.py b/src/llamafactory/webui/locales.py index 8b5baade55..64cdf1f52c 100644 --- a/src/llamafactory/webui/locales.py +++ b/src/llamafactory/webui/locales.py @@ -1119,7 +1119,7 @@ "info": "Нормализация оценок в тренировке PPO.", }, "zh": { - "label": "奖励模型", + "label": "归一化分数", "info": "PPO 训练中归一化奖励分数。", }, "ko": { From 5111cac6f8e7b77ef1ca1ff967734cfe1d6785f4 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Thu, 19 Dec 2024 14:57:09 +0000 Subject: [PATCH 2/2] support report custom args --- .gitignore | 1 + README.md | 14 ++--- README_zh.md | 15 +++-- setup.py | 1 + src/llamafactory/chat/hf_engine.py | 5 +- src/llamafactory/hparams/data_args.py | 7 ++- src/llamafactory/hparams/finetuning_args.py | 11 +++- src/llamafactory/hparams/model_args.py | 7 ++- src/llamafactory/train/callbacks.py | 57 +++++++++++++++++-- src/llamafactory/train/dpo/trainer.py | 10 +--- src/llamafactory/train/kto/trainer.py | 5 +- src/llamafactory/train/ppo/trainer.py | 5 +- src/llamafactory/train/pt/trainer.py | 10 +--- src/llamafactory/train/rm/trainer.py | 10 +--- src/llamafactory/train/sft/trainer.py | 10 +--- src/llamafactory/train/trainer_utils.py | 7 +-- src/llamafactory/train/tuner.py | 11 +++- src/llamafactory/webui/components/train.py | 18 +++--- src/llamafactory/webui/locales.py | 63 ++++++++++----------- src/llamafactory/webui/runner.py | 5 +- 20 files changed, 156 insertions(+), 116 deletions(-) diff --git a/.gitignore b/.gitignore index 88c36ca2e0..3e1f97dd7c 100644 --- a/.gitignore +++ b/.gitignore @@ -171,4 +171,5 @@ config/ saves/ output/ wandb/ +swanlog/ generated_predictions.jsonl diff --git a/README.md b/README.md index 515ba5ffa7..5b8b258f2a 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ [![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Factory)](LICENSE) [![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main) [![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/) -[![Citation](https://img.shields.io/badge/citation-93-green)](#projects-using-llama-factory) +[![Citation](https://img.shields.io/badge/citation-196-green)](https://scholar.google.com/scholar?cites=12620864006390196564) [![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls) [![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK) [![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai) @@ -13,6 +13,7 @@ [![Spaces](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board) [![Studios](https://img.shields.io/badge/ModelScope-Open%20in%20Studios-blue)](https://modelscope.cn/studios/hiyouga/LLaMA-Board) [![SageMaker](https://img.shields.io/badge/SageMaker-Open%20in%20AWS-blue)](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/) +[![GitCode](https://gitcode.com/zhengyaowei/LLaMA-Factory/star/badge.svg)](https://gitcode.com/zhengyaowei/LLaMA-Factory) [![GitHub Tread](https://trendshift.io/api/badge/repositories/4535)](https://trendshift.io/repositories/4535) @@ -87,18 +88,18 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ ## Changelog -[24/12/21] We supported **[SwanLab](https://github.com/SwanHubX/SwanLab)** experiment tracking and visualization. See [this section](#use-swanlab-logger) for details. +[24/12/21] We supported using **[SwanLab](https://github.com/SwanHubX/SwanLab)** for experiment tracking and visualization. See [this section](#use-swanlab-logger) for details. [24/11/27] We supported fine-tuning the **[Skywork-o1](https://huggingface.co/Skywork/Skywork-o1-Open-Llama-3.1-8B)** model and the **[OpenO1](https://huggingface.co/datasets/O1-OPEN/OpenO1-SFT)** dataset. [24/10/09] We supported downloading pre-trained models and datasets from the **[Modelers Hub](https://modelers.cn/models)**. See [this tutorial](#download-from-modelers-hub) for usage. +
Full Changelog + [24/09/19] We supported fine-tuning the **[Qwen2.5](https://qwenlm.github.io/blog/qwen2.5/)** models. [24/08/30] We supported fine-tuning the **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** models. Thank [@simonJJJ](https://github.com/simonJJJ)'s PR. -
Full Changelog - [24/08/27] We supported **[Liger Kernel](https://github.com/linkedin/Liger-Kernel)**. Try `enable_liger_kernel: true` for efficient training. [24/08/09] We supported **[Adam-mini](https://github.com/zyushun/Adam-mini)** optimizer. See [examples](examples/README.md) for usage. Thank [@relic-yuexi](https://github.com/relic-yuexi)'s PR. @@ -388,7 +389,7 @@ cd LLaMA-Factory pip install -e ".[torch,metrics]" ``` -Extra dependencies available: torch, torch-npu, metrics, deepspeed, liger-kernel, bitsandbytes, hqq, eetq, gptq, awq, aqlm, vllm, galore, badam, adam-mini, qwen, modelscope, openmind, quality +Extra dependencies available: torch, torch-npu, metrics, deepspeed, liger-kernel, bitsandbytes, hqq, eetq, gptq, awq, aqlm, vllm, galore, badam, adam-mini, qwen, modelscope, openmind, swanlab, quality > [!TIP] > Use `pip install --no-deps -e .` to resolve package conflicts. @@ -642,8 +643,7 @@ To use [SwanLab](https://github.com/SwanHubX/SwanLab) for logging experimental r ```yaml use_swanlab: true -swanlab_project: test_project # optional -swanlab_experiment_name: test_experiment # optional +swanlab_run_name: test_run # optional ``` When launching training tasks, you can log in to SwanLab in three ways: diff --git a/README_zh.md b/README_zh.md index dcfb489c06..72f3be0a43 100644 --- a/README_zh.md +++ b/README_zh.md @@ -4,7 +4,7 @@ [![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Factory)](LICENSE) [![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main) [![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/) -[![Citation](https://img.shields.io/badge/citation-93-green)](#使用了-llama-factory-的项目) +[![Citation](https://img.shields.io/badge/citation-196-green)](https://scholar.google.com/scholar?cites=12620864006390196564) [![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls) [![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK) [![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai) @@ -13,6 +13,7 @@ [![Spaces](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board) [![Studios](https://img.shields.io/badge/ModelScope-Open%20in%20Studios-blue)](https://modelscope.cn/studios/hiyouga/LLaMA-Board) [![SageMaker](https://img.shields.io/badge/SageMaker-Open%20in%20AWS-blue)](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/) +[![GitCode](https://gitcode.com/zhengyaowei/LLaMA-Factory/star/badge.svg)](https://gitcode.com/zhengyaowei/LLaMA-Factory) [![GitHub Tread](https://trendshift.io/api/badge/repositories/4535)](https://trendshift.io/repositories/4535) @@ -88,18 +89,18 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272 ## 更新日志 -[24/12/21] 我们支持了 **[SwanLab](https://github.com/SwanHubX/SwanLab)** 跟踪与可视化实验。详细用法请参考 [此部分](#使用-wb-面板)。 +[24/12/21] 我们支持了使用 **[SwanLab](https://github.com/SwanHubX/SwanLab)** 跟踪与可视化实验。详细用法请参考 [此部分](#使用-swanlab-面板)。 [24/11/27] 我们支持了 **[Skywork-o1](https://huggingface.co/Skywork/Skywork-o1-Open-Llama-3.1-8B)** 模型的微调和 **[OpenO1](https://huggingface.co/datasets/O1-OPEN/OpenO1-SFT)** 数据集。 [24/10/09] 我们支持了从 **[魔乐社区](https://modelers.cn/models)** 下载预训练模型和数据集。详细用法请参照 [此教程](#从魔乐社区下载)。 +
展开日志 + [24/09/19] 我们支持了 **[Qwen2.5](https://qwenlm.github.io/blog/qwen2.5/)** 模型的微调。 [24/08/30] 我们支持了 **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** 模型的微调。感谢 [@simonJJJ](https://github.com/simonJJJ) 的 PR。 -
展开日志 - [24/08/27] 我们支持了 **[Liger Kernel](https://github.com/linkedin/Liger-Kernel)**。请使用 `enable_liger_kernel: true` 来加速训练。 [24/08/09] 我们支持了 **[Adam-mini](https://github.com/zyushun/Adam-mini)** 优化器。详细用法请参照 [examples](examples/README_zh.md)。感谢 [@relic-yuexi](https://github.com/relic-yuexi) 的 PR。 @@ -389,7 +390,7 @@ cd LLaMA-Factory pip install -e ".[torch,metrics]" ``` -可选的额外依赖项:torch、torch-npu、metrics、deepspeed、liger-kernel、bitsandbytes、hqq、eetq、gptq、awq、aqlm、vllm、galore、badam、adam-mini、qwen、modelscope、openmind、quality +可选的额外依赖项:torch、torch-npu、metrics、deepspeed、liger-kernel、bitsandbytes、hqq、eetq、gptq、awq、aqlm、vllm、galore、badam、adam-mini、qwen、modelscope、openmind、swanlab、quality > [!TIP] > 遇到包冲突时,可使用 `pip install --no-deps -e .` 解决。 @@ -643,8 +644,7 @@ run_name: test_run # 可选 ```yaml use_swanlab: true -swanlab_project: test_run # 可选 -swanlab_experiment_name: test_experiment # 可选 +swanlab_run_name: test_run # 可选 ``` 在启动训练任务时,登录SwanLab账户有以下三种方式: @@ -653,7 +653,6 @@ swanlab_experiment_name: test_experiment # 可选 方式二:将环境变量 `SWANLAB_API_KEY` 设置为你的 [API 密钥](https://swanlab.cn/settings)。 方式三:启动前使用 `swanlab login` 命令完成登录。 - ## 使用了 LLaMA Factory 的项目 如果您有项目希望添加至下述列表,请通过邮件联系或者创建一个 PR。 diff --git a/setup.py b/setup.py index 862e9b943c..bf7662c844 100644 --- a/setup.py +++ b/setup.py @@ -61,6 +61,7 @@ def get_console_scripts() -> List[str]: "qwen": ["transformers_stream_generator"], "modelscope": ["modelscope"], "openmind": ["openmind"], + "swanlab": ["swanlab"], "dev": ["pre-commit", "ruff", "pytest"], } diff --git a/src/llamafactory/chat/hf_engine.py b/src/llamafactory/chat/hf_engine.py index 61f7f9a6a9..d001386bc6 100644 --- a/src/llamafactory/chat/hf_engine.py +++ b/src/llamafactory/chat/hf_engine.py @@ -171,7 +171,10 @@ def _process_args( elif not isinstance(value, torch.Tensor): value = torch.tensor(value) - gen_kwargs[key] = value.to(dtype=model.dtype, device=model.device) + if torch.is_floating_point(value): + value = value.to(model.dtype) + + gen_kwargs[key] = value.to(model.device) return gen_kwargs, prompt_length diff --git a/src/llamafactory/hparams/data_args.py b/src/llamafactory/hparams/data_args.py index 2d7e30df96..a33e626773 100644 --- a/src/llamafactory/hparams/data_args.py +++ b/src/llamafactory/hparams/data_args.py @@ -15,8 +15,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass, field -from typing import Literal, Optional +from dataclasses import asdict, dataclass, field +from typing import Any, Dict, Literal, Optional @dataclass @@ -161,3 +161,6 @@ def split_arg(arg): if self.mask_history and self.train_on_prompt: raise ValueError("`mask_history` is incompatible with `train_on_prompt`.") + + def to_dict(self) -> Dict[str, Any]: + return asdict(self) diff --git a/src/llamafactory/hparams/finetuning_args.py b/src/llamafactory/hparams/finetuning_args.py index 44a3e36283..29e91a27c3 100644 --- a/src/llamafactory/hparams/finetuning_args.py +++ b/src/llamafactory/hparams/finetuning_args.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass, field -from typing import List, Literal, Optional +from dataclasses import asdict, dataclass, field +from typing import Any, Dict, List, Literal, Optional @dataclass @@ -318,7 +318,7 @@ class SwanLabArguments: default=None, metadata={"help": "The workspace name in SwanLab."}, ) - swanlab_experiment_name: str = field( + swanlab_run_name: str = field( default=None, metadata={"help": "The experiment name in SwanLab."}, ) @@ -440,3 +440,8 @@ def split_arg(arg): if self.pissa_init: raise ValueError("`pissa_init` is only valid for LoRA training.") + + def to_dict(self) -> Dict[str, Any]: + args = asdict(self) + args = {k: f"<{k.upper()}>" if k.endswith("api_key") else v for k, v in args.items()} + return args diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index 6b25ea1688..bf3cf7f570 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -16,7 +16,7 @@ # limitations under the License. import json -from dataclasses import dataclass, field, fields +from dataclasses import asdict, dataclass, field, fields from typing import Any, Dict, Literal, Optional, Union import torch @@ -344,3 +344,8 @@ def copyfrom(cls, source: "Self", **kwargs) -> "Self": setattr(result, name, value) return result + + def to_dict(self) -> Dict[str, Any]: + args = asdict(self) + args = {k: f"<{k.upper()}>" if k.endswith("token") else v for k, v in args.items()} + return args diff --git a/src/llamafactory/train/callbacks.py b/src/llamafactory/train/callbacks.py index a2885cc611..189c753377 100644 --- a/src/llamafactory/train/callbacks.py +++ b/src/llamafactory/train/callbacks.py @@ -42,10 +42,13 @@ from safetensors import safe_open from safetensors.torch import save_file + if TYPE_CHECKING: from transformers import TrainerControl, TrainerState, TrainingArguments from trl import AutoModelForCausalLMWithValueHead + from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments + logger = logging.get_logger(__name__) @@ -101,9 +104,6 @@ class FixValueHeadModelCallback(TrainerCallback): @override def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): - r""" - Event called after a checkpoint save. - """ if args.should_save: output_dir = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}") fix_valuehead_checkpoint( @@ -138,9 +138,6 @@ class PissaConvertCallback(TrainerCallback): @override def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): - r""" - Event called at the beginning of training. - """ if args.should_save: model = kwargs.pop("model") pissa_init_dir = os.path.join(args.output_dir, "pissa_init") @@ -348,3 +345,51 @@ def on_prediction_step( remaining_time=self.remaining_time, ) self.thread_pool.submit(self._write_log, args.output_dir, logs) + + +class ReporterCallback(TrainerCallback): + r""" + A callback for reporting training status to external logger. + """ + + def __init__( + self, + model_args: "ModelArguments", + data_args: "DataArguments", + finetuning_args: "FinetuningArguments", + generating_args: "GeneratingArguments", + ) -> None: + self.model_args = model_args + self.data_args = data_args + self.finetuning_args = finetuning_args + self.generating_args = generating_args + os.environ["WANDB_PROJECT"] = os.getenv("WANDB_PROJECT", "llamafactory") + + @override + def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + if not state.is_world_process_zero: + return + + if "wandb" in args.report_to: + import wandb + + wandb.config.update( + { + "model_args": self.model_args.to_dict(), + "data_args": self.data_args.to_dict(), + "finetuning_args": self.finetuning_args.to_dict(), + "generating_args": self.generating_args.to_dict(), + } + ) + + if self.finetuning_args.use_swanlab: + import swanlab + + swanlab.config.update( + { + "model_args": self.model_args.to_dict(), + "data_args": self.data_args.to_dict(), + "finetuning_args": self.finetuning_args.to_dict(), + "generating_args": self.generating_args.to_dict(), + } + ) diff --git a/src/llamafactory/train/dpo/trainer.py b/src/llamafactory/train/dpo/trainer.py index 0ab177b42e..9d1e010477 100644 --- a/src/llamafactory/train/dpo/trainer.py +++ b/src/llamafactory/train/dpo/trainer.py @@ -30,8 +30,8 @@ from ...extras.constants import IGNORE_INDEX from ...extras.packages import is_transformers_version_equal_to_4_46 -from ..callbacks import PissaConvertCallback, SaveProcessorCallback -from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps, get_swanlab_callback +from ..callbacks import SaveProcessorCallback +from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps if TYPE_CHECKING: @@ -97,18 +97,12 @@ def __init__( if processor is not None: self.add_callback(SaveProcessorCallback(processor)) - if finetuning_args.pissa_convert: - self.callback_handler.add_callback(PissaConvertCallback) - if finetuning_args.use_badam: from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.add_callback(BAdamCallback) - if finetuning_args.use_swanlab: - self.add_callback(get_swanlab_callback(finetuning_args)) - @override def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: diff --git a/src/llamafactory/train/kto/trainer.py b/src/llamafactory/train/kto/trainer.py index 1c6b5fe842..3d007ae70b 100644 --- a/src/llamafactory/train/kto/trainer.py +++ b/src/llamafactory/train/kto/trainer.py @@ -30,7 +30,7 @@ from ...extras.constants import IGNORE_INDEX from ...extras.packages import is_transformers_version_equal_to_4_46 from ..callbacks import SaveProcessorCallback -from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps, get_swanlab_callback +from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps if TYPE_CHECKING: @@ -101,9 +101,6 @@ def __init__( self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.add_callback(BAdamCallback) - if finetuning_args.use_swanlab: - self.add_callback(get_swanlab_callback(finetuning_args)) - @override def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: diff --git a/src/llamafactory/train/ppo/trainer.py b/src/llamafactory/train/ppo/trainer.py index a60b7d7cda..4ab7a11879 100644 --- a/src/llamafactory/train/ppo/trainer.py +++ b/src/llamafactory/train/ppo/trainer.py @@ -40,7 +40,7 @@ from ...extras import logging from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback -from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_swanlab_callback +from ..trainer_utils import create_custom_optimizer, create_custom_scheduler from .ppo_utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm @@ -186,9 +186,6 @@ def __init__( self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.add_callback(BAdamCallback) - if finetuning_args.use_swanlab: - self.add_callback(get_swanlab_callback(finetuning_args)) - def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None: r""" Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer. diff --git a/src/llamafactory/train/pt/trainer.py b/src/llamafactory/train/pt/trainer.py index ab4b008171..445462b944 100644 --- a/src/llamafactory/train/pt/trainer.py +++ b/src/llamafactory/train/pt/trainer.py @@ -20,8 +20,8 @@ from typing_extensions import override from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than -from ..callbacks import PissaConvertCallback, SaveProcessorCallback -from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_swanlab_callback +from ..callbacks import SaveProcessorCallback +from ..trainer_utils import create_custom_optimizer, create_custom_scheduler if TYPE_CHECKING: @@ -47,18 +47,12 @@ def __init__( if processor is not None: self.add_callback(SaveProcessorCallback(processor)) - if finetuning_args.pissa_convert: - self.add_callback(PissaConvertCallback) - if finetuning_args.use_badam: from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.add_callback(BAdamCallback) - if finetuning_args.use_swanlab: - self.add_callback(get_swanlab_callback(finetuning_args)) - @override def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: diff --git a/src/llamafactory/train/rm/trainer.py b/src/llamafactory/train/rm/trainer.py index 244f460e67..347cae9ba4 100644 --- a/src/llamafactory/train/rm/trainer.py +++ b/src/llamafactory/train/rm/trainer.py @@ -26,8 +26,8 @@ from ...extras import logging from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than -from ..callbacks import FixValueHeadModelCallback, PissaConvertCallback, SaveProcessorCallback -from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_swanlab_callback +from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback +from ..trainer_utils import create_custom_optimizer, create_custom_scheduler if TYPE_CHECKING: @@ -59,18 +59,12 @@ def __init__( if processor is not None: self.add_callback(SaveProcessorCallback(processor)) - if finetuning_args.pissa_convert: - self.add_callback(PissaConvertCallback) - if finetuning_args.use_badam: from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.add_callback(BAdamCallback) - if finetuning_args.use_swanlab: - self.add_callback(get_swanlab_callback(finetuning_args)) - @override def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index d1510bb2c4..6ba758cd64 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -28,8 +28,8 @@ from ...extras import logging from ...extras.constants import IGNORE_INDEX from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than -from ..callbacks import PissaConvertCallback, SaveProcessorCallback -from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_swanlab_callback +from ..callbacks import SaveProcessorCallback +from ..trainer_utils import create_custom_optimizer, create_custom_scheduler if TYPE_CHECKING: @@ -62,18 +62,12 @@ def __init__( if processor is not None: self.add_callback(SaveProcessorCallback(processor)) - if finetuning_args.pissa_convert: - self.add_callback(PissaConvertCallback) - if finetuning_args.use_badam: from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.add_callback(BAdamCallback) - if finetuning_args.use_swanlab: - self.add_callback(get_swanlab_callback(finetuning_args)) - @override def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: diff --git a/src/llamafactory/train/trainer_utils.py b/src/llamafactory/train/trainer_utils.py index 5b8fb4037a..eb2421ce79 100644 --- a/src/llamafactory/train/trainer_utils.py +++ b/src/llamafactory/train/trainer_utils.py @@ -472,9 +472,8 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall swanlab_callback = SwanLabCallback( project=finetuning_args.swanlab_project, workspace=finetuning_args.swanlab_workspace, - experiment_name=finetuning_args.swanlab_experiment_name, + experiment_name=finetuning_args.swanlab_run_name, mode=finetuning_args.swanlab_mode, - config={"Framework": "🦙LLaMA Factory"}, + config={"Framework": "🦙LlamaFactory"}, ) - - return swanlab_callback \ No newline at end of file + return swanlab_callback diff --git a/src/llamafactory/train/tuner.py b/src/llamafactory/train/tuner.py index 14cc206123..6c79320e7c 100644 --- a/src/llamafactory/train/tuner.py +++ b/src/llamafactory/train/tuner.py @@ -24,13 +24,14 @@ from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME from ..hparams import get_infer_args, get_train_args from ..model import load_model, load_tokenizer -from .callbacks import LogCallback +from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback from .dpo import run_dpo from .kto import run_kto from .ppo import run_ppo from .pt import run_pt from .rm import run_rm from .sft import run_sft +from .trainer_utils import get_swanlab_callback if TYPE_CHECKING: @@ -44,6 +45,14 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallb callbacks.append(LogCallback()) model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args) + if finetuning_args.pissa_convert: + callbacks.append(PissaConvertCallback()) + + if finetuning_args.use_swanlab: + callbacks.append(get_swanlab_callback(finetuning_args)) + + callbacks.append(ReporterCallback(model_args, data_args, finetuning_args, generating_args)) # add to last + if finetuning_args.stage == "pt": run_pt(model_args, data_args, training_args, finetuning_args, callbacks) elif finetuning_args.stage == "sft": diff --git a/src/llamafactory/webui/components/train.py b/src/llamafactory/webui/components/train.py index 399823d84b..05fa810b1a 100644 --- a/src/llamafactory/webui/components/train.py +++ b/src/llamafactory/webui/components/train.py @@ -273,21 +273,23 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: with gr.Accordion(open=False) as swanlab_tab: with gr.Row(): use_swanlab = gr.Checkbox() - swanlab_project = gr.Textbox(value="llamafactory", placeholder="Project name", interactive=True) - swanlab_experiment_name = gr.Textbox(value="", placeholder="Experiment name", interactive=True) - swanlab_workspace = gr.Textbox(value="", placeholder="Workspace name", interactive=True) - swanlab_api_key = gr.Textbox(value="", placeholder="API key", interactive=True) - swanlab_mode = gr.Dropdown(choices=["cloud", "local", "disabled"], value="cloud", interactive=True) + swanlab_project = gr.Textbox(value="llamafactory") + swanlab_run_name = gr.Textbox() + swanlab_workspace = gr.Textbox() + swanlab_api_key = gr.Textbox() + swanlab_mode = gr.Dropdown(choices=["cloud", "local"], value="cloud") - input_elems.update({use_swanlab, swanlab_api_key, swanlab_project, swanlab_workspace, swanlab_experiment_name, swanlab_mode}) + input_elems.update( + {use_swanlab, swanlab_project, swanlab_run_name, swanlab_workspace, swanlab_api_key, swanlab_mode} + ) elem_dict.update( dict( swanlab_tab=swanlab_tab, use_swanlab=use_swanlab, - swanlab_api_key=swanlab_api_key, swanlab_project=swanlab_project, + swanlab_run_name=swanlab_run_name, swanlab_workspace=swanlab_workspace, - swanlab_experiment_name=swanlab_experiment_name, + swanlab_api_key=swanlab_api_key, swanlab_mode=swanlab_mode, ) ) diff --git a/src/llamafactory/webui/locales.py b/src/llamafactory/webui/locales.py index 64cdf1f52c..9b78e6e9ca 100644 --- a/src/llamafactory/webui/locales.py +++ b/src/llamafactory/webui/locales.py @@ -1385,86 +1385,85 @@ "info": "SwanLab를 사용하여 실험을 추적하고 시각화합니다.", }, }, - "swanlab_api_key": { + "swanlab_project": { "en": { - "label": "API Key(optional)", - "info": "API key for SwanLab. Once logged in, no need to login again in the programming environment.", + "label": "SwanLab project", }, "ru": { - "label": "API ключ(Необязательный)", - "info": "API ключ для SwanLab. После входа в программное окружение, нет необходимости входить снова.", + "label": "SwanLab Проект", }, "zh": { - "label": "API密钥(选填)", - "info": "用于在编程环境登录SwanLab,已登录则无需填写。", + "label": "SwanLab 项目名", }, "ko": { - "label": "API 키(선택 사항)", - "info": "SwanLab의 API 키. 프로그래밍 환경에 로그인한 후 다시 로그인할 필요가 없습니다.", + "label": "SwanLab 프로젝트", }, }, - "swanlab_project": { + "swanlab_run_name": { "en": { - "label": "Project(optional)", + "label": "SwanLab experiment name (optional)", }, "ru": { - "label": "Проект(Необязательный)", + "label": "SwanLab Имя эксперимента (опционально)", }, "zh": { - "label": "项目(选填)", + "label": "SwanLab 实验名(非必填)", }, "ko": { - "label": "프로젝트(선택 사항)", + "label": "SwanLab 실험 이름 (선택 사항)", }, }, "swanlab_workspace": { "en": { - "label": "Workspace(optional)", - "info": "Workspace for SwanLab. If not filled, it defaults to the personal workspace.", - + "label": "SwanLab workspace (optional)", + "info": "Workspace for SwanLab. Defaults to the personal workspace.", }, "ru": { - "label": "Рабочая область(Необязательный)", + "label": "SwanLab Рабочая область (опционально)", "info": "Рабочая область SwanLab, если не заполнено, то по умолчанию в личной рабочей области.", }, "zh": { - "label": "Workspace(选填)", - "info": "SwanLab组织的工作区,如不填写则默认在个人工作区下", + "label": "SwanLab 工作区(非必填)", + "info": "SwanLab 的工作区,默认在个人工作区下。", }, "ko": { - "label": "작업 영역(선택 사항)", + "label": "SwanLab 작업 영역 (선택 사항)", "info": "SwanLab 조직의 작업 영역, 비어 있으면 기본적으로 개인 작업 영역에 있습니다.", }, }, - "swanlab_experiment_name": { + "swanlab_api_key": { "en": { - "label": "Experiment name (optional)", + "label": "SwanLab API key (optional)", + "info": "API key for SwanLab.", }, "ru": { - "label": "Имя эксперимента(Необязательный)", + "label": "SwanLab API ключ (опционально)", + "info": "API ключ для SwanLab.", }, "zh": { - "label": "实验名(选填) ", + "label": "SwanLab API密钥(非必填)", + "info": "用于在编程环境登录 SwanLab,已登录则无需填写。", }, "ko": { - "label": "실험 이름(선택 사항)", + "label": "SwanLab API 키 (선택 사항)", + "info": "SwanLab의 API 키.", }, }, "swanlab_mode": { "en": { - "label": "Mode", - "info": "Cloud or offline version.", + "label": "SwanLab mode", + "info": "Cloud or offline version.", }, "ru": { - "label": "Режим", + "label": "SwanLab Режим", "info": "Версия в облаке или локальная версия.", }, "zh": { - "label": "模式", - "info": "云端版或离线版", + "label": "SwanLab 模式", + "info": "使用云端版或离线版 SwanLab。", }, "ko": { - "label": "모드", + "label": "SwanLab 모드", "info": "클라우드 버전 또는 오프라인 버전.", }, }, diff --git a/src/llamafactory/webui/runner.py b/src/llamafactory/webui/runner.py index 2b5c55b8a3..f5aecaeb32 100644 --- a/src/llamafactory/webui/runner.py +++ b/src/llamafactory/webui/runner.py @@ -231,12 +231,11 @@ def _parse_train_args(self, data: Dict["Component", Any]) -> Dict[str, Any]: # swanlab config if get("train.use_swanlab"): - args["swanlab_api_key"] = get("train.swanlab_api_key") args["swanlab_project"] = get("train.swanlab_project") + args["swanlab_run_name"] = get("train.swanlab_run_name") args["swanlab_workspace"] = get("train.swanlab_workspace") - args["swanlab_experiment_name"] = get("train.swanlab_experiment_name") + args["swanlab_api_key"] = get("train.swanlab_api_key") args["swanlab_mode"] = get("train.swanlab_mode") - # eval config if get("train.val_size") > 1e-6 and args["stage"] != "ppo":