-
Notifications
You must be signed in to change notification settings - Fork 5
/
main.py
71 lines (55 loc) · 1.99 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# TensorFlow and tf.keras
import tensorflow as tf
# Utils
import os
# Configuration
IMAGE_SIZE = (256, 256)
BATCH_SIZE = 10
EPOCHS_NUM = 40
print("Using Tensorflow", tf.__version__)
def create_model(num_classes):
model = tf.keras.Sequential()
model.add(tf.keras.layers.Rescaling(1./255, input_shape=(IMAGE_SIZE[0], IMAGE_SIZE[1], 3)))
model.add(tf.keras.layers.Conv2D(16, 3, padding="same", activation="relu"))
model.add(tf.keras.layers.MaxPooling2D())
model.add(tf.keras.layers.Conv2D(32, 3, padding="same", activation="relu"))
model.add(tf.keras.layers.MaxPooling2D())
model.add(tf.keras.layers.Conv2D(64, 3, padding="same", activation="relu"))
model.add(tf.keras.layers.MaxPooling2D())
model.add(tf.keras.layers.Dropout(0.2))
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(128, activation="relu"))
model.add(tf.keras.layers.Dense(num_classes))
return model
def load_dataset(path):
dataset = tf.keras.preprocessing.image_dataset_from_directory(
path, image_size=IMAGE_SIZE, batch_size=BATCH_SIZE
)
return dataset
if __name__ == "__main__":
# create class names
for folder in os.listdir("./data"):
print(folder)
with open("./output/descriptions.txt", "a+") as f:
f.write(folder + "\n")
# Load the dataset
data = load_dataset("./data")
# Create the model
model = create_model(len(data.class_names))
model.summary()
# Compile the model
model.compile(
optimizer="adam",
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=["accuracy"],
)
# Train the model
model.fit(data, epochs=EPOCHS_NUM, validation_data=data)
# Save the model
model.save("output/keras/model_" + str(EPOCHS_NUM) + ".h5")
# Evaluate the model
test_loss, test_acc = model.evaluate(data)
print("Test Loss:", test_loss)
print("Test Accuracy:", test_acc)
# End
print("Dataset loaded and model trained")