Skip to content

Commit

Permalink
Continue refactoring library to use new changes/optimisations
Browse files Browse the repository at this point in the history
  • Loading branch information
Pencilcaseman committed Aug 21, 2023
1 parent f918959 commit e591166
Show file tree
Hide file tree
Showing 8 changed files with 128 additions and 179 deletions.
21 changes: 10 additions & 11 deletions librapid/include/librapid/array/arrayContainer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,11 @@ namespace librapid {
template<typename T>
struct IsArrayContainer : std::false_type {};

template<size_t dims, typename StorageScalar>
struct IsArrayContainer<array::ArrayContainer<Shape<dims>, StorageScalar>>
: std::true_type {};
template<typename StorageScalar>
struct IsArrayContainer<array::ArrayContainer<Shape, StorageScalar>> : std::true_type {};

LIBRAPID_DEFINE_AS_TYPE(size_t dims COMMA typename StorageScalar,
array::ArrayContainer<Shape<dims> COMMA StorageScalar>);
LIBRAPID_DEFINE_AS_TYPE(typename StorageScalar,
array::ArrayContainer<Shape COMMA StorageScalar>);

LIBRAPID_DEFINE_AS_TYPE(typename StorageScalar,
array::ArrayContainer<MatrixShape COMMA StorageScalar>);
Expand All @@ -88,7 +87,7 @@ namespace librapid {
public:
using StorageType = StorageType_;
using ShapeType = ShapeType_;
using StrideType = Stride<32>;
using StrideType = Stride;
using SizeType = typename ShapeType::SizeType;
using Scalar = typename StorageType::Scalar;
using Packet = typename typetraits::TypeInfo<Scalar>::Packet;
Expand Down Expand Up @@ -582,11 +581,11 @@ namespace librapid {
auto subSize = res.shape().size();
int64_t storageSize = sizeof(typename StorageType_::Scalar);
cl_buffer_region region {index * subSize * storageSize, subSize * storageSize};
res.m_storage.set(
res.m_storage =
StorageType_(m_storage.data().createSubBuffer(
StorageType_::bufferFlags, CL_BUFFER_CREATE_TYPE_REGION, &region),
subSize,
false));
false);
return res;
#else
LIBRAPID_ERROR("OpenCL support not enabled");
Expand All @@ -597,7 +596,7 @@ namespace librapid {
res.m_shape = m_shape.subshape(1, ndim());
auto subSize = res.shape().size();
Scalar *begin = m_storage.begin().get() + index * subSize;
res.m_storage.set(StorageType_(begin, subSize, false));
res.m_storage = StorageType_(begin, subSize, false);
return res;
#else
LIBRAPID_ERROR("CUDA support not enabled");
Expand All @@ -610,7 +609,7 @@ namespace librapid {
auto subSize = res.shape().size();
Scalar *begin = m_storage.begin() + index * subSize;
Scalar *end = begin + subSize;
res.m_storage.set(StorageType_(begin, end, false));
res.m_storage = StorageType_(begin, end, false);
return res;
}
}
Expand Down Expand Up @@ -845,7 +844,7 @@ namespace librapid {
template<typename T, typename Char, typename Ctx>
LIBRAPID_ALWAYS_INLINE void ArrayContainer<ShapeType_, StorageType_>::str(
const fmt::formatter<T, Char> &format, char bracket, char separator, Ctx &ctx) const {
GeneralArrayView(*this).str(format, bracket, separator, ctx);
createGeneralArrayView(*this).str(format, bracket, separator, ctx);
}
} // namespace array

Expand Down
10 changes: 5 additions & 5 deletions librapid/include/librapid/array/arrayTypeDef.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,25 +28,25 @@ namespace librapid {
/// `backend::CPU`, `backend::CUDA` or any Storage interface
/// \tparam Scalar The scalar type of the array.
/// \tparam StorageType The storage type of the array.
template<typename Scalar, typename StorageType = backend::CPU>
template<typename Scalar, typename StorageType = backend::CPU, typename ShapeType = Shape,
typename std::enable_if_t<typetraits::IsSizeType<ShapeType>::value, int> = 0>
using Array =
array::ArrayContainer<Shape<LIBRAPID_MAX_ARRAY_DIMS>,
array::ArrayContainer<Shape,
typename detail::TypeDefStorageEvaluator<Scalar, StorageType>::Type>;

/// A definition for fixed-size array objects.
/// \tparam Scalar The scalar type of the array.
/// \tparam Dimensions The dimensions of the array.
/// \see Array
template<typename Scalar, size_t... Dimensions>
using ArrayF =
array::ArrayContainer<Shape<LIBRAPID_MAX_ARRAY_DIMS>, FixedStorage<Scalar, Dimensions...>>;
using ArrayF = array::ArrayContainer<Shape, FixedStorage<Scalar, Dimensions...>>;

/// A reference type for Array objects. Use this to accept Array objects as parameters since
/// the compiler cannot determine the templates tingle for the Array typedef. For more
/// granularity, you can also accept a raw ArrayContainer object. \tparam StorageType The
/// storage type of the array. \see Array \see ArrayF \see Function \see FunctionRef
template<typename StorageType>
using ArrayRef = array::ArrayContainer<Shape<LIBRAPID_MAX_ARRAY_DIMS>, StorageType>;
using ArrayRef = array::ArrayContainer<Shape, StorageType>;

template<typename Scalar, typename Backend = backend::CPU>
using Matrix =
Expand Down
36 changes: 21 additions & 15 deletions librapid/include/librapid/array/generalArrayView.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,33 @@ namespace librapid {
LIBRAPID_DEFINE_AS_TYPE(typename T, array::GeneralArrayView<T>);
} // namespace typetraits

template<typename T>
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto createGeneralArrayView(T &&array) {
return array::GeneralArrayView<T>(std::forward<T>(array));
}

namespace array {
template<typename ArrayViewType>
class GeneralArrayView {
public:
// using ArrayType = T;
using BaseType = typename std::decay_t<ArrayViewType>;
using Scalar = typename typetraits::TypeInfo<BaseType>::Scalar;
using Reference = BaseType &;
using ConstReference = const BaseType &;
using Backend = typename typetraits::TypeInfo<BaseType>::Backend;
using ArrayType = Array<Scalar, Backend>;
using StrideType = typename ArrayType::StrideType;
using ShapeType = typename ArrayType::ShapeType;
using Iterator = detail::ArrayIterator<GeneralArrayView>;
using StrideType = typename BaseType::StrideType;
using ShapeType = typename BaseType::ShapeType;
using StorageType = typename BaseType::StorageType;
// using ArrayType = Array<Scalar, Backend>;
using ArrayType = array::ArrayContainer<ShapeType, StorageType>;
using Iterator = detail::ArrayIterator<GeneralArrayView>;

/// Default constructor should never be used
GeneralArrayView() = delete;

/// Copy an ArrayView object
/// \param array The array to copy
LIBRAPID_ALWAYS_INLINE GeneralArrayView(ArrayViewType &array);
// LIBRAPID_ALWAYS_INLINE GeneralArrayView(ArrayViewType &array);

/// Copy an ArrayView object (not const)
/// \param array The array to copy
Expand All @@ -57,7 +63,7 @@ namespace librapid {
/// Assigns a temporary ArrayView to this ArrayView.
/// \param other The ArrayView to move.
/// \return A reference to this ArrayView.
// ArrayView &operator=(ArrayView &&other) noexcept = default;
GeneralArrayView &operator=(GeneralArrayView &&other) noexcept = default;

/// Assign a scalar value to this ArrayView. This function should only be used to
/// assign to a zero-dimensional "scalar" ArrayView, and will throw an error if used
Expand Down Expand Up @@ -140,17 +146,17 @@ namespace librapid {
Ctx &ctx) const;

private:
ArrayViewType &m_ref;
ArrayViewType m_ref;
ShapeType m_shape;
StrideType m_stride;
int64_t m_offset = 0;
};

template<typename ArrayViewType>
LIBRAPID_ALWAYS_INLINE
GeneralArrayView<ArrayViewType>::GeneralArrayView(ArrayViewType &array) :
m_ref(array),
m_shape(array.shape()), m_stride(array.shape()) {}
// template<typename ArrayViewType>
// LIBRAPID_ALWAYS_INLINE
// GeneralArrayView<ArrayViewType>::GeneralArrayView(ArrayViewType &array) :
// m_ref(array),
// m_shape(array.shape()), m_stride(array.shape()) {}

template<typename ArrayViewType>
LIBRAPID_ALWAYS_INLINE
Expand Down Expand Up @@ -202,7 +208,7 @@ namespace librapid {
"Index {} out of bounds in ArrayContainer::operator[] with leading dimension={}",
index,
m_shape[0]);
GeneralArrayView<T> view(m_ref);
GeneralArrayView view(createGeneralArrayView(m_ref));
const auto stride = Stride(m_shape);
view.setShape(m_shape.subshape(1, ndim()));
if (ndim() == 1)
Expand All @@ -221,7 +227,7 @@ namespace librapid {
"Index {} out of bounds in ArrayContainer::operator[] with leading dimension={}",
index,
m_shape[0]);
GeneralArrayView<T> view(m_ref);
GeneralArrayView view(createGeneralArrayView(m_ref));
const auto stride = Stride(m_shape);
view.setShape(m_shape.subshape(1, ndim()));
if (ndim() == 1)
Expand Down
16 changes: 10 additions & 6 deletions librapid/include/librapid/array/linalg/transpose.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -634,9 +634,11 @@ namespace librapid {
}
}; // namespace array

template<typename T, typename ShapeType = Shape<size_t, 32>,
typename std::enable_if_t<
typetraits::TypeInfo<T>::type == detail::LibRapidType::ArrayContainer, int> = 0>
template<typename T, typename ShapeType = MatrixShape,
typename std::enable_if_t<typetraits::TypeInfo<T>::type ==
detail::LibRapidType::ArrayContainer &&
typetraits::IsSizeType<ShapeType>::value,
int> = 0>
auto transpose(T &&array, const ShapeType &axes = ShapeType()) {
// If axes is empty, transpose the array in reverse order
ShapeType newAxes = axes;
Expand All @@ -648,9 +650,11 @@ namespace librapid {
return array::Transpose(array, newAxes);
}

template<typename T, typename ShapeType = Shape<size_t, 32>,
typename std::enable_if_t<
typetraits::TypeInfo<T>::type != detail::LibRapidType::ArrayContainer, int> = 0>
template<typename T, typename ShapeType = MatrixShape,
typename std::enable_if_t<typetraits::TypeInfo<T>::type !=
detail::LibRapidType::ArrayContainer &&
typetraits::IsSizeType<ShapeType>::value,
int> = 0>
auto transpose(const T &function, const ShapeType &axes = ShapeType()) {
// If axes is empty, transpose the array in reverse order
auto array = function.eval();
Expand Down
20 changes: 10 additions & 10 deletions librapid/include/librapid/array/pseudoConstructors.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ namespace librapid {
/// \tparam N Maximum number of dimensions of the Shape
/// \param shape Shape of the Array
/// \return Array filled with zeros
template<typename Scalar = double, typename Backend = backend::CPU,
size_t N = LIBRAPID_MAX_ARRAY_DIMS>
Array<Scalar, Backend> zeros(const Shape<N> &shape) {
template<typename Scalar = double, typename Backend = backend::CPU, typename ShapeType = Shape,
typename std::enable_if_t<typetraits::IsSizeType<ShapeType>::value, int> = 0>
Array<Scalar, Backend> zeros(const ShapeType &shape) {
return Array<Scalar, Backend>(shape, Scalar(0));
}

Expand All @@ -79,9 +79,9 @@ namespace librapid {
/// \tparam N Maximum number of dimensions of the Shape
/// \param shape Shape of the Array
/// \return Array filled with ones
template<typename Scalar = double, typename Backend = backend::CPU,
size_t N = LIBRAPID_MAX_ARRAY_DIMS>
Array<Scalar, Backend> ones(const Shape<N> &shape) {
template<typename Scalar = double, typename Backend = backend::CPU, typename ShapeType = Shape,
typename std::enable_if_t<typetraits::IsSizeType<ShapeType>::value, int> = 0>
Array<Scalar, Backend> ones(const ShapeType &shape) {
return Array<Scalar, Backend>(shape, Scalar(1));
}

Expand Down Expand Up @@ -109,11 +109,11 @@ namespace librapid {
/// \tparam N Maximum number of dimensions of the Shape
/// \param shape Shape of the Array
/// \return Array filled with numbers from 0 to N-1
template<typename Scalar = int64_t, typename Backend = backend::CPU,
size_t N = LIBRAPID_MAX_ARRAY_DIMS>
Array<Scalar, Backend> ordered(const Shape<N> &shape) {
template<typename Scalar = int64_t, typename Backend = backend::CPU, typename ShapeType = Shape,
typename std::enable_if_t<typetraits::IsSizeType<ShapeType>::value, int> = 0>
Array<Scalar, Backend> ordered(const ShapeType &shape) {
Array<Scalar, Backend> result(shape);
for (size_t i = 0; i < shape.size(); i++) { result.storage()[i] = Scalar(i); }
for (size_t i = 0; i < result.size(); i++) { result.storage()[i] = Scalar(i); }
return result;
}

Expand Down
Loading

0 comments on commit e591166

Please sign in to comment.