Allow changing the window size of pretrained models #594
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
After digging into this excellent model, I found that, with an input size of 1024, and a patch size of 16, we're left with a 64x64 feature map. Then, I found that SAM uses a window size of 14, which causes padding to be necessary.
It turns out that the pretrained SAM (Large at least) is robust to changes in the window size, as long as you handle the relative position embedding for the attention layer. This PR will lerp the position embedding from state dict into the embedding for how the model was created. Fortunately, the embeddings based on L1 distance seem perfectly fine with this.
Once I got model loading, I ran the resulting model through the COCO instance segmentation script, as implemented by the EfficientViT researchers (https://github.com/mit-han-lab/efficientvit/blob/master/eval_sam_coco.py), and ran the evaluation with a few different window sizes:
As we can see, if you change the window size to 16, not only do these mIOU metrics improve slightly across all object sizes, but also the throughput increases (I used an A100 with batch size 16, 100 forward passes, reported as im/sec). I suspect that throughput improves for sizes 8 and 16 for a couple of reasons: (A) GPUs prefer powers of two, and those two window size choices result in gemm's with size 16^2 or 64^2, and (B) Padding in
Attention
is no longer necessary around every windowed attention operation.So, this PR optionally allows api consumers to specify a different window size during model construction, and implements the weight lerping during state_dict loading so that existing model weights may be used.