-
Notifications
You must be signed in to change notification settings - Fork 357
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
KeyLLM to extract keywords from text with LLMs
- Loading branch information
Showing
11 changed files
with
873 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from keybert._model import KeyBERT | ||
from keybert._llm import KeyLLM | ||
|
||
__version__ = "0.7.0" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
from typing import List, Union | ||
|
||
|
||
class KeyLLM: | ||
""" | ||
A minimal method for keyword extraction with Large Language Models (LLM) | ||
The keyword extraction is done by simply asking the LLM to extract a | ||
number of keywords from a single piece of text. | ||
""" | ||
|
||
def __init__(self, llm): | ||
"""KeyBERT initialization | ||
Arguments: | ||
llm: The Large Language Model to use | ||
""" | ||
self.llm = llm | ||
|
||
def extract_keywords( | ||
self, | ||
docs: Union[str, List[str]], | ||
) -> Union[List[str], List[List[str]]]: | ||
"""Extract keywords and/or keyphrases | ||
To get the biggest speed-up, make sure to pass multiple documents | ||
at once instead of iterating over a single document. | ||
NOTE: The resulting keywords are expected to be separated by commas so | ||
any changes to the prompt will have to make sure that the resulting | ||
keywords are comma-separated. | ||
Arguments: | ||
docs: The document(s) for which to extract keywords/keyphrases | ||
top_n: Return the top n keywords/keyphrases | ||
Returns: | ||
keywords: The top n keywords for a document with their respective distances | ||
to the input document. | ||
Usage: | ||
To extract keywords from a single document: | ||
```python | ||
import openai | ||
from keybert.llm import OpenAI | ||
from keybert import KeyLLM | ||
# Create your LLM | ||
openai.api_key = "sk-..." | ||
llm = OpenAI() | ||
# Load it in KeyLLM | ||
kw_model = KeyLLM(llm) | ||
# Extract keywords | ||
document = "The website mentions that it only takes a couple of days to deliver but I still have not received mine." | ||
keywords = kw_model.extract_keywords(document) | ||
``` | ||
""" | ||
# Check for a single, empty document | ||
if isinstance(docs, str): | ||
if docs: | ||
docs = [docs] | ||
else: | ||
return [] | ||
|
||
keywords = self.llm.extract_keywords(docs) | ||
return keywords |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
class NotInstalled: | ||
""" | ||
This object is used to notify the user that additional dependencies need to be | ||
installed in order to use the string matching model. | ||
""" | ||
|
||
def __init__(self, tool, dep, custom_msg=None): | ||
self.tool = tool | ||
self.dep = dep | ||
|
||
msg = f"In order to use {self.tool} you will need to install via;\n\n" | ||
if custom_msg is not None: | ||
msg += custom_msg | ||
else: | ||
msg += f"pip install bertopic[{self.dep}]\n\n" | ||
self.msg = msg | ||
|
||
def __getattr__(self, *args, **kwargs): | ||
raise ModuleNotFoundError(self.msg) | ||
|
||
def __call__(self, *args, **kwargs): | ||
raise ModuleNotFoundError(self.msg) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
from keybert._utils import NotInstalled | ||
from keybert.llm._base import BaseRepresentation | ||
from keybert.llm._textgeneration import TextGeneration | ||
|
||
|
||
# OpenAI Generator | ||
try: | ||
from keybert.llm._openai import OpenAI | ||
except ModuleNotFoundError: | ||
msg = "`pip install openai` \n\n" | ||
OpenAI = NotInstalled("OpenAI", "openai", custom_msg=msg) | ||
|
||
# Cohere Generator | ||
try: | ||
from keybert.llm._cohere import Cohere | ||
except ModuleNotFoundError: | ||
msg = "`pip install cohere` \n\n" | ||
Cohere = NotInstalled("Cohere", "cohere", custom_msg=msg) | ||
|
||
# LangChain Generator | ||
try: | ||
from keybert.llm._langchain import LangChain | ||
except ModuleNotFoundError: | ||
msg = "`pip install langchain` \n\n" | ||
LangChain = NotInstalled("langchain", "langchain", custom_msg=msg) | ||
|
||
# LiteLLM | ||
try: | ||
from keybert.llm._litellm import LiteLLM | ||
except ModuleNotFoundError: | ||
msg = "`pip install litellm` \n\n" | ||
LiteLLM = NotInstalled("LiteLLM", "litellm", custom_msg=msg) | ||
|
||
|
||
__all__ = [ | ||
"BaseRepresentation", | ||
"Cohere", | ||
"OpenAI", | ||
"TextGeneration", | ||
"LangChain", | ||
"LiteLLM" | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import pandas as pd | ||
from scipy.sparse import csr_matrix | ||
from sklearn.base import BaseEstimator | ||
from typing import Mapping, List, Tuple | ||
|
||
|
||
class BaseRepresentation(BaseEstimator): | ||
""" The base representation model for fine-tuning topic representations """ | ||
def extract_topics(self, | ||
topic_model, | ||
documents: pd.DataFrame, | ||
c_tf_idf: csr_matrix, | ||
topics: Mapping[str, List[Tuple[str, float]]] | ||
) -> Mapping[str, List[Tuple[str, float]]]: | ||
""" Extract topics | ||
Each representation model that inherits this class will have | ||
its arguments (topic_model, documents, c_tf_idf, topics) | ||
automatically passed. Therefore, the representation model | ||
will only have access to the information about topics related | ||
to those arguments. | ||
Arguments: | ||
topic_model: The BERTopic model that is fitted until topic | ||
representations are calculated. | ||
documents: A dataframe with columns "Document" and "Topic" | ||
that contains all documents with each corresponding | ||
topic. | ||
c_tf_idf: A c-TF-IDF representation that is typically | ||
identical to `topic_model.c_tf_idf_` except for | ||
dynamic, class-based, and hierarchical topic modeling | ||
where it is calculated on a subset of the documents. | ||
topics: A dictionary with topic (key) and tuple of word and | ||
weight (value) as calculated by c-TF-IDF. This is the | ||
default topics that are returned if no representation | ||
model is used. | ||
""" | ||
return topic_model.topic_representations_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
import time | ||
from tqdm import tqdm | ||
from typing import List | ||
from keybert.llm._base import BaseRepresentation | ||
|
||
|
||
DEFAULT_PROMPT = """ | ||
The following is a list of documents. Please extract the top keywords, separated by a comma, that describe the topic of the texts. | ||
Document: | ||
- Traditional diets in most cultures were primarily plant-based with a little meat on top, but with the rise of industrial style meat production and factory farming, meat has become a staple food. | ||
Keywords: Traditional diets, Plant-based, Meat, Industrial style meat production, Factory farming, Staple food, Cultural dietary practices | ||
Document: | ||
- The website mentions that it only takes a couple of days to deliver but I still have not received mine. | ||
Keywords: Website, Delivery, Mention, Timeframe, Not received, Waiting, Order fulfillment | ||
Document: | ||
- [DOCUMENT] | ||
Keywords:""" | ||
|
||
|
||
class Cohere(BaseRepresentation): | ||
""" Use the Cohere API to generate topic labels based on their | ||
generative model. | ||
Find more about their models here: | ||
https://docs.cohere.ai/docs | ||
NOTE: The resulting keywords are expected to be separated by commas so | ||
any changes to the prompt will have to make sure that the resulting | ||
keywords are comma-separated. | ||
Arguments: | ||
client: A cohere.Client | ||
model: Model to use within Cohere, defaults to `"xlarge"`. | ||
prompt: The prompt to be used in the model. If no prompt is given, | ||
`self.default_prompt_` is used instead. | ||
NOTE: Use `"[KEYWORDS]"` and `"[DOCUMENTS]"` in the prompt | ||
to decide where the keywords and documents need to be | ||
inserted. | ||
delay_in_seconds: The delay in seconds between consecutive prompts | ||
in order to prevent RateLimitErrors. | ||
verbose: Set this to True if you want to see a progress bar for the | ||
keyword extraction. | ||
Usage: | ||
To use this, you will need to install cohere first: | ||
`pip install cohere` | ||
Then, get yourself an API key and use Cohere's API as follows: | ||
```python | ||
import cohere | ||
from keybert.llm import Cohere | ||
from keybert import KeyLLM | ||
# Create your LLM | ||
co = cohere.Client(my_api_key) | ||
llm = Cohere(co) | ||
# Load it in KeyLLM | ||
kw_model = KeyLLM(llm) | ||
# Extract keywords | ||
document = "The website mentions that it only takes a couple of days to deliver but I still have not received mine." | ||
keywords = kw_model.extract_keywords(document) | ||
``` | ||
You can also use a custom prompt: | ||
```python | ||
prompt = "I have the following document: [DOCUMENT]. What keywords does it contain? Make sure to separate the keywords with commas." | ||
llm = Cohere(co, prompt=prompt) | ||
``` | ||
""" | ||
def __init__(self, | ||
client, | ||
model: str = "xlarge", | ||
prompt: str = None, | ||
delay_in_seconds: float = None, | ||
verbose: bool = False | ||
): | ||
self.client = client | ||
self.model = model | ||
self.prompt = prompt if prompt is not None else DEFAULT_PROMPT | ||
self.default_prompt_ = DEFAULT_PROMPT | ||
self.delay_in_seconds = delay_in_seconds | ||
self.verbose = verbose | ||
|
||
def extract_keywords(self, documents: List[str]): | ||
""" Extract topics | ||
Arguments: | ||
documents: The documents to extract keywords from | ||
Returns: | ||
all_keywords: All keywords for each document | ||
""" | ||
all_keywords = [] | ||
|
||
for document in tqdm(documents, disable=not self.verbose): | ||
prompt = self.prompt.replace("[DOCUMENT]", document) | ||
|
||
# Delay | ||
if self.delay_in_seconds: | ||
time.sleep(self.delay_in_seconds) | ||
|
||
request = self.client.generate(model=self.model, | ||
prompt=prompt, | ||
max_tokens=50, | ||
num_generations=1, | ||
stop_sequences=["\n"]) | ||
keywords = request.generations[0].text.strip() | ||
keywords = [keyword.strip() for keyword in keywords.split(",")] | ||
all_keywords.append(keywords) | ||
|
||
return all_keywords |
Oops, something went wrong.