Skip to content

Commit

Permalink
Merge pull request #6160 from rapidsai/branch-24.12
Browse files Browse the repository at this point in the history
Forward-merge branch-24.12 into branch-25.02
  • Loading branch information
GPUtester authored Dec 5, 2024
2 parents b1278ab + de96f3a commit ba482f2
Showing 1 changed file with 54 additions and 17 deletions.
71 changes: 54 additions & 17 deletions cpp/src/glm/qn/mg/standardization.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,59 @@ namespace ML {
namespace GLM {
namespace opg {

/**
* @brief Compute variance of the input matrix across all GPUs
*
* Variance operation is assumed to be performed on a given column.
*
* @tparam T the data type
* @param handle the internal cuml handle object
* @param X the input dense matrix
* @param n_samples number of rows of data across all GPUs
* @param mean_vector_all_samples the mean vector of rows of data across all GPUs
* @param var_vector the output variance vector
*/
template <typename T>
void vars(const raft::handle_t& handle,
const SimpleDenseMat<T>& X,
size_t n_samples,
T* mean_vector_all_samples,
T* var_vector)
{
const T* input_data = X.data;
int D = X.n;
int num_rows = X.m;
bool col_major = (X.ord == COL_MAJOR);
auto stream = handle.get_stream();
auto& comm = handle.get_comms();

rmm::device_uvector<T> zero(D, handle.get_stream());
SimpleVec<T> zero_vec(zero.data(), D);
zero_vec.fill(0., stream);

// get sum of squares on every column
raft::stats::vars(var_vector, input_data, zero.data(), D, num_rows, false, !col_major, stream);
T weight = n_samples < 1 ? T(0) : T(1) * num_rows / T(n_samples - 1);
raft::linalg::multiplyScalar(var_vector, var_vector, weight, D, stream);
comm.allreduce(var_vector, var_vector, D, raft::comms::op_t::SUM, stream);
comm.sync_stream(stream);

// subtract mean
weight = n_samples <= 1 ? T(1) : T(n_samples) / T(n_samples - 1);
raft::linalg::binaryOp(
var_vector,
var_vector,
mean_vector_all_samples,
D,
[weight] __device__(const T v, const T m) {
T scaled_m = weight * m * m;
T diff = v - scaled_m;
// avoid negative variance that is due to precision loss of floating point arithmetic
return diff >= 0. ? diff : v;
},
stream);
}

template <typename T>
void mean_stddev(const raft::handle_t& handle,
const SimpleDenseMat<T>& X,
Expand All @@ -60,23 +113,7 @@ void mean_stddev(const raft::handle_t& handle,
comm.allreduce(mean_vector, mean_vector, D, raft::comms::op_t::SUM, stream);
comm.sync_stream(stream);

raft::stats::vars(stddev_vector, input_data, mean_vector, D, num_rows, false, !col_major, stream);
weight = n_samples < 1 ? T(0) : T(1) * num_rows / T(n_samples - 1);
raft::linalg::multiplyScalar(stddev_vector, stddev_vector, weight, D, stream);
comm.allreduce(stddev_vector, stddev_vector, D, raft::comms::op_t::SUM, stream);
comm.sync_stream(stream);

// avoid negative variance that is due to precision loss of floating point arithmetic
weight = n_samples < 1 ? T(0) : T(1) / T(n_samples - 1);
weight = n_samples * weight;
auto no_neg_op = [weight] __device__(const T a, const T b) -> T {
if (a >= 0) return a;

return a + weight * b * b;
};

raft::linalg::binaryOp(stddev_vector, stddev_vector, mean_vector, D, no_neg_op, stream);

vars<T>(handle, X, n_samples, mean_vector, stddev_vector);
raft::linalg::sqrt(stddev_vector, stddev_vector, D, handle.get_stream());
}

Expand Down

0 comments on commit ba482f2

Please sign in to comment.