Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support the new minhash 25.02 api #445

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions nemo_curator/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,15 @@
except (ImportError, TypeError):
CURRENT_CUDF_VERSION = parse_version("24.10.0")

# TODO remove this once 24.12.0 becomes the base version of cudf in nemo-curator
MINHASH_PERMUTED_AVAILABLE = CURRENT_CUDF_VERSION >= parse_version("24.12.0") or (
CURRENT_CUDF_VERSION.is_prerelease
and CURRENT_CUDF_VERSION.base_version >= "24.12.0"
# TODO remove this once 25.02 becomes the base version of cudf in nemo-curator

# minhash in < 24.12 used to have a minhash(txt) api which was deprecated in favor of
# minhash(a, b) in 25.02 (in 24.12, minhash_permuted(a,b) was introduced)
MINHASH_DEPRECATED_API = (
CURRENT_CUDF_VERSION.base_version < parse_version("24.12").base_version
)
MINHASH_PERMUTED_AVAILABLE = (CURRENT_CUDF_VERSION.major == 24) & (
CURRENT_CUDF_VERSION.minor == 12
Comment on lines -42 to +50
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait I'm confused, why are we doing it for 24.10 and 24.12 instead of for 24.12 and 25.02?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe the code block here helps

if MINHASH_DEPRECATED_API:
     # this is 24.10, the variable is named minhash deprecated api, 
     # because now cudf only needs four args, rather than two args
    return ser.str.minhash(seeds=seeds, width=char_ngram)
else:
    if MINHASH_PERMUTED_AVAILABLE: 
        # this is 24.12 because in this version the four arg function is called minhash_permuted
        return ser.str.minhash_permuted(
            a=seeds_a, b=seeds_b, seed=seeds[0][0], width=char_ngram
        )
    else:
        # this is the 25.02 case where it's neither the deprecated case neither permuted is available
        return ser.str.minhash(a=seeds_a, b=seeds_b, seed=seeds[0][0], width=char_ngram)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah ok I see, thanks! So are we waiting until the next NeMo Curator release (which will have 24.12 as the stable RAPIDS version) before removing MINHASH_DEPRECATED_API?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that's correct! And 25.02+ we wouldn't need any of these variables

)

# TODO: remove when dask min version gets bumped
Expand Down
36 changes: 24 additions & 12 deletions nemo_curator/modules/fuzzy_dedup.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from dask.utils import M
from tqdm import tqdm

from nemo_curator._compat import MINHASH_PERMUTED_AVAILABLE
from nemo_curator._compat import MINHASH_DEPRECATED_API, MINHASH_PERMUTED_AVAILABLE
from nemo_curator.datasets import DocumentDataset
from nemo_curator.log import create_logger
from nemo_curator.modules.config import FuzzyDuplicatesConfig
Expand Down Expand Up @@ -98,15 +98,17 @@ def __init__(
"""
self.num_hashes = num_hashes
self.char_ngram = char_ngrams
if MINHASH_PERMUTED_AVAILABLE:
if MINHASH_DEPRECATED_API:
self.seeds = self.generate_seeds(n_seeds=self.num_hashes, seed=seed)
else:
self.seeds = self.generate_hash_permutation_seeds(
bit_width=64 if use_64bit_hash else 32,
n_permutations=self.num_hashes,
seed=seed,
)
else:
self.seeds = self.generate_seeds(n_seeds=self.num_hashes, seed=seed)

self.minhash_method = self.minhash64 if use_64bit_hash else self.minhash32

self.id_field = id_field
self.text_field = text_field

Expand Down Expand Up @@ -171,7 +173,7 @@ def minhash32(
if not isinstance(ser, cudf.Series):
raise TypeError("Expected data of type cudf.Series")

if not MINHASH_PERMUTED_AVAILABLE:
if MINHASH_DEPRECATED_API:
warnings.warn(
"Using an outdated minhash implementation, please update to cuDF version 24.12 "
"or later for improved performance. "
Expand All @@ -184,9 +186,14 @@ def minhash32(
seeds_a = cudf.Series(seeds[:, 0], dtype="uint32")
seeds_b = cudf.Series(seeds[:, 1], dtype="uint32")

return ser.str.minhash_permuted(
a=seeds_a, b=seeds_b, seed=seeds[0][0], width=char_ngram
)
if MINHASH_PERMUTED_AVAILABLE:
return ser.str.minhash_permuted(
a=seeds_a, b=seeds_b, seed=seeds[0][0], width=char_ngram
)
else:
return ser.str.minhash(
a=seeds_a, b=seeds_b, seed=seeds[0][0], width=char_ngram
)

def minhash64(
self, ser: cudf.Series, seeds: np.ndarray, char_ngram: int
Expand All @@ -196,7 +203,7 @@ def minhash64(
"""
if not isinstance(ser, cudf.Series):
raise TypeError("Expected data of type cudf.Series")
if not MINHASH_PERMUTED_AVAILABLE:
if MINHASH_DEPRECATED_API:
warnings.warn(
"Using an outdated minhash implementation, please update to cuDF version 24.12 "
"or later for improved performance. "
Expand All @@ -209,9 +216,14 @@ def minhash64(
seeds_a = cudf.Series(seeds[:, 0], dtype="uint64")
seeds_b = cudf.Series(seeds[:, 1], dtype="uint64")

return ser.str.minhash64_permuted(
a=seeds_a, b=seeds_b, seed=seeds[0][0], width=char_ngram
)
if MINHASH_PERMUTED_AVAILABLE:
return ser.str.minhash64_permuted(
a=seeds_a, b=seeds_b, seed=seeds[0][0], width=char_ngram
)
else:
return ser.str.minhash64(
a=seeds_a, b=seeds_b, seed=seeds[0][0], width=char_ngram
)

def __call__(self, dataset: DocumentDataset) -> Union[str, DocumentDataset]:
"""
Expand Down
14 changes: 2 additions & 12 deletions tests/test_fuzzy_dedup.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def gpu_client(self, request):
[
(5, 0.5, [[4, -1]]),
(10, 0.39, [[4, -1], [1, 2]]),
(3, 0.3, [[4, -1], [1, 2, 300]]),
(15, 0.3, [[4, -1], [1, 2, 300]]),
],
)
def test_fuzzy_dedup(
Expand All @@ -329,11 +329,6 @@ def test_fuzzy_dedup(
duplicate_docs,
tmpdir,
):
if not use_64_bit_hash and jaccard_threshold == 0.3:
pytest.xfail(
"TODO: RAPIDS 24.12 fails with parameters 3-0.3-duplicate_docs2-False"
)

print(self.client)
# Dedup might fail when indices per partition do not start from 0
fuzzy_dedup_data.df = fuzzy_dedup_data.df.reset_index(drop=True)
Expand Down Expand Up @@ -477,17 +472,12 @@ def test_num_anchors(self, large_fuzzy_dedup_data, num_anchors, tmpdir):
# Duplcated docs estimated from true_jaccard values
[
(10, [[4, -1], [1, 2, 300]]),
(3, [[4, -1], [1, 2, 300]]),
(5, [[4, -1], [1, 2, 300]]),
],
)
def test_no_fp_check(
self, fuzzy_dedup_data, use_64_bit_hash, num_buckets, duplicate_docs, tmpdir
):
if not use_64_bit_hash and num_buckets == 3:
pytest.xfail(
"TODO: RAPIDS 24.12 fails with parameters 3-duplicate_docs1-False"
)

config = FuzzyDuplicatesConfig(
cache_dir=tmpdir,
id_field="id",
Expand Down
Loading