Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor parquet dataloader #867

Draft
wants to merge 20 commits into
base: main
Choose a base branch
from
Draft

Conversation

zyaoj
Copy link
Contributor

@zyaoj zyaoj commented Dec 3, 2024

What does this PR do? Please describe:
The first attempt to extract and migrate generic parquet dataloader from MERES to fairseq2.

Does your PR introduce any breaking changes? If yes, please list them:
N/A

Check list:

  • Was the content of this PR discussed and approved via a GitHub issue? (no need for typos or documentation improvements)
  • Did you read the contributor guideline?
  • Did you make sure that your PR does only one thing instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests?
  • Did you verify new and existing tests pass locally with your changes?
  • Did you update the CHANGELOG? (no need for typos, documentation, or minor internal changes)

@zyaoj zyaoj self-assigned this Dec 3, 2024
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 3, 2024
@zyaoj zyaoj marked this pull request as ready for review January 3, 2025 13:46
@zyaoj zyaoj removed request for artemru and cbalioglu January 3, 2025 13:47
@zyaoj zyaoj marked this pull request as draft January 3, 2025 13:47
Comment on lines +7 to +8
pandas~=2.0.0
pandas-stubs~=2.2.3
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we remove these explicite deps here and only keep pyarrow as main dep

"arrow": ["pyarrow>=13.0.0", "pandas~=2.0.0"],
"arrow": [
"pyarrow>=13.0.0",
"joblib~=1.4.2",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's if we could go without it (maybe ThreadingPool from multiprocessing will do the jobs)

"pyarrow_to_torch_tensor",
"pyarrow_column_to_array",
"split_fragment_in_row_groups",
"table_func_wrap",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this's not needed at all

Comment on lines +62 to +63
"_TableWrapper",
"_to_real_object",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now fs2.data handles well pyarrow.Table directly, so no more need for that

Contains different options that allows to load only a part of the provided dataset.
"""

columns: Optional[List[str]] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not yet sure if it belongs to it

"""If ``True``, uses Parquet row groups instead of simple partitions which
are generally smaller. Highly recommended for non-partitioned parquet files."""

nb_parallel_fragments: Optional[int] = 5
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we could keep it with default=None and add an extra config arg (max_tokens) to be used with dynamic bucketing.



@dataclass
class DataLoadingConfig:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this's a more specific dataloading case, I would call it Seq2SeqDataLoading.
But we could get something like ClassifierDataloading or PairDataloading ...

Comment on lines +226 to +240
shuffle: bool = True
"""If ``True``, shuffles the dataset samples during the iteration. If ``False``
and ``order_by_length`` is ``None``, the batch samples will be produced in
natural Parquet dataset reading order."""

drop_null: bool = True
"""If ``True``, drops rows containing any null value."""

seed: int = 123
"""The RNG seed value for deterministic behavior."""

nb_epochs: int = 100
"""
Number of passes over the data before iterations stop
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this probably should go to Basic Dataset config (frontend pipeline)

Comment on lines +112 to +113
# XXX: this will reinit default aws creds if they were not provided explicitly
# tested only on aws cluster only !
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove comments about aws

)
self.fragment = loads(dumps(self.fragment))
fragment_table = self.fragment.to_table(
columns=fragment_columns, use_threads=False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use_threads should be a parameter here

return fragment_table # type: ignore


def parquet_fragments_to_pipeline_builder(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could remove this one in favor of list_parquet_fragments ?

# Apply filters if specified
if dataset_config.filters is not None or dataloader_config.drop_null:
pipeline_builder = pipeline_builder.map(
table_func_wrap(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not need for these wrapper any more

return replace_table_column(table, column, new_array)


def correct_paragraph_length(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this's too LCM specific function, no reason to keep it here

# and take the 'list.min()' and 'list.max()' as needed.
filter_series = df_pl.with_columns(
(
(pl.col(column).list.eval(pl.col("").str.len_bytes()).list.min() >= min_len)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here we used len_bytes but probably len_chars is more relevant

return table


def load_one_fragment(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have SafeFragment interface for this now

return np.asarray(length_col, dtype=np.int32)


class _TableWrapper:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to remove

self.table: pa.Table = table


def _to_real_object(x: Union[_TableWrapper, NestedDict]) -> BatchOutputType:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to remove

return x


def table_func_wrap(func: Callable[..., Any]) -> Callable[..., Any]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to remove

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants