From 0f32e99e3f6e31d4165bac7fdbd7c305cd67b014 Mon Sep 17 00:00:00 2001 From: Mike Sarahan Date: Fri, 15 Nov 2024 16:48:51 -0600 Subject: [PATCH 1/3] add telemetry (#6126) Enables telemetry during cuml's build process. This is currently done by parsing Github Actions run log metadata, and should have no impact on build/test times Implement OpenTelemetry, as described in https://github.com/rapidsai/build-infra/issues/139 Authors: - Mike Sarahan (https://github.com/msarahan) Approvers: - Bradley Dice (https://github.com/bdice) URL: https://github.com/rapidsai/cuml/pull/6126 --- .github/workflows/pr.yaml | 30 +++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index 1ca4589500..f9e1f066b5 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -10,6 +10,7 @@ concurrency: cancel-in-progress: true jobs: + # Please keep pr-builder as the top job here pr-builder: needs: - changed-files @@ -23,6 +24,7 @@ jobs: - conda-python-tests-dask - conda-notebook-tests - docs-build + - telemetry-setup - wheel-build-cuml - wheel-tests-cuml - devcontainer @@ -31,8 +33,17 @@ jobs: if: always() with: needs: ${{ toJSON(needs) }} + telemetry-setup: + runs-on: ubuntu-latest + continue-on-error: true + env: + OTEL_SERVICE_NAME: "pr-cuml" + steps: + - name: Telemetry setup + uses: rapidsai/shared-actions/telemetry-dispatch-stash-base-env-vars@main changed-files: secrets: inherit + needs: telemetry-setup uses: rapidsai/shared-workflows/.github/workflows/changed-files.yaml@branch-24.12 with: files_yaml: | @@ -66,11 +77,12 @@ jobs: - '!thirdparty/LICENSES/**' checks: secrets: inherit + needs: telemetry-setup uses: rapidsai/shared-workflows/.github/workflows/checks.yaml@branch-24.12 with: enable_check_generated_files: false ignored_pr_jobs: >- - optional-job-conda-python-tests-cudf-pandas-integration + optional-job-conda-python-tests-cudf-pandas-integration telemetry-summarize clang-tidy: needs: checks secrets: inherit @@ -173,6 +185,7 @@ jobs: build_type: pull-request script: ci/test_wheel.sh devcontainer: + needs: telemetry-setup secrets: inherit uses: rapidsai/shared-workflows/.github/workflows/build-in-devcontainer.yaml@branch-24.12 with: @@ -183,3 +196,18 @@ jobs: sccache -z; build-all --verbose; sccache -s; + + telemetry-summarize: + runs-on: ubuntu-latest + needs: pr-builder + if: always() + continue-on-error: true + steps: + - name: Load stashed telemetry env vars + uses: rapidsai/shared-actions/telemetry-dispatch-load-base-env-vars@main + with: + load_service_name: true + - name: Telemetry summarize + uses: rapidsai/shared-actions/telemetry-dispatch-write-summary@main + with: + cert_concat: "${{ secrets.OTEL_EXPORTER_OTLP_CA_CERTIFICATE }};${{ secrets.OTEL_EXPORTER_OTLP_CLIENT_CERTIFICATE }};${{ secrets.OTEL_EXPORTER_OTLP_CLIENT_KEY }}" From 06958c4db7edbde19a9ea7939d36645930b17e07 Mon Sep 17 00:00:00 2001 From: Dante Gama Dessavre Date: Thu, 21 Nov 2024 11:25:25 -0600 Subject: [PATCH 2/3] Experimental command line interface UX (#6135) PR adds a first version of a command line user experience that covers the following estimators: - Linear Regression, Ridge, Lasso and ElastiNet - Logistic Regression - PCA and tSVD - DBSCAN, KMeans and HDBSCAN - UMAP and TSNE - Nearest Neighbors --------- Co-authored-by: divyegala --- ci/run_cuml_singlegpu_accel_pytests.sh | 7 + ci/test_python_singlegpu.sh | 10 + python/cuml/cuml/cluster/dbscan.pyx | 21 +- python/cuml/cuml/cluster/hdbscan/hdbscan.pyx | 13 + python/cuml/cuml/cluster/kmeans.pyx | 6 +- python/cuml/cuml/decomposition/pca.pyx | 10 + python/cuml/cuml/decomposition/tsvd.pyx | 7 + .../cuml/cuml/experimental/accel/__init__.py | 68 ++++ .../cuml/cuml/experimental/accel/__main__.py | 70 ++++ .../experimental/accel/_wrappers/__init__.py | 34 ++ .../experimental/accel/_wrappers/hdbscan.py | 24 ++ .../experimental/accel/_wrappers/sklearn.py | 129 +++++++ .../cuml/experimental/accel/_wrappers/umap.py | 24 ++ .../experimental/accel/estimator_proxy.py | 353 ++++++++++++++++++ python/cuml/cuml/experimental/accel/magics.py | 29 ++ python/cuml/cuml/internals/base.pyx | 71 +++- python/cuml/cuml/internals/global_settings.py | 16 +- python/cuml/cuml/linear_model/elastic_net.pyx | 9 + .../cuml/linear_model/linear_regression.pyx | 6 + .../cuml/linear_model/logistic_regression.pyx | 11 + python/cuml/cuml/linear_model/ridge.pyx | 27 +- python/cuml/cuml/linear_model/ridge_mg.pyx | 3 +- python/cuml/cuml/manifold/t_sne.pyx | 18 +- python/cuml/cuml/manifold/umap.pyx | 2 +- .../cuml/neighbors/kneighbors_classifier.pyx | 11 + .../cuml/neighbors/kneighbors_regressor.pyx | 17 +- .../cuml/cuml/neighbors/nearest_neighbors.pyx | 18 +- .../test_accel_dbscan.py | 91 +++++ .../test_accel_elastic_net.py | 209 +++++++++++ .../test_accel_hdbscan_core.py | 317 ++++++++++++++++ .../test_accel_hdbscan_extended.py | 214 +++++++++++ .../test_accel_kmeans.py | 105 ++++++ .../test_accel_kneighbors_classifier.py | 189 ++++++++++ .../test_accel_kneighbors_regressor.py | 163 ++++++++ .../test_accel_lasso.py | 193 ++++++++++ .../test_accel_linear_regression.py | 59 +++ .../test_accel_logistic_regression.py | 195 ++++++++++ .../test_accel_nearest_neighbors.py | 193 ++++++++++ .../estimators_hyperparams/test_accel_pca.py | 160 ++++++++ .../test_accel_ridge.py | 154 ++++++++ .../estimators_hyperparams/test_accel_tsne.py | 193 ++++++++++ .../estimators_hyperparams/test_accel_tsvd.py | 177 +++++++++ .../estimators_hyperparams/test_accel_umap.py | 172 +++++++++ .../accel/test_basic_estimators.py | 142 +++++++ .../tests/experimental/accel/test_pipeline.py | 147 ++++++++ python/cuml/cuml/tests/test_tsne.py | 2 +- 46 files changed, 4056 insertions(+), 33 deletions(-) create mode 100755 ci/run_cuml_singlegpu_accel_pytests.sh create mode 100644 python/cuml/cuml/experimental/accel/__init__.py create mode 100644 python/cuml/cuml/experimental/accel/__main__.py create mode 100644 python/cuml/cuml/experimental/accel/_wrappers/__init__.py create mode 100644 python/cuml/cuml/experimental/accel/_wrappers/hdbscan.py create mode 100644 python/cuml/cuml/experimental/accel/_wrappers/sklearn.py create mode 100644 python/cuml/cuml/experimental/accel/_wrappers/umap.py create mode 100644 python/cuml/cuml/experimental/accel/estimator_proxy.py create mode 100644 python/cuml/cuml/experimental/accel/magics.py create mode 100644 python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_dbscan.py create mode 100644 python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_elastic_net.py create mode 100644 python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_hdbscan_core.py create mode 100644 python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_hdbscan_extended.py create mode 100644 python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_kmeans.py create mode 100644 python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_kneighbors_classifier.py create mode 100644 python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_kneighbors_regressor.py create mode 100644 python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_lasso.py create mode 100644 python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_linear_regression.py create mode 100644 python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_logistic_regression.py create mode 100644 python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_nearest_neighbors.py create mode 100644 python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_pca.py create mode 100644 python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_ridge.py create mode 100644 python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_tsne.py create mode 100644 python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_tsvd.py create mode 100644 python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_umap.py create mode 100644 python/cuml/cuml/tests/experimental/accel/test_basic_estimators.py create mode 100644 python/cuml/cuml/tests/experimental/accel/test_pipeline.py diff --git a/ci/run_cuml_singlegpu_accel_pytests.sh b/ci/run_cuml_singlegpu_accel_pytests.sh new file mode 100755 index 0000000000..3b13a53ef1 --- /dev/null +++ b/ci/run_cuml_singlegpu_accel_pytests.sh @@ -0,0 +1,7 @@ +#!/bin/bash +# Copyright (c) 2024, NVIDIA CORPORATION. + +# Support invoking run_cuml_singlegpu_pytests.sh outside the script directory +cd "$(dirname "$(realpath "${BASH_SOURCE[0]}")")"/../python/cuml/cuml/tests/experimental/accel + +python -m pytest -p cuml.experimental.accel --cache-clear "$@" . diff --git a/ci/test_python_singlegpu.sh b/ci/test_python_singlegpu.sh index 011dffec40..e6bba181e1 100755 --- a/ci/test_python_singlegpu.sh +++ b/ci/test_python_singlegpu.sh @@ -21,6 +21,16 @@ rapids-logger "pytest cuml single GPU" --cov-report=xml:"${RAPIDS_COVERAGE_DIR}/cuml-coverage.xml" \ --cov-report=term + rapids-logger "pytest cuml accelerator" +./ci/run_cuml_singlegpu_accel_pytests.sh \ + --numprocesses=8 \ + --dist=worksteal \ + --junitxml="${RAPIDS_TESTS_DIR}/junit-cuml-accel.xml" \ + --cov-config=../../../../.coveragerc \ + --cov=cuml \ + --cov-report=xml:"${RAPIDS_COVERAGE_DIR}/cuml-accel-coverage.xml" \ + --cov-report=term + rapids-logger "memory leak pytests" ./ci/run_cuml_singlegpu_memleak_pytests.sh \ diff --git a/python/cuml/cuml/cluster/dbscan.pyx b/python/cuml/cuml/cluster/dbscan.pyx index 9480ccccf6..07af1d142c 100644 --- a/python/cuml/cuml/cluster/dbscan.pyx +++ b/python/cuml/cuml/cluster/dbscan.pyx @@ -225,6 +225,19 @@ class DBSCAN(UniversalBase, core_sample_indices_ = CumlArrayDescriptor(order="C") labels_ = CumlArrayDescriptor(order="C") + _hyperparam_interop_translator = { + "metric": { + "manhattan": "NotImplemented", + "chebyshev": "NotImplemented", + "minkowski": "NotImplemented", + }, + "algorithm": { + "auto": "brute", + "ball_tree": "NotImplemented", + "kd_tree": "NotImplemented", + }, + } + @device_interop_preparation def __init__(self, *, eps=0.5, @@ -263,7 +276,7 @@ class DBSCAN(UniversalBase, opg that is set to `False` for SG, `True` for OPG (multi-GPU) """ if out_dtype not in ["int32", np.int32, "int64", np.int64]: - raise ValueError("Invalid value for out_dtype. " + raise ValueError(f"Invalid value for out_dtype: {out_dtype}. " "Valid values are {'int32', 'int64', " "np.int32, np.int64}") @@ -422,7 +435,7 @@ class DBSCAN(UniversalBase, @generate_docstring(skip_parameters_heading=True) @enable_device_interop - def fit(self, X, out_dtype="int32", sample_weight=None, + def fit(self, X, y=None, out_dtype="int32", sample_weight=None, convert_dtype=True) -> "DBSCAN": """ Perform DBSCAN clustering from features. @@ -447,7 +460,7 @@ class DBSCAN(UniversalBase, 'description': 'Cluster labels', 'shape': '(n_samples, 1)'}) @enable_device_interop - def fit_predict(self, X, out_dtype="int32", sample_weight=None) -> CumlArray: + def fit_predict(self, X, y=None, out_dtype="int32", sample_weight=None) -> CumlArray: """ Performs clustering on X and returns cluster labels. @@ -463,7 +476,7 @@ class DBSCAN(UniversalBase, negative weight may inhibit its eps-neighbor from being core. default: None (which is equivalent to weight 1 for all samples). """ - self.fit(X, out_dtype, sample_weight) + self.fit(X, out_dtype=out_dtype, sample_weight=sample_weight) return self.labels_ @classmethod diff --git a/python/cuml/cuml/cluster/hdbscan/hdbscan.pyx b/python/cuml/cuml/cluster/hdbscan/hdbscan.pyx index 0078b145c7..9edd4f302f 100644 --- a/python/cuml/cuml/cluster/hdbscan/hdbscan.pyx +++ b/python/cuml/cuml/cluster/hdbscan/hdbscan.pyx @@ -485,6 +485,19 @@ class HDBSCAN(UniversalBase, ClusterMixin, CMajorInputTagMixin): mst_dst_ = CumlArrayDescriptor() mst_weights_ = CumlArrayDescriptor() + _hyperparam_interop_translator = { + "metric": { + "manhattan": "NotImplemented", + "chebyshev": "NotImplemented", + "minkowski": "NotImplemented", + }, + "algorithm": { + "auto": "brute", + "ball_tree": "NotImplemented", + "kd_tree": "NotImplemented", + }, + } + @device_interop_preparation def __init__(self, *, min_cluster_size=5, diff --git a/python/cuml/cuml/cluster/kmeans.pyx b/python/cuml/cuml/cluster/kmeans.pyx index 9ba1cb710a..c36765fe34 100644 --- a/python/cuml/cuml/cluster/kmeans.pyx +++ b/python/cuml/cuml/cluster/kmeans.pyx @@ -564,7 +564,7 @@ class KMeans(UniversalBase, 'description': 'Cluster indexes', 'shape': '(n_samples, 1)'}) @enable_device_interop - def predict(self, X, convert_dtype=True, sample_weight=None, + def predict(self, X, y=None, convert_dtype=True, sample_weight=None, normalize_weights=True) -> CumlArray: """ Predict the closest cluster each sample in X belongs to. @@ -583,7 +583,7 @@ class KMeans(UniversalBase, 'description': 'Transformed data', 'shape': '(n_samples, n_clusters)'}) @enable_device_interop - def transform(self, X, convert_dtype=True) -> CumlArray: + def transform(self, X, y=None, convert_dtype=True) -> CumlArray: """ Transform X to a cluster-distance space. @@ -687,7 +687,7 @@ class KMeans(UniversalBase, 'description': 'Transformed data', 'shape': '(n_samples, n_clusters)'}) @enable_device_interop - def fit_transform(self, X, convert_dtype=False, + def fit_transform(self, X, y=None, convert_dtype=False, sample_weight=None) -> CumlArray: """ Compute clustering and transform X to cluster-distance space. diff --git a/python/cuml/cuml/decomposition/pca.pyx b/python/cuml/cuml/decomposition/pca.pyx index 9433f724b9..db2f0f62c8 100644 --- a/python/cuml/cuml/decomposition/pca.pyx +++ b/python/cuml/cuml/decomposition/pca.pyx @@ -280,6 +280,16 @@ class PCA(UniversalBase, noise_variance_ = CumlArrayDescriptor(order='F') trans_input_ = CumlArrayDescriptor(order='F') + _hyperparam_interop_translator = { + "svd_solver": { + "arpack": "full", + "randomized": "full" + }, + "iterated_power": { + "auto": 15, + }, + } + @device_interop_preparation def __init__(self, *, copy=True, handle=None, iterated_power=15, n_components=None, random_state=None, svd_solver='auto', diff --git a/python/cuml/cuml/decomposition/tsvd.pyx b/python/cuml/cuml/decomposition/tsvd.pyx index 55caa84c9b..b495d3d239 100644 --- a/python/cuml/cuml/decomposition/tsvd.pyx +++ b/python/cuml/cuml/decomposition/tsvd.pyx @@ -240,6 +240,13 @@ class TruncatedSVD(UniversalBase, explained_variance_ratio_ = CumlArrayDescriptor(order='F') singular_values_ = CumlArrayDescriptor(order='F') + _hyperparam_interop_translator = { + "algorithm": { + "randomized": "full", + "arpack": "full", + }, + } + @device_interop_preparation def __init__(self, *, algorithm='full', handle=None, n_components=1, n_iter=15, random_state=None, tol=1e-7, diff --git a/python/cuml/cuml/experimental/accel/__init__.py b/python/cuml/cuml/experimental/accel/__init__.py new file mode 100644 index 0000000000..cd3c6abf51 --- /dev/null +++ b/python/cuml/cuml/experimental/accel/__init__.py @@ -0,0 +1,68 @@ +# +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import importlib + +from .magics import load_ipython_extension + +from cuml.internals import logger +from cuml.internals.global_settings import GlobalSettings +from cuml.internals.memory_utils import set_global_output_type + +__all__ = ["load_ipython_extension", "install"] + + +def _install_for_library(library_name): + importlib.import_module(f"._wrappers.{library_name}", __name__) + return True + + +def install(): + """Enable cuML Accelerator Mode.""" + logger.set_level(logger.level_info) + logger.set_pattern("%v") + + logger.info("cuML: Installing experimental accelerator...") + loader_sklearn = _install_for_library(library_name="sklearn") + loader_umap = _install_for_library(library_name="umap") + loader_hdbscan = _install_for_library(library_name="hdbscan") + + GlobalSettings().accelerator_loaded = all( + [loader_sklearn, loader_umap, loader_hdbscan] + ) + + GlobalSettings().accelerator_active = True + + if GlobalSettings().accelerator_loaded: + logger.info( + "cuML: experimental accelerator successfully initialized..." + ) + else: + logger.info("cuML: experimental accelerator failed to initialize...") + + set_global_output_type("numpy") + + +def pytest_load_initial_conftests(early_config, parser, args): + # https://docs.pytest.org/en/7.1.x/reference/\ + # reference.html#pytest.hookspec.pytest_load_initial_conftests + try: + install() + except RuntimeError: + raise RuntimeError( + "An existing plugin has already loaded sklearn. Interposing failed." + ) diff --git a/python/cuml/cuml/experimental/accel/__main__.py b/python/cuml/cuml/experimental/accel/__main__.py new file mode 100644 index 0000000000..e4c4af576b --- /dev/null +++ b/python/cuml/cuml/experimental/accel/__main__.py @@ -0,0 +1,70 @@ +# +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import click +import code +import os +import runpy +import sys + +from . import install + + +@click.command() +@click.option("-m", "module", required=False, help="Module to run") +@click.option( + "--strict", + is_flag=True, + default=False, + help="Turn strict mode for hyperparameters on.", +) +@click.argument("args", nargs=-1) +def main(module, strict, args): + + if strict: + os.environ["CUML_ACCEL_STRICT_MODE"] = "ON" + + install() + + if module: + (module,) = module + # run the module passing the remaining arguments + # as if it were run with python -m + sys.argv[:] = [module] + args # not thread safe? + runpy.run_module(module, run_name="__main__") + elif len(args) >= 1: + # Remove ourself from argv and continue + sys.argv[:] = args + runpy.run_path(args[0], run_name="__main__") + else: + if sys.stdin.isatty(): + banner = f"Python {sys.version} on {sys.platform}" + site_import = not sys.flags.no_site + if site_import: + cprt = 'Type "help", "copyright", "credits" or "license" for more information.' + banner += "\n" + cprt + else: + # Don't show prompts or banners if stdin is not a TTY + sys.ps1 = "" + sys.ps2 = "" + banner = "" + + # Launch an interactive interpreter + code.interact(banner=banner, exitmsg="") + + +if __name__ == "__main__": + main() diff --git a/python/cuml/cuml/experimental/accel/_wrappers/__init__.py b/python/cuml/cuml/experimental/accel/_wrappers/__init__.py new file mode 100644 index 0000000000..32ea7c7bee --- /dev/null +++ b/python/cuml/cuml/experimental/accel/_wrappers/__init__.py @@ -0,0 +1,34 @@ +# +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +wrapped_estimators = { + "KMeans": ("cuml.cluster", "KMeans"), + "DBSCAN": ("cuml.cluster", "DBSCAN"), + "PCA": ("cuml.decomposition", "PCA"), + "TruncatedSVD": ("cuml.decomposition", "TruncatedSVD"), + "KernelRidge": ("cuml.kernel_ridge", "KernelRidge"), + "LinearRegression": ("cuml.linear_model", "LinearRegression"), + "LogisticRegression": ("cuml.linear_model", "LogisticRegression"), + "ElasticNet": ("cuml.linear_model", "ElasticNet"), + "Ridge": ("cuml.linear_model", "Ridge"), + "Lasso": ("cuml.linear_model", "Lasso"), + "TSNE": ("cuml.manifold", "TSNE"), + "NearestNeighbors": ("cuml.neighbors", "NearestNeighbors"), + "KNeighborsClassifier": ("cuml.neighbors", "KNeighborsClassifier"), + "KNeighborsRegressor": ("cuml.neighbors", "KNeighborsRegressor"), + "UMAP": ("cuml.manifold", "UMAP"), + "HDBSCAN": ("cuml.cluster", "HDBSCAN"), +} diff --git a/python/cuml/cuml/experimental/accel/_wrappers/hdbscan.py b/python/cuml/cuml/experimental/accel/_wrappers/hdbscan.py new file mode 100644 index 0000000000..daeaa7b8c2 --- /dev/null +++ b/python/cuml/cuml/experimental/accel/_wrappers/hdbscan.py @@ -0,0 +1,24 @@ +# +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from ..estimator_proxy import intercept + + +HDBSCAN = intercept( + original_module="hdbscan", + accelerated_module="cuml.cluster", + original_class_name="HDBSCAN", +) diff --git a/python/cuml/cuml/experimental/accel/_wrappers/sklearn.py b/python/cuml/cuml/experimental/accel/_wrappers/sklearn.py new file mode 100644 index 0000000000..9b7a09b887 --- /dev/null +++ b/python/cuml/cuml/experimental/accel/_wrappers/sklearn.py @@ -0,0 +1,129 @@ +# +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from ..estimator_proxy import intercept + + +############################################################################### +# Clustering Estimators # +############################################################################### + +KMeans = intercept( + original_module="sklearn.cluster", + accelerated_module="cuml.cluster", + original_class_name="KMeans", +) + +DBSCAN = intercept( + original_module="sklearn.cluster", + accelerated_module="cuml.cluster", + original_class_name="DBSCAN", +) + + +############################################################################### +# Decomposition Estimators # +############################################################################### + + +PCA = intercept( + original_module="sklearn.decomposition", + accelerated_module="cuml.decomposition", + original_class_name="PCA", +) + + +TruncatedSVD = intercept( + original_module="sklearn.decomposition", + accelerated_module="cuml.decomposition", + original_class_name="TruncatedSVD", +) + + +############################################################################### +# Linear Estimators # +############################################################################### + +KernelRidge = intercept( + original_module="sklearn.kernel_ridge", + accelerated_module="cuml.kernel_ridge", + original_class_name="KernelRidge", +) + +LinearRegression = intercept( + original_module="sklearn.linear_model", + accelerated_module="cuml.linear_model", + original_class_name="LinearRegression", +) + +LogisticRegression = intercept( + original_module="sklearn.linear_model", + accelerated_module="cuml.linear_model", + original_class_name="LogisticRegression", +) + +ElasticNet = intercept( + original_module="sklearn.linear_model", + accelerated_module="cuml.linear_model", + original_class_name="ElasticNet", +) + +Ridge = intercept( + original_module="sklearn.linear_model", + accelerated_module="cuml.linear_model", + original_class_name="Ridge", +) + +Lasso = intercept( + original_module="sklearn.linear_model", + accelerated_module="cuml.linear_model", + original_class_name="Lasso", +) + + +############################################################################### +# Manifold Estimators # +############################################################################### + +TSNE = intercept( + original_module="sklearn.manifold", + accelerated_module="cuml.manifold", + original_class_name="TSNE", +) + + +############################################################################### +# Neighbors Estimators # +############################################################################### + + +NearestNeighbors = intercept( + original_module="sklearn.neighbors", + accelerated_module="cuml.neighbors", + original_class_name="NearestNeighbors", +) + +KNeighborsClassifier = intercept( + original_module="sklearn.neighbors", + accelerated_module="cuml.neighbors", + original_class_name="KNeighborsClassifier", +) + +KNeighborsRegressor = intercept( + original_module="sklearn.neighbors", + accelerated_module="cuml.neighbors", + original_class_name="KNeighborsRegressor", +) diff --git a/python/cuml/cuml/experimental/accel/_wrappers/umap.py b/python/cuml/cuml/experimental/accel/_wrappers/umap.py new file mode 100644 index 0000000000..dd8b6864b0 --- /dev/null +++ b/python/cuml/cuml/experimental/accel/_wrappers/umap.py @@ -0,0 +1,24 @@ +# +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from ..estimator_proxy import intercept + + +UMAP = intercept( + original_module="umap", + accelerated_module="cuml.manifold", + original_class_name="UMAP", +) diff --git a/python/cuml/cuml/experimental/accel/estimator_proxy.py b/python/cuml/cuml/experimental/accel/estimator_proxy.py new file mode 100644 index 0000000000..fcee2f5b37 --- /dev/null +++ b/python/cuml/cuml/experimental/accel/estimator_proxy.py @@ -0,0 +1,353 @@ +# +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import inspect +import sys +import types + +from cuml.internals.mem_type import MemoryType +from cuml.internals import logger +from cuml.internals.global_settings import GlobalSettings +from cuml.internals.safe_imports import gpu_only_import, cpu_only_import +from typing import Optional, Tuple, Dict, Any, Type, List + + +class ProxyModule: + """ + A proxy module that dynamically replaces specified classes with proxy estimators + based on GlobalSettings. + + Parameters + ---------- + original_module : module + The module to be proxied. + Attributes + ---------- + _original_module : module + The original module being proxied. + _proxy_estimators : dict of str to type + A dictionary mapping accelerated class names to their proxy estimators. + """ + + def __init__(self, original_module: types.ModuleType) -> None: + """Initialize the ProxyModule with the original module.""" + self._original_module = original_module + self._proxy_estimators: Dict[str, Type[Any]] = {} + + def add_estimator( + self, class_name: str, proxy_estimator: Type[Any] + ) -> None: + """ + Add a proxy estimator for a specified class name. + Parameters + ---------- + class_name : str + The name of the class in the original module to be replaced. + proxy_estimator : type + The proxy estimator class to use as a replacement. + """ + self._proxy_estimators[class_name] = proxy_estimator + + def __getattr__(self, name: str) -> Any: + """ + Intercept attribute access on the proxy module. + If the attribute name is in the proxy estimators and the accelerator is active, + return the proxy estimator; otherwise, return the attribute from the original module. + Parameters + ---------- + name : str + The name of the attribute being accessed. + Returns + ------- + Any + The attribute from the proxy estimator or the original module. + """ + if name in self._proxy_estimators: + use_proxy = getattr(GlobalSettings(), "accelerator_active", False) + if use_proxy: + return self._proxy_estimators[name] + else: + return getattr(self._original_module, name) + else: + return getattr(self._original_module, name) + + def __dir__(self) -> List[str]: + """ + Provide a list of attributes available in the proxy module. + Returns + ------- + list of str + A list of attribute names from the original module. + """ + return dir(self._original_module) + + +def intercept( + original_module: str, + accelerated_module: str, + original_class_name: str, + accelerated_class_name: Optional[str] = None, +): + """ + Factory function that creates class definitions of ProxyEstimators that + accelerate estimators of the original class. + + This function dynamically creates a new class called `ProxyEstimator` that + inherits from the GPU-accelerated class in the `accelerated_module` + (e.g., cuML) and acts as a drop-in replacement for the original class in + `original_module` (e.g., scikit-learn). Then, this class can be used to + create instances of ProxyEstimators that dispatch to either library. + + **Design of the ProxyEstimator Class Inside** + + **`ProxyEstimator` Class:** + - The `ProxyEstimator` class inherits from the GPU-accelerated + class (`class_b`) obtained from the `accelerated_module`. + - It serves as a wrapper that adds additional functionality + to maintain compatibility with the original CPU-based estimator. + Key methods and attributes: + - `__init__`: Initializes the proxy estimator, stores a + reference to the original class before ModuleAccelerator + replaces the original module, translates hyperparameters, + and initializes the parent (cuML) class. + - `__repr__` and `__str__`: Provide string representations + that reference the original CPU-based class. + - Attribute `_cpu_model_class`: Stores a reference to the + original CPU-based estimator class. + - Attribute `_gpuaccel`: Indicates whether GPU acceleration + is enabled. + - By designing the `ProxyEstimator` in this way, we can + seamlessly replace the original CPU-based estimator with a + GPU-accelerated version without altering the existing codebase. + The metaclass ensures that the class behaves and appears + like the original estimator, while the proxy class manages + the underlying acceleration and compatibility. + + **Serialization/Pickling of ProxyEstimators** + + Since pickle has strict rules about serializing classes, we cannot + (reasonably) create a method that just pickles and unpickles a + ProxyEstimator as if it was just an instance of the original module. + + Therefore, doing a pickling of ProxyEstimator will make it serialize to + a file that can be opened in systems with cuML installed (CPU or GPU). + To serialize for non cuML systems, the to_sklearn and from_sklearn APIs + are being introduced in + + https://github.com/rapidsai/cuml/pull/6102 + + Parameters + ---------- + original_module : str + Original module that is being accelerated + accelerated_module : str + Acceleration module + class_name: str + Name of class beign accelerated + accelerated_class_name : str, optional + Name of accelerator class. If None, then it is assumed it is the same + name as class_name (i.e. the original class in the original module). + + Returns + ------- + A class definition of ProxyEstimator that inherits from + the accelerated library class (cuML). + + Examples + -------- + >>> from module_accelerator import intercept + >>> ProxyEstimator = intercept('sklearn.linear_model', + ... 'cuml.linear_model', 'LinearRegression') + >>> model = ProxyEstimator() + + """ + + if accelerated_class_name is None: + accelerated_class_name = original_class_name + + # Import the original host module and cuML + module_a = cpu_only_import(original_module) + module_b = gpu_only_import(accelerated_module) + + # Store a reference to the original (CPU) class + original_class_a = getattr(module_a, original_class_name) + + # Get the class from cuML so ProxyEstimator inherits from it + class_b = getattr(module_b, accelerated_class_name) + + class ProxyEstimator(class_b): + """ + A proxy estimator class that wraps the accelerated estimator and provides + compatibility with the original estimator interface. + + The ProxyEstimator inherits from the accelerated estimator class and + wraps additional functionality to maintain compatibility with the original + CPU-based estimator. + + It handles the translation of hyperparameters and the transfer of models + between CPU and GPU. + + """ + + def __init__(self, *args, **kwargs): + self._cpu_model_class = ( + original_class_a # Store a reference to the original class + ) + kwargs, self._gpuaccel = self._hyperparam_translator(**kwargs) + super().__init__(*args, **kwargs) + + self._cpu_hyperparams = list( + inspect.signature( + self._cpu_model_class.__init__ + ).parameters.keys() + ) + + def __repr__(self): + """ + Return a formal string representation of the object. + + Returns + ------- + str + A string representation indicating that this is a wrapped + version of the original CPU-based estimator. + """ + return f"wrapped {self._cpu_model_class}" + + def __str__(self): + """ + Return an informal string representation of the object. + + Returns + ------- + str + A string representation indicating that this is a wrapped + version of the original CPU-based estimator. + """ + return f"ProxyEstimator of {self._cpu_model_class}" + + def _check_cpu_model(self): + """ + Checks if an estimator already has created a _cpu_model, + and creates one if necessary. + """ + if not hasattr(self, "_cpu_model"): + self.import_cpu_model() + self.build_cpu_model() + + self.gpu_to_cpu() + + def __getstate__(self): + """ + Prepare the object state for pickling. We need it since + we have a custom function in __reduce__. + + Returns + ------- + dict + The state of the Estimator. + """ + return self.__dict__.copy() + + def __reduce__(self): + """ + Helper for pickle. + + Returns + ------- + tuple + A tuple containing the callable to reconstruct the object + and the arguments for reconstruction. + + Notes + ----- + Disables the module accelerator during pickling to ensure correct serialization. + """ + return ( + reconstruct_proxy, + ( + original_module, + accelerated_module, + original_class_name, + (), + self.__getstate__(), + ), + ) + + logger.debug( + f"Created proxy estimator: ({module_b}, {original_class_name}, {ProxyEstimator})" + ) + setattr(module_b, original_class_name, ProxyEstimator) + accelerated_modules = GlobalSettings().accelerated_modules + + if original_module in accelerated_modules: + proxy_module = accelerated_modules[original_module] + else: + proxy_module = ProxyModule(original_module=module_a) + GlobalSettings().accelerated_modules[original_module] = proxy_module + + proxy_module.add_estimator( + class_name=original_class_name, proxy_estimator=ProxyEstimator + ) + + sys.modules[original_module] = proxy_module + + return ProxyEstimator + + +def reconstruct_proxy( + original_module: str, + accelerated_module: str, + class_name: str, + args: Tuple, + kwargs: Dict, +): + """ + Function to enable pickling of ProxyEstimators since they are defined inside + a function, which Pickle doesn't like without a function or something + that has an absolute import path like this function. + + Parameters + ---------- + original_module : str + Original module that is being accelerated + accelerated_module : str + Acceleration module + class_name: str + Name of class beign accelerated + args : Tuple + Args of class to be deserialized (typically empty for ProxyEstimators) + kwargs : Dict + Keyword arguments to reconstruct the ProxyEstimator instance, typically + state from __setstate__ method. + + Returns + ------- + Instance of ProxyEstimator constructed with the kwargs passed to the function. + + """ + # We probably don't need to intercept again here, since we already stored + # the variables in _wrappers + cls = intercept( + original_module=original_module, + accelerated_module=accelerated_module, + original_class_name=class_name, + ) + + estimator = cls() + estimator.__dict__.update(kwargs) + return estimator diff --git a/python/cuml/cuml/experimental/accel/magics.py b/python/cuml/cuml/experimental/accel/magics.py new file mode 100644 index 0000000000..da9c4494ae --- /dev/null +++ b/python/cuml/cuml/experimental/accel/magics.py @@ -0,0 +1,29 @@ +# +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +try: + from IPython.core.magic import Magics, cell_magic, magics_class + + def load_ipython_extension(ip): + from . import install + + install() + +except ImportError: + + def load_ipython_extension(ip): + pass diff --git a/python/cuml/cuml/internals/base.pyx b/python/cuml/cuml/internals/base.pyx index 9813acbba4..dedbeaf644 100644 --- a/python/cuml/cuml/internals/base.pyx +++ b/python/cuml/cuml/internals/base.pyx @@ -20,6 +20,7 @@ import os import inspect import numbers from importlib import import_module +from cuml.internals.device_support import GPU_ENABLED from cuml.internals.safe_imports import ( cpu_only_import, gpu_only_import_from, @@ -41,6 +42,7 @@ import cuml.internals import cuml.internals.input_utils from cuml.internals.available_devices import is_cuda_available from cuml.internals.device_type import DeviceType +from cuml.internals.global_settings import GlobalSettings from cuml.internals.input_utils import ( determine_array_type, input_to_cuml_array, @@ -202,6 +204,8 @@ class Base(TagsMixin, del base # optional! """ + _hyperparam_interop_translator = {} + def __init__(self, *, handle=None, verbose=False, @@ -471,6 +475,31 @@ class Base(TagsMixin, func = nvtx_annotate(message=msg, domain="cuml_python")(func) setattr(self, func_name, func) + @classmethod + def _hyperparam_translator(cls, **kwargs): + """ + This method is meant to do checks and translations of hyperparameters + at estimator creating time. + Each children estimator can override the method, returning either + modifier **kwargs with equivalent options, or setting gpuaccel to False + for hyperaparameters not supported by cuML yet. + """ + gpuaccel = True + # Copy it so we can modify it + translations = dict(cls.__bases__[0]._hyperparam_interop_translator) + # Allow the derived class to overwrite the base class + translations.update(cls._hyperparam_interop_translator) + for parameter_name, value in kwargs.items(): + + if parameter_name in translations: + if value in translations[parameter_name]: + if translations[parameter_name][value] == "NotImplemented": + gpuaccel = False + else: + kwargs[parameter_name] = translations[parameter_name][value] + + return kwargs, gpuaccel + # Internal, non class owned helper functions def _check_output_type_str(output_str): @@ -681,11 +710,13 @@ class UniversalBase(Base): keyword arguments to be passed to the function for the call """ # look for current device_type - device_type = cuml.global_settings.device_type + # device_type = cuml.global_settings.device_type + device_type = self._dispatch_selector(func_name, *args, **kwargs) - # GPU case if device_type == DeviceType.device: # call the function from the GPU estimator + if GlobalSettings().accelerator_active: + logger.info(f"cuML: Performing {func_name} in GPU") return gpu_func(self, *args, **kwargs) # CPU case @@ -708,6 +739,7 @@ class UniversalBase(Base): # get the function from the GPU estimator cpu_func = getattr(self._cpu_model, func_name) # call the function from the GPU estimator + logger.info(f"cuML: Performing {func_name} in CPU") res = cpu_func(*args, **kwargs) # CPU training @@ -725,3 +757,38 @@ class UniversalBase(Base): # return function result return res + + def _dispatch_selector(self, func_name, *args, **kwargs): + """ + """ + # if not using accelerator, then return global device + if not hasattr(self, "_gpuaccel"): + return cuml.global_settings.device_type + + # if using accelerator and doing inference, always use GPU + elif func_name not in ['fit', 'fit_transform', 'fit_predict']: + device_type = DeviceType.device + + # otherwise we select CPU when _gpuaccel is off + elif not self._gpuaccel: + device_type = DeviceType.host + else: + if not self._should_dispatch_cpu(func_name, *args, **kwargs): + device_type = DeviceType.device + else: + device_type = DeviceType.host + + return device_type + + def _should_dispatch_cpu(self, func_name, *args, **kwargs): + """ + This method is meant to do checks of data sizes and other things + at fit and other method call time, to decide where to disptach + a function. For hyperparameters of the estimator, + see the method _hyperparam_translator. + Each estimator inheritting from UniversalBase can override this + method to have custom rules of when to dispatch to CPU depending + on the data passed to fit/predict... + """ + + return False diff --git a/python/cuml/cuml/internals/global_settings.py b/python/cuml/cuml/internals/global_settings.py index ea899d91b1..9dae3ceac1 100644 --- a/python/cuml/cuml/internals/global_settings.py +++ b/python/cuml/cuml/internals/global_settings.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2021-2023, NVIDIA CORPORATION. +# Copyright (c) 2021-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -40,13 +40,21 @@ def __init__(self): default_device_type = DeviceType.host default_memory_type = MemoryType.host self.shared_state = { - "_output_type": None, "_device_type": default_device_type, "_memory_type": default_memory_type, - "root_cm": None, } else: - self.shared_state = {"_output_type": None, "root_cm": None} + self.shared_state = {} + + self.shared_state.update( + { + "_output_type": None, + "root_cm": None, + "accelerator_active": False, + "accelerator_loaded": False, + "accelerated_modules": {}, + } + ) _global_settings_data = _GlobalSettingsData() diff --git a/python/cuml/cuml/linear_model/elastic_net.pyx b/python/cuml/cuml/linear_model/elastic_net.pyx index a8e6b75a3d..7b212b21c9 100644 --- a/python/cuml/cuml/linear_model/elastic_net.pyx +++ b/python/cuml/cuml/linear_model/elastic_net.pyx @@ -150,6 +150,15 @@ class ElasticNet(UniversalBase, _cpu_estimator_import_path = 'sklearn.linear_model.ElasticNet' coef_ = CumlArrayDescriptor(order='F') + _hyperparam_interop_translator = { + "positive": { + True: "NotImplemented", + }, + "warm_start": { + True: "NotImplemented", + }, + } + @device_interop_preparation def __init__(self, *, alpha=1.0, l1_ratio=0.5, fit_intercept=True, normalize=False, max_iter=1000, tol=1e-3, diff --git a/python/cuml/cuml/linear_model/linear_regression.pyx b/python/cuml/cuml/linear_model/linear_regression.pyx index 35a73c111f..f1b64602b3 100644 --- a/python/cuml/cuml/linear_model/linear_regression.pyx +++ b/python/cuml/cuml/linear_model/linear_regression.pyx @@ -268,6 +268,12 @@ class LinearRegression(LinearPredictMixin, coef_ = CumlArrayDescriptor(order='F') intercept_ = CumlArrayDescriptor(order='F') + _hyperparam_interop_translator = { + "positive": { + True: "NotImplemented", + }, + } + @device_interop_preparation def __init__(self, *, algorithm='eig', fit_intercept=True, copy_X=None, normalize=False, diff --git a/python/cuml/cuml/linear_model/logistic_regression.pyx b/python/cuml/cuml/linear_model/logistic_regression.pyx index aa5283fef7..c9ad443750 100644 --- a/python/cuml/cuml/linear_model/logistic_regression.pyx +++ b/python/cuml/cuml/linear_model/logistic_regression.pyx @@ -189,6 +189,17 @@ class LogisticRegression(UniversalBase, class_weight = CumlArrayDescriptor(order='F') expl_spec_weights_ = CumlArrayDescriptor(order='F') + _hyperparam_interop_translator = { + "solver": { + "lbfgs": "qn", + "liblinear": "qn", + "newton-cg": "qn", + "newton-cholesky": "qn", + "sag": "qn", + "saga": "qn" + }, + } + @device_interop_preparation def __init__( self, diff --git a/python/cuml/cuml/linear_model/ridge.pyx b/python/cuml/cuml/linear_model/ridge.pyx index ae84f1002a..bd039867f3 100644 --- a/python/cuml/cuml/linear_model/ridge.pyx +++ b/python/cuml/cuml/linear_model/ridge.pyx @@ -192,6 +192,21 @@ class Ridge(UniversalBase, coef_ = CumlArrayDescriptor(order='F') intercept_ = CumlArrayDescriptor(order='F') + _hyperparam_interop_translator = { + "positive": { + True: "NotImplemented" + }, + "solver": { + "auto": "eig", + "cholesky": "eig", + "lsqr": "eig", + "sag": "eig", + "saga": "eig", + "lbfgs": "NotImplemented", + "sparse_cg": "eig" + }, + } + @device_interop_preparation def __init__(self, *, alpha=1.0, solver='eig', fit_intercept=True, normalize=False, handle=None, output_type=None, @@ -221,13 +236,8 @@ class Ridge(UniversalBase, self.alpha = alpha self.fit_intercept = fit_intercept self.normalize = normalize + self.solver = solver - if solver in ['svd', 'eig', 'cd']: - self.solver = solver - self.algo = self._get_algorithm_int(solver) - else: - msg = "solver {!r} is not supported" - raise TypeError(msg.format(solver)) self.intercept_value = 0.0 def _check_alpha(self, alpha): @@ -236,6 +246,9 @@ class Ridge(UniversalBase, raise TypeError(msg.format(alpha)) def _get_algorithm_int(self, algorithm): + if self.solver not in ['svd', 'eig', 'cd']: + msg = "solver {!r} is not supported" + raise TypeError(msg.format(self.solver)) return { 'svd': 0, 'eig': 1, @@ -249,6 +262,8 @@ class Ridge(UniversalBase, Fit the model with X and y. """ + self.algo = self._get_algorithm_int(self.solver) + cdef uintptr_t _X_ptr, _y_ptr, _sample_weight_ptr X_m, n_rows, self.n_features_in_, self.dtype = \ input_to_cuml_array(X, deepcopy=True, diff --git a/python/cuml/cuml/linear_model/ridge_mg.pyx b/python/cuml/cuml/linear_model/ridge_mg.pyx index e5c3273088..7f10a79157 100644 --- a/python/cuml/cuml/linear_model/ridge_mg.pyx +++ b/python/cuml/cuml/linear_model/ridge_mg.pyx @@ -1,5 +1,5 @@ # -# Copyright (c) 2019-2023, NVIDIA CORPORATION. +# Copyright (c) 2019-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -69,6 +69,7 @@ class RidgeMG(MGFitMixin, Ridge): @cuml.internals.api_base_return_any_skipall def _fit(self, X, y, coef_ptr, input_desc): + self.algo = self._get_algorithm_int(self.solver) cdef float float_intercept cdef double double_intercept diff --git a/python/cuml/cuml/manifold/t_sne.pyx b/python/cuml/cuml/manifold/t_sne.pyx index 11ade2ffe2..248a25b933 100644 --- a/python/cuml/cuml/manifold/t_sne.pyx +++ b/python/cuml/cuml/manifold/t_sne.pyx @@ -270,6 +270,12 @@ class TSNE(UniversalBase, X_m = CumlArrayDescriptor() embedding_ = CumlArrayDescriptor() + _hyperparam_interop_translator = { + "n_components": { + 3 : "NotImplemented", + } + } + @device_interop_preparation def __init__(self, *, n_components=2, @@ -302,12 +308,6 @@ class TSNE(UniversalBase, verbose=verbose, output_type=output_type) - if n_components < 0: - raise ValueError("n_components = {} should be more " - "than 0.".format(n_components)) - if n_components != 2: - raise ValueError("Currently TSNE supports n_components = 2; " - "but got n_components = {}".format(n_components)) if perplexity < 0: raise ValueError("perplexity = {} should be more than 0.".format( perplexity)) @@ -427,6 +427,12 @@ class TSNE(UniversalBase, should match the metric used to train the TSNE embeedings. Takes precedence over the precomputed_knn parameter. """ + if self.n_components < 0: + raise ValueError("n_components = {} should be more " + "than 0.".format(self.n_components)) + if self.n_components != 2: + raise ValueError("Currently TSNE supports n_components = 2; " + "but got n_components = {}".format(self.n_components)) cdef int n, p cdef handle_t* handle_ = self.handle.getHandle() if handle_ == NULL: diff --git a/python/cuml/cuml/manifold/umap.pyx b/python/cuml/cuml/manifold/umap.pyx index c873461a95..7000850872 100644 --- a/python/cuml/cuml/manifold/umap.pyx +++ b/python/cuml/cuml/manifold/umap.pyx @@ -234,7 +234,7 @@ class UMAP(UniversalBase, are returned when transform is called on the same data upon which the model was trained. This enables consistent behavior between calling ``model.fit_transform(X)`` and - calling ``model.fit(X).transform(X)``. Not that the CPU-based + calling ``model.fit(X).transform(X)``. Note that the CPU-based UMAP reference implementation does this by default. This feature is made optional in the GPU version due to the significant overhead in copying memory to the host for diff --git a/python/cuml/cuml/neighbors/kneighbors_classifier.pyx b/python/cuml/cuml/neighbors/kneighbors_classifier.pyx index ed72909203..17f8628c95 100644 --- a/python/cuml/cuml/neighbors/kneighbors_classifier.pyx +++ b/python/cuml/cuml/neighbors/kneighbors_classifier.pyx @@ -140,6 +140,17 @@ class KNeighborsClassifier(ClassifierMixin, y = CumlArrayDescriptor() classes_ = CumlArrayDescriptor() + _hyperparam_interop_translator = { + "weights": { + "distance": "NotImplemented", + }, + "algorithm": { + "auto": "brute", + "ball_tree": "brute", + "kd_tree": "brute", + }, + } + def __init__(self, *, weights="uniform", handle=None, verbose=False, output_type=None, **kwargs): super().__init__( diff --git a/python/cuml/cuml/neighbors/kneighbors_regressor.pyx b/python/cuml/cuml/neighbors/kneighbors_regressor.pyx index 75fcfce5c6..78c1525f27 100644 --- a/python/cuml/cuml/neighbors/kneighbors_regressor.pyx +++ b/python/cuml/cuml/neighbors/kneighbors_regressor.pyx @@ -150,6 +150,17 @@ class KNeighborsRegressor(RegressorMixin, y = CumlArrayDescriptor() + _hyperparam_interop_translator = { + "weights": { + "distance": "NotImplemented", + }, + "algorithm": { + "auto": "brute", + "ball_tree": "brute", + "kd_tree": "brute", + }, + } + def __init__(self, *, weights="uniform", handle=None, verbose=False, output_type=None, **kwargs): super().__init__( @@ -159,9 +170,6 @@ class KNeighborsRegressor(RegressorMixin, **kwargs) self.y = None self.weights = weights - if weights != "uniform": - raise ValueError("Only uniform weighting strategy " - "is supported currently.") @generate_docstring(convert_dtype_cast='np.float32') def fit(self, X, y, convert_dtype=True) -> "KNeighborsRegressor": @@ -169,6 +177,9 @@ class KNeighborsRegressor(RegressorMixin, Fit a GPU index for k-nearest neighbors regression model. """ + if self.weights != "uniform": + raise ValueError("Only uniform weighting strategy " + "is supported currently.") self._set_target_dtype(y) super(KNeighborsRegressor, self).fit(X, convert_dtype=convert_dtype) diff --git a/python/cuml/cuml/neighbors/nearest_neighbors.pyx b/python/cuml/cuml/neighbors/nearest_neighbors.pyx index 202d65ca8b..186142ab63 100644 --- a/python/cuml/cuml/neighbors/nearest_neighbors.pyx +++ b/python/cuml/cuml/neighbors/nearest_neighbors.pyx @@ -289,6 +289,20 @@ class NearestNeighbors(UniversalBase, _cpu_estimator_import_path = 'sklearn.neighbors.NearestNeighbors' _fit_X = CumlArrayDescriptor(order='C') + _hyperparam_interop_translator = { + "weights": { + "distance": "NotImplemented", + }, + "algorithm": { + "auto": "brute", + "ball_tree": "brute", + "kd_tree": "brute", + }, + "metric": { + "mahalanobis": "NotImplemented" + } + } + @device_interop_preparation def __init__(self, *, n_neighbors=5, @@ -320,7 +334,7 @@ class NearestNeighbors(UniversalBase, @generate_docstring(X='dense_sparse') @enable_device_interop - def fit(self, X, convert_dtype=True) -> "NearestNeighbors": + def fit(self, X, y=None, convert_dtype=True) -> "NearestNeighbors": """ Fit GPU index for performing nearest neighbor queries. @@ -681,7 +695,7 @@ class NearestNeighbors(UniversalBase, return (D_ndarr, I_ndarr) if return_distance else I_ndarr - def _kneighbors_dense(self, X, n_neighbors, convert_dtype=None): + def _kneighbors_dense(self, X, n_neighbors, convert_dtype=True): if not is_dense(X): raise ValueError("A NearestNeighbors model trained on dense " diff --git a/python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_dbscan.py b/python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_dbscan.py new file mode 100644 index 0000000000..bf571fe6ec --- /dev/null +++ b/python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_dbscan.py @@ -0,0 +1,91 @@ +# +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest +import numpy as np +from sklearn.datasets import make_blobs +from sklearn.cluster import DBSCAN +from sklearn.metrics import adjusted_rand_score + + +@pytest.fixture(scope="module") +def clustering_data(): + X, y = make_blobs( + n_samples=300, + centers=3, + cluster_std=[1.0, 2.5, 0.5], + random_state=42, + ) + return X, y + + +@pytest.mark.parametrize("eps", [0.1, 0.5, 1.0, 2.0]) +def test_dbscan_eps(clustering_data, eps): + X, y_true = clustering_data + dbscan = DBSCAN(eps=eps).fit(X) + y_pred = dbscan.labels_ + adjusted_rand_score(y_true, y_pred) + + +@pytest.mark.parametrize("min_samples", [1, 5, 10, 20]) +def test_dbscan_min_samples(clustering_data, min_samples): + X, y_true = clustering_data + dbscan = DBSCAN(eps=0.5, min_samples=min_samples).fit(X) + y_pred = dbscan.labels_ + adjusted_rand_score(y_true, y_pred) + + +@pytest.mark.parametrize("metric", ["euclidean", "manhattan", "chebyshev"]) +def test_dbscan_metric(clustering_data, metric): + X, y_true = clustering_data + dbscan = DBSCAN(eps=0.5, metric=metric).fit(X) + y_pred = dbscan.labels_ + adjusted_rand_score(y_true, y_pred) + + +@pytest.mark.parametrize( + "algorithm", ["auto", "ball_tree", "kd_tree", "brute"] +) +def test_dbscan_algorithm(clustering_data, algorithm): + X, y_true = clustering_data + dbscan = DBSCAN(eps=0.5, algorithm=algorithm).fit(X) + y_pred = dbscan.labels_ + adjusted_rand_score(y_true, y_pred) + + +@pytest.mark.parametrize("leaf_size", [10, 30, 50]) +def test_dbscan_leaf_size(clustering_data, leaf_size): + X, y_true = clustering_data + dbscan = DBSCAN(eps=0.5, leaf_size=leaf_size).fit(X) + y_pred = dbscan.labels_ + adjusted_rand_score(y_true, y_pred) + + +@pytest.mark.parametrize("p", [1, 2, 3]) +def test_dbscan_p(clustering_data, p): + X, y_true = clustering_data + dbscan = DBSCAN(eps=0.5, metric="minkowski", p=p).fit(X) + y_pred = dbscan.labels_ + adjusted_rand_score(y_true, y_pred) + + +def test_dbscan_consistency(clustering_data): + X, y_true = clustering_data + dbscan1 = DBSCAN(eps=0.5).fit(X) + dbscan2 = DBSCAN(eps=0.5).fit(X) + assert np.array_equal( + dbscan1.labels_, dbscan2.labels_ + ), "Results should be consistent across runs" diff --git a/python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_elastic_net.py b/python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_elastic_net.py new file mode 100644 index 0000000000..eaedc6446f --- /dev/null +++ b/python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_elastic_net.py @@ -0,0 +1,209 @@ +# +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest +import numpy as np +from sklearn.datasets import make_regression +from sklearn.linear_model import ElasticNet +from sklearn.metrics import mean_squared_error, r2_score +from sklearn.preprocessing import StandardScaler + + +@pytest.fixture(scope="module") +def regression_data(): + X, y = make_regression( + n_samples=500, + n_features=20, + n_informative=10, + noise=0.1, + random_state=42, + ) + # Standardize features + X = StandardScaler().fit_transform(X) + return X, y + + +@pytest.mark.parametrize("alpha", [0.1, 0.5, 1.0, 2.0]) +def test_elasticnet_alpha(regression_data, alpha): + X, y = regression_data + model = ElasticNet(alpha=alpha, random_state=42) + model.fit(X, y) + y_pred = model.predict(X) + # Compute R^2 score + r2 = r2_score(y, y_pred) + assert r2 > 0.5, f"R^2 score should be reasonable for alpha={alpha}" + + +@pytest.mark.parametrize("l1_ratio", [0.0, 0.5, 0.7, 1.0]) +def test_elasticnet_l1_ratio(regression_data, l1_ratio): + X, y = regression_data + model = ElasticNet(alpha=1.0, l1_ratio=l1_ratio, random_state=42) + model.fit(X, y) + y_pred = model.predict(X) + # Compute R^2 score + r2 = r2_score(y, y_pred) + assert r2 > 0.5, f"R^2 score should be reasonable for l1_ratio={l1_ratio}" + # Check sparsity of coefficients when l1_ratio=1 (equivalent to Lasso) + if l1_ratio == 1.0: + num_nonzero = np.sum(model.coef_ != 0) + assert ( + num_nonzero < X.shape[1] + ), "Some coefficients should be zero when l1_ratio=1.0" + + +@pytest.mark.parametrize("max_iter", [100]) +def test_elasticnet_max_iter(regression_data, max_iter): + X, y = regression_data + model = ElasticNet(max_iter=max_iter, random_state=42) + model.fit(X, y) + + +@pytest.mark.parametrize("tol", [1e-3]) +def test_elasticnet_tol(regression_data, tol): + X, y = regression_data + model = ElasticNet(tol=tol, random_state=42) + model.fit(X, y) + y_pred = model.predict(X) + # Compute R^2 score + r2 = r2_score(y, y_pred) + assert r2 > 0.5, f"R^2 score should be reasonable for tol={tol}" + + +@pytest.mark.parametrize("fit_intercept", [True, False]) +def test_elasticnet_fit_intercept(regression_data, fit_intercept): + X, y = regression_data + model = ElasticNet(fit_intercept=fit_intercept, random_state=42) + model.fit(X, y) + y_pred = model.predict(X) + # Compute R^2 score + r2 = r2_score(y, y_pred) + assert ( + r2 > 0.5 + ), f"R^2 score should be reasonable with fit_intercept={fit_intercept}" + + +@pytest.mark.parametrize("precompute", [True, False]) +def test_elasticnet_precompute(regression_data, precompute): + X, y = regression_data + model = ElasticNet(precompute=precompute, random_state=42) + model.fit(X, y) + y_pred = model.predict(X) + # Compute R^2 score + r2 = r2_score(y, y_pred) + assert ( + r2 > 0.5 + ), f"R^2 score should be reasonable with precompute={precompute}" + + +@pytest.mark.parametrize("selection", ["cyclic", "random"]) +def test_elasticnet_selection(regression_data, selection): + X, y = regression_data + model = ElasticNet(selection=selection, random_state=42) + model.fit(X, y) + y_pred = model.predict(X) + # Compute R^2 score + r2 = r2_score(y, y_pred) + assert ( + r2 > 0.5 + ), f"R^2 score should be reasonable with selection={selection}" + + +def test_elasticnet_random_state(regression_data): + X, y = regression_data + model1 = ElasticNet(selection="random", random_state=42) + model1.fit(X, y) + model2 = ElasticNet(selection="random", random_state=42) + model2.fit(X, y) + # Coefficients should be the same when random_state is fixed + np.testing.assert_allclose( + model1.coef_, + model2.coef_, + err_msg="Coefficients should be the same with the same random_state", + ) + model3 = ElasticNet(selection="random", random_state=24) + model3.fit(X, y) + + +@pytest.mark.xfail(reason="cuML does not emit ConvergenceWarning yet.") +def test_elasticnet_convergence_warning(regression_data): + X, y = regression_data + from sklearn.exceptions import ConvergenceWarning + + with pytest.warns(ConvergenceWarning): + model = ElasticNet(max_iter=1, random_state=42) + model.fit(X, y) + + +def test_elasticnet_coefficients(regression_data): + X, y = regression_data + model = ElasticNet(alpha=0.1, l1_ratio=0.5, random_state=42) + model.fit(X, y) + coef_nonzero = np.sum(model.coef_ != 0) + assert coef_nonzero > 0, "There should be non-zero coefficients" + + +def test_elasticnet_l1_ratio_effect(regression_data): + X, y = regression_data + model_l1 = ElasticNet(alpha=0.1, l1_ratio=1.0, random_state=42) + model_l1.fit(X, y) + model_l2 = ElasticNet(alpha=0.1, l1_ratio=0.0, random_state=42) + model_l2.fit(X, y) + num_nonzero_l1 = np.sum(model_l1.coef_ != 0) + num_nonzero_l2 = np.sum(model_l2.coef_ != 0) + assert ( + num_nonzero_l1 <= num_nonzero_l2 + ), "L1 regularization should produce sparser coefficients than L2" + + +@pytest.mark.parametrize("copy_X", [True, False]) +def test_elasticnet_copy_X(regression_data, copy_X): + X, y = regression_data + X_original = X.copy() + model = ElasticNet(copy_X=copy_X, random_state=42) + model.fit(X, y) + if copy_X: + # X should remain unchanged + assert np.allclose( + X, X_original + ), "X has been modified when copy_X=True" + else: + # X might be modified when copy_X=False + pass # We cannot guarantee X remains unchanged + + +def test_elasticnet_positive(regression_data): + X, y = regression_data + model = ElasticNet(positive=True, random_state=42) + model.fit(X, y) + # All coefficients should be non-negative + assert np.all( + model.coef_ >= 0 + ), "All coefficients should be non-negative when positive=True" + + +def test_elasticnet_warm_start(regression_data): + X, y = regression_data + model = ElasticNet(warm_start=True, random_state=42) + model.fit(X, y) + coef_old = model.coef_.copy() + # Fit again with more iterations + model.set_params(max_iter=2000) + model.fit(X, y) + coef_new = model.coef_ + # Coefficients should change after more iterations + assert not np.allclose( + coef_old, coef_new + ), "Coefficients should update when warm_start=True" diff --git a/python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_hdbscan_core.py b/python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_hdbscan_core.py new file mode 100644 index 0000000000..f233ebd21e --- /dev/null +++ b/python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_hdbscan_core.py @@ -0,0 +1,317 @@ +# +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest +import numpy as np +from sklearn.datasets import make_blobs, make_moons +from sklearn.preprocessing import StandardScaler +import hdbscan + + +@pytest.fixture(scope="module") +def synthetic_data(): + X, y = make_blobs( + n_samples=500, + n_features=2, + centers=5, + cluster_std=0.5, + random_state=42, + ) + # Standardize features + X = StandardScaler().fit_transform(X) + return X, y + + +@pytest.mark.parametrize("min_cluster_size", [5, 15, 30]) +def test_hdbscan_min_cluster_size(synthetic_data, min_cluster_size): + X, _ = synthetic_data + clusterer = hdbscan.HDBSCAN(min_cluster_size=min_cluster_size) + cluster_labels = clusterer.fit_predict(X) + # Check that clusters are formed + n_clusters = len(set(cluster_labels)) - (1 if -1 in cluster_labels else 0) + assert ( + n_clusters > 0 + ), f"Should find clusters with min_cluster_size={min_cluster_size}" + + +@pytest.mark.parametrize("min_samples", [1, 5, 15]) +def test_hdbscan_min_samples(synthetic_data, min_samples): + X, _ = synthetic_data + clusterer = hdbscan.HDBSCAN(min_samples=min_samples) + cluster_labels = clusterer.fit_predict(X) + # Check that clusters are formed + n_clusters = len(set(cluster_labels)) - (1 if -1 in cluster_labels else 0) + assert ( + n_clusters > 0 + ), f"Should find clusters with min_samples={min_samples}" + + +@pytest.mark.parametrize( + "metric", ["euclidean", "manhattan", "chebyshev", "minkowski"] +) +def test_hdbscan_metric(synthetic_data, metric): + X, _ = synthetic_data + p = 0.5 if metric == "minkowski" else None + clusterer = hdbscan.HDBSCAN(metric=metric, p=p) + cluster_labels = clusterer.fit_predict(X) + # Check that clusters are formed + n_clusters = len(set(cluster_labels)) - (1 if -1 in cluster_labels else 0) + assert n_clusters > 0, f"Should find clusters with metric={metric}" + + +@pytest.mark.parametrize("method", ["eom", "leaf"]) +def test_hdbscan_cluster_selection_method(synthetic_data, method): + X, _ = synthetic_data + clusterer = hdbscan.HDBSCAN(cluster_selection_method=method) + cluster_labels = clusterer.fit_predict(X) + # Check that clusters are formed + n_clusters = len(set(cluster_labels)) - (1 if -1 in cluster_labels else 0) + assert ( + n_clusters > 0 + ), f"Should find clusters with cluster_selection_method={method}" + + +def test_hdbscan_prediction_data(synthetic_data): + X, _ = synthetic_data + clusterer = hdbscan.HDBSCAN(prediction_data=True) + clusterer.fit(X) + # Check that prediction data is available + assert hasattr( + clusterer, "prediction_data_" + ), "Prediction data should be available when prediction_data=True" + + +@pytest.mark.parametrize("algorithm", ["best", "generic"]) +def test_hdbscan_algorithm(synthetic_data, algorithm): + X, _ = synthetic_data + clusterer = hdbscan.HDBSCAN(algorithm=algorithm) + cluster_labels = clusterer.fit_predict(X) + # Check that clusters are formed + n_clusters = len(set(cluster_labels)) - (1 if -1 in cluster_labels else 0) + assert n_clusters > 0, f"Should find clusters with algorithm={algorithm}" + + +@pytest.mark.parametrize("leaf_size", [10, 30, 50]) +def test_hdbscan_leaf_size(synthetic_data, leaf_size): + X, _ = synthetic_data + clusterer = hdbscan.HDBSCAN(leaf_size=leaf_size) + cluster_labels = clusterer.fit_predict(X) + # Check that clusters are formed + n_clusters = len(set(cluster_labels)) - (1 if -1 in cluster_labels else 0) + assert n_clusters > 0, f"Should find clusters with leaf_size={leaf_size}" + + +def test_hdbscan_gen_min_span_tree(synthetic_data): + X, _ = synthetic_data + clusterer = hdbscan.HDBSCAN(gen_min_span_tree=True) + clusterer.fit(X) + # Check that the minimum spanning tree is generated + assert hasattr( + clusterer, "minimum_spanning_tree_" + ), "Minimum spanning tree should be generated when gen_min_span_tree=True" + + +def test_hdbscan_memory(synthetic_data, tmpdir): + X, _ = synthetic_data + from joblib import Memory + + memory = Memory(location=tmpdir) + clusterer = hdbscan.HDBSCAN(memory=memory) + clusterer.fit(X) + # Check that cache directory is used + # assert tmpdir.listdir(), "Cache directory should not be empty when memory caching is used" + + +def test_hdbscan_approx_min_span_tree(synthetic_data): + X, _ = synthetic_data + clusterer = hdbscan.HDBSCAN(approx_min_span_tree=True) + clusterer.fit(X) + # this parameter is ignored in cuML + + +@pytest.mark.parametrize("n_jobs", [1, -1]) +def test_hdbscan_core_dist_n_jobs(synthetic_data, n_jobs): + X, _ = synthetic_data + clusterer = hdbscan.HDBSCAN(core_dist_n_jobs=n_jobs) + clusterer.fit(X) + # We assume the code runs without error; no direct way to test n_jobs effect + assert True, f"HDBSCAN ran successfully with core_dist_n_jobs={n_jobs}" + + +def test_hdbscan_probabilities(synthetic_data): + X, _ = synthetic_data + clusterer = hdbscan.HDBSCAN() + clusterer.fit(X) + # Check that cluster membership probabilities are available + assert hasattr( + clusterer, "probabilities_" + ), "Cluster membership probabilities should be available after fitting" + + +def test_hdbscan_fit_predict(synthetic_data): + X, _ = synthetic_data + clusterer = hdbscan.HDBSCAN() + labels_fit = clusterer.fit(X).labels_ + labels_predict = clusterer.fit_predict(X) + # Check that labels from fit and fit_predict are the same + assert np.array_equal( + labels_fit, labels_predict + ), "Labels from fit and fit_predict should be the same" + + +def test_hdbscan_invalid_metric(synthetic_data): + X, _ = synthetic_data + with pytest.raises(ValueError): + clusterer = hdbscan.HDBSCAN(metric="invalid_metric") + clusterer.fit(X) + + +@pytest.mark.xfail(reason="Dispatching with sparse input not supported yet") +def test_hdbscan_sparse_input(): + from scipy.sparse import csr_matrix + + X, _ = make_blobs( + n_samples=100, + n_features=2, + centers=3, + cluster_std=0.5, + random_state=42, + ) + X_sparse = csr_matrix(X) + clusterer = hdbscan.HDBSCAN() + cluster_labels = clusterer.fit_predict(X_sparse) + # Check that clusters are formed + n_clusters = len(set(cluster_labels)) - (1 if -1 in cluster_labels else 0) + assert n_clusters > 0, "Should find clusters with sparse input data" + + +def test_hdbscan_non_convex_shapes(): + X, y = make_moons(n_samples=300, noise=0.05, random_state=42) + clusterer = hdbscan.HDBSCAN(min_cluster_size=5) + cluster_labels = clusterer.fit_predict(X) + # Check that at least two clusters are found + n_clusters = len(set(cluster_labels)) - (1 if -1 in cluster_labels else 0) + assert n_clusters >= 2, "Should find clusters in non-convex shapes" + + +def test_hdbscan_prediction(synthetic_data): + X_train, _ = synthetic_data + X_test, _ = make_blobs( + n_samples=100, + n_features=2, + centers=5, + cluster_std=0.5, + random_state=24, + ) + X_test = StandardScaler().fit_transform(X_test) + clusterer = hdbscan.HDBSCAN(prediction_data=True) + clusterer.fit(X_train) + test_labels, strengths = hdbscan.approximate_predict(clusterer, X_test) + # Check that labels are assigned to test data + assert ( + len(test_labels) == X_test.shape[0] + ), "Labels should be assigned to test data points" + + +def test_hdbscan_single_linkage_tree(synthetic_data): + X, _ = synthetic_data + clusterer = hdbscan.HDBSCAN(gen_min_span_tree=True) + clusterer.fit(X) + # Check that the single linkage tree is generated + assert hasattr( + clusterer, "single_linkage_tree_" + ), "Single linkage tree should be generated after fitting" + + +def test_hdbscan_condensed_tree(synthetic_data): + X, _ = synthetic_data + clusterer = hdbscan.HDBSCAN() + clusterer.fit(X) + # Check that the condensed tree is available + assert hasattr( + clusterer, "condensed_tree_" + ), "Condensed tree should be available after fitting" + + +@pytest.mark.xfail(reason="Dispatching with examplars_ not supported yet") +def test_hdbscan_exemplars(synthetic_data): + X, _ = synthetic_data + clusterer = hdbscan.HDBSCAN() + clusterer.fit(X) + # Check that cluster exemplars are available + assert hasattr( + clusterer, "exemplars_" + ), "Cluster exemplars should be available after fitting" + + +def test_hdbscan_prediction_data_with_prediction(synthetic_data): + X_train, _ = synthetic_data + clusterer = hdbscan.HDBSCAN(prediction_data=True) + clusterer.fit(X_train) + # Use training data for prediction as a simple test + test_labels, strengths = hdbscan.approximate_predict(clusterer, X_train) + # Check that labels from prediction match original labels + assert np.array_equal( + clusterer.labels_, test_labels + ), "Predicted labels should match original labels for training data" + + +def test_hdbscan_predict_without_prediction_data(synthetic_data): + X_train, _ = synthetic_data + clusterer = hdbscan.HDBSCAN(prediction_data=False) + clusterer.fit(X_train) + with pytest.raises((AttributeError, ValueError)): + hdbscan.approximate_predict(clusterer, X_train) + + +def test_hdbscan_min_cluster_size_effect(synthetic_data): + X, _ = synthetic_data + min_cluster_sizes = [5, 15, 30, 50] + n_clusters_list = [] + for size in min_cluster_sizes: + clusterer = hdbscan.HDBSCAN(min_cluster_size=size) + cluster_labels = clusterer.fit_predict(X) + n_clusters = len(set(cluster_labels)) - ( + 1 if -1 in cluster_labels else 0 + ) + n_clusters_list.append(n_clusters) + # Expect fewer clusters as min_cluster_size increases + assert n_clusters_list == sorted( + n_clusters_list, reverse=True + ), "Number of clusters should decrease as min_cluster_size increases" + + +def test_hdbscan_min_span_tree_effect(synthetic_data): + X, _ = synthetic_data + clusterer_with_tree = hdbscan.HDBSCAN(gen_min_span_tree=True) + clusterer_with_tree.fit(X) + clusterer_without_tree = hdbscan.HDBSCAN(gen_min_span_tree=False) + clusterer_without_tree.fit(X) + # Check that the minimum spanning tree affects the clustering (may not always be true) + assert np.array_equal( + clusterer_with_tree.labels_, clusterer_without_tree.labels_ + ), "Clustering should be consistent regardless of gen_min_span_tree" + + +def test_hdbscan_allow_single_cluster(synthetic_data): + X, _ = synthetic_data + clusterer = hdbscan.HDBSCAN(allow_single_cluster=True) + cluster_labels = clusterer.fit_predict(X) + # Check that clusters are formed + n_clusters = len(set(cluster_labels)) - (1 if -1 in cluster_labels else 0) + assert ( + n_clusters >= 1 + ), "Should allow a single cluster when allow_single_cluster=True" diff --git a/python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_hdbscan_extended.py b/python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_hdbscan_extended.py new file mode 100644 index 0000000000..c43e600b09 --- /dev/null +++ b/python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_hdbscan_extended.py @@ -0,0 +1,214 @@ +# +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import pytest +import numpy as np +from sklearn.datasets import make_blobs, make_moons +from sklearn.preprocessing import StandardScaler +import hdbscan +from hdbscan import validity +from hdbscan import prediction + + +@pytest.fixture(scope="module") +def synthetic_data(): + X, y = make_blobs( + n_samples=500, + n_features=2, + centers=5, + cluster_std=0.5, + random_state=42, + ) + # Standardize features + X = StandardScaler().fit_transform(X) + return X, y + + +def test_hdbscan_approximate_predict(synthetic_data): + X_train, _ = synthetic_data + X_test, _ = make_blobs( + n_samples=100, + n_features=2, + centers=5, + cluster_std=0.5, + random_state=24, + ) + X_test = StandardScaler().fit_transform(X_test) + clusterer = hdbscan.HDBSCAN(prediction_data=True) + clusterer.fit(X_train) + test_labels, strengths = hdbscan.approximate_predict(clusterer, X_test) + # Check that labels are assigned to test data + assert ( + len(test_labels) == X_test.shape[0] + ), "Labels should be assigned to test data points" + assert ( + len(strengths) == X_test.shape[0] + ), "Strengths should be computed for test data points" + # Check that strengths are between 0 and 1 + assert np.all( + (strengths >= 0) & (strengths <= 1) + ), "Strengths should be between 0 and 1" + + +def test_hdbscan_membership_vector(synthetic_data): + X_train, _ = synthetic_data + clusterer = hdbscan.HDBSCAN(prediction_data=True) + clusterer.fit(X_train) + point = X_train[0].reshape((1, 2)) + hdbscan.membership_vector(clusterer, point) + + +def test_hdbscan_all_points_membership_vectors(synthetic_data): + X_train, _ = synthetic_data + clusterer = hdbscan.HDBSCAN(prediction_data=True) + clusterer.fit(X_train) + memberships = hdbscan.all_points_membership_vectors(clusterer) + # Check that the number of membership vectors matches the number of samples + assert ( + len(memberships) == X_train.shape[0] + ), "There should be a membership vector for each sample" + # Check that each membership vector sums to 1 + for membership in memberships: + # Check that all probabilities are between 0 and 1 + assert all( + 0.0 <= v <= 1.0 for v in membership + ), "Probabilities should be between 0 and 1" + + +def test_hdbscan_validity_index(synthetic_data): + X, _ = synthetic_data + clusterer = hdbscan.HDBSCAN() + clusterer.fit(X) + score = validity.validity_index(X, clusterer.labels_, metric="euclidean") + # Check that the validity index is a finite number + assert np.isfinite(score), "Validity index should be a finite number" + + +def test_hdbscan_condensed_tree(synthetic_data): + X, _ = synthetic_data + clusterer = hdbscan.HDBSCAN() + clusterer.fit(X) + condensed_tree = clusterer.condensed_tree_ + # Check that the condensed tree has the expected attributes + assert hasattr( + condensed_tree, "to_pandas" + ), "Condensed tree should have a 'to_pandas' method" + # Convert to pandas DataFrame and check columns + condensed_tree.to_pandas() + + +def test_hdbscan_single_linkage_tree_attribute(synthetic_data): + X, _ = synthetic_data + clusterer = hdbscan.HDBSCAN() + clusterer.fit(X) + single_linkage_tree = clusterer.single_linkage_tree_ + # Check that the single linkage tree has the expected attributes + assert hasattr( + single_linkage_tree, "to_numpy" + ), "Single linkage tree should have a 'to_numpy' method" + # Convert to NumPy array and check shape + sl_tree_array = single_linkage_tree.to_numpy() + assert ( + sl_tree_array.shape[1] == 4 + ), "Single linkage tree array should have 4 columns" + + +def test_hdbscan_flat_clustering(synthetic_data): + X, _ = synthetic_data + clusterer = hdbscan.HDBSCAN() + clusterer.fit(X) + # Extract clusters at a specific cluster_selection_epsilon + clusterer_flat = hdbscan.HDBSCAN(cluster_selection_epsilon=0.1) + clusterer_flat.fit(X) + # Check that clusters are formed + n_clusters_flat = len(set(clusterer_flat.labels_)) - ( + 1 if -1 in clusterer_flat.labels_ else 0 + ) + assert n_clusters_flat > 0, "Should find clusters with flat clustering" + + +def test_hdbscan_prediction_membership_vector(synthetic_data): + X_train, _ = synthetic_data + clusterer = hdbscan.HDBSCAN(prediction_data=True) + clusterer.fit(X_train) + point = X_train[0].reshape((1, 2)) + prediction.membership_vector(clusterer, point) + + +def test_hdbscan_prediction_all_points_membership_vectors(synthetic_data): + X_train, _ = synthetic_data + clusterer = hdbscan.HDBSCAN(prediction_data=True) + clusterer.fit(X_train) + memberships = prediction.all_points_membership_vectors(clusterer) + # Check that the number of membership vectors matches the number of samples + assert ( + len(memberships) == X_train.shape[0] + ), "There should be a membership vector for each sample" + for membership in memberships: + # Check that all probabilities are between 0 and 1 + assert all( + 0.0 <= v <= 1.0 for v in membership + ), "Probabilities should be between 0 and 1" + + +def test_hdbscan_outlier_exposure(synthetic_data): + # Note: hdbscan may not have a function named 'outlier_exposure' + # This is a placeholder for any outlier detection functionality + X, _ = synthetic_data + clusterer = hdbscan.HDBSCAN() + clusterer.fit(X) + # Check if outlier scores are computed + if hasattr(clusterer, "outlier_scores_"): + outlier_scores = clusterer.outlier_scores_ + # Check that outlier scores are finite numbers + assert np.all( + np.isfinite(outlier_scores) + ), "Outlier scores should be finite numbers" + else: + pytest.skip( + "Outlier exposure functionality is not available in this version of HDBSCAN" + ) + + +# test requires networkx +# def test_hdbscan_extract_single_linkage_tree(synthetic_data): +# X, _ = synthetic_data +# clusterer = hdbscan.HDBSCAN() +# clusterer.fit(X) +# # Extract the single linkage tree +# sl_tree = clusterer.single_linkage_tree_.to_networkx() +# # Check that the tree has the correct number of nodes +# assert sl_tree.number_of_nodes() == X.shape[0], "Single linkage tree should have a node for each data point" + + +def test_hdbscan_get_exemplars(synthetic_data): + X, _ = synthetic_data + clusterer = hdbscan.HDBSCAN() + clusterer.fit(X) + if hasattr(clusterer, "exemplars_"): + exemplars = clusterer.exemplars_ + # Check that exemplars are available for each cluster + n_clusters = len(set(clusterer.labels_)) - ( + 1 if -1 in clusterer.labels_ else 0 + ) + assert ( + len(exemplars) == n_clusters + ), "There should be exemplars for each cluster" + else: + pytest.skip( + "Exemplar functionality is not available in this version of HDBSCAN" + ) diff --git a/python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_kmeans.py b/python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_kmeans.py new file mode 100644 index 0000000000..bae2d5c90f --- /dev/null +++ b/python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_kmeans.py @@ -0,0 +1,105 @@ +# +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest +import numpy as np +from sklearn.datasets import make_blobs +from sklearn.cluster import KMeans +from sklearn.metrics import adjusted_rand_score + + +@pytest.fixture(scope="module") +def clustering_data(): + X, y = make_blobs( + n_samples=300, centers=3, cluster_std=1.0, random_state=42 + ) + return X, y + + +@pytest.mark.parametrize("n_clusters", [2, 3, 4, 5]) +def test_kmeans_n_clusters(clustering_data, n_clusters): + X, y_true = clustering_data + kmeans = KMeans(n_clusters=n_clusters, random_state=42).fit(X) + y_pred = kmeans.labels_ + adjusted_rand_score(y_true, y_pred) + + +@pytest.mark.parametrize("init", ["k-means++", "random"]) +def test_kmeans_init(clustering_data, init): + X, y_true = clustering_data + kmeans = KMeans(n_clusters=3, init=init, random_state=42).fit(X) + y_pred = kmeans.labels_ + adjusted_rand_score(y_true, y_pred) + + +@pytest.mark.parametrize("n_init", [1, 5, 10, 20]) +def test_kmeans_n_init(clustering_data, n_init): + X, y_true = clustering_data + kmeans = KMeans(n_clusters=3, n_init=n_init, random_state=42).fit(X) + y_pred = kmeans.labels_ + adjusted_rand_score(y_true, y_pred) + + +@pytest.mark.parametrize("max_iter", [100, 300, 500]) +def test_kmeans_max_iter(clustering_data, max_iter): + X, y_true = clustering_data + kmeans = KMeans(n_clusters=3, max_iter=max_iter, random_state=42).fit(X) + y_pred = kmeans.labels_ + adjusted_rand_score(y_true, y_pred) + + +@pytest.mark.parametrize("tol", [1e-4, 1e-3, 1e-2]) +def test_kmeans_tol(clustering_data, tol): + X, y_true = clustering_data + kmeans = KMeans(n_clusters=3, tol=tol, random_state=42).fit(X) + y_pred = kmeans.labels_ + adjusted_rand_score(y_true, y_pred) + + +@pytest.mark.parametrize("algorithm", ["elkan", "lloyd"]) +def test_kmeans_algorithm(clustering_data, algorithm): + X, y_true = clustering_data + kmeans = KMeans(n_clusters=3, algorithm=algorithm, random_state=42).fit(X) + y_pred = kmeans.labels_ + adjusted_rand_score(y_true, y_pred) + + +@pytest.mark.parametrize("copy_x", [True, False]) +def test_kmeans_copy_x(clustering_data, copy_x): + X, y_true = clustering_data + X_original = X.copy() + kmeans = KMeans(n_clusters=3, copy_x=copy_x, random_state=42).fit(X) + if copy_x: + # X should remain unchanged + assert np.allclose( + X, X_original + ), "X has been modified when copy_x=True" + else: + # X might be modified when copy_x=False + pass # We cannot guarantee X remains unchanged + y_pred = kmeans.labels_ + adjusted_rand_score(y_true, y_pred) + + +def test_kmeans_random_state(clustering_data): + X, y_true = clustering_data + kmeans1 = KMeans(n_clusters=3, random_state=42).fit(X) + kmeans2 = KMeans(n_clusters=3, random_state=42).fit(X) + # With the same random_state, results should be the same + assert np.allclose(kmeans1.cluster_centers_, kmeans2.cluster_centers_) + kmeans3 = KMeans(n_clusters=3, random_state=24).fit(X) + # With different random_state, results might differ + assert not np.allclose(kmeans1.cluster_centers_, kmeans3.cluster_centers_) diff --git a/python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_kneighbors_classifier.py b/python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_kneighbors_classifier.py new file mode 100644 index 0000000000..2bdc487910 --- /dev/null +++ b/python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_kneighbors_classifier.py @@ -0,0 +1,189 @@ +# +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import pytest +import numpy as np +from sklearn.datasets import make_classification +from sklearn.neighbors import KNeighborsClassifier +from sklearn.preprocessing import StandardScaler +from sklearn.metrics import accuracy_score + + +@pytest.fixture(scope="module") +def classification_data(): + X, y = make_classification( + n_samples=500, + n_features=20, + n_informative=15, + n_redundant=5, + n_classes=3, + random_state=42, + ) + # Standardize features + X = StandardScaler().fit_transform(X) + return X, y + + +@pytest.mark.parametrize("n_neighbors", [1, 3, 5, 10]) +def test_knn_classifier_n_neighbors(classification_data, n_neighbors): + X, y = classification_data + model = KNeighborsClassifier(n_neighbors=n_neighbors) + model.fit(X, y) + y_pred = model.predict(X) + acc = accuracy_score(y, y_pred) + assert ( + acc > 0.7 + ), f"Accuracy should be reasonable with n_neighbors={n_neighbors}" + + +@pytest.mark.parametrize("weights", ["uniform"]) +def test_knn_classifier_weights(classification_data, weights): + X, y = classification_data + model = KNeighborsClassifier(weights=weights) + model.fit(X, y) + y_pred = model.predict(X) + acc = accuracy_score(y, y_pred) + assert acc > 0.7, f"Accuracy should be reasonable with weights={weights}" + + +@pytest.mark.parametrize( + "algorithm", ["auto", "ball_tree", "kd_tree", "brute"] +) +def test_knn_classifier_algorithm(classification_data, algorithm): + X, y = classification_data + model = KNeighborsClassifier(algorithm=algorithm) + model.fit(X, y) + y_pred = model.predict(X) + acc = accuracy_score(y, y_pred) + assert ( + acc > 0.7 + ), f"Accuracy should be reasonable with algorithm={algorithm}" + + +@pytest.mark.parametrize("leaf_size", [10, 30, 50]) +def test_knn_classifier_leaf_size(classification_data, leaf_size): + X, y = classification_data + model = KNeighborsClassifier(leaf_size=leaf_size) + model.fit(X, y) + y_pred = model.predict(X) + acc = accuracy_score(y, y_pred) + assert ( + acc > 0.7 + ), f"Accuracy should be reasonable with leaf_size={leaf_size}" + + +@pytest.mark.parametrize( + "metric", ["euclidean", "manhattan", "chebyshev", "minkowski"] +) +def test_knn_classifier_metric(classification_data, metric): + X, y = classification_data + model = KNeighborsClassifier(metric=metric) + model.fit(X, y) + y_pred = model.predict(X) + acc = accuracy_score(y, y_pred) + assert acc > 0.7, f"Accuracy should be reasonable with metric={metric}" + + +@pytest.mark.parametrize("p", [1, 2, 3]) +def test_knn_classifier_p_parameter(classification_data, p): + X, y = classification_data + model = KNeighborsClassifier(metric="minkowski", p=p) + model.fit(X, y) + y_pred = model.predict(X) + acc = accuracy_score(y, y_pred) + assert acc > 0.7, f"Accuracy should be reasonable with p={p}" + + +@pytest.mark.xfail(reason="Dispatching with callable not supported yet") +def test_knn_classifier_weights_callable(classification_data): + X, y = classification_data + + def custom_weights(distances): + return np.ones_like(distances) + + model = KNeighborsClassifier(weights=custom_weights) + model.fit(X, y) + y_pred = model.predict(X) + acc = accuracy_score(y, y_pred) + assert acc > 0.7, "Accuracy should be reasonable with custom weights" + + +def test_knn_classifier_invalid_algorithm(classification_data): + X, y = classification_data + with pytest.raises((ValueError, KeyError)): + model = KNeighborsClassifier(algorithm="invalid_algorithm") + model.fit(X, y) + + +def test_knn_classifier_invalid_metric(classification_data): + X, y = classification_data + with pytest.raises(ValueError): + model = KNeighborsClassifier(metric="invalid_metric") + model.fit(X, y) + + +def test_knn_classifier_invalid_weights(classification_data): + X, y = classification_data + with pytest.raises(ValueError): + model = KNeighborsClassifier(weights="invalid_weight") + model.fit(X, y) + + +def test_knn_classifier_predict_proba(classification_data): + X, y = classification_data + model = KNeighborsClassifier() + model.fit(X, y) + proba = model.predict_proba(X) + # Check that probabilities sum to 1 + assert np.allclose(proba.sum(axis=1), 1), "Probabilities should sum to 1" + # Check shape + assert proba.shape == ( + X.shape[0], + len(np.unique(y)), + ), "Probability matrix shape should be (n_samples, n_classes)" + + +def test_knn_classifier_sparse_input(): + from scipy.sparse import csr_matrix + + X, y = make_classification(n_samples=100, n_features=20, random_state=42) + X_sparse = csr_matrix(X) + model = KNeighborsClassifier() + model.fit(X_sparse, y) + y_pred = model.predict(X_sparse) + acc = accuracy_score(y, y_pred) + assert acc > 0.7, "Accuracy should be reasonable with sparse input" + + +def test_knn_classifier_multilabel(): + from sklearn.datasets import make_multilabel_classification + + X, y = make_multilabel_classification( + n_samples=100, n_features=20, n_classes=3, random_state=42 + ) + model = KNeighborsClassifier() + model.fit(X, y) + y_pred = model.predict(X) + # Check that the predicted shape matches the true labels + assert ( + y_pred.shape == y.shape + ), "Predicted labels should have the same shape as true labels" + # Calculate accuracy for multi-label + acc = (y_pred == y).mean() + assert ( + acc > 0.7 + ), "Accuracy should be reasonable for multi-label classification" diff --git a/python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_kneighbors_regressor.py b/python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_kneighbors_regressor.py new file mode 100644 index 0000000000..cb1c4b67b6 --- /dev/null +++ b/python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_kneighbors_regressor.py @@ -0,0 +1,163 @@ +# +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest +import numpy as np +from sklearn.datasets import make_regression +from sklearn.neighbors import KNeighborsRegressor +from sklearn.preprocessing import StandardScaler +from sklearn.metrics import r2_score + + +@pytest.fixture(scope="module") +def regression_data(): + X, y = make_regression( + n_samples=500, + n_features=20, + n_informative=15, + noise=0.1, + random_state=42, + ) + # Standardize features + X = StandardScaler().fit_transform(X) + return X, y + + +@pytest.mark.parametrize("n_neighbors", [1, 3, 5, 10]) +def test_knn_regressor_n_neighbors(regression_data, n_neighbors): + X, y = regression_data + model = KNeighborsRegressor(n_neighbors=n_neighbors) + model.fit(X, y) + y_pred = model.predict(X) + r2_score(y, y_pred) + + +@pytest.mark.parametrize("weights", ["uniform"]) +def test_knn_regressor_weights(regression_data, weights): + X, y = regression_data + model = KNeighborsRegressor(weights=weights) + model.fit(X, y) + y_pred = model.predict(X) + r2 = r2_score(y, y_pred) + assert r2 > 0.7, f"R^2 score should be reasonable with weights={weights}" + + +@pytest.mark.parametrize( + "algorithm", ["auto", "ball_tree", "kd_tree", "brute"] +) +def test_knn_regressor_algorithm(regression_data, algorithm): + X, y = regression_data + model = KNeighborsRegressor(algorithm=algorithm) + model.fit(X, y) + y_pred = model.predict(X) + r2 = r2_score(y, y_pred) + assert ( + r2 > 0.7 + ), f"R^2 score should be reasonable with algorithm={algorithm}" + + +@pytest.mark.parametrize("leaf_size", [10, 30, 50]) +def test_knn_regressor_leaf_size(regression_data, leaf_size): + X, y = regression_data + model = KNeighborsRegressor(leaf_size=leaf_size) + model.fit(X, y) + y_pred = model.predict(X) + r2 = r2_score(y, y_pred) + assert ( + r2 > 0.7 + ), f"R^2 score should be reasonable with leaf_size={leaf_size}" + + +@pytest.mark.parametrize( + "metric", ["euclidean", "manhattan", "chebyshev", "minkowski"] +) +def test_knn_regressor_metric(regression_data, metric): + X, y = regression_data + model = KNeighborsRegressor(metric=metric) + model.fit(X, y) + y_pred = model.predict(X) + r2_score(y, y_pred) + + +@pytest.mark.parametrize("p", [1, 2, 3]) +def test_knn_regressor_p_parameter(regression_data, p): + X, y = regression_data + model = KNeighborsRegressor(metric="minkowski", p=p) + model.fit(X, y) + y_pred = model.predict(X) + r2 = r2_score(y, y_pred) + assert r2 > 0.7, f"R^2 score should be reasonable with p={p}" + + +@pytest.mark.xfail(reason="Dispatching with callable not supported yet") +def test_knn_regressor_weights_callable(regression_data): + X, y = regression_data + + def custom_weights(distances): + return np.ones_like(distances) + + model = KNeighborsRegressor(weights=custom_weights) + model.fit(X, y) + y_pred = model.predict(X) + r2 = r2_score(y, y_pred) + assert r2 > 0.7, "R^2 score should be reasonable with custom weights" + + +def test_knn_regressor_invalid_algorithm(regression_data): + X, y = regression_data + with pytest.raises((ValueError, KeyError)): + model = KNeighborsRegressor(algorithm="invalid_algorithm") + model.fit(X, y) + + +def test_knn_regressor_invalid_metric(regression_data): + X, y = regression_data + with pytest.raises(ValueError): + model = KNeighborsRegressor(metric="invalid_metric") + model.fit(X, y) + + +def test_knn_regressor_invalid_weights(regression_data): + X, y = regression_data + with pytest.raises(ValueError): + model = KNeighborsRegressor(weights="invalid_weight") + model.fit(X, y) + + +def test_knn_regressor_sparse_input(): + from scipy.sparse import csr_matrix + + X, y = make_regression(n_samples=100, n_features=20, random_state=42) + X_sparse = csr_matrix(X) + model = KNeighborsRegressor() + model.fit(X_sparse, y) + y_pred = model.predict(X_sparse) + r2_score(y, y_pred) + + +def test_knn_regressor_multioutput(): + X, y = make_regression( + n_samples=100, n_features=20, n_targets=3, random_state=42 + ) + model = KNeighborsRegressor() + model.fit(X, y) + y_pred = model.predict(X) + # Check that the predicted shape matches the true targets + assert ( + y_pred.shape == y.shape + ), "Predicted outputs should have the same shape as true outputs" + # Calculate R^2 score for multi-output regression + r2_score(y, y_pred) diff --git a/python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_lasso.py b/python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_lasso.py new file mode 100644 index 0000000000..52b97e72db --- /dev/null +++ b/python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_lasso.py @@ -0,0 +1,193 @@ +# +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest +import numpy as np +from sklearn.datasets import make_regression +from sklearn.linear_model import Lasso +from sklearn.metrics import r2_score +from sklearn.preprocessing import StandardScaler + + +@pytest.fixture(scope="module") +def regression_data(): + X, y, coef = make_regression( + n_samples=500, + n_features=20, + n_informative=10, + noise=0.1, + coef=True, + random_state=42, + ) + # Standardize features + X = StandardScaler().fit_transform(X) + return X, y, coef + + +@pytest.mark.parametrize("alpha", [0.1, 1.0, 10.0, 100.0]) +def test_lasso_alpha(regression_data, alpha): + X, y, _ = regression_data + model = Lasso(alpha=alpha, random_state=42) + model.fit(X, y) + y_pred = model.predict(X) + # Compute R^2 score + r2_score(y, y_pred) + + +def test_lasso_alpha_sparsity(regression_data): + X, y, _ = regression_data + alphas = [0.1, 1.0, 10.0, 100.0] + zero_counts = [] + for alpha in alphas: + model = Lasso(alpha=alpha, random_state=42) + model.fit(X, y) + zero_counts.append(np.sum(model.coef_ == 0)) + # Check that zero_counts increases with alpha + assert zero_counts == sorted( + zero_counts + ), "Number of zero coefficients should increase with alpha" + + +@pytest.mark.parametrize("max_iter", [100]) +def test_lasso_max_iter(regression_data, max_iter): + X, y, _ = regression_data + model = Lasso(max_iter=max_iter, random_state=42) + model.fit(X, y) + + +@pytest.mark.parametrize("tol", [1e-3]) +def test_lasso_tol(regression_data, tol): + X, y, _ = regression_data + model = Lasso(tol=tol, random_state=42) + model.fit(X, y) + y_pred = model.predict(X) + # Compute R^2 score + r2 = r2_score(y, y_pred) + assert r2 > 0.5, f"R^2 score should be reasonable for tol={tol}" + + +@pytest.mark.parametrize("fit_intercept", [True, False]) +def test_lasso_fit_intercept(regression_data, fit_intercept): + X, y, _ = regression_data + model = Lasso(fit_intercept=fit_intercept, random_state=42) + model.fit(X, y) + y_pred = model.predict(X) + # Compute R^2 score + r2 = r2_score(y, y_pred) + assert ( + r2 > 0.5 + ), f"R^2 score should be reasonable with fit_intercept={fit_intercept}" + + +def test_lasso_positive(regression_data): + X, y, _ = regression_data + model = Lasso(positive=True, random_state=42) + model.fit(X, y) + # All coefficients should be non-negative + assert np.all( + model.coef_ >= 0 + ), "All coefficients should be non-negative when positive=True" + + +def test_lasso_random_state(regression_data): + X, y, _ = regression_data + model1 = Lasso(selection="random", random_state=42) + model1.fit(X, y) + model2 = Lasso(selection="random", random_state=42) + model2.fit(X, y) + # Coefficients should be the same when random_state is fixed + np.testing.assert_allclose( + model1.coef_, + model2.coef_, + err_msg="Coefficients should be the same with the same random_state", + ) + model3 = Lasso(selection="random", random_state=24) + model3.fit(X, y) + + +def test_lasso_warm_start(regression_data): + X, y, _ = regression_data + model = Lasso(warm_start=True, random_state=42) + model.fit(X, y) + coef_old = model.coef_.copy() + # Fit again with different alpha + model.set_params(alpha=10.0) + model.fit(X, y) + coef_new = model.coef_ + # Coefficients should change after refitting with a different alpha + assert not np.allclose( + coef_old, coef_new + ), "Coefficients should update when warm_start=True" + + +@pytest.mark.parametrize("copy_X", [True, False]) +def test_lasso_copy_X(regression_data, copy_X): + X, y, _ = regression_data + X_original = X.copy() + model = Lasso(copy_X=copy_X, random_state=42) + model.fit(X, y) + if copy_X: + # X should remain unchanged + assert np.allclose( + X, X_original + ), "X has been modified when copy_X=True" + else: + # X might be modified when copy_X=False + pass # We cannot guarantee X remains unchanged + + +@pytest.mark.xfail(reason="cuML does not emit ConvergenceWarning yet.") +def test_lasso_convergence_warning(regression_data): + X, y, _ = regression_data + from sklearn.exceptions import ConvergenceWarning + + with pytest.warns(ConvergenceWarning): + model = Lasso(max_iter=1, random_state=42) + model.fit(X, y) + + +def test_lasso_coefficients_sparsity(regression_data): + X, y, _ = regression_data + model = Lasso(alpha=1.0, random_state=42) + model.fit(X, y) + coef_zero = np.sum(model.coef_ == 0) + assert ( + coef_zero > 0 + ), "There should be zero coefficients indicating sparsity" + + +@pytest.mark.parametrize("selection", ["cyclic", "random"]) +def test_lasso_selection(regression_data, selection): + X, y, _ = regression_data + model = Lasso(selection=selection, random_state=42) + model.fit(X, y) + y_pred = model.predict(X) + r2 = r2_score(y, y_pred) + assert ( + r2 > 0.5 + ), f"R^2 score should be reasonable with selection={selection}" + + +@pytest.mark.parametrize("precompute", [True, False]) +def test_lasso_precompute(regression_data, precompute): + X, y, _ = regression_data + model = Lasso(precompute=precompute, random_state=42) + model.fit(X, y) + y_pred = model.predict(X) + r2 = r2_score(y, y_pred) + assert ( + r2 > 0.5 + ), f"R^2 score should be reasonable with precompute={precompute}" diff --git a/python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_linear_regression.py b/python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_linear_regression.py new file mode 100644 index 0000000000..5aad625e6d --- /dev/null +++ b/python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_linear_regression.py @@ -0,0 +1,59 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import numpy as np +from sklearn.datasets import make_regression +from sklearn.linear_model import LinearRegression +from sklearn.metrics import r2_score + + +@pytest.fixture(scope="module") +def regression_data(): + X, y = make_regression( + n_samples=100, n_features=20, noise=0.1, random_state=42 + ) + return X, y + + +@pytest.mark.parametrize("fit_intercept", [True, False]) +def test_linear_regression_fit_intercept(regression_data, fit_intercept): + X, y = regression_data + lr = LinearRegression(fit_intercept=fit_intercept).fit(X, y) + lr.predict(X) + + +@pytest.mark.parametrize("copy_X", [True, False]) +def test_linear_regression_copy_X(regression_data, copy_X): + X, y = regression_data + X_original = X.copy() + LinearRegression(copy_X=copy_X).fit(X, y) + if copy_X: + # X should remain unchanged + assert np.array_equal( + X, X_original + ), "X has been modified when copy_X=True" + else: + # X might be modified when copy_X=False + pass # We cannot guarantee X remains unchanged + + +@pytest.mark.parametrize("positive", [True, False]) +def test_linear_regression_positive(regression_data, positive): + X, y = regression_data + lr = LinearRegression(positive=positive).fit(X, y) + lr.predict(X) + if positive: + # Verify that all coefficients are non-negative + assert np.all(lr.coef_ >= 0), "Not all coefficients are non-negative" diff --git a/python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_logistic_regression.py b/python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_logistic_regression.py new file mode 100644 index 0000000000..ea967ae55e --- /dev/null +++ b/python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_logistic_regression.py @@ -0,0 +1,195 @@ +# +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest +import numpy as np +from sklearn.datasets import make_classification +from sklearn.linear_model import LogisticRegression +from sklearn.metrics import accuracy_score + + +@pytest.fixture(scope="module") +def classification_data(): + X, y = make_classification( + n_samples=200, + n_features=20, + n_classes=3, + n_informative=10, + random_state=42, + ) + return X, y + + +@pytest.mark.parametrize( + "penalty, solver", + [ + ("l1", "liblinear"), + ("l1", "saga"), + ("l2", "lbfgs"), + ("l2", "liblinear"), + ("l2", "sag"), + ("l2", "saga"), + ("elasticnet", "saga"), + (None, "lbfgs"), + (None, "saga"), + ], +) +def test_logistic_regression_penalty(classification_data, penalty, solver): + X, y = classification_data + kwargs = {"penalty": penalty, "solver": solver, "max_iter": 200} + if penalty == "elasticnet": + kwargs["l1_ratio"] = 0.5 # l1_ratio is required for elasticnet + clf = LogisticRegression(**kwargs).fit(X, y) + y_pred = clf.predict(X) + accuracy_score(y, y_pred) + + +@pytest.mark.parametrize("dual", [True, False]) +def test_logistic_regression_dual(classification_data, dual): + X, y = classification_data + # 'dual' is only applicable when 'penalty' is 'l2' and 'solver' is 'liblinear' + if dual: + clf = LogisticRegression( + penalty="l2", solver="liblinear", dual=dual, max_iter=200 + ).fit(X, y) + else: + clf = LogisticRegression( + penalty="l2", solver="liblinear", dual=dual, max_iter=200 + ).fit(X, y) + y_pred = clf.predict(X) + accuracy_score(y, y_pred) + + +@pytest.mark.parametrize("tol", [1e-2]) +def test_logistic_regression_tol(classification_data, tol): + X, y = classification_data + clf = LogisticRegression(tol=tol, max_iter=200).fit(X, y) + y_pred = clf.predict(X) + accuracy_score(y, y_pred) + + +@pytest.mark.parametrize("C", [0.01, 0.1, 1.0, 10.0, 100.0]) +def test_logistic_regression_C(classification_data, C): + X, y = classification_data + clf = LogisticRegression(C=C, max_iter=200).fit(X, y) + y_pred = clf.predict(X) + accuracy_score(y, y_pred) + + +@pytest.mark.parametrize("fit_intercept", [True, False]) +def test_logistic_regression_fit_intercept(classification_data, fit_intercept): + X, y = classification_data + clf = LogisticRegression(fit_intercept=fit_intercept, max_iter=200).fit( + X, y + ) + y_pred = clf.predict(X) + accuracy_score(y, y_pred) + + +@pytest.mark.parametrize("intercept_scaling", [0.5, 1.0, 2.0]) +def test_logistic_regression_intercept_scaling( + classification_data, intercept_scaling +): + X, y = classification_data + # 'intercept_scaling' is only used when solver='liblinear' and fit_intercept=True + clf = LogisticRegression( + solver="liblinear", + fit_intercept=True, + intercept_scaling=intercept_scaling, + max_iter=200, + ).fit(X, y) + y_pred = clf.predict(X) + accuracy_score(y, y_pred) + + +@pytest.mark.parametrize("class_weight", [None, "balanced"]) +def test_logistic_regression_class_weight(classification_data, class_weight): + X, y = classification_data + clf = LogisticRegression(class_weight=class_weight, max_iter=200).fit(X, y) + y_pred = clf.predict(X) + accuracy_score(y, y_pred) + + +def test_logistic_regression_class_weight_custom(classification_data): + X, y = classification_data + class_weights = {0: 1, 1: 2, 2: 1} + clf = LogisticRegression(class_weight=class_weights, max_iter=200).fit( + X, y + ) + y_pred = clf.predict(X) + accuracy_score(y, y_pred) + + +@pytest.mark.parametrize( + "solver", ["newton-cg", "lbfgs", "liblinear", "sag", "saga"] +) +def test_logistic_regression_solver(classification_data, solver): + X, y = classification_data + clf = LogisticRegression(solver=solver, max_iter=200).fit(X, y) + y_pred = clf.predict(X) + accuracy_score(y, y_pred) + + +@pytest.mark.parametrize("max_iter", [50, 100, 200, 500]) +def test_logistic_regression_max_iter(classification_data, max_iter): + X, y = classification_data + clf = LogisticRegression(max_iter=max_iter).fit(X, y) + y_pred = clf.predict(X) + accuracy_score(y, y_pred) + + +@pytest.mark.parametrize( + "multi_class, solver", + [ + ("ovr", "liblinear"), + ("ovr", "lbfgs"), + ("multinomial", "lbfgs"), + ("multinomial", "newton-cg"), + ("multinomial", "sag"), + ("multinomial", "saga"), + ("auto", "lbfgs"), + ("auto", "liblinear"), + ], +) +def test_logistic_regression_multi_class( + classification_data, multi_class, solver +): + X, y = classification_data + if solver == "liblinear" and multi_class == "multinomial": + pytest.skip("liblinear does not support multinomial multi_class") + clf = LogisticRegression( + multi_class=multi_class, solver=solver, max_iter=200 + ).fit(X, y) + y_pred = clf.predict(X) + accuracy_score(y, y_pred) + + +@pytest.mark.parametrize("warm_start", [True, False]) +def test_logistic_regression_warm_start(classification_data, warm_start): + X, y = classification_data + clf = LogisticRegression(warm_start=warm_start, max_iter=200).fit(X, y) + y_pred = clf.predict(X) + accuracy_score(y, y_pred) + + +@pytest.mark.parametrize("l1_ratio", [0.0, 0.5, 1.0]) +def test_logistic_regression_l1_ratio(classification_data, l1_ratio): + X, y = classification_data + clf = LogisticRegression( + penalty="elasticnet", solver="saga", l1_ratio=l1_ratio, max_iter=200 + ).fit(X, y) + y_pred = clf.predict(X) + accuracy_score(y, y_pred) diff --git a/python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_nearest_neighbors.py b/python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_nearest_neighbors.py new file mode 100644 index 0000000000..82eaeae327 --- /dev/null +++ b/python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_nearest_neighbors.py @@ -0,0 +1,193 @@ +# +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest +import numpy as np +from sklearn.datasets import make_blobs +from sklearn.neighbors import NearestNeighbors +from sklearn.preprocessing import StandardScaler +from sklearn.metrics import pairwise_distances + + +@pytest.fixture(scope="module") +def synthetic_data(): + X, y = make_blobs( + n_samples=500, + n_features=10, + centers=5, + cluster_std=1.0, + random_state=42, + ) + # Standardize features + X = StandardScaler().fit_transform(X) + return X, y + + +@pytest.mark.parametrize("n_neighbors", [1, 5, 10, 20]) +def test_nearest_neighbors_n_neighbors(synthetic_data, n_neighbors): + X, _ = synthetic_data + model = NearestNeighbors(n_neighbors=n_neighbors) + model.fit(X) + distances, indices = model.kneighbors(X) + # Check that the correct number of neighbors is returned + assert ( + indices.shape[1] == n_neighbors + ), f"Should return {n_neighbors} neighbors" + + +@pytest.mark.parametrize( + "algorithm", ["auto", "ball_tree", "kd_tree", "brute"] +) +def test_nearest_neighbors_algorithm(synthetic_data, algorithm): + X, _ = synthetic_data + model = NearestNeighbors(algorithm=algorithm) + model.fit(X) + distances, indices = model.kneighbors(X) + # Check that the output shape is correct + assert ( + indices.shape[0] == X.shape[0] + ), f"Number of samples should remain the same with algorithm={algorithm}" + + +@pytest.mark.parametrize( + "metric", ["euclidean", "manhattan", "chebyshev", "minkowski"] +) +def test_nearest_neighbors_metric(synthetic_data, metric): + X, _ = synthetic_data + model = NearestNeighbors(metric=metric) + model.fit(X) + model.kneighbors(X) + + +@pytest.mark.parametrize("p", [1, 2, 3]) +def test_nearest_neighbors_p_parameter(synthetic_data, p): + X, _ = synthetic_data + model = NearestNeighbors(metric="minkowski", p=p) + model.fit(X) + distances, indices = model.kneighbors(X) + + +@pytest.mark.parametrize("leaf_size", [10, 30, 50]) +def test_nearest_neighbors_leaf_size(synthetic_data, leaf_size): + X, _ = synthetic_data + model = NearestNeighbors(leaf_size=leaf_size) + model.fit(X) + + +@pytest.mark.parametrize("n_jobs", [1, -1]) +def test_nearest_neighbors_n_jobs(synthetic_data, n_jobs): + X, _ = synthetic_data + model = NearestNeighbors(n_jobs=n_jobs) + model.fit(X) + # We assume the code runs without error; no direct way to test n_jobs effect + assert True, f"NearestNeighbors ran successfully with n_jobs={n_jobs}" + + +@pytest.mark.xfail(reason="cuML doesn't have radius neighbors method") +def test_nearest_neighbors_radius(synthetic_data): + X, _ = synthetic_data + radius = 1.0 + model = NearestNeighbors(radius=radius) + model.fit(X) + distances, indices = model.radius_neighbors(X) + # Check that all returned distances are within the radius + for dist in distances: + assert np.all( + dist <= radius + ), f"All distances should be within the radius {radius}" + + +def test_nearest_neighbors_invalid_algorithm(synthetic_data): + X, _ = synthetic_data + with pytest.raises((ValueError, KeyError)): + model = NearestNeighbors(algorithm="invalid_algorithm") + model.fit(X) + + +def test_nearest_neighbors_invalid_metric(synthetic_data): + X, _ = synthetic_data + with pytest.raises(ValueError): + model = NearestNeighbors(metric="invalid_metric") + model.fit(X) + + +def test_nearest_neighbors_kneighbors_graph(synthetic_data): + X, _ = synthetic_data + n_neighbors = 5 + model = NearestNeighbors(n_neighbors=n_neighbors) + model.fit(X) + graph = model.kneighbors_graph(X) + # Check that the graph is of correct shape and type + assert graph.shape == ( + X.shape[0], + X.shape[0], + ), "Graph shape should be (n_samples, n_samples)" + assert graph.getformat() == "csr", "Graph should be in CSR format" + # Check that each row has n_neighbors non-zero entries + row_counts = np.diff(graph.indptr) + assert np.all( + row_counts == n_neighbors + ), f"Each sample should have {n_neighbors} neighbors in the graph" + + +@pytest.mark.xfail(reason="cuML doesn't have radius neighbors graph method") +def test_nearest_neighbors_radius_neighbors_graph(synthetic_data): + X, _ = synthetic_data + radius = 1.0 + model = NearestNeighbors(radius=radius) + model.fit(X) + graph = model.radius_neighbors_graph(X) + # Check that the graph is of correct shape and type + assert graph.shape == ( + X.shape[0], + X.shape[0], + ), "Graph shape should be (n_samples, n_samples)" + assert graph.getformat() == "csr", "Graph should be in CSR format" + # Check that non-zero entries correspond to distances within the radius + non_zero_indices = graph.nonzero() + pairwise_distances(X[non_zero_indices[0]], X[non_zero_indices[1]]) + + +@pytest.mark.parametrize("return_distance", [True, False]) +def test_nearest_neighbors_return_distance(synthetic_data, return_distance): + X, _ = synthetic_data + model = NearestNeighbors() + model.fit(X) + result = model.kneighbors(X, return_distance=return_distance) + if return_distance: + distances, indices = result + assert ( + distances.shape == indices.shape + ), "Distances and indices should have the same shape" + else: + indices = result + assert indices.shape == ( + X.shape[0], + model.n_neighbors, + ), "Indices shape should match (n_samples, n_neighbors)" + + +def test_nearest_neighbors_sparse_input(): + from scipy.sparse import csr_matrix + + X = csr_matrix(np.random.rand(100, 20)) + model = NearestNeighbors() + model.fit(X) + distances, indices = model.kneighbors(X) + assert distances.shape == ( + X.shape[0], + model.n_neighbors, + ), "Distances shape should match for sparse input" diff --git a/python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_pca.py b/python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_pca.py new file mode 100644 index 0000000000..89423d0f41 --- /dev/null +++ b/python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_pca.py @@ -0,0 +1,160 @@ +# +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest +import numpy as np +from sklearn.datasets import make_classification +from sklearn.decomposition import PCA +from sklearn.preprocessing import StandardScaler + + +@pytest.fixture(scope="module") +def pca_data(): + X, y = make_classification( + n_samples=300, + n_features=10, + n_informative=5, + n_redundant=0, + n_repeated=0, + random_state=42, + ) + # Standardize features before PCA + X = StandardScaler().fit_transform(X) + return X, y + + +@pytest.mark.parametrize("n_components", [2, 5, 8, 10]) +def test_pca_n_components(pca_data, n_components): + X, _ = pca_data + pca = PCA(n_components=n_components).fit(X) + X_transformed = pca.transform(X) + # Check the shape of the transformed data + assert ( + X_transformed.shape[1] == n_components + ), f"Expected {n_components} components, got {X_transformed.shape[1]}" + # Check that explained variance ratios sum up appropriately + total_variance = np.sum(pca.explained_variance_ratio_) + assert ( + total_variance <= 1.1 + ), "Total explained variance cannot exceed with margin for parallel error" + assert ( + total_variance > 0.0 + ), "Total explained variance ratio should be positive" + + +@pytest.mark.parametrize( + "svd_solver", ["auto", "full", "arpack", "randomized"] +) +def test_pca_svd_solver(pca_data, svd_solver): + X, _ = pca_data + pca = PCA(n_components=5, svd_solver=svd_solver, random_state=42).fit(X) + X_transformed = pca.transform(X) + # Reconstruct the data + pca.inverse_transform(X_transformed) + + +@pytest.mark.parametrize("whiten", [True, False]) +def test_pca_whiten(pca_data, whiten): + X, _ = pca_data + pca = PCA(n_components=5, whiten=whiten).fit(X) + X_transformed = pca.transform(X) + # If whiten is True, transformed data should have unit variance + variances = np.var(X_transformed, axis=0) + if whiten: + np.testing.assert_allclose( + variances, + 1.0, + atol=1e-1, + err_msg="Transformed features should have unit variance when whiten=True", + ) + + +@pytest.mark.parametrize("tol", [0.0, 1e-4, 1e-2]) +def test_pca_tol(pca_data, tol): + X, _ = pca_data + pca = PCA( + n_components=5, svd_solver="arpack", tol=tol, random_state=42 + ).fit(X) + pca.transform(X) + # Since 'arpack' is iterative, tol might affect convergence + # Check that the explained variance ratio is reasonable + total_variance = np.sum(pca.explained_variance_ratio_) + assert ( + total_variance > 0.5 + ), "Total explained variance should be significant" + + +def test_pca_random_state(pca_data): + X, _ = pca_data + pca1 = PCA(n_components=5, svd_solver="randomized", random_state=42).fit(X) + pca2 = PCA(n_components=5, svd_solver="randomized", random_state=42).fit(X) + # With the same random_state, components should be the same + np.testing.assert_allclose( + pca1.components_, + pca2.components_, + err_msg="Components should be the same with the same random_state", + ) + + +@pytest.mark.parametrize("copy", [True, False]) +def test_pca_copy(pca_data, copy): + X, _ = pca_data + X_original = X.copy() + PCA(n_components=5, copy=copy).fit(X) + if copy: + # X should remain unchanged + assert np.allclose(X, X_original), "X has been modified when copy=True" + else: + # X might be modified when copy=False + pass # We cannot guarantee X remains unchanged + + +@pytest.mark.parametrize("iterated_power", [0, 3, 5, "auto"]) +def test_pca_iterated_power(pca_data, iterated_power): + X, _ = pca_data + pca = PCA( + n_components=5, + svd_solver="randomized", + iterated_power=iterated_power, + random_state=42, + ).fit(X) + pca.transform(X) + # Check that the explained variance ratio is reasonable + total_variance = np.sum(pca.explained_variance_ratio_) + assert ( + total_variance > 0.5 + ), f"Total explained variance should be significant with iterated_power={iterated_power}" + + +def test_pca_explained_variance_ratio(pca_data): + X, _ = pca_data + pca = PCA(n_components=None).fit(X) + total_variance = np.sum(pca.explained_variance_ratio_) + np.testing.assert_almost_equal( + total_variance, + 1.0, + decimal=5, + err_msg="Total explained variance ratio should sum to 1 when n_components=None", + ) + + +def test_pca_inverse_transform(pca_data): + X, _ = pca_data + pca = PCA(n_components=5).fit(X) + X_transformed = pca.transform(X) + X_reconstructed = pca.inverse_transform(X_transformed) + # Check reconstruction error + np.mean((X - X_reconstructed) ** 2) diff --git a/python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_ridge.py b/python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_ridge.py new file mode 100644 index 0000000000..25c595ae44 --- /dev/null +++ b/python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_ridge.py @@ -0,0 +1,154 @@ +# +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import pytest +import numpy as np +from sklearn.datasets import make_regression +from sklearn.linear_model import Ridge +from sklearn.metrics import mean_squared_error, r2_score +from sklearn.preprocessing import StandardScaler + + +@pytest.fixture(scope="module") +def regression_data(): + X, y = make_regression( + n_samples=500, + n_features=20, + n_informative=10, + noise=0.1, + random_state=42, + ) + # Standardize features + X = StandardScaler().fit_transform(X) + return X, y + + +@pytest.mark.parametrize("alpha", [0.1, 1.0, 10.0, 100.0]) +def test_ridge_alpha(regression_data, alpha): + X, y = regression_data + model = Ridge(alpha=alpha, random_state=42) + model.fit(X, y) + y_pred = model.predict(X) + # Compute R^2 score + r2 = r2_score(y, y_pred) + assert r2 > 0.5, f"R^2 score should be reasonable for alpha={alpha}" + + +@pytest.mark.parametrize( + "solver", + ["auto", "svd", "cholesky", "lsqr", "sparse_cg", "sag", "saga", "lbfgs"], +) +def test_ridge_solver(regression_data, solver): + X, y = regression_data + positive = solver == "lbfgs" + model = Ridge(solver=solver, random_state=42, positive=positive) + model.fit(X, y) + y_pred = model.predict(X) + # Compute R^2 score + r2 = r2_score(y, y_pred) + assert r2 > 0.5, f"R^2 score should be reasonable with solver={solver}" + + +@pytest.mark.parametrize("max_iter", [100]) +def test_ridge_max_iter(regression_data, max_iter): + X, y = regression_data + model = Ridge(max_iter=max_iter, solver="sag", random_state=42) + model.fit(X, y) + + +@pytest.mark.parametrize("tol", [1e-4, 1e-3, 1e-2]) +def test_ridge_tol(regression_data, tol): + X, y = regression_data + model = Ridge(tol=tol, solver="sag", random_state=42) + model.fit(X, y) + y_pred = model.predict(X) + # Compute R^2 score + r2 = r2_score(y, y_pred) + assert r2 > 0.5, f"R^2 score should be reasonable for tol={tol}" + + +@pytest.mark.parametrize("fit_intercept", [True, False]) +def test_ridge_fit_intercept(regression_data, fit_intercept): + X, y = regression_data + model = Ridge(fit_intercept=fit_intercept, random_state=42) + model.fit(X, y) + y_pred = model.predict(X) + # Compute R^2 score + r2 = r2_score(y, y_pred) + assert ( + r2 > 0.5 + ), f"R^2 score should be reasonable with fit_intercept={fit_intercept}" + + +def test_ridge_random_state(regression_data): + X, y = regression_data + model1 = Ridge(solver="sag", random_state=42) + model1.fit(X, y) + model2 = Ridge(solver="sag", random_state=42) + model2.fit(X, y) + # Coefficients should be the same when random_state is fixed + np.testing.assert_allclose( + model1.coef_, + model2.coef_, + err_msg="Coefficients should be the same with the same random_state", + ) + model3 = Ridge(solver="sag", random_state=24) + model3.fit(X, y) + + +@pytest.mark.parametrize("copy_X", [True, False]) +def test_ridge_copy_X(regression_data, copy_X): + X, y = regression_data + X_original = X.copy() + model = Ridge(copy_X=copy_X, random_state=42) + model.fit(X, y) + if copy_X: + # X should remain unchanged + assert np.allclose( + X, X_original + ), "X has been modified when copy_X=True" + else: + # X might be modified when copy_X=False + pass # We cannot guarantee X remains unchanged + + +@pytest.mark.xfail(reason="cuML does not emit ConvergenceWarning yet.") +def test_ridge_convergence_warning(regression_data): + X, y = regression_data + from sklearn.exceptions import ConvergenceWarning + + with pytest.warns(ConvergenceWarning): + model = Ridge(max_iter=1, solver="sag", random_state=42) + model.fit(X, y) + + +def test_ridge_coefficients(regression_data): + X, y = regression_data + model = Ridge(alpha=1.0, random_state=42) + model.fit(X, y) + coef_nonzero = np.sum(model.coef_ != 0) + assert coef_nonzero > 0, "There should be non-zero coefficients" + + +def test_ridge_positive(regression_data): + X, y = regression_data + model = Ridge(positive=True, solver="lbfgs", random_state=42) + model.fit(X, y) + # All coefficients should be non-negative + assert np.all( + model.coef_ >= 0 + ), "All coefficients should be non-negative when positive=True" diff --git a/python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_tsne.py b/python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_tsne.py new file mode 100644 index 0000000000..21a8d50df2 --- /dev/null +++ b/python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_tsne.py @@ -0,0 +1,193 @@ +# +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest +import numpy as np +from sklearn.datasets import make_classification +from sklearn.manifold import TSNE +from sklearn.metrics import pairwise_distances +from sklearn.preprocessing import StandardScaler + + +@pytest.fixture(scope="module") +def synthetic_data(): + X, y = make_classification( + n_samples=100, + n_features=20, + n_informative=10, + n_redundant=10, + n_clusters_per_class=1, + n_classes=5, + random_state=42, + ) + # Standardize features + X = StandardScaler().fit_transform(X) + return X, y + + +@pytest.mark.parametrize("n_components", [2, 3]) +def test_tsne_n_components(synthetic_data, n_components): + X, _ = synthetic_data + model = TSNE(n_components=n_components, random_state=42) + X_embedded = model.fit_transform(X) + assert ( + X_embedded.shape[1] == n_components + ), f"Output dimensions should be {n_components}" + + +@pytest.mark.parametrize("perplexity", [50]) +def test_tsne_perplexity(synthetic_data, perplexity): + X, _ = synthetic_data + model = TSNE(perplexity=perplexity, random_state=42) + X_embedded = model.fit_transform(X) + # Check that the embedding has the correct shape + assert ( + X_embedded.shape[0] == X.shape[0] + ), "Number of samples should remain the same" + + +@pytest.mark.parametrize("early_exaggeration", [12.0]) +def test_tsne_early_exaggeration(synthetic_data, early_exaggeration): + X, _ = synthetic_data + model = TSNE(early_exaggeration=early_exaggeration, random_state=42) + X_embedded = model.fit_transform(X) + # Check that the embedding has the correct shape + assert ( + X_embedded.shape[0] == X.shape[0] + ), "Number of samples should remain the same" + + +@pytest.mark.parametrize("learning_rate", [200]) +def test_tsne_learning_rate(synthetic_data, learning_rate): + X, _ = synthetic_data + model = TSNE(learning_rate=learning_rate, random_state=42) + X_embedded = model.fit_transform(X) + # Check that the embedding has the correct shape + assert ( + X_embedded.shape[0] == X.shape[0] + ), "Number of samples should remain the same" + + +@pytest.mark.parametrize("n_iter", [250]) +def test_tsne_n_iter(synthetic_data, n_iter): + X, _ = synthetic_data + model = TSNE(n_iter=n_iter, random_state=42) + model.fit_transform(X) + # Since TSNE may perform additional iterations, check if n_iter_ is at least n_iter + assert ( + model.n_iter_ >= n_iter + ), f"Number of iterations should be at least {n_iter}" + + +@pytest.mark.parametrize("metric", ["euclidean", "manhattan", "cosine"]) +def test_tsne_metric(synthetic_data, metric): + X, _ = synthetic_data + model = TSNE(metric=metric, random_state=42) + X_embedded = model.fit_transform(X) + # Check that the embedding has the correct shape + assert ( + X_embedded.shape[0] == X.shape[0] + ), f"Embedding should have same number of samples with metric={metric}" + + +@pytest.mark.parametrize("init", ["random", "pca"]) +def test_tsne_init(synthetic_data, init): + X, _ = synthetic_data + model = TSNE(init=init, random_state=42) + X_embedded = model.fit_transform(X) + # Check that the embedding has the correct shape + assert ( + X_embedded.shape[0] == X.shape[0] + ), f"Embedding should have same number of samples with init={init}" + + +@pytest.mark.parametrize("method", ["barnes_hut", "exact"]) +def test_tsne_method(synthetic_data, method): + X, _ = synthetic_data + model = TSNE(method=method, random_state=42) + X_embedded = model.fit_transform(X) + # Check that the embedding has the correct shape + assert ( + X_embedded.shape[0] == X.shape[0] + ), f"Embedding should have same number of samples with method={method}" + + +@pytest.mark.parametrize("angle", [0.2]) +def test_tsne_angle(synthetic_data, angle): + X, _ = synthetic_data + model = TSNE(method="barnes_hut", angle=angle, random_state=42) + model.fit_transform(X) + # Check that the angle parameter is set correctly + assert model.angle == angle, f"Angle should be {angle}" + + +def test_tsne_random_state(synthetic_data): + X, _ = synthetic_data + model1 = TSNE(random_state=42) + X_embedded1 = model1.fit_transform(X) + model2 = TSNE(random_state=42) + X_embedded2 = model2.fit_transform(X) + # The embeddings should be the same when random_state is fixed + np.testing.assert_allclose( + X_embedded1, + X_embedded2, + atol=1e-5, + err_msg="Embeddings should be the same with the same random_state", + ) + + +def test_tsne_verbose(synthetic_data, capsys): + X, _ = synthetic_data + model = TSNE(verbose=1, random_state=42) + model.fit_transform(X) + captured = capsys.readouterr() + # Check that there is output when verbose=1 + assert len(captured.out) > 0, "There should be output when verbose=1" + + +def test_tsne_structure_preservation(synthetic_data): + X, y = synthetic_data + model = TSNE(random_state=42) + X_embedded = model.fit_transform(X) + # Compute pairwise distances in original and embedded spaces + dist_original = pairwise_distances(X) + dist_embedded = pairwise_distances(X_embedded) + # Compute correlation between the distances + np.corrcoef(dist_original.ravel(), dist_embedded.ravel())[0, 1] + + +@pytest.mark.parametrize("min_grad_norm", [1e-5]) +def test_tsne_min_grad_norm(synthetic_data, min_grad_norm): + X, _ = synthetic_data + model = TSNE(min_grad_norm=min_grad_norm, random_state=42) + model.fit_transform(X) + # Check that the min_grad_norm parameter is set correctly + assert ( + model.min_grad_norm == min_grad_norm + ), f"min_grad_norm should be {min_grad_norm}" + + +def test_tsne_metric_params(synthetic_data): + X, _ = synthetic_data + metric_params = {"p": 2} + model = TSNE( + metric="minkowski", metric_params=metric_params, random_state=42 + ) + X_embedded = model.fit_transform(X) + # Check that the embedding has the correct shape + assert ( + X_embedded.shape[0] == X.shape[0] + ), "Embedding should have same number of samples with custom metric_params" diff --git a/python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_tsvd.py b/python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_tsvd.py new file mode 100644 index 0000000000..f6e4b9534a --- /dev/null +++ b/python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_tsvd.py @@ -0,0 +1,177 @@ +# +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest +import numpy as np +from sklearn.datasets import make_classification +from sklearn.decomposition import TruncatedSVD +from sklearn.preprocessing import StandardScaler +from scipy.sparse import csr_matrix + + +@pytest.fixture(scope="module") +def svd_data(): + X, y = make_classification( + n_samples=300, + n_features=50, + n_informative=10, + n_redundant=10, + random_state=42, + ) + # Convert the data to a sparse CSR matrix + return X, y + + +@pytest.mark.parametrize("n_components", [5, 10, 20, 30]) +def test_truncated_svd_n_components(svd_data, n_components): + X, _ = svd_data + svd = TruncatedSVD(n_components=n_components, random_state=42) + X_transformed = svd.fit_transform(X) + # Check the shape of the transformed data + assert ( + X_transformed.shape[1] == n_components + ), f"Expected {n_components} components, got {X_transformed.shape[1]}" + # Check that explained variance ratios sum up appropriately + total_variance = np.sum(svd.explained_variance_ratio_) + assert ( + total_variance <= 1.0 + ), "Total explained variance ratio cannot exceed 1" + assert ( + total_variance > 0.0 + ), "Total explained variance ratio should be positive" + + +@pytest.mark.parametrize("algorithm", ["randomized", "arpack"]) +def test_truncated_svd_algorithm(svd_data, algorithm): + X, _ = svd_data + svd = TruncatedSVD(n_components=10, algorithm=algorithm, random_state=42) + X_transformed = svd.fit_transform(X) + # Reconstruct the data + svd.inverse_transform(X_transformed) + + +@pytest.mark.parametrize("n_iter", [5, 7, 10]) +def test_truncated_svd_n_iter(svd_data, n_iter): + X, _ = svd_data + svd = TruncatedSVD(n_components=10, n_iter=n_iter, random_state=42) + svd.fit_transform(X) + # Check that the explained variance ratio is reasonable + total_variance = np.sum(svd.explained_variance_ratio_) + assert ( + total_variance > 0.5 + ), f"Total explained variance should be significant with n_iter={n_iter}" + + +def test_truncated_svd_random_state(svd_data): + X, _ = svd_data + svd1 = TruncatedSVD( + n_components=10, algorithm="randomized", random_state=42 + ) + svd2 = TruncatedSVD( + n_components=10, algorithm="randomized", random_state=42 + ) + svd1.fit_transform(X) + svd2.fit_transform(X) + # With the same random_state, components should be the same + np.testing.assert_allclose( + svd1.components_, + svd2.components_, + err_msg="Components should be the same with the same random_state", + ) + svd3 = TruncatedSVD( + n_components=10, algorithm="randomized", random_state=24 + ) + svd3.fit(X) + + +@pytest.mark.parametrize("tol", [0.0, 1e-4, 1e-2]) +def test_truncated_svd_tol(svd_data, tol): + X, _ = svd_data + svd = TruncatedSVD( + n_components=10, algorithm="arpack", tol=tol, random_state=42 + ) + svd.fit_transform(X) + # Check that the explained variance ratio is reasonable + total_variance = np.sum(svd.explained_variance_ratio_) + assert ( + total_variance > 0.5 + ), f"Total explained variance should be significant with tol={tol}" + + +@pytest.mark.parametrize( + "power_iteration_normalizer", ["auto", "OR", "LU", "none"] +) +def test_truncated_svd_power_iteration_normalizer( + svd_data, power_iteration_normalizer +): + X, _ = svd_data + svd = TruncatedSVD( + n_components=10, + power_iteration_normalizer=power_iteration_normalizer, + random_state=42, + ) + svd.fit_transform(X) + # Check that the explained variance ratio is reasonable + total_variance = np.sum(svd.explained_variance_ratio_) + assert ( + total_variance > 0.5 + ), f"Total explained variance should be significant with power_iteration_normalizer={power_iteration_normalizer}" + + +def test_truncated_svd_inverse_transform(svd_data): + X, _ = svd_data + svd = TruncatedSVD(n_components=10, random_state=42) + X_transformed = svd.fit_transform(X) + X_reconstructed = svd.inverse_transform(X_transformed) + # Check reconstruction error + np.mean((X - X_reconstructed) ** 2) + + +def test_truncated_svd_sparse_input_dense_output(svd_data): + X, _ = svd_data + svd = TruncatedSVD(n_components=10, random_state=42) + X_transformed = svd.fit_transform(X) + # The output should be dense even if input is sparse + assert not isinstance( + X_transformed, csr_matrix + ), "Transformed data should be dense" + + +def test_truncated_svd_components_norm(svd_data): + X, _ = svd_data + svd = TruncatedSVD(n_components=10, random_state=42) + svd.fit(X) + components_norm = np.linalg.norm(svd.components_, axis=1) + np.testing.assert_allclose( + components_norm, + 1.0, + atol=1e-5, + err_msg="Each component should have unit length", + ) + + +@pytest.mark.parametrize("n_oversamples", [5]) +def test_truncated_svd_n_oversamples(svd_data, n_oversamples): + X, _ = svd_data + svd = TruncatedSVD( + n_components=10, n_oversamples=n_oversamples, random_state=42 + ) + svd.fit_transform(X) + # Check that the explained variance ratio is reasonable + total_variance = np.sum(svd.explained_variance_ratio_) + assert ( + total_variance > 0.5 + ), f"Total explained variance should be significant with n_oversamples={n_oversamples}" diff --git a/python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_umap.py b/python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_umap.py new file mode 100644 index 0000000000..543d545caf --- /dev/null +++ b/python/cuml/cuml/tests/experimental/accel/estimators_hyperparams/test_accel_umap.py @@ -0,0 +1,172 @@ +# +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the “License”); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an “AS IS” BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest +import numpy as np +from sklearn.datasets import make_swiss_roll +from umap import UMAP +from sklearn.manifold import trustworthiness + + +@pytest.fixture(scope="module") +def manifold_data(): + X, _ = make_swiss_roll(n_samples=100, noise=0.05, random_state=42) + return X + + +@pytest.mark.parametrize("n_neighbors", [5]) +def test_umap_n_neighbors(manifold_data, n_neighbors): + X = manifold_data + umap = UMAP(n_neighbors=n_neighbors, random_state=42) + X_embedded = umap.fit_transform(X) + trust = trustworthiness(X, X_embedded, n_neighbors=5) + print(f"Trustworthiness with n_neighbors={n_neighbors}: {trust}") + + +@pytest.mark.parametrize("min_dist", [0.0, 0.5]) +def test_umap_min_dist(manifold_data, min_dist): + X = manifold_data + umap = UMAP(min_dist=min_dist, random_state=42) + X_embedded = umap.fit_transform(X) + trust = trustworthiness(X, X_embedded, n_neighbors=5) + print(f"Trustworthiness with min_dist={min_dist}: {trust}") + + +@pytest.mark.parametrize( + "metric", ["euclidean", "manhattan", "chebyshev", "cosine"] +) +def test_umap_metric(manifold_data, metric): + X = manifold_data + umap = UMAP(metric=metric, random_state=42) + X_embedded = umap.fit_transform(X) + trust = trustworthiness(X, X_embedded, n_neighbors=5) + print(f"Trustworthiness with metric={metric}: {trust}") + + +@pytest.mark.parametrize("n_components", [2, 3]) +def test_umap_n_components(manifold_data, n_components): + X = manifold_data + umap = UMAP(n_components=n_components, random_state=42) + X_embedded = umap.fit_transform(X) + trust = trustworthiness(X, X_embedded, n_neighbors=5) + print(f"Trustworthiness with n_components={n_components}: {trust}") + + +@pytest.mark.parametrize("spread", [0.5, 1.5]) +def test_umap_spread(manifold_data, spread): + X = manifold_data + umap = UMAP(spread=spread, random_state=42) + X_embedded = umap.fit_transform(X) + trust = trustworthiness(X, X_embedded, n_neighbors=5) + print(f"Trustworthiness with spread={spread}: {trust}") + + +@pytest.mark.parametrize("negative_sample_rate", [5]) +def test_umap_negative_sample_rate(manifold_data, negative_sample_rate): + X = manifold_data + umap = UMAP(negative_sample_rate=negative_sample_rate, random_state=42) + X_embedded = umap.fit_transform(X) + trust = trustworthiness(X, X_embedded, n_neighbors=5) + print( + f"Trustworthiness with negative_sample_rate={negative_sample_rate}: {trust}" + ) + + +@pytest.mark.parametrize("learning_rate", [0.1, 10.0]) +def test_umap_learning_rate(manifold_data, learning_rate): + X = manifold_data + umap = UMAP(learning_rate=learning_rate, random_state=42) + X_embedded = umap.fit_transform(X) + trust = trustworthiness(X, X_embedded, n_neighbors=5) + print(f"Trustworthiness with learning_rate={learning_rate}: {trust}") + + +@pytest.mark.parametrize("init", ["spectral", "random"]) +def test_umap_init(manifold_data, init): + X = manifold_data + umap = UMAP(init=init, random_state=42) + X_embedded = umap.fit_transform(X) + trust = trustworthiness(X, X_embedded, n_neighbors=5) + print(f"Trustworthiness with init={init}: {trust}") + + +@pytest.mark.parametrize("n_epochs", [100, 200, 500]) +def test_umap_n_epochs(manifold_data, n_epochs): + X = manifold_data + umap = UMAP(n_epochs=n_epochs, random_state=42) + X_embedded = umap.fit_transform(X) + trust = trustworthiness(X, X_embedded, n_neighbors=5) + print(f"Trustworthiness with n_epochs={n_epochs}: {trust}") + + +@pytest.mark.parametrize("local_connectivity", [1, 2, 5]) +def test_umap_local_connectivity(manifold_data, local_connectivity): + X = manifold_data + umap = UMAP(local_connectivity=local_connectivity, random_state=42) + X_embedded = umap.fit_transform(X) + trust = trustworthiness(X, X_embedded, n_neighbors=5) + print( + f"Trustworthiness with local_connectivity={local_connectivity}: {trust}" + ) + + +@pytest.mark.parametrize("repulsion_strength", [0.5, 1.0, 2.0]) +def test_umap_repulsion_strength(manifold_data, repulsion_strength): + X = manifold_data + umap = UMAP(repulsion_strength=repulsion_strength, random_state=42) + X_embedded = umap.fit_transform(X) + trust = trustworthiness(X, X_embedded, n_neighbors=5) + print( + f"Trustworthiness with repulsion_strength={repulsion_strength}: {trust}" + ) + + +@pytest.mark.parametrize("metric_kwds", [{"p": 1}, {"p": 2}, {"p": 3}]) +def test_umap_metric_kwds(manifold_data, metric_kwds): + X = manifold_data + umap = UMAP(metric="minkowski", metric_kwds=metric_kwds, random_state=42) + X_embedded = umap.fit_transform(X) + trust = trustworthiness(X, X_embedded, n_neighbors=5) + print(f"Trustworthiness with metric_kwds={metric_kwds}: {trust}") + + +@pytest.mark.parametrize("angular_rp_forest", [True, False]) +def test_umap_angular_rp_forest(manifold_data, angular_rp_forest): + X = manifold_data + umap = UMAP(angular_rp_forest=angular_rp_forest, random_state=42) + X_embedded = umap.fit_transform(X) + trust = trustworthiness(X, X_embedded, n_neighbors=5) + print( + f"Trustworthiness with angular_rp_forest={angular_rp_forest}: {trust}" + ) + + +@pytest.mark.parametrize("densmap", [True, False]) +def test_umap_densmap(manifold_data, densmap): + X = manifold_data + umap = UMAP(densmap=densmap, random_state=42) + X_embedded = umap.fit_transform(X) + trust = trustworthiness(X, X_embedded, n_neighbors=5) + print(f"Trustworthiness with densmap={densmap}: {trust}") + + +@pytest.mark.parametrize("output_metric", ["euclidean", "manhattan"]) +def test_umap_output_metric(manifold_data, output_metric): + X = manifold_data + umap = UMAP(output_metric=output_metric, random_state=42) + X_embedded = umap.fit_transform(X) + trust = trustworthiness(X, X_embedded, n_neighbors=5) + print(f"Trustworthiness with output_metric={output_metric}: {trust}") diff --git a/python/cuml/cuml/tests/experimental/accel/test_basic_estimators.py b/python/cuml/cuml/tests/experimental/accel/test_basic_estimators.py new file mode 100644 index 0000000000..cacefa9241 --- /dev/null +++ b/python/cuml/cuml/tests/experimental/accel/test_basic_estimators.py @@ -0,0 +1,142 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import numpy as np +from sklearn.datasets import make_classification, make_regression, make_blobs +from sklearn.linear_model import ( + LinearRegression, + LogisticRegression, + ElasticNet, + Ridge, + Lasso, +) +from sklearn.cluster import KMeans, DBSCAN +from sklearn.decomposition import PCA, TruncatedSVD +from sklearn.kernel_ridge import KernelRidge +from sklearn.manifold import TSNE +from sklearn.neighbors import ( + NearestNeighbors, + KNeighborsClassifier, + KNeighborsRegressor, +) +from sklearn.metrics import ( + mean_squared_error, + r2_score, + adjusted_rand_score, + accuracy_score, +) +from scipy.sparse import random as sparse_random + + +def test_kmeans(): + X, y_true = make_blobs(n_samples=100, centers=3, random_state=42) + clf = KMeans().fit(X) + clf.predict(X) + + +def test_dbscan(): + X, y_true = make_blobs(n_samples=100, centers=3, random_state=42) + clf = DBSCAN().fit(X) + clf.labels_ + + +def test_pca(): + X, _ = make_blobs(n_samples=100, centers=3, random_state=42) + pca = PCA().fit(X) + pca.transform(X) + + +def test_truncated_svd(): + X, _ = make_blobs(n_samples=100, centers=3, random_state=42) + svd = TruncatedSVD().fit(X) + svd.transform(X) + + +def test_linear_regression(): + X, y = make_regression( + n_samples=100, n_features=20, noise=0.1, random_state=42 + ) + lr = LinearRegression().fit(X, y) + lr.predict(X) + + +def test_logistic_regression(): + X, y = make_classification( + n_samples=100, n_features=20, n_classes=2, random_state=42 + ) + clf = LogisticRegression().fit(X, y) + clf.predict(X) + + +def test_elastic_net(): + X, y = make_regression( + n_samples=100, n_features=20, noise=0.1, random_state=42 + ) + enet = ElasticNet().fit(X, y) + enet.predict(X) + + +def test_ridge(): + X, y = make_regression( + n_samples=100, n_features=20, noise=0.1, random_state=42 + ) + ridge = Ridge().fit(X, y) + ridge.predict(X) + + +def test_lasso(): + X, y = make_regression( + n_samples=100, n_features=20, noise=0.1, random_state=42 + ) + lasso = Lasso().fit(X, y) + lasso.predict(X) + + +def test_tsne(): + X, _ = make_blobs(n_samples=100, centers=3, n_features=20, random_state=42) + tsne = TSNE() + tsne.fit_transform(X) + + +def test_nearest_neighbors(): + X, _ = make_blobs(n_samples=100, centers=3, n_features=20, random_state=42) + nn = NearestNeighbors().fit(X) + distances, indices = nn.kneighbors(X) + assert distances.shape == (100, 5) + assert indices.shape == (100, 5) + + +def test_k_neighbors_classifier(): + X, y = make_classification( + n_samples=100, + n_features=20, + n_classes=3, + random_state=42, + n_informative=6, + ) + for weights in ["uniform", "distance"]: + for metric in ["euclidean", "manhattan"]: + knn = KNeighborsClassifier().fit(X, y) + knn.predict(X) + + +def test_k_neighbors_regressor(): + X, y = make_regression( + n_samples=100, n_features=20, noise=0.1, random_state=42 + ) + for weights in ["uniform", "distance"]: + for metric in ["euclidean", "manhattan"]: + knr = KNeighborsRegressor().fit(X, y) + knr.predict(X) diff --git a/python/cuml/cuml/tests/experimental/accel/test_pipeline.py b/python/cuml/cuml/tests/experimental/accel/test_pipeline.py new file mode 100644 index 0000000000..5327c6957e --- /dev/null +++ b/python/cuml/cuml/tests/experimental/accel/test_pipeline.py @@ -0,0 +1,147 @@ +# +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest +from sklearn.decomposition import PCA, TruncatedSVD +from sklearn.cluster import KMeans, DBSCAN +from sklearn.kernel_ridge import KernelRidge +from sklearn.linear_model import ( + LogisticRegression, + LinearRegression, + ElasticNet, + Ridge, + Lasso, +) +from sklearn.manifold import TSNE +from sklearn.neighbors import ( + NearestNeighbors, + KNeighborsClassifier, + KNeighborsRegressor, +) +from sklearn.pipeline import Pipeline +from sklearn.datasets import make_classification, make_regression +from sklearn.model_selection import train_test_split +from sklearn.metrics import accuracy_score, mean_squared_error +from umap import UMAP +import hdbscan +import numpy as np + + +@pytest.fixture +def classification_data(): + # Create a synthetic dataset for binary classification + X, y = make_classification(n_samples=100, n_features=20, random_state=42) + return train_test_split(X, y, test_size=0.2, random_state=42) + + +@pytest.fixture +def regression_data(): + # Create a synthetic dataset for regression + X, y = make_regression( + n_samples=100, n_features=20, noise=0.1, random_state=42 + ) + return train_test_split(X, y, test_size=0.2, random_state=42) + + +classification_estimators = [ + LogisticRegression(), + KNeighborsClassifier(), +] + +regression_estimators = [ + LinearRegression(), + Ridge(), + Lasso(), + ElasticNet(), + KernelRidge(), + KNeighborsRegressor(), +] + + +@pytest.mark.parametrize( + "transformer", + [ + PCA(n_components=5), + TruncatedSVD(n_components=5), + KMeans(n_clusters=5, random_state=42), + ], +) +@pytest.mark.parametrize("estimator", classification_estimators) +def test_classification_transformers( + transformer, estimator, classification_data +): + X_train, X_test, y_train, y_test = classification_data + # Create pipeline with the transformer and estimator + pipeline = Pipeline( + [("transformer", transformer), ("classifier", estimator)] + ) + # Fit and predict + pipeline.fit(X_train, y_train) + pipeline.predict(X_test) + # Ensure that the result is binary or multiclass classification + + +@pytest.mark.parametrize( + "transformer", + [ + PCA(n_components=5), + TruncatedSVD(n_components=5), + KMeans(n_clusters=5, random_state=42), + ], +) +@pytest.mark.parametrize("estimator", regression_estimators) +def test_regression_transformers(transformer, estimator, regression_data): + X_train, X_test, y_train, y_test = regression_data + # Create pipeline with the transformer and estimator + pipeline = Pipeline( + [("transformer", transformer), ("regressor", estimator)] + ) + # Fit and predict + pipeline.fit(X_train, y_train) + pipeline.predict(X_test) + + +@pytest.mark.parametrize( + "transformer", + [ + PCA(n_components=5), + TruncatedSVD(n_components=5), + KMeans(n_clusters=5, random_state=42), + ], +) +@pytest.mark.parametrize("estimator", [NearestNeighbors(), DBSCAN()]) +def test_unsupervised_neighbors(transformer, estimator, classification_data): + X_train, X_test, _, _ = classification_data + # Create pipeline with the transformer and unsupervised model + pipeline = Pipeline( + [("transformer", transformer), ("unsupervised", estimator)] + ) + # Fit the model (no predict needed for unsupervised learning) + pipeline.fit(X_train) + + +def test_umap_with_logistic_regression(classification_data): + X_train, X_test, y_train, y_test = classification_data + # Create pipeline with UMAP for dimensionality reduction and logistic regression + pipeline = Pipeline( + [ + ("umap", UMAP(n_components=5, random_state=42)), + ("classifier", LogisticRegression()), + ] + ) + # Fit and predict + pipeline.fit(X_train, y_train) + pipeline.predict(X_test) diff --git a/python/cuml/cuml/tests/test_tsne.py b/python/cuml/cuml/tests/test_tsne.py index 0abca2a926..fe119eb999 100644 --- a/python/cuml/cuml/tests/test_tsne.py +++ b/python/cuml/cuml/tests/test_tsne.py @@ -238,7 +238,7 @@ def test_tsne_large(nrows, ncols, method): def test_components_exception(): with pytest.raises(ValueError): - TSNE(n_components=3) + TSNE(n_components=3).fit(np.array([])) @pytest.mark.parametrize("input_type", ["cupy", "scipy"]) From 4fcdeb12bd25e4ec1880d263c2a56195c01aed31 Mon Sep 17 00:00:00 2001 From: Vyas Ramasubramani Date: Fri, 22 Nov 2024 11:46:59 -0800 Subject: [PATCH 3/3] Stop excluding cutlass from symbol exclusion check (#6140) Depends on https://github.com/rapidsai/raft/pull/2503, which includes the kernel visibility fixes needed from cutlass. Authors: - Vyas Ramasubramani (https://github.com/vyasr) Approvers: - Kyle Edwards (https://github.com/KyleFromNVIDIA) URL: https://github.com/rapidsai/cuml/pull/6140 --- .github/workflows/pr.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index f9e1f066b5..f256966583 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -113,7 +113,6 @@ jobs: with: build_type: pull-request enable_check_symbols: true - symbol_exclusions: raft_cutlass conda-python-build: needs: conda-cpp-build secrets: inherit