From ff28c1fe4c2f1c2ac0cc18917954acfd2fa4e5ae Mon Sep 17 00:00:00 2001 From: aunetx Date: Mon, 23 Dec 2019 15:58:07 +0100 Subject: [PATCH] Added some graph tools --- scripts/loulou.py | 20 +++++++++++++------- scripts/train.py | 3 ++- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/scripts/loulou.py b/scripts/loulou.py index 28c735d..0a79c61 100644 --- a/scripts/loulou.py +++ b/scripts/loulou.py @@ -26,7 +26,7 @@ def feed_forward(X_input: np.ndarray, weights: list, activation_fn: list) -> np. # Forward loop for id, w in enumerate(weights): - # Weighted average `z = w^T · x` + # Weighted average `z = w · x` z = x[-1].dot(w) # Activation function `y = g(x)` y = activation_fn[id](z) @@ -85,10 +85,14 @@ def train(weights: list, trainX: np.ndarray, trainY: np.ndarray, testX: np.ndarr accuracy = np.mean(prediction == np.argmax(testY, axis=1)) accuracy_table.append(accuracy) + initial_cost = 1/len(testY) * np.sum((prediction - + np.argmax(testY, axis=1)) ** 2) + average_cost_table.append(initial_cost) + if reduce_output <= 1: - print('Accuracy at epoch 0 :', accuracy) + print('Accuracy at epoch 0 :', accuracy, ' cost =', initial_cost) elif reduce_output == 2: - print(0, accuracy) + print(0, accuracy, initial_cost) if epochs < 0: epochs = 99999999999 @@ -135,7 +139,7 @@ def train(weights: list, trainX: np.ndarray, trainY: np.ndarray, testX: np.ndarr print('Accuracy at epoch', i+1, ':', accuracy, ' cost =', average_cost) if reduce_output == 2: - print(i+1, accuracy) + print(i+1, accuracy, average_cost) # Save temp file if set so if filename: @@ -178,7 +182,7 @@ def train(weights: list, trainX: np.ndarray, trainY: np.ndarray, testX: np.ndarr utils.save(weights, activations_fn, filename, no_infos, infos, reduce_output) - return accuracy_table + return (accuracy_table, average_cost_table), weights def runTrain(params: dict, architecture: list, file=None) -> dict: @@ -199,11 +203,13 @@ def runTrain(params: dict, architecture: list, file=None) -> dict: architecture, activations_arch, epochs, batch, learning_rate) # Load data - # TODO do not load arbitrary data trX, trY, teX, teY = mnist.load_data() # Init weights weights = [np.random.randn(*w) * 0.1 for w in architecture] # Train network - return train(weights, trX, trY, teX, teY, activations_arch, primes_arch, file, epochs, batch, learning_rate, save_timeout, graph, no_infos, reduce_output) + tr, weights = train(weights, trX, trY, teX, teY, activations_arch, primes_arch, file, + epochs, batch, learning_rate, save_timeout, graph, no_infos, reduce_output) + + return tr diff --git a/scripts/train.py b/scripts/train.py index 9977169..d10a932 100755 --- a/scripts/train.py +++ b/scripts/train.py @@ -78,10 +78,11 @@ params = json.dumps(params) try: - accuracy = runTrain(params, architecture, file=filename) + accuracy, cost = runTrain(params, architecture, file=filename) except KeyboardInterrupt: print('\nTraining stopped by user') exit() if args.return_json: print(accuracy) + print(cost)