🏃 [Chinese chatbot] tensorflow implementation of seq2seq model with bahdanau attention and Word2Vec pretrained embedding
This repo is based on this article, written by liuchongee.
$ git clone https://github.com/AdrianHsu/tensorflow-chatbot-chinese.git
# put your own training/eval data in the correct path, as shown above
$ ./run.sh
- You should download the pretrained model here and then put it into
save/
directory. - make sure your input is already put in the correct path, and also it is pre-processed by text segmentation APIs, for example, jieba.
$ ./hw2_seq2seq.sh
You have to download the frozen model first.
$ ./download_model.sh
And then just directly run the server file.
$ python3 server.py
You'll need aiohttp
, python-socketio
to run this. The result will be shown on https://localhost:8080
.
WARNING: do not install socketio
, this is not the same as python-socketio
!!
MAX_SENTENCE_LENGTH = 15 # abandon those sentences with > 15 words
special_tokens = {'<PAD>': 0, '<BOS>': 1, '<EOS>': 2, '<UNK>': 3}
emb_size = 300
# Model parameters
num_layers = 2
rnn_size = 2048 # GRU cell
keep_prob = 1.0
vocab_num = 10561
min_cnt = 100 # min counts, we only preserver those words appears >= 100 times
num_epochs = 50
batch_size = 250
lr = 0.5
lr = tf.train.exponential_decay(...
decay_steps=3250, decay_rate=0.95, staircase=True)
- optimizer:
GradientDescentOptimizer()
- schedule sampling:
inverse sigmoid(0.88 to 0.5)
- attention method: Bahdanau Attention
- random seed are all fixed to
0
.
- Attention Method: Bahdanau Attention
- RNN Cell : GRU
$ tree -L 1
.
├── clean.sh # clean up the `logs/` and `save/`
├── embeddings.npy # [vocab_num, emb_size] matrix, not used in the final version
├── handler.py # data handler
├── hw2_seq2seq.sh # script for testing data
├── logs # logs for tensorboard
├── model_seq2seq.py # training, validation, inference session & model buliding
├── output.txt # output default path for testing
├── perplexity.sh # not used
├── README.md
├── run.sh # re-train the model
├── save # model saver checkpoint
├── test_input.txt # input default path for training
├── util.py # print function, inverse sigmoid sampling function
├── idx2word.pkl
├── word2idx.pkl
└── word2vec.model # generated by gensim Word2Vec model
The data for training/validation is only a txt file named conversation.txt, and I put it in:
/home/data/mlds_hw2_2_data/conversation.txt
You could modify the path through passing an argument —data_dir
.
Also, if you want to use your own data, you should modify the model_seq2seq.py
, which is originally set to:
filename = '/conversation.txt' # text segmentation is already finished
total_line_num = 3599478
train_line_num = 3587000
eval_line_num = 12478
emb_size = 300 # embedding size
PKL_EXIST = True # if you want to re-train the GenSim model, you should set it False
這 不是 一時 起意 的 行刺
而是 有 政治 動機
上校 , 這種 事
+++$+++
他 的 口袋 是 空 的
沒有 皮夾 , 也 沒有 身分證
手錶 停 在 4 點 15 分
大概 是 墜機 的 時刻
他 的 降落傘 被 樹枝 纏住 了
- Every dialogs are split by
+++$+++
- the prior sentence will be input, and the sentence itself will be ground truth (pair format)
- for example (this is for training/validation):
[[['這', '不是', '一時', '起意', '的', '行刺'], ['而是', '有', '政治', '動機']],
[['而是', '有', '政治', '動機'], [‘上校’, ',', '這種', '事']], ...]
# train
original line num: 3587000
used data num: 1642549
# validation
original line num: 12478
used data num: 4762
min_count
: 100MAX_SENTENCE_LENGTH
: 15MIN_SENTENCE_LENGTH
: 2, therefore, those sentence with only 1 word will be abandonedunk_num / float(len(sent)) > 0.1
: for example, if there are 10 words in this sentence, with only 1 , then it will be preserved
raw_line.append(['<PAD>'] * 3999999) # to ensure <PAD> will be in index 0 (after sort)
raw_line.append(['<BOS>'] * 100)
raw_line.append(['<EOS>'] * 50)
raw_line.append(['<UNK>'] * 2999999) # to ensure <UNK> will be in index 3 (after sort)
...
(in for loop)
line = text_to_word_sequence(line, lower=True, split=" ") # Keras API
line.insert(0, '<BOS>')
line.append('<EOS>')
raw_line.append(line)
...
self.model = Word2Vec(raw_line, size=emb_size,
sorted_vocab=1, min_count=min_count, workers=4)
It takes about 30 hours (with one 1080Ti ) to make the training loss decrease from 5.8 to 2.5, (the generated sentence looks quite okay after training loss <= 2.5).
Questions: ['好', '了', '大家', '到此爲止', '了', '。']
Answers : ['請', '讓', '當事人', '離開', '讓', '他', '走', '吧', '。', '<EOS>']
Predict : ['請', '讓', '當事人', '離開', '我', '就', '走', '吧', '。', '<EOS>']
Questions: ['這下', '讓', '我', '看看', '啊', '。']
Answers : ['再給', '我', '一分鐘', '不行', '嗎', '。', '<EOS>']
Predict : ['再給', '我', '一分鐘', '不行', '。', '<EOS>']
Questions: ['高', '譚市', '盡', '在', '你', '掌握', '之中', '了', '。']
Answers : ['這樣', '做', '是', '不', '對', '的', '。', '<EOS>']
Predict : ['這樣', '做', '是', '不', '對', '的', '。', '<EOS>']
Questions: ['別', '激動', '別', '激動', '。']
Answers : ['把', '武器', '放下', '。', '<EOS>'] (len: 16, eos: 5)
Predict : ['我們', '沒事', '吧', '。', '<EOS>']
Questions: ['最糟', '的', '就是', '讓', '他', '知道', '。']
Answers : ['我們', '害怕', '。', '<EOS>']
Predict : ['他', '是', '我', '的', '兒子', '的', '。', '<EOS>']
Questions: ['起來', '你', '沒事', '吧', '。']
Answers : ['不好意思', '姐姐', '剛才', '真的', '沒', '看到', '你', '。', '<EOS>']
Predict : ['沒事', '。', '<EOS>']
Questions: ['我愛你', '。']
Answers : ['上星期', '我', '在', '教堂', '問', '上帝', '。', '<EOS>']
Predict : ['我', '也', '愛', '你', '。', '<EOS>']
It takes about 30 hours (with one 1080Ti ) to make the training loss decrease from 5.8 to 2.5, (the generated sentence looks quite okay after training loss <= 2.5).