diff --git a/keybert/llm/_litellm.py b/keybert/llm/_litellm.py index f0e55469..7063a93f 100644 --- a/keybert/llm/_litellm.py +++ b/keybert/llm/_litellm.py @@ -68,6 +68,7 @@ class LiteLLM(BaseLLM): def __init__(self, model: str = "gpt-3.5-turbo", prompt: str = None, + system_content: str = "You are a helpful assistant.", generator_kwargs: Mapping[str, Any] = {}, delay_in_seconds: float = None, verbose: bool = False @@ -79,6 +80,7 @@ def __init__(self, else: self.prompt = prompt + self.system_content = system_content self.default_prompt_ = DEFAULT_PROMPT self.delay_in_seconds = delay_in_seconds self.verbose = verbose @@ -116,7 +118,7 @@ def extract_keywords(self, documents: List[str], candidate_keywords: List[List[s # Use a chat model messages = [ - {"role": "system", "content": "You are a helpful assistant."}, + {"role": "system", "content": self.system_content}, {"role": "user", "content": prompt} ] kwargs = {"model": self.model, "messages": messages, **self.generator_kwargs} diff --git a/keybert/llm/_openai.py b/keybert/llm/_openai.py index 5c8c078c..70671afa 100644 --- a/keybert/llm/_openai.py +++ b/keybert/llm/_openai.py @@ -114,6 +114,7 @@ def __init__(self, client, model: str = "gpt-3.5-turbo-instruct", prompt: str = None, + system_content: str = "You are a helpful assistant.", generator_kwargs: Mapping[str, Any] = {}, delay_in_seconds: float = None, exponential_backoff: bool = False, @@ -128,6 +129,7 @@ def __init__(self, else: self.prompt = prompt + self.system_content = system_content self.default_prompt_ = DEFAULT_CHAT_PROMPT if chat else DEFAULT_PROMPT self.delay_in_seconds = delay_in_seconds self.exponential_backoff = exponential_backoff @@ -170,7 +172,7 @@ def extract_keywords(self, documents: List[str], candidate_keywords: List[List[s # Use a chat model if self.chat: messages = [ - {"role": "system", "content": "You are a helpful assistant."}, + {"role": "system", "content": self.system_content}, {"role": "user", "content": prompt} ] kwargs = {"model": self.model, "messages": messages, **self.generator_kwargs}