@@ -446,6 +446,11 @@ def __init__(
446
446
self ._apply_jit = apply_jit
447
447
self ._enqueue_batch_after_forward = enqueue_batch_after_forward
448
448
449
+ logger .info (
450
+ f"enqueue_batch_after_forward: { self ._enqueue_batch_after_forward } "
451
+ f"execute_all_batches: { self ._execute_all_batches } "
452
+ )
453
+
449
454
if device .type == "cuda" :
450
455
# use two data streams to support two concurrent batches
451
456
# Dynamo does not support cuda stream specificaiton,
@@ -624,6 +629,7 @@ def fill_pipeline(self, dataloader_iter: Iterator[In]) -> None:
624
629
625
630
# batch i, data (batch) and context
626
631
if not self .enqueue_batch (dataloader_iter ):
632
+ logger .info ("fill_pipeline: failed to load batch i" )
627
633
return
628
634
629
635
# modify the (sharded) sparse module forward, and invoke the first part of input_dist
@@ -637,6 +643,7 @@ def fill_pipeline(self, dataloader_iter: Iterator[In]) -> None:
637
643
638
644
# batch i+1
639
645
if not self .enqueue_batch (dataloader_iter ):
646
+ logger .info ("fill_pipeline: failed to load batch i+1" )
640
647
return
641
648
642
649
def _wait_for_batch (self ) -> None :
@@ -801,7 +808,14 @@ def copy_batch_to_gpu(
801
808
if batch is not None :
802
809
batch = _to_device (batch , self ._device , non_blocking = True )
803
810
elif not self ._execute_all_batches :
811
+ logger .info (
812
+ "copy_batch_to_gpu: raising StopIteration for None Batch (execute_all_batches=False)"
813
+ )
804
814
raise StopIteration
815
+ else :
816
+ logger .info (
817
+ "copy_batch_to_gpu: returning None batch (execute_all_batches=True)"
818
+ )
805
819
return batch , context
806
820
807
821
def _next_batch (self , dataloader_iter : Iterator [In ]) -> Optional [In ]:
@@ -820,6 +834,9 @@ def _next_batch(self, dataloader_iter: Iterator[In]) -> Optional[In]:
820
834
batch = next (dataloader_iter , None )
821
835
if batch is None :
822
836
self ._dataloader_exhausted = True
837
+
838
+ if batch is None :
839
+ logger .info ("_next_batch: dataloader exhausted" )
823
840
return batch
824
841
825
842
def start_sparse_data_dist (
0 commit comments