Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reproduce Video-Llava #14

Open
liyucheng09 opened this issue Jan 23, 2025 · 2 comments
Open

Reproduce Video-Llava #14

liyucheng09 opened this issue Jan 23, 2025 · 2 comments

Comments

@liyucheng09
Copy link

Hi Senqiao @Yangsenqiao, how can we reproduce the results video-llava repored in the paper?

@liyucheng09
Copy link
Author

And What's the version of lmms-eval and llava to reproduce visionzip?

def restore_image_features_sorted(self, image_feature, cur_keep_idx, width, height):
   
    num_img, total_patches, feature_dim = image_feature.shape
    num_keep = cur_keep_idx.shape[1]  
    num_extra = total_patches - num_keep  


    cur_keep_idx_sorted, _ = cur_keep_idx.sort(dim=1)  # [num_img, num_keep]
    cur_keep_idx_sorted_restore = cur_keep_idx_sorted[:, 1:]-1

    restored_features = torch.zeros((num_img, 576, feature_dim), device=image_feature.device, dtype=image_feature.dtype)  # [num_img, total_patches, feature_dim]

    mask = torch.zeros(num_img, 576, dtype=torch.bool, device=image_feature.device)
    mask.scatter_(1, cur_keep_idx_sorted_restore, True)  

    kept_features = image_feature[:, 1:num_keep, :]  
    restored_features[mask] = kept_features.reshape(-1, feature_dim)  
    

    assert width * height == restored_features.shape[0], "width * height must equal num_img"
    restored_features = restored_features.view(height, width, 24, 24, feature_dim)  # [height, width, 24, 24, feature_dim]
    restored_features = restored_features.permute(0, 2, 1, 3, 4).contiguous()  # [height, 24, width, 24, feature_dim]
    restored_features = restored_features.view(height, 24, width * 24, feature_dim)  # [height, 24, width*24, feature_dim]
    restored_features = restored_features.view(height * 24, width * 24, feature_dim)  # [height*24, width*24, feature_dim]
    image_newline_expanded = self.model.image_newline.view(1, 1, feature_dim).expand(height * 24, 1, feature_dim).to(restored_features.device)  # [height*24, 1, feature_dim]
    grid_with_newline = restored_features

    mask = mask.view(height, width, 24, 24)  # [height, width, 24, 24]
    mask = mask.permute(0, 2, 1, 3).contiguous()  # [height, 24, width, 24]
    mask = mask.view(height * 24, width * 24)  # [height*24, width*24]

    mask_all = mask

    image_feature_select = grid_with_newline[mask_all]
    raw_img_feature_merge = image_feature[:,-num_extra:,].reshape(-1, feature_dim)
    cls_img_feature_merge = image_feature[:,0,]

    image_feature_select = torch.cat([image_feature_select, cls_img_feature_merge, raw_img_feature_merge])
    return image_feature_select

I am especially confused by this func. What is 576 here? why 24 * 24?

@liyucheng09
Copy link
Author

@Yangsenqiao Hi Senqiao, can you help me understand what's the expected shape of the images below when we dealing with video?

def prepare_inputs_labels_for_multimodal_visionzip(
    self, input_ids, position_ids, attention_mask, past_key_values, labels,
    images, image_sizes=None
):
    vision_tower = self.get_vision_tower()
    if vision_tower is None or images is None or input_ids.shape[1] == 1:
        return input_ids, position_ids, attention_mask, past_key_values, None, labels

    if type(images) is list or images.ndim == 5:
        if type(images) is list:

And how width * height must equal num_img? How can we understand this part of the code?

def restore_image_features_sorted(self, image_feature, cur_keep_idx, width, height):
   
    num_img, total_patches, feature_dim = image_feature.shape
    num_keep = cur_keep_idx.shape[1]  
    num_extra = total_patches - num_keep  


    cur_keep_idx_sorted, _ = cur_keep_idx.sort(dim=1)  # [num_img, num_keep]
    cur_keep_idx_sorted_restore = cur_keep_idx_sorted[:, 1:]

    restored_features = torch.zeros((num_img, 729, feature_dim), device=image_feature.device, dtype=image_feature.dtype)  # [num_img, total_patches, feature_dim]

    mask = torch.zeros(num_img, 729, dtype=torch.bool, device=image_feature.device)
    mask.scatter_(1, cur_keep_idx_sorted_restore, True)  

    kept_features = image_feature[:, 1:num_keep, :]  
    restored_features[mask] = kept_features.reshape(-1, feature_dim)  
    

    assert width * height == restored_features.shape[0], "width * height must equal num_img"

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant