diff --git a/tensorflow_datasets/datasets/robonet/robonet_dataset_builder.py b/tensorflow_datasets/datasets/robonet/robonet_dataset_builder.py index 43ff94c8344..4cd448fbe40 100644 --- a/tensorflow_datasets/datasets/robonet/robonet_dataset_builder.py +++ b/tensorflow_datasets/datasets/robonet/robonet_dataset_builder.py @@ -165,28 +165,29 @@ def _build_pcollection(self, pipeline, filedir): """Generate examples as dicts.""" beam = tfds.core.lazy_imports.apache_beam - def _process_example(filename): - """Converts one video from hdf5 format.""" - h5py = tfds.core.lazy_imports.h5py - with h5py.File(filename) as hf: - video_bytes = hf['env']['cam0_video']['frames'][:].tobytes() - states = hf['env']['state'][:].astype(np.float32) - states = np.pad( - states, ((0, 0), (0, STATES_DIM - states.shape[1])), 'constant' - ) - actions = hf['policy']['actions'][:].astype(np.float32) - actions = np.pad( - actions, ((0, 0), (0, ACTIONS_DIM - actions.shape[1])), 'constant' - ) - - basename = os.path.basename(filename) - features = { - 'video': video_bytes, - 'actions': actions, - 'states': states, - 'filename': basename, - } - return basename, features - filenames = tf.io.gfile.glob(os.path.join(filedir, '*.hdf5')) return pipeline | beam.Create(filenames) | beam.Map(_process_example) + + +def _process_example(filename): + """Converts one video from hdf5 format.""" + h5py = tfds.core.lazy_imports.h5py + with h5py.File(filename) as hf: + video_bytes = hf['env']['cam0_video']['frames'][:].tobytes() + states = hf['env']['state'][:].astype(np.float32) + states = np.pad( + states, ((0, 0), (0, STATES_DIM - states.shape[1])), 'constant' + ) + actions = hf['policy']['actions'][:].astype(np.float32) + actions = np.pad( + actions, ((0, 0), (0, ACTIONS_DIM - actions.shape[1])), 'constant' + ) + + basename = os.path.basename(filename) + features = { + 'video': video_bytes, + 'actions': actions, + 'states': states, + 'filename': basename, + } + return basename, features diff --git a/tensorflow_datasets/rl_unplugged/rlu_rwrl/rlu_rwrl.py b/tensorflow_datasets/rl_unplugged/rlu_rwrl/rlu_rwrl.py index 44a39d9cc69..fc40178f3a0 100644 --- a/tensorflow_datasets/rl_unplugged/rlu_rwrl/rlu_rwrl.py +++ b/tensorflow_datasets/rl_unplugged/rlu_rwrl/rlu_rwrl.py @@ -245,6 +245,31 @@ def tf_feature_to_tfds_feature( raise ValueError(f'Unsupported type {type(nested)}') +def _generate_examples_one_file_fn( + path, + feature_description, + tf_example_to_step_ds_fn, +) -> Generator[Tuple[str, Dict[str, Any]], None, None]: + """Yields examples from one file.""" + counter = 0 + key_prefix = os.path.basename(path) + # Dataset of tf.Examples containing full episodes. + example_ds = tf.data.TFRecordDataset(filenames=str(path)) + # Dataset of episodes, each represented as a dataset of steps. + episode_ds = example_ds.map( + functools.partial( + tf_example_to_step_ds_fn, + feature_description=feature_description, + ), + num_parallel_calls=tf.data.experimental.AUTOTUNE, + ) + episode_ds = tfds.as_numpy(episode_ds) + for e in episode_ds: + episode_id = counter + yield f'{key_prefix}/{episode_id}', e + counter += 1 + + class RluRwrl(rlu_common.RLUBuilder): """DatasetBuilder for rlu_rwrl dataset.""" @@ -368,26 +393,8 @@ def _generate_examples(self, paths): feature_description = tf_example_to_feature_description(example_item) - def _generate_examples_one_file( - path, - ) -> Generator[Tuple[str, Dict[str, Any]], None, None]: - """Yields examples from one file.""" - counter = 0 - key_prefix = os.path.basename(path) - # Dataset of tf.Examples containing full episodes. - example_ds = tf.data.TFRecordDataset(filenames=str(path)) - # Dataset of episodes, each represented as a dataset of steps. - episode_ds = example_ds.map( - functools.partial( - self.tf_example_to_step_ds, - feature_description=feature_description, - ), - num_parallel_calls=tf.data.experimental.AUTOTUNE, - ) - episode_ds = tfds.as_numpy(episode_ds) - for e in episode_ds: - episode_id = counter - yield f'{key_prefix}/{episode_id}', e - counter += 1 - - return beam.Create(file_paths) | beam.FlatMap(_generate_examples_one_file) + return beam.Create(file_paths) | beam.FlatMap( + _generate_examples_one_file_fn, + feature_description=feature_description, + tf_example_to_step_ds_fn=self.tf_example_to_step_ds, + ) diff --git a/tensorflow_datasets/robotics/dataset_importer_builder.py b/tensorflow_datasets/robotics/dataset_importer_builder.py index 957ed11cbb9..e20c9972484 100644 --- a/tensorflow_datasets/robotics/dataset_importer_builder.py +++ b/tensorflow_datasets/robotics/dataset_importer_builder.py @@ -18,6 +18,7 @@ from __future__ import annotations import abc +import functools import os from typing import Any @@ -32,6 +33,24 @@ +def _dataset_importer_converter_fn(example, decode_fn, keys_to_strip): + """Beam converter function for DatasetImporterBuilder.""" + # Decode the RLDS Episode and transform it to numpy. + example_out = dict(example) + example_out['steps'] = tf.data.Dataset.from_tensor_slices( + example_out['steps'] + ).map(decode_fn) + steps = list(iter(example_out['steps'].take(-1))) + example_out['steps'] = steps + example_out = dataset_utils.as_numpy(example_out) + example_id = example_out['tfds_id'].decode('utf-8') + del example_out['tfds_id'] + for key in keys_to_strip: + if key in example_out: + del example_out[key] + yield example_id, example_out + + class DatasetImporterBuilder( tfds.core.GeneratorBasedBuilder, skip_registration=True ): @@ -118,24 +137,11 @@ def _generate_examples( decode_fn = builder.info.features['steps'].feature.decode_example - def converter_fn(example): - # Decode the RLDS Episode and transform it to numpy. - example_out = dict(example) - example_out['steps'] = tf.data.Dataset.from_tensor_slices( - example_out['steps'] - ).map(decode_fn) - steps = list(iter(example_out['steps'].take(-1))) - example_out['steps'] = steps - - example_out = dataset_utils.as_numpy(example_out) - - example_id = example_out['tfds_id'].decode('utf-8') - del example_out['tfds_id'] - for key in self.KEYS_TO_STRIP: - if key in example_out: - del example_out[key] - - yield example_id, example_out + converter_fn = functools.partial( + _dataset_importer_converter_fn, + decode_fn=decode_fn, + keys_to_strip=self.KEYS_TO_STRIP, + ) return f'read_tfds_dataset@{split}' >> beam_utils.ReadFromTFDS( builder=builder, diff --git a/tensorflow_datasets/structured/covid19/covid19.py b/tensorflow_datasets/structured/covid19/covid19.py index 806f08ca07a..5e3a5a4df7c 100644 --- a/tensorflow_datasets/structured/covid19/covid19.py +++ b/tensorflow_datasets/structured/covid19/covid19.py @@ -20,6 +20,7 @@ response, weather, and more. """ +import functools import numpy as np from tensorflow_datasets.core.utils.lazy_imports_utils import tensorflow as tf import tensorflow_datasets.public_api as tfds @@ -48,6 +49,29 @@ _BATCH_SIZE = 10000 +def _cast_according_to_column(feature_type, v): + if feature_type == tf.string and isinstance(v, (float, int)): + return str(v) + return v + + +def _load_shard(index: int, dl_manager, archive_path, columns, features): + """Load a shard of the dataset.""" + pd = tfds.core.lazy_imports.pandas + # There is only one file so by using the for we guarantee that the file + # will be closed. + for _, file in dl_manager.iter_archive(archive_path): + df = pd.read_csv(file, skiprows=index, nrows=_BATCH_SIZE) + result = [] + for i, row in df.iterrows(): + example = { + k: _cast_according_to_column(features[k].dtype, v) + for k, v in zip(columns, row.values) + } + result.append((index + i, example)) + return result + + class Covid19(tfds.core.GeneratorBasedBuilder): """DatasetBuilder for covid19 dataset.""" @@ -787,31 +811,18 @@ def _generate_examples( pd = tfds.core.lazy_imports.pandas beam = tfds.core.lazy_imports.apache_beam - def cast_according_to_column(feature_type, v): - if feature_type == tf.string and isinstance(v, (float, int)): - return str(v) - return v - file_handles = dl_manager.iter_archive(archive_path) _, file = next(file_handles) columns = pd.read_csv(file, nrows=1).columns - - def load_shard(index: int): - # There is only one file so by using the for we guarantee that the file - # will be closed. - for _, file in dl_manager.iter_archive(archive_path): - df = pd.read_csv(file, skiprows=index, nrows=_BATCH_SIZE) - features = self.info.features - result = [] - for i, row in df.iterrows(): - example = { - k: cast_according_to_column(features[k].dtype, v) - for k, v in zip(columns, row.values) - } - result.append((index + i, example)) - return result + features = self.info.features return beam.Create(list(range(0, _N_RECORDS, _BATCH_SIZE))) | beam.FlatMap( - load_shard + functools.partial( + _load_shard, + dl_manager=dl_manager, + archive_path=archive_path, + columns=columns, + features=features, + ) ) diff --git a/tensorflow_datasets/structured/web_graph/web_graph.py b/tensorflow_datasets/structured/web_graph/web_graph.py index 892ee0b5d37..8808b049098 100644 --- a/tensorflow_datasets/structured/web_graph/web_graph.py +++ b/tensorflow_datasets/structured/web_graph/web_graph.py @@ -82,6 +82,24 @@ """ +def _get_int_feature(example: tf.train.Example, feature_name: str) -> List[int]: + return example.features.feature[feature_name].int64_list.value + + +def _process_example(example: bytes, is_test=False): + """Process a single example.""" + example = tf.train.Example.FromString(example) + row_tag = _get_int_feature(example, 'row_tag')[0] + col_tag = np.array(_get_int_feature(example, 'col_tag'), dtype=np.int64) + if is_test: + gt_tag = _get_int_feature(example, 'gt_tag') + else: + gt_tag = [] + gt_tag = np.array(gt_tag, dtype=np.int64) + return_dict = {'row_tag': row_tag, 'col_tag': col_tag, 'gt_tag': gt_tag} + return row_tag, return_dict + + @dataclasses.dataclass class WebGraphConfig(tfds.core.BuilderConfig): """Palmer Penguins dataset builder config.""" @@ -225,23 +243,6 @@ def _generate_examples(self, pipeline, files, split: str): """Yields examples.""" beam = tfds.core.lazy_imports.apache_beam - def _get_int_feature( - example: tf.train.Example, feature_name: str - ) -> List[int]: - return example.features.feature[feature_name].int64_list.value - - def _process_example(example: bytes, is_test=False): - example = tf.train.Example.FromString(example) - row_tag = _get_int_feature(example, 'row_tag')[0] - col_tag = np.array(_get_int_feature(example, 'col_tag'), dtype=np.int64) - if is_test: - gt_tag = _get_int_feature(example, 'gt_tag') - else: - gt_tag = [] - gt_tag = np.array(gt_tag, dtype=np.int64) - return_dict = {'row_tag': row_tag, 'col_tag': col_tag, 'gt_tag': gt_tag} - return row_tag, return_dict - return ( pipeline | f'{split}_create' >> beam.Create(files) diff --git a/tensorflow_datasets/text/c4.py b/tensorflow_datasets/text/c4.py index 1622012da77..3dc1300b926 100644 --- a/tensorflow_datasets/text/c4.py +++ b/tensorflow_datasets/text/c4.py @@ -349,6 +349,28 @@ ] +def _download_wet_file(path, dl_dir): + """Download WET file if it doesn't already exist.""" + url = f"{_DOWNLOAD_HOST}/{path}" + out_path = epath.Path(dl_dir) / path + if out_path.exists(): + c4_utils.get_counter_inc_fn("download_wet_url")("exists") + return out_path + tmp_dir = epath.Path(f"{os.fspath(out_path)}.incomplete{uuid.uuid4().hex}") + try: + tmp_dir.mkdir(parents=True, exist_ok=True) + downloader = tfds.download.download_manager.get_downloader() + with downloader.tqdm(): + # TODO(slebedev): Investigate why pytype infers Promise[Future[...]]. + dl_path = downloader.download(url, tmp_dir).get().path # type: ignore + dl_path = epath.Path(dl_path) + dl_path.rename(out_path) + finally: + tmp_dir.rmtree(missing_ok=True) + c4_utils.get_counter_inc_fn("download_wet_url")("downloaded") + return out_path + + class C4Config(tfds.core.BuilderConfig): """BuilderConfig for C4 dataset.""" @@ -605,30 +627,6 @@ def _get_pages_pcollection(self, pipeline, file_paths, dl_manager): """Build PCollection of un-split page content.""" beam = tfds.core.lazy_imports.apache_beam - def download_wet_file(path, dl_dir): - url = f"{_DOWNLOAD_HOST}/{path}" - out_path = epath.Path(dl_dir) / path - - if out_path.exists(): - c4_utils.get_counter_inc_fn("download_wet_url")("exists") - return out_path - - tmp_dir = epath.Path( - f"{os.fspath(out_path)}.incomplete{uuid.uuid4().hex}" - ) - try: - tmp_dir.mkdir(parents=True, exist_ok=True) - downloader = tfds.download.download_manager.get_downloader() - with downloader.tqdm(): - # TODO(slebedev): Investigate why pytype infers Promise[Future[...]]. - dl_path = downloader.download(url, tmp_dir).get().path # type: ignore - dl_path = epath.Path(dl_path) - dl_path.rename(out_path) - finally: - tmp_dir.rmtree(missing_ok=True) - c4_utils.get_counter_inc_fn("download_wet_url")("downloaded") - return out_path - wet_file_paths = ( pipeline | "create_wet_path_urls" >> beam.Create(file_paths["wet_path_urls"]) @@ -640,7 +638,7 @@ def download_wet_file(path, dl_dir): | "filter_corrupt_wet_files" >> beam.Filter(lambda p: p not in _KNOWN_CORRUPT_WET_FILES) | beam.Map( - download_wet_file, + _download_wet_file, dl_dir=os.path.join(dl_manager.download_dir, "c4_wet_files"), ) ) diff --git a/tensorflow_datasets/text/c4_utils.py b/tensorflow_datasets/text/c4_utils.py index 47e7e4cc846..a5863b39566 100644 --- a/tensorflow_datasets/text/c4_utils.py +++ b/tensorflow_datasets/text/c4_utils.py @@ -78,15 +78,14 @@ def counter_inc_fn(counter, amt=1): return counter_inc_fn -def get_hashed_url_filter_fn(predicate_fn): - def filter_fn(page): - url = page.normalized_url - val = int( - hashlib.md5(tf.compat.as_text(url).encode("utf-8")).hexdigest(), 16 - ) - return predicate_fn(val) +def _hashed_url_filter_fn(page, predicate_fn): + url = page.normalized_url + val = int(hashlib.md5(tf.compat.as_text(url).encode("utf-8")).hexdigest(), 16) + return predicate_fn(val) - return filter_fn + +def get_hashed_url_filter_fn(predicate_fn): + return functools.partial(_hashed_url_filter_fn, predicate_fn=predicate_fn) _nltk_lock = threading.Lock() @@ -506,6 +505,30 @@ def normalize_url(url): return url +def _badwords_predicate(val, num, den): + return val % den >= num + + +def _badwords_filter(page, badwords_regex, keep_badword_page): + """Filter pages that contain bad words.""" + lang = page.language.split("-")[0] # remove suffix if present + if lang in badwords_regex: + text = page.text + badwords_found = badwords_regex[lang].search(text.lower()) + if badwords_found is not None: + if keep_badword_page(page): + get_counter_inc_fn("badwords-filter")("soft-passed") + get_counter_inc_fn("badwords-filter-%s" % lang)("soft-passed") + return True + get_counter_inc_fn("badwords-filter")("filtered") + get_counter_inc_fn("badwords-filter-%s" % lang)("filtered") + return False + get_counter_inc_fn("badwords-filter-%s" % lang)("passed") + + get_counter_inc_fn("badwords-filter")("passed") + return True + + def get_badwords_filter_fn( badwords: Mapping[str, Sequence[str]], filter_fraction: float = 1.0 ): @@ -523,29 +546,16 @@ def get_badwords_filter_fn( filter_ratio = float.as_integer_ratio(filter_fraction) keep_badword_page = get_hashed_url_filter_fn( - lambda x: x % filter_ratio[1] >= filter_ratio[0] + functools.partial( + _badwords_predicate, num=filter_ratio[0], den=filter_ratio[1] + ) ) - def badwords_filter(page): - lang = page.language.split("-")[0] # remove suffix if present - - if lang in badwords_regex: - text = page.text - badwords_found = badwords_regex[lang].search(text.lower()) - if badwords_found is not None: - if keep_badword_page(page): - get_counter_inc_fn("badwords-filter")("soft-passed") - get_counter_inc_fn("badwords-filter-%s" % lang)("soft-passed") - return True - get_counter_inc_fn("badwords-filter")("filtered") - get_counter_inc_fn("badwords-filter-%s" % lang)("filtered") - return False - get_counter_inc_fn("badwords-filter-%s" % lang)("passed") - - get_counter_inc_fn("badwords-filter")("passed") - return True - - return badwords_filter + return functools.partial( + _badwords_filter, + badwords_regex=badwords_regex, + keep_badword_page=keep_badword_page, + ) def paragraph_filter( diff --git a/tensorflow_datasets/video/youtube_vis/youtube_vis.py b/tensorflow_datasets/video/youtube_vis/youtube_vis.py index a50076b8d66..50718c4f5a6 100644 --- a/tensorflow_datasets/video/youtube_vis/youtube_vis.py +++ b/tensorflow_datasets/video/youtube_vis/youtube_vis.py @@ -18,6 +18,7 @@ from __future__ import annotations import collections +import functools import json import os from typing import Any, Dict, List, Optional, Tuple, Union @@ -210,6 +211,46 @@ def _build_annotations_index( return video_id_to_annos, videos +def _frame_index(frame_filename): + """Convert a video frame filename into a numerical index.""" + basename = os.path.basename(os.fspath(frame_filename)) + return int(basename.split('.')[0]) + + +def _process_example( + video_id, + *, + videos, + only_frames_with_labels, + all_frames, + height, + width, + video_id_to_tracks, + maybe_resize_video, +): + """Process a single video into a data example.""" + data_example = {} + video = videos[video_id] + if only_frames_with_labels: + frames_list = [all_frames / file for file in video['file_names']] + else: + video_dir = os.path.dirname(video['file_names'][0]) + video_directory = all_frames / video_dir + frames_list = list(video_directory.glob('*')) + frames_list = sorted(frames_list, key=_frame_index) + data_example['metadata'] = _create_metadata( + video, height, width, len(frames_list) + ) + data_example['tracks'] = [] + track_annotations = video_id_to_tracks[video_id] + for track in track_annotations: + data_example['tracks'].append( + _create_per_track_annotation(video, frames_list, track, height, width) + ) + data_example['video'] = maybe_resize_video(frames_list) + return data_example['metadata']['video_name'], data_example + + class YoutubeVisConfig(tfds.core.BuilderConfig): """ "Configuration for Youtube-vis video instance segmentation dataset. @@ -508,39 +549,20 @@ def _generate_examples( height = self._builder_config.height width = self._builder_config.width only_frames_with_labels = self._builder_config.only_frames_with_labels - data_example = {} - - def _frame_index(frame_filename): - """Convert a video frame filename into a numerical index.""" - basename = os.path.basename(os.fspath(frame_filename)) - return int(basename.split('.')[0]) - - def _process_example(video_id): - """Process a single video into a data example.""" - video = videos[video_id] - if only_frames_with_labels: - frames_list = [all_frames / file for file in video['file_names']] - else: - video_dir = os.path.dirname(video['file_names'][0]) - video_directory = all_frames / video_dir - frames_list = list(video_directory.glob('*')) - frames_list = sorted(frames_list, key=_frame_index) - data_example['metadata'] = _create_metadata( - video, height, width, len(frames_list) - ) - data_example['tracks'] = [] - track_annotations = video_id_to_tracks[video_id] - for track in track_annotations: - data_example['tracks'].append( - _create_per_track_annotation( - video, frames_list, track, height, width - ) - ) - data_example['video'] = self._maybe_resize_video(frames_list) - return data_example['metadata']['video_name'], data_example video_keys = list(videos.keys()) if video_range_to_use is not None: video_keys = video_keys[video_range_to_use[0] : video_range_to_use[1]] - return beam.Create(video_keys) | beam.Map(_process_example) + return beam.Create(video_keys) | beam.Map( + functools.partial( + _process_example, + videos=videos, + only_frames_with_labels=only_frames_with_labels, + all_frames=all_frames, + height=height, + width=width, + video_id_to_tracks=video_id_to_tracks, + maybe_resize_video=self._maybe_resize_video, + ) + ) diff --git a/tensorflow_datasets/vision_language/grounded_scan/grounded_scan.py b/tensorflow_datasets/vision_language/grounded_scan/grounded_scan.py index c08c23239ca..d9a9222c367 100644 --- a/tensorflow_datasets/vision_language/grounded_scan/grounded_scan.py +++ b/tensorflow_datasets/vision_language/grounded_scan/grounded_scan.py @@ -76,6 +76,60 @@ _SPATIAL_DATA_PATH = 'https://storage.googleapis.com/gresearch/gscan/' +def _get_position_feature(raw_position): + return { + 'row': int(raw_position['row']), + 'column': int(raw_position['column']), + } + + +def _get_object_feature(raw_object): + return { + 'vector': raw_object['vector'].strip(), + 'position': _get_position_feature(raw_object['position']), + 'object': { + 'shape': raw_object['object']['shape'], + 'color': raw_object['object']['color'], + 'size': int(raw_object['object']['size']), + }, + } + + +def _parse_sparse_situation_to_feature(situation): + return { + 'grid_size': int(situation['grid_size']), + 'agent_direction': int(situation['agent_direction']), + 'distance_to_target': int(situation['grid_size']), + 'direction_to_target': situation['direction_to_target'], + 'agent_position': _get_position_feature(situation['agent_position']), + 'target_object': _get_object_feature(situation['target_object']), + 'placed_objects': [ + _get_object_feature(obj) + for obj in situation['placed_objects'].values() + ], + } + + +def _preprocess(example): + return { + 'command': example['command'].split(','), + 'target_commands': example['target_commands'].split(','), + 'meaning': example['meaning'].split(','), + 'manner': example['manner'], + 'verb_in_command': example['verb_in_command'], + 'referred_target': example['referred_target'], + 'situation': _parse_sparse_situation_to_feature(example['situation']), + } + + +def _yield_examples(path, split_name): + dataset_path = os.path.join(path, 'dataset.txt') + with tf.io.gfile.GFile(dataset_path, 'r') as f: + dataset = json.load(f) + for i, example in enumerate(dataset['examples'][split_name]): + yield f'{split_name}_{i}', _preprocess(example) + + class GroundedScanConfig(tfds.core.BuilderConfig): """BuilderConfig for groundedSCAN.""" @@ -199,56 +253,8 @@ def _generate_examples(self, path, split_name): """Yields examples.""" beam = tfds.core.lazy_imports.apache_beam - - def _get_position_feature(raw_position): - return { - 'row': int(raw_position['row']), - 'column': int(raw_position['column']), - } - - def _get_object_feature(raw_object): - return { - 'vector': raw_object['vector'].strip(), - 'position': _get_position_feature(raw_object['position']), - 'object': { - 'shape': raw_object['object']['shape'], - 'color': raw_object['object']['color'], - 'size': int(raw_object['object']['size']), - }, - } - - def _parse_sparse_situation_to_feature(situation): - return { - 'grid_size': int(situation['grid_size']), - 'agent_direction': int(situation['agent_direction']), - 'distance_to_target': int(situation['grid_size']), - 'direction_to_target': situation['direction_to_target'], - 'agent_position': _get_position_feature(situation['agent_position']), - 'target_object': _get_object_feature(situation['target_object']), - 'placed_objects': [ - _get_object_feature(obj) - for obj in situation['placed_objects'].values() - ], - } - - def _preprocess(example): - return { - 'command': example['command'].split(','), - 'target_commands': example['target_commands'].split(','), - 'meaning': example['meaning'].split(','), - 'manner': example['manner'], - 'verb_in_command': example['verb_in_command'], - 'referred_target': example['referred_target'], - 'situation': _parse_sparse_situation_to_feature(example['situation']), - } - - def _yield_examples(path): - dataset_path = os.path.join(path, 'dataset.txt') - with tf.io.gfile.GFile(dataset_path, 'r') as f: - dataset = json.load(f) - for i, example in enumerate(dataset['examples'][split_name]): - yield f'{split_name}_{i}', _preprocess(example) - return 'Create pipeline' >> beam.Create( [path] - ) | 'Process samples' >> beam.FlatMap(_yield_examples) + ) | 'Process samples' >> beam.FlatMap( + _yield_examples, split_name=split_name + ) diff --git a/tensorflow_datasets/vision_language/wit_kaggle/wit_kaggle.py b/tensorflow_datasets/vision_language/wit_kaggle/wit_kaggle.py index 71a6a36f832..2dd602d3401 100644 --- a/tensorflow_datasets/vision_language/wit_kaggle/wit_kaggle.py +++ b/tensorflow_datasets/vision_language/wit_kaggle/wit_kaggle.py @@ -65,6 +65,117 @@ _BEAM_NAMESPACE = "TFDS_WIT_KAGGLE" +def _get_csv_reader(filename, *, counter): + if filename.suffix == ".gz": + counter("gz_csv_files").inc() + g = tf.io.gfile.GFile(filename, "rb") + f = gzip.open(g, "rt", newline="") + else: + counter("normal_csv_files").inc() + f = tf.io.gfile.GFile(filename, "r") + # Limit to 100 MB. Value must be smaller than the C long maximum value. + csv.field_size_limit(sys.maxsize) + return csv.reader(f, delimiter="\t") + + +def _read_pixel_rows(filename, *, counter): + r"""Contains image_url \t image_pixel \t metadata_url.""" + reader = _get_csv_reader(filename, counter=counter) + for row in reader: + counter("pixel_rows").inc() + image_url, image_representation, metadata_url = row + if image_url: + yield [image_url, (image_representation, metadata_url)] + else: + counter("pixel_rows_no_image_url").inc() + + +def _read_resnet_rows(filename, *, counter): + r"""Contains image_url \t resnet_embedding.""" + reader = _get_csv_reader(filename, counter=counter) + for row in reader: + counter("resnet_rows").inc() + image_url, image_representation = row + if image_url: + yield [image_url, image_representation] + else: + counter("resnet_rows_no_image_url").inc() + + +def _read_samples_rows(folder_path, *, builder_config, counter): + """Contains samples: train and test have different fields.""" + for filename in tf.io.gfile.listdir(folder_path): + file_path = folder_path / filename + f = tf.io.gfile.GFile(file_path, "r") + # Limit to 100 MB. Value must be smaller than the C long maximum value. + csv.field_size_limit(sys.maxsize) + csv_reader = csv.DictReader(f, delimiter="\t", quoting=csv.QUOTE_ALL) + for row in csv_reader: + counter("samples_rows").inc() + sample = { + feature_key: row[feature_key] + for feature_key in builder_config.split_specific_features.keys() + } + image_url = row["image_url"] + if image_url: + yield [image_url, sample] + else: + counter("samples_rows_no_image_url").inc() + + +def _process_examples(el, *, builder_config, counter): + """Process examples.""" + sample_url, sample_fields = el + # Each image_url can be associated with multiple samples (e.g., multiple + # languages). + for i, sample_info in enumerate(sample_fields["sample_info"]): + sample_id = f"{i}_{sample_url}" + sample = {"image_url": sample_url} + for feature_key in builder_config.split_specific_features.keys(): + sample[feature_key] = sample_info[feature_key] + is_boolean_feature = ( + builder_config.split_specific_features[feature_key].np_dtype + == np.bool_ + ) + if is_boolean_feature: + sample[feature_key] = bool_utils.parse_bool(sample[feature_key]) + # Test samples don't have gold captions. + if "caption_title_and_reference_description" not in sample_info: + sample["caption_title_and_reference_description"] = "" + + # We output image data only if there is at least one image + # representation per image_url. + # Not all of the samples in the competition have corresponding image + # data. In case multiple different image representations are associated + # with the same image_url, we don't know which one is correct and don't + # output any. + if len(set(sample_fields["image_pixels"])) == 1: + sample_image, sample_metadata = sample_fields["image_pixels"][0] + sample["image"] = io.BytesIO(base64.b64decode(sample_image)) + sample["metadata_url"] = sample_metadata + else: + if len(set(sample_fields["image_pixels"])) > 1: + counter("image_pixels_multiple").inc() + else: + counter("image_pixels_missing").inc() + sample["image"] = io.BytesIO(base64.b64decode(_EMPTY_IMAGE_BYTES)) + sample["metadata_url"] = "" + + if len(set(sample_fields["image_resnet"])) == 1: + image_resnet = [ + float(x) for x in sample_fields["image_resnet"][0].split(",") + ] + sample["embedding"] = image_resnet + else: + if len(set(sample_fields["image_resnet"])) > 1: + counter("image_resnet_multiple").inc() + else: + counter("image_resnet_missing").inc() + sample["embedding"] = builder_config.empty_resnet_embedding + + yield sample_id, sample + + class WitKaggleConfig(tfds.core.BuilderConfig): """BuilderConfig for WitKaggle.""" @@ -285,111 +396,6 @@ def _generate_examples( beam = tfds.core.lazy_imports.apache_beam counter = functools.partial(beam.metrics.Metrics.counter, _BEAM_NAMESPACE) - def _get_csv_reader(filename): - if filename.suffix == ".gz": - counter("gz_csv_files").inc() - g = tf.io.gfile.GFile(filename, "rb") - f = gzip.open(g, "rt", newline="") - else: - counter("normal_csv_files").inc() - f = tf.io.gfile.GFile(filename, "r") - # Limit to 100 MB. Value must be smaller than the C long maximum value. - csv.field_size_limit(sys.maxsize) - return csv.reader(f, delimiter="\t") - - def _read_pixel_rows(filename): - r"""Contains image_url \t image_pixel \t metadata_url.""" - reader = _get_csv_reader(filename) - for row in reader: - counter("pixel_rows").inc() - image_url, image_representation, metadata_url = row - if image_url: - yield [image_url, (image_representation, metadata_url)] - else: - counter("pixel_rows_no_image_url").inc() - - def _read_resnet_rows(filename): - r"""Contains image_url \t resnet_embedding.""" - reader = _get_csv_reader(filename) - for row in reader: - counter("resnet_rows").inc() - image_url, image_representation = row - if image_url: - yield [image_url, image_representation] - else: - counter("resnet_rows_no_image_url").inc() - - def _read_samples_rows(folder_path): - """Contains samples: train and test have different fields.""" - for filename in tf.io.gfile.listdir(folder_path): - file_path = folder_path / filename - f = tf.io.gfile.GFile(file_path, "r") - # Limit to 100 MB. Value must be smaller than the C long maximum value. - csv.field_size_limit(sys.maxsize) - csv_reader = csv.DictReader(f, delimiter="\t", quoting=csv.QUOTE_ALL) - for row in csv_reader: - counter("samples_rows").inc() - sample = { - feature_key: row[feature_key] - for feature_key in self.builder_config.split_specific_features.keys() - } - image_url = row["image_url"] - if image_url: - yield [image_url, sample] - else: - counter("samples_rows_no_image_url").inc() - - def _process_examples(el): - sample_url, sample_fields = el - # Each image_url can be associated with multiple samples (e.g., multiple - # languages). - for i, sample_info in enumerate(sample_fields["sample_info"]): - sample_id = f"{i}_{sample_url}" - sample = {"image_url": sample_url} - for feature_key in self.builder_config.split_specific_features.keys(): - sample[feature_key] = sample_info[feature_key] - is_boolean_feature = ( - self.builder_config.split_specific_features[feature_key].np_dtype - == np.bool_ - ) - if is_boolean_feature: - sample[feature_key] = bool_utils.parse_bool(sample[feature_key]) - # Test samples don't have gold captions. - if "caption_title_and_reference_description" not in sample_info: - sample["caption_title_and_reference_description"] = "" - - # We output image data only if there is at least one image - # representation per image_url. - # Not all of the samples in the competition have corresponding image - # data. In case multiple different image representations are associated - # with the same image_url, we don't know which one is correct and don't - # output any. - if len(set(sample_fields["image_pixels"])) == 1: - sample_image, sample_metadata = sample_fields["image_pixels"][0] - sample["image"] = io.BytesIO(base64.b64decode(sample_image)) - sample["metadata_url"] = sample_metadata - else: - if len(set(sample_fields["image_pixels"])) > 1: - counter("image_pixels_multiple").inc() - else: - counter("image_pixels_missing").inc() - sample["image"] = io.BytesIO(base64.b64decode(_EMPTY_IMAGE_BYTES)) - sample["metadata_url"] = "" - - if len(set(sample_fields["image_resnet"])) == 1: - image_resnet = [ - float(x) for x in sample_fields["image_resnet"][0].split(",") - ] - sample["embedding"] = image_resnet - else: - if len(set(sample_fields["image_resnet"])) > 1: - counter("image_resnet_multiple").inc() - else: - counter("image_resnet_missing").inc() - sample["embedding"] = self.builder_config.empty_resnet_embedding - - yield sample_id, sample - # Read embeddings and bytes representations from (possibly compressed) csv. image_resnet_files = [ image_resnet_path / f for f in tf.io.gfile.listdir(image_resnet_path) @@ -397,7 +403,8 @@ def _process_examples(el): resnet_collection = ( pipeline | "Collection from resnet files" >> beam.Create(image_resnet_files) - | "Get embeddings per image" >> beam.FlatMap(_read_resnet_rows) + | "Get embeddings per image" + >> beam.FlatMap(functools.partial(_read_resnet_rows, counter=counter)) ) image_pixel_files = [ @@ -406,14 +413,22 @@ def _process_examples(el): pixel_collection = ( pipeline | "Collection from pixel files" >> beam.Create(image_pixel_files) - | "Get pixels per image" >> beam.FlatMap(_read_pixel_rows) + | "Get pixels per image" + >> beam.FlatMap(functools.partial(_read_pixel_rows, counter=counter)) ) # Read samples from tsv files. sample_collection = ( pipeline | "Collection from sample files" >> beam.Create(samples_path) - | "Get samples" >> beam.FlatMap(_read_samples_rows) + | "Get samples" + >> beam.FlatMap( + functools.partial( + _read_samples_rows, + builder_config=self.builder_config, + counter=counter, + ) + ) ) # Combine the features and yield examples. @@ -425,5 +440,12 @@ def _process_examples(el): } | "Group by image_url" >> beam.CoGroupByKey() | "Reshuffle" >> beam.Reshuffle() - | "Process and yield examples" >> beam.FlatMap(_process_examples) + | "Process and yield examples" + >> beam.FlatMap( + functools.partial( + _process_examples, + builder_config=self.builder_config, + counter=counter, + ) + ) )