Skip to content

Latest commit

 

History

History
51 lines (44 loc) · 2.83 KB

model_performance.md

File metadata and controls

51 lines (44 loc) · 2.83 KB

Training

Some settings follow those in the AlphaFold 3 paper, The table below shows the training settings for different fine-tuning stages:

Arguments Initial training Fine tuning 1 Fine tuning 2 Fine tuning 3
train_crop_size 384 640 768 768
diffusion_batch_size 48 32 32 32
loss.weight.alpha_pae 0 0 0 1.0
loss.weight.alpha_bond 0 1.0 1.0 0
loss.weight.smooth_lddt 1.0 0 0 0
loss.weight.alpha_confidence 1e-4 1e-4 1e-4 1e-4
loss.weight.alpha_diffusion 4.0 4.0 4.0 0
loss.weight.alpha_distogram 0.03 0.03 0.03 0
train_confidence_only False False False True
full BF16-mixed speed(A100, s/step) ~12 ~30 ~44 ~13
full BF16-mixed peak memory (G) ~34 ~35 ~48 ~24

We recommend carrying out the training on A100-80G or H20/H100 GPUs. If utilizing full BF16-Mixed precision training, the initial training stage can also be performed on A800-40G GPUs. GPUs with smaller memory, such as A30, you'll need to reduce the model size, such as decreasing model.pairformer.nblocks and diffusion_batch_size.

Inference

The model will be infered in BF16 Mixed precision, by default, the SampleDiffusion,ConfidenceHead part will still be infered in FP32 precision.

Below are reference examples of cuda memory usage (G).

Ntoken Natom Default Full BF16 Mixed
500 10000 5.6 5.1
1500 30000 24.8 19.2
2500 25000 52.2 34.8
3500 35000 67.6 38.2
4500 45000 77.0 59.2
5000 50000 OOM 72.8

The script in runner/inference.py will automatically change the default precision to compute SampleDiffusion,ConfidenceHead to avoid OOM as follows:

def update_inference_configs(configs: Any, N_token: int):
    # Setting the default inference configs for different N_token and N_atom
    # when N_token is larger than 3000, the default config might OOM even on a
    # A100 80G GPUS,
    if N_token > 3840:
        configs.skip_amp.confidence_head = False
        configs.skip_amp.sample_diffusion = False
    elif N_token > 2560:
        configs.skip_amp.confidence_head = False
        configs.skip_amp.sample_diffusion = True
    else:
        configs.skip_amp.confidence_head = True
        configs.skip_amp.sample_diffusion = True
    return configs