17
17
from typing import List , Literal , Optional
18
18
19
19
import lightning .pytorch as pl
20
+ import torch
20
21
from torch .utils .data import DataLoader
21
22
22
23
from nemo .lightning .megatron_parallel import MegatronStep
23
24
24
25
25
26
class DataSampler :
27
+ """Abstract interface for data sampling and dataloader transformation.
28
+
29
+ Implementations can prepare state in ``setup`` and wrap/transform a
30
+ ``torch.utils.data.DataLoader`` in ``transform_dataloader`` to inject the
31
+ appropriate sampler for the active strategy.
32
+ """
33
+
26
34
def connect (self , trainer : pl .Trainer ):
35
+ """Attach the Lightning ``trainer`` to this sampler instance."""
27
36
self .trainer = trainer
28
37
29
38
def setup (self , global_rank : int ) -> None :
39
+ """Initialize any sampler-related state for the given ``global_rank``."""
30
40
raise NotImplementedError ()
31
41
32
42
def transform_dataloader (self , dataloader : DataLoader , consumed_samples : int = 0 ) -> DataLoader :
43
+ """Transform the dataloader."""
33
44
raise NotImplementedError ()
34
45
35
46
36
47
class MegatronDataSampler (DataSampler ):
48
+ """Megatron-LM data sampler.
49
+
50
+ Handles batch ramp-up, logging of consumed samples, and wiring Megatron's
51
+ microbatch/global-batch calculations into NeMo Lightning training.
52
+ """
53
+
37
54
def __init__ (
38
55
self ,
39
56
seq_len : int ,
@@ -60,11 +77,17 @@ def __init__(
60
77
self .init_global_step = init_global_step
61
78
62
79
def setup (self , global_rank : int ) -> None :
80
+ """Initialize Megatron microbatch calculator for this process."""
63
81
from nemo .lightning .data import setup_microbatch_calculator
64
82
65
83
setup_microbatch_calculator (global_rank , self .micro_batch_size , self .global_batch_size , self .rampup_batch_size )
66
84
67
85
def transform_dataloader (self , dataloader : DataLoader , consumed_samples : int = 0 ) -> DataLoader :
86
+ """Wrap the dataloader with a Megatron-aware sampler.
87
+
88
+ The sampler accounts for data-parallel rank/size, ramp-up schedule, and
89
+ train/validation/test modes.
90
+ """
68
91
from megatron .core import parallel_state
69
92
70
93
from nemo .lightning .data import add_megatron_sampler
@@ -87,6 +110,13 @@ def transform_dataloader(self, dataloader: DataLoader, consumed_samples: int = 0
87
110
)
88
111
89
112
def compute_consumed_samples (self , steps_since_resume = 0 ) -> int :
113
+ """Compute the number of consumed samples since training start or resume.
114
+
115
+ If a ramp-up schedule is active, the value uses the previous and current
116
+ global batch sizes. Otherwise it is derived from
117
+ ``data_parallel_size * micro_batch_size * num_microbatches`` times the
118
+ number of steps since resume.
119
+ """
90
120
from nemo .lightning .pytorch .strategies import MegatronStrategy
91
121
from nemo .utils import AppState
92
122
@@ -107,6 +137,7 @@ def compute_consumed_samples(self, steps_since_resume=0) -> int:
107
137
# Megatron callbacks
108
138
109
139
def on_megatron_step_start (self , step : MegatronStep ) -> MegatronStep :
140
+ """Inject Megatron step configuration such as sequence length and batch sizes."""
110
141
return dataclasses .replace (
111
142
step ,
112
143
seq_length = self .seq_len ,
@@ -116,6 +147,11 @@ def on_megatron_step_start(self, step: MegatronStep) -> MegatronStep:
116
147
)
117
148
118
149
def on_megatron_microbatches_start (self , step : MegatronStep ) -> None :
150
+ """Trigger a validation/checkpoint boundary when global batch size changes.
151
+
152
+ During batch-size ramp-up we stop the trainer at the boundary so that a
153
+ checkpoint can be saved and validation can run with the new batch size.
154
+ """
119
155
if not step .trainer :
120
156
return
121
157
@@ -128,6 +164,11 @@ def on_megatron_microbatches_start(self, step: MegatronStep) -> None:
128
164
step .trainer .should_stop = True
129
165
130
166
def on_megatron_step_end (self , step : MegatronStep ) -> None :
167
+ """Log training metrics and update Megatron's microbatch calculator.
168
+
169
+ Logs ``consumed_samples`` and ``global_batch_size`` (GPU-friendly) and
170
+ updates Megatron's internal number of microbatches for the next step.
171
+ """
131
172
trainer = step .trainer
132
173
pl_module = step .pl_module
133
174
@@ -144,6 +185,12 @@ def on_megatron_step_end(self, step: MegatronStep) -> None:
144
185
consumed_samples = self .compute_consumed_samples (step .step_i + 1 - self .init_global_step )
145
186
if self .output_log and trainer and getattr (trainer , "training" , False ):
146
187
# You may need to turn off logging, for example when doing trainer.predict(model, data)
188
+ # pl_module.log () will trigger pageable H2D Memcpy which stalls CPU. Use pin_memory=True to avoid it
189
+ consumed_samples = (
190
+ consumed_samples
191
+ if (torch .is_tensor (consumed_samples ) and consumed_samples .is_cuda )
192
+ else torch .tensor (consumed_samples , pin_memory = True ).to ("cuda" , non_blocking = True )
193
+ )
147
194
pl_module .log (
148
195
'consumed_samples' ,
149
196
consumed_samples ,
@@ -159,16 +206,22 @@ def on_megatron_step_end(self, step: MegatronStep) -> None:
159
206
)
160
207
if self .output_log and trainer :
161
208
# You may need to turn off logging, for example when doing trainer.predict(model, data)
209
+ current_global_batch_size = (
210
+ self .current_global_batch_size
211
+ if (torch .is_tensor (self .current_global_batch_size ) and self .current_global_batch_size .is_cuda )
212
+ else torch .tensor (self .current_global_batch_size , pin_memory = True ).to ("cuda" , non_blocking = True )
213
+ )
162
214
pl_module .log (
163
215
"global_batch_size" ,
164
- self . current_global_batch_size ,
216
+ current_global_batch_size ,
165
217
prog_bar = True ,
166
218
batch_size = 1 ,
167
219
)
168
220
self .if_first_step = 1
169
221
170
222
@property
171
223
def num_microbatches (self ) -> int :
224
+ """Return the current number of microbatches from Megatron."""
172
225
try :
173
226
from megatron .core .num_microbatches_calculator import get_num_microbatches
174
227
@@ -180,6 +233,7 @@ def num_microbatches(self) -> int:
180
233
181
234
@property
182
235
def current_global_batch_size (self ) -> int :
236
+ """Return the current effective global batch size (fallback to 1)."""
183
237
try :
184
238
from megatron .core .num_microbatches_calculator import get_current_global_batch_size
185
239
0 commit comments