Skip to content

Commit

Permalink
Integrated KeyLLM into KeyBERT
Browse files Browse the repository at this point in the history
  • Loading branch information
MaartenGr committed Sep 26, 2023
1 parent ad98da8 commit 779eb5a
Show file tree
Hide file tree
Showing 10 changed files with 140 additions and 50 deletions.
2 changes: 1 addition & 1 deletion keybert/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from keybert._model import KeyBERT
from keybert._llm import KeyLLM
from keybert._model import KeyBERT

__version__ = "0.7.0"
78 changes: 76 additions & 2 deletions keybert/_llm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
from typing import List, Union

try:
from sentence_transformers import util
HAS_SBERT = True
except ModuleNotFoundError:
HAS_SBERT = False


class KeyLLM:
"""
Expand All @@ -20,6 +26,10 @@ def __init__(self, llm):
def extract_keywords(
self,
docs: Union[str, List[str]],
check_vocab: bool = False,
candidate_keywords: List[List[str]] = None,
threshold: float = None,
embeddings = None
) -> Union[List[str], List[List[str]]]:
"""Extract keywords and/or keyphrases
Expand All @@ -32,7 +42,8 @@ def extract_keywords(
Arguments:
docs: The document(s) for which to extract keywords/keyphrases
top_n: Return the top n keywords/keyphrases
check_vocab: Only return keywords that appear exactly in the documents
candidate_keywords: Candidate keywords for each document
Returns:
keywords: The top n keywords for a document with their respective distances
Expand Down Expand Up @@ -66,5 +77,68 @@ def extract_keywords(
else:
return []

keywords = self.llm.extract_keywords(docs)
if HAS_SBERT and threshold is not None and embeddings is not None:

# Find similar documents
clusters = util.community_detection(embeddings, min_community_size=2, threshold=threshold)
in_cluster = set([cluster for cluster_set in clusters for cluster in cluster_set])
out_cluster = set(list(range(len(docs)))).difference(in_cluster)

# Extract keywords for all documents not in a cluster
if out_cluster:
selected_docs = [docs[index] for index in out_cluster]
print(out_cluster, selected_docs)
if candidate_keywords is not None:
selected_keywords = [candidate_keywords[index] for index in out_cluster]
else:
selected_keywords = None
print(f"Call LLM with {len(selected_docs)} docs; out-cluster")
out_cluster_keywords = self.llm.extract_keywords(
selected_docs,
selected_keywords,
)
out_cluster_keywords = {index: words for words, index in zip(out_cluster_keywords, out_cluster)}

# Extract keywords for only the first document in a cluster
if in_cluster:
selected_docs = [docs[cluster[0]] for cluster in clusters]
print(in_cluster, selected_docs)
if candidate_keywords is not None:
selected_keywords = [candidate_keywords[cluster[0]] for cluster in in_cluster]
else:
selected_keywords = None
print(f"Call LLM with {len(selected_docs)} docs; in-cluster")
in_cluster_keywords = self.llm.extract_keywords(
selected_docs,
selected_keywords
)
in_cluster_keywords = {
doc_id: in_cluster_keywords[index]
for index, cluster in enumerate(clusters)
for doc_id in cluster
}

# Update out cluster keywords with in cluster keywords
if out_cluster:
if in_cluster:
out_cluster_keywords.update(in_cluster_keywords)
print(out_cluster_keywords)
keywords = [out_cluster_keywords[index] for index in range(len(docs))]
else:
keywords = [in_cluster_keywords[index] for index in range(len(docs))]
else:
# Extract keywords using a Large Language Model (LLM)
keywords = self.llm.extract_keywords(docs, candidate_keywords)

# Only extract keywords that appear in the input document
if check_vocab:
updated_keywords = []
for keyword_set, document in zip(keywords, docs):
updated_keyword_set = []
for keyword in keyword_set:
if keyword in document:
updated_keyword_set.append(keyword)
updated_keywords.append(updated_keyword_set)
return updated_keywords

return keywords
23 changes: 22 additions & 1 deletion keybert/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from keybert._maxsum import max_sum_distance
from keybert._highlight import highlight_document
from keybert.backend._utils import select_backend
from keybert.llm._base import BaseLLM
from keybert import KeyLLM


class KeyBERT:
Expand All @@ -36,7 +38,7 @@ class KeyBERT:
</div>
"""

def __init__(self, model="all-MiniLM-L6-v2"):
def __init__(self, model="all-MiniLM-L6-v2", llm: BaseLLM = None):
"""KeyBERT initialization
Arguments:
Expand All @@ -54,6 +56,11 @@ def __init__(self, model="all-MiniLM-L6-v2"):
"""
self.model = select_backend(model)

if isinstance(llm, BaseLLM):
self.llm = KeyLLM(llm)
else:
self.llm = llm

def extract_keywords(
self,
docs: Union[str, List[str]],
Expand All @@ -71,6 +78,7 @@ def extract_keywords(
seed_keywords: Union[List[str], List[List[str]]] = None,
doc_embeddings: np.array = None,
word_embeddings: np.array = None,
threshold: float = None
) -> Union[List[Tuple[str, float]], List[List[Tuple[str, float]]]]:
"""Extract keywords and/or keyphrases
Expand Down Expand Up @@ -245,6 +253,19 @@ def extract_keywords(
highlight_document(docs[0], all_keywords[0], count)
all_keywords = all_keywords[0]

# Fine-tune keywords using an LLM
if self.llm is not None:
if isinstance(all_keywords[0], tuple):
candidate_keywords = [[keyword[0] for keyword in all_keywords]]
else:
candidate_keywords = [[keyword[0] for keyword in keywords] for keywords in all_keywords]
keywords = self.llm.extract_keywords(
docs,
embeddings=doc_embeddings,
candidate_keywords=candidate_keywords,
threshold=threshold
)
return keywords
return all_keywords

def extract_embeddings(
Expand Down
4 changes: 2 additions & 2 deletions keybert/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from keybert._utils import NotInstalled
from keybert.llm._base import BaseRepresentation
from keybert.llm._base import BaseLLM
from keybert.llm._textgeneration import TextGeneration


Expand Down Expand Up @@ -33,7 +33,7 @@


__all__ = [
"BaseRepresentation",
"BaseLLM",
"Cohere",
"OpenAI",
"TextGeneration",
Expand Down
42 changes: 12 additions & 30 deletions keybert/llm/_base.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,20 @@
import pandas as pd
from scipy.sparse import csr_matrix
from sklearn.base import BaseEstimator
from typing import Mapping, List, Tuple
from typing import List


class BaseRepresentation(BaseEstimator):
class BaseLLM(BaseEstimator):
""" The base representation model for fine-tuning topic representations """
def extract_topics(self,
topic_model,
documents: pd.DataFrame,
c_tf_idf: csr_matrix,
topics: Mapping[str, List[Tuple[str, float]]]
) -> Mapping[str, List[Tuple[str, float]]]:
def extract_keywords(self, documents: List[str], candidate_keywords: List[List[str]] = None):
""" Extract topics
Each representation model that inherits this class will have
its arguments (topic_model, documents, c_tf_idf, topics)
automatically passed. Therefore, the representation model
will only have access to the information about topics related
to those arguments.
Arguments:
topic_model: The BERTopic model that is fitted until topic
representations are calculated.
documents: A dataframe with columns "Document" and "Topic"
that contains all documents with each corresponding
topic.
c_tf_idf: A c-TF-IDF representation that is typically
identical to `topic_model.c_tf_idf_` except for
dynamic, class-based, and hierarchical topic modeling
where it is calculated on a subset of the documents.
topics: A dictionary with topic (key) and tuple of word and
weight (value) as calculated by c-TF-IDF. This is the
default topics that are returned if no representation
model is used.
documents: The documents to extract keywords from
candidate_keywords: A list of candidate keywords that the LLM will fine-tune
For example, it will create a nicer representation of
the candidate keywords, remove redundant keywords, or
shorten them depending on the input prompt.
Returns:
all_keywords: All keywords for each document
"""
return topic_model.topic_representations_
return [None for document in documents]
4 changes: 2 additions & 2 deletions keybert/llm/_cohere.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import time
from tqdm import tqdm
from typing import List
from keybert.llm._base import BaseRepresentation
from keybert.llm._base import BaseLLM


DEFAULT_PROMPT = """
Expand All @@ -23,7 +23,7 @@
Keywords:"""


class Cohere(BaseRepresentation):
class Cohere(BaseLLM):
""" Use the Cohere API to generate topic labels based on their
generative model.
Expand Down
4 changes: 2 additions & 2 deletions keybert/llm/_langchain.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from tqdm import tqdm
from typing import List
from langchain.docstore.document import Document
from keybert.llm._base import BaseRepresentation
from keybert.llm._base import BaseLLM


DEFAULT_PROMPT = "What is this document about? Please provide keywords separated by commas."


class LangChain(BaseRepresentation):
class LangChain(BaseLLM):
""" Using chains in langchain to generate keywords.
Currently, only chains from question answering is implemented. See:
Expand Down
4 changes: 2 additions & 2 deletions keybert/llm/_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from tqdm import tqdm
from litellm import completion
from typing import Mapping, Any, List
from keybert.llm._base import BaseRepresentation
from keybert.llm._base import BaseLLM


DEFAULT_PROMPT = """
Expand All @@ -15,7 +15,7 @@
"""


class LiteLLM(BaseRepresentation):
class LiteLLM(BaseLLM):
""" Extract keywords using LiteLLM to call any LLM API using OpenAI format
such as Anthropic, Huggingface, Cohere, TogetherAI, Azure, OpenAI, etc.
Expand Down
25 changes: 19 additions & 6 deletions keybert/llm/_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import openai
from tqdm import tqdm
from typing import Mapping, Any, List
from keybert.llm._base import BaseRepresentation
from keybert.llm._base import BaseLLM
from keybert.llm._utils import retry_with_exponential_backoff


Expand Down Expand Up @@ -34,7 +34,7 @@
"""


class OpenAI(BaseRepresentation):
class OpenAI(BaseLLM):
""" Using the OpenAI API to extract keywords
The default method is `openai.Completion` if `chat=False`.
Expand Down Expand Up @@ -110,7 +110,7 @@ class OpenAI(BaseRepresentation):
```
"""
def __init__(self,
model: str = "text-ada-001",
model: str = "gpt-3.5-turbo-instruct",
prompt: str = None,
generator_kwargs: Mapping[str, Any] = {},
delay_in_seconds: float = None,
Expand Down Expand Up @@ -139,19 +139,32 @@ def __init__(self,
if not self.generator_kwargs.get("stop") and not chat:
self.generator_kwargs["stop"] = "\n"

def extract_keywords(self, documents: List[str]):
def extract_keywords(self, documents: List[str], candidate_keywords: List[List[str]] = None):
""" Extract topics
Arguments:
documents: The documents to extract keywords from
candidate_keywords: A list of candidate keywords that the LLM will fine-tune
For example, it will create a nicer representation of
the candidate keywords, remove redundant keywords, or
shorten them depending on the input prompt.
Returns:
all_keywords: All keywords for each document
"""
all_keywords = []

for document in tqdm(documents, disable=not self.verbose):
if candidate_keywords is None:
candidate_keywords = [None for _ in documents]
elif isinstance(candidate_keywords[0][0], str) and not isinstance(candidate_keywords[0], list):
candidate_keywords = [[keyword for keyword, _ in candidate_keywords]]
elif isinstance(candidate_keywords[0][0], tuple):
candidate_keywords = [[keyword for keyword, _ in keywords] for keywords in candidate_keywords]
print(candidate_keywords)

for document, candidates in tqdm(zip(documents, candidate_keywords), disable=not self.verbose):
prompt = self.prompt.replace("[DOCUMENT]", document)
if candidates is not None:
prompt = prompt.replace("[CANDIDATES]", ", ".join(candidates))

# Delay
if self.delay_in_seconds:
Expand Down
4 changes: 2 additions & 2 deletions keybert/llm/_textgeneration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from transformers import pipeline, set_seed
from transformers.pipelines.base import Pipeline
from typing import Mapping, List, Any, Union
from keybert.llm._base import BaseRepresentation
from keybert.llm._base import BaseLLM


DEFAULT_PROMPT = """
Expand All @@ -13,7 +13,7 @@
"""


class TextGeneration(BaseRepresentation):
class TextGeneration(BaseLLM):
""" Text2Text or text generation with transformers
NOTE: The resulting keywords are expected to be separated by commas so
Expand Down

0 comments on commit 779eb5a

Please sign in to comment.