-
Notifications
You must be signed in to change notification settings - Fork 1
/
model.py
69 lines (61 loc) · 2.05 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os
# Make TensorFlow logs less verbose
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import tensorflow as tf
from tensorflow.keras.models import Sequential, model_from_json
from tensorflow.keras.layers import Conv2D, MaxPool2D,Flatten,Dense
from tensorflow.keras.optimizers import SGD
import numpy as np
import pandas as pd
from joblib import load, dump
import base64
import dill
import tempfile
from io import BytesIO
def define_model(INPUT_SHAPE, NUM_CLASSES) -> Sequential:
"""
Define the model architecture
Parameters
------------
INPUT_SHAPE: tuple
Shape of the input data
NUM_CLASSES: int
Number of classes
Returns
------------
model: Sequential
Model architecture
"""
model = Sequential()
model.add(Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_uniform', input_shape=INPUT_SHAPE))
model.add(MaxPool2D((2, 2)))
model.add(Flatten())
model.add(Dense(100, activation='relu', kernel_initializer='he_uniform'))
model.add(Dense(NUM_CLASSES, activation='softmax'))
# compile model
opt = SGD(learning_rate=0.01, momentum=0.9)
model.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['accuracy'])
return model
def ModelBase64Encoder(model_weights):
"""
Encode the model weight to base64
https://stackoverflow.com/questions/60567679/save-keras-model-weights-directly-to-bytes-memory
"""
bytes_container = BytesIO()
dill.dump(model_weights, bytes_container)
bytes_container.seek(0)
bytes_file = bytes_container.read()
base64File = base64.b64encode(bytes_file)
return base64File
def ModelBase64Decoder(model_bytes):
"""
Decode the base64 encoded model weight
https://stackoverflow.com/questions/60567679/save-keras-model-weights-directly-to-bytes-memory
"""
loaded_binary = base64.b64decode(model_bytes)
loaded_object = tempfile.TemporaryFile()
loaded_object.write(loaded_binary)
loaded_object.seek(0)
ObjectFile = load(loaded_object)
loaded_object.close()
return ObjectFile