From 415d607fcada25b6e55a5ea28f0cf69839689d83 Mon Sep 17 00:00:00 2001 From: Daniel McKnight Date: Mon, 2 Dec 2024 15:52:02 -0800 Subject: [PATCH] Add RMQ unit test coverage Include `routing_key` in LLM responses to associate inputs/responses --- neon_llm_core/rmq.py | 21 +++++--- tests/test_rmq.py | 124 ++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 130 insertions(+), 15 deletions(-) diff --git a/neon_llm_core/rmq.py b/neon_llm_core/rmq.py index 7a31df8..6d87951 100644 --- a/neon_llm_core/rmq.py +++ b/neon_llm_core/rmq.py @@ -107,15 +107,17 @@ def model(self) -> NeonLLM: pass @create_mq_callback() - def handle_request(self, body: dict): + def handle_request(self, body: dict) -> Thread: """ Handles ask requests (response to prompt) from MQ to LLM :param body: request body (dict) """ # Handle this asynchronously so multiple subminds can be handled # concurrently - Thread(target=self._handle_request_async, args=(body,), - daemon=True).start() + t = Thread(target=self._handle_request_async, args=(body,), + daemon=True) + t.start() + return t def _handle_request_async(self, request: dict): message_id = request["message_id"] @@ -133,7 +135,8 @@ def _handle_request_async(self, request: dict): response = ('Sorry, but I cannot respond to your message at the ' 'moment, please try again later') api_response = LLMProposeResponse(message_id=message_id, - response=response) + response=response, + routing_key=routing_key) LOG.info(f"Sending response: {response}") self.send_message(request_data=api_response.model_dump(), queue=routing_key) @@ -154,17 +157,18 @@ def handle_score_request(self, body: dict): persona = body.get("persona", {}) if not responses: - sorted_answer_indexes = [] + sorted_answer_idx = [] else: try: - sorted_answer_indexes = self.model.get_sorted_answer_indexes( + sorted_answer_idx = self.model.get_sorted_answer_indexes( question=query, answers=responses, persona=persona) except ValueError as err: LOG.error(f'ValueError={err}') - sorted_answer_indexes = [] + sorted_answer_idx = [] api_response = LLMVoteResponse(message_id=message_id, - sorted_answer_indexes=sorted_answer_indexes) + routing_key=routing_key, + sorted_answer_indexes=sorted_answer_idx) self.send_message(request_data=api_response.model_dump(), queue=routing_key) LOG.info(f"Handled score request for message_id={message_id}") @@ -200,6 +204,7 @@ def handle_opinion_request(self, body: dict): "an opinion on this topic") api_response = LLMDiscussResponse(message_id=message_id, + routing_key=routing_key, opinion=opinion) self.send_message(request_data=api_response.model_dump(), queue=routing_key) diff --git a/tests/test_rmq.py b/tests/test_rmq.py index 2b3758e..bcdb4f3 100644 --- a/tests/test_rmq.py +++ b/tests/test_rmq.py @@ -30,6 +30,7 @@ from unittest.mock import Mock from mirakuru import ProcessExitedWithError +from neon_mq_connector.utils.network_utils import dict_to_b64 from port_for import get_port from pytest_rabbitmq.factories.executor import RabbitMqExecutor from pytest_rabbitmq.factories.process import get_config @@ -47,6 +48,11 @@ def __init__(self, rmq_port: int): "neon_llm_mock_mq": {"user": "test_llm_user", "password": "test_llm_password"}}}} NeonLLMMQConnector.__init__(self, config=config) + self._model = Mock() + self._model.ask.return_value = "Mock response" + self._model.get_sorted_answer_indexes.return_value = [0, 1] + self.send_message = Mock() + self._compose_opinion_prompt = Mock(return_value="Mock opinion prompt") @property def name(self): @@ -54,13 +60,12 @@ def name(self): @property def model(self) -> NeonLLM: - return Mock() + return self._model - @staticmethod - def compose_opinion_prompt(respondent_nick: str, + def compose_opinion_prompt(self, respondent_nick: str, question: str, answer: str) -> str: - return "opinion prompt" + return self._compose_opinion_prompt(respondent_nick, question, answer) @pytest.fixture(scope="class") @@ -112,6 +117,8 @@ def rmq_instance(request, tmp_path_factory): @pytest.mark.usefixtures("rmq_instance") class TestNeonLLMMQConnector(TestCase): + mq_llm: NeonMockLlm = None + rmq_instance: RabbitMqExecutor = None @classmethod def tearDownClass(cls): @@ -120,12 +127,115 @@ def tearDownClass(cls): except ProcessExitedWithError: pass - def test_00_init(self): - self.mq_llm = NeonMockLlm(self.rmq_instance.port) + def setUp(self): + if self.mq_llm is None: + self.mq_llm = NeonMockLlm(self.rmq_instance.port) + def test_00_init(self): self.assertIn(self.mq_llm.name, self.mq_llm.service_name) self.assertIsInstance(self.mq_llm.ovos_config, dict) self.assertEqual(self.mq_llm.vhost, "/llm") - self.assertIsNotNone(self.mq_llm.model) + self.assertIsNotNone(self.mq_llm.model, self.mq_llm.model) self.assertEqual(self.mq_llm._personas_provider.service_name, self.mq_llm.name) + + def test_handle_request(self): + from neon_data_models.models.api.mq import (LLMProposeRequest, + LLMProposeResponse) + # Valid Request + request = LLMProposeRequest(message_id="mock_message_id", + routing_key="mock_routing_key", + query="Mock Query", history=[]) + self.mq_llm.handle_request(None, None, None, + dict_to_b64(request.model_dump())).join() + self.mq_llm.model.ask.assert_called_with(message=request.query, + chat_history=request.history, + persona=request.persona) + response = self.mq_llm.send_message.call_args.kwargs + self.assertEqual(response['queue'], request.routing_key) + response = LLMProposeResponse(**response['request_data']) + self.assertIsInstance(response, LLMProposeResponse) + self.assertEqual(request.routing_key, response.routing_key) + self.assertEqual(request.message_id, response.message_id) + + self.assertEqual(response.response, self.mq_llm.model.ask()) + + def test_handle_opinion_request(self): + from neon_data_models.models.api.mq import (LLMDiscussRequest, + LLMDiscussResponse) + # Valid Request + request = LLMDiscussRequest(message_id="mock_message_id", + routing_key="mock_routing_key", + query="Mock Discuss", history=[], + options={"bot 1": "resp 1", + "bot 2": "resp 2"}) + self.mq_llm.handle_opinion_request(None, None, None, + dict_to_b64(request.model_dump())) + + self.mq_llm._compose_opinion_prompt.assert_called_with( + list(request.options.keys())[0], request.query, + list(request.options.values())[0]) + + response = self.mq_llm.send_message.call_args.kwargs + self.assertEqual(response['queue'], request.routing_key) + response = LLMDiscussResponse(**response['request_data']) + self.assertIsInstance(response, LLMDiscussResponse) + self.assertEqual(request.routing_key, response.routing_key) + self.assertEqual(request.message_id, response.message_id) + + self.assertEqual(response.opinion, self.mq_llm.model.ask()) + + # No input options + request = LLMDiscussRequest(message_id="mock_message_id1", + routing_key="mock_routing_key1", + query="Mock Discuss 1", history=[], + options={}) + self.mq_llm.handle_opinion_request(None, None, None, + dict_to_b64(request.model_dump())) + response = self.mq_llm.send_message.call_args.kwargs + self.assertEqual(response['queue'], request.routing_key) + response = LLMDiscussResponse(**response['request_data']) + self.assertIsInstance(response, LLMDiscussResponse) + self.assertEqual(request.routing_key, response.routing_key) + self.assertEqual(request.message_id, response.message_id) + self.assertNotEqual(response.opinion, self.mq_llm.model.ask()) + + # TODO: Test with invalid sorted answer indexes + + def test_handle_score_request(self): + from neon_data_models.models.api.mq import (LLMVoteRequest, + LLMVoteResponse) + + # Valid Request + request = LLMVoteRequest(message_id="mock_message_id", + routing_key="mock_routing_key", + query="Mock Score", history=[], + responses=["one", "two"]) + self.mq_llm.handle_score_request(None, None, None, + dict_to_b64(request.model_dump())) + + response = self.mq_llm.send_message.call_args.kwargs + self.assertEqual(response['queue'], request.routing_key) + response = LLMVoteResponse(**response['request_data']) + self.assertIsInstance(response, LLMVoteResponse) + self.assertEqual(request.routing_key, response.routing_key) + self.assertEqual(request.message_id, response.message_id) + + self.assertEqual(response.sorted_answer_indexes, + self.mq_llm.model.get_sorted_answer_indexes()) + + # No response options + request = LLMVoteRequest(message_id="mock_message_id", + routing_key="mock_routing_key", + query="Mock Score", history=[], responses=[]) + self.mq_llm.handle_score_request(None, None, None, + dict_to_b64(request.model_dump())) + + response = self.mq_llm.send_message.call_args.kwargs + self.assertEqual(response['queue'], request.routing_key) + response = LLMVoteResponse(**response['request_data']) + self.assertIsInstance(response, LLMVoteResponse) + self.assertEqual(request.routing_key, response.routing_key) + self.assertEqual(request.message_id, response.message_id) + + self.assertEqual(response.sorted_answer_indexes, [])