Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Checkered Predictions (Checkered Artifacts?) #778

Open
Vishawjeet-rmsl opened this issue Oct 15, 2024 · 6 comments
Open

Checkered Predictions (Checkered Artifacts?) #778

Vishawjeet-rmsl opened this issue Oct 15, 2024 · 6 comments

Comments

@Vishawjeet-rmsl
Copy link

Hi,
I have been modifying the scripts of vanilla SAM, mainly to come up with my own training script.
I was kind of successful in that, and training is happening with loss gradually reducing. But I noticed something, when I save the predictions made by the model in every epoch, I observe that there is a checkered lines all over the predictions.
For eg, in the below image, the left one is prediction from epoch 1 and right image is from epoch 117. I observe that although the grid is fading, but it's clearly visible.
image

Does anyone know what is causing this? Or is it just because the model is not trained for enough number of epochs?
Well, I'm using just 15 image and mask pairs for training (Using image encoder weights, but training prompt encoder and mask decoder from scratch).

I would be grateful if someone can give me some clue if not a proper solution. Thanks in advance!

@heyoeyo
Copy link

heyoeyo commented Oct 16, 2024

The image you posted seems to have about 64 'tiles' horizontally. By default the image encoder outputs a 64x64 token 'image' that is processed and eventually upscaled by the mask decoder. So it seems likely that the upscaling isn't 'mixing' the pixels together enough and therefore the original 64x64 grid is still visible in the result.

There can also be other, much more subtle, lower-resolution grid artifacts that can appear due to the windowing the models uses.

@Vishawjeet-rmsl
Copy link
Author

The image you posted seems to have about 64 'tiles' horizontally. By default the image encoder outputs a 64x64 token 'image' that is processed and eventually upscaled by the mask decoder. So it seems likely that the upscaling isn't 'mixing' the pixels together enough and therefore the original 64x64 grid is still visible in the result.

There can also be other, much more subtle, lower-resolution grid artifacts that can appear due to the windowing the models uses.

Wow! It does seem to have 64 tiles, if this is the reason then it could mean there is some issue with the upsampling method.
BTW, the above images are raw predictions (Upscaled to the original image size) without applying the threshold. So, I was wondering maybe another reason for this could be the interpolation? i.e. the raw prediction from the mask decoder has spatial dimension 256x256 and then we interpolate it during the postprocessing. Maybe I should visualize the raw predictions as well.

@heyoeyo
Copy link

heyoeyo commented Oct 18, 2024

So, I was wondering maybe another reason for this could be the interpolation?

Since the interpolation is bilinear, it probably shouldn't introduce any artifacts other than a blurring effect as the smaller pattern is scaled up.

However, the fact that the model doesn't upscale all the way back to the original input size may be part of the problem (there was some discussion of this on the samv2 issue board), since it gives the model less chance of processing the original tokens + any artifacts get interpolated up to be more visible.

@bhack
Copy link

bhack commented Oct 19, 2024

I've tried to upscale the decoder more smoothly with some extra layers (512 and 1024) up to 1024x1024 instead of the original 256x256 + pure interpolation and I have seen similar artifacts.

I think that there is still something else that is going to impact the resolution.

@heyoeyo
Copy link

heyoeyo commented Oct 20, 2024

I've tried to upscale the decoder more smoothly with some extra layers (512 and 1024) up to 1024x1024 instead of the original 256x256 + pure interpolation and I have seen similar artifacts

That's interesting! Maybe the decoder model is just too small/simple to avoid these kinds of artifacts entirely. It's probably hard to improve it without breaking the original 'real-time on cpu' design constraint. Maybe a few regular convolutions in between the upscaling steps could help blend things better spatially?

@bhack
Copy link

bhack commented Oct 20, 2024

Maybe a few regular convolutions in between the upscaling steps could help blend things better spatially

It is what I have tried to not invalidate the pretrained checkpoint part of the decoder.

Probably we need to have a better design of these extra layers.

If not you are going to strictly interpolate from 256x256 (for 1024x1024 inputs) at:

https://github.com/facebookresearch/sam2/blob/main/sam2%2Fmodeling%2Fsam2_base.py#L373

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants