Some settings follow those in the AlphaFold 3 paper, The table below shows the training settings for different fine-tuning stages:
Arguments | Initial training | Fine tuning 1 | Fine tuning 2 | Fine tuning 3 |
---|---|---|---|---|
train_crop_size |
384 | 640 | 768 | 768 |
diffusion_batch_size |
48 | 32 | 32 | 32 |
loss.weight.alpha_pae |
0 | 0 | 0 | 1.0 |
loss.weight.alpha_bond |
0 | 1.0 | 1.0 | 0 |
loss.weight.smooth_lddt |
1.0 | 0 | 0 | 0 |
loss.weight.alpha_confidence |
1e-4 | 1e-4 | 1e-4 | 1e-4 |
loss.weight.alpha_diffusion |
4.0 | 4.0 | 4.0 | 0 |
loss.weight.alpha_distogram |
0.03 | 0.03 | 0.03 | 0 |
train_confidence_only |
False | False | False | True |
full BF16-mixed speed(A100, s/step) | ~12 | ~30 | ~44 | ~13 |
full BF16-mixed peak memory (G) | ~34 | ~35 | ~48 | ~24 |
We recommend carrying out the training on A100-80G or H20/H100 GPUs. If utilizing full BF16-Mixed precision training, the initial training stage can also be performed on A800-40G GPUs. GPUs with smaller memory, such as A30, you'll need to reduce the model size, such as decreasing model.pairformer.nblocks
and diffusion_batch_size
.
The model will be infered in BF16 Mixed precision, by default, the SampleDiffusion
,ConfidenceHead
part will still be infered in FP32 precision.
Below are reference examples of cuda memory usage (G).
Ntoken | Natom | Default | Full BF16 Mixed |
---|---|---|---|
500 | 10000 | 5.6 | 5.1 |
1500 | 30000 | 24.8 | 19.2 |
2500 | 25000 | 52.2 | 34.8 |
3500 | 35000 | 67.6 | 38.2 |
4500 | 45000 | 77.0 | 59.2 |
5000 | 50000 | OOM | 72.8 |
The script in runner/inference.py will automatically change the default precision to compute SampleDiffusion
,ConfidenceHead
to avoid OOM as follows:
def update_inference_configs(configs: Any, N_token: int):
# Setting the default inference configs for different N_token and N_atom
# when N_token is larger than 3000, the default config might OOM even on a
# A100 80G GPUS,
if N_token > 3840:
configs.skip_amp.confidence_head = False
configs.skip_amp.sample_diffusion = False
elif N_token > 2560:
configs.skip_amp.confidence_head = False
configs.skip_amp.sample_diffusion = True
else:
configs.skip_amp.confidence_head = True
configs.skip_amp.sample_diffusion = True
return configs