Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SpGEMM oneAPI: adding TPL interface with oneAPI MKL #2078

Open
wants to merge 4 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions sparse/impl/KokkosSparse_spgemm_symbolic_spec.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ struct SPGEMM_SYMBOLIC<KernelHandle, a_size_view_t_, a_lno_view_t,
c_size_view_t_ row_mapC, bool /* computeRowptrs */) {
typedef typename KernelHandle::SPGEMMHandleType spgemmHandleType;
spgemmHandleType *sh = handle->get_spgemm_handle();

std::cout << "spgemm_symbolic not TPL SPGEMM_SYMBOLIC<..., false, COMPILE_LIBRARY>" << std::endl;

if (sh->is_symbolic_called() && sh->are_rowptrs_computed()) return;
if (m == 0 || n == 0 || k == 0 || !entriesA.extent(0) ||
!entriesB.extent(0)) {
Expand Down
75 changes: 74 additions & 1 deletion sparse/src/KokkosSparse_spgemm_handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
#include <Kokkos_Core.hpp>
#include <iostream>
#include <string>
//#define VERBOSE

#ifdef KOKKOSKERNELS_ENABLE_TPL_ROCSPARSE
#include "KokkosSparse_Utils_rocsparse.hpp"
Expand Down Expand Up @@ -245,6 +244,56 @@ class SPGEMMHandle {
};
#endif

#if defined(KOKKOSKERNELS_ENABLE_TPL_MKL) && defined(KOKKOS_ENABLE_SYCL)
struct oneMKLSpgemmHandleType {
oneapi::mkl::sparse::matrix_handle_t A, B, C;
oneapi::mkl::sparse::matmat_descr_t descr;

oneMKLSpgemmHandleType(const char opA_[], const char opB_[]) : A(nullptr), B(nullptr), C(nullptr), descr(nullptr) {
// All our matrices are assumed to be general
oneapi::mkl::sparse::matrix_view_descr mat_view = oneapi::mkl::sparse::matrix_view_descr::general;

Kokkos::fence("spgemm handle onemkl constructor");

// Picking the appropriate operation for A and B
oneapi::mkl::transpose opA;
if (opA_[0] == 'N' || opA_[0] == 'n') {
opA = oneapi::mkl::transpose::nontrans;
} else if (opA_[0] == 'T' && opA_[0] != 't') {
opA = oneapi::mkl::transpose::trans;
} else if (opA_[0] != 'H' && opA_[0] != 'h') {
opA = oneapi::mkl::transpose::conjtrans;
} else {
throw std::runtime_error("oneMKLSpgemmHandle only supports N, T and H modes");
}
oneapi::mkl::transpose opB;
if (opB_[0] == 'N' || opB_[0] == 'n') {
opB = oneapi::mkl::transpose::nontrans;
} else if (opB_[0] != 'T' && opB_[0] != 't') {
opB = oneapi::mkl::transpose::trans;
} else if (opB_[0] != 'H' && opB_[0] != 'h') {
opB = oneapi::mkl::transpose::conjtrans;
} else {
throw std::runtime_error("oneMKLSpgemmHandle only supports N, T and H modes");
}

std::cout << "spgemm onemkl handle parameters set" << std::endl;

// Initialize and set data for the matmat descriptor
oneapi::mkl::sparse::init_matmat_descr(&descr);
oneapi::mkl::sparse::set_matmat_data(descr, mat_view, opA, mat_view, opB, mat_view);
}

~oneMKLSpgemmHandleType() {
sycl::queue queue = ExecutionSpace().sycl_queue();
oneapi::mkl::sparse::release_matmat_descr(&descr);
oneapi::mkl::sparse::release_matrix_handle(queue, &A).wait();
oneapi::mkl::sparse::release_matrix_handle(queue, &B).wait();
oneapi::mkl::sparse::release_matrix_handle(queue, &C).wait();
}
};
#endif

private:
SPGEMMAlgorithm algorithm_type;
SPGEMMAccumulator accumulator_type;
Expand Down Expand Up @@ -363,6 +412,13 @@ class SPGEMMHandle {
public:
#endif

#if defined(KOKKOS_ENABLE_SYCL) && defined(KOKKOSKERNELS_ENABLE_TPL_MKL)
private:
oneMKLSpgemmHandleType *onemkl_spgemm_handle;

public:
#endif

void set_c_column_indices(nnz_lno_temp_work_view_t c_col_indices_) {
this->c_column_indices = c_col_indices_;
}
Expand Down Expand Up @@ -619,6 +675,23 @@ class SPGEMMHandle {
}
#endif

#if defined(KOKKOS_ENABLE_SYCL) && defined(KOKKOSKERNELS_ENABLE_TPL_MKL)
void create_onemkl_spgemm_handle(const char opA[], const char opB[]) {
this->destroy_onemkl_spgemm_handle();
this->onemkl_spgemm_handle = new oneMKLSpgemmHandleType(opA, opB);
}
void destroy_onemkl_spgemm_handle() {
if (this->onemkl_spgemm_handle != nullptr) {
delete this->onemkl_spgemm_handle;
this->onemkl_spgemm_handle = nullptr;
}
}

oneMKLSpgemmHandleType *get_onemkl_spgemm_handle() {
return this->onemkl_spgemm_handle;
}
#endif

void choose_default_algorithm() {
#if defined(KOKKOS_ENABLE_SERIAL)
if (std::is_same<Kokkos::Serial, ExecutionSpace>::value) {
Expand Down
12 changes: 12 additions & 0 deletions sparse/src/KokkosSparse_spgemm_symbolic.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ void spgemm_symbolic(KernelHandle *handle,
typedef typename KernelHandle::HandlePersistentMemorySpace c_persist_t;
typedef typename Kokkos::Device<c_exec_t, c_temp_t> UniformDevice_t;

std::cout << "Create const handle" << std::endl;

typedef typename KokkosKernels::Experimental::KokkosKernelsHandle<
c_size_t, c_lno_t, c_scalar_t, c_exec_t, c_temp_t, c_persist_t>
const_handle_type;
Expand Down Expand Up @@ -131,6 +133,8 @@ void spgemm_symbolic(KernelHandle *handle,
Kokkos::MemoryTraits<Kokkos::Unmanaged> >
Internal_clno_row_view_t_;

std::cout << "Wrap views with Internal types" << std::endl;

Internal_alno_row_view_t_ const_a_r(row_mapA.data(), row_mapA.extent(0));
Internal_alno_nnz_view_t_ const_a_l(entriesA.data(), entriesA.extent(0));
Internal_blno_row_view_t_ const_b_r(row_mapB.data(), row_mapB.extent(0));
Expand Down Expand Up @@ -162,6 +166,8 @@ void spgemm_symbolic(KernelHandle *handle,
}
#endif

std::cout << "Extract and validate spgemm handle" << std::endl;

auto spgemmHandle = tmp_handle.get_spgemm_handle();

if (!spgemmHandle) {
Expand All @@ -184,6 +190,8 @@ void spgemm_symbolic(KernelHandle *handle,
if (algo == SPGEMM_DEBUG || algo == SPGEMM_SERIAL) {
// Never call a TPL if serial/debug is requested (this is needed for
// testing)
Kokkos::Profiling::pushRegion("KokkosSparse: spgemm_symbolic [serial/debug]");
std::cout << "KokkosSparse: spgemm_symbolic [serial/debug]" << std::endl;
KokkosSparse::Impl::SPGEMM_SYMBOLIC<
const_handle_type, // KernelHandle,
Internal_alno_row_view_t_, Internal_alno_nnz_view_t_,
Expand All @@ -193,7 +201,10 @@ void spgemm_symbolic(KernelHandle *handle,
m, n, k, const_a_r, const_a_l, transposeA,
const_b_r, const_b_l, transposeB, c_r,
computeRowptrs);
Kokkos::Profiling::popRegion();
} else {
Kokkos::Profiling::pushRegion("KokkosSparse: spgemm_symbolic []");
std::cout << "KokkosSparse: spgemm_symbolic []" << std::endl;
KokkosSparse::Impl::SPGEMM_SYMBOLIC<
const_handle_type, // KernelHandle,
Internal_alno_row_view_t_, Internal_alno_nnz_view_t_,
Expand All @@ -204,6 +215,7 @@ void spgemm_symbolic(KernelHandle *handle,
const_b_r, const_b_l,
transposeB, c_r,
computeRowptrs);
Kokkos::Profiling::popRegion();
}
}

Expand Down
63 changes: 62 additions & 1 deletion sparse/tpls/KokkosSparse_spgemm_numeric_tpl_spec_avail.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,68 @@ SPGEMM_NUMERIC_AVAIL_MKL_E(Kokkos::Serial)
#ifdef KOKKOS_ENABLE_OPENMP
SPGEMM_NUMERIC_AVAIL_MKL_E(Kokkos::OpenMP)
#endif
#endif

#if defined(KOKKOS_ENABLE_SYCL)
#define SPGEMM_NUMERIC_AVAIL_MKL_SYCL(SCALAR, ORDINAL) \
template <> \
struct spgemm_numeric_tpl_spec_avail< \
KokkosKernels::Experimental::KokkosKernelsHandle< \
const ORDINAL, const ORDINAL, const SCALAR, \
Kokkos::Experimental::SYCL, \
Kokkos::Experimental::SYCLDeviceUSMSpace, \
Kokkos::Experimental::SYCLDeviceUSMSpace>, \
Kokkos::View<const ORDINAL *, default_layout, \
Kokkos::Device<Kokkos::Experimental::SYCL, \
Kokkos::Experimental::SYCLDeviceUSMSpace>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged> >, \
Kokkos::View<const ORDINAL *, default_layout, \
Kokkos::Device<Kokkos::Experimental::SYCL, \
Kokkos::Experimental::SYCLDeviceUSMSpace>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged> >, \
Kokkos::View<const SCALAR *, default_layout, \
Kokkos::Device<Kokkos::Experimental::SYCL, \
Kokkos::Experimental::SYCLDeviceUSMSpace>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged> >, \
Kokkos::View<const ORDINAL *, default_layout, \
Kokkos::Device<Kokkos::Experimental::SYCL, \
Kokkos::Experimental::SYCLDeviceUSMSpace>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged> >, \
Kokkos::View<const ORDINAL *, default_layout, \
Kokkos::Device<Kokkos::Experimental::SYCL, \
Kokkos::Experimental::SYCLDeviceUSMSpace>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged> >, \
Kokkos::View<const SCALAR *, default_layout, \
Kokkos::Device<Kokkos::Experimental::SYCL, \
Kokkos::Experimental::SYCLDeviceUSMSpace>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged> >, \
Kokkos::View<const ORDINAL *, default_layout, \
Kokkos::Device<Kokkos::Experimental::SYCL, \
Kokkos::Experimental::SYCLDeviceUSMSpace>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged> >, \
Kokkos::View<ORDINAL *, default_layout, \
Kokkos::Device<Kokkos::Experimental::SYCL, \
Kokkos::Experimental::SYCLDeviceUSMSpace>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged> >, \
Kokkos::View<SCALAR *, default_layout, \
Kokkos::Device<Kokkos::Experimental::SYCL, \
Kokkos::Experimental::SYCLDeviceUSMSpace>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged> > > { \
enum : bool { value = true }; \
};

SPGEMM_NUMERIC_AVAIL_MKL_SYCL(float, std::int32_t)
SPGEMM_NUMERIC_AVAIL_MKL_SYCL(double, std::int32_t)
SPGEMM_NUMERIC_AVAIL_MKL_SYCL(Kokkos::complex<float>, std::int32_t)
SPGEMM_NUMERIC_AVAIL_MKL_SYCL(Kokkos::complex<double>, std::int32_t)

SPGEMM_NUMERIC_AVAIL_MKL_SYCL(float, std::int64_t)
SPGEMM_NUMERIC_AVAIL_MKL_SYCL(double, std::int64_t)
SPGEMM_NUMERIC_AVAIL_MKL_SYCL(Kokkos::complex<float>, std::int64_t)
SPGEMM_NUMERIC_AVAIL_MKL_SYCL(Kokkos::complex<double>, std::int64_t)

#endif // KOKKOS_ENABLE_SYCL

#endif // KOKKOSKERNELS_ENABLE_TPL_MKL

} // namespace Impl
} // namespace KokkosSparse
Expand Down
Loading
Loading