1
1
from __future__ import annotations
2
2
3
3
import logging
4
- from typing import Dict , List
5
4
6
- import numpy as np
7
- import PIL
8
5
import torch
9
6
from PIL import Image
10
7
from transformers import CLIPModel , CLIPProcessor
11
8
12
- from datadreamer .dataset_annotation .image_annotator import BaseAnnotator , TaskList
9
+ from datadreamer .dataset_annotation .cls_annotator import ImgClassificationAnnotator
13
10
14
11
logger = logging .getLogger (__name__ )
15
12
16
13
17
- class CLIPAnnotator (BaseAnnotator ):
14
+ class CLIPAnnotator (ImgClassificationAnnotator ):
18
15
"""A class for image annotation using the CLIP model, specializing in image
19
16
classification.
20
17
@@ -31,25 +28,6 @@ class CLIPAnnotator(BaseAnnotator):
31
28
release(empty_cuda_cache): Releases resources and optionally empties the CUDA cache.
32
29
"""
33
30
34
- def __init__ (
35
- self ,
36
- seed : float = 42 ,
37
- device : str = "cuda" ,
38
- size : str = "base" ,
39
- ) -> None :
40
- """Initializes the CLIPAnnotator with a specific seed and device.
41
-
42
- Args:
43
- seed (float): Seed for reproducibility. Defaults to 42.
44
- device (str): The device to run the model on. Defaults to 'cuda'.
45
- """
46
- super ().__init__ (seed , task_definition = TaskList .CLASSIFICATION )
47
- self .size = size
48
- self .model = self ._init_model ()
49
- self .processor = self ._init_processor ()
50
- self .device = device
51
- self .model .to (self .device )
52
-
53
31
def _init_processor (self ) -> CLIPProcessor :
54
32
"""Initializes the CLIP processor.
55
33
@@ -71,82 +49,6 @@ def _init_model(self) -> CLIPModel:
71
49
return CLIPModel .from_pretrained ("openai/clip-vit-large-patch14" )
72
50
return CLIPModel .from_pretrained ("openai/clip-vit-base-patch32" )
73
51
74
- def annotate_batch (
75
- self ,
76
- images : List [PIL .Image .Image ],
77
- objects : List [str ],
78
- conf_threshold : float = 0.1 ,
79
- synonym_dict : Dict [str , List [str ]] | None = None ,
80
- ) -> List [np .ndarray ]:
81
- """Annotates images using the CLIP model.
82
-
83
- Args:
84
- images: The images to be annotated.
85
- objects: A list of objects (text) to test against the images.
86
- conf_threshold (float, optional): Confidence threshold for the annotations. Defaults to 0.1.
87
- synonym_dict (dict, optional): Dictionary for handling synonyms in labels. Defaults to None.
88
-
89
- Returns:
90
- List[np.ndarray]: A list of the annotations for each image.
91
- """
92
- if synonym_dict is not None :
93
- objs_syn = set ()
94
- for obj in objects :
95
- objs_syn .add (obj )
96
- for syn in synonym_dict [obj ]:
97
- objs_syn .add (syn )
98
- objs_syn = list (objs_syn )
99
- # Make a dict to transform synonym ids to original ids
100
- synonym_dict_rev = {}
101
- for key , value in synonym_dict .items ():
102
- if key in objects :
103
- synonym_dict_rev [objs_syn .index (key )] = objects .index (key )
104
- for v in value :
105
- synonym_dict_rev [objs_syn .index (v )] = objects .index (key )
106
- objects = objs_syn
107
-
108
- inputs = self .processor (
109
- text = objects , images = images , return_tensors = "pt" , padding = True
110
- ).to (self .device )
111
-
112
- outputs = self .model (** inputs )
113
-
114
- logits_per_image = outputs .logits_per_image # image-text similarity score
115
- probs = logits_per_image .softmax (dim = 1 ).cpu () # label probabilities
116
-
117
- labels = []
118
- # Get the labels for each image
119
- if synonym_dict is not None :
120
- for prob in probs :
121
- labels .append (
122
- np .unique (
123
- np .array (
124
- [
125
- synonym_dict_rev [label .item ()]
126
- for label in torch .where (prob > conf_threshold )[
127
- 0
128
- ].numpy ()
129
- ]
130
- )
131
- )
132
- )
133
- else :
134
- for prob in probs :
135
- labels .append (torch .where (prob > conf_threshold )[0 ].numpy ())
136
-
137
- return labels
138
-
139
- def release (self , empty_cuda_cache : bool = False ) -> None :
140
- """Releases the model and optionally empties the CUDA cache.
141
-
142
- Args:
143
- empty_cuda_cache (bool, optional): Whether to empty the CUDA cache. Defaults to False.
144
- """
145
- self .model = self .model .to ("cpu" )
146
- if empty_cuda_cache :
147
- with torch .no_grad ():
148
- torch .cuda .empty_cache ()
149
-
150
52
151
53
if __name__ == "__main__" :
152
54
import requests
0 commit comments