Skip to content

Commit

Permalink
Change the names of the model & processor objects
Browse files Browse the repository at this point in the history
  • Loading branch information
HonzaCuhel committed Mar 17, 2024
1 parent b2e7f1f commit 7f3fe69
Showing 1 changed file with 8 additions and 9 deletions.
17 changes: 8 additions & 9 deletions datadreamer/dataset_annotation/clip_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ class CLIPAnnotator(BaseAnnotator):
classification.
Attributes:
clip (CLIPModel): The CLIP model for image-text similarity evaluation.
clip_processor (CLIPProcessor): The processor for preparing inputs to the CLIP model.
model (CLIPModel): The CLIP model for image-text similarity evaluation.
processor (CLIPProcessor): The processor for preparing inputs to the CLIP model.
device (str): The device on which the model will run ('cuda' for GPU, 'cpu' for CPU).
Methods:
Expand All @@ -37,12 +37,10 @@ def __init__(
device (str): The device to run the model on. Defaults to 'cuda'.
"""
super().__init__(seed)
self.clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
self.clip_processor = CLIPProcessor.from_pretrained(
"openai/clip-vit-base-patch32"
)
self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
self.device = device
self.clip.to(self.device)
self.model.to(self.device)

def annotate_batch(
self,
Expand Down Expand Up @@ -78,11 +76,11 @@ def annotate_batch(
synonym_dict_rev[objs_syn.index(v)] = objects.index(key)
objects = objs_syn

inputs = self.clip_processor(
inputs = self.processor(
text=objects, images=images, return_tensors="pt", padding=True
).to(self.device)

outputs = self.clip(**inputs)
outputs = self.model(**inputs)

logits_per_image = outputs.logits_per_image # image-text similarity score
probs = logits_per_image.softmax(dim=1).cpu() # label probabilities
Expand Down Expand Up @@ -130,3 +128,4 @@ def release(self, empty_cuda_cache: bool = False) -> None:
annotator = CLIPAnnotator(device=device)
labels = annotator.annotate_batch([im], ["bus", "people"])
print(labels)
annotator.release()

0 comments on commit 7f3fe69

Please sign in to comment.