diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ec64e326..61d45e96 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,8 +11,8 @@ repos: - id: end-of-file-fixer - id: check-yaml - id: check-added-large-files -- repo: https://gitlab.com/pycqa/flake8 - rev: 3.8.4 +- repo: https://github.com/PyCQA/flake8 + rev: 7.0.0 hooks: - id: flake8 - repo: https://github.com/psf/black diff --git a/keybert/_model.py b/keybert/_model.py index c442e416..4e9990ad 100644 --- a/keybert/_model.py +++ b/keybert/_model.py @@ -13,6 +13,7 @@ from keybert._mmr import mmr from keybert._maxsum import max_sum_distance from keybert._highlight import highlight_document +from keybert.backend._base import BaseEmbedder from keybert.backend._utils import select_backend from keybert.llm._base import BaseLLM from keybert import KeyLLM @@ -38,11 +39,15 @@ class KeyBERT: """ - def __init__(self, model="all-MiniLM-L6-v2", llm: BaseLLM = None): + def __init__( + self, + model="all-MiniLM-L6-v2", + llm: BaseLLM = None, + ): """KeyBERT initialization Arguments: - model: Use a custom embedding model. + model: Use a custom embedding model or a specific KeyBERT Backend. The following backends are currently supported: * SentenceTransformers * 🤗 Transformers @@ -78,7 +83,7 @@ def extract_keywords( seed_keywords: Union[List[str], List[List[str]]] = None, doc_embeddings: np.array = None, word_embeddings: np.array = None, - threshold: float = None + threshold: float = None, ) -> Union[List[Tuple[str, float]], List[List[Tuple[str, float]]]]: """Extract keywords and/or keyphrases @@ -111,9 +116,9 @@ def extract_keywords( NOTE: This does not work if multiple documents are passed. seed_keywords: Seed keywords that may guide the extraction of keywords by steering the similarities towards the seeded keywords. - NOTE: when multiple documents are passed, + NOTE: when multiple documents are passed, `seed_keywords`funtions in either of the two ways below: - - globally: when a flat list of str is passed, keywords are shared by all documents, + - globally: when a flat list of str is passed, keywords are shared by all documents, - locally: when a nested list of str is passed, keywords differs among documents. doc_embeddings: The embeddings of each document. word_embeddings: The embeddings of each potential keyword/keyphrase across @@ -178,10 +183,12 @@ def extract_keywords( # Check if the right number of word embeddings are generated compared with the vectorizer if word_embeddings is not None: if word_embeddings.shape[0] != len(words): - raise ValueError("Make sure that the `word_embeddings` are generated from the function " - "`.extract_embeddings`. \nMoreover, the `candidates`, `keyphrase_ngram_range`," - "`stop_words`, and `min_df` parameters need to have the same values in both " - "`.extract_embeddings` and `.extract_keywords`.") + raise ValueError( + "Make sure that the `word_embeddings` are generated from the function " + "`.extract_embeddings`. \nMoreover, the `candidates`, `keyphrase_ngram_range`," + "`stop_words`, and `min_df` parameters need to have the same values in both " + "`.extract_embeddings` and `.extract_keywords`." + ) # Extract embeddings if doc_embeddings is None: @@ -192,15 +199,21 @@ def extract_keywords( # Guided KeyBERT either local (keywords shared among documents) or global (keywords per document) if seed_keywords is not None: if isinstance(seed_keywords[0], str): - seed_embeddings = self.model.embed(seed_keywords).mean(axis=0, keepdims=True) + seed_embeddings = self.model.embed(seed_keywords).mean( + axis=0, keepdims=True + ) elif len(docs) != len(seed_keywords): - raise ValueError("The length of docs must match the length of seed_keywords") + raise ValueError( + "The length of docs must match the length of seed_keywords" + ) else: - seed_embeddings = np.vstack([ - self.model.embed(keywords).mean(axis=0, keepdims=True) - for keywords in seed_keywords - ]) - doc_embeddings = ((doc_embeddings * 3 + seed_embeddings) / 4) + seed_embeddings = np.vstack( + [ + self.model.embed(keywords).mean(axis=0, keepdims=True) + for keywords in seed_keywords + ] + ) + doc_embeddings = (doc_embeddings * 3 + seed_embeddings) / 4 # Find keywords all_keywords = [] @@ -256,18 +269,21 @@ def extract_keywords( # Fine-tune keywords using an LLM if self.llm is not None: import torch + doc_embeddings = torch.from_numpy(doc_embeddings).float() if torch.cuda.is_available(): doc_embeddings = doc_embeddings.to("cuda") if isinstance(all_keywords[0], tuple): candidate_keywords = [[keyword[0] for keyword in all_keywords]] else: - candidate_keywords = [[keyword[0] for keyword in keywords] for keywords in all_keywords] + candidate_keywords = [ + [keyword[0] for keyword in keywords] for keywords in all_keywords + ] keywords = self.llm.extract_keywords( docs, embeddings=doc_embeddings, candidate_keywords=candidate_keywords, - threshold=threshold + threshold=threshold, ) return keywords return all_keywords @@ -279,7 +295,7 @@ def extract_embeddings( keyphrase_ngram_range: Tuple[int, int] = (1, 1), stop_words: Union[str, List[str]] = "english", min_df: int = 1, - vectorizer: CountVectorizer = None + vectorizer: CountVectorizer = None, ) -> Union[List[Tuple[str, float]], List[List[Tuple[str, float]]]]: """Extract document and word embeddings for the input documents and the generated candidate keywords/keyphrases respectively. diff --git a/keybert/backend/__init__.py b/keybert/backend/__init__.py index a6001558..1c6c690a 100644 --- a/keybert/backend/__init__.py +++ b/keybert/backend/__init__.py @@ -1,3 +1,4 @@ from ._base import BaseEmbedder +from ._sentencetransformers import SentenceTransformerBackend -__all__ = ["BaseEmbedder"] +__all__ = ["BaseEmbedder", "SentenceTransformerBackend"] diff --git a/keybert/backend/_sentencetransformers.py b/keybert/backend/_sentencetransformers.py index 977281c1..47fd1e73 100644 --- a/keybert/backend/_sentencetransformers.py +++ b/keybert/backend/_sentencetransformers.py @@ -12,6 +12,7 @@ class SentenceTransformerBackend(BaseEmbedder): Arguments: embedding_model: A sentence-transformers embedding model + encode_kwargs: Additional parameters for the SentenceTransformers.encode() method Usage: @@ -33,7 +34,9 @@ class SentenceTransformerBackend(BaseEmbedder): ``` """ - def __init__(self, embedding_model: Union[str, SentenceTransformer]): + def __init__( + self, embedding_model: Union[str, SentenceTransformer], **encode_kwargs + ): super().__init__() if isinstance(embedding_model, SentenceTransformer): @@ -46,6 +49,7 @@ def __init__(self, embedding_model: Union[str, SentenceTransformer]): "`from sentence_transformers import SentenceTransformer` \n" "`model = SentenceTransformer('all-MiniLM-L6-v2')`" ) + self.encode_kwargs = encode_kwargs def embed(self, documents: List[str], verbose: bool = False) -> np.ndarray: """Embed a list of n documents/words into an n-dimensional @@ -59,5 +63,6 @@ def embed(self, documents: List[str], verbose: bool = False) -> np.ndarray: Document/words embeddings with shape (n, m) with `n` documents/words that each have an embeddings size of `m` """ - embeddings = self.embedding_model.encode(documents, show_progress_bar=verbose) + self.encode_kwargs.update({"show_progress_bar": verbose}) + embeddings = self.embedding_model.encode(documents, **self.encode_kwargs) return embeddings diff --git a/tests/test_backend.py b/tests/test_backend.py new file mode 100644 index 00000000..7ef86174 --- /dev/null +++ b/tests/test_backend.py @@ -0,0 +1,41 @@ +import pytest +from keybert import KeyBERT +from keybert.backend import SentenceTransformerBackend +import sentence_transformers + +from sklearn.feature_extraction.text import CountVectorizer +from .utils import get_test_data + + +doc_one, doc_two = get_test_data() + + +@pytest.mark.parametrize("keyphrase_length", [(1, i + 1) for i in range(5)]) +@pytest.mark.parametrize( + "vectorizer", [None, CountVectorizer(ngram_range=(1, 1), stop_words="english")] +) +def test_single_doc_sentence_transformer_backend(keyphrase_length, vectorizer): + """Test whether the keywords are correctly extracted""" + top_n = 5 + + model_name = "paraphrase-MiniLM-L6-v2" + st_model = sentence_transformers.SentenceTransformer(model_name, device="cpu") + + kb_model = KeyBERT(model=SentenceTransformerBackend(st_model, batch_size=128)) + + keywords = kb_model.extract_keywords( + doc_one, + keyphrase_ngram_range=keyphrase_length, + min_df=1, + top_n=top_n, + vectorizer=vectorizer, + ) + + assert model_name in kb_model.model.embedding_model.tokenizer.name_or_path + assert isinstance(keywords, list) + assert isinstance(keywords[0], tuple) + assert isinstance(keywords[0][0], str) + assert isinstance(keywords[0][1], float) + assert len(keywords) == top_n + for keyword in keywords: + assert len(keyword[0].split(" ")) <= keyphrase_length[1]