-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy paththree_flood_bp_nn_tflearn.py
39 lines (32 loc) · 1.3 KB
/
three_flood_bp_nn_tflearn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# -- encoding: utf-8 --
'在titanic数据集上使用tflearn搭建一个简单的3层神经网络'
__author__ = 'zcy'
import numpy as np
import tflearn
from tflearn.data_utils import load_csv
data, labels = load_csv('titanic_dataset.csv', target_column=0,
categorical_labels=True, n_classes=2)
# 数据预处理
def preprocess(data, columns_to_ignore):
for id in sorted(columns_to_ignore, reverse=True):
[r.pop(id) for r in data]
for i in range(len(data)):
data[i][1] = 1. if data[i][1] == 'female' else 0.
return np.array(data, dtype=np.float32)
to_ignore=[1, 6]
data = preprocess(data, to_ignore)
net = tflearn.input_data(shape=[None, 6])
net = tflearn.fully_connected(net, 32)
net = tflearn.fully_connected(net, 32)
net = tflearn.fully_connected(net, 2, activation='softmax')
net = tflearn.regression(net)
model = tflearn.DNN(net)
# 训练模型
model.fit(data, labels, n_epoch=50, batch_size=128, show_metric=True)
# 预测新样本的类别
dicaprio = [3, 'Jack Dawson', 'male', 19, 0, 0, 'N/A', 5.0000]
winslet = [1, 'Rose DeWitt Bukater', 'female', 17, 1, 2, 'N/A', 100.0000]
dicaprio, winslet = preprocess([dicaprio, winslet], to_ignore)
pred = model.predict([dicaprio, winslet])
print("DiCaprio Surviving Rate:", pred[0][1])
print("Winslet Surviving Rate:", pred[1][1])