Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle Klat persona update events #10

Open
wants to merge 11 commits into
base: dev
Choose a base branch
from
11 changes: 8 additions & 3 deletions neon_llm_core/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,14 @@

from typing import List, Optional
from chatbot_core.v2 import ChatBot
from neon_data_models.models.api.mq import (LLMProposeRequest,
LLMDiscussRequest, LLMVoteRequest, LLMProposeResponse, LLMDiscussResponse,
LLMVoteResponse)
from neon_data_models.models.api.mq import (
LLMProposeRequest,
LLMDiscussRequest,
LLMVoteRequest,
LLMProposeResponse,
LLMDiscussResponse,
LLMVoteResponse,
)
from neon_mq_connector.utils.client_utils import send_mq_request
from neon_utils.logger import LOG
from neon_data_models.models.api.llm import LLMPersona
Expand Down
52 changes: 48 additions & 4 deletions neon_llm_core/rmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,19 @@
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from abc import abstractmethod, ABC
from threading import Thread
from threading import Thread, Lock
from time import time
from typing import Optional

from neon_mq_connector.connector import MQConnector
from neon_mq_connector.utils.rabbit_utils import create_mq_callback
from neon_utils.logger import LOG

from neon_data_models.models.api.mq import (LLMProposeResponse,
LLMDiscussResponse, LLMVoteResponse)
from neon_data_models.models.api.mq import (
LLMProposeResponse,
LLMDiscussResponse,
LLMVoteResponse,
)

from neon_llm_core.utils.config import load_config
from neon_llm_core.llm import NeonLLM
Expand All @@ -59,6 +63,8 @@ def __init__(self, config: Optional[dict] = None):
self.register_consumers()
self._model = None
self._bots = list()
self._persona_update_lock = Lock()
self._last_persona_update = time()
self._personas_provider = PersonasProvider(service_name=self.name,
ovos_config=self.ovos_config)

Expand All @@ -79,7 +85,17 @@ def register_consumers(self):
queue=self.queue_opinion,
callback=self.handle_opinion_request,
on_error=self.default_error_handler,)

self.register_subscriber(name=f'neon_llm_{self.name}_update_persona',
vhost=self.vhost,
exchange=self.exchange_persona_updated,
callback=self.handle_persona_update,
on_error=self.default_error_handler)
self.register_subscriber(name=f'neon_llm_{self.name}_delete_persona',
vhost=self.vhost,
exchange=self.exchange_persona_deleted,
callback=self.handle_persona_delete,
on_error=self.default_error_handler)

@property
@abstractmethod
def name(self):
Expand All @@ -104,6 +120,14 @@ def queue_score(self):
def queue_opinion(self):
return f"{self.name}_discussion_input"

@property
def exchange_persona_updated(self):
return f"{self.name}_persona_updated"

@property
def exchange_persona_deleted(self):
return f"{self.name}_persona_deleted"

@property
@abstractmethod
def model(self) -> NeonLLM:
Expand All @@ -122,6 +146,26 @@ def handle_request(self, body: dict) -> Thread:
t.start()
return t

@create_mq_callback()
def handle_persona_update(self, body: dict):
"""
Handles an emitted message from the server containing updated persona data
for this LLM
:param body: MQ message body containing persona data for update
"""
with self._persona_update_lock:
self._personas_provider.apply_incoming_persona(body)

@create_mq_callback()
def handle_persona_delete(self, body: dict):
"""
Handles an emitted message from the server containing deleted persona data
for this LLM
:param body: MQ message body containing persona data for deletion
"""
with self._persona_update_lock:
self._personas_provider.remove_persona(body)

def _handle_request_async(self, request: dict):
message_id = request["message_id"]
routing_key = request["routing_key"]
Expand Down
188 changes: 162 additions & 26 deletions neon_llm_core/utils/personas/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,30 @@
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import os
from time import time
from typing import List, Optional

from neon_data_models.models.api.llm import (
LLMPersona,
LLMPersonaIdentity,
)

from neon_mq_connector.utils import RepeatingTimer
from neon_mq_connector.utils.client_utils import send_mq_request
from neon_utils.logger import LOG
from pydantic import ValidationError

from neon_llm_core.utils.constants import LLM_VHOST
from neon_llm_core.utils.personas.models import PersonaModel
from neon_llm_core.utils.personas.state import PersonaHandlersState


class PersonasProvider:
"""
Manages personas defined via Klat. Each LLM that connects to the MQ bus will
include an instance of this object to track changes to personas.
"""

PERSONA_STATE_TTL = int(os.getenv("PERSONA_STATE_TTL", 15 * 60))
PERSONA_SYNC_INTERVAL = int(os.getenv("PERSONA_SYNC_INTERVAL", 5 * 60))
PERSONA_SYNC_INTERVAL = int(os.getenv("PERSONA_SYNC_INTERVAL", 0))
GET_CONFIGURED_PERSONAS_QUEUE = "get_configured_personas"

def __init__(self, service_name: str, ovos_config: dict):
Expand All @@ -50,7 +60,7 @@ def __init__(self, service_name: str, ovos_config: dict):
self._persona_sync_thread = None

@property
def persona_sync_thread(self):
def persona_sync_thread(self) -> RepeatingTimer:
"""Creates new synchronization thread which fetches Klat personas"""
if not (isinstance(self._persona_sync_thread, RepeatingTimer) and
self._persona_sync_thread.is_alive()):
Expand All @@ -60,48 +70,174 @@ def persona_sync_thread(self):
return self._persona_sync_thread

@property
def personas(self):
def personas(self) -> List[LLMPersona]:
return self._personas

@personas.setter
def personas(self, data):
def personas(self, data: List[LLMPersona]):
LOG.debug(f'Setting personas={data}')
if self._should_reset_personas(data=data):
LOG.warning(f'Persona state TTL expired, resetting personas config')
LOG.warning(f'Persona state expired, setting default personas')
self._personas = []
self._persona_handlers_state.init_default_handlers()
elif not data:
self._persona_handlers_state.init_default_handlers()
return
self._persona_handlers_state.init_default_personas()
else:
self._personas = data
self._persona_handlers_state.clean_up_personas(ignore_items=self._personas)
self._persona_handlers_state.clean_up_personas(ignore_items=self._personas)

def _should_reset_personas(self, data: List[LLMPersona]) -> bool:
"""
Checks if personas should be re-initialized after setting a new value
for personas.

If PERSONA_SYNC_INTERVAL is enabled - verifies based on TTL, otherwise

:param data: requested list of personas
:return: True if requested `data` should be ignored and personas
reloaded from config, False if requested `data` should be used
directly
"""
return self._should_reset_personas_based_on_ttl(data) if self.PERSONA_SYNC_INTERVAL > 0 else not data

def _should_reset_personas_based_on_ttl(self, data: dict) -> bool:
"""
Determines whether personas should be reset based on Time-to-Live (TTL) and
the synchronization timestamp.

Examines the time elapsed since the last persona synchronization in relation
to a predefined TTL. Also considers whether the state of the personas and the
incoming data indicate a need for resetting.

def _should_reset_personas(self, data) -> bool:
:param data: provided persona data

returns: True if personas need to be reset based on TTL False otherwise.
"""
return (not (self._persona_last_sync == 0 and data)
and int(time()) - self._persona_last_sync > self.PERSONA_STATE_TTL)

def _fetch_persona_config(self):
response = send_mq_request(vhost=LLM_VHOST,
request_data={"service_name": self.service_name},
target_queue=PersonasProvider.GET_CONFIGURED_PERSONAS_QUEUE,
timeout=60)
if 'items' in response:
"""
Get personas from a provider on the MQ bus and update the internal
`personas` reference.
"""
response = send_mq_request(
vhost=LLM_VHOST,
request_data={"service_name": self.service_name},
target_queue=PersonasProvider.GET_CONFIGURED_PERSONAS_QUEUE,
timeout=60)
self.parse_persona_response(response)

def parse_persona_response(self, persona_response: dict):
"""
Parses and processes a response containing persona data, updates internal state,
and manages personas accordingly.

:param persona_response: A dictionary containing the response data with
persona information.
Expected to contain a key 'items' holding a list of
persona details.
"""
if 'items' in persona_response:
self._persona_last_sync = int(time())
response_data = response.get('items', [])
personas = []
for item in response_data:
item.setdefault('name', item.pop('persona_name', None))
persona = PersonaModel.parse_obj(obj=item)
self._persona_handlers_state.add_persona_handler(persona=persona)
personas.append(persona)
self.personas = personas
response_data = persona_response.get('items', [])
validated_personas = [self._validate_persona_data(persona_data) for persona_data in response_data]
active_personas = []
for persona in validated_personas:
persona_applied = self.apply_incoming_persona(persona=persona)
if persona_applied:
active_personas.append(persona)
self.personas = active_personas

@staticmethod
def _validate_persona_data(persona_data: dict) -> Optional[LLMPersona]:
"""
Validates of the persona data and returns the resulting `PersonaModel`.
If validation fails - logs error and returns None

:param persona_data : A dictionary containing details of the persona, where
specific key-value mappings are applied for validation.

returns: A validated and updated `PersonaModel` instance based on the provided data if
validation was successful, None otherwise
"""
persona_data.setdefault('name', persona_data.pop('persona_name', None))

try:
persona = LLMPersona.model_validate(obj=persona_data)
except ValidationError as err:
LOG.error(f"Failed to apply persona data from {persona_data} - {str(err)}")
return
return persona

def apply_incoming_persona(self, persona: LLMPersona) -> bool:
"""
Attempts to add incoming persona and return an updated PersonaModel instance if successful.
If default personas are running upon adding persona - removes default personas.
If state container has only one running persona & it was disabled by this method -
triggers initialisation of the default personas.

:param persona: `LLMPersona` instance to add to the `PersonaHandlersState` container

:returns: True if incoming persona object in case of successful addition or if identical persona already exists
False otherwise
"""

new_persona = self._persona_handlers_state.add_persona_handler(persona=persona)

if new_persona:
LOG.info(f"Persona {persona.id} updated successfully")

# Once first manually configured persona added - pruning default personas
if self._persona_handlers_state.default_personas_running:
LOG.info("Starting to remove default personas")
self._persona_handlers_state.clean_up_personas(ignore_items=[persona])
self._persona_handlers_state.default_personas_running = False
LOG.info("Completed removing of default personas")

elif persona.id not in self._persona_handlers_state.connected_persona_ids:
# May occur if the last updated persona was set to be disabled
if not self._persona_handlers_state.has_connected_personas():
LOG.info("No personas connected after the last update - setting default personas")
self._persona_handlers_state.init_default_personas()
return False

return True

def remove_persona(self, persona_data: dict):
"""
Removes a persona from the active persona handlers state.

This method handles the removal of a persona based on persona data.
It ensures that the default personas are initialized if no other connected personas
remain after the removal.

:param persona_data: A dictionary containing details of the persona to be removed.
"""
if (self._persona_handlers_state.has_connected_personas() and
not self._persona_handlers_state.default_personas_running):
persona_data.setdefault('name', persona_data.pop('persona_name', None))
persona = LLMPersonaIdentity.model_validate(obj=persona_data)
self._persona_handlers_state.remove_persona(persona_id=persona.id)
LOG.info(f"Persona {persona.id} removed successfully")

if not self._persona_handlers_state.has_connected_personas():
LOG.info("No personas connected after the last removal - setting default personas")
self._persona_handlers_state.init_default_personas()
else:
LOG.warning("No running personas detected - skipping persona removal")

def start_sync(self):
"""
Update personas and start thread to periodically update from a service
on the MQ bus.
"""
self._fetch_persona_config()
self.persona_sync_thread.start()
if self.PERSONA_SYNC_INTERVAL > 0:
self.persona_sync_thread.start()

def stop_sync(self):
"""
Stop persona updates from the MQ bus.
"""
if self._persona_sync_thread:
self._persona_sync_thread.cancel()
self._persona_sync_thread = None
Loading
Loading