Skip to content

Commit

Permalink
Another refactorization
Browse files Browse the repository at this point in the history
  • Loading branch information
Carlota de la Vega committed Jun 18, 2024
1 parent 19e0369 commit be6814d
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 29 deletions.
Binary file removed data/nums/drawn_number_5.png
Binary file not shown.
71 changes: 47 additions & 24 deletions src/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,7 @@ def _execute_command(self, request):
response = self._generate_number(text)

elif command == "generate_image":
response = {
"status": "error",
"message": "Command not implemented",
}

logging.error("Command not implemented: %s", command)
response = self._generate_image(text)

else:
response = {
Expand All @@ -60,11 +55,46 @@ def _execute_command(self, request):
return response

def _generate_number(self, text):
try:
generator, discriminator = builders.build_models()
cond_gan = builders.build_conditional_gan(generator, discriminator)
cond_gan = self.build_and_load("nums")
img = self.draw("nums", cond_gan)

response = {
"status": "success",
"message": "Image generated successfully",
"image": img,
}

cond_gan = utils.load_model_with_weights("models/cgan_nums.weights.h5", cond_gan)
logging.info("Image generated successfully: %s", text)

return response

def _generate_image(self, text):
response = {
"status": "error",
"message": "Command not implemented",
}

logging.error("Command not implemented: %s", text)

return response

def build_and_load(self, model_type):
try:
if model_type == "nums":
dataset = utils.load_dataset("mnist")
generator, discriminator = builders.build_models()
cond_gan = builders.build_conditional_gan(generator, discriminator)
model_path = "cgan_nums.weights.h5"

elif model_type == "images":
dataset = utils.load_dataset("cifar10")
generator, discriminator = builders.build_models()
cond_gan = builders.build_conditional_gan(generator, discriminator)
model_path = "cgan_images.weights.h5"

utils.train_model(dataset, cond_gan)
utils.save_model_weights(cond_gan, model_path)
return cond_gan

except FileNotFoundError as e:
response = {
Expand All @@ -75,19 +105,16 @@ def _generate_number(self, text):

return response

def draw(self, draw_type, cond_gan):
try:
img = drawing.draw_number(text, cond_gan)
img = img.tolist()
if draw_type == "nums":
img = drawing.draw_number("5", cond_gan)

response = {
"status": "success",
"message": "Image generated successfully",
"image": img,
}
elif draw_type == "images":
img = drawing.draw_image("cat", cond_gan)

logging.info("Image generated successfully: %s", text)

return response
img = img.tolist()
return img

except Exception as e:
response = {
Expand All @@ -99,9 +126,5 @@ def _generate_number(self, text):

return response

def _generate_image(self, text):
# TODO: Implement this method
pass

def _send_response(self, response):
self.socket.sendall(json.dumps(response).encode())
8 changes: 8 additions & 0 deletions src/images/training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from src import utils
from src.nums import builders

dataset = utils.load_dataset("cifar10")
generator, discriminator = builders.build_models()
cond_gan = builders.build_conditional_gan(generator, discriminator)
utils.train_model(dataset, cond_gan)
utils.save_model_weights(cond_gan, "cgan_images.weights.h5")
2 changes: 1 addition & 1 deletion src/nums/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
generator, discriminator = builders.build_models()
cond_gan = builders.build_conditional_gan(generator, discriminator)
utils.train_model(dataset, cond_gan)
utils.save_model_weights(cond_gan, "cond_weights.weights.h5")
utils.save_model_weights(cond_gan, "cgan_nums.weights.h5")
9 changes: 5 additions & 4 deletions tests/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from unittest.mock import MagicMock

from src import utils
from src.nums import builders


class TestTraining(unittest.TestCase):
Expand All @@ -12,15 +13,15 @@ def test_load_dataset(self):
self.assertEqual(dataset, "mocked_dataset")

def test_build_models(self):
utils.build_models = MagicMock(return_value=("mocked_generator", "mocked_discriminator"))
generator, discriminator = utils.build_models()
builders.build_models = MagicMock(return_value=("mocked_generator", "mocked_discriminator"))
generator, discriminator = builders.build_models()

self.assertEqual(generator, "mocked_generator")
self.assertEqual(discriminator, "mocked_discriminator")

def test_build_conditional_gan(self):
utils.build_conditional_gan = MagicMock(return_value="mocked_cond_gan")
cond_gan = utils.build_conditional_gan("mocked_generator", "mocked_discriminator")
builders.build_conditional_gan = MagicMock(return_value="mocked_cond_gan")
cond_gan = builders.build_conditional_gan("mocked_generator", "mocked_discriminator")

self.assertEqual(cond_gan, "mocked_cond_gan")

Expand Down

0 comments on commit be6814d

Please sign in to comment.