Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

About the CLS token for the llama3_2_vision_encoder #2268

Open
dfloreaa opened this issue Jan 15, 2025 · 4 comments
Open

About the CLS token for the llama3_2_vision_encoder #2268

dfloreaa opened this issue Jan 15, 2025 · 4 comments
Assignees
Labels
discussion Start a discussion triaged This issue has been assigned an owner and appropriate label

Comments

@dfloreaa
Copy link

From what I have seen, the llama3_2_vision_encoder seems to use a CLIP encoder for the images, which itself is an instance of a VisionTransformer, which itself has a cls_token_embedding module, which takes append_cls_token as an argument.

As the llama3_2_vision_encoder sets this variable to False, the CLS token is added "to the beginning of the sequence", this being the input of the transformer pre self-attention. Does this CLS token (once output from the transformer on a [1601, 4096] vector, being the first element) even encode the image's information?

I do not know the attention mask which the model is working with, I would guess it's all 1s for the vision side of things and therefore, it does actually contain the contents of the image, but if so, why does it make a difference to place it in front or in the back of the input?

I am looking to get a single (1, 4096) vector that encodes the images content (this for a research project) and I am curious if this vector will do the trick or not.

@RdoubleA
Copy link
Contributor

I do not know the attention mask which the model is working with, I would guess it's all 1s for the vision side of things and therefore, it does actually contain the contents of the image, but if so, why does it make a difference to place it in front or in the back of the input?

Yes you are correct, there is no attention mask used. So all patches can attend to all others, including the class token. The exact position of the class token probably makes a marginal difference, you would just need to be consistent with how the model was pretrained. For ViT, the class token was prepended to the sequence.

I am looking to get a single (1, 4096) vector that encodes the images content (this for a research project) and I am curious if this vector will do the trick or not.

This is exactly what the class token was used for. It learns some global information about the image since it can attend to all other patches, and it's a lot more efficient to add a linear classification layer on top of a single token vs. doing some aggregation of all the patches for classification. The original ViT paper does exactly this, and adds an MLP head on the class token for classification.

@felipemello1
Copy link
Contributor

felipemello1 commented Jan 15, 2025

I would just double check if the CLIP model was finetuned or kept frozen in llama3_2. If I recall it correctly, it was finetuned. In that case, this CLS token probably does not hold the same meaning as it did for ViT, since I dont think that llama3_2 had any classification objective related to this token.

@joecummings joecummings added discussion Start a discussion triaged This issue has been assigned an owner and appropriate label labels Jan 15, 2025
@dfloreaa
Copy link
Author

I do not really want to perform classification with this token, but rather something similar to this paper

Image

Basically, they create a temporal learner that aligns a video embedding vector with the text embeddings. In order to perform this, I do need a (1, 4096) vector per image that pretty much contains all of the information contained inside of it. Hoping that the CLS token may still retain part of that property, I was looking into using its embeddings for this process, then replacing the image embeddings used for the Llama decoder with the new, temporal-updated vector.

So far things have been going fine, with the exception that the models seems to eat up a lot of GPU memory with each generation, just for demostration purposes, this code is similar to what I use for generation:

import torch
from torchtune.models.llama3_2_vision import llama3_2_vision_11b

model = llama3_2_vision_11b()
weights_path = f"{checkpoint_dir}/{weights_file}"
state_dict = torch.load(weights_path)
model.load_state_dict(state_dict)
model.eval()

# (...)
img_embedding = model.encoder(img_tensor)
tokenized_input = torch.IntTensor(tokenizer.encode(text="Describe this scene?\n", add_bos=True, add_eos=True)).unsqueeze(0)
# (...)

for i in range(400):
    out = model.decoder(tokens = tokenized_input, encoder_input = img_embedding)
    tokenized_input = torch.cat((tokenized_input, out.argmax(-1)[:, -1:]), dim=-1)

Each loop seems to increase VRAM memory usage by quite a bit, eventhough I am practically overriding the same vectors. Same with the model initialization, whenever I run it it uses up ~100GB to 140GB of RAM memory, which seems quite excessive for a model this size. Is there something I am missing/doing wrong?

@felipemello1
Copy link
Contributor

felipemello1 commented Jan 15, 2025

  • You can cast the model to bf16, this should reduce memory by 50%
  • You can set expandable segments = True. Basically every time you call tokenized_input = torch.cat(, you have to allocate new memory. It doesnt mean that the previous one gets released. Expandable segments helps with that.
  • As as sanity check, print the shapes to make sure they are what you think they are.

Regarding 100GB RAM, you can do some sanity check: The model is 11B * 4 bytes (fp32) = 44B. Then you need to check the size of your tokenizer_input, out, img_embedding + some extra for the activations of the model.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
discussion Start a discussion triaged This issue has been assigned an owner and appropriate label
Projects
None yet
Development

No branches or pull requests

5 participants