diff --git a/src/datatrove/pipeline/dedup/minhash.py b/src/datatrove/pipeline/dedup/minhash.py index ca682e8f..2f0a6f82 100644 --- a/src/datatrove/pipeline/dedup/minhash.py +++ b/src/datatrove/pipeline/dedup/minhash.py @@ -35,6 +35,52 @@ SENTINEL = (1 << 32) - 1 +def affine_61(a, b, x): + """Computes the affine transform mod mersenne prime: (a*x + b) % (2^61 - 1) + + Args: + a: int in range [0, (1<<61) - 1) + b: int in range [0, (1<<61) - 1) + x: int in range [0, (1<<61) - 1) + """ + # This expects 64-bit uints less than (1<> H + x1, x2 = x & Hm, x >> H + + # ret = (p11 + p12 + p22 + b) % Pm + # p11 = a1 * x1 + # ~ a1 * x1 % Pm + # p12 = (a2 * x1 + x2 * a1)*(1<<32) + # ~ (a2 * x1 + x2 * a1)*(1<<32) % Pm + # p22 = a2*x2*(1<<64) + # ~ (a2*x2 % Pm) * ((1<<64) % Pm) + # ~ (a2*x2*8) % Pm + + # Multiply low bits with low bits (No uint overflow) + # Take modulus Pm (and add b-term here due to broadcasting reason) + p11 = a1 * x1 + ret = b + (p11 & Pm) + (p11 >> P) + + # Multiply low bits with high bits + # Take "modulus" Pm accounting for 32-bit shift. + p12 = a1 * x2 + a2 * x1 + ret += ((p12 << H) & Pm) + (p12 >> (P - H)) + + # Multiply high bits with high bits + p22 = a2 * x2 + # Take "modulus" Pm accounting for 64-bit shift. + ret += ((p22 << 3) & Pm) + (p22 >> (P - 3)) + return ret % Pm + + @dataclass class MinhashConfig: """Configuration for Min-Hash deduplication @@ -170,7 +216,7 @@ def get_signature(self, shingles: np.ndarray) -> list[list[int]]: list (num buckets) of lists of integers (hashes) """ a, b = self.parameters - phv = (shingles * a + b) % _mersenne_prime + phv = affine_61(a, b, shingles) if self.config.hash_config.precision == 32: phv = np.bitwise_and(phv, self.config.hash_config.max) return [