Skip to content

Commit

Permalink
added tqdm
Browse files Browse the repository at this point in the history
  • Loading branch information
guipenedo committed Jul 13, 2023
1 parent f238a91 commit 9af9499
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions src/tools/check_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numpy as np
from tokenizers import Tokenizer
from tqdm import tqdm

from datatrove.io import BaseInputDataFolder, InputDataFile

Expand Down Expand Up @@ -48,12 +49,15 @@ def load_input_mmap(file: InputDataFile):
doc_ends = [load_doc_ends(file) for file in datafiles_index]
token_inputs = list(map(load_input_mmap, datafiles))
loss_inputs = list(map(load_input_mmap, datafiles_loss)) if check_loss else None
for file_doc_ends, file_token_inputs, file_loss_inputs in zip(doc_ends, token_inputs, loss_inputs):
for filei, (file_doc_ends, file_token_inputs, file_loss_inputs) in enumerate(
zip(doc_ends, token_inputs, loss_inputs)
):
print(f"Processing file {filei + 1}/{len(datafiles)}")
assert (
not check_loss or file_token_inputs.size() == file_loss_inputs.size() * 2
), "Mismatch between loss and tokens file sizes"
assert file_token_inputs.size() == file_doc_ends[-1] * 2, "Size of .ds does not match last doc_end"
for doci, doc_end in enumerate(file_doc_ends):
for doci, doc_end in tqdm(enumerate(file_doc_ends), total=len(file_doc_ends)):
last_token = struct.unpack("<H", file_token_inputs[(doc_end.item() - 1) * 2 : doc_end.item() * 2])[0]
assert last_token == eos_token, f"no EOS at doc end of doc {doci}"
file_token_inputs.close()
Expand Down

0 comments on commit 9af9499

Please sign in to comment.