From f851c33abff0f45b2986e3fb3cd42b51a13aeb50 Mon Sep 17 00:00:00 2001 From: Carlota de la Vega Date: Tue, 18 Jun 2024 11:33:04 +0200 Subject: [PATCH] Refactorization --- data/nums/drawn_number_0.png | Bin 0 -> 329 bytes data/nums/drawn_number_3.png | Bin 0 -> 378 bytes data/nums/drawn_number_4.png | Bin 0 -> 319 bytes data/nums/drawn_number_5.png | Bin 0 -> 349 bytes data/nums/drawn_number_6.png | Bin 0 -> 299 bytes data/nums/drawn_number_7.png | Bin 0 -> 322 bytes data/nums/drawn_number_8.png | Bin 0 -> 349 bytes data/nums/drawn_number_9.png | Bin 0 -> 314 bytes src/drawing.py | 24 +++++++++++- src/handler.py | 70 +++++++++-------------------------- src/nums/builders.py | 22 ++++------- src/{ => nums}/training.py | 0 src/server.py | 27 +------------- src/utils.py | 67 ++++++++++----------------------- 14 files changed, 69 insertions(+), 141 deletions(-) create mode 100644 data/nums/drawn_number_0.png create mode 100644 data/nums/drawn_number_3.png create mode 100644 data/nums/drawn_number_4.png create mode 100644 data/nums/drawn_number_5.png create mode 100644 data/nums/drawn_number_6.png create mode 100644 data/nums/drawn_number_7.png create mode 100644 data/nums/drawn_number_8.png create mode 100644 data/nums/drawn_number_9.png rename src/{ => nums}/training.py (100%) diff --git a/data/nums/drawn_number_0.png b/data/nums/drawn_number_0.png new file mode 100644 index 0000000000000000000000000000000000000000..b171750d5385024b44d687c4db39f112d559c12a GIT binary patch literal 329 zcmV-P0k-~$P)~UthHT*v60^vPd&hGTJ+a(!3-+LoQ;w zkJ$BgNqzS)8#d<>z6K>JW5Ek!76@EPsUj>mAzL(H4<-3`P<6w_bKta1n!Ext9=7fZ zm@y>=uEC}wr;Y-ZP52a-Aez9Bre^4Gt=urOd@U|?`z;7GXt?574_=vn4?YvzCH z*N`mn{QK|UpFe+K5!QobC<6n7^X1R*xcoDX8!isDIM?dkhEV=p%mmjGG Y0GYpjQ{`u literal 0 HcmV?d00001 diff --git a/data/nums/drawn_number_4.png b/data/nums/drawn_number_4.png new file mode 100644 index 0000000000000000000000000000000000000000..c33027ecb765a7bd11e557c6862dc31fcdefbce4 GIT binary patch literal 319 zcmV-F0l@x=P)NO_r718E*>B>WhLWPu<5*G)Ti(34 zu-w=Pq_6rfnCe2>|Jmi(Z~z}+j&ZiHQEqB0Qjo{EmyRccz4I>aUgbLfNZ&s`P3|?e R={^7e002ovPDHLkV1jowjLiT5 literal 0 HcmV?d00001 diff --git a/data/nums/drawn_number_5.png b/data/nums/drawn_number_5.png new file mode 100644 index 0000000000000000000000000000000000000000..ef7ace83759149fda25d8411e6bb116727c76a96 GIT binary patch literal 349 zcmeAS@N?(olHy`uVBq!ia0vp^G9b(WBpAZe8ax;n7*#x7978JRBqvCuGzfN{U8=5` z_C4iXZqBQadH;TYe}8|!@|_zh>LBLtR)@$nW;%bqzb|JrW@ZK=<>fN_Etr|PWBQKw zJ=}hJ>3*Jr@2BQGIeTNK#5^r!mB&vcW`u9v{Qt?ny{d2BPuf-*C+oJXjQ;<(`umUL z`|J15%X#LoJ@Ci#{qa>bg1w);?CsV&O}F{+{{C|VKh|E6iQk^PFG%tUusSuZv2k<$ zL#M<@XTfFXGo(`QgPe1Ie*Epcy?@KXn6)^H-zGOlRGLNp<-W`;Tz9s*MUumZ2MA=W z|LwhUSjr8^%)9?T_n6|*M+Lv`ulHCJ)W^+V_xHDXujp)<$J6Wo|NFaBwKr7x^&|EB s^|n(S!}^YhoqI65__@8_@=Kfy3|5b`qkFYI?Lgt?>FVdQ&MBb@05IgBr~m)} literal 0 HcmV?d00001 diff --git a/data/nums/drawn_number_6.png b/data/nums/drawn_number_6.png new file mode 100644 index 0000000000000000000000000000000000000000..823e332900fd693e6bd25439f4f868956991c3bb GIT binary patch literal 299 zcmV+`0o4A9P)m zpD#+dAA!z} x{{8#+??V9-h8zQz_wHAZdP0O*Fr|hD001SQUY`N9iK74j002ovPDHLkV1l_Mkj?-A literal 0 HcmV?d00001 diff --git a/data/nums/drawn_number_7.png b/data/nums/drawn_number_7.png new file mode 100644 index 0000000000000000000000000000000000000000..4af54d539d0535ef6ab059c0aff2aaf3d510a709 GIT binary patch literal 322 zcmV-I0lof-P)>ox@iPItNBmZvLlbv zi*C#{d;~U5G_E!zNtF)#y|sJ(zFP1l$t(IgJj)ItE=Q6y{>*ulwclDbo!DjJEwv#R zt{b*{w$_}q=8?D>M_zbxMjkn+a literal 0 HcmV?d00001 diff --git a/data/nums/drawn_number_8.png b/data/nums/drawn_number_8.png new file mode 100644 index 0000000000000000000000000000000000000000..87ab98f870c8d1840e520ff70551eb45d275c310 GIT binary patch literal 349 zcmV-j0iyniP)M#@+I$#Fz`dI(CTxrs6~-JvMuAE0J2 zXV<`fzs?Ne*ny+orv>*Gq}GcCXB_E$)AHtG0IeMh>h!TbZnHn-3$y=*Xhg=u1poj5 M07*qoM6N<$f?xiL2><{9 literal 0 HcmV?d00001 diff --git a/src/drawing.py b/src/drawing.py index 4f2810b..6d1aad7 100644 --- a/src/drawing.py +++ b/src/drawing.py @@ -3,12 +3,13 @@ import numpy as np from src import utils +from src.nums import builders def draw_number(number, cond_gan): number = int(number) noise = np.random.normal(size=(1, utils.latent_dim)) - label = keras.utils.to_categorical([number], utils.num_classes) + label = keras.utils.to_categorical([number], builders.num_classes) label = label.astype("float32") noise_and_label = np.concatenate([noise, label], 1) @@ -26,3 +27,24 @@ def draw_number(number, cond_gan): imageio.imwrite(filename, generated_image) return generated_image + +def draw_image(description, cond_gan): + noise = np.random.normal(size=(1, utils.latent_dim)) + label = keras.utils.to_categorical([description], builders.num_classes) + label = label.astype("float32") + + noise_and_label = np.concatenate([noise, label], 1) + generated_image = cond_gan.generator.predict(noise_and_label) + + generated_image = np.squeeze(generated_image) + + if len(generated_image.shape) > 2: + generated_image = np.mean(generated_image, axis=-1) + + generated_image = np.clip(generated_image * 255, 0, 255) + generated_image = generated_image.astype(np.uint8) + + filename = f"./data/images/drawn_image_{description}.png" + imageio.imwrite(filename, generated_image) + + return generated_image \ No newline at end of file diff --git a/src/handler.py b/src/handler.py index 1819230..4875d74 100644 --- a/src/handler.py +++ b/src/handler.py @@ -2,29 +2,15 @@ import logging from src import drawing, utils +from src.nums import builders # pylint: disable=too-few-public-methods class Handler: - """Class that handles client requests. - - This class handles the requests sent by the client and executes the - corresponding actions. - - Attributes: - socket (socket): Client socket. - """ - def __init__(self, socket): - """Class constructor. - - Args: - socket: Client socket. - """ self.socket = socket def handle(self): - """Method that handles the client's request.""" with self.socket: data = self._receive_data() request = self._process_request(data) @@ -32,22 +18,9 @@ def handle(self): self._send_response(response) def _receive_data(self): - """Method that receives the data sent by the client. - - Returns: - str: Data sent by the client. - """ return self.socket.recv(1024).decode() def _process_request(self, data): - """Method that processes the data received from the client. - - Args: - data (str): Data sent by the client. - - Returns: - dict: Request data. - """ try: return json.loads(data) @@ -62,19 +35,20 @@ def _process_request(self, data): return response def _execute_command(self, request): - """Method that executes the command sent by the client. - - Args: - request (dict): Request data. - - Returns: - dict: Response data. - """ command = request.get("command") text = request.get("text", "") if command == "generate_number": response = self._generate_number(text) + + elif command == "generate_image": + response = { + "status": "error", + "message": "Command not implemented", + } + + logging.error("Command not implemented: %s", command) + else: response = { "status": "error", @@ -86,16 +60,12 @@ def _execute_command(self, request): return response def _generate_number(self, text): - """Method that generates an image of a number. + try: + generator, discriminator = builders.build_models() + cond_gan = builders.build_conditional_gan(generator, discriminator) - Args: - text (str): Number to generate. + cond_gan = utils.load_model_with_weights("models/cgan_nums.weights.h5", cond_gan) - Returns: - dict: Response data. - """ - try: - cond_gan = utils.load_model_with_weights("models/cgan_nums.weights.h5") except FileNotFoundError as e: response = { "status": "error", @@ -129,13 +99,9 @@ def _generate_number(self, text): return response - def _send_response(self, response): - """Method that sends the response to the client. + def _generate_image(self, text): + # TODO: Implement this method + pass - Args: - response (dict): Response data. - - Returns: - dict: Response data. - """ + def _send_response(self, response): self.socket.sendall(json.dumps(response).encode()) diff --git a/src/nums/builders.py b/src/nums/builders.py index 5a0ee7a..de1951b 100644 --- a/src/nums/builders.py +++ b/src/nums/builders.py @@ -3,6 +3,10 @@ from src import utils from src.nums import cgan +num_channels = 1 +num_classes = 10 +image_size = 28 + def build_models(): """ @@ -13,8 +17,8 @@ def build_models(): keras.Model: Discriminator model. """ # - - - - - - - Calculate the number of input channels - - - - - - - - gen_channels = utils.latent_dim + utils.num_classes - dis_channels = utils.num_channels + utils.num_classes + gen_channels = utils.latent_dim + num_classes + dis_channels = num_channels + num_classes # - - - - - - - Generator - - - - - - - generator = keras.Sequential( @@ -50,22 +54,12 @@ def build_models(): def build_conditional_gan(generator, discriminator): - """ - Builds the conditional GAN (cGAN) model. - - Args: - generator (keras.Model): Generator model. - discriminator (keras.Model): Discriminator model. - - Returns: - conditionalGAN: Compiled cGAN model. - """ config = cgan.GANConfig( discriminator=discriminator, generator=generator, latent_dim=utils.latent_dim, - image_size=utils.image_size, - num_classes=utils.num_classes, + image_size=image_size, + num_classes=num_classes, ) cond_gan = cgan.ConditionalGAN( diff --git a/src/training.py b/src/nums/training.py similarity index 100% rename from src/training.py rename to src/nums/training.py diff --git a/src/server.py b/src/server.py index 6c5fb8c..64bad55 100644 --- a/src/server.py +++ b/src/server.py @@ -5,44 +5,19 @@ class Server: - """Class that represents a server that listens on a port and handles - client connections. - - This class is responsible for starting a server that listens on a specific - port and handles incoming client connections. - - Attributes: - host (str): IP address or host name where the server will listen. - port (int): Port where the server will listen. - """ - def __init__(self, host, port): - """Class constructor. - - Args: - host (str): IP address or host name where the server will listen. - port (int): Port where the server will listen. - """ self.host = host self.port = port def start(self): - """Method that starts the server.""" with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server_socket: server_socket.bind((self.host, self.port)) server_socket.listen() print(f"Server listening on port {self.port}") while True: client_socket, _ = server_socket.accept() - threading.Thread( - target=self.client_handler, args=(client_socket,) - ).start() + threading.Thread(target=self.client_handler, args=(client_socket,)).start() def client_handler(self, client_socket): - """Method that handles a connection with a client. - - Args: - client_socket (socket): Client socket. - """ handler = Handler(client_socket) handler.handle() diff --git a/src/utils.py b/src/utils.py index 81dfd6a..b52d6f5 100644 --- a/src/utils.py +++ b/src/utils.py @@ -4,73 +4,44 @@ import numpy as np import tensorflow as tf -from src.nums import builders - batch_size = 64 -num_channels = 1 -num_classes = 10 -image_size = 28 latent_dim = 128 -def load_dataset(): - """ - Loads the MNIST dataset, preprocesses it, and returns it as a TensorFlow dataset. +def load_dataset(dataset_name): + dataset_dict = { + "mnist": (keras.datasets.mnist, 10, 1, 28), + "cifar10": (keras.datasets.cifar10, 10, 3, 32), + "cifar100": (keras.datasets.cifar100, 100, 3, 32), + } + + if dataset_name not in dataset_dict: + raise ValueError("Invalid dataset name") + + dataset, num_classes, num_channels, image_size = dataset_dict[dataset_name] + (x_train, y_train), (x_test, y_test) = dataset.load_data() - Returns: - tf.data.Dataset: Dataset containing the preprocessed MNIST images and labels. - """ - (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() - all_digits = np.concatenate([x_train, x_test]) - all_labels = np.concatenate([y_train, y_test]) + all_images = np.concatenate([x_train, x_test]).astype("float32") / 255.0 + all_labels = keras.utils.to_categorical(np.concatenate([y_train, y_test]), num_classes) - all_digits = all_digits.astype("float32") / 255.0 - all_digits = np.reshape(all_digits, (-1, 28, 28, 1)) - all_labels = keras.utils.to_categorical(all_labels, 10) + all_images = np.reshape(all_images, (-1, image_size, image_size, num_channels)) - dataset = tf.data.Dataset.from_tensor_slices((all_digits, all_labels)) + dataset = tf.data.Dataset.from_tensor_slices((all_images, all_labels)) dataset = dataset.shuffle(buffer_size=1024).batch(batch_size) return dataset def train_model(dataset, cond_gan): - """ - Trains the conditional GAN (cGAN) model. - - Args: - dataset (tf.data.Dataset): Dataset containing the training images and labels. - cond_gan (conditionalGAN): Compiled cGAN model. - """ cond_gan.fit(dataset, epochs=50) def save_model_weights(cond_gan, filename): - """ - Saves the weights of the conditional GAN (cGAN) model. - - Args: - cond_gan (conditionalGAN): Compiled cGAN model. - filename (str): Filepath to save the model weights. - """ if os.path.exists(filename): os.remove(filename) cond_gan.save_weights(filename) -def load_model_with_weights(filename): - """ - Loads the conditional GAN (cGAN) model with saved weights. - - Args: - filename (str): Filepath to the saved model weights. - - Returns: - conditionalGAN: cGAN model with loaded weights. - """ - generator, discriminator = builders.build_models() - - new_cond_gan = builders.build_conditional_gan(generator, discriminator) - - new_cond_gan.load_weights(filename) - return new_cond_gan +def load_model_with_weights(filename, cond_gan): + cond_gan.load_weights(filename) + return cond_gan