Skip to content

Commit d2efa75

Browse files
committed
Adding update that solves one logger issue
Signed-off-by: John St John <jstjohn@nvidia.com>
1 parent eba4382 commit d2efa75

File tree

2 files changed

+27
-37
lines changed

2 files changed

+27
-37
lines changed

sub-packages/bionemo-evo2/src/bionemo/evo2/run/train.py

Lines changed: 3 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
# TODO add back support for slurm resilience.
2525
# import nvidia_resiliency_ext.ptl_resiliency as res_module
2626
import torch
27-
from lightning.pytorch.callbacks import Callback, LearningRateMonitor, RichModelSummary
27+
from lightning.pytorch.callbacks import LearningRateMonitor, RichModelSummary
2828
from megatron.core.distributed import DistributedDataParallelConfig
2929
from megatron.core.enums import Fp8Recipe
3030
from megatron.core.optimizer import OptimizerConfig
@@ -53,7 +53,7 @@
5353
from bionemo.evo2.models.mamba import MAMBA_MODEL_OPTIONS, MambaModel, mamba_no_weight_decay_cond_with_embeddings
5454
from bionemo.evo2.models.peft import Evo2LoRA
5555
from bionemo.evo2.run.utils import infer_model_type, lookup_activation_func, patch_eden_tokenizer
56-
from bionemo.evo2.utils.callbacks import GarbageCollectAtInferenceTime
56+
from bionemo.evo2.utils.callbacks import GarbageCollectAtInferenceTime, _FirstBatchCudaSync
5757
from bionemo.evo2.utils.config import hyena_no_weight_decay_cond_with_embeddings
5858
from bionemo.evo2.utils.logging.callbacks import TEVCallback
5959
from bionemo.llm.utils.datamodule_utils import infer_global_batch_size
@@ -864,27 +864,6 @@ def train(args: argparse.Namespace) -> nl.Trainer:
864864
TEVCallback(),
865865
]
866866

867-
# First batch CUDA sync callback: adds barriers for the first training batch to avoid race condition
868-
# See https://github.com/NVIDIA/bionemo-framework/issues/1301 for more details.
869-
class _FirstBatchCudaSync(Callback):
870-
def __init__(self):
871-
self._done = False
872-
873-
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
874-
if not self._done and torch.cuda.is_available():
875-
torch.cuda.synchronize()
876-
877-
def on_after_backward(self, trainer, pl_module):
878-
if not self._done and torch.cuda.is_available():
879-
torch.cuda.synchronize()
880-
881-
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
882-
if not self._done and torch.cuda.is_available():
883-
torch.cuda.synchronize()
884-
# Unset blocking for subsequent batches
885-
os.environ.pop("CUDA_LAUNCH_BLOCKING", None)
886-
self._done = True
887-
888867
callbacks.append(_FirstBatchCudaSync())
889868

890869
if args.garbage_collect_at_inference:
@@ -1115,15 +1094,6 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
11151094
enable_checkpointing=args.create_checkpoint_callback,
11161095
)
11171096

1118-
# Logger setup
1119-
nemo_logger.setup(
1120-
trainer,
1121-
resume_if_exists=True,
1122-
)
1123-
1124-
if auto_resume is not None:
1125-
auto_resume.setup(trainer, model)
1126-
11271097
# Optimizer and scheduler setup
11281098
opt_config = OptimizerConfig(
11291099
optimizer="adam",
@@ -1151,12 +1121,8 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
11511121
opt = MegatronOptimizerModule(
11521122
opt_config, sched, no_weight_decay_cond=getattr(model_config, "hyena_no_weight_decay_cond_fn", None)
11531123
)
1154-
opt.connect(model)
1155-
1156-
# Remove earlier warmup and hook logic; first-batch blocking is sufficient.
1124+
llm.train(model, data_module, trainer, log=nemo_logger, resume=auto_resume, optim=opt, tokenizer="data")
11571125

1158-
# Start training
1159-
trainer.fit(model, data_module)
11601126
return trainer
11611127

11621128

sub-packages/bionemo-evo2/src/bionemo/evo2/utils/callbacks.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,35 @@
1414
# limitations under the License.
1515

1616
import gc
17+
import os
1718

1819
import torch
1920
from lightning.pytorch import Callback
2021

2122

23+
class _FirstBatchCudaSync(Callback):
24+
# TEMPORARY CALLBACK. Remove once bug is fixed.
25+
# First batch CUDA sync callback: adds barriers for the first training batch to avoid race condition
26+
# See https://github.com/NVIDIA/bionemo-framework/issues/1301 for more details.
27+
def __init__(self):
28+
self._done = False
29+
30+
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
31+
if not self._done and torch.cuda.is_available():
32+
torch.cuda.synchronize()
33+
34+
def on_after_backward(self, trainer, pl_module):
35+
if not self._done and torch.cuda.is_available():
36+
torch.cuda.synchronize()
37+
38+
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
39+
if not self._done and torch.cuda.is_available():
40+
torch.cuda.synchronize()
41+
# Unset blocking for subsequent batches
42+
os.environ.pop("CUDA_LAUNCH_BLOCKING", None)
43+
self._done = True
44+
45+
2246
class GarbageCollectAtInferenceTime(Callback):
2347
"""Callback to clean up CUDA memory before validation to prevent initialization errors."""
2448

0 commit comments

Comments
 (0)