Skip to content

Commit 9e6e1e2

Browse files
committed
progress...
1 parent 7e7ce56 commit 9e6e1e2

9 files changed

+402
-203
lines changed

blas/impl/KokkosBlas1_scal_impl.hpp

+32-64
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ struct V_Scal_Functor {
5757
XV m_x;
5858
AV m_a;
5959

60-
V_Scal_Functor(const RV& r, const XV& x, const AV& a, const SizeType startingColumn)
60+
V_Scal_Functor(const RV& r, const XV& x, const AV& a)
6161
: m_r(r), m_x(x), m_a(a) {
6262
static_assert(Kokkos::is_view<RV>::value,
6363
"V_Scal_Functor: RV is not a Kokkos::View.");
@@ -68,15 +68,6 @@ struct V_Scal_Functor {
6868
"V_Scal_Functor: XV is not a Kokkos::View.");
6969
static_assert(RV::rank == 1, "V_Scal_Functor: RV is not rank 1.");
7070
static_assert(XV::rank == 1, "V_Scal_Functor: XV is not rank 1.");
71-
72-
73-
if constexpr (Kokkos::is_view_v<AV>) {
74-
if (startingColumn != 0) {
75-
m_a = Kokkos::subview(
76-
a,
77-
std::make_pair(startingColumn, static_cast<SizeType>(a.extent(0))));
78-
}
79-
}
8071
}
8172

8273
KOKKOS_INLINE_FUNCTION
@@ -105,54 +96,31 @@ struct V_Scal_Functor {
10596
}
10697
};
10798

108-
#if 0
109-
// Partial specialization of V_Scal_Functor that lets a be a scalar
110-
// (rather than a 1-D View, as in the most general version above).
111-
// This functor computes any of the following:
112-
//
113-
// 1. Y(i) = alpha*X(i) for alpha in -1,0,1
114-
// 2. Y(i) = a*X(i)
115-
template <class RV, class XV, int scalar_x, class SizeType>
116-
struct V_Scal_Functor<RV, typename XV::non_const_value_type, XV, scalar_x,
117-
SizeType> {
118-
typedef SizeType size_type;
119-
typedef Kokkos::ArithTraits<typename RV::non_const_value_type> ATS;
99+
/*! \brief
120100
121-
RV m_r;
122-
XV m_x;
123-
const typename XV::non_const_value_type m_a;
101+
r(i) = av * x(i)
102+
r(i) = av() * x(i)
124103
125-
V_Scal_Functor(const RV& r, const XV& x,
126-
const typename XV::non_const_value_type& a,
127-
const SizeType /* startingColumn */)
128-
: m_r(r), m_x(x), m_a(a) {}
104+
\param space
105+
\param r
106+
\param av
107+
\param x
108+
\param alphaHint A KokkosKernels::Impl::ScalarHint corresponding to the value of av. If not KokkosKernels::Impl:ß:ScalarHint::none, may be used to optimize the implementation
109+
110+
\tparam SizeType
111+
\tparam ExecutionSpace
112+
\tparam RV
113+
\tparam AV
114+
\tparam XV
115+
116+
*/
117+
template <typename SizeType, typename ExecutionSpace, typename RV, typename AV, typename XV>
118+
void V_Scal_Generic(const ExecutionSpace& space, const RV& r, const AV& av,
119+
const XV& x,
120+
const KokkosKernels::Impl::ScalarHint &alphaHint = KokkosKernels::Impl::ScalarHint::none) {
121+
122+
// TODO: assert some things about AV
129123

130-
KOKKOS_INLINE_FUNCTION
131-
void operator()(const size_type& i) const {
132-
if (scalar_x == 0) {
133-
m_r(i) = ATS::zero();
134-
}
135-
if (scalar_x == -1) {
136-
m_r(i) = -m_x(i);
137-
}
138-
if (scalar_x == 1) {
139-
m_r(i) = m_x(i);
140-
}
141-
if (scalar_x == 2) {
142-
m_r(i) = m_a * m_x(i);
143-
}
144-
}
145-
};
146-
#endif
147-
148-
// Variant of MV_Scal_Generic for single vectors (1-D Views) r and x.
149-
// As above, av is either a 1-D View (and only its first entry will be
150-
// read), or a scalar.
151-
template <class execution_space, class RV, class AV, class XV, class SizeType>
152-
void V_Scal_Generic(const execution_space& space, const RV& r, const AV& av,
153-
const XV& x,
154-
const SizeType startingColumn,
155-
const KokkosKernels::Impl::ScalarHint &alphaHint) {
156124
static_assert(Kokkos::is_view<RV>::value,
157125
"V_Scal_Generic: RV is not a Kokkos::View.");
158126
static_assert(Kokkos::is_view<XV>::value,
@@ -161,26 +129,26 @@ void V_Scal_Generic(const execution_space& space, const RV& r, const AV& av,
161129
static_assert(XV::rank == 1, "V_Scal_Generic: XV is not rank 1.");
162130

163131
const SizeType numRows = x.extent(0);
164-
Kokkos::RangePolicy<execution_space, SizeType> policy(space, 0, numRows);
132+
Kokkos::RangePolicy<ExecutionSpace, SizeType> policy(space, 0, numRows);
165133

166134
if (alphaHint == KokkosKernels::Impl::ScalarHint::zero) {
167-
V_Scal_Functor<RV, AV, XV, KokkosKernels::Impl::ScalarHint::zero, SizeType> op(r, x, av, startingColumn);
168-
Kokkos::parallel_for("KokkosBlas::Scal::S0", policy, op);
135+
V_Scal_Functor<RV, AV, XV, KokkosKernels::Impl::ScalarHint::zero, SizeType> op(r, x, av);
136+
Kokkos::parallel_for("KokkosBlas::Scal::0", policy, op);
169137
return;
170138
}
171139
else if (alphaHint == KokkosKernels::Impl::ScalarHint::neg_one) {
172-
V_Scal_Functor<RV, AV, XV, KokkosKernels::Impl::ScalarHint::neg_one, SizeType> op(r, x, av, startingColumn);
173-
Kokkos::parallel_for("KokkosBlas::Scal::S1", policy, op);
140+
V_Scal_Functor<RV, AV, XV, KokkosKernels::Impl::ScalarHint::neg_one, SizeType> op(r, x, av);
141+
Kokkos::parallel_for("KokkosBlas::Scal::-1", policy, op);
174142
return;
175143
}
176144
else if (alphaHint == KokkosKernels::Impl::ScalarHint::pos_one) {
177-
V_Scal_Functor<RV, AV, XV, KokkosKernels::Impl::ScalarHint::pos_one, SizeType> op(r, x, av, startingColumn);
178-
Kokkos::parallel_for("KokkosBlas::Scal::S2", policy, op);
145+
V_Scal_Functor<RV, AV, XV, KokkosKernels::Impl::ScalarHint::pos_one, SizeType> op(r, x, av);
146+
Kokkos::parallel_for("KokkosBlas::Scal::1", policy, op);
179147
return;
180148
}
181149

182-
V_Scal_Functor<RV, AV, XV, KokkosKernels::Impl::ScalarHint::none, SizeType> op(r, x, av, startingColumn);
183-
Kokkos::parallel_for("KokkosBlas::Scal::S3", policy, op);
150+
V_Scal_Functor<RV, AV, XV, KokkosKernels::Impl::ScalarHint::none, SizeType> op(r, x, av);
151+
Kokkos::parallel_for("KokkosBlas::Scal::none", policy, op);
184152
}
185153

186154
} // namespace Impl

blas/impl/KokkosBlas1_scal_mv_impl.hpp

+22-11
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,7 @@ void MV_Scal_Generic(const execution_space& space, const RVector& r,
422422
template <class execution_space, class RMV, class AV, class XMV, class SizeType>
423423
void MV_Scal_Invoke_Left(const execution_space& space, const RMV& r,
424424
const AV& av, const XMV& x,
425-
const KokkosKernels::Impl::ScalarHint &a = KokkosKernels::Impl::ScalarHint::none) {
425+
const KokkosKernels::Impl::ScalarHint &aHint = KokkosKernels::Impl::ScalarHint::none) {
426426
const SizeType numCols = x.extent(1);
427427

428428
#if KOKKOSBLAS_OPTIMIZATION_LEVEL_SCAL <= 2
@@ -440,7 +440,7 @@ void MV_Scal_Invoke_Left(const execution_space& space, const RMV& r,
440440
typedef decltype(R_cur) RMV2D;
441441

442442
MV_Scal_Unrolled<execution_space, RMV2D, AV, XMV2D, 8, SizeType>(
443-
space, R_cur, av, X_cur, j, a);
443+
space, R_cur, av, X_cur, j, aHint);
444444
}
445445
for (; j + 4 <= numCols; j += 4) {
446446
const std::pair<SizeType, SizeType> rng(j, j + 4);
@@ -450,7 +450,7 @@ void MV_Scal_Invoke_Left(const execution_space& space, const RMV& r,
450450
typedef decltype(R_cur) RMV2D;
451451

452452
MV_Scal_Unrolled<execution_space, RMV2D, AV, XMV2D, 4, SizeType>(
453-
space, R_cur, av, X_cur, j, a);
453+
space, R_cur, av, X_cur, j, aHint);
454454
}
455455
for (; j < numCols; ++j) {
456456
// RMV and XMV need to turn 1-D.
@@ -459,8 +459,21 @@ void MV_Scal_Invoke_Left(const execution_space& space, const RMV& r,
459459
typedef decltype(r_cur) RV;
460460
typedef decltype(x_cur) XV;
461461

462-
V_Scal_Generic<execution_space, RV, AV, XV, SizeType>(space, r_cur, av,
463-
x_cur, j, a);
462+
// If AV is a rank-one vector, get a rank-0 subview
463+
// Otherwise, just pass along AV as-is
464+
// can't short-circuit if constexpr :(
465+
if constexpr (Kokkos::is_view_v<AV>) {
466+
if constexpr (AV::rank == 1) {
467+
auto a_cur = Kokkos::subview(av, j);
468+
V_Scal_Generic<SizeType>(space, r_cur, a_cur, x_cur, aHint);
469+
} else {
470+
V_Scal_Generic<SizeType>(space, r_cur, av, x_cur, aHint);
471+
}
472+
} else {
473+
V_Scal_Generic<SizeType>(space, r_cur, av, x_cur, aHint);
474+
}
475+
476+
464477
}
465478

466479
#else // KOKKOSBLAS_OPTIMIZATION_LEVEL_SCAL > 2
@@ -472,7 +485,7 @@ void MV_Scal_Invoke_Left(const execution_space& space, const RMV& r,
472485
typedef decltype(r_0) RV;
473486
typedef decltype(x_0) XV;
474487

475-
V_Scal_Generic<execution_space, RV, AV, XV, SizeType>(space, r_0, av, x_0,
488+
V_Scal_Generic<SizeType>(space, r_0, av, x_0,
476489
0, a);
477490
break;
478491
}
@@ -537,7 +550,7 @@ void MV_Scal_Invoke_Left(const execution_space& space, const RMV& r,
537550
space, r, av, x, 0, a);
538551
break;
539552
default:
540-
MV_Scal_Generic<execution_space, RMV, AV, XMV, SizeType>(space, r, av, x,
553+
MV_Scal_Generic<SizeType>(space, r, av, x,
541554
0, a);
542555
}
543556

@@ -574,11 +587,9 @@ void MV_Scal_Invoke_Right(const execution_space& space, const RMV& r,
574587

575588
RV r_0 = Kokkos::subview(r, Kokkos::ALL(), 0);
576589
XV x_0 = Kokkos::subview(x, Kokkos::ALL(), 0);
577-
V_Scal_Generic<execution_space, RMV, aVector, XMV, 1, SizeType>(space, r_0,
578-
av, x_0, a);
590+
V_Scal_Generic<SizeType>(space, r_0, av, x_0, a);
579591
} else {
580-
MV_Scal_Generic<execution_space, RMV, aVector, XMV, SizeType>(space, r, av,
581-
x, a);
592+
MV_Scal_Generic<SizeType>(space, r, av, x, a);
582593
}
583594
}
584595

blas/impl/KokkosBlas1_scal_spec.hpp

+66-18
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,12 @@ struct scal_eti_spec_avail {
3939

4040
//
4141
// Macro for declaration of full specialization availability
42-
// KokkosBlas::Impl::Scal for rank == 1. This is NOT for users!!! All
42+
// KokkosBlas::Impl::Scal for rank == 1 R and X. This is NOT for users!!! All
4343
// the declarations of full specializations go in this header file.
4444
// We may spread out definitions (see _INST macro below) across one or
4545
// more .cpp files.
4646
//
47+
// Alpha can either be scalar or rank 0
4748
#define KOKKOSBLAS1_SCAL_ETI_SPEC_AVAIL(SCALAR, LAYOUT, EXEC_SPACE, MEM_SPACE) \
4849
template <> \
4950
struct scal_eti_spec_avail< \
@@ -56,15 +57,28 @@ struct scal_eti_spec_avail {
5657
Kokkos::MemoryTraits<Kokkos::Unmanaged> >, \
5758
1> { \
5859
enum : bool { value = true }; \
60+
}; \
61+
template <> \
62+
struct scal_eti_spec_avail< \
63+
EXEC_SPACE, \
64+
Kokkos::View<SCALAR*, LAYOUT, Kokkos::Device<EXEC_SPACE, MEM_SPACE>, \
65+
Kokkos::MemoryTraits<Kokkos::Unmanaged> >, \
66+
Kokkos::View<SCALAR, LAYOUT, Kokkos::Device<EXEC_SPACE, MEM_SPACE>, \
67+
Kokkos::MemoryTraits<Kokkos::Unmanaged> >, \
68+
Kokkos::View<const SCALAR*, LAYOUT, \
69+
Kokkos::Device<EXEC_SPACE, MEM_SPACE>, \
70+
Kokkos::MemoryTraits<Kokkos::Unmanaged> >, \
71+
1> { \
72+
enum : bool { value = true }; \
5973
};
60-
6174
//
6275
// Macro for declaration of full specialization availability
63-
// KokkosBlas::Impl::Scal for rank == 2. This is NOT for users!!! All
76+
// KokkosBlas::Impl::Scal for rank == 2 R and X. This is NOT for users!!! All
6477
// the declarations of full specializations go in this header file.
6578
// We may spread out definitions (see _DEF macro below) across one or
6679
// more .cpp files.
6780
//
81+
// Alpha can either be rank 1, rank 0, or scalar
6882
#define KOKKOSBLAS1_SCAL_MV_ETI_SPEC_AVAIL(SCALAR, LAYOUT, EXEC_SPACE, \
6983
MEM_SPACE) \
7084
template <> \
@@ -82,6 +96,20 @@ struct scal_eti_spec_avail {
8296
enum : bool { value = true }; \
8397
}; \
8498
template <> \
99+
struct scal_eti_spec_avail< \
100+
EXEC_SPACE, \
101+
Kokkos::View<SCALAR**, LAYOUT, Kokkos::Device<EXEC_SPACE, MEM_SPACE>, \
102+
Kokkos::MemoryTraits<Kokkos::Unmanaged> >, \
103+
Kokkos::View<const SCALAR, LAYOUT, \
104+
Kokkos::Device<EXEC_SPACE, MEM_SPACE>, \
105+
Kokkos::MemoryTraits<Kokkos::Unmanaged> >, \
106+
Kokkos::View<const SCALAR**, LAYOUT, \
107+
Kokkos::Device<EXEC_SPACE, MEM_SPACE>, \
108+
Kokkos::MemoryTraits<Kokkos::Unmanaged> >, \
109+
2> { \
110+
enum : bool { value = true }; \
111+
}; \
112+
template <> \
85113
struct scal_eti_spec_avail< \
86114
EXEC_SPACE, \
87115
Kokkos::View<SCALAR**, LAYOUT, Kokkos::Device<EXEC_SPACE, MEM_SPACE>, \
@@ -151,7 +179,7 @@ struct Scal<execution_space, RV, typename XV::non_const_value_type, XV, 1,
151179
typeid(RV).name(), typeid(AV).name(), typeid(XV).name());
152180
#endif
153181

154-
const size_type numRows = X.extent(0);
182+
155183
KokkosKernels::Impl::ScalarHint alphaHint = KokkosKernels::Impl::ScalarHint::none;
156184
if (alpha == ATA::zero()) {
157185
alphaHint = KokkosKernels::Impl::ScalarHint::zero;
@@ -161,25 +189,22 @@ struct Scal<execution_space, RV, typename XV::non_const_value_type, XV, 1,
161189
alphaHint = KokkosKernels::Impl::ScalarHint::pos_one;
162190
}
163191

192+
const size_type numRows = X.extent(0);
164193
if (numRows < static_cast<size_type>(INT_MAX)) {
165-
typedef int index_type;
166-
V_Scal_Generic<execution_space, RV, AV, XV, index_type>(space, R, alpha,
167-
X, 0, alphaHint);
194+
V_Scal_Generic<int>(space, R, alpha, X, alphaHint);
168195
} else {
169-
typedef typename XV::size_type index_type;
170-
V_Scal_Generic<execution_space, RV, AV, XV, index_type>(space, R, alpha,
171-
X, 0, alphaHint);
196+
V_Scal_Generic<typename XV::size_type>(space, R, alpha, X, alphaHint);
172197
}
173198
Kokkos::Profiling::popRegion();
174199
}
175200
};
176201

177-
/// \brief Partial specialization of Scal for 2-D Views and 1-D View AV.
202+
/// \brief Partial specialization of Scal for 2-D Views and 1-D, 0-D, or scalar AV.
178203
///
179204
/// Compute any of the following:
180-
///
181-
/// 1. R(i,j) = a*X(i,j) for a in -1,0,1
182-
/// 2. R(i,j) = alpha(j)*X(i,j)
205+
/// 1. R(i,j) = av * X(i,j)
206+
/// 2. R(i,j) = av() * X(i,j)
207+
/// 3. R(i,j) = av(j) * X(i,j)
183208
template <class execution_space, class RMV, class AV, class XMV>
184209
struct Scal<execution_space, RMV, AV, XMV, 2, false,
185210
KOKKOSKERNELS_IMPL_COMPILE_LIBRARY> {
@@ -189,6 +214,9 @@ struct Scal<execution_space, RMV, AV, XMV, 2, false,
189214

190215
static void scal(const execution_space& space, const RMV& R, const AV& av,
191216
const XMV& X) {
217+
218+
// TODO: assert some things about AV
219+
192220
static_assert(Kokkos::is_view<RMV>::value,
193221
"KokkosBlas::Impl::"
194222
"Scal<2-D>: RMV is not a Kokkos::View.");
@@ -201,9 +229,6 @@ struct Scal<execution_space, RMV, AV, XMV, 2, false,
201229
static_assert(RMV::rank == 2,
202230
"KokkosBlas::Impl::Scal<2-D>: "
203231
"RMV is not rank 2.");
204-
static_assert(AV::rank == 1,
205-
"KokkosBlas::Impl::Scal<2-D>: "
206-
"AV is not rank 1.");
207232
static_assert(XMV::rank == 2,
208233
"KokkosBlas::Impl::Scal<2-D>: "
209234
"XMV is not rank 2.");
@@ -312,17 +337,29 @@ struct Scal<execution_space, RMV, typename XMV::non_const_value_type, XMV, 2,
312337

313338
//
314339
// Macro for declaration of full specialization of
315-
// KokkosBlas::Impl::Scal for rank == 2. This is NOT for users!!! All
340+
// KokkosBlas::Impl::Scal for rank == 1. This is NOT for users!!! All
316341
// the declarations of full specializations go in this header file.
317342
// We may spread out definitions (see _DEF macro below) across one or
318343
// more .cpp files.
319344
//
345+
// alpha can be either scalar or rank 0
320346
#define KOKKOSBLAS1_SCAL_ETI_SPEC_DECL(SCALAR, LAYOUT, EXEC_SPACE, MEM_SPACE) \
321347
extern template struct Scal< \
322348
EXEC_SPACE, \
323349
Kokkos::View<SCALAR*, LAYOUT, Kokkos::Device<EXEC_SPACE, MEM_SPACE>, \
324350
Kokkos::MemoryTraits<Kokkos::Unmanaged> >, \
325351
SCALAR, \
352+
Kokkos::View<const SCALAR*, LAYOUT, \
353+
Kokkos::Device<EXEC_SPACE, MEM_SPACE>, \
354+
Kokkos::MemoryTraits<Kokkos::Unmanaged> >, \
355+
1, false, true>; \
356+
extern template struct Scal< \
357+
EXEC_SPACE, \
358+
Kokkos::View<SCALAR*, LAYOUT, Kokkos::Device<EXEC_SPACE, MEM_SPACE>, \
359+
Kokkos::MemoryTraits<Kokkos::Unmanaged> >, \
360+
Kokkos::View<const SCALAR, LAYOUT, \
361+
Kokkos::Device<EXEC_SPACE, MEM_SPACE>, \
362+
Kokkos::MemoryTraits<Kokkos::Unmanaged> >, \
326363
Kokkos::View<const SCALAR*, LAYOUT, \
327364
Kokkos::Device<EXEC_SPACE, MEM_SPACE>, \
328365
Kokkos::MemoryTraits<Kokkos::Unmanaged> >, \
@@ -334,6 +371,17 @@ struct Scal<execution_space, RMV, typename XMV::non_const_value_type, XMV, 2,
334371
Kokkos::View<SCALAR*, LAYOUT, Kokkos::Device<EXEC_SPACE, MEM_SPACE>, \
335372
Kokkos::MemoryTraits<Kokkos::Unmanaged> >, \
336373
SCALAR, \
374+
Kokkos::View<const SCALAR*, LAYOUT, \
375+
Kokkos::Device<EXEC_SPACE, MEM_SPACE>, \
376+
Kokkos::MemoryTraits<Kokkos::Unmanaged> >, \
377+
1, false, true>; \
378+
template struct Scal< \
379+
EXEC_SPACE, \
380+
Kokkos::View<SCALAR*, LAYOUT, Kokkos::Device<EXEC_SPACE, MEM_SPACE>, \
381+
Kokkos::MemoryTraits<Kokkos::Unmanaged> >, \
382+
Kokkos::View<const SCALAR, LAYOUT, \
383+
Kokkos::Device<EXEC_SPACE, MEM_SPACE>, \
384+
Kokkos::MemoryTraits<Kokkos::Unmanaged> >, \
337385
Kokkos::View<const SCALAR*, LAYOUT, \
338386
Kokkos::Device<EXEC_SPACE, MEM_SPACE>, \
339387
Kokkos::MemoryTraits<Kokkos::Unmanaged> >, \

0 commit comments

Comments
 (0)