Skip to content

Commit

Permalink
Merge pull request #12 from huggingface/deduplication
Browse files Browse the repository at this point in the history
Deduplication
  • Loading branch information
alexchapeaux authored Jul 17, 2023
2 parents 0b51530 + 4fa0562 commit 980d85e
Show file tree
Hide file tree
Showing 4 changed files with 307 additions and 0 deletions.
70 changes: 70 additions & 0 deletions examples/sentence_deduplication_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import os

from datatrove.executor.base import PipelineExecutor
from datatrove.executor.local import LocalPipelineExecutor
from datatrove.io import LocalInputDataFolder, LocalOutputDataFolder
from datatrove.pipeline.dedup import SentenceDedupFilter, SentenceDedupSignature, SentenceFindDedups
from datatrove.pipeline.extractors import Trafilatura
from datatrove.pipeline.filters import GopherQualityFilter, LanguageFilter
from datatrove.pipeline.readers import JsonlReader, WarcReader
from datatrove.pipeline.writers.jsonl import JsonlWriter
from datatrove.utils.typeshelper import Languages


"""
example on how to use sentence-deduplication. sentence-deduplication implements deduplication as in:
https://jmlr.org/papers/v21/20-074.html
'To deduplicate the data set, we discarded all but one of any three-sentence span
occurring more than once in the data set.'
to run deduplication we need to run three different pipelines,
pipeline 1:
implements usual extraction + quality filtering, it ends with SentenceDedupSignature, preprended by a writer.
pipeline 2:
implements only SentenceFindDedups
pipeline 3:
implements SentenceDedupFilter prepended by a reader of the same writer-kind used during stage 1. after the
SentenceDedupFilter.
"""


def run_example():
pipeline_1 = [
WarcReader(data_folder=LocalInputDataFolder(path=f"{os.getcwd()}/warc/"), limit=1000),
Trafilatura(),
GopherQualityFilter(min_stop_words=0),
LanguageFilter(language_threshold=0.5, languages=(Languages.english,)),
JsonlWriter(LocalOutputDataFolder(path=f"{os.getcwd()}/intermediate/")),
SentenceDedupSignature(output_folder=LocalOutputDataFolder(path=f"{os.getcwd()}/c4/")),
]

pipeline_2 = [
SentenceFindDedups(
data_folder=LocalInputDataFolder(path=f"{os.getcwd()}/c4/", extension="c4_dup"),
output_folder=LocalOutputDataFolder(path=f"{os.getcwd()}/c4/"),
)
]

pipeline_3 = [
JsonlReader(data_folder=LocalInputDataFolder(path=f"{os.getcwd()}/intermediate/")),
SentenceDedupFilter(data_folder=LocalInputDataFolder(path=f"{os.getcwd()}/c4/", extension=".c4_dup")),
]

executor_1: PipelineExecutor = LocalPipelineExecutor(
pipeline=pipeline_1, workers=4, max_concurrent_uploads=1, tasks=4
)

executor_2: PipelineExecutor = LocalPipelineExecutor(
pipeline=pipeline_2, workers=1, max_concurrent_uploads=1, tasks=1
)

executor_3: PipelineExecutor = LocalPipelineExecutor(
pipeline=pipeline_3, workers=4, max_concurrent_uploads=1, tasks=4
)

print(executor_1.run())
print(executor_2.run())
print(executor_3.run())


run_example()
1 change: 1 addition & 0 deletions src/datatrove/pipeline/dedup/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .sentence_dedup import SentenceDedupFilter, SentenceDedupSignature, SentenceFindDedups
194 changes: 194 additions & 0 deletions src/datatrove/pipeline/dedup/sentence_dedup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
"""
'To deduplicate the data set, we discarded all but one of any three-sentence span
occurring more than once in the data set.'
from: https://jmlr.org/papers/volume21/20-074/20-074.pdf (C4)
# get hashes for each doc and write them down
"""
import heapq
import struct
from dataclasses import dataclass
from typing import Generator

from nltk.tokenize import sent_tokenize, word_tokenize

from datatrove.data import Document, DocumentsPipeline
from datatrove.io import BaseInputDataFolder, BaseOutputDataFolder, InputDataFile
from datatrove.pipeline.base import PipelineStep
from datatrove.utils.typeshelper import StatHints

from .utils import merge_docs, simplify_content, str_hash


@dataclass
class HashSig:
hash_value: int
doc_id: int
sent_id: int
file_id: int = None

# priority queue accepts anything that is sortable
def __lt__(self, other):
return (self.hash_value, self.file_id, self.doc_id, self.sent_id) < (
other.hash_value,
other.file_id,
other.doc_id,
other.sent_id,
)


class SentenceDedupSignature(PipelineStep):
type = "🫂 - DEDUP"
name = "💥 sentence-deduplication stage 1"

def __init__(self, output_folder: BaseOutputDataFolder, n_sentences: int = 3, stage_2_workers: int = 1, **kwargs):
super().__init__(**kwargs)
self.output_folder = output_folder
self.n_sentences = n_sentences
self.stage_2_workers = stage_2_workers
self.signatures = []

def set_up_dl_locks(self, dl_lock, up_lock):
self.output_folder.set_lock(up_lock)

def save_hashes(self, rank: int):
self.signatures.sort()

f = self.output_folder.open(f"{rank:05d}.c4_sig", mode="wb")
for hs in self.signatures:
f.file_handler.write(struct.pack("<Q", hs.hash_value))
f.file_handler.write(struct.pack("<I", hs.doc_id))
f.file_handler.write(struct.pack("<H", hs.sent_id))
self.output_folder.close()

def get_hashes(self, doc: Document, doc_idx: int) -> list[None] | list[HashSig]:
# todo use language id metadata in sent_tokenize
sentences = sent_tokenize(doc.content)
if len(sentences) < self.n_sentences:
return []

sentences_tokens = [simplify_content(sent) for sent in sentences]
n_sent_grams: list = [
" ".join(sentences_tokens[i : i + self.n_sentences])
for i in range(len(sentences_tokens) - self.n_sentences + 1)
]
hashes = [
HashSig(
hash_value=str_hash(n_sent_gram),
doc_id=doc_idx,
sent_id=sentence_idx,
)
for sentence_idx, n_sent_gram in enumerate(n_sent_grams)
]

return hashes

def __call__(self, data: DocumentsPipeline, rank: int = 0, world_size: int = 1):
for doc_idx, doc in enumerate(data):
self.stat_update(StatHints.total)
self.signatures.extend(self.get_hashes(doc, doc_idx))
self.save_hashes(rank)
self.output_folder.close()


def read_sigs(file: InputDataFile, file_id: int) -> Generator[HashSig, None, None]:
with file.open(binary=True) as f:
while True:
x = {}
for t, b, k in [("Q", 8, "hash_value"), ("I", 4, "doc_id"), ("H", 2, "sent_id")]:
by = f.read(b)
if not by:
return
x[k] = struct.unpack(f"<{t}", by)[0]
yield HashSig(file_id=file_id, **x)


class SentenceFindDedups(PipelineStep):
type = "🫂 - DEDUP"
name = "💥 sentence-deduplication stage 2"

def __init__(self, data_folder: BaseInputDataFolder, output_folder: BaseOutputDataFolder, **kwargs):
super().__init__(**kwargs)
self.data_folder = data_folder
self.output_folder = output_folder

def __call__(self, data: DocumentsPipeline, rank: int = 0, world_size: int = 1):
sig_files = self.data_folder.list_files(".c4_sig")
sig_readers = [read_sigs(file, file_i) for file_i, file in enumerate(sig_files)]

pq = [next(sig_reader) for sig_reader in sig_readers]
heapq.heapify(pq)

last = None
while pq:
v: HashSig = heapq.heappop(pq)
if last == v.hash_value:
f = self.output_folder.open(f"{v.file_id:05d}.c4_dup", mode="wb")
f.file_handler.write(struct.pack("<I", v.doc_id))
f.file_handler.write(struct.pack("<H", v.sent_id))
last = v.hash_value
try:
new_v = next(sig_readers[v.file_id])
except StopIteration:
new_v = None
if new_v:
heapq.heappush(pq, new_v)
self.output_folder.close()


def read_dups(file: InputDataFile) -> Generator[tuple, None, None]:
with file.open(binary=True) as f:
while True:
x = []
for (
t,
b,
) in [("I", 4), ("H", 2)]:
by = f.read(b)
if not by:
return
x.append(struct.unpack(f"<{t}", by)[0])
yield tuple(x)


class SentenceDedupFilter(PipelineStep):
type = "🫂 - DEDUP"
name = "💥 sentence-deduplication stage 3"

def __init__(
self,
data_folder: BaseInputDataFolder,
min_doc_words: int = 50,
**kwargs,
):
super().__init__(**kwargs)
self.data_folder = data_folder
self.min_doc_words = min_doc_words

def filter(self, doc: Document, du_lines: set = None):
sentences = sent_tokenize(doc.content)
# todo find a way to keep skip lines as in the original text
doc.content = " ".join([sent for idx, sent in enumerate(sentences) if not du_lines or idx not in du_lines])
if len(word_tokenize(doc.content)) > self.min_doc_words:
return True
return False

def __call__(self, data: DocumentsPipeline, rank: int = 0, world_size: int = 1) -> DocumentsPipeline:
"""
step method for Filters.
Drops documents that if .filter() is False
@param datapipe: input DocumentsPipeline
@return: DocumentsPipeline
"""
files = self.data_folder.get_files_shard(rank, world_size)
assert len(files) == 1
du_file = merge_docs(sorted(read_dups(files[0])))
for idx, doc in enumerate(data):
self.stat_update(StatHints.total)
with self.time_stats_manager:
is_kept = self.filter(doc, du_lines=du_file.get(idx))
if is_kept:
yield doc
42 changes: 42 additions & 0 deletions src/datatrove/pipeline/dedup/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import hashlib
import re
import string

import numpy as np


# taken from
# https://github.com/Cerebras/modelzoo/blob/main/modelzoo/transformers/data_processing/slimpajama/dedup/to_hash.py
def simplify_content(text: str):
# TODO replace special chars: e' -> e
# lower cased
text = text.lower()
# remove punctuation
text = text.translate(str.maketrans("", "", string.punctuation))
# remove consecutive spaces, newlines, tabs in the middle and in the beginning / end
text = re.sub(r"\s+", " ", text.strip())
return text


def _b2i(b: bytes) -> int:
return np.frombuffer(b, dtype=np.uint64, count=1, offset=0).item(0)


def str_hash(s: str) -> int:
h = hashlib.sha1(bytes(s, encoding="utf-8"))
return _b2i(h.digest())


def merge_docs(sen_list, n_sentences: int = 3) -> dict:
# TODO IMPROVE!
def to_sentences(idx: int):
return (idx + i for i in range(n_sentences))

new_l = [[sen_list[0][0], {sen_list[0][1]}]]
for x in sen_list[1:]:
if x[0] == new_l[-1][0]:
new_l[-1][1].update(to_sentences(x[1]))
else:
new_l.append([x[0], {x[1]}])

return {x[0]: x[1] for x in new_l}

0 comments on commit 980d85e

Please sign in to comment.