-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding LLM 4bit quantization & update LLM pipeline #32
Conversation
☂️ Python Coverage
Overall Coverage
New FilesNo new covered files... Modified Files
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left some minor comments, otherwise LGTM
pipe = pipeline( | ||
"text-generation", | ||
model=model, | ||
tokenizer=tokenizer, | ||
torch_dtype=torch.float16 if self.device == "cuda" else "auto", | ||
device_map=self.device, | ||
) | ||
print("Done!") | ||
return model, tokenizer, pipe |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@HonzaCuhel what is the reason for using a pipeline here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since I changed the code to use the pipeline in all the other LM classes, I thought it should be consistent everywhere, but I can change it, it's no problem.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, makes sense
def _test_prompt(self, prompt: str, selected_objects: List[str]) -> bool: | ||
"""Tests if the generated prompt is valid based on selected objects. | ||
|
||
Args: | ||
prompt (str): The generated prompt. | ||
selected_objects (List[str]): Objects to check in the prompt. | ||
|
||
Returns: | ||
bool: True if the prompt is valid, False otherwise. | ||
""" | ||
return prompt.lower().startswith( | ||
"a photo of" | ||
) # and all(obj.lower() in prompt.lower() for obj in selected_objects) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@HonzaCuhel do you have any other ideas on how to ensure the prompt follows the template?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, when I was experimenting with TinyLlama I tried several approaches. I tried to command the LM to generate the response that follows this template through the prompt. However, the generated responses weren't satisfactory enough. I also tried to edit the prompt in such a way that I explicitly wrote the words "A photo of" after the instruction end token so that the generated responses started with the template; however, in the case of TinyLlama (I haven't tried it with Mistral) it confused the LM, and it was generating gibberish. There are other ways to restrict the generated tokens; however, when I tried them some time ago, they weren't working very well, but that could change.
When I was experimenting with TinyLlama, I noticed that sometimes the generated prompts were still usable even though they didn't pass the test, e.g. ("A picture of ..." or "Picture showing ..." etc.). Of course, sometimes the prompts weren't good, so this test ensures at least some quality. So maybe we could also include in this test these cases "A picture of," etc., but maybe first try to measure how often this happens and if it is worth adding it.
I could look more into it, but I generally think we should have some simple regex tests to ensure the quality of the generated prompts for their simplicity, the fact that they are fast, and also that they are relatively powerful.
* Add quantization of Mistral * Add quantization flag and add 8bit quantization * Better printing * Update quantization * Add pipeline * Update Mistral LM generation * Update LM quantization * Add unittests & update pipeline & prompt generation * Correct tests * Update version of Mistral, update docstrings & README.md * Format code
This PR includes: