Skip to content

Commit

Permalink
kahan_sum
Browse files Browse the repository at this point in the history
  • Loading branch information
zclllyybb committed Nov 12, 2024
1 parent 7296cfd commit 3ae6e44
Showing 1 changed file with 34 additions and 7 deletions.
41 changes: 34 additions & 7 deletions be/src/vec/aggregate_functions/aggregate_function_covar.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,27 +62,36 @@ namespace doris::vectorized {

template <typename T>
struct BaseData {
BaseData() : sum_x(0.0), sum_y(0.0), sum_xy(0.0), count(0) {}
BaseData() = default;
virtual ~BaseData() = default;

void write(BufferWritable& buf) const {
write_binary(sum_x, buf);
write_binary(sum_x_com, buf);
write_binary(sum_y, buf);
write_binary(sum_y_com, buf);
write_binary(sum_xy, buf);
write_binary(sum_xy_com, buf);
write_binary(count, buf);
}

void read(BufferReadable& buf) {
read_binary(sum_x, buf);
read_binary(sum_x_com, buf);
read_binary(sum_y, buf);
read_binary(sum_y_com, buf);
read_binary(sum_xy, buf);
read_binary(sum_xy_com, buf);
read_binary(count, buf);
}

void reset() {
sum_x = 0.0;
sum_x_com = 0;
sum_y = 0.0;
sum_y_com = 0;
sum_xy = 0.0;
sum_xy_com = 0;
count = 0;
}

Expand All @@ -96,6 +105,21 @@ struct BaseData {
return val;
};

void kahan_sum(long double src_data, long double& dest_data, long double& dest_com) {
src_data -= dest_com;
long double tmp_sum = dest_data + src_data;
dest_com = (tmp_sum - dest_data) - src_data;
dest_data = tmp_sum;
}

void kahan_merge(long double src_data, long double src_com, long double& dest_data,
long double& dest_com) {
long double tmp_sum = dest_data + src_data;
long double tmp_com = (tmp_sum - dest_data) - src_data;
dest_com = src_com + dest_com + tmp_com;
dest_data = tmp_sum + dest_com;
}

// Cov(X, Y) = E(XY) - E(X)E(Y)
double get_pop_result() const {
if (count == 1) {
Expand All @@ -112,9 +136,9 @@ struct BaseData {
if (rhs.count == 0) {
return;
}
sum_x += rhs.sum_x;
sum_y += rhs.sum_y;
sum_xy += rhs.sum_xy;
kahan_merge(rhs.sum_x, rhs.sum_x_com, sum_x, sum_x_com);
kahan_merge(rhs.sum_y, rhs.sum_y_com, sum_y, sum_y_com);
kahan_merge(rhs.sum_xy, rhs.sum_xy_com, sum_xy, sum_xy_com);
count += rhs.count;
}

Expand All @@ -126,15 +150,18 @@ struct BaseData {
assert_cast<const ColumnVector<T>&, TypeCheckOnRelease::DISABLE>(*column_y);
double source_data_y = sources_y.get_data()[row_num];

sum_x += source_data_x;
sum_y += source_data_y;
sum_xy += source_data_x * source_data_y;
kahan_sum(source_data_x, sum_x, sum_x_com);
kahan_sum(source_data_y, sum_y, sum_y_com);
kahan_sum(source_data_x * source_data_y, sum_y, sum_y_com);
count += 1;
}

long double sum_x {};
long double sum_x_com {};
long double sum_y {};
long double sum_y_com {};
long double sum_xy {};
long double sum_xy_com {};
int64_t count {};
};

Expand Down

0 comments on commit 3ae6e44

Please sign in to comment.