Skip to content

Conversation

@AmitMY
Copy link
Contributor

@AmitMY AmitMY commented Dec 7, 2025

Summary

When using group_images=True with small group_max_seq_len, the image_ids tensor creation becomes a significant bottleneck. This PR adds LRU caching for image_ids tensors, similar to the existing posemb_grid cache.

The Problem

With group_images=True and group_max_seq_len=5, processing 512 images creates ~440 groups. Each group calls:

image_ids = torch.repeat_interleave(
    arange(len(images)),
    torch.tensor(patch_counts, device=device)
)

The torch.tensor(patch_counts, device=device) call has significant overhead when called 440 times per forward pass.

The Solution

Cache the image_ids tensors by patch count pattern using lru_cache. Since there are only a few unique patch count patterns (e.g., (3,), (2, 3), (5,)), the cache is highly effective.

Benchmark Results

512 variable-width images (16px tall, 32-80px wide), group_max_seq_len=5:

Config Before After Speedup
NaViT (group_images=True) 97.9ms 24.1ms 4x faster

🤖 Generated with Claude Code

@AmitMY AmitMY force-pushed the optimize-group-images branch from 88714a6 to 43726b7 Compare December 7, 2025 14:07
@AmitMY AmitMY marked this pull request as draft December 7, 2025 14:08
@AmitMY AmitMY force-pushed the optimize-group-images branch from 43726b7 to 417d571 Compare December 7, 2025 14:13
@AmitMY
Copy link
Contributor Author

AmitMY commented Dec 7, 2025

This would not really work in real scenarios.
But it does show that the repeat_interleave is very costly for short sequences

@AmitMY AmitMY closed this Dec 9, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant