MedVAE can be run using either:
- A PyTorch model (programmatic use)
- A command-line interface (CLI) (recommended for beginners)
Please see the demo for programmatic examples.
If you are new to MedVAE and want to downsize your medical images, the CLI approach is recommended.
MedVAE provides six pre-trained models for 2D and 3D medical images, each with different compression settings:
Model Name | Compression | Latent Channels | Total Compression |
---|---|---|---|
medvae_4_1_2d |
4× per dim | 1 | 16× total |
medvae_4_3_2d |
4× per dim | 3 | 16× total |
medvae_8_1_2d |
8× per dim | 1 | 64× total |
medvae_8_4_2d |
8× per dim | 4 | 64× total |
Model Name | Compression | Latent Channels | Total Compression |
---|---|---|---|
medvae_4_1_3d |
4× per dim | 1 | 64× total |
medvae_8_1_3d |
8× per dim | 1 | 512× total |
If you are integrating MedVAE into an existing PyTorch workflow, using it as a PyTorch model is recommended. The MVAE class provides an easy way to load and use MedVAE models programmatically.
To create an MVAE
model object, three parameters are needed:
model_name
– Specifies which of the six available MedVAE models to use.modality
– Defines the medical imaging modality ("xray"
,"ct"
, or"mri"
).gpu_dim
(optional) – Sets the largest volumetric dimension the GPU can handle.- Default:
160
, optimized for a 48GB Nvidia A6000 GPU.
- Default:
The MVAE
class provides an apply_transforms
method, which automatically applies the appropriate transformation based on the input file type and modality.
- 2D MedVAE models → Input should be a 2D
.png
file. - 3D MedVAE models → Input should be a compressed 3D NIfTI (
*.nii.gz
) file.
For more details, the transforms file is located here.
import torch
from medvae import MVAE
fpath = "documentation/data/mmg_data/isJV8hQ2hhJsvEP5rdQNiy.png"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MVAE(model_name='medvae_4_3_2d', modality='xray').to(device)
img = model.apply_transform(fpath).to(device)
model.requires_grad_(False)
model.eval()
with torch.no_grad():
latent = model(img)
The CLI script runs inference using MedVAE, processing 2D or 3D medical images to generate latent representations. It allows users to specify a pretrained MedVAE model and input modalities (X-ray, CT, MRI). Given an input directory, it will process all the medical images into latent representations and save them in the specified folder.
medvae_inference -i INPUT_FOLDER -o OUTPUT_FOLDER -model_name MED_VAE_MODEL -modality MODALITY
Argument | Type | Required | Description |
---|---|---|---|
-i | str | ✅ Yes | Path to the input folder containing images (*.png for 2D, *.nii.gz for 3D). The filenames must not contain multiple dots. |
-o | str | ✅ Yes | Path to the output folder where latent representations will be saved. If the folder does not exist, it will be created. |
-model_name | str | ✅ Yes | Specifies the Med-VAE model to use. See available options above. |
-modality | str | ✅ Yes | Specifies the image modality: "xray", "ct", or "mri". |
-roi_size | int | ❌ No (Default: 160) | Sets the region of interest (ROI) size for 3D models (used to manage GPU memory). |
-device | str | ❌ No (Default: "cuda") | Specifies the device to run inference on: "cuda" (GPU), "cpu" (CPU), "mps" (Apple M1/M2). Do not specify GPU ID here! Use CUDA_VISIBLE_DEVICES=X instead. |
Total Compression Factor | Channels | Dimensions | Modalities | Anatomies | Config File | Model File |
---|---|---|---|---|---|---|
16 | 1 | 2D | X-ray | Chest, Breast (FFDM) | medvae_4x1.yaml | vae_4x_1c_2D.ckpt |
16 | 3 | 2D | X-ray | Chest, Breast (FFDM) | medvae_4x3.yaml | vae_4x_3c_2D.ckpt |
64 | 1 | 2D | X-ray | Chest, Breast (FFDM) | medvae_8x1.yaml | vae_8x_1c_2D.ckpt |
64 | 3 | 2D | X-ray | Chest, Breast (FFDM) | medvae_8x4.yaml | vae_8x_4c_2D.ckpt |
64 | 1 | 3D | MRI, CT | Whole-Body | medvae_4x1.yaml | vae_4x_1c_3D.ckpt |
512 | 1 | 3D | MRI, CT | Whole-Body | medvae_8x1.yaml | vae_8x_1c_3D.ckpt |