Skip to content

Commit

Permalink
Merge pull request #126 from StevenTang1998/main
Browse files Browse the repository at this point in the history
Update
  • Loading branch information
StevenTang1998 authored Apr 15, 2021
2 parents 3a69d3a + dc11775 commit d04499f
Show file tree
Hide file tree
Showing 12 changed files with 484 additions and 242 deletions.
117 changes: 111 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ We provide the support for 9 benchmark text generation datasets. A user can appl
<b>Figure</b>: The Overall Architecture of TextBox
</p>


## Feature

- **Unified and modularized framework.** TextBox is built upon PyTorch and designed to be highly modularized, by decoupling diverse models into a set of highly reusable modules.
Expand Down Expand Up @@ -131,7 +130,7 @@ python run_textbox.py --model=GPT2 --dataset=COCO \
--pretrained_model_path=pretrained_model/gpt2
```

### **Train with Distributed Data Parallel**
### **Train with Distributed Data Parallel (DDP)**

TextBox supports to train models with multiple GPUs conveniently. You don't need to modify the model, just run the following command:

Expand Down Expand Up @@ -419,7 +418,7 @@ NLL, BLEU and SBLEU on test dataset:
| **MaskGAN** | 509.58 | 56.61 | 21.41 | 4.49 | 0.86 | 92.09 | 77.88 | 59.62 | 42.36 |
| **GPT-2** | 348.67 | 72.52 | 41.75 | 15.40 | 4.22 | 86.21 | 58.26 | 30.03 | 12.56 |

Part of generated examples (with `max_length` 100):
Part of generated examples (with `max_length=100`):

<table align="center">
<thead>
Expand All @@ -446,7 +445,7 @@ Part of generated examples (with `max_length` 100):

#### GigaWord (Summarization)

ROUGE metric on test dataset using beam search (with `beam_size` 5):
ROUGE metric on test dataset using beam search (with `beam_size=5`):

<table align="center">
<thead>
Expand Down Expand Up @@ -486,8 +485,19 @@ ROUGE metric on test dataset using beam search (with `beam_size` 5):
<td align="center">40.06</td>
<td align="center">26.21</td>
</tr>
<td align="center"><strong>ProphetNet</strong></td>
<td align="center">38.49</td>
<td align="center">18.41</td>
<td align="center">39.84</td>
<td align="center">26.12</td>
</tr>
<td align="center"><strong>T5</strong></td>
<td align="center">38.83</td>
<td align="center">19.68</td>
<td align="center">40.76</td>
<td align="center">26.73</td>
</tr>
</tbody></table>

Part of generated examples:

<table align="center">
Expand All @@ -508,11 +518,23 @@ Part of generated examples:
<td align="center"><b>Transformer</b></td>
<td>nec computer to join forces in chip sales</td>
</tr>
<td align="center"><b>BART</b></td>
<td>nec computer corp.</td>
</tr>
<td align="center"><b>BERT2BERT</b></td>
<td>nec computer form alliance for supercomputer sales</td>
</tr>
<td align="center"><b>ProphetNet</b></td>
<td>nec computer to join forces in supercomputer sales</td>
</tr>
<td align="center"><b>T5</b></td>
<td>nec computer to join forces in supercomputer sales</td>
</tr>
</tbody></table>

#### IWSLT2014 German-English (Translation)

BLEU metric on test dataset with three decoding strategies: top-k sampling, greedy search and beam search (with `beam_size` 5):
BLEU metric on test dataset with three decoding strategies: top-k sampling, greedy search and beam search (with `beam_size=5`):

<table align="center">
<thead>
Expand Down Expand Up @@ -606,6 +628,87 @@ Part of generated examples:
</tr>
</tbody></table>

#### Persona Chat (Dialogue)

BLEU and distinct metrics on test dataset using beam search (with `beam_size=5`):

<table align="center">
<thead>
<tr>
<th align="center">Model</th>
<th align="center">Distinct-1</th>
<th align="center">Distinct-2</th>
<th align="center">BLEU-1</th>
<th align="center">BLEU-2</th>
<th align="center">BLEU-3</th>
<th align="center">BLEU-4</th>
</tr>
</thead>
<tbody><tr>
<td align="center"><strong>RNN with Attention</strong></td>
<td align="center">0.24</td>
<td align="center">0.72</td>
<td align="center">17.51</td>
<td align="center">4.65</td>
<td align="center">2.11</td>
<td align="center">1.47</td>
</tr>
<tr>
<td align="center"><strong>Transformer</strong></td>
<td align="center">0.38</td>
<td align="center">2.28</td>
<td align="center">17.29</td>
<td align="center">4.85</td>
<td align="center">2.32</td>
<td align="center">1.65</td>
</tr>
<tr>
<td align="center"><strong>HRED</strong></td>
<td align="center">0.22</td>
<td align="center">0.63</td>
<td align="center">17.29</td>
<td align="center">4.72</td>
<td align="center">2.20</td>
<td align="center">1.60</td>
</tr>
</tbody></table>

#### Amazon Electronic (Attribute to text)

BLEU and distinct metrics on test dataset using beam search (with `beam_size=5`):

<table align="center">
<thead>
<tr>
<th align="center">Model</th>
<th align="center">Distinct-1</th>
<th align="center">Distinct-2</th>
<th align="center">BLEU-1</th>
<th align="center">BLEU-2</th>
<th align="center">BLEU-3</th>
<th align="center">BLEU-4</th>
</tr>
</thead>
<tbody><tr>
<td align="center"><strong>Context2Seq</strong></td>
<td align="center">0.07</td>
<td align="center">0.39</td>
<td align="center">17.21</td>
<td align="center">2.80</td>
<td align="center">0.83</td>
<td align="center">0.43</td>
</tr>
<tr>
<td align="center"><strong>Attr2Seq</strong></td>
<td align="center">0.14</td>
<td align="center">2.81</td>
<td align="center">17.14</td>
<td align="center">2.81</td>
<td align="center">0.87</td>
<td align="center">0.48</td>
</tr>
</tbody></table>

## Releases

| Releases | Date | Features |
Expand All @@ -621,6 +724,8 @@ We welcome all contributions from bug fixes to new features and extensions.

We expect all contributions discussed in the issue tracker and going through PRs.

We thank [@LucasTsui0725](https://github.com/LucasTsui0725/) for contributing HRED model and [@Richar-Du](https://github.com/Richar-Du/) for CVAE model.

## Reference

If you find TextBox useful for your research or development, please cite the following [paper](https://arxiv.org/abs/2101.02046):
Expand Down
123 changes: 114 additions & 9 deletions README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -296,8 +296,6 @@ python -m torch.distributed.launch --nproc_per_node=[gpu_num] \
</tr>
</tbody>
</table>


下载好的数据集需要放到 `dataset` 目录下面,和我们项目中的结构类似。

我们也支持用户在自己的数据集上训练模型,只需要按照下面三个步骤操作即可:
Expand All @@ -324,7 +322,7 @@ python -m torch.distributed.launch --nproc_per_node=[gpu_num] \

#### Image COCO Caption

测试集负对数似然 (NLL), BLEU and Self-BLEU (SBLEU)度量结果展示
测试集负对数似然 (NLL), BLEU and Self-BLEU (SBLEU)指标结果展示

| Model | NLL | BLEU-2 | BLEU-3 | BLEU-4 | BLEU-5 | SBLEU-2 | SBLEU-3 | SBLEU-4 | SBLEU-5 |
| :-----------: | :---: | :----: | :----: | :----: | :----: | :-----: | :-----: | :-----: | :-----: |
Expand Down Expand Up @@ -364,7 +362,7 @@ python -m torch.distributed.launch --nproc_per_node=[gpu_num] \

#### EMNLP2017 WMT News

测试集NLL, BLEU, SBLEU度量结果展示
测试集NLL, BLEU, SBLEU指标结果展示

| Model | NLL | BLEU-2 | BLEU-3 | BLEU-4 | BLEU-5 | SBLEU-2 | SBLEU-3 | SBLEU-4 | SBLEU-5 |
| :-----------: | :----: | :----: | :----: | :----: | :----: | :-----: | :-----: | :-----: | :-----: |
Expand Down Expand Up @@ -404,7 +402,7 @@ python -m torch.distributed.launch --nproc_per_node=[gpu_num] \

#### IMDB Movie Review

测试集NLL, BLEU, SBLEU度量结果展示
测试集NLL, BLEU, SBLEU指标结果展示

| Model | NLL | BLEU-2 | BLEU-3 | BLEU-4 | BLEU-5 | SBLEU-2 | SBLEU-3 | SBLEU-4 | SBLEU-5 |
| :-----------: | :----: | :----: | :----: | :----: | :----: | :-----: | :-----: | :-----: | :-----: |
Expand All @@ -419,7 +417,7 @@ python -m torch.distributed.launch --nproc_per_node=[gpu_num] \
| **MaskGAN** | 509.58 | 56.61 | 21.41 | 4.49 | 0.86 | 92.09 | 77.88 | 59.62 | 42.36 |
| **GPT-2** | 348.67 | 72.52 | 41.75 | 15.40 | 4.22 | 86.21 | 58.26 | 30.03 | 12.56 |

部分生成实例展示(最大长度 `max_length`100):
部分生成实例展示(最大长度 `max_length`设置为100 ):

<table align="center">
<thead>
Expand All @@ -446,7 +444,7 @@ python -m torch.distributed.launch --nproc_per_node=[gpu_num] \

#### GigaWord (摘要)

使用beam搜索在测试集上的ROUGE度量(beam搜索大小 `beam_size` 设置为5):
使用beam搜索在测试集上的ROUGE指标(beam搜索大小 `beam_size` 设置为5):

<table align="center">
<thead>
Expand Down Expand Up @@ -486,8 +484,19 @@ python -m torch.distributed.launch --nproc_per_node=[gpu_num] \
<td align="center">40.06</td>
<td align="center">26.21</td>
</tr>
<td align="center"><strong>ProphetNet</strong></td>
<td align="center">38.49</td>
<td align="center">18.41</td>
<td align="center">39.84</td>
<td align="center">26.12</td>
</tr>
<td align="center"><strong>T5</strong></td>
<td align="center">38.83</td>
<td align="center">19.68</td>
<td align="center">40.76</td>
<td align="center">26.73</td>
</tr>
</tbody></table>

部分生成实例展示:

<table align="center">
Expand All @@ -508,11 +517,24 @@ python -m torch.distributed.launch --nproc_per_node=[gpu_num] \
<td align="center"><b>Transformer</b></td>
<td>nec computer to join forces in chip sales</td>
</tr>
<td align="center"><b>BART</b></td>
<td>nec computer corp.</td>
</tr>
<td align="center"><b>BERT2BERT</b></td>
<td>nec computer form alliance for supercomputer sales</td>
</tr>
<td align="center"><b>ProphetNet</b></td>
<td>nec computer to join forces in supercomputer sales</td>
</tr>
<td align="center"><b>T5</b></td>
<td>nec computer to join forces in supercomputer sales</td>
</tr>
</tbody></table>


#### IWSLT2014 German-English(翻译)

测试集上的BLEU度量有三种解码策略:top-k采样、贪婪搜索和beam搜索(beam搜索大小 `beam_size` 设置为5):
测试集上的BLEU指标有三种解码策略:top-k采样、贪婪搜索和beam搜索(beam搜索大小 `beam_size` 设置为5):

<table align="center">
<thead>
Expand Down Expand Up @@ -606,6 +628,87 @@ python -m torch.distributed.launch --nproc_per_node=[gpu_num] \
</tr>
</tbody></table>

#### Persona Chat (对话)

使用beam搜索在测试集上的BLEU、Distinct指标(beam搜索大小 `beam_size` 设置为5):

<table align="center">
<thead>
<tr>
<th align="center">Model</th>
<th align="center">Distinct-1</th>
<th align="center">Distinct-2</th>
<th align="center">BLEU-1</th>
<th align="center">BLEU-2</th>
<th align="center">BLEU-3</th>
<th align="center">BLEU-4</th>
</tr>
</thead>
<tbody><tr>
<td align="center"><strong>RNN with Attention</strong></td>
<td align="center">0.24</td>
<td align="center">0.72</td>
<td align="center">17.51</td>
<td align="center">4.65</td>
<td align="center">2.11</td>
<td align="center">1.47</td>
</tr>
<tr>
<td align="center"><strong>Transformer</strong></td>
<td align="center">0.38</td>
<td align="center">2.28</td>
<td align="center">17.29</td>
<td align="center">4.85</td>
<td align="center">2.32</td>
<td align="center">1.65</td>
</tr>
<tr>
<td align="center"><strong>HRED</strong></td>
<td align="center">0.22</td>
<td align="center">0.63</td>
<td align="center">17.29</td>
<td align="center">4.72</td>
<td align="center">2.20</td>
<td align="center">1.60</td>
</tr>
</tbody></table>

#### Amazon Electronic (属性文本生成)

使用beam搜索在测试集上的BLEU、Distinct指标(beam搜索大小 `beam_size` 设置为5):

<table align="center">
<thead>
<tr>
<th align="center">Model</th>
<th align="center">Distinct-1</th>
<th align="center">Distinct-2</th>
<th align="center">BLEU-1</th>
<th align="center">BLEU-2</th>
<th align="center">BLEU-3</th>
<th align="center">BLEU-4</th>
</tr>
</thead>
<tbody><tr>
<td align="center"><strong>Context2Seq</strong></td>
<td align="center">0.07</td>
<td align="center">0.39</td>
<td align="center">17.21</td>
<td align="center">2.80</td>
<td align="center">0.83</td>
<td align="center">0.43</td>
</tr>
<tr>
<td align="center"><strong>Attr2Seq</strong></td>
<td align="center">0.14</td>
<td align="center">2.81</td>
<td align="center">17.14</td>
<td align="center">2.81</td>
<td align="center">0.87</td>
<td align="center">0.48</td>
</tr>
</tbody></table>

## TextBox重要发布

| 发行版本 | 日期 | 特点 |
Expand All @@ -621,6 +724,8 @@ python -m torch.distributed.launch --nproc_per_node=[gpu_num] \

我们希望所有的贡献者先在issue中提出问题,然后再提PR

我们感谢[@LucasTsui0725](https://github.com/LucasTsui0725/)实现了HRED模型,[@Richar-Du](https://github.com/Richar-Du/)实现了CVAE模型。

## 引用

如果你觉得TextBox对你的科研工作有帮助,请引用我们的 [论文](https://arxiv.org/abs/2101.02046):
Expand Down
Loading

0 comments on commit d04499f

Please sign in to comment.