Skip to content

Commit

Permalink
Fix regex get_data_files formatting for base paths (#6322)
Browse files Browse the repository at this point in the history
* Fix regex from formatting url base_path

* Test test_get_data_patterns from Hub

* simply match basename instead

* more tests

* minor

* remove comment

---------

Co-authored-by: Albert Villanova del Moral <[email protected]>
Co-authored-by: Quentin Lhoest <[email protected]>
Co-authored-by: Quentin Lhoest <[email protected]>
  • Loading branch information
4 people authored Oct 23, 2023
1 parent d82f3c2 commit 02ecc84
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 9 deletions.
6 changes: 4 additions & 2 deletions src/datasets/data_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,10 @@ def _get_data_files_patterns(
except FileNotFoundError:
continue
if len(data_files) > 0:
pattern = base_path + ("/" if base_path else "") + split_pattern
splits: Set[str] = {string_to_dict(p, glob_pattern_to_regex(pattern))["split"] for p in data_files}
splits: Set[str] = {
string_to_dict(xbasename(p), glob_pattern_to_regex(xbasename(split_pattern)))["split"]
for p in data_files
}
sorted_splits = [str(split) for split in DEFAULT_SPLITS if split in splits] + sorted(
splits - set(DEFAULT_SPLITS)
)
Expand Down
33 changes: 26 additions & 7 deletions tests/test_data_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,8 +502,10 @@ def mock_fs(file_paths: List[str]):
["data", "data/train.txt", "data.test.txt"]
```
"""

dir_paths = {file_path.rsplit("/", 1)[0] for file_path in file_paths if "/" in file_path}
file_paths = [file_path.split("://")[-1] for file_path in file_paths]
dir_paths = {
"/".join(file_path.split("/")[: i + 1]) for file_path in file_paths for i in range(file_path.count("/"))
}
fs_contents = [{"name": dir_path, "type": "directory"} for dir_path in dir_paths] + [
{"name": file_path, "type": "file", "size": 10} for file_path in file_paths
]
Expand All @@ -529,6 +531,7 @@ def ls(self, path, detail=True, refresh=True, **kwargs):
return DummyTestFS


@pytest.mark.parametrize("base_path", ["", "mock://", "my_dir"])
@pytest.mark.parametrize(
"data_file_per_split",
[
Expand Down Expand Up @@ -598,20 +601,36 @@ def ls(self, path, detail=True, refresh=True, **kwargs):
{"test": "test00001.txt"},
],
)
def test_get_data_files_patterns(data_file_per_split):
def test_get_data_files_patterns(base_path, data_file_per_split):
data_file_per_split = {k: v if isinstance(v, list) else [v] for k, v in data_file_per_split.items()}
file_paths = [file_path for split_file_paths in data_file_per_split.values() for file_path in split_file_paths]
data_file_per_split = {
split: [
base_path + ("/" if base_path and base_path[-1] != "/" else "") + file_path
for file_path in data_file_per_split[split]
]
for split in data_file_per_split
}
file_paths = sum(data_file_per_split.values(), [])
DummyTestFS = mock_fs(file_paths)
fs = DummyTestFS()

def resolver(pattern):
return [file_path for file_path in fs.glob(pattern) if fs.isfile(file_path)]
pattern = base_path + ("/" if base_path and base_path[-1] != "/" else "") + pattern
return [
file_path[len(fs._strip_protocol(base_path)) :].lstrip("/")
for file_path in fs.glob(pattern)
if fs.isfile(file_path)
]

patterns_per_split = _get_data_files_patterns(resolver)
patterns_per_split = _get_data_files_patterns(resolver, base_path=base_path)
assert list(patterns_per_split.keys()) == list(data_file_per_split.keys()) # Test split order with list()
for split, patterns in patterns_per_split.items():
matched = [file_path for pattern in patterns for file_path in resolver(pattern)]
assert matched == data_file_per_split[split]
expected = [
fs._strip_protocol(file_path)[len(fs._strip_protocol(base_path)) :].lstrip("/")
for file_path in data_file_per_split[split]
]
assert matched == expected


@pytest.mark.parametrize(
Expand Down
20 changes: 20 additions & 0 deletions tests/test_upstream_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
load_dataset_builder,
)
from datasets.config import METADATA_CONFIGS_FIELD
from datasets.data_files import get_data_patterns
from datasets.packaged_modules.folder_based_builder.folder_based_builder import (
FolderBasedBuilder,
FolderBasedBuilderConfig,
Expand Down Expand Up @@ -884,3 +885,22 @@ def test_load_dataset_with_metadata_file(self, temporary_repo, text_file_with_me
generator = builder._generate_examples(**gen_kwargs)
result = [example for _, example in generator]
assert len(result) == 1

def test_get_data_patterns(self, temporary_repo, tmp_path):
repo_dir = tmp_path / "test_get_data_patterns"
data_dir = repo_dir / "data"
data_dir.mkdir(parents=True)
data_file = data_dir / "train-00001-of-00009.parquet"
data_file.touch()
with temporary_repo() as repo_id:
self._api.create_repo(repo_id, token=self._token, repo_type="dataset")
self._api.upload_folder(
folder_path=str(repo_dir),
repo_id=repo_id,
repo_type="dataset",
token=self._token,
)
data_file_patterns = get_data_patterns(f"hf://datasets/{repo_id}")
assert data_file_patterns == {
"train": ["data/train-[0-9][0-9][0-9][0-9][0-9]-of-[0-9][0-9][0-9][0-9][0-9]*.*"]
}

0 comments on commit 02ecc84

Please sign in to comment.