-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor parquet dataloader #867
base: main
Are you sure you want to change the base?
Conversation
pandas~=2.0.0 | ||
pandas-stubs~=2.2.3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could we remove these explicite deps here and only keep pyarrow as main dep
"arrow": ["pyarrow>=13.0.0", "pandas~=2.0.0"], | ||
"arrow": [ | ||
"pyarrow>=13.0.0", | ||
"joblib~=1.4.2", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's if we could go without it (maybe ThreadingPool from multiprocessing will do the jobs)
"pyarrow_to_torch_tensor", | ||
"pyarrow_column_to_array", | ||
"split_fragment_in_row_groups", | ||
"table_func_wrap", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this's not needed at all
"_TableWrapper", | ||
"_to_real_object", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
now fs2.data handles well pyarrow.Table directly, so no more need for that
Contains different options that allows to load only a part of the provided dataset. | ||
""" | ||
|
||
columns: Optional[List[str]] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not yet sure if it belongs to it
"""If ``True``, uses Parquet row groups instead of simple partitions which | ||
are generally smaller. Highly recommended for non-partitioned parquet files.""" | ||
|
||
nb_parallel_fragments: Optional[int] = 5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe we could keep it with default=None and add an extra config arg (max_tokens) to be used with dynamic bucketing.
|
||
|
||
@dataclass | ||
class DataLoadingConfig: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this's a more specific dataloading case, I would call it Seq2SeqDataLoading.
But we could get something like ClassifierDataloading or PairDataloading ...
shuffle: bool = True | ||
"""If ``True``, shuffles the dataset samples during the iteration. If ``False`` | ||
and ``order_by_length`` is ``None``, the batch samples will be produced in | ||
natural Parquet dataset reading order.""" | ||
|
||
drop_null: bool = True | ||
"""If ``True``, drops rows containing any null value.""" | ||
|
||
seed: int = 123 | ||
"""The RNG seed value for deterministic behavior.""" | ||
|
||
nb_epochs: int = 100 | ||
""" | ||
Number of passes over the data before iterations stop | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this probably should go to Basic Dataset config (frontend pipeline)
# XXX: this will reinit default aws creds if they were not provided explicitly | ||
# tested only on aws cluster only ! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove comments about aws
) | ||
self.fragment = loads(dumps(self.fragment)) | ||
fragment_table = self.fragment.to_table( | ||
columns=fragment_columns, use_threads=False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use_threads should be a parameter here
return fragment_table # type: ignore | ||
|
||
|
||
def parquet_fragments_to_pipeline_builder( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we could remove this one in favor of list_parquet_fragments
?
# Apply filters if specified | ||
if dataset_config.filters is not None or dataloader_config.drop_null: | ||
pipeline_builder = pipeline_builder.map( | ||
table_func_wrap( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not need for these wrapper any more
return replace_table_column(table, column, new_array) | ||
|
||
|
||
def correct_paragraph_length( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this's too LCM specific function, no reason to keep it here
# and take the 'list.min()' and 'list.max()' as needed. | ||
filter_series = df_pl.with_columns( | ||
( | ||
(pl.col(column).list.eval(pl.col("").str.len_bytes()).list.min() >= min_len) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
here we used len_bytes
but probably len_chars is more relevant
return table | ||
|
||
|
||
def load_one_fragment( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we have SafeFragment interface for this now
return np.asarray(length_col, dtype=np.int32) | ||
|
||
|
||
class _TableWrapper: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to remove
self.table: pa.Table = table | ||
|
||
|
||
def _to_real_object(x: Union[_TableWrapper, NestedDict]) -> BatchOutputType: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to remove
return x | ||
|
||
|
||
def table_func_wrap(func: Callable[..., Any]) -> Callable[..., Any]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to remove
What does this PR do? Please describe:
The first attempt to extract and migrate generic parquet dataloader from MERES to fairseq2.
Does your PR introduce any breaking changes? If yes, please list them:
N/A
Check list: