Skip to content

Commit

Permalink
Feature/advanced configuration update (#47)
Browse files Browse the repository at this point in the history
* Add annotator size && IOU threshold

* Add negative prompts, prompt prefix and suffix args & unittests for new args

* Fix image size for large OWLv2

* Update README.md & check args

* Correct clf unittests

* Fix unittests for negative prompt & prompt suffix

* [Automated] Updated coverage badge

* fix: change defaut value of generated synonyms to 3

* Update negative prompt description

* Change --negative_prompt & --prompt_suffix args to string

---------

Co-authored-by: Jan Cuhel <[email protected]>
Co-authored-by: GitHub Actions <[email protected]>
Co-authored-by: Nikita Sokovnin <[email protected]>
  • Loading branch information
4 people authored Mar 26, 2024
1 parent 40e25ed commit 45e741d
Show file tree
Hide file tree
Showing 11 changed files with 202 additions and 44 deletions.
29 changes: 17 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,30 +101,35 @@ datadreamer --save_dir <directory> --class_names <objects> --prompts_number <num
### Main Parameters

- `--save_dir` (required): Path to the directory for saving generated images and annotations.
- `--class_names` (required): Space-separated list of object names for image generation and annotation. Example: person moon robot.
- `--prompts_number` (optional): Number of prompts to generate for each object. Defaults to 10.
- `--annotate_only` (optional): Only annotate the images without generating new ones, prompt and image generator will be skipped. Defaults to False.
- `--class_names` (required): Space-separated list of object names for image generation and annotation. Example: `person moon robot`.
- `--prompts_number` (optional): Number of prompts to generate for each object. Defaults to `10`.
- `--annotate_only` (optional): Only annotate the images without generating new ones, prompt and image generator will be skipped. Defaults to `False`.

<a name="additional-parameters"></a>

### Additional Parameters

- `--task`: Choose between `detection` and `classification`. Default is `detection`.
- `--task`: Choose between detection and classification. Default is `detection`.
- `--num_objects_range`: Range of objects in a prompt. Default is 1 to 3.
- `--prompt_generator`: Choose between `simple`, `lm` (language model) and `tiny` (tiny LM). Default is `simple`.
- `--image_generator`: Choose image generator, e.g., `sdxl`, `sdxl-turbo` or `sdxl-lightning`. Default is `sdxl-turbo`.
- `--image_annotator`: Specify the image annotator, like `owlv2` for object detection or `clip` for image classification. Default is `owlv2`.
- `--conf_threshold`: Confidence threshold for annotation. Default is 0.15.
- `--use_tta`: Toggle test time augmentation for object detection. Default is True.
- `--conf_threshold`: Confidence threshold for annotation. Default is `0.15`.
- `--annotation_iou_threshold`: Intersection over Union (IoU) threshold for annotation. Default is `0.2`.
- `--prompt_prefix`: Prefix to add to every image generation prompt. Default is `""`.
- `--prompt_suffix`: Suffix to add to every image generation prompt, e.g., for adding details like resolution. Default is `", hd, 8k, highly detailed"`.
- `--negative_prompt`: Negative prompts to guide the generation away from certain features. Default is `"cartoon, blue skin, painting, scrispture, golden, illustration, worst quality, low quality, normal quality:2, unrealistic dream, low resolution, static, sd character, low quality, low resolution, greyscale, monochrome, nose, cropped, lowres, jpeg artifacts, deformed iris, deformed pupils, bad eyes, semi-realistic worst quality, bad lips, deformed mouth, deformed face, deformed fingers, bad anatomy"`.
- `--use_tta`: Toggle test time augmentation for object detection. Default is `True`.
- `--synonym_generator`: Enhance class names with synonyms. Default is `none`. Other options are `llm`, `wordnet`.
- `--use_image_tester`: Use image tester for image generation. Default is False.
- `--image_tester_patience`: Patience level for image tester. Default is 1.
- `--use_image_tester`: Use image tester for image generation. Default is `False`.
- `--image_tester_patience`: Patience level for image tester. Default is `1`.
- `--lm_quantization`: Quantization to use for Mistral language model. Choose between `none` and `4bit`. Default is `none`.
- `--annotator_size`: Size of the annotator model to use. Choose between `base` and `large`. Default is `base`.
- `--batch_size_prompt`: Batch size for prompt generation. Default is 64.
- `--batch_size_annotation`: Batch size for annotation. Default is 8.
- `--batch_size_image`: Batch size for image generation. Default is 1.
- `--device`: Choose between `cuda` and `cpu`. Default is cuda.
- `--seed`: Set a random seed for image and prompt generation. Default is 42.
- `--batch_size_annotation`: Batch size for annotation. Default is `8`.
- `--batch_size_image`: Batch size for image generation. Default is `1`.
- `--device`: Choose between `cuda` and `cpu`. Default is `cuda`.
- `--seed`: Set a random seed for image and prompt generation. Default is `42`.

<a name="available-models"></a>

Expand Down
33 changes: 29 additions & 4 deletions datadreamer/dataset_annotation/clip_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from PIL import Image
from transformers import CLIPModel, CLIPProcessor

from datadreamer.dataset_annotation.image_annotator import BaseAnnotator
from datadreamer.dataset_annotation.image_annotator import BaseAnnotator, TaskList


class CLIPAnnotator(BaseAnnotator):
Expand All @@ -19,8 +19,11 @@ class CLIPAnnotator(BaseAnnotator):
model (CLIPModel): The CLIP model for image-text similarity evaluation.
processor (CLIPProcessor): The processor for preparing inputs to the CLIP model.
device (str): The device on which the model will run ('cuda' for GPU, 'cpu' for CPU).
size (str): The size of the CLIP model to use ('base' or 'large').
Methods:
_init_processor(): Initializes the CLIP processor.
_init_model(): Initializes the CLIP model.
annotate_batch(image, prompts, conf_threshold, use_tta, synonym_dict): Annotates the given image with bounding boxes and labels.
release(empty_cuda_cache): Releases resources and optionally empties the CUDA cache.
"""
Expand All @@ -29,19 +32,41 @@ def __init__(
self,
seed: float = 42,
device: str = "cuda",
size: str = "base",
) -> None:
"""Initializes the CLIPAnnotator with a specific seed and device.
Args:
seed (float): Seed for reproducibility. Defaults to 42.
device (str): The device to run the model on. Defaults to 'cuda'.
"""
super().__init__(seed)
self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
super().__init__(seed, task_definition=TaskList.CLASSIFICATION)
self.size = size
self.model = self._init_model()
self.processor = self._init_processor()
self.device = device
self.model.to(self.device)

def _init_processor(self):
"""Initializes the CLIP processor.
Returns:
CLIPProcessor: The initialized CLIP processor.
"""
if self.size == "large":
return CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
return CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

def _init_model(self):
"""Initializes the CLIP model.
Returns:
CLIPModel: The initialized CLIP model.
"""
if self.size == "large":
return CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
return CLIPModel.from_pretrained("openai/clip-vit-base-patch32")

def annotate_batch(
self,
images: List[PIL.Image.Image],
Expand Down
33 changes: 31 additions & 2 deletions datadreamer/dataset_annotation/owlv2_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class OWLv2Annotator(BaseAnnotator):
model (Owlv2ForObjectDetection): The OWLv2 model for object detection.
processor (Owlv2Processor): The processor for the OWLv2 model.
device (str): The device on which the model will run ('cuda' for GPU, 'cpu' for CPU).
size (str): The size of the OWLv2 model to use ('base' or 'large').
Methods:
_init_model(): Initializes the OWLv2 model.
Expand All @@ -32,6 +33,7 @@ def __init__(
self,
seed: float = 42,
device: str = "cuda",
size: str = "base",
) -> None:
"""Initializes the OWLv2Annotator with a specific seed and device.
Expand All @@ -40,6 +42,7 @@ def __init__(
device (str): The device to run the model on. Defaults to 'cuda'.
"""
super().__init__(seed)
self.size = size
self.model = self._init_model()
self.processor = self._init_processor()
self.device = device
Expand All @@ -51,6 +54,10 @@ def _init_model(self):
Returns:
Owlv2ForObjectDetection: The initialized OWLv2 model.
"""
if self.size == "large":
return Owlv2ForObjectDetection.from_pretrained(
"google/owlv2-large-patch14-ensemble"
)
return Owlv2ForObjectDetection.from_pretrained(
"google/owlv2-base-patch16-ensemble"
)
Expand All @@ -61,6 +68,10 @@ def _init_processor(self):
Returns:
Owlv2Processor: The initialized processor.
"""
if self.size == "large":
return Owlv2Processor.from_pretrained(
"google/owlv2-large-patch14-ensemble", do_pad=False, do_resize=False
)
return Owlv2Processor.from_pretrained(
"google/owlv2-base-patch16-ensemble", do_pad=False, do_resize=False
)
Expand All @@ -86,7 +97,8 @@ def _generate_annotations(
target_sizes = torch.Tensor(images[0].size[::-1]).repeat((n, 1)).to(self.device)

# resize the images to the model's input size
images = [images[i].resize((960, 960)) for i in range(n)]
img_size = (1008, 1008) if self.size == "large" else (960, 960)
images = [images[i].resize(img_size) for i in range(n)]
inputs = self.processor(
text=batched_prompts,
images=images,
Expand Down Expand Up @@ -145,6 +157,7 @@ def annotate_batch(
images: List[PIL.Image.Image],
prompts: List[str],
conf_threshold: float = 0.1,
iou_threshold: float = 0.2,
use_tta: bool = False,
synonym_dict: dict[str, List[str]] | None = None,
) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
Expand All @@ -154,6 +167,7 @@ def annotate_batch(
images: The images to be annotated.
prompts: Prompts to guide the annotation.
conf_threshold (float, optional): Confidence threshold for the annotations. Defaults to 0.1.
iou_threshold (float, optional): Intersection over union threshold for non-maximum suppression. Defaults to 0.2.
use_tta (bool, optional): Flag to apply test-time augmentation. Defaults to False.
synonym_dict (dict, optional): Dictionary for handling synonyms in labels. Defaults to None.
Expand Down Expand Up @@ -233,7 +247,9 @@ def annotate_batch(

# output is a list of detections, each item is one tensor with shape (num_boxes, 6), 6 is for [xyxy, conf, cls].
output = non_max_suppression(
all_boxes_cat.unsqueeze(0), conf_thres=conf_threshold, iou_thres=0.2
all_boxes_cat.unsqueeze(0),
conf_thres=conf_threshold,
iou_thres=iou_threshold,
)

output_boxes = output[0][:, :4]
Expand Down Expand Up @@ -268,3 +284,16 @@ def release(self, empty_cuda_cache: bool = False) -> None:
if empty_cuda_cache:
with torch.no_grad():
torch.cuda.empty_cache()


if __name__ == "__main__":
import requests
from PIL import Image

url = "https://ultralytics.com/images/bus.jpg"
im = Image.open(requests.get(url, stream=True).raw)
annotator = OWLv2Annotator(device="cpu", size="large")
final_boxes, final_scores, final_labels = annotator.annotate_batch(
[im], ["robot", "horse"]
)
annotator.release()
48 changes: 46 additions & 2 deletions datadreamer/pipelines/generate_dataset_from_scratch.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,41 @@ def parse_args():
help="Image annotator to use",
)

parser.add_argument(
"--negative_prompt",
type=str,
default="cartoon, blue skin, painting, scrispture, golden, illustration, worst quality, low quality, normal quality:2, unrealistic dream, low resolution, static, sd character, low quality, low resolution, greyscale, monochrome, nose, cropped, lowres, jpeg artifacts, deformed iris, deformed pupils, bad eyes, semi-realistic worst quality, bad lips, deformed mouth, deformed face, deformed fingers, bad anatomy",
help="Negative prompt to guide the generation away from certain features",
)

parser.add_argument(
"--prompt_suffix",
type=str,
default=", hd, 8k, highly detailed",
help="Suffix to add to every image generation prompt, e.g., for adding details like resolution",
)

parser.add_argument(
"--prompt_prefix",
type=str,
default="",
help="Prefix to add to every image generation prompt",
)

parser.add_argument(
"--conf_threshold",
type=float,
default=0.15,
help="Confidence threshold for annotation",
)

parser.add_argument(
"--annotation_iou_threshold",
type=float,
default=0.2,
help="Intersection over Union (IoU) threshold for annotation",
)

parser.add_argument(
"--use_tta",
default=False,
Expand Down Expand Up @@ -156,6 +184,14 @@ def parse_args():
help="Quantization to use for Mistral language model",
)

parser.add_argument(
"--annotator_size",
type=str,
default="base",
choices=["base", "large"],
help="Size of the annotator model to use",
)

parser.add_argument(
"--batch_size_prompt",
type=int,
Expand Down Expand Up @@ -233,6 +269,10 @@ def check_args(args):
if not 0 <= args.conf_threshold <= 1:
raise ValueError("--conf_threshold must be between 0 and 1")

# Check annotation_iou_threshold
if not 0 <= args.annotation_iou_threshold <= 1:
raise ValueError("--annotation_iou_threshold must be between 0 and 1")

# Check image_tester_patience
if args.image_tester_patience < 0:
raise ValueError("--image_tester_patience must be a non-negative integer")
Expand Down Expand Up @@ -359,6 +399,9 @@ def main():
# Image generation
image_generator_class = image_generators[args.image_generator]
image_generator = image_generator_class(
prompt_prefix=args.prompt_prefix,
prompt_suffix=args.prompt_suffix,
negative_prompt=args.negative_prompt,
seed=args.seed,
use_clip_image_tester=args.use_image_tester,
image_tester_patience=args.image_tester_patience,
Expand Down Expand Up @@ -402,7 +445,7 @@ def main():
if args.task == "classification":
# Classification annotation
annotator_class = clf_annotators[args.image_annotator]
annotator = annotator_class(device=args.device)
annotator = annotator_class(device=args.device, size=args.annotator_size)

labels_list = []
# Split image_paths into batches
Expand Down Expand Up @@ -431,7 +474,7 @@ def main():
else:
# Annotation
annotator_class = det_annotators[args.image_annotator]
annotator = annotator_class(device=args.device)
annotator = annotator_class(device=args.device, size=args.annotator_size)

boxes_list = []
scores_list = []
Expand All @@ -453,6 +496,7 @@ def main():
images,
args.class_names,
conf_threshold=args.conf_threshold,
iou_threshold=args.annotation_iou_threshold,
use_tta=args.use_tta,
synonym_dict=synonym_dict,
)
Expand Down
2 changes: 1 addition & 1 deletion datadreamer/prompt_generation/lm_synonym_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class LMSynonymGenerator(SynonymGenerator):

def __init__(
self,
synonyms_number: int = 5,
synonyms_number: int = 3,
seed: Optional[float] = 42,
device: str = "cuda",
) -> None:
Expand Down
2 changes: 1 addition & 1 deletion datadreamer/prompt_generation/synonym_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class SynonymGenerator(ABC):

def __init__(
self,
synonyms_number: int = 5,
synonyms_number: int = 3,
seed: Optional[float] = 42,
device: str = "cuda",
) -> None:
Expand Down
2 changes: 1 addition & 1 deletion datadreamer/prompt_generation/wordnet_synonym_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class WordNetSynonymGenerator(SynonymGenerator):

def __init__(
self,
synonyms_number: int = 5,
synonyms_number: int = 3,
seed: Optional[float] = 42,
device: str = "cuda",
) -> None:
Expand Down
Loading

0 comments on commit 45e741d

Please sign in to comment.