Skip to content

Commit

Permalink
support report custom args
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga committed Dec 21, 2024
1 parent 84cd118 commit e2e0cc5
Show file tree
Hide file tree
Showing 19 changed files with 155 additions and 116 deletions.
14 changes: 7 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
[![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Factory)](LICENSE)
[![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main)
[![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/)
[![Citation](https://img.shields.io/badge/citation-93-green)](#projects-using-llama-factory)
[![Citation](https://img.shields.io/badge/citation-196-green)](https://scholar.google.com/scholar?cites=12620864006390196564)
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls)
[![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK)
[![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai)
Expand All @@ -13,6 +13,7 @@
[![Spaces](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
[![Studios](https://img.shields.io/badge/ModelScope-Open%20in%20Studios-blue)](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
[![SageMaker](https://img.shields.io/badge/SageMaker-Open%20in%20AWS-blue)](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/)
[![GitCode](https://gitcode.com/zhengyaowei/LLaMA-Factory/star/badge.svg)](https://gitcode.com/zhengyaowei/LLaMA-Factory)

[![GitHub Tread](https://trendshift.io/api/badge/repositories/4535)](https://trendshift.io/repositories/4535)

Expand Down Expand Up @@ -87,18 +88,18 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/

## Changelog

[24/12/21] We supported **[SwanLab](https://github.com/SwanHubX/SwanLab)** experiment tracking and visualization. See [this section](#use-swanlab-logger) for details.
[24/12/21] We supported using **[SwanLab](https://github.com/SwanHubX/SwanLab)** for experiment tracking and visualization. See [this section](#use-swanlab-logger) for details.

[24/11/27] We supported fine-tuning the **[Skywork-o1](https://huggingface.co/Skywork/Skywork-o1-Open-Llama-3.1-8B)** model and the **[OpenO1](https://huggingface.co/datasets/O1-OPEN/OpenO1-SFT)** dataset.

[24/10/09] We supported downloading pre-trained models and datasets from the **[Modelers Hub](https://modelers.cn/models)**. See [this tutorial](#download-from-modelers-hub) for usage.

<details><summary>Full Changelog</summary>

[24/09/19] We supported fine-tuning the **[Qwen2.5](https://qwenlm.github.io/blog/qwen2.5/)** models.

[24/08/30] We supported fine-tuning the **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** models. Thank [@simonJJJ](https://github.com/simonJJJ)'s PR.

<details><summary>Full Changelog</summary>

[24/08/27] We supported **[Liger Kernel](https://github.com/linkedin/Liger-Kernel)**. Try `enable_liger_kernel: true` for efficient training.

[24/08/09] We supported **[Adam-mini](https://github.com/zyushun/Adam-mini)** optimizer. See [examples](examples/README.md) for usage. Thank [@relic-yuexi](https://github.com/relic-yuexi)'s PR.
Expand Down Expand Up @@ -388,7 +389,7 @@ cd LLaMA-Factory
pip install -e ".[torch,metrics]"
```

Extra dependencies available: torch, torch-npu, metrics, deepspeed, liger-kernel, bitsandbytes, hqq, eetq, gptq, awq, aqlm, vllm, galore, badam, adam-mini, qwen, modelscope, openmind, quality
Extra dependencies available: torch, torch-npu, metrics, deepspeed, liger-kernel, bitsandbytes, hqq, eetq, gptq, awq, aqlm, vllm, galore, badam, adam-mini, qwen, modelscope, openmind, swanlab, quality

> [!TIP]
> Use `pip install --no-deps -e .` to resolve package conflicts.
Expand Down Expand Up @@ -642,8 +643,7 @@ To use [SwanLab](https://github.com/SwanHubX/SwanLab) for logging experimental r

```yaml
use_swanlab: true
swanlab_project: test_project # optional
swanlab_experiment_name: test_experiment # optional
swanlab_run_name: test_run # optional
```

When launching training tasks, you can log in to SwanLab in three ways:
Expand Down
15 changes: 7 additions & 8 deletions README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
[![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Factory)](LICENSE)
[![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main)
[![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/)
[![Citation](https://img.shields.io/badge/citation-93-green)](#使用了-llama-factory-的项目)
[![Citation](https://img.shields.io/badge/citation-196-green)](https://scholar.google.com/scholar?cites=12620864006390196564)
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls)
[![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK)
[![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai)
Expand All @@ -13,6 +13,7 @@
[![Spaces](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
[![Studios](https://img.shields.io/badge/ModelScope-Open%20in%20Studios-blue)](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
[![SageMaker](https://img.shields.io/badge/SageMaker-Open%20in%20AWS-blue)](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/)
[![GitCode](https://gitcode.com/zhengyaowei/LLaMA-Factory/star/badge.svg)](https://gitcode.com/zhengyaowei/LLaMA-Factory)

[![GitHub Tread](https://trendshift.io/api/badge/repositories/4535)](https://trendshift.io/repositories/4535)

Expand Down Expand Up @@ -88,18 +89,18 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272

## 更新日志

[24/12/21] 我们支持了 **[SwanLab](https://github.com/SwanHubX/SwanLab)** 跟踪与可视化实验。详细用法请参考 [此部分](#使用-wb-面板)
[24/12/21] 我们支持了使用 **[SwanLab](https://github.com/SwanHubX/SwanLab)** 跟踪与可视化实验。详细用法请参考 [此部分](#使用-swanlab-面板)

[24/11/27] 我们支持了 **[Skywork-o1](https://huggingface.co/Skywork/Skywork-o1-Open-Llama-3.1-8B)** 模型的微调和 **[OpenO1](https://huggingface.co/datasets/O1-OPEN/OpenO1-SFT)** 数据集。

[24/10/09] 我们支持了从 **[魔乐社区](https://modelers.cn/models)** 下载预训练模型和数据集。详细用法请参照 [此教程](#从魔乐社区下载)

<details><summary>展开日志</summary>

[24/09/19] 我们支持了 **[Qwen2.5](https://qwenlm.github.io/blog/qwen2.5/)** 模型的微调。

[24/08/30] 我们支持了 **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** 模型的微调。感谢 [@simonJJJ](https://github.com/simonJJJ) 的 PR。

<details><summary>展开日志</summary>

[24/08/27] 我们支持了 **[Liger Kernel](https://github.com/linkedin/Liger-Kernel)**。请使用 `enable_liger_kernel: true` 来加速训练。

[24/08/09] 我们支持了 **[Adam-mini](https://github.com/zyushun/Adam-mini)** 优化器。详细用法请参照 [examples](examples/README_zh.md)。感谢 [@relic-yuexi](https://github.com/relic-yuexi) 的 PR。
Expand Down Expand Up @@ -389,7 +390,7 @@ cd LLaMA-Factory
pip install -e ".[torch,metrics]"
```

可选的额外依赖项:torch、torch-npu、metrics、deepspeed、liger-kernel、bitsandbytes、hqq、eetq、gptq、awq、aqlm、vllm、galore、badam、adam-mini、qwen、modelscope、openmind、quality
可选的额外依赖项:torch、torch-npu、metrics、deepspeed、liger-kernel、bitsandbytes、hqq、eetq、gptq、awq、aqlm、vllm、galore、badam、adam-mini、qwen、modelscope、openmind、swanlab、quality

> [!TIP]
> 遇到包冲突时,可使用 `pip install --no-deps -e .` 解决。
Expand Down Expand Up @@ -643,8 +644,7 @@ run_name: test_run # 可选

```yaml
use_swanlab: true
swanlab_project: test_run # 可选
swanlab_experiment_name: test_experiment # 可选
swanlab_run_name: test_run # 可选
```

在启动训练任务时,登录SwanLab账户有以下三种方式:
Expand All @@ -653,7 +653,6 @@ swanlab_experiment_name: test_experiment # 可选
方式二:将环境变量 `SWANLAB_API_KEY` 设置为你的 [API 密钥](https://swanlab.cn/settings)。
方式三:启动前使用 `swanlab login` 命令完成登录。


## 使用了 LLaMA Factory 的项目

如果您有项目希望添加至下述列表,请通过邮件联系或者创建一个 PR。
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def get_console_scripts() -> List[str]:
"qwen": ["transformers_stream_generator"],
"modelscope": ["modelscope"],
"openmind": ["openmind"],
"swanlab": ["swanlab"],
"dev": ["pre-commit", "ruff", "pytest"],
}

Expand Down
5 changes: 4 additions & 1 deletion src/llamafactory/chat/hf_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,10 @@ def _process_args(
elif not isinstance(value, torch.Tensor):
value = torch.tensor(value)

gen_kwargs[key] = value.to(dtype=model.dtype, device=model.device)
if torch.is_floating_point(value):
value = value.to(model.dtype)

gen_kwargs[key] = value.to(model.device)

return gen_kwargs, prompt_length

Expand Down
7 changes: 5 additions & 2 deletions src/llamafactory/hparams/data_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass, field
from typing import Literal, Optional
from dataclasses import asdict, dataclass, field
from typing import Any, Dict, Literal, Optional


@dataclass
Expand Down Expand Up @@ -161,3 +161,6 @@ def split_arg(arg):

if self.mask_history and self.train_on_prompt:
raise ValueError("`mask_history` is incompatible with `train_on_prompt`.")

def to_dict(self) -> Dict[str, Any]:
return asdict(self)
11 changes: 8 additions & 3 deletions src/llamafactory/hparams/finetuning_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass, field
from typing import List, Literal, Optional
from dataclasses import asdict, dataclass, field
from typing import Any, Dict, List, Literal, Optional


@dataclass
Expand Down Expand Up @@ -318,7 +318,7 @@ class SwanLabArguments:
default=None,
metadata={"help": "The workspace name in SwanLab."},
)
swanlab_experiment_name: str = field(
swanlab_run_name: str = field(
default=None,
metadata={"help": "The experiment name in SwanLab."},
)
Expand Down Expand Up @@ -440,3 +440,8 @@ def split_arg(arg):

if self.pissa_init:
raise ValueError("`pissa_init` is only valid for LoRA training.")

def to_dict(self) -> Dict[str, Any]:
args = asdict(self)
args = {k: f"<{k.upper()}>" if k.endswith("api_key") else v for k, v in args.items()}
return args
7 changes: 6 additions & 1 deletion src/llamafactory/hparams/model_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# limitations under the License.

import json
from dataclasses import dataclass, field, fields
from dataclasses import asdict, dataclass, field, fields
from typing import Any, Dict, Literal, Optional, Union

import torch
Expand Down Expand Up @@ -344,3 +344,8 @@ def copyfrom(cls, source: "Self", **kwargs) -> "Self":
setattr(result, name, value)

return result

def to_dict(self) -> Dict[str, Any]:
args = asdict(self)
args = {k: f"<{k.upper()}>" if k.endswith("token") else v for k, v in args.items()}
return args
57 changes: 51 additions & 6 deletions src/llamafactory/train/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,13 @@
from safetensors import safe_open
from safetensors.torch import save_file


if TYPE_CHECKING:
from transformers import TrainerControl, TrainerState, TrainingArguments
from trl import AutoModelForCausalLMWithValueHead

from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments


logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -101,9 +104,6 @@ class FixValueHeadModelCallback(TrainerCallback):

@override
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after a checkpoint save.
"""
if args.should_save:
output_dir = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
fix_valuehead_checkpoint(
Expand Down Expand Up @@ -138,9 +138,6 @@ class PissaConvertCallback(TrainerCallback):

@override
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the beginning of training.
"""
if args.should_save:
model = kwargs.pop("model")
pissa_init_dir = os.path.join(args.output_dir, "pissa_init")
Expand Down Expand Up @@ -348,3 +345,51 @@ def on_prediction_step(
remaining_time=self.remaining_time,
)
self.thread_pool.submit(self._write_log, args.output_dir, logs)


class ReporterCallback(TrainerCallback):
r"""
A callback for reporting training status to external logger.
"""

def __init__(
self,
model_args: "ModelArguments",
data_args: "DataArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
) -> None:
self.model_args = model_args
self.data_args = data_args
self.finetuning_args = finetuning_args
self.generating_args = generating_args
os.environ["WANDB_PROJECT"] = os.getenv("WANDB_PROJECT", "llamafactory")

@override
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
if not state.is_world_process_zero:
return

if "wandb" in args.report_to:
import wandb

wandb.config.update(
{
"model_args": self.model_args.to_dict(),
"data_args": self.data_args.to_dict(),
"finetuning_args": self.finetuning_args.to_dict(),
"generating_args": self.generating_args.to_dict(),
}
)

if self.finetuning_args.use_swanlab:
import swanlab

swanlab.config.update(
{
"model_args": self.model_args.to_dict(),
"data_args": self.data_args.to_dict(),
"finetuning_args": self.finetuning_args.to_dict(),
"generating_args": self.generating_args.to_dict(),
}
)
10 changes: 2 additions & 8 deletions src/llamafactory/train/dpo/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@

from ...extras.constants import IGNORE_INDEX
from ...extras.packages import is_transformers_version_equal_to_4_46
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps, get_swanlab_callback
from ..callbacks import SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps


if TYPE_CHECKING:
Expand Down Expand Up @@ -97,18 +97,12 @@ def __init__(
if processor is not None:
self.add_callback(SaveProcessorCallback(processor))

if finetuning_args.pissa_convert:
self.callback_handler.add_callback(PissaConvertCallback)

if finetuning_args.use_badam:
from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore

self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback)

if finetuning_args.use_swanlab:
self.add_callback(get_swanlab_callback(finetuning_args))

@override
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
Expand Down
5 changes: 1 addition & 4 deletions src/llamafactory/train/kto/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from ...extras.constants import IGNORE_INDEX
from ...extras.packages import is_transformers_version_equal_to_4_46
from ..callbacks import SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps, get_swanlab_callback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps


if TYPE_CHECKING:
Expand Down Expand Up @@ -101,9 +101,6 @@ def __init__(
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback)

if finetuning_args.use_swanlab:
self.add_callback(get_swanlab_callback(finetuning_args))

@override
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
Expand Down
5 changes: 1 addition & 4 deletions src/llamafactory/train/ppo/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from ...extras import logging
from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor
from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_swanlab_callback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
from .ppo_utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm


Expand Down Expand Up @@ -186,9 +186,6 @@ def __init__(
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback)

if finetuning_args.use_swanlab:
self.add_callback(get_swanlab_callback(finetuning_args))

def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None:
r"""
Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.
Expand Down
10 changes: 2 additions & 8 deletions src/llamafactory/train/pt/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from typing_extensions import override

from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_swanlab_callback
from ..callbacks import SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler


if TYPE_CHECKING:
Expand All @@ -47,18 +47,12 @@ def __init__(
if processor is not None:
self.add_callback(SaveProcessorCallback(processor))

if finetuning_args.pissa_convert:
self.add_callback(PissaConvertCallback)

if finetuning_args.use_badam:
from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore

self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback)

if finetuning_args.use_swanlab:
self.add_callback(get_swanlab_callback(finetuning_args))

@override
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
Expand Down
Loading

0 comments on commit e2e0cc5

Please sign in to comment.