From 930dfee1e5c65dd966f08dfc43c343b8dd853456 Mon Sep 17 00:00:00 2001 From: Andrew Reisner Date: Tue, 13 Aug 2024 21:18:31 -0600 Subject: [PATCH] Simplify templates in par_strength --- raptor/par_strength.cpp | 90 +++++++++++------------- raptor/ruge_stuben/par_interpolation.cpp | 25 +------ 2 files changed, 44 insertions(+), 71 deletions(-) diff --git a/raptor/par_strength.cpp b/raptor/par_strength.cpp index 7b62d696..defed6c9 100644 --- a/raptor/par_strength.cpp +++ b/raptor/par_strength.cpp @@ -4,6 +4,7 @@ #include #include "core/par_matrix.hpp" +#include "core/matrix_traits.hpp" using namespace raptor; @@ -64,15 +65,17 @@ template<> struct norm_coupling static constexpr double strongest(double a, double b) { return std::max(std::abs(a), b); } }; -template +template struct mat_args { - const std::conditional_t mat; + T & mat; const int * variables; int beg; }; +template +mat_args(T&, const int*, int)->mat_args; template -constexpr double value(Matrix & mat, int i) { +constexpr double value(CSRMatrix & mat, int i) { return mat.vals[i]; } template @@ -83,8 +86,8 @@ constexpr double value(BSRMatrix & mat, int i) { if (P::comp(val, curr)) curr = val; return curr; } -template -constexpr double strongest_element(int i, int row_var, const mat_args & a) { +template +constexpr double strongest_element(int i, int row_var, const mat_args & a) { auto curr = P::init; for (int j = a.beg; j < a.mat.idx1[i+1]; ++j) { auto col = a.mat.idx2[j]; @@ -96,9 +99,9 @@ constexpr double strongest_element(int i, int row_var, const mat_args & } return curr; } -template +template constexpr double strongest_connection(int row, int row_var, int num_variables, - mat_args on_proc, mat_args off_proc) { + mat_args on_proc, mat_args off_proc) { if (num_variables == 1) { return P::strongest( strongest_element(row, row_var, on_proc), @@ -109,11 +112,12 @@ constexpr double strongest_connection(int row, int row_var, int num_variables, strongest_element(row, row_var, off_proc)); } } -template -struct append_args : mat_args { Matrix & soc; }; -template +template +struct append_args : mat_args { Matrix & soc; }; + +template constexpr void add_connections(int row, int row_var, double threshold, - const append_args & args) { + const append_args & args) { for (int j = args.beg; j < args.mat.idx1[row+1]; ++j) { auto col = args.mat.idx2[j]; if constexpr (filter) @@ -126,9 +130,9 @@ constexpr void add_connections(int row, int row_var, double threshold, } } } -template +template constexpr void add_strong_connections(int row, int row_var, int num_variables, double threshold, - append_args on_proc, append_args off_proc) { + append_args on_proc, append_args off_proc) { if (num_variables == 1) { add_connections(row, row_var, threshold, on_proc); add_connections(row, row_var, threshold, off_proc); @@ -141,6 +145,8 @@ constexpr void add_strong_connections(int row, int row_var, int num_variables, d void hybrid_strength(ParCSRMatrix & A, ParCSRMatrix & S, double theta, int num_variables, int *variables, int *off_variables) { + auto & on_proc = dynamic_cast(*A.on_proc); + auto & off_proc = dynamic_cast(*A.off_proc); for (int i = 0; i < A.local_num_rows; i++) { auto row_start_on = A.on_proc->idx1[i]; @@ -165,9 +171,9 @@ void hybrid_strength(ParCSRMatrix & A, ParCSRMatrix & S, auto row_scale = [&]() { auto get_row_scale = [&](auto comp) { using P = decltype(comp); - return strongest_connection(i, row_var, num_variables, - {*A.on_proc, variables, row_start_on}, - {*A.off_proc, off_variables, row_start_off}); + return strongest_connection(i, row_var, num_variables, + {on_proc, variables, row_start_on}, + {off_proc, off_variables, row_start_off}); }; if (diag < 0.0) return get_row_scale(positive_coupling{}); @@ -188,9 +194,9 @@ void hybrid_strength(ParCSRMatrix & A, ParCSRMatrix & S, // row_max * theta auto add_row = [&](auto comp) { using P = decltype(comp); - add_strong_connections(i, row_var, num_variables, threshold, - {{*A.on_proc, variables, row_start_on}, *S.on_proc}, - {{*A.off_proc, off_variables, row_start_off}, *S.off_proc}); + add_strong_connections(i, row_var, num_variables, threshold, + {{on_proc, variables, row_start_on}, *S.on_proc}, + {{off_proc, off_variables, row_start_off}, *S.off_proc}); }; if (diag < 0) add_row(positive_coupling{}); @@ -202,11 +208,13 @@ void hybrid_strength(ParCSRMatrix & A, ParCSRMatrix & S, } } // hybrid_strength -template -void norm_strength(ParCSRMatrix & A, ParCSRMatrix & S, +template = true> +void norm_strength(T & A, ParCSRMatrix & S, double theta, int num_variables, int *variables, int *off_variables) { - auto *bsr_diag = dynamic_cast(A.on_proc); - auto *bsr_offd = dynamic_cast(A.off_proc); + using seq_t = sequential_matrix_t; + auto & diag = dynamic_cast(*A.on_proc); + auto & offd = dynamic_cast(*A.off_proc); + using P = norm_coupling; for (int i = 0; i < A.local_num_rows; i++) { @@ -224,25 +232,18 @@ void norm_strength(ParCSRMatrix & A, ParCSRMatrix & S, auto row_var = (num_variables > 1) ? variables[i] : -1; // Find value with max magnitude in row - auto row_scale = [&]() { - if constexpr (is_bsr) - return strongest_connection(i, row_var, num_variables, - {*bsr_diag, variables, row_start_on}, - {*bsr_offd, off_variables, row_start_off}); - else - return strongest_connection(i, row_var, num_variables, - {*A.on_proc, variables, row_start_on}, - {*A.off_proc, off_variables, row_start_off}); - }(); + auto row_scale = strongest_connection(i, row_var, num_variables, + {diag, variables, row_start_on}, + {offd, off_variables, row_start_off}); // Multiply row max magnitude by theta auto threshold = row_scale * theta; // Always add diagonal S.on_proc->idx2[S.on_proc->nnz] = i; - if constexpr (is_bsr) + if constexpr (is_bsr_v) S.on_proc->vals[S.on_proc->nnz] = has_zero_diag ? 0 : - value

(dynamic_cast(*A.on_proc), row_start_on - 1); + value

(diag, row_start_on - 1); else S.on_proc->vals[S.on_proc->nnz] = has_zero_diag ? 0 : A.on_proc->vals[row_start_on - 1]; S.on_proc->nnz++; @@ -250,14 +251,9 @@ void norm_strength(ParCSRMatrix & A, ParCSRMatrix & S, // Add all off-diagonal entries to strength // if magnitude greater than equal to // row_max * theta - if constexpr (is_bsr) - add_strong_connections(i, row_var, num_variables, threshold, - {{*bsr_diag, variables, row_start_on}, *S.on_proc}, - {{*bsr_offd, off_variables, row_start_off}, *S.off_proc}); - else - add_strong_connections(i, row_var, num_variables, threshold, - {*A.on_proc, variables, row_start_on}, - {*A.off_proc, off_variables, row_start_off}); + add_strong_connections(i, row_var, num_variables, threshold, + {{diag, variables, row_start_on}, *S.on_proc}, + {{offd, off_variables, row_start_off}, *S.off_proc}); } S.on_proc->idx1[i+1] = S.on_proc->nnz; S.off_proc->idx1[i+1] = S.off_proc->nnz; @@ -291,12 +287,12 @@ ParCSRMatrix* classical_strength(ParCSRMatrix* A, double theta, bool tap_amg, in classical::init_strength(*A->on_proc, *S->on_proc); classical::init_strength(*A->off_proc, *S->off_proc); - auto is_bsr = dynamic_cast(A); - if (!is_bsr) { + auto bsr = dynamic_cast(A); + if (!bsr) { classical::hybrid_strength(*A, *S, theta, num_variables, variables, off_variables); } else { - classical::norm_strength(*A, *S, theta, num_variables, - variables, off_variables); + classical::norm_strength(*bsr, *S, theta, num_variables, + variables, off_variables); } classical::finalize_strength(*A, *S); diff --git a/raptor/ruge_stuben/par_interpolation.cpp b/raptor/ruge_stuben/par_interpolation.cpp index fa1382ac..8778b293 100644 --- a/raptor/ruge_stuben/par_interpolation.cpp +++ b/raptor/ruge_stuben/par_interpolation.cpp @@ -2,6 +2,7 @@ // License: Simplified BSD, http://opensource.org/licenses/BSD-2-Clause #include "assert.h" #include "raptor/core/types.hpp" +#include "raptor/core/matrix_traits.hpp" #include "raptor/core/par_matrix.hpp" #include "raptor/ruge_stuben/par_interpolation.hpp" @@ -1978,33 +1979,9 @@ ParBSRMatrix * one_point_interpolation(const ParBSRMatrix & A, return ret; } -template -using is_bsr_or_csr = std::enable_if_t || - std::is_same_v, bool>; - -template struct is_bsr : std::false_type {}; -template <> struct is_bsr : std::true_type {}; -template inline constexpr bool is_bsr_v = is_bsr::value; - -BSRMatrix & bsr_cast(Matrix &mat) { return dynamic_cast(mat); } -const BSRMatrix & bsr_cast(const Matrix & mat) { return dynamic_cast(mat); } - -template struct matrix_value; -template <> struct matrix_value { using type = double; }; -template <> struct matrix_value { using type = double*; }; -template -using matrix_value_t = typename matrix_value::type; - -template struct sequential_matrix; -template<> struct sequential_matrix { using type = BSRMatrix; }; -template<> struct sequential_matrix { using type = CSRMatrix; }; -template -using sequential_matrix_t = typename sequential_matrix::type; - namespace lair { namespace { - /* Helper type providing access to received rows based on whether they are on_proc or off_proc