From 88b81b084b66571ff9a77ef00f2edb10c0951b81 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Mon, 2 Dec 2024 14:59:05 -0800 Subject: [PATCH] Use sparse knn / distances from cuvs (#6143) Use the sparse knn and distances from cuvs instead of from raft. This also allows us to switch over to the cuvs DistanceType, instead of the raft DistanceType (which is now only needed for the RBC code) Authors: - Ben Frederickson (https://github.com/benfred) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/cuml/pull/6143 --- cpp/bench/sg/dbscan.cu | 2 +- cpp/bench/sg/kmeans.cu | 3 +- cpp/bench/sg/linkage.cu | 5 +- cpp/examples/dbscan/dbscan_example.cpp | 2 +- cpp/include/cuml/cluster/dbscan.hpp | 10 +- cpp/include/cuml/cluster/hdbscan.hpp | 13 +- cpp/include/cuml/cluster/linkage.hpp | 9 +- cpp/include/cuml/manifold/tsne.h | 4 +- cpp/include/cuml/manifold/umapparams.h | 5 +- cpp/include/cuml/metrics/metrics.hpp | 20 +- cpp/include/cuml/neighbors/knn.hpp | 8 +- cpp/include/cuml/neighbors/knn_api.h | 2 +- cpp/include/cuml/neighbors/knn_sparse.hpp | 5 +- cpp/src/dbscan/dbscan.cu | 8 +- cpp/src/dbscan/dbscan.cuh | 6 +- cpp/src/dbscan/dbscan_api.cpp | 4 +- cpp/src/dbscan/runner.cuh | 12 +- cpp/src/dbscan/vertexdeg/algo.cuh | 6 +- cpp/src/dbscan/vertexdeg/runner.cuh | 2 +- cpp/src/hdbscan/detail/predict.cuh | 8 +- cpp/src/hdbscan/detail/reachability.cuh | 174 +-------------- cpp/src/hdbscan/detail/soft_clustering.cuh | 13 +- cpp/src/hdbscan/hdbscan.cu | 10 +- cpp/src/hdbscan/runner.h | 8 +- cpp/src/hierarchy/linkage.cu | 10 +- cpp/src/knn/knn.cu | 35 ++-- cpp/src/knn/knn_api.cpp | 4 +- cpp/src/knn/knn_opg_common.cuh | 2 +- cpp/src/knn/knn_sparse.cu | 50 +++-- cpp/src/metrics/pairwise_distance.cu | 22 +- cpp/src/metrics/silhouette_score.cu | 4 +- .../silhouette_score_batched_double.cu | 4 +- .../metrics/silhouette_score_batched_float.cu | 6 +- cpp/src/metrics/trustworthiness.cu | 6 +- cpp/src/tsne/distances.cuh | 53 +++-- cpp/src/tsne/tsne.cu | 3 +- cpp/src/tsne/tsne_runner.cuh | 2 +- cpp/src/umap/knn_graph/algo.cuh | 48 ++--- cpp/src/umap/supervised.cuh | 1 + cpp/src_prims/selection/knn.cuh | 3 +- cpp/test/prims/dist_adj.cu | 4 +- cpp/test/prims/distance_base.cuh | 24 +-- cpp/test/sg/dbscan_test.cu | 198 ++++++++++++++---- cpp/test/sg/hdbscan_test.cu | 45 ++-- cpp/test/sg/linkage_test.cu | 6 +- cpp/test/sg/rproj_test.cu | 2 +- cpp/test/sg/trustworthiness_test.cu | 2 +- cpp/test/sg/tsne_test.cu | 8 +- cpp/test/sg/umap_parametrizable_test.cu | 2 +- python/cuml/cuml/cluster/kmeans.pyx | 5 +- python/cuml/cuml/cluster/kmeans_utils.pxd | 25 +-- python/cuml/cuml/metrics/distance_type.pxd | 46 ++-- .../cuml/cuml/metrics/raft_distance_type.pxd | 40 ++++ python/cuml/cuml/metrics/trustworthiness.pyx | 6 +- .../cuml/cuml/neighbors/nearest_neighbors.pyx | 5 +- 55 files changed, 495 insertions(+), 515 deletions(-) create mode 100644 python/cuml/cuml/metrics/raft_distance_type.pxd diff --git a/cpp/bench/sg/dbscan.cu b/cpp/bench/sg/dbscan.cu index 34fa8631f2..1290f84200 100644 --- a/cpp/bench/sg/dbscan.cu +++ b/cpp/bench/sg/dbscan.cu @@ -57,7 +57,7 @@ class Dbscan : public BlobsFixture { this->params.ncols, D(dParams.eps), dParams.min_pts, - raft::distance::L2SqrtUnexpanded, + cuvs::distance::DistanceType::L2SqrtUnexpanded, this->data.y.data(), this->core_sample_indices, nullptr, diff --git a/cpp/bench/sg/kmeans.cu b/cpp/bench/sg/kmeans.cu index 7fb44a20fb..e974163599 100644 --- a/cpp/bench/sg/kmeans.cu +++ b/cpp/bench/sg/kmeans.cu @@ -19,9 +19,10 @@ #include #include -#include #include +#include + #include namespace ML { diff --git a/cpp/bench/sg/linkage.cu b/cpp/bench/sg/linkage.cu index db27b32e65..4ee88505bb 100644 --- a/cpp/bench/sg/linkage.cu +++ b/cpp/bench/sg/linkage.cu @@ -19,9 +19,10 @@ #include #include -#include #include +#include + #include namespace ML { @@ -55,7 +56,7 @@ class Linkage : public BlobsFixture { this->params.nrows, this->params.ncols, &out_arrs, - raft::distance::DistanceType::L2Unexpanded, + cuvs::distance::DistanceType::L2Unexpanded, 15, 50); }); diff --git a/cpp/examples/dbscan/dbscan_example.cpp b/cpp/examples/dbscan/dbscan_example.cpp index 9c16668a12..3ba367cbdc 100644 --- a/cpp/examples/dbscan/dbscan_example.cpp +++ b/cpp/examples/dbscan/dbscan_example.cpp @@ -203,7 +203,7 @@ int main(int argc, char* argv[]) nCols, eps, minPts, - raft::distance::L2SqrtUnexpanded, + cuvs::distance::DistanceType::L2SqrtUnexpanded, d_labels, nullptr, nullptr, diff --git a/cpp/include/cuml/cluster/dbscan.hpp b/cpp/include/cuml/cluster/dbscan.hpp index 71fe7292c5..d691452db2 100644 --- a/cpp/include/cuml/cluster/dbscan.hpp +++ b/cpp/include/cuml/cluster/dbscan.hpp @@ -18,7 +18,7 @@ #include -#include +#include #include #include @@ -67,7 +67,7 @@ void fit(const raft::handle_t& handle, int n_cols, float eps, int min_pts, - raft::distance::DistanceType metric, + cuvs::distance::DistanceType metric, int* labels, int* core_sample_indices = nullptr, float* sample_weight = nullptr, @@ -81,7 +81,7 @@ void fit(const raft::handle_t& handle, int n_cols, double eps, int min_pts, - raft::distance::DistanceType metric, + cuvs::distance::DistanceType metric, int* labels, int* core_sample_indices = nullptr, double* sample_weight = nullptr, @@ -96,7 +96,7 @@ void fit(const raft::handle_t& handle, int64_t n_cols, float eps, int min_pts, - raft::distance::DistanceType metric, + cuvs::distance::DistanceType metric, int64_t* labels, int64_t* core_sample_indices = nullptr, float* sample_weight = nullptr, @@ -110,7 +110,7 @@ void fit(const raft::handle_t& handle, int64_t n_cols, double eps, int min_pts, - raft::distance::DistanceType metric, + cuvs::distance::DistanceType metric, int64_t* labels, int64_t* core_sample_indices = nullptr, double* sample_weight = nullptr, diff --git a/cpp/include/cuml/cluster/hdbscan.hpp b/cpp/include/cuml/cluster/hdbscan.hpp index eb1223fd88..9a469bc659 100644 --- a/cpp/include/cuml/cluster/hdbscan.hpp +++ b/cpp/include/cuml/cluster/hdbscan.hpp @@ -17,10 +17,11 @@ #pragma once #include -#include #include +#include + #include namespace ML { @@ -424,7 +425,7 @@ void hdbscan(const raft::handle_t& handle, const float* X, size_t m, size_t n, - raft::distance::DistanceType metric, + cuvs::distance::DistanceType metric, HDBSCAN::Common::HDBSCANParams& params, HDBSCAN::Common::hdbscan_output& out, float* core_dists); @@ -456,7 +457,7 @@ void compute_all_points_membership_vectors( HDBSCAN::Common::CondensedHierarchy& condensed_tree, HDBSCAN::Common::PredictionData& prediction_data, const float* X, - raft::distance::DistanceType metric, + cuvs::distance::DistanceType metric, float* membership_vec, size_t batch_size = 4096); @@ -467,7 +468,7 @@ void compute_membership_vector(const raft::handle_t& handle, const float* points_to_predict, size_t n_prediction_points, int min_samples, - raft::distance::DistanceType metric, + cuvs::distance::DistanceType metric, float* membership_vec, size_t batch_size = 4096); @@ -478,7 +479,7 @@ void out_of_sample_predict(const raft::handle_t& handle, int* labels, const float* points_to_predict, size_t n_prediction_points, - raft::distance::DistanceType metric, + cuvs::distance::DistanceType metric, int min_samples, int* out_labels, float* out_probabilities); @@ -501,7 +502,7 @@ void compute_core_dists(const raft::handle_t& handle, float* core_dists, size_t m, size_t n, - raft::distance::DistanceType metric, + cuvs::distance::DistanceType metric, int min_samples); /** diff --git a/cpp/include/cuml/cluster/linkage.hpp b/cpp/include/cuml/cluster/linkage.hpp index f17fa11e21..d327e3dec6 100644 --- a/cpp/include/cuml/cluster/linkage.hpp +++ b/cpp/include/cuml/cluster/linkage.hpp @@ -17,9 +17,10 @@ #pragma once #include -#include #include +#include + namespace raft { class handle_t; } @@ -46,7 +47,7 @@ void single_linkage_pairwise(const raft::handle_t& handle, size_t m, size_t n, raft::hierarchy::linkage_output* out, - raft::distance::DistanceType metric, + cuvs::distance::DistanceType metric, int n_clusters = 5); /** @@ -74,7 +75,7 @@ void single_linkage_neighbors( size_t m, size_t n, raft::hierarchy::linkage_output* out, - raft::distance::DistanceType metric = raft::distance::DistanceType::L2Unexpanded, + cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Unexpanded, int c = 15, int n_clusters = 5); @@ -83,7 +84,7 @@ void single_linkage_pairwise(const raft::handle_t& handle, size_t m, size_t n, raft::hierarchy::linkage_output* out, - raft::distance::DistanceType metric, + cuvs::distance::DistanceType metric, int n_clusters = 5); }; // namespace ML diff --git a/cpp/include/cuml/manifold/tsne.h b/cpp/include/cuml/manifold/tsne.h index b9330e5215..8c658b3c69 100644 --- a/cpp/include/cuml/manifold/tsne.h +++ b/cpp/include/cuml/manifold/tsne.h @@ -18,7 +18,7 @@ #include -#include +#include namespace raft { class handle_t; @@ -106,7 +106,7 @@ struct TSNEParams { bool square_distances = true; // Distance metric to use. - raft::distance::DistanceType metric = raft::distance::DistanceType::L2SqrtExpanded; + cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2SqrtExpanded; // Value of p for Minkowski distance float p = 2.0; diff --git a/cpp/include/cuml/manifold/umapparams.h b/cpp/include/cuml/manifold/umapparams.h index 71418198cf..a337c6cf64 100644 --- a/cpp/include/cuml/manifold/umapparams.h +++ b/cpp/include/cuml/manifold/umapparams.h @@ -19,9 +19,10 @@ #include #include -#include #include +#include + namespace ML { using nn_index_params = raft::neighbors::experimental::nn_descent::index_params; @@ -170,7 +171,7 @@ class UMAPParams { */ bool deterministic = true; - raft::distance::DistanceType metric = raft::distance::DistanceType::L2SqrtExpanded; + cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2SqrtExpanded; float p = 2.0; diff --git a/cpp/include/cuml/metrics/metrics.hpp b/cpp/include/cuml/metrics/metrics.hpp index 8d4fceb28a..cd61e61681 100644 --- a/cpp/include/cuml/metrics/metrics.hpp +++ b/cpp/include/cuml/metrics/metrics.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,7 @@ #pragma once -#include +#include #include @@ -105,7 +105,7 @@ double silhouette_score(const raft::handle_t& handle, int* labels, int nLabels, double* silScores, - raft::distance::DistanceType metric); + cuvs::distance::DistanceType metric); namespace Batched { /** @@ -138,7 +138,7 @@ float silhouette_score(const raft::handle_t& handle, int n_labels, float* scores, int chunk, - raft::distance::DistanceType metric); + cuvs::distance::DistanceType metric); double silhouette_score(const raft::handle_t& handle, double* X, int n_rows, @@ -147,7 +147,7 @@ double silhouette_score(const raft::handle_t& handle, int n_labels, double* scores, int chunk, - raft::distance::DistanceType metric); + cuvs::distance::DistanceType metric); } // namespace Batched /** @@ -349,7 +349,7 @@ void pairwise_distance(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, + cuvs::distance::DistanceType metric, bool isRowMajor = true, double metric_arg = 2.0); @@ -376,7 +376,7 @@ void pairwise_distance(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, + cuvs::distance::DistanceType metric, bool isRowMajor = true, float metric_arg = 2.0f); @@ -393,7 +393,7 @@ void pairwiseDistance_sparse(const raft::handle_t& handle, int* y_indptr, int* x_indices, int* y_indices, - raft::distance::DistanceType metric, + cuvs::distance::DistanceType metric, float metric_arg); void pairwiseDistance_sparse(const raft::handle_t& handle, float* x, @@ -408,7 +408,7 @@ void pairwiseDistance_sparse(const raft::handle_t& handle, int* y_indptr, int* x_indices, int* y_indices, - raft::distance::DistanceType metric, + cuvs::distance::DistanceType metric, float metric_arg); /** @@ -425,7 +425,7 @@ void pairwiseDistance_sparse(const raft::handle_t& handle, * @tparam distance_type: Distance type to consider * @return Trustworthiness score */ -template +template double trustworthiness_score(const raft::handle_t& h, const math_t* X, math_t* X_embedded, diff --git a/cpp/include/cuml/neighbors/knn.hpp b/cpp/include/cuml/neighbors/knn.hpp index 43150cf976..8c7a2de263 100644 --- a/cpp/include/cuml/neighbors/knn.hpp +++ b/cpp/include/cuml/neighbors/knn.hpp @@ -16,10 +16,10 @@ #pragma once -#include #include #include // MetricProcessor +#include #include #include @@ -63,7 +63,7 @@ void brute_force_knn(const raft::handle_t& handle, int k, bool rowMajorIndex = false, bool rowMajorQuery = false, - raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded, + cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Expanded, float metric_arg = 2.0f, std::vector* translations = nullptr); @@ -79,7 +79,7 @@ void rbc_knn_query(const raft::handle_t& handle, float* out_dists); struct knnIndex { - raft::distance::DistanceType metric; + cuvs::distance::DistanceType metric; float metricArg; int nprobe; std::unique_ptr> metric_processor; @@ -123,7 +123,7 @@ struct IVFPQParam : IVFParam { void approx_knn_build_index(raft::handle_t& handle, knnIndex* index, knnIndexParam* params, - raft::distance::DistanceType metric, + cuvs::distance::DistanceType metric, float metricArg, float* index_array, int n, diff --git a/cpp/include/cuml/neighbors/knn_api.h b/cpp/include/cuml/neighbors/knn_api.h index 0ba8ffbc8b..dde96f5ce9 100644 --- a/cpp/include/cuml/neighbors/knn_api.h +++ b/cpp/include/cuml/neighbors/knn_api.h @@ -43,7 +43,7 @@ extern "C" { * @param[in] rowMajorIndex is the index array in row major layout? * @param[in] rowMajorQuery is the query array in row major layout? * @param[in] metric_type the type of distance metric to use. This corresponds - * to the value in the raft::distance::DistanceType enum. + * to the value in the cuvs::distance::DistanceType enum. * Default is Euclidean (L2). * @param[in] metric_arg the value of `p` for Minkowski (l-p) distances. This * is ignored if the metric_type is not Minkowski. diff --git a/cpp/include/cuml/neighbors/knn_sparse.hpp b/cpp/include/cuml/neighbors/knn_sparse.hpp index 3ac0d7969a..8650e4976a 100644 --- a/cpp/include/cuml/neighbors/knn_sparse.hpp +++ b/cpp/include/cuml/neighbors/knn_sparse.hpp @@ -18,9 +18,8 @@ #include -#include - #include +#include namespace raft { class handle_t; @@ -49,7 +48,7 @@ void brute_force_knn(raft::handle_t& handle, int k, size_t batch_size_index = DEFAULT_BATCH_SIZE, size_t batch_size_query = DEFAULT_BATCH_SIZE, - raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded, + cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Expanded, float metricArg = 0); }; // end namespace Sparse }; // end namespace ML diff --git a/cpp/src/dbscan/dbscan.cu b/cpp/src/dbscan/dbscan.cu index 9a08e21863..43c130297e 100644 --- a/cpp/src/dbscan/dbscan.cu +++ b/cpp/src/dbscan/dbscan.cu @@ -29,7 +29,7 @@ void fit(const raft::handle_t& handle, int n_cols, float eps, int min_pts, - raft::distance::DistanceType metric, + cuvs::distance::DistanceType metric, int* labels, int* core_sample_indices, float* sample_weight, @@ -76,7 +76,7 @@ void fit(const raft::handle_t& handle, int n_cols, double eps, int min_pts, - raft::distance::DistanceType metric, + cuvs::distance::DistanceType metric, int* labels, int* core_sample_indices, double* sample_weight, @@ -123,7 +123,7 @@ void fit(const raft::handle_t& handle, int64_t n_cols, float eps, int min_pts, - raft::distance::DistanceType metric, + cuvs::distance::DistanceType metric, int64_t* labels, int64_t* core_sample_indices, float* sample_weight, @@ -170,7 +170,7 @@ void fit(const raft::handle_t& handle, int64_t n_cols, double eps, int min_pts, - raft::distance::DistanceType metric, + cuvs::distance::DistanceType metric, int64_t* labels, int64_t* core_sample_indices, double* sample_weight, diff --git a/cpp/src/dbscan/dbscan.cuh b/cpp/src/dbscan/dbscan.cuh index 84c963f719..a8962010a4 100644 --- a/cpp/src/dbscan/dbscan.cuh +++ b/cpp/src/dbscan/dbscan.cuh @@ -101,7 +101,7 @@ void dbscanFitImpl(const raft::handle_t& handle, Index_ n_cols, T eps, Index_ min_pts, - raft::distance::DistanceType metric, + cuvs::distance::DistanceType metric, Index_* labels, Index_* core_sample_indices, T* sample_weight, @@ -114,7 +114,7 @@ void dbscanFitImpl(const raft::handle_t& handle, ML::Logger::get().setLevel(verbosity); // XXX: for algo_vd and algo_adj, 0 (naive) is no longer an option and has // been removed. - int algo_vd = (metric == raft::distance::Precomputed) ? 2 : 1; + int algo_vd = (metric == cuvs::distance::DistanceType::Precomputed) ? 2 : 1; int algo_adj = 1; int algo_ccl = 2; @@ -147,7 +147,7 @@ void dbscanFitImpl(const raft::handle_t& handle, RAFT_CUDA_TRY(cudaMemGetInfo(&free_memory, &total_memory)); // X can either be a feature matrix or distance matrix - size_t dataset_memory = (metric == raft::distance::Precomputed) + size_t dataset_memory = (metric == cuvs::distance::DistanceType::Precomputed) ? ((size_t)n_rows * (size_t)n_rows * sizeof(T)) : ((size_t)n_rows * (size_t)n_cols * sizeof(T)); diff --git a/cpp/src/dbscan/dbscan_api.cpp b/cpp/src/dbscan/dbscan_api.cpp index 2c821826c3..a052b4e5b2 100644 --- a/cpp/src/dbscan/dbscan_api.cpp +++ b/cpp/src/dbscan/dbscan_api.cpp @@ -44,7 +44,7 @@ cumlError_t cumlSpDbscanFit(cumlHandle_t handle, n_cols, eps, min_pts, - raft::distance::L2SqrtUnexpanded, + cuvs::distance::DistanceType::L2SqrtUnexpanded, labels, core_sample_indices, NULL, @@ -87,7 +87,7 @@ cumlError_t cumlDpDbscanFit(cumlHandle_t handle, n_cols, eps, min_pts, - raft::distance::L2SqrtUnexpanded, + cuvs::distance::DistanceType::L2SqrtUnexpanded, labels, core_sample_indices, NULL, diff --git a/cpp/src/dbscan/runner.cuh b/cpp/src/dbscan/runner.cuh index 804be72fc7..e8a0b5ff51 100644 --- a/cpp/src/dbscan/runner.cuh +++ b/cpp/src/dbscan/runner.cuh @@ -121,7 +121,7 @@ std::size_t run(const raft::handle_t& handle, std::size_t batch_size, EpsNnMethod eps_nn_method, cudaStream_t stream, - raft::distance::DistanceType metric) + cuvs::distance::DistanceType metric) { const std::size_t align = 256; Index_ n_batches = raft::ceildiv((std::size_t)n_owned_rows, batch_size); @@ -136,8 +136,8 @@ std::size_t run(const raft::handle_t& handle, // switch compute mode based on feature dimension bool sparse_rbc_mode = eps_nn_method == EpsNnMethod::RBC; - if (sparse_rbc_mode && metric != raft::distance::DistanceType::L2SqrtExpanded && - metric != raft::distance::DistanceType::L2SqrtUnexpanded) { + if (sparse_rbc_mode && metric != cuvs::distance::DistanceType::L2SqrtExpanded && + metric != cuvs::distance::DistanceType::L2SqrtUnexpanded) { CUML_LOG_WARN("Metric not supported by RBC yet. Falling back to BRUTE_FORCE strategy."); sparse_rbc_mode = false; } @@ -220,7 +220,11 @@ std::size_t run(const raft::handle_t& handle, raft::neighbors::ball_cover::BallCoverIndex* rbc_index_ptr = nullptr; raft::neighbors::ball_cover::BallCoverIndex rbc_index( - handle, x, sparse_rbc_mode ? N : 0, sparse_rbc_mode ? D : 0, metric); + handle, + x, + sparse_rbc_mode ? N : 0, + sparse_rbc_mode ? D : 0, + static_cast(metric)); if (sparse_rbc_mode) { raft::neighbors::ball_cover::build_index(handle, rbc_index); diff --git a/cpp/src/dbscan/vertexdeg/algo.cuh b/cpp/src/dbscan/vertexdeg/algo.cuh index b19345b033..efb1299df2 100644 --- a/cpp/src/dbscan/vertexdeg/algo.cuh +++ b/cpp/src/dbscan/vertexdeg/algo.cuh @@ -22,7 +22,6 @@ #include #include -#include #include #include #include @@ -38,6 +37,7 @@ #include #include +#include #include namespace ML { @@ -157,7 +157,7 @@ void launcher(const raft::handle_t& handle, index_t start_vertex_id, index_t batch_size, cudaStream_t stream, - raft::distance::DistanceType metric) + cuvs::distance::DistanceType metric) { ASSERT(sizeof(index_t) == 4 || sizeof(index_t) == 8, "index_t should be 4 or 8 bytes"); @@ -167,7 +167,7 @@ void launcher(const raft::handle_t& handle, value_t eps2; // Compute adjacency matrix `adj` using Cosine or L2 metric. - if (metric == raft::distance::DistanceType::CosineExpanded) { + if (metric == cuvs::distance::DistanceType::CosineExpanded) { rmm::device_uvector rowNorms(m, stream); raft::linalg::rowNorm(rowNorms.data(), diff --git a/cpp/src/dbscan/vertexdeg/runner.cuh b/cpp/src/dbscan/vertexdeg/runner.cuh index 1138949106..1500e0f5a9 100644 --- a/cpp/src/dbscan/vertexdeg/runner.cuh +++ b/cpp/src/dbscan/vertexdeg/runner.cuh @@ -41,7 +41,7 @@ void run(const raft::handle_t& handle, Index_ start_vertex_id, Index_ batch_size, cudaStream_t stream, - raft::distance::DistanceType metric) + cuvs::distance::DistanceType metric) { Pack data = { rbc_index, vd, wght_sum, ia, ja, max_k, adj, x, sample_weight, eps, N, D}; diff --git a/cpp/src/hdbscan/detail/predict.cuh b/cpp/src/hdbscan/detail/predict.cuh index 966bb761a5..9cbe5fea19 100644 --- a/cpp/src/hdbscan/detail/predict.cuh +++ b/cpp/src/hdbscan/detail/predict.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -160,7 +160,7 @@ void _compute_knn_and_nearest_neighbor(const raft::handle_t& handle, size_t n_prediction_points, value_idx* min_mr_inds, value_t* prediction_lambdas, - raft::distance::DistanceType metric) + cuvs::distance::DistanceType metric) { auto stream = handle.get_stream(); size_t m = prediction_data.n_rows; @@ -233,12 +233,12 @@ void approximate_predict(const raft::handle_t& handle, value_idx* labels, const value_t* points_to_predict, size_t n_prediction_points, - raft::distance::DistanceType metric, + cuvs::distance::DistanceType metric, int min_samples, value_idx* out_labels, value_t* out_probabilities) { - RAFT_EXPECTS(metric == raft::distance::DistanceType::L2SqrtExpanded, + RAFT_EXPECTS(metric == cuvs::distance::DistanceType::L2SqrtExpanded, "Currently only L2 expanded distance is supported"); auto stream = handle.get_stream(); diff --git a/cpp/src/hdbscan/detail/reachability.cuh b/cpp/src/hdbscan/detail/reachability.cuh index 03a7f7c0ad..e0c4c8a799 100644 --- a/cpp/src/hdbscan/detail/reachability.cuh +++ b/cpp/src/hdbscan/detail/reachability.cuh @@ -18,9 +18,7 @@ #include -#include #include -#include #include #include #include @@ -34,6 +32,8 @@ #include #include +#include + namespace ML { namespace HDBSCAN { namespace detail { @@ -93,7 +93,7 @@ void compute_knn(const raft::handle_t& handle, const value_t* search_items, size_t n_search_items, int k, - raft::distance::DistanceType metric) + cuvs::distance::DistanceType metric) { auto stream = handle.get_stream(); auto exec_policy = handle.get_thrust_policy(); @@ -139,10 +139,10 @@ void _compute_core_dists(const raft::handle_t& handle, value_t* core_dists, size_t m, size_t n, - raft::distance::DistanceType metric, + cuvs::distance::DistanceType metric, int min_samples) { - RAFT_EXPECTS(metric == raft::distance::DistanceType::L2SqrtExpanded, + RAFT_EXPECTS(metric == cuvs::distance::DistanceType::L2SqrtExpanded, "Currently only L2 expanded distance is supported"); auto stream = handle.get_stream(); @@ -157,170 +157,6 @@ void _compute_core_dists(const raft::handle_t& handle, core_distances(dists.data(), min_samples, min_samples, m, core_dists, stream); } -// Functor to post-process distances into reachability space -template -struct ReachabilityPostProcess { - DI value_t operator()(value_t value, value_idx row, value_idx col) const - { - return max(core_dists[col], max(core_dists[row], alpha * value)); - } - - const value_t* core_dists; - value_t alpha; -}; - -/** - * Given core distances, Fuses computations of L2 distances between all - * points, projection into mutual reachability space, and k-selection. - * @tparam value_idx - * @tparam value_t - * @param[in] handle raft handle for resource reuse - * @param[out] out_inds output indices array (size m * k) - * @param[out] out_dists output distances array (size m * k) - * @param[in] X input data points (size m * n) - * @param[in] m number of rows in X - * @param[in] n number of columns in X - * @param[in] k neighborhood size (includes self-loop) - * @param[in] core_dists array of core distances (size m) - */ -template -void mutual_reachability_knn_l2(const raft::handle_t& handle, - value_idx* out_inds, - value_t* out_dists, - const value_t* X, - size_t m, - size_t n, - int k, - value_t* core_dists, - value_t alpha) -{ - // Create a functor to postprocess distances into mutual reachability space - // Note that we can't use a lambda for this here, since we get errors like: - // `A type local to a function cannot be used in the template argument of the - // enclosing parent function (and any parent classes) of an extended __device__ - // or __host__ __device__ lambda` - auto epilogue = ReachabilityPostProcess{core_dists, alpha}; - - auto X_view = raft::make_device_matrix_view(X, m, n); - std::vector> index = {X_view}; - - raft::neighbors::brute_force::knn( - handle, - index, - X_view, - raft::make_device_matrix_view(out_inds, m, static_cast(k)), - raft::make_device_matrix_view(out_dists, m, static_cast(k)), - // TODO: expand distance metrics to support more than just L2 distance - // https://github.com/rapidsai/cuml/issues/5301 - raft::distance::DistanceType::L2SqrtExpanded, - std::make_optional(2.0f), - std::nullopt, - epilogue); -} - -/** - * Constructs a mutual reachability graph, which is a k-nearest neighbors - * graph projected into mutual reachability space using the following - * function for each data point, where core_distance is the distance - * to the kth neighbor: max(core_distance(a), core_distance(b), d(a, b)) - * - * Unfortunately, points in the tails of the pdf (e.g. in sparse regions - * of the space) can have very large neighborhoods, which will impact - * nearby neighborhoods. Because of this, it's possible that the - * radius for points in the main mass, which might have a very small - * radius initially, to expand very large. As a result, the initial - * knn which was used to compute the core distances may no longer - * capture the actual neighborhoods after projection into mutual - * reachability space. - * - * For the experimental version, we execute the knn twice- once - * to compute the radii (core distances) and again to capture - * the final neighborhoods. Future iterations of this algorithm - * will work improve upon this "exact" version, by using - * more specialized data structures, such as space-partitioning - * structures. It has also been shown that approximate nearest - * neighbors can yield reasonable neighborhoods as the - * data sizes increase. - * - * @tparam value_idx - * @tparam value_t - * @param[in] handle raft handle for resource reuse - * @param[in] X input data points (size m * n) - * @param[in] m number of rows in X - * @param[in] n number of columns in X - * @param[in] metric distance metric to use - * @param[in] k neighborhood size - * @param[in] min_samples this neighborhood will be selected for core distances - * @param[in] alpha weight applied when internal distance is chosen for - * mutual reachability (value of 1.0 disables the weighting) - * @param[out] indptr CSR indptr of output knn graph (size m + 1) - * @param[out] core_dists output core distances array (size m) - * @param[out] out COO object, uninitialized on entry, on exit it stores the - * (symmetrized) maximum reachability distance for the k nearest - * neighbors. - */ -template -void mutual_reachability_graph(const raft::handle_t& handle, - const value_t* X, - size_t m, - size_t n, - raft::distance::DistanceType metric, - int min_samples, - value_t alpha, - value_idx* indptr, - value_t* core_dists, - raft::sparse::COO& out) -{ - RAFT_EXPECTS(metric == raft::distance::DistanceType::L2SqrtExpanded, - "Currently only L2 expanded distance is supported"); - - auto stream = handle.get_stream(); - auto exec_policy = handle.get_thrust_policy(); - - rmm::device_uvector coo_rows(min_samples * m, stream); - rmm::device_uvector inds(min_samples * m, stream); - rmm::device_uvector dists(min_samples * m, stream); - - // perform knn - compute_knn(handle, X, inds.data(), dists.data(), m, n, X, m, min_samples, metric); - - // Slice core distances (distances to kth nearest neighbor) - core_distances(dists.data(), min_samples, min_samples, m, core_dists, stream); - - /** - * Compute L2 norm - */ - mutual_reachability_knn_l2( - handle, inds.data(), dists.data(), X, m, n, min_samples, core_dists, (value_t)1.0 / alpha); - - // self-loops get max distance - auto coo_rows_counting_itr = thrust::make_counting_iterator(0); - thrust::transform(exec_policy, - coo_rows_counting_itr, - coo_rows_counting_itr + (m * min_samples), - coo_rows.data(), - [min_samples] __device__(value_idx c) -> value_idx { return c / min_samples; }); - - raft::sparse::linalg::symmetrize( - handle, coo_rows.data(), inds.data(), dists.data(), m, m, min_samples * m, out); - - raft::sparse::convert::sorted_coo_to_csr(out.rows(), out.nnz, indptr, m + 1, stream); - - // self-loops get max distance - auto transform_in = - thrust::make_zip_iterator(thrust::make_tuple(out.rows(), out.cols(), out.vals())); - - thrust::transform(exec_policy, - transform_in, - transform_in + out.nnz, - out.vals(), - [=] __device__(const thrust::tuple& tup) { - return thrust::get<0>(tup) == thrust::get<1>(tup) - ? std::numeric_limits::max() - : thrust::get<2>(tup); - }); -} - }; // end namespace Reachability }; // end namespace detail }; // end namespace HDBSCAN diff --git a/cpp/src/hdbscan/detail/soft_clustering.cuh b/cpp/src/hdbscan/detail/soft_clustering.cuh index 5370ad2dfb..00ef1e6b04 100644 --- a/cpp/src/hdbscan/detail/soft_clustering.cuh +++ b/cpp/src/hdbscan/detail/soft_clustering.cuh @@ -24,7 +24,6 @@ #include #include -#include #include #include #include @@ -66,7 +65,7 @@ void dist_membership_vector(const raft::handle_t& handle, value_idx* exemplar_idx, value_idx* exemplar_label_offsets, value_t* dist_membership_vec, - raft::distance::DistanceType metric, + cuvs::distance::DistanceType metric, size_t batch_size, bool softmax = false) { @@ -96,7 +95,7 @@ void dist_membership_vector(const raft::handle_t& handle, query + batch_offset * n, samples_per_batch, n), raft::make_device_matrix_view(exemplars_dense.data(), n_exemplars, n), raft::make_device_matrix_view(dist.data(), samples_per_batch, n_exemplars), - static_cast(metric)); + metric); // compute the minimum distances to exemplars of each cluster value_idx n_elements = samples_per_batch * n_selected_clusters; @@ -391,7 +390,7 @@ void all_points_membership_vectors(const raft::handle_t& handle, Common::CondensedHierarchy& condensed_tree, Common::PredictionData& prediction_data, const value_t* X, - raft::distance::DistanceType metric, + cuvs::distance::DistanceType metric, value_t* membership_vec, size_t batch_size) { @@ -510,12 +509,12 @@ void membership_vector(const raft::handle_t& handle, const value_t* X, const value_t* points_to_predict, size_t n_prediction_points, - raft::distance::DistanceType metric, + cuvs::distance::DistanceType metric, int min_samples, value_t* membership_vec, size_t batch_size) { - RAFT_EXPECTS(metric == raft::distance::DistanceType::L2SqrtExpanded, + RAFT_EXPECTS(metric == cuvs::distance::DistanceType::L2SqrtExpanded, "Currently only L2 expanded distance is supported"); auto stream = handle.get_stream(); @@ -548,7 +547,7 @@ void membership_vector(const raft::handle_t& handle, prediction_data.get_exemplar_idx(), prediction_data.get_exemplar_label_offsets(), dist_membership_vec.data(), - raft::distance::DistanceType::L2SqrtExpanded, + cuvs::distance::DistanceType::L2SqrtExpanded, batch_size); auto prediction_lambdas = diff --git a/cpp/src/hdbscan/hdbscan.cu b/cpp/src/hdbscan/hdbscan.cu index ea64d20f6b..8f5078c529 100644 --- a/cpp/src/hdbscan/hdbscan.cu +++ b/cpp/src/hdbscan/hdbscan.cu @@ -29,7 +29,7 @@ void hdbscan(const raft::handle_t& handle, const float* X, size_t m, size_t n, - raft::distance::DistanceType metric, + cuvs::distance::DistanceType metric, HDBSCAN::Common::HDBSCANParams& params, HDBSCAN::Common::hdbscan_output& out, float* core_dists) @@ -90,7 +90,7 @@ void compute_all_points_membership_vectors( HDBSCAN::Common::CondensedHierarchy& condensed_tree, HDBSCAN::Common::PredictionData& prediction_data, const float* X, - raft::distance::DistanceType metric, + cuvs::distance::DistanceType metric, float* membership_vec, size_t batch_size) { @@ -105,7 +105,7 @@ void compute_membership_vector(const raft::handle_t& handle, const float* points_to_predict, size_t n_prediction_points, int min_samples, - raft::distance::DistanceType metric, + cuvs::distance::DistanceType metric, float* membership_vec, size_t batch_size) { @@ -130,7 +130,7 @@ void out_of_sample_predict(const raft::handle_t& handle, int* labels, const float* points_to_predict, size_t n_prediction_points, - raft::distance::DistanceType metric, + cuvs::distance::DistanceType metric, int min_samples, int* out_labels, float* out_probabilities) @@ -157,7 +157,7 @@ void compute_core_dists(const raft::handle_t& handle, float* core_dists, size_t m, size_t n, - raft::distance::DistanceType metric, + cuvs::distance::DistanceType metric, int min_samples) { HDBSCAN::detail::Reachability::_compute_core_dists( diff --git a/cpp/src/hdbscan/runner.h b/cpp/src/hdbscan/runner.h index 2f3e554a20..9168d0e77c 100644 --- a/cpp/src/hdbscan/runner.h +++ b/cpp/src/hdbscan/runner.h @@ -159,7 +159,7 @@ void build_linkage(const raft::handle_t& handle, const value_t* X, size_t m, size_t n, - raft::distance::DistanceType metric, + cuvs::distance::DistanceType metric, Common::HDBSCANParams& params, value_t* core_dists, Common::robust_single_linkage_output& out) @@ -183,7 +183,7 @@ void build_linkage(const raft::handle_t& handle, raft::make_device_vector_view(mutual_reachability_indptr.data(), m + 1), raft::make_device_vector_view(core_dists, m), mutual_reachability_coo, - static_cast(metric), + metric, params.alpha); /** @@ -206,7 +206,7 @@ void build_linkage(const raft::handle_t& handle, color.data(), mutual_reachability_coo.nnz, red_op, - metric, + static_cast(metric), (size_t)10); /** @@ -229,7 +229,7 @@ void _fit_hdbscan(const raft::handle_t& handle, const value_t* X, size_t m, size_t n, - raft::distance::DistanceType metric, + cuvs::distance::DistanceType metric, Common::HDBSCANParams& params, value_idx* labels, value_t* core_dists, diff --git a/cpp/src/hierarchy/linkage.cu b/cpp/src/hierarchy/linkage.cu index 7094145fd2..87f2e8f7a2 100644 --- a/cpp/src/hierarchy/linkage.cu +++ b/cpp/src/hierarchy/linkage.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2023, NVIDIA CORPORATION. + * Copyright (c) 2018-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,11 +26,11 @@ void single_linkage_pairwise(const raft::handle_t& handle, size_t m, size_t n, raft::cluster::linkage_output* out, - raft::distance::DistanceType metric, + cuvs::distance::DistanceType metric, int n_clusters) { raft::cluster::single_linkage( - handle, X, m, n, metric, out, 0, n_clusters); + handle, X, m, n, static_cast(metric), out, 0, n_clusters); } void single_linkage_neighbors(const raft::handle_t& handle, @@ -38,12 +38,12 @@ void single_linkage_neighbors(const raft::handle_t& handle, size_t m, size_t n, raft::cluster::linkage_output* out, - raft::distance::DistanceType metric, + cuvs::distance::DistanceType metric, int c, int n_clusters) { raft::cluster::single_linkage( - handle, X, m, n, metric, out, c, n_clusters); + handle, X, m, n, static_cast(metric), out, c, n_clusters); } struct distance_graph_impl_int_float diff --git a/cpp/src/knn/knn.cu b/cpp/src/knn/knn.cu index c08cd47de2..0d20ec0130 100644 --- a/cpp/src/knn/knn.cu +++ b/cpp/src/knn/knn.cu @@ -49,7 +49,7 @@ void brute_force_knn(const raft::handle_t& handle, int k, bool rowMajorIndex, bool rowMajorQuery, - raft::distance::DistanceType metric, + cuvs::distance::DistanceType metric, float metric_arg, std::vector* translations) { @@ -109,14 +109,14 @@ void brute_force_knn(const raft::handle_t& handle, idx = cuvs::neighbors::brute_force::build( current_handle, raft::make_device_matrix_view(input[i], sizes[i], D), - static_cast(metric), + metric, metric_arg); } else { idx = cuvs::neighbors::brute_force::build( current_handle, raft::make_device_matrix_view(input[i], sizes[i], D), - static_cast(metric), + metric, metric_arg); } @@ -177,7 +177,7 @@ void rbc_knn_query(const raft::handle_t& handle, void approx_knn_build_index(raft::handle_t& handle, knnIndex* index, knnIndexParam* params, - raft::distance::DistanceType metric, + cuvs::distance::DistanceType metric, float metricArg, float* index_array, int n, @@ -189,14 +189,19 @@ void approx_knn_build_index(raft::handle_t& handle, auto ivf_ft_pams = dynamic_cast(params); auto ivf_pq_pams = dynamic_cast(params); - index->metric_processor = raft::spatial::knn::create_processor( - metric, n, D, 0, false, raft::resource::get_cuda_stream(handle)); + index->metric_processor = + raft::spatial::knn::create_processor(static_cast(metric), + n, + D, + 0, + false, + raft::resource::get_cuda_stream(handle)); // For cosine/correlation distance, the metric processor translates distance // to inner product via pre/post processing - pass the translated metric to // ANN index - if (metric == raft::distance::DistanceType::CosineExpanded || - metric == raft::distance::DistanceType::CorrelationExpanded) { - metric = index->metric = raft::distance::DistanceType::InnerProduct; + if (metric == cuvs::distance::DistanceType::CosineExpanded || + metric == cuvs::distance::DistanceType::CorrelationExpanded) { + metric = index->metric = cuvs::distance::DistanceType::InnerProduct; } index->metric_processor->preprocess(index_array); auto index_view = raft::make_device_matrix_view(index_array, n, D); @@ -204,7 +209,7 @@ void approx_knn_build_index(raft::handle_t& handle, if (ivf_ft_pams) { index->nprobe = ivf_ft_pams->nprobe; cuvs::neighbors::ivf_flat::index_params params; - params.metric = static_cast(metric); + params.metric = metric; params.metric_arg = metricArg; params.n_lists = ivf_ft_pams->nlist; @@ -213,7 +218,7 @@ void approx_knn_build_index(raft::handle_t& handle, } else if (ivf_pq_pams) { index->nprobe = ivf_pq_pams->nprobe; cuvs::neighbors::ivf_pq::index_params params; - params.metric = static_cast(metric); + params.metric = metric; params.metric_arg = metricArg; params.n_lists = ivf_pq_pams->nlist; params.pq_bits = ivf_pq_pams->n_bits; @@ -266,14 +271,14 @@ void approx_knn_search(raft::handle_t& handle, index->metric_processor->revert(query_array); // perform post-processing to show the real distances - if (index->metric == raft::distance::DistanceType::L2SqrtExpanded || - index->metric == raft::distance::DistanceType::L2SqrtUnexpanded || - index->metric == raft::distance::DistanceType::LpUnexpanded) { + if (index->metric == cuvs::distance::DistanceType::L2SqrtExpanded || + index->metric == cuvs::distance::DistanceType::L2SqrtUnexpanded || + index->metric == cuvs::distance::DistanceType::LpUnexpanded) { /** * post-processing */ float p = 0.5; // standard l2 - if (index->metric == raft::distance::DistanceType::LpUnexpanded) p = 1.0 / index->metricArg; + if (index->metric == cuvs::distance::DistanceType::LpUnexpanded) p = 1.0 / index->metricArg; raft::linalg::unaryOp(distances, distances, n * k, diff --git a/cpp/src/knn/knn_api.cpp b/cpp/src/knn/knn_api.cpp index 62766635c4..22c9df0bb7 100644 --- a/cpp/src/knn/knn_api.cpp +++ b/cpp/src/knn/knn_api.cpp @@ -68,8 +68,8 @@ cumlError_t knn_search(const cumlHandle_t handle, cumlError_t status; raft::handle_t* handle_ptr; std::tie(handle_ptr, status) = ML::handleMap.lookupHandlePointer(handle); - raft::distance::DistanceType metric_distance_type = - static_cast(metric_type); + cuvs::distance::DistanceType metric_distance_type = + static_cast(metric_type); std::vector input_vec(n_params); std::vector sizes_vec(n_params); diff --git a/cpp/src/knn/knn_opg_common.cuh b/cpp/src/knn/knn_opg_common.cuh index bcf3fe81ae..e8b73b3a5d 100644 --- a/cpp/src/knn/knn_opg_common.cuh +++ b/cpp/src/knn/knn_opg_common.cuh @@ -454,7 +454,7 @@ void perform_local_knn(opg_knn_param& params, params.k, params.rowMajorIndex, params.rowMajorQuery, - raft::distance::DistanceType::L2SqrtExpanded, + cuvs::distance::DistanceType::L2SqrtExpanded, 2.0f, &start_indices_long); handle.sync_stream(handle.get_stream()); diff --git a/cpp/src/knn/knn_sparse.cu b/cpp/src/knn/knn_sparse.cu index f7392266f6..8768e8ab3c 100644 --- a/cpp/src/knn/knn_sparse.cu +++ b/cpp/src/knn/knn_sparse.cu @@ -17,7 +17,8 @@ #include #include -#include + +#include namespace ML { namespace Sparse { @@ -40,29 +41,34 @@ void brute_force_knn(raft::handle_t& handle, int k, size_t batch_size_index, // approx 1M size_t batch_size_query, - raft::distance::DistanceType metric, + cuvs::distance::DistanceType metric, float metricArg) { - raft::sparse::selection::brute_force_knn(idx_indptr, - idx_indices, - idx_data, - idx_nnz, - n_idx_rows, - n_idx_cols, - query_indptr, - query_indices, - query_data, - query_nnz, - n_query_rows, - n_query_cols, - output_indices, - output_dists, - k, - handle, - batch_size_index, - batch_size_query, - metric, - metricArg); + auto idx_structure = raft::make_device_compressed_structure_view( + const_cast(idx_indptr), const_cast(idx_indices), n_idx_rows, n_idx_cols, idx_nnz); + auto idx_csr = raft::make_device_csr_matrix_view(idx_data, idx_structure); + + auto query_structure = + raft::make_device_compressed_structure_view(const_cast(query_indptr), + const_cast(query_indices), + n_query_rows, + n_query_cols, + query_nnz); + auto query_csr = raft::make_device_csr_matrix_view(query_data, query_structure); + + cuvs::neighbors::brute_force::sparse_search_params search_params; + search_params.batch_size_index = batch_size_index; + search_params.batch_size_query = batch_size_query; + + auto index = cuvs::neighbors::brute_force::build(handle, idx_csr, metric, metricArg); + + cuvs::neighbors::brute_force::search( + handle, + search_params, + index, + query_csr, + raft::make_device_matrix_view(output_indices, n_query_rows, k), + raft::make_device_matrix_view(output_dists, n_query_cols, k)); } }; // namespace Sparse }; // namespace ML diff --git a/cpp/src/metrics/pairwise_distance.cu b/cpp/src/metrics/pairwise_distance.cu index 1f56ae8bf9..94d10129e8 100644 --- a/cpp/src/metrics/pairwise_distance.cu +++ b/cpp/src/metrics/pairwise_distance.cu @@ -19,7 +19,6 @@ #include #include -#include #include @@ -33,7 +32,7 @@ void pairwise_distance(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, + cuvs::distance::DistanceType metric, bool isRowMajor, double metric_arg) { @@ -43,7 +42,7 @@ void pairwise_distance(const raft::handle_t& handle, raft::make_device_matrix_view(x, m, k), raft::make_device_matrix_view(y, n, k), raft::make_device_matrix_view(dist, m, n), - static_cast(metric), + metric, metric_arg); } else { cuvs::distance::pairwise_distance( @@ -51,7 +50,7 @@ void pairwise_distance(const raft::handle_t& handle, raft::make_device_matrix_view(x, m, k), raft::make_device_matrix_view(y, n, k), raft::make_device_matrix_view(dist, m, n), - static_cast(metric), + metric, metric_arg); } } @@ -63,7 +62,7 @@ void pairwise_distance(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, + cuvs::distance::DistanceType metric, bool isRowMajor, float metric_arg) { @@ -73,7 +72,7 @@ void pairwise_distance(const raft::handle_t& handle, raft::make_device_matrix_view(x, m, k), raft::make_device_matrix_view(y, n, k), raft::make_device_matrix_view(dist, m, n), - static_cast(metric), + metric, metric_arg); } else { cuvs::distance::pairwise_distance( @@ -81,7 +80,7 @@ void pairwise_distance(const raft::handle_t& handle, raft::make_device_matrix_view(x, m, k), raft::make_device_matrix_view(y, n, k), raft::make_device_matrix_view(dist, m, n), - static_cast(metric), + metric, metric_arg); } } @@ -100,7 +99,7 @@ void pairwiseDistance_sparse(const raft::handle_t& handle, value_idx* y_indptr, value_idx* x_indices, value_idx* y_indices, - raft::distance::DistanceType metric, + cuvs::distance::DistanceType metric, float metric_arg) { auto out = raft::make_device_matrix_view(dist, y_nrows, x_nrows); @@ -113,8 +112,7 @@ void pairwiseDistance_sparse(const raft::handle_t& handle, y_indptr, y_indices, y_nrows, n_cols, y_nnz); auto y_csr_view = raft::make_device_csr_matrix_view(y, y_structure); - raft::sparse::distance::pairwise_distance( - handle, y_csr_view, x_csr_view, out, metric, metric_arg); + cuvs::distance::pairwise_distance(handle, y_csr_view, x_csr_view, out, metric, metric_arg); } void pairwiseDistance_sparse(const raft::handle_t& handle, @@ -130,7 +128,7 @@ void pairwiseDistance_sparse(const raft::handle_t& handle, int* y_indptr, int* x_indices, int* y_indices, - raft::distance::DistanceType metric, + cuvs::distance::DistanceType metric, float metric_arg) { pairwiseDistance_sparse(handle, @@ -163,7 +161,7 @@ void pairwiseDistance_sparse(const raft::handle_t& handle, int* y_indptr, int* x_indices, int* y_indices, - raft::distance::DistanceType metric, + cuvs::distance::DistanceType metric, float metric_arg) { pairwiseDistance_sparse(handle, diff --git a/cpp/src/metrics/silhouette_score.cu b/cpp/src/metrics/silhouette_score.cu index cf4fdeb6fd..6ac668fe14 100644 --- a/cpp/src/metrics/silhouette_score.cu +++ b/cpp/src/metrics/silhouette_score.cu @@ -32,7 +32,7 @@ double silhouette_score(const raft::handle_t& handle, int* labels, int nLabels, double* silScores, - raft::distance::DistanceType metric) + cuvs::distance::DistanceType metric) { std::optional> silhouette_score_per_sample; if (silScores != NULL) { @@ -45,7 +45,7 @@ double silhouette_score(const raft::handle_t& handle, raft::make_device_vector_view(labels, nRows), silhouette_score_per_sample, nLabels, - static_cast(metric)); + metric); } } // namespace Metrics } // namespace ML diff --git a/cpp/src/metrics/silhouette_score_batched_double.cu b/cpp/src/metrics/silhouette_score_batched_double.cu index 3188ce0bc4..9c1c183dcc 100644 --- a/cpp/src/metrics/silhouette_score_batched_double.cu +++ b/cpp/src/metrics/silhouette_score_batched_double.cu @@ -34,7 +34,7 @@ double silhouette_score(const raft::handle_t& handle, int n_labels, double* scores, int chunk, - raft::distance::DistanceType metric) + cuvs::distance::DistanceType metric) { std::optional> silhouette_score_per_sample; if (scores != NULL) { @@ -48,7 +48,7 @@ double silhouette_score(const raft::handle_t& handle, silhouette_score_per_sample, n_labels, chunk, - static_cast(metric)); + metric); } } // namespace Batched diff --git a/cpp/src/metrics/silhouette_score_batched_float.cu b/cpp/src/metrics/silhouette_score_batched_float.cu index 0245375657..3676abdbcb 100644 --- a/cpp/src/metrics/silhouette_score_batched_float.cu +++ b/cpp/src/metrics/silhouette_score_batched_float.cu @@ -18,8 +18,8 @@ #include #include -#include +#include #include namespace ML { @@ -34,7 +34,7 @@ float silhouette_score(const raft::handle_t& handle, int n_labels, float* scores, int chunk, - raft::distance::DistanceType metric) + cuvs::distance::DistanceType metric) { std::optional> silhouette_score_per_sample; if (scores != NULL) { @@ -48,7 +48,7 @@ float silhouette_score(const raft::handle_t& handle, silhouette_score_per_sample, n_labels, chunk, - static_cast(metric)); + metric); } } // namespace Batched } // namespace Metrics diff --git a/cpp/src/metrics/trustworthiness.cu b/cpp/src/metrics/trustworthiness.cu index 724ef43ddd..bcc3a04803 100644 --- a/cpp/src/metrics/trustworthiness.cu +++ b/cpp/src/metrics/trustworthiness.cu @@ -37,7 +37,7 @@ namespace Metrics { * @tparam distance_type: Distance type to consider * @return Trustworthiness score */ -template +template double trustworthiness_score(const raft::handle_t& h, const math_t* X, math_t* X_embedded, @@ -52,11 +52,11 @@ double trustworthiness_score(const raft::handle_t& h, raft::make_device_matrix_view(X, n, m), raft::make_device_matrix_view(X_embedded, n, d), n_neighbors, - static_cast(distance_type), + distance_type, batchSize); } -template double trustworthiness_score( +template double trustworthiness_score( const raft::handle_t& h, const float* X, float* X_embedded, diff --git a/cpp/src/tsne/distances.cuh b/cpp/src/tsne/distances.cuh index 31be50942c..c0c12984ec 100644 --- a/cpp/src/tsne/distances.cuh +++ b/cpp/src/tsne/distances.cuh @@ -23,11 +23,9 @@ #include #include -#include #include #include #include -#include #include #include @@ -37,6 +35,7 @@ #include #include +#include #include #include @@ -57,7 +56,7 @@ void get_distances(const raft::handle_t& handle, tsne_input& input, knn_graph& k_graph, cudaStream_t stream, - raft::distance::DistanceType metric, + cuvs::distance::DistanceType metric, value_t p); // dense, int64 indices @@ -66,7 +65,7 @@ void get_distances(const raft::handle_t& handle, manifold_dense_inputs_t& input, knn_graph& k_graph, cudaStream_t stream, - raft::distance::DistanceType metric, + cuvs::distance::DistanceType metric, float p) { // TODO: for TSNE transform first fit some points then transform with 1/(1+d^2) @@ -74,8 +73,7 @@ void get_distances(const raft::handle_t& handle, auto k = k_graph.n_neighbors; auto X_view = raft::make_device_matrix_view(input.X, input.n, input.d); - auto idx = cuvs::neighbors::brute_force::build( - handle, X_view, static_cast(metric), p); + auto idx = cuvs::neighbors::brute_force::build(handle, X_view, metric, p); cuvs::neighbors::brute_force::search( handle, @@ -91,7 +89,7 @@ void get_distances(const raft::handle_t& handle, manifold_dense_inputs_t& input, knn_graph& k_graph, cudaStream_t stream, - raft::distance::DistanceType metric, + cuvs::distance::DistanceType metric, float p) { throw raft::exception("Dense TSNE does not support 32-bit integer indices yet."); @@ -103,29 +101,26 @@ void get_distances(const raft::handle_t& handle, manifold_sparse_inputs_t& input, knn_graph& k_graph, cudaStream_t stream, - raft::distance::DistanceType metric, + cuvs::distance::DistanceType metric, float p) { - raft::sparse::selection::brute_force_knn(input.indptr, - input.indices, - input.data, - input.nnz, - input.n, - input.d, - input.indptr, - input.indices, - input.data, - input.nnz, - input.n, - input.d, - k_graph.knn_indices, - k_graph.knn_dists, - k_graph.n_neighbors, - handle, - ML::Sparse::DEFAULT_BATCH_SIZE, - ML::Sparse::DEFAULT_BATCH_SIZE, - metric, - p); + auto input_structure = raft::make_device_compressed_structure_view( + input.indptr, input.indices, input.n, input.d, input.nnz); + auto input_csr = raft::make_device_csr_matrix_view(input.data, input_structure); + + cuvs::neighbors::brute_force::sparse_search_params search_params; + search_params.batch_size_index = ML::Sparse::DEFAULT_BATCH_SIZE; + search_params.batch_size_query = ML::Sparse::DEFAULT_BATCH_SIZE; + + auto index = cuvs::neighbors::brute_force::build(handle, input_csr, metric, p); + + cuvs::neighbors::brute_force::search( + handle, + search_params, + index, + input_csr, + raft::make_device_matrix_view(k_graph.knn_indices, input.n, k_graph.n_neighbors), + raft::make_device_matrix_view(k_graph.knn_dists, input.n, k_graph.n_neighbors)); } // sparse, int64 @@ -134,7 +129,7 @@ void get_distances(const raft::handle_t& handle, manifold_sparse_inputs_t& input, knn_graph& k_graph, cudaStream_t stream, - raft::distance::DistanceType metric, + cuvs::distance::DistanceType metric, float p) { throw raft::exception("Sparse TSNE does not support 64-bit integer indices yet."); diff --git a/cpp/src/tsne/tsne.cu b/cpp/src/tsne/tsne.cu index 4964e1a47f..5f2f5a5e9f 100644 --- a/cpp/src/tsne/tsne.cu +++ b/cpp/src/tsne/tsne.cu @@ -19,7 +19,8 @@ #include #include -#include + +#include namespace ML { template diff --git a/cpp/src/tsne/tsne_runner.cuh b/cpp/src/tsne/tsne_runner.cuh index 96df2a12ca..b735be0e63 100644 --- a/cpp/src/tsne/tsne_runner.cuh +++ b/cpp/src/tsne/tsne_runner.cuh @@ -26,7 +26,6 @@ #include #include -#include #include #include #include @@ -36,6 +35,7 @@ #include +#include #include namespace ML { diff --git a/cpp/src/umap/knn_graph/algo.cuh b/cpp/src/umap/knn_graph/algo.cuh index fa55397659..6617d72c00 100644 --- a/cpp/src/umap/knn_graph/algo.cuh +++ b/cpp/src/umap/knn_graph/algo.cuh @@ -27,15 +27,13 @@ #include #include #include -#include #include #include #include #include -#include -#include #include +#include #include #include @@ -97,7 +95,7 @@ inline void launcher(const raft::handle_t& handle, auto idx = cuvs::neighbors::brute_force::build( handle, raft::make_device_matrix_view(inputsA.X, inputsA.n, inputsA.d), - static_cast(params->metric), + params->metric, params->p); cuvs::neighbors::brute_force::search( @@ -162,26 +160,28 @@ inline void launcher(const raft::handle_t& handle, { RAFT_EXPECTS(params->build_algo == ML::UMAPParams::graph_build_algo::BRUTE_FORCE_KNN, "nn_descent does not support sparse inputs"); - raft::sparse::selection::brute_force_knn(inputsA.indptr, - inputsA.indices, - inputsA.data, - inputsA.nnz, - inputsA.n, - inputsA.d, - inputsB.indptr, - inputsB.indices, - inputsB.data, - inputsB.nnz, - inputsB.n, - inputsB.d, - out.knn_indices, - out.knn_dists, - n_neighbors, - handle, - ML::Sparse::DEFAULT_BATCH_SIZE, - ML::Sparse::DEFAULT_BATCH_SIZE, - params->metric, - params->p); + + auto a_structure = raft::make_device_compressed_structure_view( + inputsA.indptr, inputsA.indices, inputsA.n, inputsA.d, inputsA.nnz); + auto a_csr = raft::make_device_csr_matrix_view(inputsA.data, a_structure); + + auto b_structure = raft::make_device_compressed_structure_view( + inputsB.indptr, inputsB.indices, inputsB.n, inputsB.d, inputsB.nnz); + auto b_csr = raft::make_device_csr_matrix_view(inputsB.data, b_structure); + + cuvs::neighbors::brute_force::sparse_search_params search_params; + search_params.batch_size_index = ML::Sparse::DEFAULT_BATCH_SIZE; + search_params.batch_size_query = ML::Sparse::DEFAULT_BATCH_SIZE; + + auto index = cuvs::neighbors::brute_force::build(handle, a_csr, params->metric, params->p); + + cuvs::neighbors::brute_force::search( + handle, + search_params, + index, + b_csr, + raft::make_device_matrix_view(out.knn_indices, inputsB.n, n_neighbors), + raft::make_device_matrix_view(out.knn_dists, inputsB.n, n_neighbors)); } template <> diff --git a/cpp/src/umap/supervised.cuh b/cpp/src/umap/supervised.cuh index 1156005ad2..21ed42f157 100644 --- a/cpp/src/umap/supervised.cuh +++ b/cpp/src/umap/supervised.cuh @@ -27,6 +27,7 @@ #include #include +#include #include #include #include diff --git a/cpp/src_prims/selection/knn.cuh b/cpp/src_prims/selection/knn.cuh index b24f0d03e1..56c3b2b9f6 100644 --- a/cpp/src_prims/selection/knn.cuh +++ b/cpp/src_prims/selection/knn.cuh @@ -21,7 +21,6 @@ #include #include -#include #include #include #include @@ -29,6 +28,8 @@ #include #include +#include + #include #include #include diff --git a/cpp/test/prims/dist_adj.cu b/cpp/test/prims/dist_adj.cu index 5730b276ea..3b2c4d1e61 100644 --- a/cpp/test/prims/dist_adj.cu +++ b/cpp/test/prims/dist_adj.cu @@ -106,7 +106,7 @@ class DistanceAdjTest : public ::testing::TestWithParam( + getWorkspaceSize( x, y, m, n, k); rmm::device_uvector workspace(worksize, stream); @@ -114,7 +114,7 @@ class DistanceAdjTest : public ::testing::TestWithParam(x.data(), + distance(x.data(), y.data(), dist.data(), m, diff --git a/cpp/test/prims/distance_base.cuh b/cpp/test/prims/distance_base.cuh index 4a472779f7..46201aa878 100644 --- a/cpp/test/prims/distance_base.cuh +++ b/cpp/test/prims/distance_base.cuh @@ -35,7 +35,7 @@ CUML_KERNEL void naiveDistanceKernel(DataType* dist, int m, int n, int k, - raft::distance::DistanceType type, + cuvs::distance::DistanceType type, bool isRowMajor) { int midx = threadIdx.x + blockIdx.x * blockDim.x; @@ -48,8 +48,8 @@ CUML_KERNEL void naiveDistanceKernel(DataType* dist, auto diff = x[xidx] - y[yidx]; acc += diff * diff; } - if (type == raft::distance::DistanceType::L2SqrtExpanded || - type == raft::distance::DistanceType::L2SqrtUnexpanded) + if (type == cuvs::distance::DistanceType::L2SqrtExpanded || + type == cuvs::distance::DistanceType::L2SqrtUnexpanded) acc = raft::sqrt(acc); int outidx = isRowMajor ? midx * n + nidx : midx + m * nidx; dist[outidx] = acc; @@ -112,23 +112,23 @@ void naiveDistance(DataType* dist, int m, int n, int k, - raft::distance::DistanceType type, + cuvs::distance::DistanceType type, bool isRowMajor) { static const dim3 TPB(16, 32, 1); dim3 nblks(raft::ceildiv(m, (int)TPB.x), raft::ceildiv(n, (int)TPB.y), 1); switch (type) { - case raft::distance::DistanceType::L1: + case cuvs::distance::DistanceType::L1: naiveL1DistanceKernel < <>(dist, x, y, m, n, k, isRowMajor); break; - case raft::distance::DistanceType::L2SqrtUnexpanded: - case raft::distance::DistanceType::L2Unexpanded: - case raft::distance::DistanceType::L2SqrtExpanded: - case raft::distance::DistanceType::L2Expanded: + case cuvs::distance::DistanceType::L2SqrtUnexpanded: + case cuvs::distance::DistanceType::L2Unexpanded: + case cuvs::distance::DistanceType::L2SqrtExpanded: + case cuvs::distance::DistanceType::L2Expanded: naiveDistanceKernel < <>(dist, x, y, m, n, k, type, isRowMajor); break; - case raft::distance::DistanceType::CosineExpanded: + case cuvs::distance::DistanceType::CosineExpanded: naiveCosineDistanceKernel < <>(dist, x, y, m, n, k, isRowMajor); break; default: FAIL() << "should be here\n"; @@ -150,7 +150,7 @@ template return os; } -template +template void distanceLauncher(raft::resources const& handle, DataType* x, DataType* y, @@ -175,7 +175,7 @@ void distanceLauncher(raft::resources const& handle, handle, x, y, dist, m, n, k, workspace, worksize, fin_op, isRowMajor); } -template +template class DistanceTest : public ::testing::TestWithParam> { public: DistanceTest() diff --git a/cpp/test/sg/dbscan_test.cu b/cpp/test/sg/dbscan_test.cu index 69f5bcf29b..a46884fb39 100644 --- a/cpp/test/sg/dbscan_test.cu +++ b/cpp/test/sg/dbscan_test.cu @@ -21,11 +21,11 @@ #include #include -#include #include #include #include +#include #include #include @@ -52,7 +52,7 @@ struct DbscanInputs { int min_pts; size_t max_bytes_per_batch; unsigned long long int seed; - raft::distance::DistanceType metric; + cuvs::distance::DistanceType metric; }; template @@ -74,7 +74,8 @@ class DbscanTest : public ::testing::TestWithParam> { rmm::device_uvector out(params.n_row * params.n_col, stream); rmm::device_uvector l(params.n_row, stream); rmm::device_uvector dist( - params.metric == raft::distance::Precomputed ? params.n_row * params.n_row : 0, stream); + params.metric == cuvs::distance::DistanceType::Precomputed ? params.n_row * params.n_row : 0, + stream); make_blobs(handle, out.data(), @@ -91,7 +92,7 @@ class DbscanTest : public ::testing::TestWithParam> { 10.0f, params.seed); - if (params.metric == raft::distance::Precomputed) { + if (params.metric == cuvs::distance::DistanceType::Precomputed) { ML::Metrics::pairwise_distance(handle, out.data(), out.data(), @@ -99,7 +100,7 @@ class DbscanTest : public ::testing::TestWithParam> { params.n_row, params.n_row, params.n_col, - raft::distance::L2SqrtUnexpanded); + cuvs::distance::DistanceType::L2SqrtUnexpanded); } rmm::device_uvector labels(params.n_row, stream); @@ -109,17 +110,18 @@ class DbscanTest : public ::testing::TestWithParam> { handle.sync_stream(stream); - Dbscan::fit(handle, - params.metric == raft::distance::Precomputed ? dist.data() : out.data(), - params.n_row, - params.n_col, - params.eps, - params.min_pts, - params.metric, - labels.data(), - nullptr, - nullptr, - params.max_bytes_per_batch); + Dbscan::fit( + handle, + params.metric == cuvs::distance::DistanceType::Precomputed ? dist.data() : out.data(), + params.n_row, + params.n_col, + params.eps, + params.min_pts, + params.metric, + labels.data(), + nullptr, + nullptr, + params.max_bytes_per_batch); handle.sync_stream(stream); @@ -143,37 +145,149 @@ class DbscanTest : public ::testing::TestWithParam> { }; const std::vector> inputsf2 = { - {500, 16, 5, 0.01, 2, 2, (size_t)100, 1234ULL, raft::distance::L2SqrtUnexpanded}, - {500, 16, 5, 0.01, 2, 2, (size_t)100, 1234ULL, raft::distance::Precomputed}, - {1000, 1000, 10, 0.01, 2, 2, (size_t)13e3, 1234ULL, raft::distance::L2SqrtUnexpanded}, - {20000, 10000, 10, 0.01, 2, 2, (size_t)13e3, 1234ULL, raft::distance::L2SqrtUnexpanded}, - {20000, 100, 5000, 0.01, 2, 2, (size_t)13e3, 1234ULL, raft::distance::L2SqrtUnexpanded}}; + {500, 16, 5, 0.01, 2, 2, (size_t)100, 1234ULL, cuvs::distance::DistanceType::L2SqrtUnexpanded}, + {500, 16, 5, 0.01, 2, 2, (size_t)100, 1234ULL, cuvs::distance::DistanceType::Precomputed}, + {1000, + 1000, + 10, + 0.01, + 2, + 2, + (size_t)13e3, + 1234ULL, + cuvs::distance::DistanceType::L2SqrtUnexpanded}, + {20000, + 10000, + 10, + 0.01, + 2, + 2, + (size_t)13e3, + 1234ULL, + cuvs::distance::DistanceType::L2SqrtUnexpanded}, + {20000, + 100, + 5000, + 0.01, + 2, + 2, + (size_t)13e3, + 1234ULL, + cuvs::distance::DistanceType::L2SqrtUnexpanded}}; const std::vector> inputsf3 = { - {500, 16, 5, 0.01, 2, 2, (size_t)100, 1234ULL, raft::distance::L2SqrtUnexpanded}, - {500, 16, 5, 0.01, 2, 2, (size_t)100, 1234ULL, raft::distance::Precomputed}, - {1000, 1000, 10, 0.01, 2, 2, (size_t)9e3, 1234ULL, raft::distance::L2SqrtUnexpanded}, - {50000, 16, 5, 0.01, 2, 2, (size_t)9e3, 1234ULL, raft::distance::L2SqrtUnexpanded}, - {20000, 10000, 10, 0.01, 2, 2, (size_t)9e3, 1234ULL, raft::distance::L2SqrtUnexpanded}, - {20000, 100, 5000, 0.01, 2, 2, (size_t)9e3, 1234ULL, raft::distance::L2SqrtUnexpanded}}; + {500, 16, 5, 0.01, 2, 2, (size_t)100, 1234ULL, cuvs::distance::DistanceType::L2SqrtUnexpanded}, + {500, 16, 5, 0.01, 2, 2, (size_t)100, 1234ULL, cuvs::distance::DistanceType::Precomputed}, + {1000, + 1000, + 10, + 0.01, + 2, + 2, + (size_t)9e3, + 1234ULL, + cuvs::distance::DistanceType::L2SqrtUnexpanded}, + {50000, 16, 5, 0.01, 2, 2, (size_t)9e3, 1234ULL, cuvs::distance::DistanceType::L2SqrtUnexpanded}, + {20000, + 10000, + 10, + 0.01, + 2, + 2, + (size_t)9e3, + 1234ULL, + cuvs::distance::DistanceType::L2SqrtUnexpanded}, + {20000, + 100, + 5000, + 0.01, + 2, + 2, + (size_t)9e3, + 1234ULL, + cuvs::distance::DistanceType::L2SqrtUnexpanded}}; const std::vector> inputsd2 = { - {50000, 16, 5, 0.01, 2, 2, (size_t)13e3, 1234ULL, raft::distance::L2SqrtUnexpanded}, - {10000, 16, 5, 0.01, 2, 2, (size_t)13e3, 1234ULL, raft::distance::Precomputed}, - {500, 16, 5, 0.01, 2, 2, (size_t)100, 1234ULL, raft::distance::L2SqrtUnexpanded}, - {1000, 1000, 10, 0.01, 2, 2, (size_t)13e3, 1234ULL, raft::distance::L2SqrtUnexpanded}, - {100, 10000, 10, 0.01, 2, 2, (size_t)13e3, 1234ULL, raft::distance::L2SqrtUnexpanded}, - {20000, 10000, 10, 0.01, 2, 2, (size_t)13e3, 1234ULL, raft::distance::L2SqrtUnexpanded}, - {20000, 100, 5000, 0.01, 2, 2, (size_t)13e3, 1234ULL, raft::distance::L2SqrtUnexpanded}}; + {50000, 16, 5, 0.01, 2, 2, (size_t)13e3, 1234ULL, cuvs::distance::DistanceType::L2SqrtUnexpanded}, + {10000, 16, 5, 0.01, 2, 2, (size_t)13e3, 1234ULL, cuvs::distance::DistanceType::Precomputed}, + {500, 16, 5, 0.01, 2, 2, (size_t)100, 1234ULL, cuvs::distance::DistanceType::L2SqrtUnexpanded}, + {1000, + 1000, + 10, + 0.01, + 2, + 2, + (size_t)13e3, + 1234ULL, + cuvs::distance::DistanceType::L2SqrtUnexpanded}, + {100, + 10000, + 10, + 0.01, + 2, + 2, + (size_t)13e3, + 1234ULL, + cuvs::distance::DistanceType::L2SqrtUnexpanded}, + {20000, + 10000, + 10, + 0.01, + 2, + 2, + (size_t)13e3, + 1234ULL, + cuvs::distance::DistanceType::L2SqrtUnexpanded}, + {20000, + 100, + 5000, + 0.01, + 2, + 2, + (size_t)13e3, + 1234ULL, + cuvs::distance::DistanceType::L2SqrtUnexpanded}}; const std::vector> inputsd3 = { - {50000, 16, 5, 0.01, 2, 2, (size_t)9e3, 1234ULL, raft::distance::L2SqrtUnexpanded}, - {10000, 16, 5, 0.01, 2, 2, (size_t)9e3, 1234ULL, raft::distance::Precomputed}, - {500, 16, 5, 0.01, 2, 2, (size_t)100, 1234ULL, raft::distance::L2SqrtUnexpanded}, - {1000, 1000, 10, 0.01, 2, 2, (size_t)9e3, 1234ULL, raft::distance::L2SqrtUnexpanded}, - {100, 10000, 10, 0.01, 2, 2, (size_t)9e3, 1234ULL, raft::distance::L2SqrtUnexpanded}, - {20000, 10000, 10, 0.01, 2, 2, (size_t)9e3, 1234ULL, raft::distance::L2SqrtUnexpanded}, - {20000, 100, 5000, 0.01, 2, 2, (size_t)9e3, 1234ULL, raft::distance::L2SqrtUnexpanded}}; + {50000, 16, 5, 0.01, 2, 2, (size_t)9e3, 1234ULL, cuvs::distance::DistanceType::L2SqrtUnexpanded}, + {10000, 16, 5, 0.01, 2, 2, (size_t)9e3, 1234ULL, cuvs::distance::DistanceType::Precomputed}, + {500, 16, 5, 0.01, 2, 2, (size_t)100, 1234ULL, cuvs::distance::DistanceType::L2SqrtUnexpanded}, + {1000, + 1000, + 10, + 0.01, + 2, + 2, + (size_t)9e3, + 1234ULL, + cuvs::distance::DistanceType::L2SqrtUnexpanded}, + {100, + 10000, + 10, + 0.01, + 2, + 2, + (size_t)9e3, + 1234ULL, + cuvs::distance::DistanceType::L2SqrtUnexpanded}, + {20000, + 10000, + 10, + 0.01, + 2, + 2, + (size_t)9e3, + 1234ULL, + cuvs::distance::DistanceType::L2SqrtUnexpanded}, + {20000, + 100, + 5000, + 0.01, + 2, + 2, + (size_t)9e3, + 1234ULL, + cuvs::distance::DistanceType::L2SqrtUnexpanded}}; typedef DbscanTest DbscanTestF_Int; TEST_P(DbscanTestF_Int, Result) { ASSERT_TRUE(score == 1.0); } @@ -242,7 +356,7 @@ class Dbscan2DSimple : public ::testing::TestWithParam> { 2, params.eps, params.min_pts, - raft::distance::L2SqrtUnexpanded, + cuvs::distance::DistanceType::L2SqrtUnexpanded, labels.data(), core_sample_indices_d.data(), sample_weight, diff --git a/cpp/test/sg/hdbscan_test.cu b/cpp/test/sg/hdbscan_test.cu index a7ce69b1bc..d90e9f4314 100644 --- a/cpp/test/sg/hdbscan_test.cu +++ b/cpp/test/sg/hdbscan_test.cu @@ -21,7 +21,6 @@ #include // build_dendrogram_host #include -#include #include #include #include @@ -35,6 +34,8 @@ #include #include +#include +#include #include #include #include @@ -105,7 +106,7 @@ class HDBSCANTest : public ::testing::TestWithParam> { data.data(), params.n_row, params.n_col, - raft::distance::DistanceType::L2SqrtExpanded, + cuvs::distance::DistanceType::L2SqrtExpanded, hdbscan_params, out, core_dists.data()); @@ -460,7 +461,7 @@ class AllPointsMembershipVectorsTest condensed_tree, prediction_data_, data.data(), - raft::distance::DistanceType::L2SqrtExpanded, + cuvs::distance::DistanceType::L2SqrtExpanded, membership_vec.data()); ASSERT_TRUE(MLCommon::devArrMatch(membership_vec.data(), @@ -571,17 +572,16 @@ class ApproximatePredictTest : public ::testing::TestWithParam mutual_reachability_indptr(params.n_row + 1, stream); raft::sparse::COO mutual_reachability_coo(stream, (params.min_samples + 1) * params.n_row * 2); - ML::HDBSCAN::detail::Reachability::mutual_reachability_graph( + + cuvs::neighbors::reachability::mutual_reachability_graph( handle, - data.data(), - (size_t)params.n_row, - (size_t)params.n_col, - raft::distance::DistanceType::L2SqrtExpanded, + raft::make_device_matrix_view(data.data(), params.n_row, params.n_col), params.min_samples + 1, - (float)1.0, - mutual_reachability_indptr.data(), - pred_data.get_core_dists(), - mutual_reachability_coo); + raft::make_device_vector_view(mutual_reachability_indptr.data(), params.n_row + 1), + raft::make_device_vector_view(core_dists.data(), params.n_row), + mutual_reachability_coo, + cuvs::distance::DistanceType::L2SqrtExpanded, + 1.0); transformLabels(handle, labels.data(), label_map.data(), params.n_row); ML::HDBSCAN::Common::generate_prediction_data(handle, @@ -602,7 +602,7 @@ class ApproximatePredictTest : public ::testing::TestWithParam(points_to_predict.data()), (size_t)params.n_points_to_predict, - raft::distance::DistanceType::L2SqrtExpanded, + cuvs::distance::DistanceType::L2SqrtExpanded, params.min_samples, out_labels.data(), out_probabilities.data()); @@ -726,17 +726,16 @@ class MembershipVectorTest : public ::testing::TestWithParam mutual_reachability_indptr(params.n_row + 1, stream); raft::sparse::COO mutual_reachability_coo(stream, (params.min_samples + 1) * params.n_row * 2); - ML::HDBSCAN::detail::Reachability::mutual_reachability_graph( + + cuvs::neighbors::reachability::mutual_reachability_graph( handle, - data.data(), - (size_t)params.n_row, - (size_t)params.n_col, - raft::distance::DistanceType::L2SqrtExpanded, + raft::make_device_matrix_view(data.data(), params.n_row, params.n_col), params.min_samples + 1, - (float)1.0, - mutual_reachability_indptr.data(), - prediction_data_.get_core_dists(), - mutual_reachability_coo); + raft::make_device_vector_view(mutual_reachability_indptr.data(), params.n_row + 1), + raft::make_device_vector_view(core_dists.data(), params.n_row), + mutual_reachability_coo, + cuvs::distance::DistanceType::L2SqrtExpanded, + 1.0); transformLabels(handle, labels.data(), label_map.data(), params.n_row); @@ -754,7 +753,7 @@ class MembershipVectorTest : public ::testing::TestWithParam #include -#include #include #include #include #include +#include #include #include @@ -92,7 +92,7 @@ class LinkageTest : public ::testing::TestWithParam> { params.n_row, params.n_col, &out_arrs, - raft::distance::DistanceType::L2Unexpanded, + cuvs::distance::DistanceType::L2Unexpanded, params.c, params.n_clusters); } else { @@ -101,7 +101,7 @@ class LinkageTest : public ::testing::TestWithParam> { params.n_row, params.n_col, &out_arrs, - raft::distance::DistanceType::L2Expanded, + cuvs::distance::DistanceType::L2Expanded, params.n_clusters); } diff --git a/cpp/test/sg/rproj_test.cu b/cpp/test/sg/rproj_test.cu index da02b9007f..cd6d0632d2 100644 --- a/cpp/test/sg/rproj_test.cu +++ b/cpp/test/sg/rproj_test.cu @@ -147,7 +147,7 @@ class RPROJTest : public ::testing::Test { void epsilon_check() { int D = johnson_lindenstrauss_min_dim(N, epsilon); - constexpr auto distance_type = raft::distance::DistanceType::L2SqrtUnexpanded; + constexpr auto distance_type = cuvs::distance::DistanceType::L2SqrtUnexpanded; rmm::device_uvector d_pdist(N * N, stream); ML::Metrics::pairwise_distance( diff --git a/cpp/test/sg/trustworthiness_test.cu b/cpp/test/sg/trustworthiness_test.cu index 0675acb391..b6daf47d05 100644 --- a/cpp/test/sg/trustworthiness_test.cu +++ b/cpp/test/sg/trustworthiness_test.cu @@ -319,7 +319,7 @@ class TrustworthinessScoreTest : public ::testing::Test { raft::update_device(d_X_embedded.data(), X_embedded.data(), X_embedded.size(), stream); // euclidean test - score = trustworthiness_score( + score = trustworthiness_score( h, d_X.data(), d_X_embedded.data(), 50, 30, 8, 5); } diff --git a/cpp/test/sg/tsne_test.cu b/cpp/test/sg/tsne_test.cu index 89b5cb2688..f1e3d47703 100644 --- a/cpp/test/sg/tsne_test.cu +++ b/cpp/test/sg/tsne_test.cu @@ -19,13 +19,13 @@ #include #include -#include #include #include #include #include +#include #include #include #include @@ -116,7 +116,7 @@ class TSNETest : public ::testing::TestWithParam { auto stream = handle.get_stream(); TSNEResults results; - auto DEFAULT_DISTANCE_METRIC = raft::distance::DistanceType::L2SqrtExpanded; + auto DEFAULT_DISTANCE_METRIC = cuvs::distance::DistanceType::L2SqrtExpanded; float minkowski_p = 2.0; // Setup parameters @@ -164,7 +164,7 @@ class TSNETest : public ::testing::TestWithParam { n, n, model_params.dim, - raft::distance::DistanceType::L2Expanded, + cuvs::distance::DistanceType::L2Expanded, false); handle.sync_stream(stream); @@ -195,7 +195,7 @@ class TSNETest : public ::testing::TestWithParam { // Produce trustworthiness score results.trustworthiness = - trustworthiness_score( + trustworthiness_score( handle, X_d.data(), Y_d.data(), n, p, model_params.dim, 5); return results; diff --git a/cpp/test/sg/umap_parametrizable_test.cu b/cpp/test/sg/umap_parametrizable_test.cu index 2980a79394..5e477f104e 100644 --- a/cpp/test/sg/umap_parametrizable_test.cu +++ b/cpp/test/sg/umap_parametrizable_test.cu @@ -239,7 +239,7 @@ class UMAPParametrizableTest : public ::testing::Test { ASSERT_TRUE(!has_nan(embedding_ptr, n_samples * umap_params.n_components, stream)); double trustworthiness = - trustworthiness_score( + trustworthiness_score( handle, X, embedding_ptr, diff --git a/python/cuml/cuml/cluster/kmeans.pyx b/python/cuml/cuml/cluster/kmeans.pyx index c36765fe34..48ad769cd8 100644 --- a/python/cuml/cuml/cluster/kmeans.pyx +++ b/python/cuml/cuml/cluster/kmeans.pyx @@ -36,7 +36,6 @@ IF GPUBUILD == 1: from cuml.metrics.distance_type cimport DistanceType from cuml.cluster.kmeans_utils cimport params as KMeansParams from cuml.cluster.kmeans_utils cimport KMeansPlusPlus, Random, Array - from cuml.cluster.kmeans_utils cimport DistanceType as CuvsDistanceType from cuml.internals.array import CumlArray from cuml.common.array_descriptor import CumlArrayDescriptor @@ -208,7 +207,7 @@ class KMeans(UniversalBase, params.tol = self.tol params.verbosity = self.verbose params.rng_state.seed = self.random_state - params.metric = CuvsDistanceType.L2Expanded # distance metric as squared L2: @todo - support other metrics # noqa: E501 + params.metric = DistanceType.L2Expanded # distance metric as squared L2: @todo - support other metrics # noqa: E501 params.batch_samples = self.max_samples_per_batch params.oversampling_factor = self.oversampling_factor params.n_init = self.n_init @@ -611,7 +610,7 @@ class KMeans(UniversalBase, cdef KMeansParams* params = \ self._get_kmeans_params() - params.metric = CuvsDistanceType.L2Expanded + params.metric = DistanceType.L2Expanded int_dtype = np.int32 if self.labels_.dtype == np.int32 else np.int64 diff --git a/python/cuml/cuml/cluster/kmeans_utils.pxd b/python/cuml/cuml/cluster/kmeans_utils.pxd index 17d58a49be..c4c10c4b74 100644 --- a/python/cuml/cuml/cluster/kmeans_utils.pxd +++ b/python/cuml/cuml/cluster/kmeans_utils.pxd @@ -18,30 +18,7 @@ import ctypes from libcpp cimport bool from cuml.common.rng_state cimport RngState - -cdef extern from "cuvs/distance/distance.hpp" namespace \ - "cuvs::distance": - ctypedef enum DistanceType: - L2Expanded "cuvs::distance::DistanceType::L2Expanded" - L2SqrtExpanded "cuvs::distance::DistanceType::L2SqrtExpanded" - CosineExpanded "cuvs::distance::DistanceType::CosineExpanded" - L1 "cuvs::distance::DistanceType::L1" - L2Unexpanded "cuvs::distance::DistanceType::L2Unexpanded" - L2SqrtUnexpanded "cuvs::distance::DistanceType::L2SqrtUnexpanded" - InnerProduct "cuvs::distance::DistanceType::InnerProduct" - Linf "cuvs::distance::DistanceType::Linf" - Canberra "cuvs::distance::DistanceType::Canberra" - LpUnexpanded "cuvs::distance::DistanceType::LpUnexpanded" - CorrelationExpanded "cuvs::distance::DistanceType::CorrelationExpanded" - JaccardExpanded "cuvs::distance::DistanceType::JaccardExpanded" - HellingerExpanded "cuvs::distance::DistanceType::HellingerExpanded" - Haversine "cuvs::distance::DistanceType::Haversine" - BrayCurtis "cuvs::distance::DistanceType::BrayCurtis" - JensenShannon "cuvs::distance::DistanceType::JensenShannon" - HammingUnexpanded "cuvs::distance::DistanceType::HammingUnexpanded" - KLDivergence "cuvs::distance::DistanceType::KLDivergence" - RusselRaoExpanded "cuvs::distance::DistanceType::RusselRaoExpanded" - DiceExpanded "cuvs::distance::DistanceType::DiceExpanded" +from cuml.metrics.distance_type cimport DistanceType cdef extern from "cuml/cluster/kmeans.hpp" namespace \ "cuvs::cluster::kmeans::params": diff --git a/python/cuml/cuml/metrics/distance_type.pxd b/python/cuml/cuml/metrics/distance_type.pxd index bd9e92e2ea..f8e2261092 100644 --- a/python/cuml/cuml/metrics/distance_type.pxd +++ b/python/cuml/cuml/metrics/distance_type.pxd @@ -1,5 +1,5 @@ # -# Copyright (c) 2021-2022, NVIDIA CORPORATION. +# Copyright (c) 2021-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,27 +14,27 @@ # limitations under the License. # -cdef extern from "raft/distance/distance_types.hpp" namespace "raft::distance": +cdef extern from "cuvs/distance/distance.hpp" namespace "cuvs::distance": ctypedef enum DistanceType: - L2Expanded "raft::distance::DistanceType::L2Expanded" - L2SqrtExpanded "raft::distance::DistanceType::L2SqrtExpanded" - CosineExpanded "raft::distance::DistanceType::CosineExpanded" - L1 "raft::distance::DistanceType::L1" - L2Unexpanded "raft::distance::DistanceType::L2Unexpanded" - L2SqrtUnexpanded "raft::distance::DistanceType::L2SqrtUnexpanded" - InnerProduct "raft::distance::DistanceType::InnerProduct" - Linf "raft::distance::DistanceType::Linf" - Canberra "raft::distance::DistanceType::Canberra" - LpUnexpanded "raft::distance::DistanceType::LpUnexpanded" - CorrelationExpanded "raft::distance::DistanceType::CorrelationExpanded" - JaccardExpanded "raft::distance::DistanceType::JaccardExpanded" - HellingerExpanded "raft::distance::DistanceType::HellingerExpanded" - Haversine "raft::distance::DistanceType::Haversine" - BrayCurtis "raft::distance::DistanceType::BrayCurtis" - JensenShannon "raft::distance::DistanceType::JensenShannon" - HammingUnexpanded "raft::distance::DistanceType::HammingUnexpanded" - KLDivergence "raft::distance::DistanceType::KLDivergence" - RusselRaoExpanded "raft::distance::DistanceType::RusselRaoExpanded" - DiceExpanded "raft::distance::DistanceType::DiceExpanded" - Precomputed "raft::distance::DistanceType::Precomputed" + L2Expanded "cuvs::distance::DistanceType::L2Expanded" + L2SqrtExpanded "cuvs::distance::DistanceType::L2SqrtExpanded" + CosineExpanded "cuvs::distance::DistanceType::CosineExpanded" + L1 "cuvs::distance::DistanceType::L1" + L2Unexpanded "cuvs::distance::DistanceType::L2Unexpanded" + L2SqrtUnexpanded "cuvs::distance::DistanceType::L2SqrtUnexpanded" + InnerProduct "cuvs::distance::DistanceType::InnerProduct" + Linf "cuvs::distance::DistanceType::Linf" + Canberra "cuvs::distance::DistanceType::Canberra" + LpUnexpanded "cuvs::distance::DistanceType::LpUnexpanded" + CorrelationExpanded "cuvs::distance::DistanceType::CorrelationExpanded" + JaccardExpanded "cuvs::distance::DistanceType::JaccardExpanded" + HellingerExpanded "cuvs::distance::DistanceType::HellingerExpanded" + Haversine "cuvs::distance::DistanceType::Haversine" + BrayCurtis "cuvs::distance::DistanceType::BrayCurtis" + JensenShannon "cuvs::distance::DistanceType::JensenShannon" + HammingUnexpanded "cuvs::distance::DistanceType::HammingUnexpanded" + KLDivergence "cuvs::distance::DistanceType::KLDivergence" + RusselRaoExpanded "cuvs::distance::DistanceType::RusselRaoExpanded" + DiceExpanded "cuvs::distance::DistanceType::DiceExpanded" + Precomputed "cuvs::distance::DistanceType::Precomputed" diff --git a/python/cuml/cuml/metrics/raft_distance_type.pxd b/python/cuml/cuml/metrics/raft_distance_type.pxd new file mode 100644 index 0000000000..bd9e92e2ea --- /dev/null +++ b/python/cuml/cuml/metrics/raft_distance_type.pxd @@ -0,0 +1,40 @@ +# +# Copyright (c) 2021-2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +cdef extern from "raft/distance/distance_types.hpp" namespace "raft::distance": + + ctypedef enum DistanceType: + L2Expanded "raft::distance::DistanceType::L2Expanded" + L2SqrtExpanded "raft::distance::DistanceType::L2SqrtExpanded" + CosineExpanded "raft::distance::DistanceType::CosineExpanded" + L1 "raft::distance::DistanceType::L1" + L2Unexpanded "raft::distance::DistanceType::L2Unexpanded" + L2SqrtUnexpanded "raft::distance::DistanceType::L2SqrtUnexpanded" + InnerProduct "raft::distance::DistanceType::InnerProduct" + Linf "raft::distance::DistanceType::Linf" + Canberra "raft::distance::DistanceType::Canberra" + LpUnexpanded "raft::distance::DistanceType::LpUnexpanded" + CorrelationExpanded "raft::distance::DistanceType::CorrelationExpanded" + JaccardExpanded "raft::distance::DistanceType::JaccardExpanded" + HellingerExpanded "raft::distance::DistanceType::HellingerExpanded" + Haversine "raft::distance::DistanceType::Haversine" + BrayCurtis "raft::distance::DistanceType::BrayCurtis" + JensenShannon "raft::distance::DistanceType::JensenShannon" + HammingUnexpanded "raft::distance::DistanceType::HammingUnexpanded" + KLDivergence "raft::distance::DistanceType::KLDivergence" + RusselRaoExpanded "raft::distance::DistanceType::RusselRaoExpanded" + DiceExpanded "raft::distance::DistanceType::DiceExpanded" + Precomputed "raft::distance::DistanceType::Precomputed" diff --git a/python/cuml/cuml/metrics/trustworthiness.pyx b/python/cuml/cuml/metrics/trustworthiness.pyx index 6db6e8fb65..c07dcff951 100644 --- a/python/cuml/cuml/metrics/trustworthiness.pyx +++ b/python/cuml/cuml/metrics/trustworthiness.pyx @@ -1,5 +1,5 @@ # -# Copyright (c) 2018-2023, NVIDIA CORPORATION. +# Copyright (c) 2018-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -28,10 +28,10 @@ from cuml.internals.input_utils import input_to_cuml_array from pylibraft.common.handle import Handle from pylibraft.common.handle cimport handle_t -cdef extern from "raft/distance/distance_types.hpp" namespace "raft::distance": +cdef extern from "cuvs/distance/distance.hpp" namespace "cuvs::distance": ctypedef int DistanceType - ctypedef DistanceType euclidean "(raft::distance::DistanceType)5" + ctypedef DistanceType euclidean "(cuvs::distance::DistanceType)5" cdef extern from "cuml/metrics/metrics.hpp" namespace "ML::Metrics": diff --git a/python/cuml/cuml/neighbors/nearest_neighbors.pyx b/python/cuml/cuml/neighbors/nearest_neighbors.pyx index 186142ab63..4f551d282c 100644 --- a/python/cuml/cuml/neighbors/nearest_neighbors.pyx +++ b/python/cuml/cuml/neighbors/nearest_neighbors.pyx @@ -38,6 +38,7 @@ from cuml.common import input_to_cuml_array from cuml.common.sparse_utils import is_sparse from cuml.common.sparse_utils import is_dense from cuml.metrics.distance_type cimport DistanceType +from cuml.metrics.raft_distance_type cimport DistanceType as RaftDistanceType from cuml.internals.api_decorators import device_interop_preparation from cuml.internals.api_decorators import enable_device_interop @@ -64,7 +65,7 @@ IF GPUBUILD == 1: float *X, uint32_t n_rows, uint32_t n_cols, - DistanceType metric) except + + RaftDistanceType metric) except + cdef extern from "cuml/neighbors/knn.hpp" namespace "ML": void brute_force_knn( @@ -426,7 +427,7 @@ class NearestNeighbors(UniversalBase, rbc_index = new BallCoverIndex[int64_t, float, uint32_t]( handle_[0], self._fit_X.ptr, self.n_samples_fit_, self.n_features_in_, - metric) + metric) rbc_build_index(handle_[0], deref(rbc_index)) self.knn_index = rbc_index