Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[model] Support Audio #6701

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
| [Qwen/QwQ (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
| [Qwen2-VL/QVQ](https://huggingface.co/Qwen) | 2B/7B/72B | qwen2_vl |
| [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio |
| [Skywork o1](https://huggingface.co/Skywork) | 8B | skywork_o1 |
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
| [TeleChat2](https://huggingface.co/Tele-AI) | 3B/7B/35B/115B | telechat2 |
Expand Down
1 change: 1 addition & 0 deletions README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
| [Qwen/QwQ (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
| [Qwen2-VL/QVQ](https://huggingface.co/Qwen) | 2B/7B/72B | qwen2_vl |
| [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio |
| [Skywork o1](https://huggingface.co/Skywork) | 8B | skywork_o1 |
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
| [TeleChat2](https://huggingface.co/Tele-AI) | 3B/7B/35B/115B | telechat2 |
Expand Down
46 changes: 46 additions & 0 deletions data/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ Currently we support datasets in **alpaca** and **sharegpt** format.
"tools": "the column name in the dataset containing the tool description. (default: None)",
"images": "the column name in the dataset containing the image inputs. (default: None)",
"videos": "the column name in the dataset containing the videos inputs. (default: None)",
"audios": "the column name in the dataset containing the audios inputs. (default: None)",
"chosen": "the column name in the dataset containing the chosen answers. (default: None)",
"rejected": "the column name in the dataset containing the rejected answers. (default: None)",
"kto_tag": "the column name in the dataset containing the kto tags. (default: None)"
Expand Down Expand Up @@ -150,6 +151,10 @@ An additional column `images` is required. Please refer to the [sharegpt](#share

An additional column `videos` is required. Please refer to the [sharegpt](#sharegpt-format) format for details.

### Multimodal Audio Dataset

An additional column `audios` is required. Please refer to the [sharegpt](#sharegpt-format) format for details.

## Sharegpt Format

### Supervised Fine-Tuning Dataset
Expand Down Expand Up @@ -374,6 +379,47 @@ Regarding the above dataset, the *dataset description* in `dataset_info.json` sh
}
```

### Multimodal Audio Dataset

- [Example dataset](mllm_audio_demo.json)

Multimodal audio datasets require a `audios` column containing the paths to the input audios.

The number of audios should be identical to the `<audio>` tokens in the conversations.

```json
[
{
"conversations": [
{
"from": "human",
"value": "<audio>human instruction"
},
{
"from": "gpt",
"value": "model response"
}
],
"audios": [
"audio path (required)"
]
}
]
```

Regarding the above dataset, the *dataset description* in `dataset_info.json` should be:

```json
"dataset_name": {
"file_name": "data.json",
"formatting": "sharegpt",
"columns": {
"messages": "conversations",
"audios": "audios"
}
}
```

### OpenAI Format

The openai format is simply a special case of the sharegpt format, where the first message may be a system prompt.
Expand Down
47 changes: 47 additions & 0 deletions data/README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
"tools": "数据集代表工具描述的表头名称(默认:None)",
"images": "数据集代表图像输入的表头名称(默认:None)",
"videos": "数据集代表视频输入的表头名称(默认:None)",
"audios": "数据集代表音频输入的表头名称(默认:None)",
"chosen": "数据集代表更优回答的表头名称(默认:None)",
"rejected": "数据集代表更差回答的表头名称(默认:None)",
"kto_tag": "数据集代表 KTO 标签的表头名称(默认:None)"
Expand Down Expand Up @@ -150,6 +151,10 @@ KTO 数据集需要提供额外的 `kto_tag` 列。详情请参阅 [sharegpt](#s

多模态视频数据集需要提供额外的 `videos` 列。详情请参阅 [sharegpt](#sharegpt-格式)。

### 多模态音频数据集

多模态音频数据集需要提供额外的 `audios` 列。详情请参阅 [sharegpt](#sharegpt-格式)。

## Sharegpt 格式

### 指令监督微调数据集
Expand Down Expand Up @@ -374,6 +379,48 @@ KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人
}
```

### 多模态音频数据集

- [样例数据集](mllm_audio_demo.json)

多模态音频数据集需要额外添加一个 `audios` 列,包含输入音频的路径。

注意音频的数量必须与文本中所有 `<audio>` 标记的数量严格一致。

```json
[
{
"conversations": [
{
"from": "human",
"value": "<audio>人类指令"
},
{
"from": "gpt",
"value": "模型回答"
}
],
"audios": [
"音频路径(必填)"
]
}
]
```

对于上述格式的数据,`dataset_info.json` 中的*数据集描述*应为:

```json
"数据集名称": {
"file_name": "data.json",
"formatting": "sharegpt",
"columns": {
"messages": "conversations",
"audios": "audios"
}
}
```


### OpenAI 格式

OpenAI 格式仅仅是 sharegpt 格式的一种特殊情况,其中第一条消息可能是系统提示词。
Expand Down
14 changes: 14 additions & 0 deletions data/dataset_info.json
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,20 @@
"assistant_tag": "assistant"
}
},
"mllm_audio_demo": {
"file_name": "mllm_audio_demo.json",
"formatting": "sharegpt",
"columns": {
"messages": "messages",
"audios": "audios"
},
"tags": {
"role_tag": "role",
"content_tag": "content",
"user_tag": "user",
"assistant_tag": "assistant"
}
},
"mllm_video_demo": {
"file_name": "mllm_video_demo.json",
"formatting": "sharegpt",
Expand Down
47 changes: 47 additions & 0 deletions data/mllm_audio_demo.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
[
{
"messages": [
{
"content": "<audio>What's that sound?",
"role": "user"
},
{
"content": "It is the sound of glass shattering.",
"role": "assistant"
}
],
"audios": [
"mllm_demo_data/1.mp3"
]
},
{
"messages": [
{
"content": "<audio>What can you hear?",
"role": "user"
},
{
"content": "A woman is coughing.",
"role": "assistant"
}
],
"audios": [
"mllm_demo_data/2.wav"
]
},
{
"messages": [
{
"content": "<audio>What does the person say?",
"role": "user"
},
{
"content": "Mister Quiller is the apostle of the middle classes and we are glad to welcome his gospel.",
"role": "assistant"
}
],
"audios": [
"mllm_demo_data/3.flac"
]
}
]
Binary file added data/mllm_demo_data/1.mp3
Binary file not shown.
Binary file added data/mllm_demo_data/2.wav
Binary file not shown.
Binary file added data/mllm_demo_data/3.flac
Binary file not shown.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
transformers>=4.41.2,<=4.46.1
transformers>=4.41.2
datasets>=2.16.0,<=3.1.0
accelerate>=0.34.0,<=1.0.1
peft>=0.11.1,<=0.12.0
Expand Down
1 change: 1 addition & 0 deletions scripts/stat_utils/cal_ppl.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor
"labels": feature["chosen_input_ids"] if self.train_on_prompt else feature["chosen_labels"],
"images": feature["images"],
"videos": feature["videos"],
"audios": feature["audios"],
}
)

Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def get_console_scripts() -> List[str]:
"jsonschema_specifications",
"librosa",
],
"audio": ["torchaudio", "librosa", "pyctcdecode", "phonemizer", "kenlm"],
"modelscope": ["modelscope"],
"openmind": ["openmind"],
"swanlab": ["swanlab"],
Expand Down
4 changes: 3 additions & 1 deletion src/llamafactory/chat/base_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from vllm import AsyncLLMEngine

from ..data import Template
from ..data.mm_plugin import ImageInput, VideoInput
from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments


Expand Down Expand Up @@ -68,6 +68,7 @@ async def chat(
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
**input_kwargs,
) -> List["Response"]:
r"""
Expand All @@ -83,6 +84,7 @@ async def stream_chat(
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
r"""
Expand Down
16 changes: 11 additions & 5 deletions src/llamafactory/chat/chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@


if TYPE_CHECKING:
from ..data.mm_plugin import ImageInput, VideoInput
from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
from .base_engine import BaseEngine, Response


Expand Down Expand Up @@ -66,13 +66,14 @@ def chat(
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
**input_kwargs,
) -> List["Response"]:
r"""
Gets a list of responses of the chat model.
"""
task = asyncio.run_coroutine_threadsafe(
self.achat(messages, system, tools, images, videos, **input_kwargs), self._loop
self.achat(messages, system, tools, images, videos, audios, **input_kwargs), self._loop
)
return task.result()

Expand All @@ -83,12 +84,13 @@ async def achat(
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
**input_kwargs,
) -> List["Response"]:
r"""
Asynchronously gets a list of responses of the chat model.
"""
return await self.engine.chat(messages, system, tools, images, videos, **input_kwargs)
return await self.engine.chat(messages, system, tools, images, videos, audios, **input_kwargs)

def stream_chat(
self,
Expand All @@ -97,12 +99,13 @@ def stream_chat(
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
**input_kwargs,
) -> Generator[str, None, None]:
r"""
Gets the response token-by-token of the chat model.
"""
generator = self.astream_chat(messages, system, tools, images, videos, **input_kwargs)
generator = self.astream_chat(messages, system, tools, images, videos, audios, **input_kwargs)
while True:
try:
task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop)
Expand All @@ -117,12 +120,15 @@ async def astream_chat(
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
r"""
Asynchronously gets the response token-by-token of the chat model.
"""
async for new_token in self.engine.stream_chat(messages, system, tools, images, videos, **input_kwargs):
async for new_token in self.engine.stream_chat(
messages, system, tools, images, videos, audios, **input_kwargs
):
yield new_token

def get_scores(
Expand Down
Loading
Loading