Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Text Generation Inference with JSON output #235

Merged
merged 8 commits into from
Jun 23, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions keybert/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
from keybert._utils import NotInstalled
from keybert.llm._base import BaseLLM

# TextGenerationInference
try:
from keybert.llm._textgenerationinference import TextGenerationInference
except ModuleNotFoundError:
msg = "`pip install huggingface-hub pydantic ` \n\n"
TextGenerationInference = NotInstalled("TextGenerationInference", "huggingface-hub", custom_msg=msg)

# TextGeneration
try:
Expand Down Expand Up @@ -43,6 +49,7 @@
"Cohere",
"OpenAI",
"TextGeneration",
"TextGenerationInference",
"LangChain",
"LiteLLM"
]
119 changes: 119 additions & 0 deletions keybert/llm/_textgenerationinference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
from tqdm import tqdm
from pydantic import BaseModel
from huggingface_hub import InferenceClient
from typing import Mapping, List, Any
from keybert.llm._base import BaseLLM
from keybert.llm._utils import process_candidate_keywords
import json


DEFAULT_PROMPT = """
I have the following document:
* [DOCUMENT]

Please give me the keywords that are present in this document and separate them with commas:
"""


class Keywords(BaseModel):
keywords: List[str]


class TextGenerationInference(BaseLLM):
""" Text2Text or text generation with transformers

NOTE: The resulting keywords are expected to be separated by commas so
any changes to the prompt will have to make sure that the resulting
keywords are comma-separated.

Arguments:
model: A transformers pipeline that should be initialized as "text-generation"
for gpt-like models or "text2text-generation" for T5-like models.
For example, `pipeline('text-generation', model='gpt2')`. If a string
is passed, "text-generation" will be selected by default.
prompt: The prompt to be used in the model. If no prompt is given,
`self.default_prompt_` is used instead.
NOTE: Use `"[KEYWORDS]"` and `"[DOCUMENTS]"` in the prompt
to decide where the keywords and documents need to be
inserted.
pipeline_kwargs: Kwargs that you can pass to the transformers.pipeline
when it is called.
random_state: A random state to be passed to `transformers.set_seed`
verbose: Set this to True if you want to see a progress bar for the
keyword extraction.

Usage:

To use a gpt-like model:

```python
from keybert.llm import TextGeneration
from keybert import KeyLLM

# Create your LLM
generator = pipeline('text-generation', model='gpt2')
llm = TextGeneration(generator)

# Load it in KeyLLM
kw_model = KeyLLM(llm)

# Extract keywords
document = "The website mentions that it only takes a couple of days to deliver but I still have not received mine."
keywords = kw_model.extract_keywords(document)
```

You can use a custom prompt and decide where the document should
be inserted with the `[DOCUMENT]` tag:

```python
from keybert.llm import TextGeneration

prompt = "I have the following documents '[DOCUMENT]'. Please give me the keywords that are present in this document and separate them with commas:"

# Create your representation model
generator = pipeline('text2text-generation', model='google/flan-t5-base')
llm = TextGeneration(generator)
```
"""
def __init__(self,
url: str,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I propose to pass the entire InferenceClient rather than just the URL since not all its parameters are exposed at the moment. Moreover, it would then follow the same structure as is done with OpenAI in this repo.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the code. Now when constructing TextGenerationInference it accepts InferenceClient. I also added json_schema in case that someone is looking for a different output result.

prompt: str = None,
client_kwargs: Mapping[str, Any] = {},
verbose: bool = False
):
self.client = InferenceClient(model=url)
self.prompt = prompt if prompt is not None else DEFAULT_PROMPT
self.default_prompt_ = DEFAULT_PROMPT
self.client_kwargs = client_kwargs
self.verbose = verbose

def extract_keywords(self, documents: List[str], candidate_keywords: List[List[str]] = None):
""" Extract topics

Arguments:
documents: The documents to extract keywords from
candidate_keywords: A list of candidate keywords that the LLM will fine-tune
For example, it will create a nicer representation of
the candidate keywords, remove redundant keywords, or
shorten them depending on the input prompt.

Returns:
all_keywords: All keywords for each document
"""
all_keywords = []
candidate_keywords = process_candidate_keywords(documents, candidate_keywords)

for document, candidates in tqdm(zip(documents, candidate_keywords), disable=not self.verbose):
prompt = self.prompt.replace("[DOCUMENT]", document)
if candidates is not None:
prompt = prompt.replace("[CANDIDATES]", ", ".join(candidates))

# Extract result from generator and use that as label
response = self.client.text_generation(
prompt=prompt,
grammar={"type": "json", "value": Keywords.schema()},
**self.client_kwargs
)
all_keywords = json.loads(response)["keywords"]

return all_keywords
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ dependencies = [
"numpy>=1.18.5",
"rich>=10.4.0",
"scikit-learn>=0.22.2",
"sentence-transformers>=0.3.8",
"sentence-transformers>=0.3.8"
]

[project.optional-dependencies]
Expand Down Expand Up @@ -70,6 +70,10 @@ test = [
"pytest-cov>=2.6.1",
"pytest>=5.4.3",
]
tgi = [
"huggingface-hub>=0.23.3",
"pydantic>=2.7.4"
]
use = [
"tensorflow",
"tensorflow_hub",
Expand Down