-
Notifications
You must be signed in to change notification settings - Fork 357
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Text Generation Inference with JSON output #235
Merged
Merged
Changes from 6 commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
239fc43
added tgi huggingface
joaomsimoes 02be03f
added tgi huggingface
joaomsimoes 7f525c4
added tgi huggingface
joaomsimoes cdbc7ae
added tgi huggingface
joaomsimoes 420b4c9
added tgi huggingface
joaomsimoes d246fe0
added tgi huggingface
joaomsimoes c5ec185
passing InferenceClient and json schema when constructing TextGenerat…
joaomsimoes dfb8ba9
added inference kwargs to control temperature, max new tokens, etc...
joaomsimoes File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
from tqdm import tqdm | ||
from pydantic import BaseModel | ||
from huggingface_hub import InferenceClient | ||
from typing import Mapping, List, Any | ||
from keybert.llm._base import BaseLLM | ||
from keybert.llm._utils import process_candidate_keywords | ||
import json | ||
|
||
|
||
DEFAULT_PROMPT = """ | ||
I have the following document: | ||
* [DOCUMENT] | ||
|
||
Please give me the keywords that are present in this document and separate them with commas: | ||
""" | ||
|
||
|
||
class Keywords(BaseModel): | ||
keywords: List[str] | ||
|
||
|
||
class TextGenerationInference(BaseLLM): | ||
""" Text2Text or text generation with transformers | ||
|
||
NOTE: The resulting keywords are expected to be separated by commas so | ||
any changes to the prompt will have to make sure that the resulting | ||
keywords are comma-separated. | ||
|
||
Arguments: | ||
model: A transformers pipeline that should be initialized as "text-generation" | ||
for gpt-like models or "text2text-generation" for T5-like models. | ||
For example, `pipeline('text-generation', model='gpt2')`. If a string | ||
is passed, "text-generation" will be selected by default. | ||
prompt: The prompt to be used in the model. If no prompt is given, | ||
`self.default_prompt_` is used instead. | ||
NOTE: Use `"[KEYWORDS]"` and `"[DOCUMENTS]"` in the prompt | ||
to decide where the keywords and documents need to be | ||
inserted. | ||
pipeline_kwargs: Kwargs that you can pass to the transformers.pipeline | ||
when it is called. | ||
random_state: A random state to be passed to `transformers.set_seed` | ||
verbose: Set this to True if you want to see a progress bar for the | ||
keyword extraction. | ||
|
||
Usage: | ||
|
||
To use a gpt-like model: | ||
|
||
```python | ||
from keybert.llm import TextGeneration | ||
from keybert import KeyLLM | ||
|
||
# Create your LLM | ||
generator = pipeline('text-generation', model='gpt2') | ||
llm = TextGeneration(generator) | ||
|
||
# Load it in KeyLLM | ||
kw_model = KeyLLM(llm) | ||
|
||
# Extract keywords | ||
document = "The website mentions that it only takes a couple of days to deliver but I still have not received mine." | ||
keywords = kw_model.extract_keywords(document) | ||
``` | ||
|
||
You can use a custom prompt and decide where the document should | ||
be inserted with the `[DOCUMENT]` tag: | ||
|
||
```python | ||
from keybert.llm import TextGeneration | ||
|
||
prompt = "I have the following documents '[DOCUMENT]'. Please give me the keywords that are present in this document and separate them with commas:" | ||
|
||
# Create your representation model | ||
generator = pipeline('text2text-generation', model='google/flan-t5-base') | ||
llm = TextGeneration(generator) | ||
``` | ||
""" | ||
def __init__(self, | ||
url: str, | ||
prompt: str = None, | ||
client_kwargs: Mapping[str, Any] = {}, | ||
verbose: bool = False | ||
): | ||
self.client = InferenceClient(model=url) | ||
self.prompt = prompt if prompt is not None else DEFAULT_PROMPT | ||
self.default_prompt_ = DEFAULT_PROMPT | ||
self.client_kwargs = client_kwargs | ||
self.verbose = verbose | ||
|
||
def extract_keywords(self, documents: List[str], candidate_keywords: List[List[str]] = None): | ||
""" Extract topics | ||
|
||
Arguments: | ||
documents: The documents to extract keywords from | ||
candidate_keywords: A list of candidate keywords that the LLM will fine-tune | ||
For example, it will create a nicer representation of | ||
the candidate keywords, remove redundant keywords, or | ||
shorten them depending on the input prompt. | ||
|
||
Returns: | ||
all_keywords: All keywords for each document | ||
""" | ||
all_keywords = [] | ||
candidate_keywords = process_candidate_keywords(documents, candidate_keywords) | ||
|
||
for document, candidates in tqdm(zip(documents, candidate_keywords), disable=not self.verbose): | ||
prompt = self.prompt.replace("[DOCUMENT]", document) | ||
if candidates is not None: | ||
prompt = prompt.replace("[CANDIDATES]", ", ".join(candidates)) | ||
|
||
# Extract result from generator and use that as label | ||
response = self.client.text_generation( | ||
prompt=prompt, | ||
grammar={"type": "json", "value": Keywords.schema()}, | ||
**self.client_kwargs | ||
) | ||
all_keywords = json.loads(response)["keywords"] | ||
|
||
return all_keywords |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I propose to pass the entire
InferenceClient
rather than just the URL since not all its parameters are exposed at the moment. Moreover, it would then follow the same structure as is done with OpenAI in this repo.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I updated the code. Now when constructing TextGenerationInference it accepts InferenceClient. I also added json_schema in case that someone is looking for a different output result.