diff --git a/README.md b/README.md index c59f83d..0f5c9c9 100644 --- a/README.md +++ b/README.md @@ -101,30 +101,35 @@ datadreamer --save_dir --class_names --prompts_number ### Additional Parameters -- `--task`: Choose between `detection` and `classification`. Default is `detection`. +- `--task`: Choose between detection and classification. Default is `detection`. - `--num_objects_range`: Range of objects in a prompt. Default is 1 to 3. - `--prompt_generator`: Choose between `simple`, `lm` (language model) and `tiny` (tiny LM). Default is `simple`. - `--image_generator`: Choose image generator, e.g., `sdxl`, `sdxl-turbo` or `sdxl-lightning`. Default is `sdxl-turbo`. - `--image_annotator`: Specify the image annotator, like `owlv2` for object detection or `clip` for image classification. Default is `owlv2`. -- `--conf_threshold`: Confidence threshold for annotation. Default is 0.15. -- `--use_tta`: Toggle test time augmentation for object detection. Default is True. +- `--conf_threshold`: Confidence threshold for annotation. Default is `0.15`. +- `--annotation_iou_threshold`: Intersection over Union (IoU) threshold for annotation. Default is `0.2`. +- `--prompt_prefix`: Prefix to add to every image generation prompt. Default is `""`. +- `--prompt_suffix`: Suffix to add to every image generation prompt, e.g., for adding details like resolution. Default is `", hd, 8k, highly detailed"`. +- `--negative_prompt`: Negative prompts to guide the generation away from certain features. Default is `"cartoon, blue skin, painting, scrispture, golden, illustration, worst quality, low quality, normal quality:2, unrealistic dream, low resolution, static, sd character, low quality, low resolution, greyscale, monochrome, nose, cropped, lowres, jpeg artifacts, deformed iris, deformed pupils, bad eyes, semi-realistic worst quality, bad lips, deformed mouth, deformed face, deformed fingers, bad anatomy"`. +- `--use_tta`: Toggle test time augmentation for object detection. Default is `True`. - `--synonym_generator`: Enhance class names with synonyms. Default is `none`. Other options are `llm`, `wordnet`. -- `--use_image_tester`: Use image tester for image generation. Default is False. -- `--image_tester_patience`: Patience level for image tester. Default is 1. +- `--use_image_tester`: Use image tester for image generation. Default is `False`. +- `--image_tester_patience`: Patience level for image tester. Default is `1`. - `--lm_quantization`: Quantization to use for Mistral language model. Choose between `none` and `4bit`. Default is `none`. +- `--annotator_size`: Size of the annotator model to use. Choose between `base` and `large`. Default is `base`. - `--batch_size_prompt`: Batch size for prompt generation. Default is 64. -- `--batch_size_annotation`: Batch size for annotation. Default is 8. -- `--batch_size_image`: Batch size for image generation. Default is 1. -- `--device`: Choose between `cuda` and `cpu`. Default is cuda. -- `--seed`: Set a random seed for image and prompt generation. Default is 42. +- `--batch_size_annotation`: Batch size for annotation. Default is `8`. +- `--batch_size_image`: Batch size for image generation. Default is `1`. +- `--device`: Choose between `cuda` and `cpu`. Default is `cuda`. +- `--seed`: Set a random seed for image and prompt generation. Default is `42`. diff --git a/datadreamer/dataset_annotation/clip_annotator.py b/datadreamer/dataset_annotation/clip_annotator.py index 38a778c..ff7b9aa 100644 --- a/datadreamer/dataset_annotation/clip_annotator.py +++ b/datadreamer/dataset_annotation/clip_annotator.py @@ -8,7 +8,7 @@ from PIL import Image from transformers import CLIPModel, CLIPProcessor -from datadreamer.dataset_annotation.image_annotator import BaseAnnotator +from datadreamer.dataset_annotation.image_annotator import BaseAnnotator, TaskList class CLIPAnnotator(BaseAnnotator): @@ -19,8 +19,11 @@ class CLIPAnnotator(BaseAnnotator): model (CLIPModel): The CLIP model for image-text similarity evaluation. processor (CLIPProcessor): The processor for preparing inputs to the CLIP model. device (str): The device on which the model will run ('cuda' for GPU, 'cpu' for CPU). + size (str): The size of the CLIP model to use ('base' or 'large'). Methods: + _init_processor(): Initializes the CLIP processor. + _init_model(): Initializes the CLIP model. annotate_batch(image, prompts, conf_threshold, use_tta, synonym_dict): Annotates the given image with bounding boxes and labels. release(empty_cuda_cache): Releases resources and optionally empties the CUDA cache. """ @@ -29,6 +32,7 @@ def __init__( self, seed: float = 42, device: str = "cuda", + size: str = "base", ) -> None: """Initializes the CLIPAnnotator with a specific seed and device. @@ -36,12 +40,33 @@ def __init__( seed (float): Seed for reproducibility. Defaults to 42. device (str): The device to run the model on. Defaults to 'cuda'. """ - super().__init__(seed) - self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") - self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") + super().__init__(seed, task_definition=TaskList.CLASSIFICATION) + self.size = size + self.model = self._init_model() + self.processor = self._init_processor() self.device = device self.model.to(self.device) + def _init_processor(self): + """Initializes the CLIP processor. + + Returns: + CLIPProcessor: The initialized CLIP processor. + """ + if self.size == "large": + return CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") + return CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") + + def _init_model(self): + """Initializes the CLIP model. + + Returns: + CLIPModel: The initialized CLIP model. + """ + if self.size == "large": + return CLIPModel.from_pretrained("openai/clip-vit-large-patch14") + return CLIPModel.from_pretrained("openai/clip-vit-base-patch32") + def annotate_batch( self, images: List[PIL.Image.Image], diff --git a/datadreamer/dataset_annotation/owlv2_annotator.py b/datadreamer/dataset_annotation/owlv2_annotator.py index 2a84875..1d4243c 100644 --- a/datadreamer/dataset_annotation/owlv2_annotator.py +++ b/datadreamer/dataset_annotation/owlv2_annotator.py @@ -20,6 +20,7 @@ class OWLv2Annotator(BaseAnnotator): model (Owlv2ForObjectDetection): The OWLv2 model for object detection. processor (Owlv2Processor): The processor for the OWLv2 model. device (str): The device on which the model will run ('cuda' for GPU, 'cpu' for CPU). + size (str): The size of the OWLv2 model to use ('base' or 'large'). Methods: _init_model(): Initializes the OWLv2 model. @@ -32,6 +33,7 @@ def __init__( self, seed: float = 42, device: str = "cuda", + size: str = "base", ) -> None: """Initializes the OWLv2Annotator with a specific seed and device. @@ -40,6 +42,7 @@ def __init__( device (str): The device to run the model on. Defaults to 'cuda'. """ super().__init__(seed) + self.size = size self.model = self._init_model() self.processor = self._init_processor() self.device = device @@ -51,6 +54,10 @@ def _init_model(self): Returns: Owlv2ForObjectDetection: The initialized OWLv2 model. """ + if self.size == "large": + return Owlv2ForObjectDetection.from_pretrained( + "google/owlv2-large-patch14-ensemble" + ) return Owlv2ForObjectDetection.from_pretrained( "google/owlv2-base-patch16-ensemble" ) @@ -61,6 +68,10 @@ def _init_processor(self): Returns: Owlv2Processor: The initialized processor. """ + if self.size == "large": + return Owlv2Processor.from_pretrained( + "google/owlv2-large-patch14-ensemble", do_pad=False, do_resize=False + ) return Owlv2Processor.from_pretrained( "google/owlv2-base-patch16-ensemble", do_pad=False, do_resize=False ) @@ -86,7 +97,8 @@ def _generate_annotations( target_sizes = torch.Tensor(images[0].size[::-1]).repeat((n, 1)).to(self.device) # resize the images to the model's input size - images = [images[i].resize((960, 960)) for i in range(n)] + img_size = (1008, 1008) if self.size == "large" else (960, 960) + images = [images[i].resize(img_size) for i in range(n)] inputs = self.processor( text=batched_prompts, images=images, @@ -145,6 +157,7 @@ def annotate_batch( images: List[PIL.Image.Image], prompts: List[str], conf_threshold: float = 0.1, + iou_threshold: float = 0.2, use_tta: bool = False, synonym_dict: dict[str, List[str]] | None = None, ) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]: @@ -154,6 +167,7 @@ def annotate_batch( images: The images to be annotated. prompts: Prompts to guide the annotation. conf_threshold (float, optional): Confidence threshold for the annotations. Defaults to 0.1. + iou_threshold (float, optional): Intersection over union threshold for non-maximum suppression. Defaults to 0.2. use_tta (bool, optional): Flag to apply test-time augmentation. Defaults to False. synonym_dict (dict, optional): Dictionary for handling synonyms in labels. Defaults to None. @@ -233,7 +247,9 @@ def annotate_batch( # output is a list of detections, each item is one tensor with shape (num_boxes, 6), 6 is for [xyxy, conf, cls]. output = non_max_suppression( - all_boxes_cat.unsqueeze(0), conf_thres=conf_threshold, iou_thres=0.2 + all_boxes_cat.unsqueeze(0), + conf_thres=conf_threshold, + iou_thres=iou_threshold, ) output_boxes = output[0][:, :4] @@ -268,3 +284,16 @@ def release(self, empty_cuda_cache: bool = False) -> None: if empty_cuda_cache: with torch.no_grad(): torch.cuda.empty_cache() + + +if __name__ == "__main__": + import requests + from PIL import Image + + url = "https://ultralytics.com/images/bus.jpg" + im = Image.open(requests.get(url, stream=True).raw) + annotator = OWLv2Annotator(device="cpu", size="large") + final_boxes, final_scores, final_labels = annotator.annotate_batch( + [im], ["robot", "horse"] + ) + annotator.release() diff --git a/datadreamer/pipelines/generate_dataset_from_scratch.py b/datadreamer/pipelines/generate_dataset_from_scratch.py index 7fcb87c..f1d3ee2 100644 --- a/datadreamer/pipelines/generate_dataset_from_scratch.py +++ b/datadreamer/pipelines/generate_dataset_from_scratch.py @@ -120,6 +120,27 @@ def parse_args(): help="Image annotator to use", ) + parser.add_argument( + "--negative_prompt", + type=str, + default="cartoon, blue skin, painting, scrispture, golden, illustration, worst quality, low quality, normal quality:2, unrealistic dream, low resolution, static, sd character, low quality, low resolution, greyscale, monochrome, nose, cropped, lowres, jpeg artifacts, deformed iris, deformed pupils, bad eyes, semi-realistic worst quality, bad lips, deformed mouth, deformed face, deformed fingers, bad anatomy", + help="Negative prompt to guide the generation away from certain features", + ) + + parser.add_argument( + "--prompt_suffix", + type=str, + default=", hd, 8k, highly detailed", + help="Suffix to add to every image generation prompt, e.g., for adding details like resolution", + ) + + parser.add_argument( + "--prompt_prefix", + type=str, + default="", + help="Prefix to add to every image generation prompt", + ) + parser.add_argument( "--conf_threshold", type=float, @@ -127,6 +148,13 @@ def parse_args(): help="Confidence threshold for annotation", ) + parser.add_argument( + "--annotation_iou_threshold", + type=float, + default=0.2, + help="Intersection over Union (IoU) threshold for annotation", + ) + parser.add_argument( "--use_tta", default=False, @@ -156,6 +184,14 @@ def parse_args(): help="Quantization to use for Mistral language model", ) + parser.add_argument( + "--annotator_size", + type=str, + default="base", + choices=["base", "large"], + help="Size of the annotator model to use", + ) + parser.add_argument( "--batch_size_prompt", type=int, @@ -233,6 +269,10 @@ def check_args(args): if not 0 <= args.conf_threshold <= 1: raise ValueError("--conf_threshold must be between 0 and 1") + # Check annotation_iou_threshold + if not 0 <= args.annotation_iou_threshold <= 1: + raise ValueError("--annotation_iou_threshold must be between 0 and 1") + # Check image_tester_patience if args.image_tester_patience < 0: raise ValueError("--image_tester_patience must be a non-negative integer") @@ -359,6 +399,9 @@ def main(): # Image generation image_generator_class = image_generators[args.image_generator] image_generator = image_generator_class( + prompt_prefix=args.prompt_prefix, + prompt_suffix=args.prompt_suffix, + negative_prompt=args.negative_prompt, seed=args.seed, use_clip_image_tester=args.use_image_tester, image_tester_patience=args.image_tester_patience, @@ -402,7 +445,7 @@ def main(): if args.task == "classification": # Classification annotation annotator_class = clf_annotators[args.image_annotator] - annotator = annotator_class(device=args.device) + annotator = annotator_class(device=args.device, size=args.annotator_size) labels_list = [] # Split image_paths into batches @@ -431,7 +474,7 @@ def main(): else: # Annotation annotator_class = det_annotators[args.image_annotator] - annotator = annotator_class(device=args.device) + annotator = annotator_class(device=args.device, size=args.annotator_size) boxes_list = [] scores_list = [] @@ -453,6 +496,7 @@ def main(): images, args.class_names, conf_threshold=args.conf_threshold, + iou_threshold=args.annotation_iou_threshold, use_tta=args.use_tta, synonym_dict=synonym_dict, ) diff --git a/datadreamer/prompt_generation/lm_synonym_generator.py b/datadreamer/prompt_generation/lm_synonym_generator.py index e1c4401..fc86db8 100644 --- a/datadreamer/prompt_generation/lm_synonym_generator.py +++ b/datadreamer/prompt_generation/lm_synonym_generator.py @@ -34,7 +34,7 @@ class LMSynonymGenerator(SynonymGenerator): def __init__( self, - synonyms_number: int = 5, + synonyms_number: int = 3, seed: Optional[float] = 42, device: str = "cuda", ) -> None: diff --git a/datadreamer/prompt_generation/synonym_generator.py b/datadreamer/prompt_generation/synonym_generator.py index b6a95e1..ec3f306 100644 --- a/datadreamer/prompt_generation/synonym_generator.py +++ b/datadreamer/prompt_generation/synonym_generator.py @@ -29,7 +29,7 @@ class SynonymGenerator(ABC): def __init__( self, - synonyms_number: int = 5, + synonyms_number: int = 3, seed: Optional[float] = 42, device: str = "cuda", ) -> None: diff --git a/datadreamer/prompt_generation/wordnet_synonym_generator.py b/datadreamer/prompt_generation/wordnet_synonym_generator.py index 4f8f6f8..994e74d 100644 --- a/datadreamer/prompt_generation/wordnet_synonym_generator.py +++ b/datadreamer/prompt_generation/wordnet_synonym_generator.py @@ -25,7 +25,7 @@ class WordNetSynonymGenerator(SynonymGenerator): def __init__( self, - synonyms_number: int = 5, + synonyms_number: int = 3, seed: Optional[float] = 42, device: str = "cuda", ) -> None: diff --git a/examples/generate_dataset_and_train_yolo.ipynb b/examples/generate_dataset_and_train_yolo.ipynb index 5749dfb..b113b7e 100644 --- a/examples/generate_dataset_and_train_yolo.ipynb +++ b/examples/generate_dataset_and_train_yolo.ipynb @@ -74,25 +74,30 @@ "source": [ "### Parameters\n", "- `--save_dir` (required): Path to the directory for saving generated images and annotations.\n", - "- `--class_names` (required): Space-separated list of object names for image generation and annotation. Example: person moon robot.\n", - "- `--prompts_number` (optional): Number of prompts to generate for each object. Defaults to 10.\n", - "- `--annotate_only` (optional): Only annotate the images without generating new ones, prompt and image generator will be skipped. Defaults to False.\n", - "- `--task`: Choose between `detection` and `classification`. Default is `detection`.\n", + "- `--class_names` (required): Space-separated list of object names for image generation and annotation. Example: `person moon robot`.\n", + "- `--prompts_number` (optional): Number of prompts to generate for each object. Defaults to `10`.\n", + "- `--annotate_only` (optional): Only annotate the images without generating new ones, prompt and image generator will be skipped. Defaults to `False`.\n", + "- `--task`: Choose between detection and classification. Default is `detection`.\n", "- `--num_objects_range`: Range of objects in a prompt. Default is 1 to 3.\n", "- `--prompt_generator`: Choose between `simple`, `lm` (language model) and `tiny` (tiny LM). Default is `simple`.\n", "- `--image_generator`: Choose image generator, e.g., `sdxl`, `sdxl-turbo` or `sdxl-lightning`. Default is `sdxl-turbo`.\n", "- `--image_annotator`: Specify the image annotator, like `owlv2` for object detection or `clip` for image classification. Default is `owlv2`.\n", - "- `--conf_threshold`: Confidence threshold for annotation. Default is 0.15.\n", - "- `--use_tta`: Toggle test time augmentation for object detection. Default is True.\n", + "- `--conf_threshold`: Confidence threshold for annotation. Default is `0.15`.\n", + "- `--annotation_iou_threshold`: Intersection over Union (IoU) threshold for annotation. Default is `0.2`.\n", + "- `--prompt_prefix`: Prefix to add to every image generation prompt. Default is `\"\"`.\n", + "- `--prompt_suffix`: Suffix to add to every image generation prompt, e.g., for adding details like resolution. Default is `\", hd, 8k, highly detailed\"`.\n", + "- `--negative_prompt`: Negative prompts to guide the generation away from certain features. Default is `\"cartoon, blue skin, painting, scrispture, golden, illustration, worst quality, low quality, normal quality:2, unrealistic dream, low resolution, static, sd character, low quality, low resolution, greyscale, monochrome, nose, cropped, lowres, jpeg artifacts, deformed iris, deformed pupils, bad eyes, semi-realistic worst quality, bad lips, deformed mouth, deformed face, deformed fingers, bad anatomy\"`.\n", + "- `--use_tta`: Toggle test time augmentation for object detection. Default is `True`.\n", "- `--synonym_generator`: Enhance class names with synonyms. Default is `none`. Other options are `llm`, `wordnet`.\n", - "- `--use_image_tester`: Use image tester for image generation. Default is False.\n", - "- `--image_tester_patience`: Patience level for image tester. Default is 1.\n", + "- `--use_image_tester`: Use image tester for image generation. Default is `False`.\n", + "- `--image_tester_patience`: Patience level for image tester. Default is `1`.\n", "- `--lm_quantization`: Quantization to use for Mistral language model. Choose between `none` and `4bit`. Default is `none`.\n", + "- `--annotator_size`: Size of the annotator model to use. Choose between `base` and `large`. Default is `base`.\n", "- `--batch_size_prompt`: Batch size for prompt generation. Default is 64.\n", - "- `--batch_size_annotation`: Batch size for annotation. Default is 8.\n", - "- `--batch_size_image`: Batch size for image generation. Default is 1.\n", - "- `--device`: Choose between `cuda` and `cpu`. Default is cuda.\n", - "- `--seed`: Set a random seed for image and prompt generation. Default is 42.\n" + "- `--batch_size_annotation`: Batch size for annotation. Default is `8`.\n", + "- `--batch_size_image`: Batch size for image generation. Default is `1`.\n", + "- `--device`: Choose between `cuda` and `cpu`. Default is `cuda`.\n", + "- `--seed`: Set a random seed for image and prompt generation. Default is `42`.\n" ] }, { diff --git a/media/coverage_badge.svg b/media/coverage_badge.svg index 9d027c7..4f8c185 100644 --- a/media/coverage_badge.svg +++ b/media/coverage_badge.svg @@ -15,7 +15,7 @@ coverage coverage - 49% - 49% + 50% + 50% diff --git a/tests/integration/test_pipeline.py b/tests/integration/test_pipeline.py index 5fbe926..c665e29 100644 --- a/tests/integration/test_pipeline.py +++ b/tests/integration/test_pipeline.py @@ -122,6 +122,12 @@ def test_invalid_device(): _check_wrong_argument_choice(cmd) +def test_invalid_annotator_size(): + # Define the cmd + cmd = "datadreamer --annotator_size invalide_value" + _check_wrong_argument_choice(cmd) + + def test_empty_class_names(): # Define the cmd cmd = "datadreamer --class_names []" @@ -152,6 +158,18 @@ def test_big_conf_threshold(): _check_wrong_value(cmd) +def test_negative_annotation_iou_threshold(): + # Define the cmd + cmd = "datadreamer --annotation_iou_threshold -1" + _check_wrong_value(cmd) + + +def test_big_annotation_iou_threshold(): + # Define the cmd + cmd = "datadreamer --annotation_iou_threshold 10" + _check_wrong_value(cmd) + + def test_invalid_image_tester_patience(): # Define the cmd cmd = "datadreamer --image_tester_patience -1" @@ -651,6 +669,7 @@ def test_cpu_simple_sdxl_turbo_classification_pipeline(): f"--prompts_number 1 " f"--prompt_generator simple " f"--num_objects_range 1 2 " + f"--image_annotator clip " f"--image_generator sdxl-turbo " f"--use_image_tester " f"--device cpu" @@ -674,6 +693,7 @@ def test_cuda_simple_sdxl_turbo_classification_pipeline(): f"--prompts_number 1 " f"--prompt_generator simple " f"--num_objects_range 1 2 " + f"--image_annotator clip " f"--image_generator sdxl-turbo " f"--use_image_tester " f"--device cuda" @@ -698,6 +718,7 @@ def test_cuda_simple_llm_synonym_sdxl_turbo_classification_pipeline(): f"--prompt_generator simple " f"--num_objects_range 1 2 " f"--image_generator sdxl-turbo " + f"--image_annotator clip " f"--use_image_tester " f"--synonym_generator llm " f"--device cuda" @@ -721,6 +742,7 @@ def test_cuda_simple_wordnet_synonym_sdxl_turbo_classification_pipeline(): f"--prompts_number 1 " f"--prompt_generator simple " f"--num_objects_range 1 2 " + f"--image_annotator clip " f"--image_generator sdxl-turbo " f"--use_image_tester " f"--synonym_generator wordnet " @@ -744,6 +766,7 @@ def test_cpu_simple_sdxl_classification_pipeline(): f"--class_names alien mars cat " f"--prompts_number 1 " f"--prompt_generator simple " + f"--image_annotator clip " f"--num_objects_range 1 2 " f"--image_generator sdxl " f"--use_image_tester " @@ -767,6 +790,7 @@ def test_cuda_simple_sdxl_classification_pipeline(): f"--class_names alien mars cat " f"--prompts_number 1 " f"--prompt_generator simple " + f"--image_annotator clip " f"--num_objects_range 1 2 " f"--image_generator sdxl " f"--use_image_tester " @@ -794,6 +818,7 @@ def test_cpu_lm_sdxl_turbo_classification_pipeline(): f"--prompts_number 1 " f"--prompt_generator lm " f"--num_objects_range 1 2 " + f"--image_annotator clip " f"--image_generator sdxl-turbo " f"--use_image_tester " f"--device cpu" @@ -817,6 +842,7 @@ def test_cuda_lm_sdxl_turbo_classification_pipeline(): f"--prompts_number 1 " f"--prompt_generator lm " f"--num_objects_range 1 2 " + f"--image_annotator clip " f"--image_generator sdxl-turbo " f"--use_image_tester " f"--device cuda" @@ -840,6 +866,7 @@ def test_cuda_4bit_lm_sdxl_turbo_classification_pipeline(): f"--prompts_number 1 " f"--prompt_generator lm " f"--num_objects_range 1 2 " + f"--image_annotator clip " f"--image_generator sdxl-turbo " f"--use_image_tester " f"--lm_quantization 4bit " @@ -863,6 +890,7 @@ def test_cpu_lm_sdxl_classification_pipeline(): f"--class_names alien mars cat " f"--prompts_number 1 " f"--prompt_generator lm " + f"--image_annotator clip " f"--num_objects_range 1 2 " f"--image_generator sdxl " f"--use_image_tester " @@ -886,6 +914,7 @@ def test_cuda_lm_sdxl_classification_pipeline(): f"--class_names alien mars cat " f"--prompts_number 1 " f"--prompt_generator lm " + f"--image_annotator clip " f"--num_objects_range 1 2 " f"--image_generator sdxl " f"--use_image_tester " @@ -910,6 +939,7 @@ def test_cuda_4bit_lm_sdxl_classification_pipeline(): f"--prompts_number 1 " f"--prompt_generator lm " f"--num_objects_range 1 2 " + f"--image_annotator clip " f"--image_generator sdxl " f"--use_image_tester " f"--lm_quantization 4bit " @@ -936,6 +966,7 @@ def test_cpu_tiny_sdxl_turbo_classification_pipeline(): f"--class_names alien mars cat " f"--prompts_number 1 " f"--prompt_generator tiny " + f"--image_annotator clip " f"--num_objects_range 1 2 " f"--image_generator sdxl-turbo " f"--use_image_tester " @@ -960,6 +991,7 @@ def test_cuda_tiny_sdxl_turbo_classification_pipeline(): f"--prompts_number 1 " f"--prompt_generator tiny " f"--num_objects_range 1 2 " + f"--image_annotator clip " f"--image_generator sdxl-turbo " f"--use_image_tester " f"--device cuda" @@ -983,6 +1015,7 @@ def test_cpu_tiny_sdxl_classification_pipeline(): f"--prompts_number 1 " f"--prompt_generator tiny " f"--num_objects_range 1 2 " + f"--image_annotator clip " f"--image_generator sdxl " f"--use_image_tester " f"--device cpu" @@ -1006,6 +1039,7 @@ def test_cuda_tiny_sdxl_classification_pipeline(): f"--prompts_number 1 " f"--prompt_generator tiny " f"--num_objects_range 1 2 " + f"--image_annotator clip " f"--image_generator sdxl " f"--use_image_tester " f"--device cuda" diff --git a/tests/unittests/test_annotators.py b/tests/unittests/test_annotators.py index 9f66867..698ed3d 100644 --- a/tests/unittests/test_annotators.py +++ b/tests/unittests/test_annotators.py @@ -14,10 +14,10 @@ total_disk_space = psutil.disk_usage("/").total / (1024**3) -def _check_owlv2_annotator(device: str): +def _check_owlv2_annotator(device: str, size: str = "base"): url = "https://ultralytics.com/images/bus.jpg" im = Image.open(requests.get(url, stream=True).raw) - annotator = OWLv2Annotator(device=device) + annotator = OWLv2Annotator(device=device, size=size) final_boxes, final_scores, final_labels = annotator.annotate_batch( [im], ["bus", "people"] ) @@ -51,14 +51,14 @@ def test_cuda_owlv2_annotator(): total_disk_space < 15, reason="Test requires at least 15GB of HDD", ) -def test_cou_owlv2_annotator(): +def test_cpu_owlv2_annotator(): _check_owlv2_annotator("cpu") -def _check_clip_annotator(device: str): +def _check_clip_annotator(device: str, size: str = "base"): url = "https://ultralytics.com/images/bus.jpg" im = Image.open(requests.get(url, stream=True).raw) - annotator = CLIPAnnotator(device=device) + annotator = CLIPAnnotator(device=device, size=size) labels = annotator.annotate_batch([im], ["bus", "people"]) # Check that the labels are lists assert isinstance(labels, list) and len(labels) == 1 @@ -70,7 +70,7 @@ def _check_clip_annotator(device: str): not torch.cuda.is_available() or total_disk_space < 15, reason="Test requires GPU and 15GB of HDD", ) -def test_cuda_clip_annotator(): +def test_cuda_clip_base_annotator(): _check_clip_annotator("cuda") @@ -78,5 +78,21 @@ def test_cuda_clip_annotator(): total_disk_space < 15, reason="Test requires at least 15GB of HDD", ) -def test_cpu_clip_annotator(): +def test_cpu_clip_base_annotator(): + _check_clip_annotator("cpu") + + +@pytest.mark.skipif( + not torch.cuda.is_available() or total_disk_space < 15, + reason="Test requires GPU and 15GB of HDD", +) +def test_cuda_clip_large_annotator(): + _check_clip_annotator("cuda") + + +@pytest.mark.skipif( + total_disk_space < 15, + reason="Test requires at least 15GB of HDD", +) +def test_cpu_clip_large_annotator(): _check_clip_annotator("cpu")