Skip to content

Commit 294ddff

Browse files
WanZzzzzzqiyuw
andauthored
Avoid host-device sync in PTL logging (#14489)
* remove sync in logging Signed-off-by: qiyuw <[email protected]> * Apply isort and black reformatting Signed-off-by: WanZzzzzz <[email protected]> * add class and func docstrings in data_sampler.py for pylint Signed-off-by: qiyuw <[email protected]> * Apply isort and black reformatting Signed-off-by: WanZzzzzz <[email protected]> --------- Signed-off-by: qiyuw <[email protected]> Signed-off-by: WanZzzzzz <[email protected]> Co-authored-by: qiyuw <[email protected]> Co-authored-by: WanZzzzzz <[email protected]>
1 parent 0256c61 commit 294ddff

File tree

2 files changed

+88
-4
lines changed

2 files changed

+88
-4
lines changed

nemo/lightning/pytorch/plugins/data_sampler.py

100644100755
Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,23 +17,40 @@
1717
from typing import List, Literal, Optional
1818

1919
import lightning.pytorch as pl
20+
import torch
2021
from torch.utils.data import DataLoader
2122

2223
from nemo.lightning.megatron_parallel import MegatronStep
2324

2425

2526
class DataSampler:
27+
"""Abstract interface for data sampling and dataloader transformation.
28+
29+
Implementations can prepare state in ``setup`` and wrap/transform a
30+
``torch.utils.data.DataLoader`` in ``transform_dataloader`` to inject the
31+
appropriate sampler for the active strategy.
32+
"""
33+
2634
def connect(self, trainer: pl.Trainer):
35+
"""Attach the Lightning ``trainer`` to this sampler instance."""
2736
self.trainer = trainer
2837

2938
def setup(self, global_rank: int) -> None:
39+
"""Initialize any sampler-related state for the given ``global_rank``."""
3040
raise NotImplementedError()
3141

3242
def transform_dataloader(self, dataloader: DataLoader, consumed_samples: int = 0) -> DataLoader:
43+
"""Transform the dataloader."""
3344
raise NotImplementedError()
3445

3546

3647
class MegatronDataSampler(DataSampler):
48+
"""Megatron-LM data sampler.
49+
50+
Handles batch ramp-up, logging of consumed samples, and wiring Megatron's
51+
microbatch/global-batch calculations into NeMo Lightning training.
52+
"""
53+
3754
def __init__(
3855
self,
3956
seq_len: int,
@@ -60,11 +77,17 @@ def __init__(
6077
self.init_global_step = init_global_step
6178

6279
def setup(self, global_rank: int) -> None:
80+
"""Initialize Megatron microbatch calculator for this process."""
6381
from nemo.lightning.data import setup_microbatch_calculator
6482

6583
setup_microbatch_calculator(global_rank, self.micro_batch_size, self.global_batch_size, self.rampup_batch_size)
6684

6785
def transform_dataloader(self, dataloader: DataLoader, consumed_samples: int = 0) -> DataLoader:
86+
"""Wrap the dataloader with a Megatron-aware sampler.
87+
88+
The sampler accounts for data-parallel rank/size, ramp-up schedule, and
89+
train/validation/test modes.
90+
"""
6891
from megatron.core import parallel_state
6992

7093
from nemo.lightning.data import add_megatron_sampler
@@ -87,6 +110,13 @@ def transform_dataloader(self, dataloader: DataLoader, consumed_samples: int = 0
87110
)
88111

89112
def compute_consumed_samples(self, steps_since_resume=0) -> int:
113+
"""Compute the number of consumed samples since training start or resume.
114+
115+
If a ramp-up schedule is active, the value uses the previous and current
116+
global batch sizes. Otherwise it is derived from
117+
``data_parallel_size * micro_batch_size * num_microbatches`` times the
118+
number of steps since resume.
119+
"""
90120
from nemo.lightning.pytorch.strategies import MegatronStrategy
91121
from nemo.utils import AppState
92122

@@ -107,6 +137,7 @@ def compute_consumed_samples(self, steps_since_resume=0) -> int:
107137
# Megatron callbacks
108138

109139
def on_megatron_step_start(self, step: MegatronStep) -> MegatronStep:
140+
"""Inject Megatron step configuration such as sequence length and batch sizes."""
110141
return dataclasses.replace(
111142
step,
112143
seq_length=self.seq_len,
@@ -116,6 +147,11 @@ def on_megatron_step_start(self, step: MegatronStep) -> MegatronStep:
116147
)
117148

118149
def on_megatron_microbatches_start(self, step: MegatronStep) -> None:
150+
"""Trigger a validation/checkpoint boundary when global batch size changes.
151+
152+
During batch-size ramp-up we stop the trainer at the boundary so that a
153+
checkpoint can be saved and validation can run with the new batch size.
154+
"""
119155
if not step.trainer:
120156
return
121157

@@ -128,6 +164,11 @@ def on_megatron_microbatches_start(self, step: MegatronStep) -> None:
128164
step.trainer.should_stop = True
129165

130166
def on_megatron_step_end(self, step: MegatronStep) -> None:
167+
"""Log training metrics and update Megatron's microbatch calculator.
168+
169+
Logs ``consumed_samples`` and ``global_batch_size`` (GPU-friendly) and
170+
updates Megatron's internal number of microbatches for the next step.
171+
"""
131172
trainer = step.trainer
132173
pl_module = step.pl_module
133174

@@ -144,6 +185,12 @@ def on_megatron_step_end(self, step: MegatronStep) -> None:
144185
consumed_samples = self.compute_consumed_samples(step.step_i + 1 - self.init_global_step)
145186
if self.output_log and trainer and getattr(trainer, "training", False):
146187
# You may need to turn off logging, for example when doing trainer.predict(model, data)
188+
# pl_module.log () will trigger pageable H2D Memcpy which stalls CPU. Use pin_memory=True to avoid it
189+
consumed_samples = (
190+
consumed_samples
191+
if (torch.is_tensor(consumed_samples) and consumed_samples.is_cuda)
192+
else torch.tensor(consumed_samples, pin_memory=True).to("cuda", non_blocking=True)
193+
)
147194
pl_module.log(
148195
'consumed_samples',
149196
consumed_samples,
@@ -159,16 +206,22 @@ def on_megatron_step_end(self, step: MegatronStep) -> None:
159206
)
160207
if self.output_log and trainer:
161208
# You may need to turn off logging, for example when doing trainer.predict(model, data)
209+
current_global_batch_size = (
210+
self.current_global_batch_size
211+
if (torch.is_tensor(self.current_global_batch_size) and self.current_global_batch_size.is_cuda)
212+
else torch.tensor(self.current_global_batch_size, pin_memory=True).to("cuda", non_blocking=True)
213+
)
162214
pl_module.log(
163215
"global_batch_size",
164-
self.current_global_batch_size,
216+
current_global_batch_size,
165217
prog_bar=True,
166218
batch_size=1,
167219
)
168220
self.if_first_step = 1
169221

170222
@property
171223
def num_microbatches(self) -> int:
224+
"""Return the current number of microbatches from Megatron."""
172225
try:
173226
from megatron.core.num_microbatches_calculator import get_num_microbatches
174227

@@ -180,6 +233,7 @@ def num_microbatches(self) -> int:
180233

181234
@property
182235
def current_global_batch_size(self) -> int:
236+
"""Return the current effective global batch size (fallback to 1)."""
183237
try:
184238
from megatron.core.num_microbatches_calculator import get_current_global_batch_size
185239

nemo/lightning/pytorch/strategies/megatron_strategy.py

100644100755
Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -752,24 +752,39 @@ def training_step(self, dataloader_iter, *args: Any, **kwargs: Any) -> STEP_OUTP
752752
raise ValueError(f"Expected 'loss' in output dict, got {out.keys()}")
753753

754754
reduced_train_loss = out["loss"]
755-
755+
# pl_module.log () will trigger pageable H2D Memcpy which stalls CPU. Use pin_memory=True to avoid it
756+
global_step = (
757+
self.trainer.global_step
758+
if (torch.is_tensor(self.trainer.global_step) and self.trainer.global_step.is_cuda)
759+
else torch.tensor(self.trainer.global_step, pin_memory=True).to("cuda", non_blocking=True)
760+
)
756761
self.lightning_module.log(
757762
"global_step",
758-
self.trainer.global_step,
763+
global_step,
759764
prog_bar=True,
760765
batch_size=1,
761766
)
762767

763768
self.lightning_module.log(
764769
"step",
765-
self.trainer.global_step,
770+
global_step,
766771
)
767772

768773
if self.log_memory_usage:
769774
# maximum GPU memory that has been managed by the caching allocator
770775
max_memory_reserved = torch.cuda.max_memory_reserved()
776+
max_memory_reserved = (
777+
max_memory_reserved
778+
if (torch.is_tensor(max_memory_reserved) and max_memory_reserved.is_cuda)
779+
else torch.tensor(max_memory_reserved, pin_memory=True).to("cuda", non_blocking=True)
780+
)
771781
# maximum GPU memory that has been occupied by active tensors
772782
max_memory_allocated = torch.cuda.max_memory_allocated()
783+
max_memory_allocated = (
784+
max_memory_allocated
785+
if (torch.is_tensor(max_memory_allocated) and max_memory_allocated.is_cuda)
786+
else torch.tensor(max_memory_allocated, pin_memory=True).to("cuda", non_blocking=True)
787+
)
773788
self.lightning_module.log(
774789
"max_memory_reserved",
775790
max_memory_reserved,
@@ -787,6 +802,11 @@ def training_step(self, dataloader_iter, *args: Any, **kwargs: Any) -> STEP_OUTP
787802
# p2p now, broadcast later at ckpt. only with pp, some ranks will log 0.0
788803
# WHICH IS OK because we broadcast later at checkpoint time
789804
_strategy_lib._sync_from_last_pipeline_stage(reduced_train_loss, broadcast=False)
805+
reduced_train_loss = (
806+
reduced_train_loss
807+
if (torch.is_tensor(reduced_train_loss) and reduced_train_loss.is_cuda)
808+
else torch.tensor(reduced_train_loss, pin_memory=True).to("cuda", non_blocking=True)
809+
)
790810
self.lightning_module.log(
791811
"reduced_train_loss", reduced_train_loss, prog_bar=True, batch_size=1, sync_dist=False
792812
)
@@ -813,8 +833,18 @@ def optimizer_step(
813833
if isinstance(optimizer, McoreDistributedOptimizer):
814834
optimizer_output, grad_norm, num_zeros_in_grad = optimizer_output
815835
if grad_norm is not None:
836+
grad_norm = (
837+
grad_norm
838+
if (torch.is_tensor(grad_norm) and grad_norm.is_cuda)
839+
else torch.tensor(grad_norm, pin_memory=True).to("cuda", non_blocking=True)
840+
)
816841
self.lightning_module.log('grad_norm', grad_norm, batch_size=1)
817842
if num_zeros_in_grad is not None:
843+
num_zeros_in_grad = (
844+
num_zeros_in_grad
845+
if (torch.is_tensor(num_zeros_in_grad) and num_zeros_in_grad.is_cuda)
846+
else torch.tensor(num_zeros_in_grad, pin_memory=True).to("cuda", non_blocking=True)
847+
)
818848
self.lightning_module.log('num_zeros_in_grad', num_zeros_in_grad, batch_size=1)
819849

820850
return optimizer_output

0 commit comments

Comments
 (0)