diff --git a/torchglyph/data/sampler.py b/torchglyph/data/sampler.py index d2ee2bf..2499061 100644 --- a/torchglyph/data/sampler.py +++ b/torchglyph/data/sampler.py @@ -39,7 +39,12 @@ def __iter__(self): class RandomSortishSampler(_SortishSampler): def __iter__(self): - idx, key, reverse = [], [], True + idx, key = [], [] + + if distributed.is_initialized(): + reverse = distributed.get_rank() % 2 == 0 + else: + reverse = True while True: for batch in self.ds.shuffle().iter(batch_size=self.section_size, drop_last_batch=False):