diff --git a/scripts/loulou.py b/scripts/loulou.py index 3c8d175..c284062 100644 --- a/scripts/loulou.py +++ b/scripts/loulou.py @@ -123,12 +123,13 @@ def train(weights: list, trX: np.ndarray, trY: np.ndarray, teX: np.ndarray, teY: temp_filename = '../trains/temp/' + \ filename + '_epoch_' + str(i) + '.npy' temp_filename = os.path.join(path, temp_filename) - utils.save(weights, temp_filename, reduce_output) + utils.save(weights, activations_fn, + temp_filename, reduce_output) # Save final file if filename: filename = os.path.join(path, '../trains/' + filename + '.npy') - utils.save(weights, filename, reduce_output) + utils.save(weights, activations_fn, filename, reduce_output) return accuracy diff --git a/scripts/run.py b/scripts/run.py index 0dd13f0..6fac7a6 100644 --- a/scripts/run.py +++ b/scripts/run.py @@ -5,7 +5,6 @@ import matplotlib.image as image from loulou import feed_forward -from activations import relu from utils import convertJson if __name__ == '__main__': @@ -19,7 +18,9 @@ # Loading weights matrix filename = sys.argv[1] try: - weights = np.load(filename) + file = np.load(filename) + weights = file['weights'] + activations = file['activations'] except FileNotFoundError: print( "Error ! Weights matrix file could not be opened, please check that it exists.") @@ -35,12 +36,9 @@ print("Image : ", img) exit() - # TODO pack activations list into `.npy` file - activation_fn = [activation_fn.relu] - # Shaping image onto matrix topred = 1 - img.reshape(784, 4).mean(axis=1) # Making prediction - prediction = feed_forward(topred, weights, activation_fn)[-1] + prediction = feed_forward(topred, weights, activations)[-1] # Printing json output print(convertJson(prediction)) diff --git a/scripts/utils.py b/scripts/utils.py index 5de82d7..7d591a5 100644 --- a/scripts/utils.py +++ b/scripts/utils.py @@ -2,8 +2,9 @@ import json -def save(weights: list, filename: str, reduce_output: int) -> None: - np.save(filename, weights) +def save(weights: list, activations_list: list, filename: str, reduce_output: int) -> None: + np.save(filename, {'weights': weights, + 'activations': activations_list}) if reduce_output < 2: print('Data saved successfully into ', filename)