From d4e1d10ce6fac3866ce3c18243c3b52b954a6459 Mon Sep 17 00:00:00 2001 From: RoundRonin Date: Sun, 31 Mar 2024 20:41:44 +0300 Subject: [PATCH] Added image tester class, provided testing example --- .gitignore | 4 ++- image_recognition/cli.py | 18 +++++++++- image_recognition/modules/importer.py | 3 +- image_recognition/modules/model.py | 4 +-- image_recognition/modules/visualization.py | 39 +++++++++++++++++++++- 5 files changed, 62 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index e0fc5ab..aca2c5e 100644 --- a/.gitignore +++ b/.gitignore @@ -136,4 +136,6 @@ dmypy.json Data* **/Data* -Save* \ No newline at end of file +Save* + +test_images \ No newline at end of file diff --git a/image_recognition/cli.py b/image_recognition/cli.py index 4588673..c881044 100644 --- a/image_recognition/cli.py +++ b/image_recognition/cli.py @@ -2,7 +2,8 @@ from image_recognition.modules.importer import importer from image_recognition.modules.model import model -from image_recognition.modules.visualization import plotter_evaluator +from image_recognition.modules.visualization import plotter_evaluator, tester +import os def main(): # pragma: no cover @@ -38,6 +39,7 @@ def main(): # pragma: no cover ## Обработка данных # Вносится рандомизация (ротация, зум, перемещение). Также приводится яркость к понятному нейросети формату (вместо 0-255, 0-1). + # По-умолчанию (0.2, 0.1, 0.08) i.generate_augmentation_layers(0.2, 0.1, 0.08) ### Применение слоёв обработки данных @@ -104,3 +106,17 @@ def main(): # pragma: no cover ## Матрица запутанности # Хорший способ понять, как именно нейросеть ошибается pe.plot_confusion_matrix() + + + # Тестирование избранных случаев + t = tester(model_instance) + + images = ["one", "two", "tree", "four", "five", "six", "seven", "eight", "nine", "zero"] + # Составной путь. Строится от корня программы + for i in range(len(images)): + + absolute_path = os.path.join(os.getcwd(), 'test_images', images[i] + '.jpg') + t.read_image(absolute_path, False) + + t.img_predict() + diff --git a/image_recognition/modules/importer.py b/image_recognition/modules/importer.py index b8fd568..871b9d7 100644 --- a/image_recognition/modules/importer.py +++ b/image_recognition/modules/importer.py @@ -71,4 +71,5 @@ def __get_stats(self): labels = np.concatenate([labels, np.argmax(y.numpy(), axis=-1)]) self.class_names = set(labels) - self.num_classes = len(self.class_names) \ No newline at end of file + self.num_classes = len(self.class_names) + diff --git a/image_recognition/modules/model.py b/image_recognition/modules/model.py index 9bf7086..94c39b0 100644 --- a/image_recognition/modules/model.py +++ b/image_recognition/modules/model.py @@ -45,7 +45,6 @@ def compile(self): # Компиляция модели optimizer = RMSprop(learning_rate=0.001, rho=0.9, epsilon=1e-08) self.model.compile(optimizer = optimizer , loss = "categorical_crossentropy", metrics=["accuracy"]) - # self.model.summary() def train(self, train_ds: tf_data.Dataset, epochs: int, validation_data: tf_data.Dataset): @@ -65,4 +64,5 @@ def init_learning_rate_reduction(self): def init_save_at_epoch(self): - self.callbacks.append(keras.callbacks.ModelCheckpoint("save_at_{epoch}.keras")) \ No newline at end of file + self.callbacks.append(keras.callbacks.ModelCheckpoint("save_at_{epoch}.keras")) + diff --git a/image_recognition/modules/visualization.py b/image_recognition/modules/visualization.py index 5c1cc23..de9fc62 100644 --- a/image_recognition/modules/visualization.py +++ b/image_recognition/modules/visualization.py @@ -9,6 +9,7 @@ from scikitplot.metrics import plot_confusion_matrix from tensorflow import data as tf_data + class plotter_evaluator: model: keras.Model @@ -62,4 +63,40 @@ def print_report(self): def plot_confusion_matrix(self): plot_confusion_matrix(self.labels, self.pred_labels,cmap= 'YlGnBu') - plt.show() \ No newline at end of file + plt.show() + + +import cv2 +import tensorflow as tf +import math + + +class tester: + + img: np.ndarray | None = None + model: keras.Model + path: str + + def __init__(self, model: keras.Model): + self.model = model + + def read_image(self, path: str, show: bool): + img = cv2.imread(path, cv2.IMREAD_GRAYSCALE) + if(img is None): + print("Image not found") + return + + if(show): plt.imshow(img, cmap='gray', vmin=0, vmax=255) + self.path = path + self.img = img + + def img_predict(self): + if(self.img is None): + print("No image supplied") + return + + prediction = self.model.predict(np.expand_dims(self.img/255, 0)) + probabilities = list(map(lambda x: math.floor(x*1000)/1000, prediction[0])) + print(self.path) + print(probabilities) +