Skip to content

Commit

Permalink
Refactorization
Browse files Browse the repository at this point in the history
  • Loading branch information
Carlota de la Vega committed Jun 18, 2024
1 parent 9c5c67c commit f851c33
Show file tree
Hide file tree
Showing 14 changed files with 69 additions and 141 deletions.
Binary file added data/nums/drawn_number_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/nums/drawn_number_3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/nums/drawn_number_4.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/nums/drawn_number_5.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/nums/drawn_number_6.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/nums/drawn_number_7.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/nums/drawn_number_8.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/nums/drawn_number_9.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
24 changes: 23 additions & 1 deletion src/drawing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
import numpy as np

from src import utils
from src.nums import builders


def draw_number(number, cond_gan):
number = int(number)
noise = np.random.normal(size=(1, utils.latent_dim))
label = keras.utils.to_categorical([number], utils.num_classes)
label = keras.utils.to_categorical([number], builders.num_classes)
label = label.astype("float32")

noise_and_label = np.concatenate([noise, label], 1)
Expand All @@ -26,3 +27,24 @@ def draw_number(number, cond_gan):
imageio.imwrite(filename, generated_image)

return generated_image

def draw_image(description, cond_gan):
noise = np.random.normal(size=(1, utils.latent_dim))
label = keras.utils.to_categorical([description], builders.num_classes)
label = label.astype("float32")

noise_and_label = np.concatenate([noise, label], 1)
generated_image = cond_gan.generator.predict(noise_and_label)

generated_image = np.squeeze(generated_image)

if len(generated_image.shape) > 2:
generated_image = np.mean(generated_image, axis=-1)

generated_image = np.clip(generated_image * 255, 0, 255)
generated_image = generated_image.astype(np.uint8)

filename = f"./data/images/drawn_image_{description}.png"
imageio.imwrite(filename, generated_image)

return generated_image
70 changes: 18 additions & 52 deletions src/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,52 +2,25 @@
import logging

from src import drawing, utils
from src.nums import builders


# pylint: disable=too-few-public-methods
class Handler:
"""Class that handles client requests.
This class handles the requests sent by the client and executes the
corresponding actions.
Attributes:
socket (socket): Client socket.
"""

def __init__(self, socket):
"""Class constructor.
Args:
socket: Client socket.
"""
self.socket = socket

def handle(self):
"""Method that handles the client's request."""
with self.socket:
data = self._receive_data()
request = self._process_request(data)
response = self._execute_command(request)
self._send_response(response)

def _receive_data(self):
"""Method that receives the data sent by the client.
Returns:
str: Data sent by the client.
"""
return self.socket.recv(1024).decode()

def _process_request(self, data):
"""Method that processes the data received from the client.
Args:
data (str): Data sent by the client.
Returns:
dict: Request data.
"""
try:
return json.loads(data)

Expand All @@ -62,19 +35,20 @@ def _process_request(self, data):
return response

def _execute_command(self, request):
"""Method that executes the command sent by the client.
Args:
request (dict): Request data.
Returns:
dict: Response data.
"""
command = request.get("command")
text = request.get("text", "")

if command == "generate_number":
response = self._generate_number(text)

elif command == "generate_image":
response = {
"status": "error",
"message": "Command not implemented",
}

logging.error("Command not implemented: %s", command)

else:
response = {
"status": "error",
Expand All @@ -86,16 +60,12 @@ def _execute_command(self, request):
return response

def _generate_number(self, text):
"""Method that generates an image of a number.
try:
generator, discriminator = builders.build_models()
cond_gan = builders.build_conditional_gan(generator, discriminator)

Args:
text (str): Number to generate.
cond_gan = utils.load_model_with_weights("models/cgan_nums.weights.h5", cond_gan)

Returns:
dict: Response data.
"""
try:
cond_gan = utils.load_model_with_weights("models/cgan_nums.weights.h5")
except FileNotFoundError as e:
response = {
"status": "error",
Expand Down Expand Up @@ -129,13 +99,9 @@ def _generate_number(self, text):

return response

def _send_response(self, response):
"""Method that sends the response to the client.
def _generate_image(self, text):
# TODO: Implement this method
pass

Args:
response (dict): Response data.
Returns:
dict: Response data.
"""
def _send_response(self, response):
self.socket.sendall(json.dumps(response).encode())
22 changes: 8 additions & 14 deletions src/nums/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
from src import utils
from src.nums import cgan

num_channels = 1
num_classes = 10
image_size = 28


def build_models():
"""
Expand All @@ -13,8 +17,8 @@ def build_models():
keras.Model: Discriminator model.
"""
# - - - - - - - Calculate the number of input channels - - - - - - -
gen_channels = utils.latent_dim + utils.num_classes
dis_channels = utils.num_channels + utils.num_classes
gen_channels = utils.latent_dim + num_classes
dis_channels = num_channels + num_classes

# - - - - - - - Generator - - - - - - -
generator = keras.Sequential(
Expand Down Expand Up @@ -50,22 +54,12 @@ def build_models():


def build_conditional_gan(generator, discriminator):
"""
Builds the conditional GAN (cGAN) model.
Args:
generator (keras.Model): Generator model.
discriminator (keras.Model): Discriminator model.
Returns:
conditionalGAN: Compiled cGAN model.
"""
config = cgan.GANConfig(
discriminator=discriminator,
generator=generator,
latent_dim=utils.latent_dim,
image_size=utils.image_size,
num_classes=utils.num_classes,
image_size=image_size,
num_classes=num_classes,
)

cond_gan = cgan.ConditionalGAN(
Expand Down
File renamed without changes.
27 changes: 1 addition & 26 deletions src/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,44 +5,19 @@


class Server:
"""Class that represents a server that listens on a port and handles
client connections.
This class is responsible for starting a server that listens on a specific
port and handles incoming client connections.
Attributes:
host (str): IP address or host name where the server will listen.
port (int): Port where the server will listen.
"""

def __init__(self, host, port):
"""Class constructor.
Args:
host (str): IP address or host name where the server will listen.
port (int): Port where the server will listen.
"""
self.host = host
self.port = port

def start(self):
"""Method that starts the server."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server_socket:
server_socket.bind((self.host, self.port))
server_socket.listen()
print(f"Server listening on port {self.port}")
while True:
client_socket, _ = server_socket.accept()
threading.Thread(
target=self.client_handler, args=(client_socket,)
).start()
threading.Thread(target=self.client_handler, args=(client_socket,)).start()

def client_handler(self, client_socket):
"""Method that handles a connection with a client.
Args:
client_socket (socket): Client socket.
"""
handler = Handler(client_socket)
handler.handle()
67 changes: 19 additions & 48 deletions src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,73 +4,44 @@
import numpy as np
import tensorflow as tf

from src.nums import builders

batch_size = 64
num_channels = 1
num_classes = 10
image_size = 28
latent_dim = 128


def load_dataset():
"""
Loads the MNIST dataset, preprocesses it, and returns it as a TensorFlow dataset.
def load_dataset(dataset_name):
dataset_dict = {
"mnist": (keras.datasets.mnist, 10, 1, 28),
"cifar10": (keras.datasets.cifar10, 10, 3, 32),
"cifar100": (keras.datasets.cifar100, 100, 3, 32),
}

if dataset_name not in dataset_dict:
raise ValueError("Invalid dataset name")

dataset, num_classes, num_channels, image_size = dataset_dict[dataset_name]
(x_train, y_train), (x_test, y_test) = dataset.load_data()

Returns:
tf.data.Dataset: Dataset containing the preprocessed MNIST images and labels.
"""
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
all_digits = np.concatenate([x_train, x_test])
all_labels = np.concatenate([y_train, y_test])
all_images = np.concatenate([x_train, x_test]).astype("float32") / 255.0
all_labels = keras.utils.to_categorical(np.concatenate([y_train, y_test]), num_classes)

all_digits = all_digits.astype("float32") / 255.0
all_digits = np.reshape(all_digits, (-1, 28, 28, 1))
all_labels = keras.utils.to_categorical(all_labels, 10)
all_images = np.reshape(all_images, (-1, image_size, image_size, num_channels))

dataset = tf.data.Dataset.from_tensor_slices((all_digits, all_labels))
dataset = tf.data.Dataset.from_tensor_slices((all_images, all_labels))
dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)

return dataset


def train_model(dataset, cond_gan):
"""
Trains the conditional GAN (cGAN) model.
Args:
dataset (tf.data.Dataset): Dataset containing the training images and labels.
cond_gan (conditionalGAN): Compiled cGAN model.
"""
cond_gan.fit(dataset, epochs=50)


def save_model_weights(cond_gan, filename):
"""
Saves the weights of the conditional GAN (cGAN) model.
Args:
cond_gan (conditionalGAN): Compiled cGAN model.
filename (str): Filepath to save the model weights.
"""
if os.path.exists(filename):
os.remove(filename)
cond_gan.save_weights(filename)


def load_model_with_weights(filename):
"""
Loads the conditional GAN (cGAN) model with saved weights.
Args:
filename (str): Filepath to the saved model weights.
Returns:
conditionalGAN: cGAN model with loaded weights.
"""
generator, discriminator = builders.build_models()

new_cond_gan = builders.build_conditional_gan(generator, discriminator)

new_cond_gan.load_weights(filename)
return new_cond_gan
def load_model_with_weights(filename, cond_gan):
cond_gan.load_weights(filename)
return cond_gan

0 comments on commit f851c33

Please sign in to comment.