Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add zero-checks to axpy-like operations #1573

Draft
wants to merge 1 commit into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions common/cuda_hip/matrix/csr_kernels.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,9 @@ __global__ __launch_bounds__(spmv_block_size) void abstract_merge_path_spmv(
merge_path_spmv<items_per_thread>(
num_rows, val, col_idxs, row_ptrs, srow, b, c, row_out, val_out,
[&alpha_val](const type& x) { return alpha_val * x; },
[&beta_val](const type& x) { return beta_val * x; });
[&beta_val](const type& x) {
return is_zero(beta_val) ? zero(beta_val) : beta_val * x;
});
Comment on lines +383 to +385
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor performance comment, you could try switching the lambda and zero check, i.e.

is_zero(beta) 
  ? [&beta_val](const type& x) { return zero(beta); } 
  :  [&beta_val](const type& x) { return beta_val * x; }

But this might not work, since the two branches of the ?: operator have different types. And it might increase compile times, since it might compile the kernel two times

}


Expand Down Expand Up @@ -480,7 +482,8 @@ __global__ __launch_bounds__(spmv_block_size) void abstract_classical_spmv(
device_classical_spmv<subwarp_size>(
num_rows, val, col_idxs, row_ptrs, b, c,
[&alpha_val, &beta_val](const type& x, const type& y) {
return alpha_val * x + beta_val * y;
return is_zero(beta_val) ? alpha_val * x
: alpha_val * x + beta_val * y;
});
}

Expand Down
5 changes: 4 additions & 1 deletion common/cuda_hip/matrix/ell_kernels.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,10 @@ __global__ __launch_bounds__(default_block_size) void spmv(
num_stored_elements_per_row, b, c, c_stride,
[&alpha_val, &beta_val](const auto& x, const OutputValueType& y) {
return static_cast<OutputValueType>(
alpha_val * x + static_cast<arithmetic_type>(beta_val * y));
is_zero(beta_val)
? alpha_val * x
: alpha_val * x +
static_cast<arithmetic_type>(beta_val * y));
});
}
}
Expand Down
4 changes: 3 additions & 1 deletion common/cuda_hip/matrix/sellp_kernels.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ __global__ __launch_bounds__(default_block_size) void advanced_spmv_kernel(
}
}
c[row * c_stride + column_id] =
beta[0] * c[row * c_stride + column_id] + alpha[0] * val;
is_zero(beta[0])
? alpha[0] * val
: beta[0] * c[row * c_stride + column_id] + alpha[0] * val;
}
}

Expand Down
3 changes: 2 additions & 1 deletion common/cuda_hip/matrix/sparsity_csr_kernels.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ __global__ __launch_bounds__(spmv_block_size) void abstract_classical_spmv(
device_classical_spmv<subwarp_size>(
num_rows, val, col_idxs, row_ptrs, b, c,
[&alpha_val, &beta_val](const type& x, const type& y) {
return alpha_val * x + beta_val * y;
return is_zero(beta_val) ? alpha_val * x
: alpha_val * x + beta_val * y;
});
}

Expand Down
28 changes: 22 additions & 6 deletions common/unified/matrix/dense_kernels.template.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,22 @@ void scale(std::shared_ptr<const DefaultExecutor> exec,
run_kernel(
exec,
[] GKO_KERNEL(auto row, auto col, auto alpha, auto x) {
x(row, col) *= alpha[col];
if (is_zero(zero(alpha[col]))) {
x(row, col) = zero(alpha[col]);
} else {
x(row, col) *= alpha[col];
}
},
x->get_size(), alpha->get_const_values(), x);
} else {
run_kernel(
exec,
[] GKO_KERNEL(auto row, auto col, auto alpha, auto x) {
x(row, col) *= alpha[0];
if (is_zero(alpha[0])) {
x(row, col) = zero(alpha[0]);
} else {
x(row, col) *= alpha[0];
}
},
x->get_size(), alpha->get_const_values(), x);
}
Expand Down Expand Up @@ -130,7 +138,9 @@ void add_scaled(std::shared_ptr<const DefaultExecutor> exec,
run_kernel(
exec,
[] GKO_KERNEL(auto row, auto col, auto alpha, auto x, auto y) {
y(row, col) += alpha[0] * x(row, col);
if (is_nonzero(alpha[0])) {
y(row, col) += alpha[0] * x(row, col);
}
},
x->get_size(), alpha->get_const_values(), x, y);
}
Expand All @@ -153,7 +163,9 @@ void sub_scaled(std::shared_ptr<const DefaultExecutor> exec,
run_kernel(
exec,
[] GKO_KERNEL(auto row, auto col, auto alpha, auto x, auto y) {
y(row, col) -= alpha[0] * x(row, col);
if (is_nonzero(alpha[0])) {
y(row, col) -= alpha[0] * x(row, col);
}
},
x->get_size(), alpha->get_const_values(), x, y);
}
Expand All @@ -170,7 +182,9 @@ void add_scaled_diag(std::shared_ptr<const DefaultExecutor> exec,
run_kernel(
exec,
[] GKO_KERNEL(auto i, auto alpha, auto diag, auto y) {
y(i, i) += alpha[0] * diag[i];
if (is_nonzero(alpha[0])) {
y(i, i) += alpha[0] * diag[i];
}
},
x->get_size()[0], alpha->get_const_values(), x->get_const_values(), y);
}
Expand All @@ -186,7 +200,9 @@ void sub_scaled_diag(std::shared_ptr<const DefaultExecutor> exec,
run_kernel(
exec,
[] GKO_KERNEL(auto i, auto alpha, auto diag, auto y) {
y(i, i) -= alpha[0] * diag[i];
if (is_nonzero(alpha[0])) {
y(i, i) -= alpha[0] * diag[i];
}
},
x->get_size()[0], alpha->get_const_values(), x->get_const_values(), y);
}
Expand Down
9 changes: 6 additions & 3 deletions dpcpp/matrix/csr_kernels.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -490,8 +490,10 @@ void abstract_merge_path_spmv(
merge_path_spmv<items_per_thread>(
num_rows, val, col_idxs, row_ptrs, srow, b, c, row_out, val_out,
[&alpha_val](const type& x) { return alpha_val * x; },
[&beta_val](const type& x) { return beta_val * x; }, item_ct1,
shared_row_ptrs);
[&beta_val](const type& x) {
return is_zero(beta_val) ? zero(beta_val) : beta_val * x;
},
item_ct1, shared_row_ptrs);
}

template <int items_per_thread, typename matrix_accessor,
Expand Down Expand Up @@ -713,7 +715,8 @@ void abstract_classical_spmv(
device_classical_spmv<subgroup_size>(
num_rows, val, col_idxs, row_ptrs, b, c,
[&alpha_val, &beta_val](const type& x, const type& y) {
return alpha_val * x + beta_val * y;
return is_zero(beta_val) ? alpha_val * x
: alpha_val * x + beta_val * y;
},
item_ct1);
}
Expand Down
5 changes: 4 additions & 1 deletion dpcpp/matrix/ell_kernels.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,10 @@ void spmv(
num_stored_elements_per_row, b, c, c_stride,
[&alpha_val, &beta_val](const auto& x, const OutputValueType& y) {
return static_cast<OutputValueType>(
alpha_val * x + static_cast<arithmetic_type>(beta_val * y));
is_zero(beta_val)
? alpha_val * x
: alpha_val * x +
static_cast<arithmetic_type>(beta_val * y));
},
item_ct1, storage);
}
Expand Down
4 changes: 3 additions & 1 deletion dpcpp/matrix/sellp_kernels.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,9 @@ void advanced_spmv_kernel(size_type num_rows, size_type num_right_hand_sides,
}
}
c[row * c_stride + column_id] =
beta[0] * c[row * c_stride + column_id] + alpha[0] * val;
is_zero(beta[0])
? alpha[0] * val
: alpha[0] * val + beta[0] * c[row * c_stride + column_id];
}
}

Expand Down
3 changes: 2 additions & 1 deletion dpcpp/matrix/sparsity_csr_kernels.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,8 @@ void abstract_classical_spmv(
device_classical_spmv<subgroup_size>(
num_rows, val, col_idxs, row_ptrs, b, c,
[&alpha_val, &beta_val](const type& x, const type& y) {
return alpha_val * x + beta_val * y;
return is_zero(beta_val) ? alpha_val * x
: alpha_val * x + beta_val * y;
},
item_ct1);
}
Expand Down
2 changes: 1 addition & 1 deletion omp/matrix/csr_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ void advanced_spmv(std::shared_ptr<const OmpExecutor> exec,
#pragma omp parallel for
for (size_type row = 0; row < a->get_size()[0]; ++row) {
for (size_type j = 0; j < c->get_size()[1]; ++j) {
auto sum = c_vals(row, j) * vbeta;
auto sum = is_zero(vbeta) ? zero(vbeta) : c_vals(row, j) * vbeta;
for (size_type k = row_ptrs[row];
k < static_cast<size_type>(row_ptrs[row + 1]); ++k) {
arithmetic_type val = a_vals(k);
Expand Down
2 changes: 1 addition & 1 deletion omp/matrix/dense_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ void apply(std::shared_ptr<const DefaultExecutor> exec,
#pragma omp parallel for
for (size_type row = 0; row < c->get_size()[0]; ++row) {
for (size_type col = 0; col < c->get_size()[1]; ++col) {
c->at(row, col) *= zero<ValueType>();
c->at(row, col) = zero<ValueType>();
}
}
}
Expand Down
4 changes: 3 additions & 1 deletion omp/matrix/ell_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,9 @@ void advanced_spmv(std::shared_ptr<const OmpExecutor> exec,
const auto alpha_val = arithmetic_type{alpha->at(0, 0)};
const auto beta_val = arithmetic_type{beta->at(0, 0)};
auto out = [&](auto i, auto j, auto value) {
return alpha_val * value + beta_val * arithmetic_type{c->at(i, j)};
return is_zero(beta_val) ? alpha_val * value
: alpha_val * value +
beta_val * arithmetic_type{c->at(i, j)};
};
if (num_rhs == 1) {
spmv_small_rhs<1>(exec, a, b, c, out);
Expand Down
6 changes: 5 additions & 1 deletion omp/matrix/fbcsr_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,11 @@ void advanced_spmv(std::shared_ptr<const OmpExecutor> exec,
for (IndexType ibrow = 0; ibrow < nbrows; ++ibrow) {
for (IndexType row = ibrow * bs; row < (ibrow + 1) * bs; ++row) {
for (IndexType rhs = 0; rhs < nvecs; rhs++) {
c->at(row, rhs) *= vbeta;
if (is_zero(vbeta)) {
c->at(row, rhs) = zero(vbeta);
} else {
c->at(row, rhs) *= vbeta;
}
}
}
for (IndexType inz = row_ptrs[ibrow]; inz < row_ptrs[ibrow + 1];
Expand Down
3 changes: 2 additions & 1 deletion omp/matrix/sellp_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,8 @@ void advanced_spmv(std::shared_ptr<const OmpExecutor> exec,
const auto alpha_val = alpha->at(0, 0);
const auto beta_val = beta->at(0, 0);
auto out = [&](auto i, auto j, auto value) {
return alpha_val * value + beta_val * c->at(i, j);
return is_zero(beta_val) ? alpha_val * value
: alpha_val * value + beta_val * c->at(i, j);
};
if (num_rhs == 1) {
spmv_small_rhs<1>(exec, a, b, c, out);
Expand Down
4 changes: 3 additions & 1 deletion omp/matrix/sparsity_csr_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,9 @@ void advanced_spmv(std::shared_ptr<const OmpExecutor> exec,
val * static_cast<arithmetic_type>(b->at(col_idxs[k], j));
}
c->at(row, j) = static_cast<OutputValueType>(
vbeta * static_cast<arithmetic_type>(c->at(row, j)) +
(is_zero(vbeta)
? zero(vbeta)
: vbeta * static_cast<arithmetic_type>(c->at(row, j))) +
valpha * temp_val);
}
}
Expand Down
2 changes: 1 addition & 1 deletion reference/matrix/csr_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ void advanced_spmv(std::shared_ptr<const ReferenceExecutor> exec,
auto c_vals = acc::helper::build_rrm_accessor<arithmetic_type>(c);
for (size_type row = 0; row < a->get_size()[0]; ++row) {
for (size_type j = 0; j < c->get_size()[1]; ++j) {
auto sum = c_vals(row, j) * vbeta;
auto sum = is_zero(vbeta) ? zero(vbeta) : c_vals(row, j) * vbeta;
for (size_type k = row_ptrs[row];
k < static_cast<size_type>(row_ptrs[row + 1]); ++k) {
arithmetic_type val = a_vals(k);
Expand Down
36 changes: 24 additions & 12 deletions reference/matrix/dense_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ void apply(std::shared_ptr<const ReferenceExecutor> exec,
} else {
for (size_type row = 0; row < c->get_size()[0]; ++row) {
for (size_type col = 0; col < c->get_size()[1]; ++col) {
c->at(row, col) *= zero<ValueType>();
c->at(row, col) = zero<ValueType>();
}
}
}
Expand Down Expand Up @@ -133,7 +133,11 @@ void scale(std::shared_ptr<const ReferenceExecutor> exec,
if (alpha->get_size()[1] == 1) {
for (size_type i = 0; i < x->get_size()[0]; ++i) {
for (size_type j = 0; j < x->get_size()[1]; ++j) {
x->at(i, j) *= alpha->at(0, 0);
if (is_zero(alpha->at(0, 0))) {
x->at(i, j) = zero<ValueType>();
} else {
x->at(i, j) *= alpha->at(0, 0);
}
}
}
} else {
Expand Down Expand Up @@ -178,9 +182,11 @@ void add_scaled(std::shared_ptr<const ReferenceExecutor> exec,
const matrix::Dense<ValueType>* x, matrix::Dense<ValueType>* y)
{
if (alpha->get_size()[1] == 1) {
for (size_type i = 0; i < x->get_size()[0]; ++i) {
for (size_type j = 0; j < x->get_size()[1]; ++j) {
y->at(i, j) += alpha->at(0, 0) * x->at(i, j);
if (is_nonzero(alpha->at(0, 0))) {
for (size_type i = 0; i < x->get_size()[0]; ++i) {
for (size_type j = 0; j < x->get_size()[1]; ++j) {
y->at(i, j) += alpha->at(0, 0) * x->at(i, j);
}
}
}
} else {
Expand All @@ -202,9 +208,11 @@ void sub_scaled(std::shared_ptr<const ReferenceExecutor> exec,
const matrix::Dense<ValueType>* x, matrix::Dense<ValueType>* y)
{
if (alpha->get_size()[1] == 1) {
for (size_type i = 0; i < x->get_size()[0]; ++i) {
for (size_type j = 0; j < x->get_size()[1]; ++j) {
y->at(i, j) -= alpha->at(0, 0) * x->at(i, j);
if (is_nonzero(alpha->at(0, 0))) {
for (size_type i = 0; i < x->get_size()[0]; ++i) {
for (size_type j = 0; j < x->get_size()[1]; ++j) {
y->at(i, j) -= alpha->at(0, 0) * x->at(i, j);
}
}
}
} else {
Expand All @@ -227,8 +235,10 @@ void add_scaled_diag(std::shared_ptr<const ReferenceExecutor> exec,
matrix::Dense<ValueType>* y)
{
const auto diag_values = x->get_const_values();
for (size_type i = 0; i < x->get_size()[0]; i++) {
y->at(i, i) += alpha->at(0, 0) * diag_values[i];
if (is_nonzero(alpha->at(0, 0))) {
for (size_type i = 0; i < x->get_size()[0]; i++) {
y->at(i, i) += alpha->at(0, 0) * diag_values[i];
}
}
}

Expand All @@ -242,8 +252,10 @@ void sub_scaled_diag(std::shared_ptr<const ReferenceExecutor> exec,
matrix::Dense<ValueType>* y)
{
const auto diag_values = x->get_const_values();
for (size_type i = 0; i < x->get_size()[0]; i++) {
y->at(i, i) -= alpha->at(0, 0) * diag_values[i];
if (is_nonzero(alpha->at(0, 0))) {
for (size_type i = 0; i < x->get_size()[0]; i++) {
y->at(i, i) -= alpha->at(0, 0) * diag_values[i];
}
}
}

Expand Down
4 changes: 2 additions & 2 deletions reference/matrix/ell_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ void advanced_spmv(std::shared_ptr<const ReferenceExecutor> exec,

for (size_type j = 0; j < c->get_size()[1]; j++) {
for (size_type row = 0; row < a->get_size()[0]; row++) {
arithmetic_type result = c->at(row, j);
result *= beta_val;
arithmetic_type result =
is_zero(beta_val) ? zero(beta_val) : beta_val * c->at(row, j);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
is_zero(beta_val) ? zero(beta_val) : beta_val * c->at(row, j);
is_zero(beta_val) ? zero(beta_val) : beta_val * static_cast<arithmetic_type>(c->at(row, j));

to make it compile

for (size_type i = 0; i < num_stored_elements_per_row; i++) {
arithmetic_type val = a_vals(row + i * stride);
auto col = a->col_at(row, i);
Expand Down
6 changes: 5 additions & 1 deletion reference/matrix/fbcsr_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,11 @@ void advanced_spmv(const std::shared_ptr<const ReferenceExecutor>,
for (IndexType ibrow = 0; ibrow < nbrows; ++ibrow) {
for (IndexType row = ibrow * bs; row < (ibrow + 1) * bs; ++row) {
for (IndexType rhs = 0; rhs < nvecs; rhs++) {
c->at(row, rhs) *= vbeta;
if (is_zero(vbeta)) {
c->at(row, rhs) = zero(vbeta);
} else {
c->at(row, rhs) *= vbeta;
}
}
}

Expand Down
6 changes: 5 additions & 1 deletion reference/matrix/sellp_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,11 @@ void advanced_spmv(std::shared_ptr<const ReferenceExecutor> exec,
break;
}
for (size_type j = 0; j < c->get_size()[1]; j++) {
c->at(global_row, j) *= vbeta;
if (is_nonzero(vbeta)) {
c->at(global_row, j) *= vbeta;
} else {
c->at(global_row, j) = zero<ValueType>();
}
}
for (size_type i = 0; i < slice_lengths[slice]; i++) {
auto val = a->val_at(row, slice_sets[slice], i);
Expand Down
4 changes: 3 additions & 1 deletion reference/matrix/sparsity_csr_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ void advanced_spmv(std::shared_ptr<const ReferenceExecutor> exec,
val * static_cast<arithmetic_type>(b->at(col_idxs[k], j));
}
c->at(row, j) = static_cast<OutputValueType>(
vbeta * static_cast<arithmetic_type>(c->at(row, j)) +
(is_zero(vbeta)
? zero(vbeta)
: vbeta * static_cast<arithmetic_type>(c->at(row, j))) +
valpha * temp_val);
}
}
Expand Down
Loading
Loading