diff --git a/src/main.py b/src/main.py index ce57ac4..edf4064 100644 --- a/src/main.py +++ b/src/main.py @@ -1,4 +1,5 @@ from src import server -server = server.Server("localhost", 12345) -server.start() +if __name__ == "__main__": + server = server.Server("localhost", 12345) + server.start() diff --git a/src/utils.py b/src/utils.py index 457c12d..d4ec62e 100644 --- a/src/utils.py +++ b/src/utils.py @@ -90,19 +90,23 @@ def build_conditional_gan(generator, discriminator): Returns: conditionalGAN: Compiled cGAN model. """ - cond_gan = cgan.ConditionalGAN( + config = cgan.GANConfig( discriminator=discriminator, generator=generator, latent_dim=latent_dim, image_size=image_size, num_classes=num_classes, ) - cond_gan.compile( + + cond_gan = cgan.ConditionalGAN( + config=config, d_optimizer=keras.optimizers.Adam(learning_rate=0.0003), g_optimizer=keras.optimizers.Adam(learning_rate=0.0003), loss_fn=keras.losses.BinaryCrossentropy(from_logits=True), ) + cond_gan.compile() + return cond_gan diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_handler.py b/tests/test_handler.py new file mode 100644 index 0000000..17304ce --- /dev/null +++ b/tests/test_handler.py @@ -0,0 +1,65 @@ +import os +import unittest +from unittest.mock import MagicMock + +from src.handler import Handler + + +class TestHandler(unittest.TestCase): + def test_execute_command_generate_number_success(self): + # Arrange + request = { + "command": "generate_number", + "text": "5" + } + expected_file_path = "./data/nums/drawn_number_5.png" + handler = Handler(MagicMock()) + + # Act + handler._execute_command(request) + + # Assert + self.assertTrue(os.path.exists(expected_file_path)) + + # Clean up + os.remove(expected_file_path) + + + def test_execute_command_generate_number_error(self): + # Arrange + request = { + "command": "generate_number", + "text": "abc" + } + expected_response = { + "status": "error", + "message": "Error generating the image: invalid literal for int() with base 10: 'abc'" + } + handler = Handler(MagicMock()) + + # Act + response = handler._execute_command(request) + + # Assert + self.assertEqual(response, expected_response) + + def test_execute_command_unknown_command(self): + # Arrange + request = { + "command": "unknown_command", + "text": "5" + } + expected_response = { + "status": "error", + "message": "Unknown command: unknown_command" + } + handler = Handler(MagicMock()) + + # Act + response = handler._execute_command(request) + + # Assert + self.assertEqual(response, expected_response) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file