Skip to content

Commit

Permalink
Optimize MG variance calculation for dataset standardization for logi…
Browse files Browse the repository at this point in the history
…stic regression (#6138)

MG variance calculation currently involks raft SG vars API. However, the abs() step of raft SG vars API introduces errors in skewed data distribution (e.g., one GPU gets small values 1 and 2, and the other GPU gets large values 98 and 99). 

The PR avoids the effect of abs() when involking SG vars for calculating MG vars. The key idea is to pass a vector of zeroes when calling SG vars.

Authors:
  - Jinfeng Li (https://github.com/lijinf2)

Approvers:
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: #6138
  • Loading branch information
lijinf2 authored Dec 5, 2024
1 parent 7e6dbc0 commit de96f3a
Showing 1 changed file with 54 additions and 17 deletions.
71 changes: 54 additions & 17 deletions cpp/src/glm/qn/mg/standardization.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,59 @@ namespace ML {
namespace GLM {
namespace opg {

/**
* @brief Compute variance of the input matrix across all GPUs
*
* Variance operation is assumed to be performed on a given column.
*
* @tparam T the data type
* @param handle the internal cuml handle object
* @param X the input dense matrix
* @param n_samples number of rows of data across all GPUs
* @param mean_vector_all_samples the mean vector of rows of data across all GPUs
* @param var_vector the output variance vector
*/
template <typename T>
void vars(const raft::handle_t& handle,
const SimpleDenseMat<T>& X,
size_t n_samples,
T* mean_vector_all_samples,
T* var_vector)
{
const T* input_data = X.data;
int D = X.n;
int num_rows = X.m;
bool col_major = (X.ord == COL_MAJOR);
auto stream = handle.get_stream();
auto& comm = handle.get_comms();

rmm::device_uvector<T> zero(D, handle.get_stream());
SimpleVec<T> zero_vec(zero.data(), D);
zero_vec.fill(0., stream);

// get sum of squares on every column
raft::stats::vars(var_vector, input_data, zero.data(), D, num_rows, false, !col_major, stream);
T weight = n_samples < 1 ? T(0) : T(1) * num_rows / T(n_samples - 1);
raft::linalg::multiplyScalar(var_vector, var_vector, weight, D, stream);
comm.allreduce(var_vector, var_vector, D, raft::comms::op_t::SUM, stream);
comm.sync_stream(stream);

// subtract mean
weight = n_samples <= 1 ? T(1) : T(n_samples) / T(n_samples - 1);
raft::linalg::binaryOp(
var_vector,
var_vector,
mean_vector_all_samples,
D,
[weight] __device__(const T v, const T m) {
T scaled_m = weight * m * m;
T diff = v - scaled_m;
// avoid negative variance that is due to precision loss of floating point arithmetic
return diff >= 0. ? diff : v;
},
stream);
}

template <typename T>
void mean_stddev(const raft::handle_t& handle,
const SimpleDenseMat<T>& X,
Expand All @@ -60,23 +113,7 @@ void mean_stddev(const raft::handle_t& handle,
comm.allreduce(mean_vector, mean_vector, D, raft::comms::op_t::SUM, stream);
comm.sync_stream(stream);

raft::stats::vars(stddev_vector, input_data, mean_vector, D, num_rows, false, !col_major, stream);
weight = n_samples < 1 ? T(0) : T(1) * num_rows / T(n_samples - 1);
raft::linalg::multiplyScalar(stddev_vector, stddev_vector, weight, D, stream);
comm.allreduce(stddev_vector, stddev_vector, D, raft::comms::op_t::SUM, stream);
comm.sync_stream(stream);

// avoid negative variance that is due to precision loss of floating point arithmetic
weight = n_samples < 1 ? T(0) : T(1) / T(n_samples - 1);
weight = n_samples * weight;
auto no_neg_op = [weight] __device__(const T a, const T b) -> T {
if (a >= 0) return a;

return a + weight * b * b;
};

raft::linalg::binaryOp(stddev_vector, stddev_vector, mean_vector, D, no_neg_op, stream);

vars<T>(handle, X, n_samples, mean_vector, stddev_vector);
raft::linalg::sqrt(stddev_vector, stddev_vector, D, handle.get_stream());
}

Expand Down

0 comments on commit de96f3a

Please sign in to comment.