-
Notifications
You must be signed in to change notification settings - Fork 3.1k
Description
Hi !
I’m working on training with a large-scale dataset (100+ Parquet files) using lazy loading, and I’m struggling to understand/optimize the num_shards setting— in the lerobot repo: streaming_datasets.py:
from datasets import load_dataset
self.hf_dataset: datasets.IterableDataset = load_dataset(
self.repo_id if not self.streaming_from_local else str(self.root),
split="train",
streaming=self.streaming,
data_files="data/*/*.parquet",
revision=self.revision,
)
self.num_shards = min(self.hf_dataset.num_shards, max_num_shards)
dataloader = torch.utils.data.DataLoader(
datasets[sub_idx],
num_workers=datasets[sub_idx].num_shards, #cfg.num_workers,
batch_size=cfg.batch_size,
shuffle=shuffle and not cfg.dataset.streaming,
sampler=sampler,
collate_fn=FlowerDataCollator(),
pin_memory=device.type == "cuda",
drop_last=True,
prefetch_factor=2 if cfg.num_workers > 0 else None,
)
What exactly does hf_dataset.num_shards represent? Is it safe to manually override/edit num_shards?
My batch loading is slower than expected (2-3s per batch) despite num_worker cannot be bigger with warning: Too many dataloader workers: 4 (max is dataset.num_shards=3). Stopping 1 dataloader workers.
Even use num_workers=datasets[sub_idx].num_shards, the waring is still exist! (my num_worker is 4 and hf_dataset.num_shards is 100+, so the datasets[sub_idx].num_shards=4)
Why does the "too many workers" warning persist even when num_workers equals dataset.num_shards—and how do I fix this?
Thanks so much for any insights or help with this! Really appreciate your time and expertise 😊