Skip to content

Commit

Permalink
Training files now contain activations
Browse files Browse the repository at this point in the history
  • Loading branch information
aunetx committed Dec 3, 2019
1 parent 6169d2c commit 6bfc026
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 10 deletions.
5 changes: 3 additions & 2 deletions scripts/loulou.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,13 @@ def train(weights: list, trX: np.ndarray, trY: np.ndarray, teX: np.ndarray, teY:
temp_filename = '../trains/temp/' + \
filename + '_epoch_' + str(i) + '.npy'
temp_filename = os.path.join(path, temp_filename)
utils.save(weights, temp_filename, reduce_output)
utils.save(weights, activations_fn,
temp_filename, reduce_output)

# Save final file
if filename:
filename = os.path.join(path, '../trains/' + filename + '.npy')
utils.save(weights, filename, reduce_output)
utils.save(weights, activations_fn, filename, reduce_output)

return accuracy

Expand Down
10 changes: 4 additions & 6 deletions scripts/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import matplotlib.image as image

from loulou import feed_forward
from activations import relu
from utils import convertJson

if __name__ == '__main__':
Expand All @@ -19,7 +18,9 @@
# Loading weights matrix
filename = sys.argv[1]
try:
weights = np.load(filename)
file = np.load(filename)
weights = file['weights']
activations = file['activations']
except FileNotFoundError:
print(
"Error ! Weights matrix file could not be opened, please check that it exists.")
Expand All @@ -35,12 +36,9 @@
print("Image : ", img)
exit()

# TODO pack activations list into `.npy` file
activation_fn = [activation_fn.relu]

# Shaping image onto matrix
topred = 1 - img.reshape(784, 4).mean(axis=1)
# Making prediction
prediction = feed_forward(topred, weights, activation_fn)[-1]
prediction = feed_forward(topred, weights, activations)[-1]
# Printing json output
print(convertJson(prediction))
5 changes: 3 additions & 2 deletions scripts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
import json


def save(weights: list, filename: str, reduce_output: int) -> None:
np.save(filename, weights)
def save(weights: list, activations_list: list, filename: str, reduce_output: int) -> None:
np.save(filename, {'weights': weights,
'activations': activations_list})
if reduce_output < 2:
print('Data saved successfully into ', filename)

Expand Down

0 comments on commit 6bfc026

Please sign in to comment.