Skip to content

Commit

Permalink
GEMM: implicit MemberType
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikołaj Zuzek committed Sep 9, 2022
1 parent 83a091b commit ef75dc3
Show file tree
Hide file tree
Showing 8 changed files with 96 additions and 124 deletions.
50 changes: 18 additions & 32 deletions perf_test/blas/blas3/KokkosBlas3_gemm_perf_test.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -719,9 +719,8 @@ struct parallel_batched_gemm {
auto svB = Kokkos::subview(gemm_args_.B, i, Kokkos::ALL(), Kokkos::ALL());
auto svC = Kokkos::subview(gemm_args_.C, i, Kokkos::ALL(), Kokkos::ALL());

KokkosBlas::TeamGemm<MemberType, TransAType, TransBType,
BlockingType>::invoke(member, gemm_args_.alpha, svA,
svB, gemm_args_.beta, svC);
KokkosBlas::TeamGemm<TransAType, TransBType, BlockingType>::invoke(
member, gemm_args_.alpha, svA, svB, gemm_args_.beta, svC);
}

KOKKOS_INLINE_FUNCTION
Expand All @@ -731,9 +730,8 @@ struct parallel_batched_gemm {
auto svB = Kokkos::subview(gemm_args_.B, Kokkos::ALL(), Kokkos::ALL(), i);
auto svC = Kokkos::subview(gemm_args_.C, Kokkos::ALL(), Kokkos::ALL(), i);

KokkosBlas::TeamGemm<MemberType, TransAType, TransBType,
BlockingType>::invoke(member, gemm_args_.alpha, svA,
svB, gemm_args_.beta, svC);
KokkosBlas::TeamGemm<TransAType, TransBType, BlockingType>::invoke(
member, gemm_args_.alpha, svA, svB, gemm_args_.beta, svC);
}

KOKKOS_INLINE_FUNCTION
Expand All @@ -746,10 +744,8 @@ struct parallel_batched_gemm {
auto svC =
Kokkos::subview(gemm_args_.C, team_idx, Kokkos::ALL(), Kokkos::ALL());

KokkosBlas::TeamVectorGemm<MemberType, TransAType, TransBType,
BlockingType>::invoke(member, gemm_args_.alpha,
svA, svB, gemm_args_.beta,
svC);
KokkosBlas::TeamVectorGemm<TransAType, TransBType, BlockingType>::invoke(
member, gemm_args_.alpha, svA, svB, gemm_args_.beta, svC);
}

KOKKOS_INLINE_FUNCTION
Expand All @@ -763,10 +759,8 @@ struct parallel_batched_gemm {
auto svC =
Kokkos::subview(gemm_args_.C, Kokkos::ALL(), Kokkos::ALL(), team_idx);

KokkosBlas::TeamVectorGemm<MemberType, TransAType, TransBType,
BlockingType>::invoke(member, gemm_args_.alpha,
svA, svB, gemm_args_.beta,
svC);
KokkosBlas::TeamVectorGemm<TransAType, TransBType, BlockingType>::invoke(
member, gemm_args_.alpha, svA, svB, gemm_args_.beta, svC);
}

KOKKOS_INLINE_FUNCTION
Expand All @@ -782,7 +776,7 @@ struct parallel_batched_gemm {
auto svC = Kokkos::subview(gemm_args_.Cv.ivec_4d, i, Kokkos::ALL(),
Kokkos::ALL(), vector_lane);

KokkosBlas::Gemm<MemberType, TransAType, TransBType, AlgoMode,
KokkosBlas::Gemm<TransAType, TransBType, AlgoMode,
BlockingType>::invoke(member, gemm_args_.alpha, svA,
svB, gemm_args_.beta, svC);
});
Expand All @@ -802,7 +796,7 @@ struct parallel_batched_gemm {
auto svC = Kokkos::subview(gemm_args_.Cv.ivec_4d, vector_lane,
Kokkos::ALL(), Kokkos::ALL(), i);

KokkosBlas::Gemm<MemberType, TransAType, TransBType, AlgoMode,
KokkosBlas::Gemm<TransAType, TransBType, AlgoMode,
BlockingType>::invoke(member, gemm_args_.alpha, svA,
svB, gemm_args_.beta, svC);
});
Expand Down Expand Up @@ -1066,10 +1060,8 @@ struct parallel_batched_gemm_experiment2_3_4 {

// Uses TeamThreadRange over C-rows
// ThreadVectorRange over C-cols
KokkosBlas::TeamVectorGemm<MemberType, TransAType, TransBType,
BlockingType>::invoke(member, gemm_args_.alpha,
svA, svB, gemm_args_.beta,
svC);
KokkosBlas::TeamVectorGemm<TransAType, TransBType, BlockingType>::invoke(
member, gemm_args_.alpha, svA, svB, gemm_args_.beta, svC);
}

// Experiment 3
Expand All @@ -1096,10 +1088,8 @@ struct parallel_batched_gemm_experiment2_3_4 {
auto svC_col = Kokkos::subview(svC, Kokkos::ALL(), lane_idx);
// TeamGemm Calls TeamThreadRange over M*N meaning the flat M*N array
// is split over all threads of the team
KokkosBlas::TeamGemm<MemberType, TransAType, TransBType,
BlockingType>::invoke(member, gemm_args_.alpha,
svA, svB_col,
gemm_args_.beta, svC_col);
KokkosBlas::TeamGemm<TransAType, TransBType, BlockingType>::invoke(
member, gemm_args_.alpha, svA, svB_col, gemm_args_.beta, svC_col);
});
}

Expand Down Expand Up @@ -1128,10 +1118,8 @@ struct parallel_batched_gemm_experiment2_3_4 {
auto svC_row = Kokkos::subview(svC, lane_idx, Kokkos::ALL());
// TeamGemm Calls TeamThreadRange over M*N meaning the flat M*N array
// is split over all threads of the team
KokkosBlas::TeamGemm<MemberType, TransAType, TransBType,
BlockingType>::invoke(member, gemm_args_.alpha,
svA_row, svB,
gemm_args_.beta, svC_row);
KokkosBlas::TeamGemm<TransAType, TransBType, BlockingType>::invoke(
member, gemm_args_.alpha, svA_row, svB, gemm_args_.beta, svC_row);
});
}
};
Expand Down Expand Up @@ -1412,10 +1400,8 @@ class parallel_batched_gemm_experiment6 {
auto svC = Kokkos::subview(C, i, Kokkos::ALL(), Kokkos::ALL());

// Uses two serial for-loops internally
KokkosBlas::TeamVectorGemm<MemberType, TransAType, TransBType,
BlockingType>::invoke(member, gemm_args.alpha,
svA, svB, gemm_args.beta,
svC);
KokkosBlas::TeamVectorGemm<TransAType, TransBType, BlockingType>::invoke(
member, gemm_args.alpha, svA, svB, gemm_args.beta, svC);
}
};

Expand Down
55 changes: 25 additions & 30 deletions src/blas/KokkosBlas3_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -277,11 +277,10 @@ struct SerialGemm {
/// Team Impl
/// =========

template <typename MemberType, typename ArgTransA, typename ArgTransB,
typename ArgAlgo>
template <typename ArgTransA, typename ArgTransB, typename ArgAlgo>
struct TeamGemm {
template <typename ScalarType, typename AViewType, typename BViewType,
typename CViewType>
template <typename MemberType, typename ScalarType, typename AViewType,
typename BViewType, typename CViewType>
KOKKOS_INLINE_FUNCTION static int invoke(
const MemberType& member, const ScalarType alpha, const AViewType& A,
const BViewType& B, const ScalarType beta, const CViewType& C);
Expand All @@ -291,11 +290,10 @@ struct TeamGemm {
/// TeamVector Impl
/// =========

template <typename MemberType, typename ArgTransA, typename ArgTransB,
typename ArgAlgo>
template <typename ArgTransA, typename ArgTransB, typename ArgAlgo>
struct TeamVectorGemm {
template <typename ScalarType, typename AViewType, typename BViewType,
typename CViewType>
template <typename MemberType, typename ScalarType, typename AViewType,
typename BViewType, typename CViewType>
KOKKOS_INLINE_FUNCTION static int invoke(
const MemberType& member, const ScalarType alpha, const AViewType& A,
const BViewType& B, const ScalarType beta, const CViewType& C);
Expand All @@ -304,20 +302,19 @@ struct TeamVectorGemm {
///
/// Selective Interface
///
template <typename MemberType, typename ArgTransA, typename ArgTransB,
typename ArgMode, typename ArgAlgo = Algo::Gemm::Default>
template <typename ArgTransA, typename ArgTransB, typename ArgMode,
typename ArgAlgo = Algo::Gemm::Default>
struct Gemm {
template <typename ScalarType, typename AViewType, typename BViewType,
typename CViewType>
template <typename MemberType, typename ScalarType, typename AViewType,
typename BViewType, typename CViewType>
KOKKOS_FORCEINLINE_FUNCTION static int invoke(
const MemberType& member, const ScalarType alpha, const AViewType& A,
const BViewType& B, const ScalarType beta, const CViewType& C);
};

template <typename MemberType, typename ArgTransA, typename ArgTransB,
typename ArgAlgo>
struct Gemm<MemberType, ArgTransA, ArgTransB, Mode::Serial, ArgAlgo> {
template <typename ScalarType, typename AViewType,
template <typename ArgTransA, typename ArgTransB, typename ArgAlgo>
struct Gemm<ArgTransA, ArgTransB, Mode::Serial, ArgAlgo> {
template <typename MemberType, typename ScalarType, typename AViewType,
typename BViewType, typename CViewType>
KOKKOS_FORCEINLINE_FUNCTION static int invoke(const MemberType& /* member */,
const ScalarType alpha,
Expand All @@ -330,29 +327,27 @@ struct Gemm<MemberType, ArgTransA, ArgTransB, Mode::Serial, ArgAlgo> {
}
};

template <typename MemberType, typename ArgTransA, typename ArgTransB,
typename ArgAlgo>
struct Gemm<MemberType, ArgTransA, ArgTransB, Mode::Team, ArgAlgo> {
template <typename ScalarType, typename AViewType, typename BViewType,
typename CViewType>
template <typename ArgTransA, typename ArgTransB, typename ArgAlgo>
struct Gemm<ArgTransA, ArgTransB, Mode::Team, ArgAlgo> {
template <typename MemberType, typename ScalarType, typename AViewType,
typename BViewType, typename CViewType>
KOKKOS_FORCEINLINE_FUNCTION static int invoke(
const MemberType& member, const ScalarType alpha, const AViewType& A,
const BViewType& B, const ScalarType beta, const CViewType& C) {
return TeamGemm<MemberType, ArgTransA, ArgTransB, ArgAlgo>::invoke(
member, alpha, A, B, beta, C);
return TeamGemm<ArgTransA, ArgTransB, ArgAlgo>::invoke(member, alpha, A, B,
beta, C);
}
};

template <typename MemberType, typename ArgTransA, typename ArgTransB,
typename ArgAlgo>
struct Gemm<MemberType, ArgTransA, ArgTransB, Mode::TeamVector, ArgAlgo> {
template <typename ScalarType, typename AViewType, typename BViewType,
typename CViewType>
template <typename ArgTransA, typename ArgTransB, typename ArgAlgo>
struct Gemm<ArgTransA, ArgTransB, Mode::TeamVector, ArgAlgo> {
template <typename MemberType, typename ScalarType, typename AViewType,
typename BViewType, typename CViewType>
KOKKOS_FORCEINLINE_FUNCTION static int invoke(
const MemberType& member, const ScalarType alpha, const AViewType& A,
const BViewType& B, const ScalarType beta, const CViewType& C) {
return TeamVectorGemm<MemberType, ArgTransA, ArgTransB, ArgAlgo>::invoke(
member, alpha, A, B, beta, C);
return TeamVectorGemm<ArgTransA, ArgTransB, ArgAlgo>::invoke(member, alpha,
A, B, beta, C);
}
};

Expand Down
Loading

0 comments on commit ef75dc3

Please sign in to comment.