Skip to content

ViT architecture with Mamba instead of transformer backbone

Notifications You must be signed in to change notification settings

JLrumberger/MambaViT

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

7 Commits
 
 
 
 
 
 
 
 

Repository files navigation

MambaViT

ViT architecture with Mamba instead of transformer backbone. The ViT code is based on https://github.com/lucidrains/vit-pytorch. I only installed it on a Linux server with an Nvidia A40 GPU and CUDA 12.2 which worked fine for me. A short training script with MNIST can be found in mamba_vit_MNIST_example.ipynb, it trains to ~90% validation accuracy in a couple of epochs, which is comparable to small ViTs.

The model can be used like this:

m = MambaViT(
    image_size=28,
    patch_size=4,
    num_classes=10,
    channels=1,
    n_layer=8,
    dim=32,
    pool="mean" # mean or cls
    )

img = torch.rand(16,1,28,28) # B x C x H x W
pred = m(img) # 16 x num_classes

About

ViT architecture with Mamba instead of transformer backbone

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published