Skip to content

Commit

Permalink
Use yaml instead of get data patterns when possible (#6154)
Browse files Browse the repository at this point in the history
* use yaml data_files instead of get_data_patterns when possible

* minor fix docstring

* update comment
  • Loading branch information
lhoestq authored Aug 17, 2023
1 parent 546c7bb commit 5ca2ba0
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 9 deletions.
2 changes: 1 addition & 1 deletion src/datasets/data_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def resolve_pattern(
allowed_extensions (Optional[list], optional): White-list of file extensions to use. Defaults to None (all extensions).
For example: allowed_extensions=[".csv", ".json", ".txt", ".parquet"]
Returns:
List[Union[Path, Url]]: List of paths or URLs to the local or remote files that match the patterns.
List[str]: List of paths or URLs to the local or remote files that match the patterns.
"""
if is_relative_path(pattern):
pattern = xjoin(base_path, pattern)
Expand Down
24 changes: 16 additions & 8 deletions src/datasets/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,10 +848,15 @@ def get_module(self) -> DatasetModule:
dataset_card_data = DatasetCard.load(readme_path).data if os.path.isfile(readme_path) else DatasetCardData()
metadata_configs = MetadataConfigs.from_dataset_card_data(dataset_card_data)
dataset_infos = DatasetInfosDict.from_dataset_card_data(dataset_card_data)
# even if metadata_configs_dict is not None (which means that we will resolve files for each config later)
# we cannot skip resolving all files because we need to infer module name by files extensions
# we need a set of data files to find which dataset builder to use
# because we need to infer module name by files extensions
base_path = Path(self.path, self.data_dir or "").expanduser().resolve().as_posix()
patterns = sanitize_patterns(self.data_files) if self.data_files is not None else get_data_patterns(base_path)
if self.data_files is not None:
patterns = sanitize_patterns(self.data_files)
if metadata_configs and "data_files" in next(iter(metadata_configs.values())):
patterns = sanitize_patterns(next(iter(metadata_configs.values()))["data_files"])
else:
patterns = get_data_patterns(base_path)
data_files = DataFilesDict.from_patterns(
patterns,
base_path=base_path,
Expand Down Expand Up @@ -1027,11 +1032,14 @@ def get_module(self) -> DatasetModule:
dataset_card_data = DatasetCardData()
metadata_configs = MetadataConfigs.from_dataset_card_data(dataset_card_data)
dataset_infos = DatasetInfosDict.from_dataset_card_data(dataset_card_data)
patterns = (
sanitize_patterns(self.data_files)
if self.data_files is not None
else get_data_patterns(base_path, download_config=self.download_config)
)
# we need a set of data files to find which dataset builder to use
# because we need to infer module name by files extensions
if self.data_files is not None:
patterns = sanitize_patterns(self.data_files)
if metadata_configs and "data_files" in next(iter(metadata_configs.values())):
patterns = sanitize_patterns(next(iter(metadata_configs.values()))["data_files"])
else:
patterns = get_data_patterns(base_path, download_config=self.download_config)
data_files = DataFilesDict.from_patterns(
patterns,
base_path=base_path,
Expand Down

0 comments on commit 5ca2ba0

Please sign in to comment.