Skip to content

Commit

Permalink
Added some graph tools
Browse files Browse the repository at this point in the history
  • Loading branch information
aunetx committed Dec 23, 2019
1 parent fd3c0bb commit ff28c1f
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 8 deletions.
20 changes: 13 additions & 7 deletions scripts/loulou.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def feed_forward(X_input: np.ndarray, weights: list, activation_fn: list) -> np.

# Forward loop
for id, w in enumerate(weights):
# Weighted average `z = w^T · x`
# Weighted average `z = w · x`
z = x[-1].dot(w)
# Activation function `y = g(x)`
y = activation_fn[id](z)
Expand Down Expand Up @@ -85,10 +85,14 @@ def train(weights: list, trainX: np.ndarray, trainY: np.ndarray, testX: np.ndarr
accuracy = np.mean(prediction == np.argmax(testY, axis=1))
accuracy_table.append(accuracy)

initial_cost = 1/len(testY) * np.sum((prediction -
np.argmax(testY, axis=1)) ** 2)
average_cost_table.append(initial_cost)

if reduce_output <= 1:
print('Accuracy at epoch 0 :', accuracy)
print('Accuracy at epoch 0 :', accuracy, ' cost =', initial_cost)
elif reduce_output == 2:
print(0, accuracy)
print(0, accuracy, initial_cost)

if epochs < 0:
epochs = 99999999999
Expand Down Expand Up @@ -135,7 +139,7 @@ def train(weights: list, trainX: np.ndarray, trainY: np.ndarray, testX: np.ndarr
print('Accuracy at epoch', i+1, ':',
accuracy, ' cost =', average_cost)
if reduce_output == 2:
print(i+1, accuracy)
print(i+1, accuracy, average_cost)

# Save temp file if set so
if filename:
Expand Down Expand Up @@ -178,7 +182,7 @@ def train(weights: list, trainX: np.ndarray, trainY: np.ndarray, testX: np.ndarr
utils.save(weights, activations_fn, filename,
no_infos, infos, reduce_output)

return accuracy_table
return (accuracy_table, average_cost_table), weights


def runTrain(params: dict, architecture: list, file=None) -> dict:
Expand All @@ -199,11 +203,13 @@ def runTrain(params: dict, architecture: list, file=None) -> dict:
architecture, activations_arch, epochs, batch, learning_rate)

# Load data
# TODO do not load arbitrary data
trX, trY, teX, teY = mnist.load_data()

# Init weights
weights = [np.random.randn(*w) * 0.1 for w in architecture]

# Train network
return train(weights, trX, trY, teX, teY, activations_arch, primes_arch, file, epochs, batch, learning_rate, save_timeout, graph, no_infos, reduce_output)
tr, weights = train(weights, trX, trY, teX, teY, activations_arch, primes_arch, file,
epochs, batch, learning_rate, save_timeout, graph, no_infos, reduce_output)

return tr
3 changes: 2 additions & 1 deletion scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,11 @@

params = json.dumps(params)
try:
accuracy = runTrain(params, architecture, file=filename)
accuracy, cost = runTrain(params, architecture, file=filename)
except KeyboardInterrupt:
print('\nTraining stopped by user')
exit()

if args.return_json:
print(accuracy)
print(cost)

0 comments on commit ff28c1f

Please sign in to comment.