Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Opens sentence transformer backend to edit batch_size param #210

Merged
merged 2 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ repos:
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
- repo: https://gitlab.com/pycqa/flake8
rev: 3.8.4
- repo: https://github.com/PyCQA/flake8
rev: 7.0.0
hooks:
- id: flake8
- repo: https://github.com/psf/black
Expand Down
54 changes: 35 additions & 19 deletions keybert/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from keybert._mmr import mmr
from keybert._maxsum import max_sum_distance
from keybert._highlight import highlight_document
from keybert.backend._base import BaseEmbedder
from keybert.backend._utils import select_backend
from keybert.llm._base import BaseLLM
from keybert import KeyLLM
Expand All @@ -38,11 +39,15 @@ class KeyBERT:
</div>
"""

def __init__(self, model="all-MiniLM-L6-v2", llm: BaseLLM = None):
def __init__(
self,
model="all-MiniLM-L6-v2",
llm: BaseLLM = None,
):
"""KeyBERT initialization

Arguments:
model: Use a custom embedding model.
model: Use a custom embedding model or a specific KeyBERT Backend.
The following backends are currently supported:
* SentenceTransformers
* 🤗 Transformers
Expand Down Expand Up @@ -78,7 +83,7 @@ def extract_keywords(
seed_keywords: Union[List[str], List[List[str]]] = None,
doc_embeddings: np.array = None,
word_embeddings: np.array = None,
threshold: float = None
threshold: float = None,
) -> Union[List[Tuple[str, float]], List[List[Tuple[str, float]]]]:
"""Extract keywords and/or keyphrases

Expand Down Expand Up @@ -111,9 +116,9 @@ def extract_keywords(
NOTE: This does not work if multiple documents are passed.
seed_keywords: Seed keywords that may guide the extraction of keywords by
steering the similarities towards the seeded keywords.
NOTE: when multiple documents are passed,
NOTE: when multiple documents are passed,
`seed_keywords`funtions in either of the two ways below:
- globally: when a flat list of str is passed, keywords are shared by all documents,
- globally: when a flat list of str is passed, keywords are shared by all documents,
- locally: when a nested list of str is passed, keywords differs among documents.
doc_embeddings: The embeddings of each document.
word_embeddings: The embeddings of each potential keyword/keyphrase across
Expand Down Expand Up @@ -178,10 +183,12 @@ def extract_keywords(
# Check if the right number of word embeddings are generated compared with the vectorizer
if word_embeddings is not None:
if word_embeddings.shape[0] != len(words):
raise ValueError("Make sure that the `word_embeddings` are generated from the function "
"`.extract_embeddings`. \nMoreover, the `candidates`, `keyphrase_ngram_range`,"
"`stop_words`, and `min_df` parameters need to have the same values in both "
"`.extract_embeddings` and `.extract_keywords`.")
raise ValueError(
"Make sure that the `word_embeddings` are generated from the function "
"`.extract_embeddings`. \nMoreover, the `candidates`, `keyphrase_ngram_range`,"
"`stop_words`, and `min_df` parameters need to have the same values in both "
"`.extract_embeddings` and `.extract_keywords`."
)

# Extract embeddings
if doc_embeddings is None:
Expand All @@ -192,15 +199,21 @@ def extract_keywords(
# Guided KeyBERT either local (keywords shared among documents) or global (keywords per document)
if seed_keywords is not None:
if isinstance(seed_keywords[0], str):
seed_embeddings = self.model.embed(seed_keywords).mean(axis=0, keepdims=True)
seed_embeddings = self.model.embed(seed_keywords).mean(
axis=0, keepdims=True
)
elif len(docs) != len(seed_keywords):
raise ValueError("The length of docs must match the length of seed_keywords")
raise ValueError(
"The length of docs must match the length of seed_keywords"
)
else:
seed_embeddings = np.vstack([
self.model.embed(keywords).mean(axis=0, keepdims=True)
for keywords in seed_keywords
])
doc_embeddings = ((doc_embeddings * 3 + seed_embeddings) / 4)
seed_embeddings = np.vstack(
[
self.model.embed(keywords).mean(axis=0, keepdims=True)
for keywords in seed_keywords
]
)
doc_embeddings = (doc_embeddings * 3 + seed_embeddings) / 4

# Find keywords
all_keywords = []
Expand Down Expand Up @@ -256,18 +269,21 @@ def extract_keywords(
# Fine-tune keywords using an LLM
if self.llm is not None:
import torch

doc_embeddings = torch.from_numpy(doc_embeddings).float()
if torch.cuda.is_available():
doc_embeddings = doc_embeddings.to("cuda")
if isinstance(all_keywords[0], tuple):
candidate_keywords = [[keyword[0] for keyword in all_keywords]]
else:
candidate_keywords = [[keyword[0] for keyword in keywords] for keywords in all_keywords]
candidate_keywords = [
[keyword[0] for keyword in keywords] for keywords in all_keywords
]
keywords = self.llm.extract_keywords(
docs,
embeddings=doc_embeddings,
candidate_keywords=candidate_keywords,
threshold=threshold
threshold=threshold,
)
return keywords
return all_keywords
Expand All @@ -279,7 +295,7 @@ def extract_embeddings(
keyphrase_ngram_range: Tuple[int, int] = (1, 1),
stop_words: Union[str, List[str]] = "english",
min_df: int = 1,
vectorizer: CountVectorizer = None
vectorizer: CountVectorizer = None,
) -> Union[List[Tuple[str, float]], List[List[Tuple[str, float]]]]:
"""Extract document and word embeddings for the input documents and the
generated candidate keywords/keyphrases respectively.
Expand Down
3 changes: 2 additions & 1 deletion keybert/backend/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ._base import BaseEmbedder
from ._sentencetransformers import SentenceTransformerBackend

__all__ = ["BaseEmbedder"]
__all__ = ["BaseEmbedder", "SentenceTransformerBackend"]
9 changes: 7 additions & 2 deletions keybert/backend/_sentencetransformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class SentenceTransformerBackend(BaseEmbedder):

Arguments:
embedding_model: A sentence-transformers embedding model
encode_kwargs: Additional parameters for the SentenceTransformers.encode() method

Usage:

Expand All @@ -33,7 +34,9 @@ class SentenceTransformerBackend(BaseEmbedder):
```
"""

def __init__(self, embedding_model: Union[str, SentenceTransformer]):
def __init__(
self, embedding_model: Union[str, SentenceTransformer], **encode_kwargs
):
super().__init__()

if isinstance(embedding_model, SentenceTransformer):
Expand All @@ -46,6 +49,7 @@ def __init__(self, embedding_model: Union[str, SentenceTransformer]):
"`from sentence_transformers import SentenceTransformer` \n"
"`model = SentenceTransformer('all-MiniLM-L6-v2')`"
)
self.encode_kwargs = encode_kwargs

def embed(self, documents: List[str], verbose: bool = False) -> np.ndarray:
"""Embed a list of n documents/words into an n-dimensional
Expand All @@ -59,5 +63,6 @@ def embed(self, documents: List[str], verbose: bool = False) -> np.ndarray:
Document/words embeddings with shape (n, m) with `n` documents/words
that each have an embeddings size of `m`
"""
embeddings = self.embedding_model.encode(documents, show_progress_bar=verbose)
self.encode_kwargs.update({"show_progress_bar": verbose})
embeddings = self.embedding_model.encode(documents, **self.encode_kwargs)
return embeddings
41 changes: 41 additions & 0 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import pytest
from keybert import KeyBERT
from keybert.backend import SentenceTransformerBackend
import sentence_transformers

from sklearn.feature_extraction.text import CountVectorizer
from .utils import get_test_data


doc_one, doc_two = get_test_data()


@pytest.mark.parametrize("keyphrase_length", [(1, i + 1) for i in range(5)])
@pytest.mark.parametrize(
"vectorizer", [None, CountVectorizer(ngram_range=(1, 1), stop_words="english")]
)
def test_single_doc_sentence_transformer_backend(keyphrase_length, vectorizer):
"""Test whether the keywords are correctly extracted"""
top_n = 5

model_name = "paraphrase-MiniLM-L6-v2"
st_model = sentence_transformers.SentenceTransformer(model_name, device="cpu")

kb_model = KeyBERT(model=SentenceTransformerBackend(st_model, batch_size=128))

keywords = kb_model.extract_keywords(
doc_one,
keyphrase_ngram_range=keyphrase_length,
min_df=1,
top_n=top_n,
vectorizer=vectorizer,
)

assert model_name in kb_model.model.embedding_model.tokenizer.name_or_path
assert isinstance(keywords, list)
assert isinstance(keywords[0], tuple)
assert isinstance(keywords[0][0], str)
assert isinstance(keywords[0][1], float)
assert len(keywords) == top_n
for keyword in keywords:
assert len(keyword[0].split(" ")) <= keyphrase_length[1]
Loading