Skip to content

Commit

Permalink
GEMM: implicit MemberType
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikołaj Zuzek committed Sep 20, 2022
1 parent bdff413 commit 2292855
Show file tree
Hide file tree
Showing 8 changed files with 96 additions and 124 deletions.
2 changes: 1 addition & 1 deletion batched/dense/unit_test/Test_Batched_TeamInverseLU.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ struct Functor_BatchedTeamGemm {
}
member.team_barrier();

KokkosBlas::TeamGemm<MemberType, typename ParamTagType::transA,
KokkosBlas::TeamGemm<typename ParamTagType::transA,
typename ParamTagType::transB,
AlgoTagType>::invoke(member, _alpha, aa, bb, _beta,
cc);
Expand Down
2 changes: 1 addition & 1 deletion batched/dense/unit_test/Test_Batched_TeamSolveLU.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ struct Functor_BatchedTeamGemm {
}
member.team_barrier();

KokkosBlas::TeamGemm<MemberType, typename ParamTagType::transA,
KokkosBlas::TeamGemm<typename ParamTagType::transA,
typename ParamTagType::transB,
AlgoTagType>::invoke(member, _alpha, aa, bb, _beta,
cc);
Expand Down
3 changes: 1 addition & 2 deletions batched/dense/unit_test/Test_Batched_TeamVectorSolveUTV2.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,7 @@ struct Functor_TestBatchedTeamVectorSolveUTV2 {
TeamVectorCopy<MemberType, Trans::NoTranspose>::invoke(member, aa, ac);

/// bb = AA*xx
KokkosBlas::TeamVectorGemm<MemberType, Trans::NoTranspose,
Trans::NoTranspose,
KokkosBlas::TeamVectorGemm<Trans::NoTranspose, Trans::NoTranspose,
Algo::Gemm::Unblocked>::invoke(member, one, aa,
xx, zero, bb);
member.team_barrier();
Expand Down
104 changes: 48 additions & 56 deletions blas/impl/KokkosBlas3_team_gemm_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,10 @@ namespace KokkosBlas {
/// NT/NT
///

template <typename MemberType>
struct TeamGemm<MemberType, Trans::NoTranspose, Trans::NoTranspose,
Algo::Gemm::Unblocked> {
template <typename ScalarType, typename AViewType, typename BViewType,
typename CViewType>
template <>
struct TeamGemm<Trans::NoTranspose, Trans::NoTranspose, Algo::Gemm::Unblocked> {
template <typename MemberType, typename ScalarType, typename AViewType,
typename BViewType, typename CViewType>
KOKKOS_INLINE_FUNCTION static int invoke(
const MemberType &member, const ScalarType alpha, const AViewType &A,
const BViewType &B, const ScalarType beta, const CViewType &C) {
Expand All @@ -81,11 +80,10 @@ struct TeamGemm<MemberType, Trans::NoTranspose, Trans::NoTranspose,
}
};

template <typename MemberType>
struct TeamGemm<MemberType, Trans::NoTranspose, Trans::NoTranspose,
Algo::Gemm::Blocked> {
template <typename ScalarType, typename AViewType, typename BViewType,
typename CViewType>
template <>
struct TeamGemm<Trans::NoTranspose, Trans::NoTranspose, Algo::Gemm::Blocked> {
template <typename MemberType, typename ScalarType, typename AViewType,
typename BViewType, typename CViewType>
KOKKOS_INLINE_FUNCTION static int invoke(
const MemberType &member, const ScalarType alpha, const AViewType &A,
const BViewType &B, const ScalarType beta, const CViewType &C) {
Expand All @@ -102,11 +100,10 @@ struct TeamGemm<MemberType, Trans::NoTranspose, Trans::NoTranspose,
/// T/NT
///

template <typename MemberType>
struct TeamGemm<MemberType, Trans::Transpose, Trans::NoTranspose,
Algo::Gemm::Unblocked> {
template <typename ScalarType, typename AViewType, typename BViewType,
typename CViewType>
template <>
struct TeamGemm<Trans::Transpose, Trans::NoTranspose, Algo::Gemm::Unblocked> {
template <typename MemberType, typename ScalarType, typename AViewType,
typename BViewType, typename CViewType>
KOKKOS_INLINE_FUNCTION static int invoke(
const MemberType &member, const ScalarType alpha, const AViewType &A,
const BViewType &B, const ScalarType beta, const CViewType &C) {
Expand All @@ -119,11 +116,10 @@ struct TeamGemm<MemberType, Trans::Transpose, Trans::NoTranspose,
}
};

template <typename MemberType>
struct TeamGemm<MemberType, Trans::Transpose, Trans::NoTranspose,
Algo::Gemm::Blocked> {
template <typename ScalarType, typename AViewType, typename BViewType,
typename CViewType>
template <>
struct TeamGemm<Trans::Transpose, Trans::NoTranspose, Algo::Gemm::Blocked> {
template <typename MemberType, typename ScalarType, typename AViewType,
typename BViewType, typename CViewType>
KOKKOS_INLINE_FUNCTION static int invoke(
const MemberType &member, const ScalarType alpha, const AViewType &A,
const BViewType &B, const ScalarType beta, const CViewType &C) {
Expand All @@ -140,11 +136,10 @@ struct TeamGemm<MemberType, Trans::Transpose, Trans::NoTranspose,
/// NT/T
///

template <typename MemberType>
struct TeamGemm<MemberType, Trans::NoTranspose, Trans::Transpose,
Algo::Gemm::Unblocked> {
template <typename ScalarType, typename AViewType, typename BViewType,
typename CViewType>
template <>
struct TeamGemm<Trans::NoTranspose, Trans::Transpose, Algo::Gemm::Unblocked> {
template <typename MemberType, typename ScalarType, typename AViewType,
typename BViewType, typename CViewType>
KOKKOS_INLINE_FUNCTION static int invoke(
const MemberType &member, const ScalarType alpha, const AViewType &A,
const BViewType &B, const ScalarType beta, const CViewType &C) {
Expand All @@ -157,11 +152,10 @@ struct TeamGemm<MemberType, Trans::NoTranspose, Trans::Transpose,
}
};

template <typename MemberType>
struct TeamGemm<MemberType, Trans::NoTranspose, Trans::Transpose,
Algo::Gemm::Blocked> {
template <typename ScalarType, typename AViewType, typename BViewType,
typename CViewType>
template <>
struct TeamGemm<Trans::NoTranspose, Trans::Transpose, Algo::Gemm::Blocked> {
template <typename MemberType, typename ScalarType, typename AViewType,
typename BViewType, typename CViewType>
KOKKOS_INLINE_FUNCTION static int invoke(
const MemberType &member, const ScalarType alpha, const AViewType &A,
const BViewType &B, const ScalarType beta, const CViewType &C) {
Expand All @@ -178,11 +172,10 @@ struct TeamGemm<MemberType, Trans::NoTranspose, Trans::Transpose,
/// T/T
///

template <typename MemberType>
struct TeamGemm<MemberType, Trans::Transpose, Trans::Transpose,
Algo::Gemm::Unblocked> {
template <typename ScalarType, typename AViewType, typename BViewType,
typename CViewType>
template <>
struct TeamGemm<Trans::Transpose, Trans::Transpose, Algo::Gemm::Unblocked> {
template <typename MemberType, typename ScalarType, typename AViewType,
typename BViewType, typename CViewType>
KOKKOS_INLINE_FUNCTION static int invoke(
const MemberType &member, const ScalarType alpha, const AViewType &A,
const BViewType &B, const ScalarType beta, const CViewType &C) {
Expand All @@ -195,11 +188,10 @@ struct TeamGemm<MemberType, Trans::Transpose, Trans::Transpose,
}
};

template <typename MemberType>
struct TeamGemm<MemberType, Trans::Transpose, Trans::Transpose,
Algo::Gemm::Blocked> {
template <typename ScalarType, typename AViewType, typename BViewType,
typename CViewType>
template <>
struct TeamGemm<Trans::Transpose, Trans::Transpose, Algo::Gemm::Blocked> {
template <typename MemberType, typename ScalarType, typename AViewType,
typename BViewType, typename CViewType>
KOKKOS_INLINE_FUNCTION static int invoke(
const MemberType &member, const ScalarType alpha, const AViewType &A,
const BViewType &B, const ScalarType beta, const CViewType &C) {
Expand All @@ -224,11 +216,11 @@ struct TeamGemm<MemberType, Trans::Transpose, Trans::Transpose,
/// NT/NT
///

template <typename MemberType>
struct TeamVectorGemm<MemberType, Trans::NoTranspose, Trans::NoTranspose,
template <>
struct TeamVectorGemm<Trans::NoTranspose, Trans::NoTranspose,
Algo::Gemm::Unblocked> {
template <typename ScalarType, typename AViewType, typename BViewType,
typename CViewType>
template <typename MemberType, typename ScalarType, typename AViewType,
typename BViewType, typename CViewType>
KOKKOS_INLINE_FUNCTION static int invoke(
const MemberType &member, const ScalarType alpha, const AViewType &A,
const BViewType &B, const ScalarType beta, const CViewType &C) {
Expand All @@ -245,11 +237,11 @@ struct TeamVectorGemm<MemberType, Trans::NoTranspose, Trans::NoTranspose,
/// T/NT
///

template <typename MemberType>
struct TeamVectorGemm<MemberType, Trans::Transpose, Trans::NoTranspose,
template <>
struct TeamVectorGemm<Trans::Transpose, Trans::NoTranspose,
Algo::Gemm::Unblocked> {
template <typename ScalarType, typename AViewType, typename BViewType,
typename CViewType>
template <typename MemberType, typename ScalarType, typename AViewType,
typename BViewType, typename CViewType>
KOKKOS_INLINE_FUNCTION static int invoke(
const MemberType &member, const ScalarType alpha, const AViewType &A,
const BViewType &B, const ScalarType beta, const CViewType &C) {
Expand All @@ -266,11 +258,11 @@ struct TeamVectorGemm<MemberType, Trans::Transpose, Trans::NoTranspose,
/// NT/T
///

template <typename MemberType>
struct TeamVectorGemm<MemberType, Trans::NoTranspose, Trans::Transpose,
template <>
struct TeamVectorGemm<Trans::NoTranspose, Trans::Transpose,
Algo::Gemm::Unblocked> {
template <typename ScalarType, typename AViewType, typename BViewType,
typename CViewType>
template <typename MemberType, typename ScalarType, typename AViewType,
typename BViewType, typename CViewType>
KOKKOS_INLINE_FUNCTION static int invoke(
const MemberType &member, const ScalarType alpha, const AViewType &A,
const BViewType &B, const ScalarType beta, const CViewType &C) {
Expand All @@ -287,11 +279,11 @@ struct TeamVectorGemm<MemberType, Trans::NoTranspose, Trans::Transpose,
/// T/T
///

template <typename MemberType>
struct TeamVectorGemm<MemberType, Trans::Transpose, Trans::Transpose,
template <>
struct TeamVectorGemm<Trans::Transpose, Trans::Transpose,
Algo::Gemm::Unblocked> {
template <typename ScalarType, typename AViewType, typename BViewType,
typename CViewType>
template <typename MemberType, typename ScalarType, typename AViewType,
typename BViewType, typename CViewType>
KOKKOS_INLINE_FUNCTION static int invoke(
const MemberType &member, const ScalarType alpha, const AViewType &A,
const BViewType &B, const ScalarType beta, const CViewType &C) {
Expand Down
55 changes: 25 additions & 30 deletions blas/src/KokkosBlas3_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -277,11 +277,10 @@ struct SerialGemm {
/// Team Impl
/// =========

template <typename MemberType, typename ArgTransA, typename ArgTransB,
typename ArgAlgo>
template <typename ArgTransA, typename ArgTransB, typename ArgAlgo>
struct TeamGemm {
template <typename ScalarType, typename AViewType, typename BViewType,
typename CViewType>
template <typename MemberType, typename ScalarType, typename AViewType,
typename BViewType, typename CViewType>
KOKKOS_INLINE_FUNCTION static int invoke(
const MemberType& member, const ScalarType alpha, const AViewType& A,
const BViewType& B, const ScalarType beta, const CViewType& C);
Expand All @@ -291,11 +290,10 @@ struct TeamGemm {
/// TeamVector Impl
/// =========

template <typename MemberType, typename ArgTransA, typename ArgTransB,
typename ArgAlgo>
template <typename ArgTransA, typename ArgTransB, typename ArgAlgo>
struct TeamVectorGemm {
template <typename ScalarType, typename AViewType, typename BViewType,
typename CViewType>
template <typename MemberType, typename ScalarType, typename AViewType,
typename BViewType, typename CViewType>
KOKKOS_INLINE_FUNCTION static int invoke(
const MemberType& member, const ScalarType alpha, const AViewType& A,
const BViewType& B, const ScalarType beta, const CViewType& C);
Expand All @@ -304,20 +302,19 @@ struct TeamVectorGemm {
///
/// Selective Interface
///
template <typename MemberType, typename ArgTransA, typename ArgTransB,
typename ArgMode, typename ArgAlgo = Algo::Gemm::Default>
template <typename ArgTransA, typename ArgTransB, typename ArgMode,
typename ArgAlgo = Algo::Gemm::Default>
struct Gemm {
template <typename ScalarType, typename AViewType, typename BViewType,
typename CViewType>
template <typename MemberType, typename ScalarType, typename AViewType,
typename BViewType, typename CViewType>
KOKKOS_FORCEINLINE_FUNCTION static int invoke(
const MemberType& member, const ScalarType alpha, const AViewType& A,
const BViewType& B, const ScalarType beta, const CViewType& C);
};

template <typename MemberType, typename ArgTransA, typename ArgTransB,
typename ArgAlgo>
struct Gemm<MemberType, ArgTransA, ArgTransB, Mode::Serial, ArgAlgo> {
template <typename ScalarType, typename AViewType,
template <typename ArgTransA, typename ArgTransB, typename ArgAlgo>
struct Gemm<ArgTransA, ArgTransB, Mode::Serial, ArgAlgo> {
template <typename MemberType, typename ScalarType, typename AViewType,
typename BViewType, typename CViewType>
KOKKOS_FORCEINLINE_FUNCTION static int invoke(const MemberType& /* member */,
const ScalarType alpha,
Expand All @@ -330,29 +327,27 @@ struct Gemm<MemberType, ArgTransA, ArgTransB, Mode::Serial, ArgAlgo> {
}
};

template <typename MemberType, typename ArgTransA, typename ArgTransB,
typename ArgAlgo>
struct Gemm<MemberType, ArgTransA, ArgTransB, Mode::Team, ArgAlgo> {
template <typename ScalarType, typename AViewType, typename BViewType,
typename CViewType>
template <typename ArgTransA, typename ArgTransB, typename ArgAlgo>
struct Gemm<ArgTransA, ArgTransB, Mode::Team, ArgAlgo> {
template <typename MemberType, typename ScalarType, typename AViewType,
typename BViewType, typename CViewType>
KOKKOS_FORCEINLINE_FUNCTION static int invoke(
const MemberType& member, const ScalarType alpha, const AViewType& A,
const BViewType& B, const ScalarType beta, const CViewType& C) {
return TeamGemm<MemberType, ArgTransA, ArgTransB, ArgAlgo>::invoke(
member, alpha, A, B, beta, C);
return TeamGemm<ArgTransA, ArgTransB, ArgAlgo>::invoke(member, alpha, A, B,
beta, C);
}
};

template <typename MemberType, typename ArgTransA, typename ArgTransB,
typename ArgAlgo>
struct Gemm<MemberType, ArgTransA, ArgTransB, Mode::TeamVector, ArgAlgo> {
template <typename ScalarType, typename AViewType, typename BViewType,
typename CViewType>
template <typename ArgTransA, typename ArgTransB, typename ArgAlgo>
struct Gemm<ArgTransA, ArgTransB, Mode::TeamVector, ArgAlgo> {
template <typename MemberType, typename ScalarType, typename AViewType,
typename BViewType, typename CViewType>
KOKKOS_FORCEINLINE_FUNCTION static int invoke(
const MemberType& member, const ScalarType alpha, const AViewType& A,
const BViewType& B, const ScalarType beta, const CViewType& C) {
return TeamVectorGemm<MemberType, ArgTransA, ArgTransB, ArgAlgo>::invoke(
member, alpha, A, B, beta, C);
return TeamVectorGemm<ArgTransA, ArgTransB, ArgAlgo>::invoke(member, alpha,
A, B, beta, C);
}
};

Expand Down
2 changes: 1 addition & 1 deletion blas/unit_test/Test_Blas3_team_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ struct Functor_TestBatchedTeamGemm {
auto bb = Kokkos::subview(_b, k, Kokkos::ALL(), Kokkos::ALL());
auto cc = Kokkos::subview(_c, k, Kokkos::ALL(), Kokkos::ALL());

KokkosBlas::TeamGemm<MemberType, typename ParamTagType::transA,
KokkosBlas::TeamGemm<typename ParamTagType::transA,
typename ParamTagType::transB,
AlgoTagType>::invoke(member, _alpha, aa, bb, _beta,
cc);
Expand Down
2 changes: 1 addition & 1 deletion blas/unit_test/Test_Blas3_teamvector_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ struct Functor_TestBatchedTeamVector {
auto bb = Kokkos::subview(_b, k, Kokkos::ALL(), Kokkos::ALL());
auto cc = Kokkos::subview(_c, k, Kokkos::ALL(), Kokkos::ALL());

KokkosBlas::TeamVectorGemm<MemberType, typename ParamTagType::transA,
KokkosBlas::TeamVectorGemm<typename ParamTagType::transA,
typename ParamTagType::transB,
AlgoTagType>::invoke(member, _alpha, aa, bb,
_beta, cc);
Expand Down
Loading

0 comments on commit 2292855

Please sign in to comment.