Skip to content

Commit 6169d2c

Browse files
committed
Updated TODO, added comments
1 parent 57a04fa commit 6169d2c

File tree

7 files changed

+90
-48
lines changed

7 files changed

+90
-48
lines changed

scripts/activations.py

+14-14
Original file line numberDiff line numberDiff line change
@@ -2,65 +2,65 @@
22

33

44
# Relu
5-
def relu(y):
5+
def relu(y) -> np.ndarray:
66
return np.maximum(y, 0)
77

88

9-
def relu_prime(y):
9+
def relu_prime(y) -> np.ndarray:
1010
return y > 0
1111

1212

1313
# Leaky relu
14-
def leaky_relu(y):
14+
def leaky_relu(y) -> np.ndarray:
1515
return np.where(y > 0, y, y * 0.01)
1616

1717

18-
def leaky_relu_prime(y):
18+
def leaky_relu_prime(y) -> np.ndarray:
1919
return (y >= 0) + (y < 0)*0.01
2020

2121

2222
# Linear
23-
def linear(y):
23+
def linear(y) -> np.ndarray:
2424
return y
2525

2626

27-
def linear_prime(y):
27+
def linear_prime(y) -> np.ndarray:
2828
return 1
2929

3030

3131
# Heavyside
32-
def heaviside(y):
32+
def heaviside(y) -> np.ndarray:
3333
return 1 * (y > 0)
3434

3535

36-
def heaviside_prime(y):
36+
def heaviside_prime(y) -> np.ndarray:
3737
return 0
3838

3939

4040
# Sigmoid
41-
def sigmoid(y):
41+
def sigmoid(y) -> np.ndarray:
4242
return 1 / (1 + np.exp(-y))
4343

4444

45-
def sigmoid_prime(y):
45+
def sigmoid_prime(y) -> np.ndarray:
4646
return y * (1 - y)
4747

4848

4949
# Tanh
50-
def tanh(y):
50+
def tanh(y) -> np.ndarray:
5151
return np.tanh(y)
5252

5353

54-
def tanh_prime(y):
54+
def tanh_prime(y) -> np.ndarray:
5555
return 1 - y**2
5656

5757

5858
# Arctan
59-
def arctan(y):
59+
def arctan(y) -> np.ndarray:
6060
return np.arctan(y)
6161

6262

63-
def arctan_prime(y):
63+
def arctan_prime(y) -> np.ndarray:
6464
return 1 / y**2 + 1
6565

6666

scripts/loulou.py

+48-21
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
1+
from tqdm import tqdm
12
import numpy as np
23
import json
3-
import os
4-
import mnist
54
import sys
6-
from tqdm import tqdm
5+
import os
76

87
import activations
8+
import mnist
99
import utils
1010

1111

12-
def feed_forward(X_input, weights, activation_fn):
12+
def feed_forward(X_input: np.ndarray, weights: list, activation_fn: list) -> np.ndarray:
1313
"""Feed fordward the network
1414
1515
X_input => input layer
@@ -34,7 +34,7 @@ def feed_forward(X_input, weights, activation_fn):
3434
return x
3535

3636

37-
def grads(x, y_expected, weights, activations_fn, activations_prime):
37+
def grads(x: np.ndarray, y_expected: np.ndarray, weights: list, activations_fn: list, activations_prime: list) -> np.ndarray:
3838
"""Calculate errors corrections with backward propagation
3939
4040
x => input layer
@@ -54,7 +54,7 @@ def grads(x, y_expected, weights, activations_fn, activations_prime):
5454
delta = y[-1] - y_expected
5555

5656
# Calculate error of output weights layer
57-
grads = np.empty_like(weights)
57+
grads: np.ndarray = np.empty_like(weights)
5858
grads[-1] = y[-2].T.dot(delta)
5959

6060
# Backward loop
@@ -69,64 +69,91 @@ def grads(x, y_expected, weights, activations_fn, activations_prime):
6969
return grads / len(x)
7070

7171

72-
def train(weights, trX, trY, teX, teY, filename, epochs, batch, learning_rate, save_timeout, reduce_output, activations_fn, activations_prime):
72+
def train(weights: list, trX: np.ndarray, trY: np.ndarray, teX: np.ndarray, teY: np.ndarray, activations_fn: list, activations_prime: list, filename: np.ndarray, epochs: int, batch: int, learning_rate: float, save_timeout: int, reduce_output: int) -> dict:
7373
path = os.path.dirname(__file__)
74-
accuracy = {}
75-
prediction = np.argmax(feed_forward(
74+
accuracy = []
75+
76+
# Make prediction with the untrained network
77+
prediction: np.ndarray = np.argmax(feed_forward(
7678
teX, weights, activations_fn)[-1], axis=1)
77-
accuracy[0] = np.mean(prediction == np.argmax(teY, axis=1))
78-
if reduce_output < 2:
79+
accuracy.append(np.mean(prediction == np.argmax(teY, axis=1)))
80+
81+
if reduce_output <= 1:
7982
print('Accuracy of epoch 0 :', accuracy[0])
80-
if reduce_output == 2:
83+
elif reduce_output == 2:
8184
print(0, accuracy[0])
85+
8286
if epochs < 0:
8387
epochs = 99999999999
88+
89+
# Epochs loop
8490
for i in range(epochs):
8591
if reduce_output < 1:
92+
8693
pbar = tqdm(range(0, len(trX), batch))
8794
else:
8895
pbar = range(0, len(trX), batch)
96+
97+
# Batches loop
8998
for j in pbar:
9099
if reduce_output < 1:
91100
pbar.set_description("Processing epoch %s" % (i+1))
92101

102+
# Select training data
93103
X, Y = trX[j:j+batch], trY[j:j+batch]
104+
105+
# Correct the network
94106
weights -= learning_rate * \
95107
grads(X, Y, weights, activations_fn, activations_prime)
108+
109+
# Make prediction for epoch
96110
prediction = np.argmax(feed_forward(
97111
teX, weights, activations_fn)[-1], axis=1)
98-
accuracy[i+1] = np.mean(prediction == np.argmax(teY, axis=1))
112+
accuracy.append(np.mean(prediction == np.argmax(teY, axis=1)))
113+
99114
if reduce_output < 2:
100115
print('Accuracy of epoch', i+1, ':', accuracy[i+1])
101116
if reduce_output == 2:
102117
print(i+1, accuracy[i+1])
118+
119+
# Save temp file if set so
103120
if filename:
104121
if save_timeout > 0:
105122
if i % save_timeout == 0:
106123
temp_filename = '../trains/temp/' + \
107124
filename + '_epoch_' + str(i) + '.npy'
108125
temp_filename = os.path.join(path, temp_filename)
109126
utils.save(weights, temp_filename, reduce_output)
127+
128+
# Save final file
110129
if filename:
111130
filename = os.path.join(path, '../trains/' + filename + '.npy')
112131
utils.save(weights, filename, reduce_output)
132+
113133
return accuracy
114134

115135

116-
def runTrain(params, architecture, file=None):
117-
params = json.loads(params)
118-
epochs = params['epochs']
119-
batch = params['batch']
120-
learning_rate = params['learning_rate']
121-
save_timeout = params['save_timeout']
122-
reduce_output = params['reduce_output']
136+
def runTrain(params: dict, architecture: list, file=None) -> dict:
137+
params: dict = json.loads(params)
138+
epochs: int = params['epochs']
139+
batch: int = params['batch']
140+
learning_rate: float = params['learning_rate']
141+
save_timeout: int = params['save_timeout']
142+
reduce_output: int = params['reduce_output']
123143
activations_arch, primes_arch = activations.listToActivations(
124144
params['activations'], architecture)
125145

146+
# Print network visualization
126147
if reduce_output < 1:
127148
utils.print_network_visualization(
128149
architecture, activations_arch, epochs, batch, learning_rate)
129150

151+
# Load data
152+
# TODO do not load arbitrary data
130153
trX, trY, teX, teY = mnist.load_data()
154+
155+
# Init weights
131156
weights = [np.random.randn(*w) * 0.1 for w in architecture]
132-
return train(weights, trX, trY, teX, teY, file, epochs, batch, learning_rate, save_timeout, reduce_output, activations_arch, primes_arch)
157+
158+
# Train network
159+
return train(weights, trX, trY, teX, teY, activations_arch, primes_arch, file, epochs, batch, learning_rate, save_timeout, reduce_output)

scripts/mnist.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
# Download and import the MNIST dataset from Yann LeCun's website.
1010
# Reserve 10,000 examples from the training set for validation.
1111
# Each image is an array of 784 (28x28) float values from 0 (white) to 1 (black).
12-
def load_data(one_hot=True, reshape=None, validation_size=10000):
12+
13+
14+
def load_data(one_hot: bool = True, reshape: bool = None, validation_size: int = 10000) -> (np.ndarray, np.ndarray, np.ndarray, np.ndarray):
1315
x_tr = load_images('train-images-idx3-ubyte.gz')
1416
y_tr = load_labels('train-labels-idx1-ubyte.gz')
1517
x_te = load_images('t10k-images-idx3-ubyte.gz')
@@ -26,26 +28,30 @@ def load_data(one_hot=True, reshape=None, validation_size=10000):
2628

2729
return x_tr, y_tr, x_te, y_te
2830

29-
def load_images(filename):
31+
32+
def load_images(filename: str) -> np.ndarray:
3033
maybe_download(filename)
3134
with gzip.open(path+filename, 'rb') as f:
3235
data = np.frombuffer(f.read(), np.uint8, offset=16)
3336
return data.reshape(-1, 28 * 28) / np.float32(256)
3437

35-
def load_labels(filename):
38+
39+
def load_labels(filename: str) -> np.ndarray:
3640
maybe_download(filename)
3741
with gzip.open(path+filename, 'rb') as f:
3842
data = np.frombuffer(f.read(), np.uint8, offset=8)
3943
return data
4044

45+
4146
# Download the file, unless it's already here.
42-
def maybe_download(filename):
47+
def maybe_download(filename: str) -> None:
4348
if not os.path.exists(path+filename):
4449
print('Please wait while downloading training dataset.')
4550
from urllib.request import urlretrieve
4651
print("Downloading %s" % filename)
4752
urlretrieve(DATA_URL + filename, path+filename)
4853

54+
4955
# Convert class labels from scalars to one-hot vectors.
50-
def to_one_hot(labels, num_classes=10):
56+
def to_one_hot(labels: np.ndarray, num_classes: int = 10) -> np.ndarray:
5157
return np.eye(num_classes)[labels]

scripts/run.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
#!/usr/bin/env python3
2+
13
import numpy as np
24
import sys
35
import matplotlib.image as image

scripts/train.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1-
from loulou import runTrain
2-
from utils import listToArch
1+
#!/usr/bin/env python3
2+
33
import json
44
import argparse
55

6+
from loulou import runTrain
7+
from utils import listToArch
8+
69
if __name__ == '__main__':
710

811
params = {}

scripts/utils.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,20 @@
22
import json
33

44

5-
def save(weights, filename, reduce_output):
5+
def save(weights: list, filename: str, reduce_output: int) -> None:
66
np.save(filename, weights)
77
if reduce_output < 2:
88
print('Data saved successfully into ', filename)
99

1010

11-
def convertJson(pred):
11+
def convertJson(pred: np.ndarray) -> str:
1212
out = {}
1313
out['hot_prediction'] = list(pred)
1414
out['prediction'] = int(np.argmax(pred))
1515
return json.dumps(out)
1616

1717

18-
def listToArch(list):
18+
def listToArch(list: list) -> list:
1919
arch = []
2020
id = 0
2121
for hl in list:
@@ -33,7 +33,7 @@ def listToArch(list):
3333
return arch
3434

3535

36-
def print_network_visualization(architecture, activations_arch, epochs, batch, learning_rate):
36+
def print_network_visualization(architecture: list, activations_arch: list, epochs: int, batch: int, learning_rate: float) -> None:
3737
print('Network has', len(architecture) - 1, 'hidden layers :')
3838

3939
print(' layer [0] --> 784 neurons, inputs')

todo.md

+6-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
# To-do
22

3-
- Handle args in run.py just like in train.py
4-
- Handle others types of data
3+
- Update README
4+
- Handle args in `run.py` just like in `train.py`
5+
- Write descriptions for functions
6+
- Handle others types of data
7+
- Save trainings to an archive with datas : `weights matrix`, `activations arch`, [`hyperparameters`], [`accuracy`]
8+
- Rewrite as object ? (OOP)

0 commit comments

Comments
 (0)