Skip to content

Commit

Permalink
fixed evaclip
Browse files Browse the repository at this point in the history
  • Loading branch information
Daniel Kaplan committed Apr 3, 2024
1 parent 6cbc71c commit 19ba4ca
Showing 1 changed file with 32 additions and 14 deletions.
46 changes: 32 additions & 14 deletions mytests/test_vision_encoder_evaclip.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,31 +5,46 @@
import torch
from PIL import Image
import open_clip
from conda.common._logic import TRUE

class Args:
def __init__(self):

pass

def notest_eva_clip_image_transformer():
def test_eva_clip_image_transformer_shapes():

device = "cuda"
model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms('hf-hub:timm/eva02_enormous_patch14_plus_clip_224.laion2b_s9b_b144k')
tokenizer = open_clip.get_tokenizer('hf-hub:timm/eva02_enormous_patch14_plus_clip_224.laion2b_s9b_b144k')
#
# model, _, preprocess = open_clip.create_model_and_transforms('hf-hub:timm/eva02_enormous_patch14_plus_clip_224.laion2b_s9b_b144k')
import open_clip

model, preprocess = open_clip.create_model_from_pretrained('hf-hub:timm/eva02_large_patch14_clip_224.merged2b_s4b_b131k')
tokenizer = open_clip.get_tokenizer('hf-hub:timm/eva02_large_patch14_clip_224.merged2b_s4b_b131k')
# print(type(preprocess))
# print(preprocess.transforms)

print("Loaded model")

vision = model.visual
vision = vision.to(device)
for transform in preprocess.transforms:
print(transform)
# preprocess.transforms = preprocess.transforms[3:]
# print(preprocess.transforms)

vision.output_tokens = True
image = preprocess(Image.open(image_path)).unsqueeze(0).to(device)
visual = model.visual
visual.trunk.output_tokens = True

pooled, tokens = vision(image)
visual = visual.to(device)
del model

image_path = "dog.jpg"
image = preprocess(Image.open(image_path)).unsqueeze(0).to(device)
print(image.shape)

pooled = visual(image)

print(pooled.shape)
print(tokens.shape)
# print(tokens.shape)



def test_eva_clip_image_transformer():

Expand All @@ -38,7 +53,7 @@ def test_eva_clip_image_transformer():
caption = ["a diagram", "a dog", "a cat"]

device = "cuda"
model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms('hf-hub:timm/eva02_enormous_patch14_plus_clip_224.laion2b_s9b_b144k')
model, _, preprocess = open_clip.create_model_and_transforms('hf-hub:timm/eva02_enormous_patch14_plus_clip_224.laion2b_s9b_b144k')
tokenizer = open_clip.get_tokenizer('hf-hub:timm/eva02_enormous_patch14_plus_clip_224.laion2b_s9b_b144k')

print("Loaded model")
Expand All @@ -60,11 +75,14 @@ def test_eva_clip_image_transformer():

def main():


test_eva_clip_image_transformer()

test_eva_clip_image_transformer_shapes()
# test_eva_clip_image_transformer()
# test_dino_image_frozen_transformer()
# test_dino_image_frozen_lora_transformer()
# test_dino_video_transformer_basic()

if __name__ == "__main__":


main()

0 comments on commit 19ba4ca

Please sign in to comment.