-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Handle activations in saved file, handle args in run.py
- Loading branch information
Showing
6 changed files
with
60 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,44 +1,67 @@ | ||
#!/usr/bin/env python3 | ||
|
||
import matplotlib.image as image | ||
import numpy as np | ||
import argparse | ||
import sys | ||
import matplotlib.image as image | ||
|
||
from loulou import feed_forward | ||
from utils import convertJson | ||
import activations | ||
|
||
if __name__ == '__main__': | ||
# Handling errors for bad arguments | ||
try: | ||
assert len(sys.argv) == 3 | ||
except AssertionError: | ||
print("Error ! Please give two arguments : path to weights file and to image to predict.") | ||
exit() | ||
verbosity = 0 | ||
|
||
parser = argparse.ArgumentParser( | ||
description='Utility to run a loulou-based neural network.') | ||
parser.add_argument('-f', '--file', dest='file', type=str, required=True, | ||
help='path to the `.npz` training file') | ||
parser.add_argument('-i', '--image', dest='image', type=str, required=True, | ||
help='path to the image to predict') | ||
parser.add_argument('-v', '--verbosity', dest='verbosity', action="count", | ||
help='add verbosity to the output (you can type several)') | ||
parser.add_argument('-j', '--return-json', dest='return_json', action="store_true", | ||
help='print the prediction and hot vector in a json format') | ||
args = parser.parse_args() | ||
|
||
# Loading weights matrix | ||
filename = sys.argv[1] | ||
if args.verbosity is not None: | ||
verbosity = args.verbosity | ||
|
||
# Load weights and activations from file | ||
try: | ||
file = np.load(filename) | ||
# TODO change behavior to allow_pickle=False for security | ||
file = np.load(args.file, allow_pickle=True) | ||
weights = file['weights'] | ||
activations = file['activations'] | ||
activations_names = file['activations'] | ||
|
||
activations_fn = activations.listToActivations( | ||
activations_names, weights)[0] | ||
if verbosity > 0: | ||
print('Activations used :', activations_names) | ||
|
||
except FileNotFoundError: | ||
print( | ||
"Error ! Weights matrix file could not be opened, please check that it exists.") | ||
print("Fichier : ", filename) | ||
'Error ! File ['+args.file+'] could not be opened, please check that it exists.') | ||
exit() | ||
|
||
# Loading image data | ||
img = sys.argv[2] | ||
# Load image data | ||
try: | ||
img = image.imread(img) | ||
img = image.imread(args.image) | ||
except FileNotFoundError: | ||
print("Error ! Image could not be opened, please check that it exists.") | ||
print("Image : ", img) | ||
print( | ||
'Error ! Image ['+args.image+'] could not be opened, please check that it exists.') | ||
exit() | ||
|
||
# Shaping image onto matrix | ||
# Convert image to matrix | ||
topred = 1 - img.reshape(784, 4).mean(axis=1) | ||
# Making prediction | ||
prediction = feed_forward(topred, weights, activations)[-1] | ||
# Printing json output | ||
print(convertJson(prediction)) | ||
# Make the prediction | ||
prediction = feed_forward(topred, weights, activations_fn)[-1] | ||
# Print final output | ||
if args.return_json: | ||
print(convertJson(prediction)) | ||
else: | ||
if verbosity > 0: | ||
print('Hot ones vector :', list(prediction)) | ||
print('Final prediction :', prediction.argmax()) | ||
else: | ||
print(prediction.argmax()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,6 @@ | ||
# To-do | ||
|
||
- Update README | ||
- Handle args in `run.py` just like in `train.py` | ||
- Write descriptions for functions | ||
- Handle others types of data | ||
- Save trainings to an archive with datas : `weights matrix`, `activations arch`, [`hyperparameters`], [`accuracy`] | ||
- Rewrite as object ? (OOP) |