Skip to content

Commit

Permalink
Merge branch 'dev' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
sokovninn authored Mar 6, 2024
2 parents d55eef2 + 3804ca0 commit 294467c
Show file tree
Hide file tree
Showing 11 changed files with 352 additions and 163 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ datadreamer --save_dir <directory> --class_names <objects> --prompts_number <num
- `--image_annotator`: Specify the image annotator, like `owlv2`. Default is `owlv2`.
- `--conf_threshold`: Confidence threshold for object detection. Default is 0.15.
- `--use_tta`: Toggle test time augmentation for object detection. Default is True.
- `--enhance_class_names`: Enhance class names with synonyms. Default is False.
- `--synonym_generator`: Enhance class names with synonyms. Default is `none`. Other options are `llm`, `wordnet`.
- `--use_image_tester`: Use image tester for image generation. Default is False.
- `--image_tester_patience`: Patience level for image tester. Default is 1.
- `--lm_quantization`: Quantization to use for Mistral language model. Choose between `none` and `4bit`. Default is `none`.
Expand Down
28 changes: 18 additions & 10 deletions datadreamer/pipelines/generate_dataset_from_scratch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@
)
from datadreamer.prompt_generation import (
LMPromptGenerator,
LMSynonymGenerator,
SimplePromptGenerator,
SynonymGenerator,
TinyLlamaLMPromptGenerator,
WordNetSynonymGenerator,
)

prompt_generators = {
Expand All @@ -30,6 +31,11 @@
"tiny": TinyLlamaLMPromptGenerator,
}

synonym_generators = {
"llm": LMSynonymGenerator,
"wordnet": WordNetSynonymGenerator,
}

image_generators = {
"sdxl": StableDiffusionImageGenerator,
"sdxl-turbo": StableDiffusionTurboImageGenerator,
Expand Down Expand Up @@ -99,6 +105,14 @@ def parse_args():
help="Image annotator to use",
)

parser.add_argument(
"--synonym_generator",
type=str,
default="none",
choices=["none", "llm", "wordnet"],
help="Image annotator to use",
)

parser.add_argument(
"--conf_threshold",
type=float,
Expand All @@ -113,13 +127,6 @@ def parse_args():
help="Whether to use test time augmentation for object detection",
)

parser.add_argument(
"--enhance_class_names",
default=False,
action="store_true",
help="Whether to enhance class names with synonyms",
)

parser.add_argument(
"--use_image_tester",
default=False,
Expand Down Expand Up @@ -326,8 +333,9 @@ def main():

# Synonym generation
synonym_dict = None
if args.enhance_class_names:
synonym_generator = SynonymGenerator(device=args.device)
if args.synonym_generator != "none":
synonym_generator_class = synonym_generators[args.synonym_generator]
synonym_generator = synonym_generator_class(device=args.device)
synonym_dict = synonym_generator.generate_synonyms_for_list(args.class_names)
synonym_generator.release(empty_cuda_cache=True)
synonym_generator.save_synonyms(
Expand Down
6 changes: 4 additions & 2 deletions datadreamer/prompt_generation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from __future__ import annotations

from .lm_prompt_generator import LMPromptGenerator
from .lm_synonym_generator import LMSynonymGenerator
from .simple_prompt_generator import SimplePromptGenerator
from .synonym_generator import SynonymGenerator
from .tinyllama_lm_prompt_generator import TinyLlamaLMPromptGenerator
from .wordnet_synonym_generator import WordNetSynonymGenerator

__all__ = [
"SimplePromptGenerator",
"LMPromptGenerator",
"SynonymGenerator",
"LMSynonymGenerator",
"TinyLlamaLMPromptGenerator",
"WordNetSynonymGenerator",
]
166 changes: 166 additions & 0 deletions datadreamer/prompt_generation/lm_synonym_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
import re
from typing import List, Optional

import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
Pipeline,
pipeline,
)

from datadreamer.prompt_generation.synonym_generator import SynonymGenerator


class LMSynonymGenerator(SynonymGenerator):
"""Synonym generator that generates synonyms for a list of words using a language
model.
Args:
synonyms_number (int): Number of synonyms to generate for each word.
seed (Optional[float]): Seed for randomization.
device (str): Device for model inference (default is "cuda").
Methods:
_init_lang_model(): Initializes the language model and tokenizer.
_generate_synonyms(prompt_text): Generates synonyms based on a given prompt text.
_extract_synonyms(text): Extracts synonyms from a text containing synonyms.
_create_prompt_text(word): Creates a prompt text for generating synonyms for a given word.
generate_synonyms(word): Generates synonyms for a single word and returns them in a list.
release(empty_cuda_cache): Releases resources (no action is taken in this implementation).
"""

def __init__(
self,
synonyms_number: int = 5,
seed: Optional[float] = 42,
device: str = "cuda",
) -> None:
"""Initializes the SynonymGenerator with parameters."""
super().__init__(synonyms_number, seed, device)
self.model, self.tokenizer, self.pipeline = self._init_lang_model()

def _init_lang_model(self) -> tuple[AutoModelForCausalLM, AutoTokenizer, Pipeline]:
"""Initializes the language model, tokenizer and pipeline for prompt generation.
Returns:
tuple: The initialized language model, tokenizer and pipeline.
"""
if self.device == "cpu":
print("Loading language model on CPU...")
model = AutoModelForCausalLM.from_pretrained(
"mistralai/Mistral-7B-Instruct-v0.1",
torch_dtype="auto",
device_map="cpu",
low_cpu_mem_usage=True,
)
else:
print("Loading FP16 language model on GPU...")
model = AutoModelForCausalLM.from_pretrained(
"mistralai/Mistral-7B-Instruct-v0.1",
torch_dtype=torch.float16,
trust_remote_code=True,
device_map=self.device,
)

tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
torch_dtype=torch.float16 if self.device == "cuda" else "auto",
device_map=self.device,
)
print("Done!")
return model, tokenizer, pipe

def _generate_synonyms(self, prompt_text: str) -> List[str]:
"""Generates synonyms based on a given prompt text.
Args:
prompt_text (str): The prompt text for generating synonyms.
Returns:
List[str]: A list of generated synonyms.
"""
sequences = self.pipeline(
prompt_text,
max_new_tokens=50,
do_sample=True,
num_return_sequences=1,
pad_token_id=self.tokenizer.eos_token_id,
)
generated_text = sequences[0]["generated_text"]

instructional_pattern = r"\[INST].*?\[/INST\]\s*"
# Remove the instructional text to isolate the caption
generated_text = (
re.sub(instructional_pattern, "", generated_text)
.replace('"', "")
.replace("'", "")
.replace(".", "")
)

# Process the generated text to extract synonyms
synonyms = self._extract_synonyms(generated_text)
return synonyms

def _extract_synonyms(self, text: str) -> List[str]:
"""Extracts synonyms from a text containing synonyms.
Args:
text (str): The text containing synonyms.
Returns:
List[str]: A list of extracted synonyms.
"""
synonyms = [
word.strip() for word in text.split(",")
] # Split and strip each synonym
return synonyms[: self.synonyms_number]

def _create_prompt_text(self, word: str) -> str:
"""Creates a prompt text for generating synonyms for a given word.
Args:
word (str): The word for which synonyms are generated.
Returns:
str: The prompt text for generating synonyms.
"""
return f"[INST] List {self.synonyms_number} most common synonyms for the word '{word}'. Write only synonyms separated by commas. [/INST]"

def generate_synonyms(self, word: str) -> List[str]:
"""Generates synonyms for a single word and returns them in a list.
Args:
word (str): The word for which synonyms are generated.
Returns:
List[str]: A list of generated synonyms for the word.
"""
prompt_text = self._create_prompt_text(word)
generated_synonyms = self._generate_synonyms(prompt_text)
return generated_synonyms

def release(self, empty_cuda_cache=False) -> None:
"""Releases resources and optionally empties the CUDA cache.
Args:
empty_cuda_cache (bool): Whether to empty the CUDA cache (default is False).
"""
self.model = self.model.to("cpu")
if empty_cuda_cache:
with torch.no_grad():
torch.cuda.empty_cache()


if __name__ == "__main__":
# Example usage
generator = LMSynonymGenerator(synonyms_number=3, device="cpu")
synonyms = generator.generate_synonyms_for_list(
["astronaut", "cat", "dog", "person", "horse"]
)
print(synonyms)
# generator.save_synonyms(synonyms, "synonyms.json")
# generator.release()
Loading

0 comments on commit 294467c

Please sign in to comment.