Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support the new minhash 25.02 api #445

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions nemo_curator/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,15 @@
except (ImportError, TypeError):
CURRENT_CUDF_VERSION = parse_version("24.10.0")

# TODO remove this once 24.12.0 becomes the base version of cudf in nemo-curator
MINHASH_PERMUTED_AVAILABLE = CURRENT_CUDF_VERSION >= parse_version("24.12.0") or (
CURRENT_CUDF_VERSION.is_prerelease
and CURRENT_CUDF_VERSION.base_version >= "24.12.0"
# TODO remove this once 25.02 becomes the base version of cudf in nemo-curator

# minhash in < 24.12 used to have a minhash(txt) api which was deprecated in favor of
# minhash(a, b) in 25.02 (in 24.12, minhash_permuted(a,b) was introduced)
MINHASH_DEPRECATED_API = (
CURRENT_CUDF_VERSION.base_version < parse_version("24.12").base_version
)
MINHASH_PERMUTED_AVAILABLE = (CURRENT_CUDF_VERSION.major == 24) & (
CURRENT_CUDF_VERSION.minor == 12
Comment on lines -42 to +50
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait I'm confused, why are we doing it for 24.10 and 24.12 instead of for 24.12 and 25.02?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe the code block here helps

if MINHASH_DEPRECATED_API:
     # this is 24.10, the variable is named minhash deprecated api, 
     # because now cudf only needs four args, rather than two args
    return ser.str.minhash(seeds=seeds, width=char_ngram)
else:
    if MINHASH_PERMUTED_AVAILABLE: 
        # this is 24.12 because in this version the four arg function is called minhash_permuted
        return ser.str.minhash_permuted(
            a=seeds_a, b=seeds_b, seed=seeds[0][0], width=char_ngram
        )
    else:
        # this is the 25.02 case where it's neither the deprecated case neither permuted is available
        return ser.str.minhash(a=seeds_a, b=seeds_b, seed=seeds[0][0], width=char_ngram)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah ok I see, thanks! So are we waiting until the next NeMo Curator release (which will have 24.12 as the stable RAPIDS version) before removing MINHASH_DEPRECATED_API?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that's correct! And 25.02+ we wouldn't need any of these variables

)

# TODO: remove when dask min version gets bumped
Expand Down
36 changes: 24 additions & 12 deletions nemo_curator/modules/fuzzy_dedup.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from dask.utils import M
from tqdm import tqdm

from nemo_curator._compat import MINHASH_PERMUTED_AVAILABLE
from nemo_curator._compat import MINHASH_DEPRECATED_API, MINHASH_PERMUTED_AVAILABLE
from nemo_curator.datasets import DocumentDataset
from nemo_curator.log import create_logger
from nemo_curator.modules.config import FuzzyDuplicatesConfig
Expand Down Expand Up @@ -98,15 +98,17 @@ def __init__(
"""
self.num_hashes = num_hashes
self.char_ngram = char_ngrams
if MINHASH_PERMUTED_AVAILABLE:
if MINHASH_DEPRECATED_API:
self.seeds = self.generate_seeds(n_seeds=self.num_hashes, seed=seed)
else:
self.seeds = self.generate_hash_permutation_seeds(
bit_width=64 if use_64bit_hash else 32,
n_permutations=self.num_hashes,
seed=seed,
)
else:
self.seeds = self.generate_seeds(n_seeds=self.num_hashes, seed=seed)

self.minhash_method = self.minhash64 if use_64bit_hash else self.minhash32

self.id_field = id_field
self.text_field = text_field

Expand Down Expand Up @@ -171,7 +173,7 @@ def minhash32(
if not isinstance(ser, cudf.Series):
raise TypeError("Expected data of type cudf.Series")

if not MINHASH_PERMUTED_AVAILABLE:
if MINHASH_DEPRECATED_API:
warnings.warn(
"Using an outdated minhash implementation, please update to cuDF version 24.12 "
"or later for improved performance. "
Expand All @@ -184,9 +186,14 @@ def minhash32(
seeds_a = cudf.Series(seeds[:, 0], dtype="uint32")
seeds_b = cudf.Series(seeds[:, 1], dtype="uint32")

return ser.str.minhash_permuted(
a=seeds_a, b=seeds_b, seed=seeds[0][0], width=char_ngram
)
if MINHASH_PERMUTED_AVAILABLE:
return ser.str.minhash_permuted(
a=seeds_a, b=seeds_b, seed=seeds[0][0], width=char_ngram
)
else:
return ser.str.minhash(
a=seeds_a, b=seeds_b, seed=seeds[0][0], width=char_ngram
)

def minhash64(
self, ser: cudf.Series, seeds: np.ndarray, char_ngram: int
Expand All @@ -196,7 +203,7 @@ def minhash64(
"""
if not isinstance(ser, cudf.Series):
raise TypeError("Expected data of type cudf.Series")
if not MINHASH_PERMUTED_AVAILABLE:
if MINHASH_DEPRECATED_API:
warnings.warn(
"Using an outdated minhash implementation, please update to cuDF version 24.12 "
"or later for improved performance. "
Expand All @@ -209,9 +216,14 @@ def minhash64(
seeds_a = cudf.Series(seeds[:, 0], dtype="uint64")
seeds_b = cudf.Series(seeds[:, 1], dtype="uint64")

return ser.str.minhash64_permuted(
a=seeds_a, b=seeds_b, seed=seeds[0][0], width=char_ngram
)
if MINHASH_PERMUTED_AVAILABLE:
return ser.str.minhash64_permuted(
a=seeds_a, b=seeds_b, seed=seeds[0][0], width=char_ngram
)
else:
return ser.str.minhash64(
a=seeds_a, b=seeds_b, seed=seeds[0][0], width=char_ngram
)

def __call__(self, dataset: DocumentDataset) -> Union[str, DocumentDataset]:
"""
Expand Down
Loading