Skip to content

Commit

Permalink
Store sizes directly alongside shapes
Browse files Browse the repository at this point in the history
  • Loading branch information
Pencilcaseman committed Aug 20, 2023
1 parent 705ba28 commit 7f435a2
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 13 deletions.
30 changes: 22 additions & 8 deletions librapid/include/librapid/array/arrayContainer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,8 @@ namespace librapid {
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE typename ShapeType::SizeType
ndim() const noexcept;

LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto size() const noexcept -> size_t;

/// Return the shape of the array container. This is an immutable reference.
/// \return The shape of the array container.
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE const ShapeType &shape() const noexcept;
Expand Down Expand Up @@ -343,28 +345,30 @@ namespace librapid {

private:
ShapeType m_shape; // The shape type of the array
size_t m_size; // The size of the array
StorageType m_storage; // The storage container of the array
};

template<typename ShapeType_, typename StorageType_>
ArrayContainer<ShapeType_, StorageType_>::ArrayContainer() :
m_shape(StorageType_::template defaultShape<ShapeType_>()) {}
m_shape(StorageType_::template defaultShape<ShapeType_>()), m_size(0) {}

template<typename ShapeType_, typename StorageType_>
template<typename T>
ArrayContainer<ShapeType_, StorageType_>::ArrayContainer(
const std::initializer_list<T> &data) :
m_shape({data.size()}),
m_storage(StorageType::fromData(data)) {}
m_size(data.size()), m_storage(StorageType::fromData(data)) {}

template<typename ShapeType_, typename StorageType_>
template<typename T>
ArrayContainer<ShapeType_, StorageType_>::ArrayContainer(const std::vector<T> &data) :
m_shape({data.size()}), m_storage(StorageType::fromData(data)) {}
m_shape({data.size()}), m_size(data.size()),
m_storage(StorageType::fromData(data)) {}

template<typename ShapeType_, typename StorageType_>
ArrayContainer<ShapeType_, StorageType_>::ArrayContainer(const ShapeType &shape) :
m_shape(shape), m_storage(shape.size()) {
m_shape(shape), m_size(shape.size()), m_storage(m_size) {
static_assert(!typetraits::IsFixedStorage<StorageType_>::value,
"For a compile-time-defined shape, "
"the storage type must be "
Expand All @@ -375,7 +379,7 @@ namespace librapid {
ArrayContainer<ShapeType_, StorageType_>::ArrayContainer(const ShapeType &shape,
const Scalar &value) :
m_shape(shape),
m_storage(shape.size(), value) {
m_size(shape.size()), m_storage(m_size, value) {
static_assert(typetraits::IsStorage<StorageType_>::value ||
typetraits::IsOpenCLStorage<StorageType_>::value ||
typetraits::IsCudaStorage<StorageType_>::value,
Expand All @@ -391,7 +395,8 @@ namespace librapid {

template<typename ShapeType_, typename StorageType_>
ArrayContainer<ShapeType_, StorageType_>::ArrayContainer(const Scalar &value) :
m_shape(detail::shapeFromFixedStorage(m_storage)), m_storage(value) {
m_shape(detail::shapeFromFixedStorage(m_storage)), m_size(m_shape.size()),
m_storage(m_size) {
static_assert(typetraits::IsFixedStorage<StorageType_>::value,
"For a compile-time-defined shape, "
"the storage type must be "
Expand All @@ -400,7 +405,8 @@ namespace librapid {

template<typename ShapeType_, typename StorageType_>
ArrayContainer<ShapeType_, StorageType_>::ArrayContainer(ShapeType_ &&shape) :
m_shape(std::forward<ShapeType_>(shape)), m_storage(m_shape.size()) {}
m_shape(std::forward<ShapeType_>(shape)), m_size(m_shape.size()),
m_storage(m_size) {}

template<typename ShapeType_, typename StorageType_>
template<typename TransposeType>
Expand All @@ -423,7 +429,7 @@ namespace librapid {
auto ArrayContainer<ShapeType_, StorageType_>::assign(
const detail::Function<desc, Functor_, Args...> &function) -> ArrayContainer & {
using FunctionType = detail::Function<desc, Functor_, Args...>;
m_storage.resize(function.shape().size(), 0);
m_storage.resize(function.size(), 0);
if constexpr (std::is_same_v<typename FunctionType::Backend, backend::OpenCL> ||
std::is_same_v<typename FunctionType::Backend, backend::CUDA>) {
detail::assign(*this, function);
Expand All @@ -443,6 +449,7 @@ namespace librapid {
ArrayContainer<ShapeType_, StorageType_>::ArrayContainer(
const detail::Function<desc, Functor_, Args...> &function) LIBRAPID_RELEASE_NOEXCEPT
: m_shape(function.shape()),
m_size(function.size()),
m_storage(m_shape.size()) {
assign(function);
}
Expand All @@ -459,6 +466,7 @@ namespace librapid {
auto ArrayContainer<ShapeType_, StorageType_>::operator=(
const Transpose<TransposeType> &transpose) -> ArrayContainer & {
m_shape = transpose.shape();
m_size = transpose.size();
m_storage.resize(m_shape.size(), 0);
transpose.applyTo(*this);
return *this;
Expand All @@ -471,6 +479,7 @@ namespace librapid {
const linalg::ArrayMultiply<ShapeTypeA, StorageTypeA, ShapeTypeB, StorageTypeB, Alpha,
Beta> &arrayMultiply) -> ArrayContainer & {
m_shape = arrayMultiply.shape();
m_size = arrayMultiply.size();
m_storage.resize(m_shape.size(), 0);
arrayMultiply.applyTo(*this);
return *this;
Expand Down Expand Up @@ -652,6 +661,11 @@ namespace librapid {
return m_shape.ndim();
}

template<typename ShapeType_, typename StorageType_>
auto ArrayContainer<ShapeType_, StorageType_>::size() const noexcept -> size_t {
return m_size;
}

template<typename ShapeType_, typename StorageType_>
auto ArrayContainer<ShapeType_, StorageType_>::shape() const noexcept -> const ShapeType & {
return m_shape;
Expand Down
2 changes: 1 addition & 1 deletion librapid/include/librapid/array/assignOps.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ namespace librapid {
detail::Function<descriptor::Trivial, Functor_, Args...>>::allowVectorisation &&
Function::argsAreSameType;

const size_t size = function.shape().size();
const size_t size = function.size();
const size_t vectorSize = size - (size % packetWidth);

LIBRAPID_ASSUME(vectorSize % packetWidth == 0);
Expand Down
18 changes: 14 additions & 4 deletions librapid/include/librapid/array/function.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,11 @@ namespace librapid {
/// \return A reference to this Function.
LIBRAPID_ALWAYS_INLINE Function &operator=(Function &&other) noexcept = default;

LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto size() const -> size_t;

/// Return the shape of the Function's result
/// \return The shape of the Function's result
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto shape() const;
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto shape() const -> const ShapeType &;

/// Return the arguments in the Function
/// \return The arguments in the Function
Expand Down Expand Up @@ -201,15 +203,23 @@ namespace librapid {

Functor m_functor;
std::tuple<Args...> m_args;
ShapeType m_shape;
size_t m_size = 0;
};

template<typename desc, typename Functor, typename... Args>
Function<desc, Functor, Args...>::Function(Functor &&functor, Args &&...args) :
m_functor(std::forward<Functor>(functor)), m_args(std::forward<Args>(args)...) {}
m_functor(std::forward<Functor>(functor)), m_args(std::forward<Args>(args)...),
m_shape(typetraits::TypeInfo<Functor>::getShape(m_args)), m_size(m_shape.size()) {}

template<typename desc, typename Functor, typename... Args>
auto Function<desc, Functor, Args...>::shape() const -> const ShapeType & {
return m_shape;
}

template<typename desc, typename Functor, typename... Args>
auto Function<desc, Functor, Args...>::shape() const {
return typetraits::TypeInfo<Functor>::getShape(m_args);
auto Function<desc, Functor, Args...>::size() const -> size_t {
return m_size;
}

template<typename desc, typename Functor, typename... Args>
Expand Down

0 comments on commit 7f435a2

Please sign in to comment.