Skip to content

Commit d401333

Browse files
Support the new minhash 25.02 api (#445)
Signed-off-by: Praateek <[email protected]>
1 parent 4fb7f54 commit d401333

File tree

2 files changed

+33
-16
lines changed

2 files changed

+33
-16
lines changed

nemo_curator/_compat.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,15 @@
3939
except (ImportError, TypeError):
4040
CURRENT_CUDF_VERSION = parse_version("24.10.0")
4141

42-
# TODO remove this once 24.12.0 becomes the base version of cudf in nemo-curator
43-
MINHASH_PERMUTED_AVAILABLE = CURRENT_CUDF_VERSION >= parse_version("24.12.0") or (
44-
CURRENT_CUDF_VERSION.is_prerelease
45-
and CURRENT_CUDF_VERSION.base_version >= "24.12.0"
42+
# TODO remove this once 25.02 becomes the base version of cudf in nemo-curator
43+
44+
# minhash in < 24.12 used to have a minhash(txt) api which was deprecated in favor of
45+
# minhash(a, b) in 25.02 (in 24.12, minhash_permuted(a,b) was introduced)
46+
MINHASH_DEPRECATED_API = (
47+
CURRENT_CUDF_VERSION.base_version < parse_version("24.12").base_version
48+
)
49+
MINHASH_PERMUTED_AVAILABLE = (CURRENT_CUDF_VERSION.major == 24) & (
50+
CURRENT_CUDF_VERSION.minor == 12
4651
)
4752

4853
# TODO: remove when dask min version gets bumped

nemo_curator/modules/fuzzy_dedup.py

+24-12
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from dask.utils import M
3636
from tqdm import tqdm
3737

38-
from nemo_curator._compat import MINHASH_PERMUTED_AVAILABLE
38+
from nemo_curator._compat import MINHASH_DEPRECATED_API, MINHASH_PERMUTED_AVAILABLE
3939
from nemo_curator.datasets import DocumentDataset
4040
from nemo_curator.log import create_logger
4141
from nemo_curator.modules.config import FuzzyDuplicatesConfig
@@ -98,15 +98,17 @@ def __init__(
9898
"""
9999
self.num_hashes = num_hashes
100100
self.char_ngram = char_ngrams
101-
if MINHASH_PERMUTED_AVAILABLE:
101+
if MINHASH_DEPRECATED_API:
102+
self.seeds = self.generate_seeds(n_seeds=self.num_hashes, seed=seed)
103+
else:
102104
self.seeds = self.generate_hash_permutation_seeds(
103105
bit_width=64 if use_64bit_hash else 32,
104106
n_permutations=self.num_hashes,
105107
seed=seed,
106108
)
107-
else:
108-
self.seeds = self.generate_seeds(n_seeds=self.num_hashes, seed=seed)
109+
109110
self.minhash_method = self.minhash64 if use_64bit_hash else self.minhash32
111+
110112
self.id_field = id_field
111113
self.text_field = text_field
112114

@@ -171,7 +173,7 @@ def minhash32(
171173
if not isinstance(ser, cudf.Series):
172174
raise TypeError("Expected data of type cudf.Series")
173175

174-
if not MINHASH_PERMUTED_AVAILABLE:
176+
if MINHASH_DEPRECATED_API:
175177
warnings.warn(
176178
"Using an outdated minhash implementation, please update to cuDF version 24.12 "
177179
"or later for improved performance. "
@@ -184,9 +186,14 @@ def minhash32(
184186
seeds_a = cudf.Series(seeds[:, 0], dtype="uint32")
185187
seeds_b = cudf.Series(seeds[:, 1], dtype="uint32")
186188

187-
return ser.str.minhash_permuted(
188-
a=seeds_a, b=seeds_b, seed=seeds[0][0], width=char_ngram
189-
)
189+
if MINHASH_PERMUTED_AVAILABLE:
190+
return ser.str.minhash_permuted(
191+
a=seeds_a, b=seeds_b, seed=seeds[0][0], width=char_ngram
192+
)
193+
else:
194+
return ser.str.minhash(
195+
a=seeds_a, b=seeds_b, seed=seeds[0][0], width=char_ngram
196+
)
190197

191198
def minhash64(
192199
self, ser: cudf.Series, seeds: np.ndarray, char_ngram: int
@@ -196,7 +203,7 @@ def minhash64(
196203
"""
197204
if not isinstance(ser, cudf.Series):
198205
raise TypeError("Expected data of type cudf.Series")
199-
if not MINHASH_PERMUTED_AVAILABLE:
206+
if MINHASH_DEPRECATED_API:
200207
warnings.warn(
201208
"Using an outdated minhash implementation, please update to cuDF version 24.12 "
202209
"or later for improved performance. "
@@ -209,9 +216,14 @@ def minhash64(
209216
seeds_a = cudf.Series(seeds[:, 0], dtype="uint64")
210217
seeds_b = cudf.Series(seeds[:, 1], dtype="uint64")
211218

212-
return ser.str.minhash64_permuted(
213-
a=seeds_a, b=seeds_b, seed=seeds[0][0], width=char_ngram
214-
)
219+
if MINHASH_PERMUTED_AVAILABLE:
220+
return ser.str.minhash64_permuted(
221+
a=seeds_a, b=seeds_b, seed=seeds[0][0], width=char_ngram
222+
)
223+
else:
224+
return ser.str.minhash64(
225+
a=seeds_a, b=seeds_b, seed=seeds[0][0], width=char_ngram
226+
)
215227

216228
def __call__(self, dataset: DocumentDataset) -> Union[str, DocumentDataset]:
217229
"""

0 commit comments

Comments
 (0)