|
24 | 24 | # TODO add back support for slurm resilience. |
25 | 25 | # import nvidia_resiliency_ext.ptl_resiliency as res_module |
26 | 26 | import torch |
27 | | -from lightning.pytorch.callbacks import Callback, LearningRateMonitor, RichModelSummary |
| 27 | +from lightning.pytorch.callbacks import LearningRateMonitor, RichModelSummary |
28 | 28 | from megatron.core.distributed import DistributedDataParallelConfig |
29 | 29 | from megatron.core.enums import Fp8Recipe |
30 | 30 | from megatron.core.optimizer import OptimizerConfig |
|
53 | 53 | from bionemo.evo2.models.mamba import MAMBA_MODEL_OPTIONS, MambaModel, mamba_no_weight_decay_cond_with_embeddings |
54 | 54 | from bionemo.evo2.models.peft import Evo2LoRA |
55 | 55 | from bionemo.evo2.run.utils import infer_model_type, lookup_activation_func, patch_eden_tokenizer |
56 | | -from bionemo.evo2.utils.callbacks import GarbageCollectAtInferenceTime |
| 56 | +from bionemo.evo2.utils.callbacks import GarbageCollectAtInferenceTime, _FirstBatchCudaSync |
57 | 57 | from bionemo.evo2.utils.config import hyena_no_weight_decay_cond_with_embeddings |
58 | 58 | from bionemo.evo2.utils.logging.callbacks import TEVCallback |
59 | 59 | from bionemo.llm.utils.datamodule_utils import infer_global_batch_size |
@@ -864,27 +864,6 @@ def train(args: argparse.Namespace) -> nl.Trainer: |
864 | 864 | TEVCallback(), |
865 | 865 | ] |
866 | 866 |
|
867 | | - # First batch CUDA sync callback: adds barriers for the first training batch to avoid race condition |
868 | | - # See https://github.com/NVIDIA/bionemo-framework/issues/1301 for more details. |
869 | | - class _FirstBatchCudaSync(Callback): |
870 | | - def __init__(self): |
871 | | - self._done = False |
872 | | - |
873 | | - def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): |
874 | | - if not self._done and torch.cuda.is_available(): |
875 | | - torch.cuda.synchronize() |
876 | | - |
877 | | - def on_after_backward(self, trainer, pl_module): |
878 | | - if not self._done and torch.cuda.is_available(): |
879 | | - torch.cuda.synchronize() |
880 | | - |
881 | | - def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): |
882 | | - if not self._done and torch.cuda.is_available(): |
883 | | - torch.cuda.synchronize() |
884 | | - # Unset blocking for subsequent batches |
885 | | - os.environ.pop("CUDA_LAUNCH_BLOCKING", None) |
886 | | - self._done = True |
887 | | - |
888 | 867 | callbacks.append(_FirstBatchCudaSync()) |
889 | 868 |
|
890 | 869 | if args.garbage_collect_at_inference: |
@@ -1115,15 +1094,6 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): |
1115 | 1094 | enable_checkpointing=args.create_checkpoint_callback, |
1116 | 1095 | ) |
1117 | 1096 |
|
1118 | | - # Logger setup |
1119 | | - nemo_logger.setup( |
1120 | | - trainer, |
1121 | | - resume_if_exists=True, |
1122 | | - ) |
1123 | | - |
1124 | | - if auto_resume is not None: |
1125 | | - auto_resume.setup(trainer, model) |
1126 | | - |
1127 | 1097 | # Optimizer and scheduler setup |
1128 | 1098 | opt_config = OptimizerConfig( |
1129 | 1099 | optimizer="adam", |
@@ -1151,12 +1121,8 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): |
1151 | 1121 | opt = MegatronOptimizerModule( |
1152 | 1122 | opt_config, sched, no_weight_decay_cond=getattr(model_config, "hyena_no_weight_decay_cond_fn", None) |
1153 | 1123 | ) |
1154 | | - opt.connect(model) |
1155 | | - |
1156 | | - # Remove earlier warmup and hook logic; first-batch blocking is sufficient. |
| 1124 | + llm.train(model, data_module, trainer, log=nemo_logger, resume=auto_resume, optim=opt, tokenizer="data") |
1157 | 1125 |
|
1158 | | - # Start training |
1159 | | - trainer.fit(model, data_module) |
1160 | 1126 | return trainer |
1161 | 1127 |
|
1162 | 1128 |
|
|
0 commit comments