Skip to content

Commit

Permalink
Updated persona handler logic to better handle the default personas a…
Browse files Browse the repository at this point in the history
…nd new events for persona update
  • Loading branch information
kirgrim committed Jan 26, 2025
1 parent 44653e5 commit a0be34c
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 40 deletions.
48 changes: 32 additions & 16 deletions neon_llm_core/rmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,17 @@ def register_consumers(self):
queue=self.queue_opinion,
callback=self.handle_opinion_request,
on_error=self.default_error_handler,)
self.register_consumer(name=f'neon_llm_{self.name}_personas',
vhost=self.vhost,
queue=self.queue_personas,
callback=self.handle_new_personas,
on_error=self.default_error_handler)

self.register_subscriber(name=f'neon_llm_{self.name}_update_persona',
vhost=self.vhost,
exchange=self.exchange_persona_updated,
callback=self.handle_persona_update,
on_error=self.default_error_handler)
self.register_subscriber(name=f'neon_llm_{self.name}_delete_persona',
vhost=self.vhost,
exchange=self.exchange_persona_deleted,
callback=self.handle_persona_delete,
on_error=self.default_error_handler)

@property
@abstractmethod
def name(self):
Expand All @@ -106,8 +111,12 @@ def queue_opinion(self):
return f"{self.name}_discussion_input"

@property
def queue_personas(self):
return f"{self.name}_personas_input"
def exchange_persona_updated(self):
return f"{self.name}_persona_updated"

@property
def exchange_persona_deleted(self):
return f"{self.name}_persona_deleted"

@property
@abstractmethod
Expand All @@ -126,18 +135,25 @@ def handle_request(self, body: dict):
daemon=True).start()

@create_mq_callback()
def handle_new_personas(self, body: dict):
def handle_persona_update(self, body: dict):
"""
Handles an emitted message from the server containing updated persona data
for this LLM
:param body: MQ message body containing persona data for update
"""
with self._persona_update_lock:
self._personas_provider.stop_default_personas()
self._personas_provider.apply_incoming_persona_data(body)

@create_mq_callback()
def handle_persona_delete(self, body: dict):
"""
Handles an emitted message from the server containing personas defined
Handles an emitted message from the server containing deleted persona data
for this LLM
:param body: MQ message body containing persona definitions
:param body: MQ message body containing persona data for deletion
"""
if body.get("update_time", time()) <= self._last_persona_update:
LOG.info("Skipping update that is older than last update")
return
with self._persona_update_lock:
self._personas_provider.parse_persona_response(body)
self._last_persona_update = body.get("update_time") or time()
self._personas_provider.remove_persona(body)

def _handle_request_async(self, request: dict):
message_id = request["message_id"]
Expand Down
14 changes: 11 additions & 3 deletions neon_llm_core/utils/personas/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,14 @@
# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from abc import ABC
from typing import Optional

from pydantic import BaseModel, computed_field


class PersonaModel(BaseModel):
class PersonaBaseModel(BaseModel, ABC):
name: str
description: str
enabled: bool = True
user_id: Optional[str] = None

@computed_field
Expand All @@ -41,3 +40,12 @@ def id(self) -> str:
if self.user_id:
persona_id += f"_{self.user_id}"
return persona_id


class PersonaModel(PersonaBaseModel):
description: str
enabled: bool = True


class PersonaDeleteModel(PersonaBaseModel):
pass
97 changes: 90 additions & 7 deletions neon_llm_core/utils/personas/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from neon_utils.logger import LOG

from neon_llm_core.utils.constants import LLM_VHOST
from neon_llm_core.utils.personas.models import PersonaModel
from neon_llm_core.utils.personas.models import PersonaModel, PersonaDeleteModel
from neon_llm_core.utils.personas.state import PersonaHandlersState


Expand Down Expand Up @@ -72,25 +72,56 @@ def personas(self) -> List[PersonaModel]:
def personas(self, data: List[PersonaModel]):
LOG.debug(f'Setting personas={data}')
if self._should_reset_personas(data=data):
LOG.warning(f'Persona state TTL expired, resetting personas config')
LOG.warning(f'Persona state expired, setting default personas')
self._personas = []
self._persona_handlers_state.init_default_handlers()
self._persona_handlers_state.init_default_personas()
else:
self._personas = data
self._persona_handlers_state.clean_up_personas(ignore_items=self._personas)
self._persona_handlers_state.clean_up_personas(ignore_items=self._personas)

def _should_reset_personas(self, data: List[PersonaModel]) -> bool:
"""
Checks if personas should be re-initialized after setting a new value
for personas.
If PERSONA_SYNC_INTERVAL is enabled - verifies based on TTL, otherwise
:param data: requested list of personas
:return: True if requested `data` should be ignored and personas
reloaded from config, False if requested `data` should be used
directly
"""
return self._should_reset_personas_based_on_ttl(data) if self.PERSONA_SYNC_INTERVAL > 0 else not data

def _should_reset_personas_based_on_ttl(self, data: dict) -> bool:
"""
Determines whether personas should be reset based on Time-to-Live (TTL) and
the synchronization timestamp.
Examines the time elapsed since the last persona synchronization in relation
to a predefined TTL. Also considers whether the state of the personas and the
incoming data indicate a need for resetting.
:param data: provided persona data
returns: True if personas need to be reset based on TTL False otherwise.
"""
return (not (self._persona_last_sync == 0 and data)
and int(time()) - self._persona_last_sync > self.PERSONA_STATE_TTL)

def stop_default_personas(self):
"""
Stops all default personas that are currently running.
This method checks whether there are any default personas actively running
and stops each one by invoking the respective removal method. It ensures
that all default personas listed in the current persona handler state are
terminated properly.
"""
if self._persona_handlers_state.default_personas_running:
LOG.info("Stopping default personas")
self._persona_handlers_state.clean_up_personas()

def _fetch_persona_config(self):
"""
Get personas from a provider on the MQ bus and update the internal
Expand All @@ -104,17 +135,69 @@ def _fetch_persona_config(self):
self.parse_persona_response(response)

def parse_persona_response(self, persona_response: dict):
"""
Parses and processes a response containing persona data, updates internal state,
and manages personas accordingly.
:param persona_response: A dictionary containing the response data with
persona information.
Expected to contain a key 'items' holding a list of
persona details.
"""
if 'items' in persona_response:
self._persona_last_sync = int(time())
response_data = persona_response.get('items', [])
personas = []
self.stop_default_personas()
for item in response_data:
item.setdefault('name', item.pop('persona_name', None))
persona = PersonaModel.parse_obj(obj=item)
self._persona_handlers_state.add_persona_handler(persona=persona)
persona = self.apply_incoming_persona_data(item)
personas.append(persona)
self.personas = personas

def apply_incoming_persona_data(self, persona_data: dict) -> PersonaModel:
"""
Apply and update incoming persona data and return an updated PersonaModel instance.
This method is responsible for processing incoming persona data, applying necessary
updates to it, and adding the updated persona instance to the state management system.
It ensures proper validation of the persona data, logs a successful update message,
and returns the resulting `PersonaModel`.
:param persona_data : A dictionary containing details of the persona, where
specific key-value mappings are applied for validation.
returns: A validated and updated `PersonaModel` instance based on the provided input data.
"""
persona_data.setdefault('name', persona_data.pop('persona_name', None))
persona = PersonaModel.model_validate(obj=persona_data)
self._persona_handlers_state.add_persona_handler(persona=persona)
LOG.info(f"Persona {persona.id} updated successfully")
return persona

def remove_persona(self, persona_data: dict):
"""
Removes a persona from the active persona handlers state.
This method handles the removal of a persona based on persona data.
It ensures that the default personas are initialized if no other connected personas
remain after the removal.
:param persona_data: A dictionary containing details of the persona to be removed.
"""
if (self._persona_handlers_state.has_connected_personas() and
not self._persona_handlers_state.default_personas_running):
persona_data.setdefault('name', persona_data.pop('persona_name', None))
persona = PersonaDeleteModel.model_validate(obj=persona_data)
self._persona_handlers_state.remove_persona(persona_id=persona.id)
LOG.info(f"Persona {persona.id} removed successfully")

if not self._persona_handlers_state.has_connected_personas():
LOG.info("No personas connected after the last removal - setting default personas")
self._persona_handlers_state.init_default_personas()
else:
LOG.warning("No running personas detected - skipping persona removal")

def start_sync(self):
"""
Update personas and start thread to periodically update from a service
Expand Down
50 changes: 36 additions & 14 deletions neon_llm_core/utils/personas/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import time
from functools import cached_property
from threading import Lock
from typing import Dict, List, Optional

from neon_utils.logger import LOG
Expand All @@ -46,17 +48,30 @@ def __init__(self, service_name: str, ovos_config: dict):
self.service_name = service_name
self.ovos_config = ovos_config
self.mq_config = ovos_config.get('MQ', {})
self.default_personas_running = False

def init_default_handlers(self):
self.personas_clean_up_lock = Lock()
self.personas_remove_lock = Lock()

@cached_property
def default_personas(self):
return self.ovos_config.get("llm_bots", {}).get(self.service_name, [])

def has_connected_personas(self) -> bool:
return bool(self._created_items)

def init_default_personas(self):
"""
Initializes LLMBot instances for all personas defined in configuration.
"""
self._created_items = {}
if self.ovos_config.get("llm_bots", {}).get(self.service_name):
LOG.info(f"Chatbot(s) configured for: {self.service_name}")
for persona in self.ovos_config['llm_bots'][self.service_name]:
if self.default_personas and not self.default_personas_running:
self.clean_up_personas()
LOG.info(f"Initializing default personas for: {self.service_name}")
for persona in self.default_personas:
self.add_persona_handler(
persona=PersonaModel.parse_obj(obj=persona))
persona=PersonaModel.model_validate(obj=persona)
)
self.default_personas_running = True

def add_persona_handler(self, persona: PersonaModel) -> Optional[LLMBot]:
"""
Expand Down Expand Up @@ -95,13 +110,20 @@ def add_persona_handler(self, persona: PersonaModel) -> Optional[LLMBot]:
return bot

def clean_up_personas(self, ignore_items: List[PersonaModel] = None):
connected_personas = set(self._created_items)
ignored_persona_ids = set(persona.id for persona in ignore_items or [])
personas_to_remove = connected_personas - ignored_persona_ids
for persona_id in personas_to_remove:
self.remove_persona(persona_id=persona_id)
with self.personas_clean_up_lock:
connected_personas = set(self._created_items)
ignored_persona_ids = set(persona.id for persona in ignore_items or [])
personas_to_remove = connected_personas - ignored_persona_ids
for persona_id in personas_to_remove:
self.remove_persona(persona_id=persona_id)

def remove_persona(self, persona_id: str):
LOG.info(f'Removing persona_id = {persona_id}')
self._created_items[persona_id].stop()
self._created_items.pop(persona_id, None)
with self.personas_remove_lock:
if persona_id in self._created_items:
LOG.info(f'Removing persona_id = {persona_id}')
self._created_items[persona_id].stop()
self._created_items.pop(persona_id, None)

if not self.has_connected_personas() and self.default_personas_running:
LOG.info("All default personas stopped")
self.default_personas_running = False

0 comments on commit a0be34c

Please sign in to comment.