Skip to content

Commit 9e4990b

Browse files
committed
Add Img Cls Annotator
1 parent 5833c62 commit 9e4990b

File tree

4 files changed

+136
-200
lines changed

4 files changed

+136
-200
lines changed

datadreamer/dataset_annotation/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from .aimv2_annotator import AIMv2Annotator
44
from .clip_annotator import CLIPAnnotator
5+
from .cls_annotator import ImgClassificationAnnotator
56
from .image_annotator import BaseAnnotator, TaskList
67
from .owlv2_annotator import OWLv2Annotator
78
from .slimsam_annotator import SlimSAMAnnotator
@@ -11,6 +12,7 @@
1112
"BaseAnnotator",
1213
"TaskList",
1314
"OWLv2Annotator",
15+
"ImgClassificationAnnotator",
1416
"CLIPAnnotator",
1517
"SlimSAMAnnotator",
1618
]

datadreamer/dataset_annotation/aimv2_annotator.py

+2-100
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,17 @@
99
from __future__ import annotations
1010

1111
import logging
12-
from typing import Dict, List
1312

14-
import numpy as np
15-
import PIL
1613
import torch
1714
from PIL import Image
1815
from transformers import AutoModel, AutoProcessor
1916

20-
from datadreamer.dataset_annotation.image_annotator import BaseAnnotator, TaskList
17+
from datadreamer.dataset_annotation.cls_annotator import ImgClassificationAnnotator
2118

2219
logger = logging.getLogger(__name__)
2320

2421

25-
class AIMv2Annotator(BaseAnnotator):
22+
class AIMv2Annotator(ImgClassificationAnnotator):
2623
"""A class for image annotation using the AIMv2 model, specializing in image
2724
classification.
2825
@@ -39,25 +36,6 @@ class AIMv2Annotator(BaseAnnotator):
3936
release(empty_cuda_cache): Releases resources and optionally empties the CUDA cache.
4037
"""
4138

42-
def __init__(
43-
self,
44-
seed: float = 42,
45-
device: str = "cuda",
46-
size: str = "base",
47-
) -> None:
48-
"""Initializes the AIMv2Annotator with a specific seed and device.
49-
50-
Args:
51-
seed (float): Seed for reproducibility. Defaults to 42.
52-
device (str): The device to run the model on. Defaults to 'cuda'.
53-
"""
54-
super().__init__(seed, task_definition=TaskList.CLASSIFICATION)
55-
self.size = size
56-
self.model = self._init_model()
57-
self.processor = self._init_processor()
58-
self.device = device
59-
self.model.to(self.device)
60-
6139
def _init_processor(self) -> AutoProcessor:
6240
"""Initializes the AIMv2 processor.
6341
@@ -77,82 +55,6 @@ def _init_model(self) -> AutoModel:
7755
"apple/aimv2-large-patch14-224-lit", trust_remote_code=True
7856
)
7957

80-
def annotate_batch(
81-
self,
82-
images: List[PIL.Image.Image],
83-
objects: List[str],
84-
conf_threshold: float = 0.1,
85-
synonym_dict: Dict[str, List[str]] | None = None,
86-
) -> List[np.ndarray]:
87-
"""Annotates images using the AIMv2 model.
88-
89-
Args:
90-
images: The images to be annotated.
91-
objects: A list of objects (text) to test against the images.
92-
conf_threshold (float, optional): Confidence threshold for the annotations. Defaults to 0.1.
93-
synonym_dict (dict, optional): Dictionary for handling synonyms in labels. Defaults to None.
94-
95-
Returns:
96-
List[np.ndarray]: A list of the annotations for each image.
97-
"""
98-
if synonym_dict is not None:
99-
objs_syn = set()
100-
for obj in objects:
101-
objs_syn.add(obj)
102-
for syn in synonym_dict[obj]:
103-
objs_syn.add(syn)
104-
objs_syn = list(objs_syn)
105-
# Make a dict to transform synonym ids to original ids
106-
synonym_dict_rev = {}
107-
for key, value in synonym_dict.items():
108-
if key in objects:
109-
synonym_dict_rev[objs_syn.index(key)] = objects.index(key)
110-
for v in value:
111-
synonym_dict_rev[objs_syn.index(v)] = objects.index(key)
112-
objects = objs_syn
113-
114-
inputs = self.processor(
115-
text=objects, images=images, return_tensors="pt", padding=True
116-
).to(self.device)
117-
118-
outputs = self.model(**inputs)
119-
120-
logits_per_image = outputs.logits_per_image # image-text similarity score
121-
probs = logits_per_image.softmax(dim=1).cpu() # label probabilities
122-
123-
labels = []
124-
# Get the labels for each image
125-
if synonym_dict is not None:
126-
for prob in probs:
127-
labels.append(
128-
np.unique(
129-
np.array(
130-
[
131-
synonym_dict_rev[label.item()]
132-
for label in torch.where(prob > conf_threshold)[
133-
0
134-
].numpy()
135-
]
136-
)
137-
)
138-
)
139-
else:
140-
for prob in probs:
141-
labels.append(torch.where(prob > conf_threshold)[0].numpy())
142-
143-
return labels
144-
145-
def release(self, empty_cuda_cache: bool = False) -> None:
146-
"""Releases the model and optionally empties the CUDA cache.
147-
148-
Args:
149-
empty_cuda_cache (bool, optional): Whether to empty the CUDA cache. Defaults to False.
150-
"""
151-
self.model = self.model.to("cpu")
152-
if empty_cuda_cache:
153-
with torch.no_grad():
154-
torch.cuda.empty_cache()
155-
15658

15759
if __name__ == "__main__":
15860
import requests

datadreamer/dataset_annotation/clip_annotator.py

+2-100
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,17 @@
11
from __future__ import annotations
22

33
import logging
4-
from typing import Dict, List
54

6-
import numpy as np
7-
import PIL
85
import torch
96
from PIL import Image
107
from transformers import CLIPModel, CLIPProcessor
118

12-
from datadreamer.dataset_annotation.image_annotator import BaseAnnotator, TaskList
9+
from datadreamer.dataset_annotation.cls_annotator import ImgClassificationAnnotator
1310

1411
logger = logging.getLogger(__name__)
1512

1613

17-
class CLIPAnnotator(BaseAnnotator):
14+
class CLIPAnnotator(ImgClassificationAnnotator):
1815
"""A class for image annotation using the CLIP model, specializing in image
1916
classification.
2017
@@ -31,25 +28,6 @@ class CLIPAnnotator(BaseAnnotator):
3128
release(empty_cuda_cache): Releases resources and optionally empties the CUDA cache.
3229
"""
3330

34-
def __init__(
35-
self,
36-
seed: float = 42,
37-
device: str = "cuda",
38-
size: str = "base",
39-
) -> None:
40-
"""Initializes the CLIPAnnotator with a specific seed and device.
41-
42-
Args:
43-
seed (float): Seed for reproducibility. Defaults to 42.
44-
device (str): The device to run the model on. Defaults to 'cuda'.
45-
"""
46-
super().__init__(seed, task_definition=TaskList.CLASSIFICATION)
47-
self.size = size
48-
self.model = self._init_model()
49-
self.processor = self._init_processor()
50-
self.device = device
51-
self.model.to(self.device)
52-
5331
def _init_processor(self) -> CLIPProcessor:
5432
"""Initializes the CLIP processor.
5533
@@ -71,82 +49,6 @@ def _init_model(self) -> CLIPModel:
7149
return CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
7250
return CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
7351

74-
def annotate_batch(
75-
self,
76-
images: List[PIL.Image.Image],
77-
objects: List[str],
78-
conf_threshold: float = 0.1,
79-
synonym_dict: Dict[str, List[str]] | None = None,
80-
) -> List[np.ndarray]:
81-
"""Annotates images using the CLIP model.
82-
83-
Args:
84-
images: The images to be annotated.
85-
objects: A list of objects (text) to test against the images.
86-
conf_threshold (float, optional): Confidence threshold for the annotations. Defaults to 0.1.
87-
synonym_dict (dict, optional): Dictionary for handling synonyms in labels. Defaults to None.
88-
89-
Returns:
90-
List[np.ndarray]: A list of the annotations for each image.
91-
"""
92-
if synonym_dict is not None:
93-
objs_syn = set()
94-
for obj in objects:
95-
objs_syn.add(obj)
96-
for syn in synonym_dict[obj]:
97-
objs_syn.add(syn)
98-
objs_syn = list(objs_syn)
99-
# Make a dict to transform synonym ids to original ids
100-
synonym_dict_rev = {}
101-
for key, value in synonym_dict.items():
102-
if key in objects:
103-
synonym_dict_rev[objs_syn.index(key)] = objects.index(key)
104-
for v in value:
105-
synonym_dict_rev[objs_syn.index(v)] = objects.index(key)
106-
objects = objs_syn
107-
108-
inputs = self.processor(
109-
text=objects, images=images, return_tensors="pt", padding=True
110-
).to(self.device)
111-
112-
outputs = self.model(**inputs)
113-
114-
logits_per_image = outputs.logits_per_image # image-text similarity score
115-
probs = logits_per_image.softmax(dim=1).cpu() # label probabilities
116-
117-
labels = []
118-
# Get the labels for each image
119-
if synonym_dict is not None:
120-
for prob in probs:
121-
labels.append(
122-
np.unique(
123-
np.array(
124-
[
125-
synonym_dict_rev[label.item()]
126-
for label in torch.where(prob > conf_threshold)[
127-
0
128-
].numpy()
129-
]
130-
)
131-
)
132-
)
133-
else:
134-
for prob in probs:
135-
labels.append(torch.where(prob > conf_threshold)[0].numpy())
136-
137-
return labels
138-
139-
def release(self, empty_cuda_cache: bool = False) -> None:
140-
"""Releases the model and optionally empties the CUDA cache.
141-
142-
Args:
143-
empty_cuda_cache (bool, optional): Whether to empty the CUDA cache. Defaults to False.
144-
"""
145-
self.model = self.model.to("cpu")
146-
if empty_cuda_cache:
147-
with torch.no_grad():
148-
torch.cuda.empty_cache()
149-
15052

15153
if __name__ == "__main__":
15254
import requests

0 commit comments

Comments
 (0)