Skip to content

Commit

Permalink
Added image tester class, provided testing example
Browse files Browse the repository at this point in the history
  • Loading branch information
RoundRonin committed Mar 31, 2024
1 parent d279468 commit d4e1d10
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 6 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -136,4 +136,6 @@ dmypy.json
Data*
**/Data*

Save*
Save*

test_images
18 changes: 17 additions & 1 deletion image_recognition/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from image_recognition.modules.importer import importer
from image_recognition.modules.model import model
from image_recognition.modules.visualization import plotter_evaluator
from image_recognition.modules.visualization import plotter_evaluator, tester
import os

def main(): # pragma: no cover

Expand Down Expand Up @@ -38,6 +39,7 @@ def main(): # pragma: no cover

## Обработка данных
# Вносится рандомизация (ротация, зум, перемещение). Также приводится яркость к понятному нейросети формату (вместо 0-255, 0-1).
# По-умолчанию (0.2, 0.1, 0.08)
i.generate_augmentation_layers(0.2, 0.1, 0.08)

### Применение слоёв обработки данных
Expand Down Expand Up @@ -104,3 +106,17 @@ def main(): # pragma: no cover
## Матрица запутанности
# Хорший способ понять, как именно нейросеть ошибается
pe.plot_confusion_matrix()


# Тестирование избранных случаев
t = tester(model_instance)

images = ["one", "two", "tree", "four", "five", "six", "seven", "eight", "nine", "zero"]
# Составной путь. Строится от корня программы
for i in range(len(images)):

absolute_path = os.path.join(os.getcwd(), 'test_images', images[i] + '.jpg')
t.read_image(absolute_path, False)

t.img_predict()

3 changes: 2 additions & 1 deletion image_recognition/modules/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,5 @@ def __get_stats(self):
labels = np.concatenate([labels, np.argmax(y.numpy(), axis=-1)])

self.class_names = set(labels)
self.num_classes = len(self.class_names)
self.num_classes = len(self.class_names)

4 changes: 2 additions & 2 deletions image_recognition/modules/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def compile(self):
# Компиляция модели
optimizer = RMSprop(learning_rate=0.001, rho=0.9, epsilon=1e-08)
self.model.compile(optimizer = optimizer , loss = "categorical_crossentropy", metrics=["accuracy"])
# self.model.summary()

def train(self, train_ds: tf_data.Dataset, epochs: int, validation_data: tf_data.Dataset):

Expand All @@ -65,4 +64,5 @@ def init_learning_rate_reduction(self):

def init_save_at_epoch(self):

self.callbacks.append(keras.callbacks.ModelCheckpoint("save_at_{epoch}.keras"))
self.callbacks.append(keras.callbacks.ModelCheckpoint("save_at_{epoch}.keras"))

39 changes: 38 additions & 1 deletion image_recognition/modules/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from scikitplot.metrics import plot_confusion_matrix
from tensorflow import data as tf_data


class plotter_evaluator:

model: keras.Model
Expand Down Expand Up @@ -62,4 +63,40 @@ def print_report(self):
def plot_confusion_matrix(self):

plot_confusion_matrix(self.labels, self.pred_labels,cmap= 'YlGnBu')
plt.show()
plt.show()


import cv2
import tensorflow as tf
import math


class tester:

img: np.ndarray | None = None
model: keras.Model
path: str

def __init__(self, model: keras.Model):
self.model = model

def read_image(self, path: str, show: bool):
img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
if(img is None):
print("Image not found")
return

if(show): plt.imshow(img, cmap='gray', vmin=0, vmax=255)
self.path = path
self.img = img

def img_predict(self):
if(self.img is None):
print("No image supplied")
return

prediction = self.model.predict(np.expand_dims(self.img/255, 0))
probabilities = list(map(lambda x: math.floor(x*1000)/1000, prediction[0]))
print(self.path)
print(probabilities)

0 comments on commit d4e1d10

Please sign in to comment.