Skip to content

Commit

Permalink
Preserve split order in DataFilesDict (#6198)
Browse files Browse the repository at this point in the history
* Test split order in DataFilesDict

* Remove key sorting in DataFilesDict

* Fix test_cache_dir_for_data_files
  • Loading branch information
albertvillanova authored Aug 31, 2023
1 parent 0f8c580 commit 00cb5cc
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 15 deletions.
11 changes: 0 additions & 11 deletions src/datasets/data_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,17 +682,6 @@ def from_patterns(
)
return out

def __reduce__(self):
"""
To make sure the order of the keys doesn't matter when pickling and hashing:
>>> from datasets.data_files import DataFilesDict
>>> from datasets.fingerprint import Hasher
>>> assert Hasher.hash(DataFilesDict(a=[], b=[])) == Hasher.hash(DataFilesDict(b=[], a=[]))
"""
return DataFilesDict, (dict(sorted(self.items())),)

def filter_extensions(self, extensions: List[str]) -> "DataFilesDict":
out = type(self)()
for key, data_files_list in self.items():
Expand Down
4 changes: 0 additions & 4 deletions tests/test_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,10 +759,6 @@ def test_cache_dir_for_data_files(self):
cache_dir=tmp_dir, data_files={"train": [dummy_data1], "test": dummy_data2}
)
self.assertEqual(builder.cache_dir, other_builder.cache_dir)
other_builder = DummyGeneratorBasedBuilder(
cache_dir=tmp_dir, data_files={"test": dummy_data2, "train": dummy_data1}
)
self.assertEqual(builder.cache_dir, other_builder.cache_dir)
other_builder = DummyGeneratorBasedBuilder(
cache_dir=tmp_dir, data_files={"train": dummy_data1, "validation": dummy_data2}
)
Expand Down
8 changes: 8 additions & 0 deletions tests/test_data_files.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import os
from pathlib import Path, PurePath
from typing import List
Expand Down Expand Up @@ -385,6 +386,13 @@ def test_DataFilesList_from_patterns_raises_FileNotFoundError(complex_data_dir):
DataFilesList.from_patterns(["file_that_doesnt_exist.txt"], complex_data_dir)


class TestDataFilesDict:
def test_key_order_after_copy(self):
data_files = DataFilesDict({"train": "train.csv", "test": "test.csv"})
copied_data_files = copy.deepcopy(data_files)
assert list(copied_data_files.keys()) == list(data_files.keys()) # test split order with list()


@pytest.mark.parametrize("pattern", _TEST_PATTERNS)
def test_DataFilesDict_from_patterns_in_dataset_repository(
hub_dataset_repo_path, hub_dataset_repo_patterns_results, pattern
Expand Down

0 comments on commit 00cb5cc

Please sign in to comment.