Skip to content

Commit 2c871e6

Browse files
committed
Continue propagating changes
1 parent e591166 commit 2c871e6

8 files changed

+488
-165
lines changed

librapid/include/librapid/array/arrayContainer.hpp

+145-85
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ namespace librapid {
9292
using Scalar = typename StorageType::Scalar;
9393
using Packet = typename typetraits::TypeInfo<Scalar>::Packet;
9494
using Backend = typename typetraits::TypeInfo<ArrayContainer>::Backend;
95-
using Iterator = detail::ArrayIterator<GeneralArrayView<ArrayContainer>>;
95+
using Iterator = detail::ArrayIterator<GeneralArrayView<ArrayContainer, ShapeType>>;
9696

9797
using DirectSubscriptType = typename detail::SubscriptType<StorageType>::Direct;
9898
using DirectRefSubscriptType = typename detail::SubscriptType<StorageType>::Ref;
@@ -132,12 +132,16 @@ namespace librapid {
132132

133133
/// Constructs an array container from a shape
134134
/// \param shape The shape of the array container
135-
LIBRAPID_ALWAYS_INLINE explicit ArrayContainer(const ShapeType &shape);
135+
LIBRAPID_ALWAYS_INLINE explicit ArrayContainer(const Shape &shape);
136+
LIBRAPID_ALWAYS_INLINE explicit ArrayContainer(const MatrixShape &shape);
137+
LIBRAPID_ALWAYS_INLINE explicit ArrayContainer(const VectorShape &shape);
136138

137139
/// Create an array container from a shape and a scalar value. The scalar value
138140
/// represents the value the memory is initialized with. \param shape The shape of the
139141
/// array container \param value The value to initialize the memory with
140-
LIBRAPID_ALWAYS_INLINE ArrayContainer(const ShapeType &shape, const Scalar &value);
142+
LIBRAPID_ALWAYS_INLINE ArrayContainer(const Shape &shape, const Scalar &value);
143+
LIBRAPID_ALWAYS_INLINE ArrayContainer(const MatrixShape &shape, const Scalar &value);
144+
LIBRAPID_ALWAYS_INLINE ArrayContainer(const VectorShape &shape, const Scalar &value);
141145

142146
/// Allows for a fixed-size array to be constructed with a fill value
143147
/// \param value The value to fill the array with
@@ -369,7 +373,7 @@ namespace librapid {
369373

370374
template<typename ShapeType_, typename StorageType_>
371375
LIBRAPID_ALWAYS_INLINE
372-
ArrayContainer<ShapeType_, StorageType_>::ArrayContainer(const ShapeType &shape) :
376+
ArrayContainer<ShapeType_, StorageType_>::ArrayContainer(const Shape &shape) :
373377
m_shape(shape),
374378
m_size(shape.size()), m_storage(m_size) {
375379
static_assert(!typetraits::IsFixedStorage<StorageType_>::value,
@@ -380,7 +384,67 @@ namespace librapid {
380384

381385
template<typename ShapeType_, typename StorageType_>
382386
LIBRAPID_ALWAYS_INLINE
383-
ArrayContainer<ShapeType_, StorageType_>::ArrayContainer(const ShapeType &shape,
387+
ArrayContainer<ShapeType_, StorageType_>::ArrayContainer(const MatrixShape &shape) :
388+
m_shape(shape),
389+
m_size(shape.size()), m_storage(m_size) {
390+
static_assert(!typetraits::IsFixedStorage<StorageType_>::value,
391+
"For a compile-time-defined shape, "
392+
"the storage type must be "
393+
"a FixedStorage object");
394+
}
395+
396+
template<typename ShapeType_, typename StorageType_>
397+
LIBRAPID_ALWAYS_INLINE
398+
ArrayContainer<ShapeType_, StorageType_>::ArrayContainer(const VectorShape &shape) :
399+
m_shape(shape),
400+
m_size(shape.size()), m_storage(m_size) {
401+
static_assert(!typetraits::IsFixedStorage<StorageType_>::value,
402+
"For a compile-time-defined shape, "
403+
"the storage type must be "
404+
"a FixedStorage object");
405+
}
406+
407+
template<typename ShapeType_, typename StorageType_>
408+
LIBRAPID_ALWAYS_INLINE
409+
ArrayContainer<ShapeType_, StorageType_>::ArrayContainer(const Shape &shape,
410+
const Scalar &value) :
411+
m_shape(shape),
412+
m_size(shape.size()), m_storage(m_size, value) {
413+
static_assert(typetraits::IsStorage<StorageType_>::value ||
414+
typetraits::IsOpenCLStorage<StorageType_>::value ||
415+
typetraits::IsCudaStorage<StorageType_>::value,
416+
"For a runtime-defined shape, "
417+
"the storage type must be "
418+
"either a Storage or a "
419+
"CudaStorage object");
420+
static_assert(!typetraits::IsFixedStorage<StorageType_>::value,
421+
"For a compile-time-defined shape, "
422+
"the storage type must be "
423+
"a FixedStorage object");
424+
}
425+
426+
template<typename ShapeType_, typename StorageType_>
427+
LIBRAPID_ALWAYS_INLINE
428+
ArrayContainer<ShapeType_, StorageType_>::ArrayContainer(const MatrixShape &shape,
429+
const Scalar &value) :
430+
m_shape(shape),
431+
m_size(shape.size()), m_storage(m_size, value) {
432+
static_assert(typetraits::IsStorage<StorageType_>::value ||
433+
typetraits::IsOpenCLStorage<StorageType_>::value ||
434+
typetraits::IsCudaStorage<StorageType_>::value,
435+
"For a runtime-defined shape, "
436+
"the storage type must be "
437+
"either a Storage or a "
438+
"CudaStorage object");
439+
static_assert(!typetraits::IsFixedStorage<StorageType_>::value,
440+
"For a compile-time-defined shape, "
441+
"the storage type must be "
442+
"a FixedStorage object");
443+
}
444+
445+
template<typename ShapeType_, typename StorageType_>
446+
LIBRAPID_ALWAYS_INLINE
447+
ArrayContainer<ShapeType_, StorageType_>::ArrayContainer(const VectorShape &shape,
384448
const Scalar &value) :
385449
m_shape(shape),
386450
m_size(shape.size()), m_storage(m_size, value) {
@@ -525,44 +589,42 @@ namespace librapid {
525589
index,
526590
m_shape[0]);
527591

528-
if constexpr (typetraits::IsOpenCLStorage<StorageType_>::value) {
529-
#if defined(LIBRAPID_HAS_OPENCL)
530-
ArrayContainer res;
531-
res.m_shape = m_shape.subshape(1, ndim());
532-
auto subSize = res.shape().size();
533-
int64_t storageSize = sizeof(typename StorageType_::Scalar);
534-
cl_buffer_region region {index * subSize * storageSize, subSize * storageSize};
535-
res.m_storage =
536-
StorageType_(m_storage.data().createSubBuffer(
537-
StorageType_::bufferFlags, CL_BUFFER_CREATE_TYPE_REGION, &region),
538-
subSize,
539-
false);
540-
return res;
541-
#else
542-
LIBRAPID_ERROR("OpenCL support not enabled");
543-
#endif // LIBRAPID_HAS_OPENCL
544-
} else if constexpr (typetraits::IsCudaStorage<StorageType_>::value) {
545-
#if defined(LIBRAPID_HAS_CUDA)
546-
ArrayContainer res;
547-
res.m_shape = m_shape.subshape(1, ndim());
548-
auto subSize = res.shape().size();
549-
Scalar *begin = m_storage.begin().get() + index * subSize;
550-
res.m_storage = StorageType_(begin, subSize, false);
551-
return res;
552-
#else
553-
LIBRAPID_ERROR("CUDA support not enabled");
554-
#endif // LIBRAPID_HAS_CUDA
555-
} else if constexpr (typetraits::IsFixedStorage<StorageType_>::value) {
556-
return GeneralArrayView(*this)[index];
557-
} else {
558-
ArrayContainer res;
559-
res.m_shape = m_shape.subshape(1, ndim());
560-
auto subSize = res.shape().size();
561-
Scalar *begin = m_storage.begin() + index * subSize;
562-
Scalar *end = begin + subSize;
563-
res.m_storage = StorageType_(begin, end, false);
564-
return res;
565-
}
592+
return createGeneralArrayView(*this)[index];
593+
594+
// if constexpr (typetraits::IsOpenCLStorage<StorageType_>::value) {
595+
// #if defined(LIBRAPID_HAS_OPENCL)
596+
// ArrayContainer res;
597+
// res.m_shape = m_shape.subshape(1, ndim());
598+
// auto subSize = res.shape().size();
599+
// int64_t storageSize = sizeof(typename StorageType_::Scalar);
600+
// cl_buffer_region region {index * subSize * storageSize, subSize *
601+
// storageSize}; res.m_storage =
602+
// StorageType_(m_storage.data().createSubBuffer(
603+
// StorageType_::bufferFlags,
604+
// CL_BUFFER_CREATE_TYPE_REGION, &region), subSize,
605+
// false); return res; #else LIBRAPID_ERROR("OpenCL support
606+
// not enabled"); #endif // LIBRAPID_HAS_OPENCL } else if constexpr
607+
//(typetraits::IsCudaStorage<StorageType_>::value) { #if defined(LIBRAPID_HAS_CUDA)
608+
// ArrayContainer res;
609+
// res.m_shape = m_shape.subshape(1, ndim());
610+
// auto subSize = res.shape().size();
611+
// Scalar *begin = m_storage.begin().get() + index * subSize;
612+
// res.m_storage = StorageType_(begin, subSize, false);
613+
// return res;
614+
// #else
615+
// LIBRAPID_ERROR("CUDA support not enabled");
616+
// #endif // LIBRAPID_HAS_CUDA
617+
// } else if constexpr (typetraits::IsFixedStorage<StorageType_>::value) {
618+
// return GeneralArrayView(*this)[index];
619+
// } else {
620+
// ArrayContainer res;
621+
// res.m_shape = m_shape.subshape(1, ndim());
622+
// auto subSize = res.shape().size();
623+
// Scalar *begin = m_storage.begin() + index * subSize;
624+
// Scalar *end = begin + subSize;
625+
// res.m_storage = StorageType_(begin, end, false);
626+
// return res;
627+
// }
566628
}
567629

568630
template<typename ShapeType_, typename StorageType_>
@@ -574,44 +636,42 @@ namespace librapid {
574636
index,
575637
m_shape[0]);
576638

577-
if constexpr (typetraits::IsOpenCLStorage<StorageType_>::value) {
578-
#if defined(LIBRAPID_HAS_OPENCL)
579-
ArrayContainer res;
580-
res.m_shape = m_shape.subshape(1, ndim());
581-
auto subSize = res.shape().size();
582-
int64_t storageSize = sizeof(typename StorageType_::Scalar);
583-
cl_buffer_region region {index * subSize * storageSize, subSize * storageSize};
584-
res.m_storage =
585-
StorageType_(m_storage.data().createSubBuffer(
586-
StorageType_::bufferFlags, CL_BUFFER_CREATE_TYPE_REGION, &region),
587-
subSize,
588-
false);
589-
return res;
590-
#else
591-
LIBRAPID_ERROR("OpenCL support not enabled");
592-
#endif // LIBRAPID_HAS_OPENCL
593-
} else if constexpr (typetraits::IsCudaStorage<StorageType_>::value) {
594-
#if defined(LIBRAPID_HAS_CUDA)
595-
ArrayContainer res;
596-
res.m_shape = m_shape.subshape(1, ndim());
597-
auto subSize = res.shape().size();
598-
Scalar *begin = m_storage.begin().get() + index * subSize;
599-
res.m_storage = StorageType_(begin, subSize, false);
600-
return res;
601-
#else
602-
LIBRAPID_ERROR("CUDA support not enabled");
603-
#endif // LIBRAPID_HAS_CUDA
604-
} else if constexpr (typetraits::IsFixedStorage<StorageType_>::value) {
605-
return GeneralArrayView(*this)[index];
606-
} else {
607-
ArrayContainer res;
608-
res.m_shape = m_shape.subshape(1, ndim());
609-
auto subSize = res.shape().size();
610-
Scalar *begin = m_storage.begin() + index * subSize;
611-
Scalar *end = begin + subSize;
612-
res.m_storage = StorageType_(begin, end, false);
613-
return res;
614-
}
639+
return createGeneralArrayView(*this)[index];
640+
641+
// if constexpr (typetraits::IsOpenCLStorage<StorageType_>::value) {
642+
// #if defined(LIBRAPID_HAS_OPENCL)
643+
// ArrayContainer res;
644+
// res.m_shape = m_shape.subshape(1, ndim());
645+
// auto subSize = res.shape().size();
646+
// int64_t storageSize = sizeof(typename StorageType_::Scalar);
647+
// cl_buffer_region region {index * subSize * storageSize, subSize *
648+
// storageSize}; res.m_storage =
649+
// StorageType_(m_storage.data().createSubBuffer(
650+
// StorageType_::bufferFlags,
651+
// CL_BUFFER_CREATE_TYPE_REGION, &region), subSize,
652+
// false); return res; #else LIBRAPID_ERROR("OpenCL support
653+
// not enabled"); #endif // LIBRAPID_HAS_OPENCL } else if constexpr
654+
//(typetraits::IsCudaStorage<StorageType_>::value) { #if defined(LIBRAPID_HAS_CUDA)
655+
// ArrayContainer res;
656+
// res.m_shape = m_shape.subshape(1, ndim());
657+
// auto subSize = res.shape().size();
658+
// Scalar *begin = m_storage.begin().get() + index * subSize;
659+
// res.m_storage = StorageType_(begin, subSize, false);
660+
// return res;
661+
// #else
662+
// LIBRAPID_ERROR("CUDA support not enabled");
663+
// #endif // LIBRAPID_HAS_CUDA
664+
// } else if constexpr (typetraits::IsFixedStorage<StorageType_>::value) {
665+
// return GeneralArrayView(*this)[index];
666+
// } else {
667+
// ArrayContainer res;
668+
// res.m_shape = m_shape.subshape(1, ndim());
669+
// auto subSize = res.shape().size();
670+
// Scalar *begin = m_storage.begin() + index * subSize;
671+
// Scalar *end = begin + subSize;
672+
// res.m_storage = StorageType_(begin, end, false);
673+
// return res;
674+
// }
615675
}
616676

617677
template<typename ShapeType_, typename StorageType_>
@@ -854,8 +914,8 @@ namespace librapid {
854914
static constexpr bool val = false;
855915
};
856916

857-
template<typename T>
858-
struct IsArrayType<ArrayRef<T>> {
917+
template<typename T, typename V>
918+
struct IsArrayType<ArrayRef<T, V>> {
859919
static constexpr bool val = true;
860920
};
861921

@@ -864,8 +924,8 @@ namespace librapid {
864924
static constexpr bool val = true;
865925
};
866926

867-
template<typename T>
868-
struct IsArrayType<array::GeneralArrayView<T>> {
927+
template<typename T, typename S>
928+
struct IsArrayType<array::GeneralArrayView<T, S>> {
869929
static constexpr bool val = true;
870930
};
871931

librapid/include/librapid/array/arrayTypeDef.hpp

+12-5
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@ namespace librapid {
2828
/// `backend::CPU`, `backend::CUDA` or any Storage interface
2929
/// \tparam Scalar The scalar type of the array.
3030
/// \tparam StorageType The storage type of the array.
31-
template<typename Scalar, typename StorageType = backend::CPU, typename ShapeType = Shape,
32-
typename std::enable_if_t<typetraits::IsSizeType<ShapeType>::value, int> = 0>
31+
template<typename Scalar, typename StorageType = backend::CPU>
3332
using Array =
3433
array::ArrayContainer<Shape,
3534
typename detail::TypeDefStorageEvaluator<Scalar, StorageType>::Type>;
@@ -45,8 +44,8 @@ namespace librapid {
4544
/// the compiler cannot determine the templates tingle for the Array typedef. For more
4645
/// granularity, you can also accept a raw ArrayContainer object. \tparam StorageType The
4746
/// storage type of the array. \see Array \see ArrayF \see Function \see FunctionRef
48-
template<typename StorageType>
49-
using ArrayRef = array::ArrayContainer<Shape, StorageType>;
47+
template<typename ShapeType, typename StorageType>
48+
using ArrayRef = array::ArrayContainer<ShapeType, StorageType>;
5049

5150
template<typename Scalar, typename Backend = backend::CPU>
5251
using Matrix =
@@ -56,6 +55,14 @@ namespace librapid {
5655
template<typename Scalar, size_t... Dimensions>
5756
using MatrixF = array::ArrayContainer<MatrixShape, FixedStorage<Scalar, Dimensions...>>;
5857

58+
template<typename Scalar, typename Backend = backend::CPU>
59+
using Array1D =
60+
array::ArrayContainer<VectorShape,
61+
typename detail::TypeDefStorageEvaluator<Scalar, Backend>::Type>;
62+
63+
template<typename Scalar, size_t... Dimensions>
64+
using Array1DF = array::ArrayContainer<VectorShape, FixedStorage<Scalar, Dimensions...>>;
65+
5966
/// A reference type for Array Function objects. Use this to accept Function objects as
6067
/// parameters since the compiler cannot determine the templates for the typedef by default.
6168
/// Additionally, this can be used to store references to Function objects.
@@ -70,7 +77,7 @@ namespace librapid {
7077
namespace array {
7178
/// An intermediate type to represent a slice or view of an array.
7279
/// \tparam T The type of the array.
73-
template<typename T>
80+
template<typename T, typename ShapeType_>
7481
class GeneralArrayView;
7582

7683
template<typename T>

0 commit comments

Comments
 (0)