Skip to content

Commit

Permalink
added stage 4 - filtering to minhash
Browse files Browse the repository at this point in the history
  • Loading branch information
guipenedo committed Jul 19, 2023
1 parent 66169dc commit b480195
Showing 1 changed file with 47 additions and 2 deletions.
49 changes: 47 additions & 2 deletions src/datatrove/pipeline/dedup/minhash.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,13 @@ def __init__(
input_folder: BaseInputDataFolder,
output_folder: BaseOutputDataFolder,
hashes_per_bucket: int = DEFAULT_PER_BUCKET,
num_buckets: int = DEFAULT_NR_BUCKETS,
**kwargs,
):
super().__init__(**kwargs)
self.input_folder = input_folder
self.output_folder = output_folder
self.num_buckets = num_buckets
self.hashes_per_bucket = hashes_per_bucket

def read_sigs(self, file: InputDataFile, file_id: int) -> Generator:
Expand All @@ -141,6 +143,7 @@ def set_up_dl_locks(self, dl_lock, up_lock):
self.output_folder.set_lock(up_lock)

def __call__(self, data: DocumentsPipeline, bucket: int = 0, world_size: int = 1):
assert world_size == self.num_buckets, "You must run exactly one task per bucket"
sig_files = self.input_folder.list_files(suffix=f"bucket_{bucket:03d}")
sig_readers = [self.read_sigs(file, file_i) for file_i, file in enumerate(sig_files)]

Expand Down Expand Up @@ -169,19 +172,21 @@ def __init__(
self,
input_folder: BaseInputDataFolder,
output_folder: BaseOutputDataFolder,
buckets: int = DEFAULT_NR_BUCKETS,
num_buckets: int = DEFAULT_NR_BUCKETS,
**kwargs,
):
super().__init__(**kwargs)
self.input_folder = input_folder
self.output_folder = output_folder
self.num_buckets = num_buckets

def set_up_dl_locks(self, dl_lock, up_lock):
self.input_folder.set_lock(dl_lock)
self.output_folder.set_lock(up_lock)

def __call__(self, data: DocumentsPipeline, bucket: int = 0, world_size: int = 1):
dup_files = self.input_folder.list_files(extension=".dups")
assert len(dup_files) == self.num_buckets, "There should be exactly one .dups file per bucket"
union_set = {}

def parent(x):
Expand All @@ -200,5 +205,45 @@ def parent(x):
for node, p in sorted(union_set.items()):
if node != p:
file, doc = node
self.output_folder.open(f"{file:06d}", mode="wb").write(struct.pack("<I", doc))
self.output_folder.open(f"{file:06d}.remove", mode="wb").write(struct.pack("<I", doc))
self.output_folder.close()


class MinhashDedupFilter(PipelineStep):
type = "🫂 - DEDUP"
name = "🎯 MinHash stage 4"

def __init__(
self,
input_folder: BaseInputDataFolder,
output_folder: BaseOutputDataFolder,
num_buckets: int = DEFAULT_NR_BUCKETS,
**kwargs,
):
super().__init__(**kwargs)
self.data_folder = input_folder
self.output_folder = output_folder
self.num_buckets = num_buckets

def set_up_dl_locks(self, dl_lock, up_lock):
self.data_folder.set_lock(dl_lock)

def __call__(self, data: DocumentsPipeline, rank: int = 0, world_size: int = 1):
remove_data = self.data_folder.get_files_shard(rank, world_size)
assert len(remove_data) == 1, f"Must have exactly one .remove file per task. Found {len(remove_data)} files."

with remove_data[0].open_binary() as f:

def get_next():
data = f.read(4)
if data:
return struct.unpack("<I", data)

next_removal = get_next()
for idx, doc in enumerate(data):
self.stat_update(StatHints.total)
if next_removal == idx:
# to remove
next_removal = get_next()
continue
yield doc

0 comments on commit b480195

Please sign in to comment.