Skip to content

Commit

Permalink
added test code
Browse files Browse the repository at this point in the history
  • Loading branch information
wonchul committed Jan 10, 2024
1 parent eca9d8b commit f6461f6
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions _posts/pytorch/2024-01-09-torchrun.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ torch.distributed.barrier()

`BatchSampler`에는 따로 `shuffle`이 존재하지 않기 때문에 `DistributedSampler`에서 설정해주어야 `shuffle`이 가능하다.



#### 리눅스에서 프로세스 갯수를 확인하기 위해서는 다음과 같으며, `num_workers`는 프로세스 갯수만큼 최대로 할당이 가능하다.
```cmd
Expand Down Expand Up @@ -184,18 +184,17 @@ if __name__ == '__main__':

train_dataset = SimpleDataset()
val_dataset = SimpleDataset()
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset=train_dataset, shuffle=True)
val_sampler = torch.utils.data.distributed.DistributedSampler(dataset=val_dataset, shuffle=False)

# case 1)
# train_sampler = torch.utils.data.distributed.DistributedSampler(dataset=train_dataset, shuffle=True)
# train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset,
# batch_size=int(batch_size/args.world_size),
# shuffle=False,
# num_workers=int(num_workers/args.world_size),
# sampler=train_sampler)

# case 2)
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset=train_dataset, shuffle=True)
train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, int(batch_size/args.world_size), drop_last=True)
train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_sampler=train_batch_sampler,
Expand Down

0 comments on commit f6461f6

Please sign in to comment.