forked from ankonzoid/artificio
-
Notifications
You must be signed in to change notification settings - Fork 0
/
image_retrieval.py
158 lines (133 loc) · 6.26 KB
/
image_retrieval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
"""
image_retrieval.py (author: Anson Wong / git: ankonzoid)
We perform image retrieval using transfer learning on a pre-trained
VGG image classifier. We plot the k=5 most similar images to our
query images, as well as the t-SNE visualizations.
"""
import os
import numpy as np
import tensorflow as tf
from sklearn.neighbors import NearestNeighbors
from src.CV_IO_utils import read_imgs_dir
from src.CV_transform_utils import apply_transformer
from src.CV_transform_utils import resize_img, normalize_img
from src.CV_plot_utils import plot_query_retrieval, plot_tsne, plot_reconstructions
from src.autoencoder import AutoEncoder
# Run mode: (autoencoder -> simpleAE, convAE) or (transfer learning -> vgg19)
modelName = "convAE" # try: "simpleAE", "convAE", "vgg19"
trainModel = True
parallel = True # use multicore processing
# Make paths
dataTrainDir = os.path.join(os.getcwd(), "data", "train")
dataTestDir = os.path.join(os.getcwd(), "data", "test")
outDir = os.path.join(os.getcwd(), "output", modelName)
if not os.path.exists(outDir):
os.makedirs(outDir)
# Read images
extensions = [".jpg", ".jpeg"]
print("Reading train images from '{}'...".format(dataTrainDir))
imgs_train = read_imgs_dir(dataTrainDir, extensions, parallel=parallel)
print("Reading test images from '{}'...".format(dataTestDir))
imgs_test = read_imgs_dir(dataTestDir, extensions, parallel=parallel)
shape_img = imgs_train[0].shape
print("Image shape = {}".format(shape_img))
# Build models
if modelName in ["simpleAE", "convAE"]:
# Set up autoencoder
info = {
"shape_img": shape_img,
"autoencoderFile": os.path.join(outDir, "{}_autoecoder.h5".format(modelName)),
"encoderFile": os.path.join(outDir, "{}_encoder.h5".format(modelName)),
"decoderFile": os.path.join(outDir, "{}_decoder.h5".format(modelName)),
}
model = AutoEncoder(modelName, info)
model.set_arch()
if modelName == "simpleAE":
shape_img_resize = shape_img
input_shape_model = (model.encoder.input.shape[1],)
output_shape_model = (model.encoder.output.shape[1],)
n_epochs = 300
elif modelName == "convAE":
shape_img_resize = shape_img
input_shape_model = tuple([int(x) for x in model.encoder.input.shape[1:]])
output_shape_model = tuple([int(x) for x in model.encoder.output.shape[1:]])
n_epochs = 500
else:
raise Exception("Invalid modelName!")
elif modelName in ["vgg19"]:
# Load pre-trained VGG19 model + higher level layers
print("Loading VGG19 pre-trained model...")
model = tf.keras.applications.VGG19(weights='imagenet', include_top=False,
input_shape=shape_img)
model.summary()
shape_img_resize = tuple([int(x) for x in model.input.shape[1:]])
input_shape_model = tuple([int(x) for x in model.input.shape[1:]])
output_shape_model = tuple([int(x) for x in model.output.shape[1:]])
n_epochs = None
else:
raise Exception("Invalid modelName!")
# Print some model info
print("input_shape_model = {}".format(input_shape_model))
print("output_shape_model = {}".format(output_shape_model))
# Apply transformations to all images
class ImageTransformer(object):
def __init__(self, shape_resize):
self.shape_resize = shape_resize
def __call__(self, img):
img_transformed = resize_img(img, self.shape_resize)
img_transformed = normalize_img(img_transformed)
return img_transformed
transformer = ImageTransformer(shape_img_resize)
print("Applying image transformer to training images...")
imgs_train_transformed = apply_transformer(imgs_train, transformer, parallel=parallel)
print("Applying image transformer to test images...")
imgs_test_transformed = apply_transformer(imgs_test, transformer, parallel=parallel)
# Convert images to numpy array
X_train = np.array(imgs_train_transformed).reshape((-1,) + input_shape_model)
X_test = np.array(imgs_test_transformed).reshape((-1,) + input_shape_model)
print(" -> X_train.shape = {}".format(X_train.shape))
print(" -> X_test.shape = {}".format(X_test.shape))
# Train (if necessary)
if modelName in ["simpleAE", "convAE"]:
if trainModel:
model.compile(loss="binary_crossentropy", optimizer="adam")
model.fit(X_train, n_epochs=n_epochs, batch_size=256)
model.save_models()
else:
model.load_models(loss="binary_crossentropy", optimizer="adam")
# Create embeddings using model
print("Inferencing embeddings using pre-trained model...")
E_train = model.predict(X_train)
E_train_flatten = E_train.reshape((-1, np.prod(output_shape_model)))
E_test = model.predict(X_test)
E_test_flatten = E_test.reshape((-1, np.prod(output_shape_model)))
print(" -> E_train.shape = {}".format(E_train.shape))
print(" -> E_test.shape = {}".format(E_test.shape))
print(" -> E_train_flatten.shape = {}".format(E_train_flatten.shape))
print(" -> E_test_flatten.shape = {}".format(E_test_flatten.shape))
# Make reconstruction visualizations
if modelName in ["simpleAE", "convAE"]:
print("Visualizing database image reconstructions...")
imgs_train_reconstruct = model.decoder.predict(E_train)
if modelName == "simpleAE":
imgs_train_reconstruct = imgs_train_reconstruct.reshape((-1,) + shape_img_resize)
plot_reconstructions(imgs_train, imgs_train_reconstruct,
os.path.join(outDir, "{}_reconstruct.png".format(modelName)),
range_imgs=[0, 255],
range_imgs_reconstruct=[0, 1])
# Fit kNN model on training images
print("Fitting k-nearest-neighbour model on training images...")
knn = NearestNeighbors(n_neighbors=5, metric="cosine")
knn.fit(E_train_flatten)
# Perform image retrieval on test images
print("Performing image retrieval on test images...")
for i, emb_flatten in enumerate(E_test_flatten):
_, indices = knn.kneighbors([emb_flatten]) # find k nearest train neighbours
img_query = imgs_test[i] # query image
imgs_retrieval = [imgs_train[idx] for idx in indices.flatten()] # retrieval images
outFile = os.path.join(outDir, "{}_retrieval_{}.png".format(modelName, i))
plot_query_retrieval(img_query, imgs_retrieval, outFile)
# Plot t-SNE visualization
print("Visualizing t-SNE on training images...")
outFile = os.path.join(outDir, "{}_tsne.png".format(modelName))
plot_tsne(E_train_flatten, imgs_train, outFile)