Skip to content

Commit 85ec396

Browse files
Pu Chenfacebook-github-bot
authored andcommitted
Enhance TrainPipelineSparseDist logging to help differentiate data loading patterns in train pipeline (#3350)
Summary: Pull Request resolved: #3350 Observed inconsistent data loading behaviors in APS train_module_train_step. Expected 3 batches loaded on first invocation of train loop, but sometimes only 1 batch loading shows in trace ([link](https://www.internalfb.com/intern/sbdive/?id=tree%2Fttfb%2Fttfb_ai_lab_APS_mtml_ctr_cmf_rc1_baseline_gpu-f788555024-fbd033a0-89b2-4b72-b540-346901657b25-treatment-1&bucket=sbdive)) despite increasing trace frequency from 500ms to 50ms. Added logs to differentiate data loading patterns. Perf impact: logs are added only when data loader is exhausted Reviewed By: andywag Differential Revision: D81418443 fbshipit-source-id: 98ccb5cb480bf31572e99b9796cf375ef676d125
1 parent 60f7f87 commit 85ec396

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

torchrec/distributed/train_pipeline/train_pipelines.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,11 @@ def __init__(
446446
self._apply_jit = apply_jit
447447
self._enqueue_batch_after_forward = enqueue_batch_after_forward
448448

449+
logger.info(
450+
f"enqueue_batch_after_forward: {self._enqueue_batch_after_forward} "
451+
f"execute_all_batches: {self._execute_all_batches}"
452+
)
453+
449454
if device.type == "cuda":
450455
# use two data streams to support two concurrent batches
451456
# Dynamo does not support cuda stream specificaiton,
@@ -624,6 +629,7 @@ def fill_pipeline(self, dataloader_iter: Iterator[In]) -> None:
624629

625630
# batch i, data (batch) and context
626631
if not self.enqueue_batch(dataloader_iter):
632+
logger.info("fill_pipeline: failed to load batch i")
627633
return
628634

629635
# modify the (sharded) sparse module forward, and invoke the first part of input_dist
@@ -637,6 +643,7 @@ def fill_pipeline(self, dataloader_iter: Iterator[In]) -> None:
637643

638644
# batch i+1
639645
if not self.enqueue_batch(dataloader_iter):
646+
logger.info("fill_pipeline: failed to load batch i+1")
640647
return
641648

642649
def _wait_for_batch(self) -> None:
@@ -801,7 +808,14 @@ def copy_batch_to_gpu(
801808
if batch is not None:
802809
batch = _to_device(batch, self._device, non_blocking=True)
803810
elif not self._execute_all_batches:
811+
logger.info(
812+
"copy_batch_to_gpu: raising StopIteration for None Batch (execute_all_batches=False)"
813+
)
804814
raise StopIteration
815+
else:
816+
logger.info(
817+
"copy_batch_to_gpu: returning None batch (execute_all_batches=True)"
818+
)
805819
return batch, context
806820

807821
def _next_batch(self, dataloader_iter: Iterator[In]) -> Optional[In]:
@@ -820,6 +834,9 @@ def _next_batch(self, dataloader_iter: Iterator[In]) -> Optional[In]:
820834
batch = next(dataloader_iter, None)
821835
if batch is None:
822836
self._dataloader_exhausted = True
837+
838+
if batch is None:
839+
logger.info("_next_batch: dataloader exhausted")
823840
return batch
824841

825842
def start_sparse_data_dist(

0 commit comments

Comments
 (0)