Skip to content

Commit

Permalink
Matrix Shape type
Browse files Browse the repository at this point in the history
  • Loading branch information
Pencilcaseman committed Aug 20, 2023
1 parent 2d79293 commit 479d7c7
Show file tree
Hide file tree
Showing 5 changed files with 294 additions and 82 deletions.
3 changes: 3 additions & 0 deletions librapid/include/librapid/array/arrayContainer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ namespace librapid {

LIBRAPID_DEFINE_AS_TYPE(size_t dims COMMA typename StorageScalar,
array::ArrayContainer<Shape<dims> COMMA StorageScalar>);

LIBRAPID_DEFINE_AS_TYPE(typename StorageScalar,
array::ArrayContainer<MatrixShape COMMA StorageScalar>);
} // namespace typetraits

namespace array {
Expand Down
8 changes: 8 additions & 0 deletions librapid/include/librapid/array/arrayTypeDef.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@ namespace librapid {
template<typename StorageType>
using ArrayRef = array::ArrayContainer<Shape<LIBRAPID_MAX_ARRAY_DIMS>, StorageType>;

template<typename Scalar, typename Backend = backend::CPU>
using Matrix =
array::ArrayContainer<MatrixShape,
typename detail::TypeDefStorageEvaluator<Scalar, Backend>::Type>;

template<typename Scalar, size_t... Dimensions>
using MatrixF = array::ArrayContainer<MatrixShape, FixedStorage<Scalar, Dimensions...>>;

/// A reference type for Array Function objects. Use this to accept Function objects as
/// parameters since the compiler cannot determine the templates for the typedef by default.
/// Additionally, this can be used to store references to Function objects.
Expand Down
247 changes: 221 additions & 26 deletions librapid/include/librapid/array/shape.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
namespace librapid {
namespace typetraits {
LIBRAPID_DEFINE_AS_TYPE(size_t N, Shape<N>);
LIBRAPID_DEFINE_AS_TYPE_NO_TEMPLATE(MatrixShape);
}

template<size_t N = 32>
template<size_t N = LIBRAPID_MAX_ARRAY_DIMS>
class Shape {
public:
using SizeType = uint32_t;
Expand Down Expand Up @@ -90,12 +91,12 @@ namespace librapid {
/// Return a Shape object with \p dims dimensions, all initialized to zero.
/// \param dims Number of dimensions
/// \return New Shape object
static auto zeros(size_t dims) -> Shape;
LIBRAPID_ALWAYS_INLINE static auto zeros(size_t dims) -> Shape;

/// Return a Shape object with \p dims dimensions, all initialized to one.
/// \param dims Number of dimensions
/// \return New Shape object
static auto ones(size_t dims) -> Shape;
LIBRAPID_ALWAYS_INLINE static auto ones(size_t dims) -> Shape;

/// Access an element of the Shape object
/// \tparam Index Typename of the index
Expand Down Expand Up @@ -124,13 +125,14 @@ namespace librapid {

/// Return the number of dimensions in the Shape object
/// \return Number of dimensions
LIBRAPID_NODISCARD auto ndim() const -> int;
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto ndim() const -> int;

/// Return a subshape of the Shape object
/// \param start Starting index
/// \param end Ending index
/// \return Subshape
LIBRAPID_NODISCARD auto subshape(size_t start, size_t end) const -> Shape;
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto subshape(size_t start, size_t end) const
-> Shape;

/// Return the number of elements the Shape object represents
/// \return Number of elements
Expand All @@ -144,6 +146,64 @@ namespace librapid {
std::array<SizeType, N> m_data;
};

class MatrixShape {
public:
using SizeType = uint32_t;
static constexpr size_t MaxDimensions = 2;

MatrixShape() = default;

template<typename Scalar, size_t Rows, size_t Cols>
explicit MatrixShape(const FixedStorage<Scalar, Rows, Cols> &fixed);

template<typename V>
MatrixShape(const std::initializer_list<V> &vals);

template<typename V>
explicit MatrixShape(const std::vector<V> &vals);

MatrixShape(const MatrixShape &other) = default;

MatrixShape(MatrixShape &&other) noexcept = default;

template<typename V>
auto operator=(const std::initializer_list<V> &vals) -> MatrixShape &;

template<typename V>
auto operator=(const std::vector<V> &vals) -> MatrixShape &;

MatrixShape &operator=(const MatrixShape &other) = default;

MatrixShape &operator=(MatrixShape &&other) noexcept = default;

static auto zeros() -> MatrixShape;

static auto ones() -> MatrixShape;

template<typename Index>
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator[](Index index) const
-> const SizeType &;

template<typename Index>
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator[](Index index) -> SizeType &;

LIBRAPID_ALWAYS_INLINE auto operator<=>(const MatrixShape &other) const = default;

LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto ndim() const -> int;

LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto subshape(size_t start, size_t end) const
-> Shape<LIBRAPID_MAX_ARRAY_DIMS>;

LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto size() const -> size_t;

template<typename T_, typename Char, typename Ctx>
void str(const fmt::formatter<T_, Char> &format, Ctx &ctx) const;

private:
SizeType m_rows;
SizeType m_cols;
};

namespace detail {
template<typename T, size_t... Dims>
Shape<LIBRAPID_MAX_ARRAY_DIMS> shapeFromFixedStorage(const FixedStorage<T, Dims...> &) {
Expand Down Expand Up @@ -239,15 +299,17 @@ namespace librapid {
template<typename Index>
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto Shape<N>::operator[](Index index) const
-> const SizeType & {
LIBRAPID_ASSERT(static_cast<T>(index) < m_dims, "Index out of bounds");
static_assert(std::is_integral_v<Index>, "Index must be an integral type");
LIBRAPID_ASSERT(index < m_dims, "Index out of bounds");
LIBRAPID_ASSERT(index >= 0, "Index out of bounds");
return m_data[index];
}

template<size_t N>
template<typename Index>
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto Shape<N>::operator[](Index index) -> SizeType & {
LIBRAPID_ASSERT(static_cast<T>(index) < m_dims, "Index out of bounds");
static_assert(std::is_integral_v<Index>, "Index must be an integral type");
LIBRAPID_ASSERT(index < m_dims, "Index out of bounds");
LIBRAPID_ASSERT(index >= 0, "Index out of bounds");
return m_data[index];
}
Expand Down Expand Up @@ -302,6 +364,122 @@ namespace librapid {
fmt::format_to(ctx.out(), ")");
}

template<typename Scalar, size_t Rows, size_t Cols>
MatrixShape::MatrixShape(const FixedStorage<Scalar, Rows, Cols> &) :
m_rows(Rows), m_cols(Cols) {}

template<typename V>
MatrixShape::MatrixShape(const std::initializer_list<V> &vals) {
LIBRAPID_ASSERT(vals.size() == 2, "MatrixShape must be initialized with 2 values");
m_rows = *(vals.begin());
m_cols = *(vals.begin() + 1);
}

template<typename V>
MatrixShape::MatrixShape(const std::vector<V> &vals) {
LIBRAPID_ASSERT(vals.size() == 2, "MatrixShape must be initialized with 2 values");
m_rows = vals[0];
m_cols = vals[1];
}

template<typename V>
auto MatrixShape::operator=(const std::initializer_list<V> &vals) -> MatrixShape & {
LIBRAPID_ASSERT(vals.size() == 2, "MatrixShape must be initialized with 2 values");
m_rows = *(vals.begin());
m_cols = *(vals.begin() + 1);
return *this;
}

template<typename V>
auto MatrixShape::operator=(const std::vector<V> &vals) -> MatrixShape & {
LIBRAPID_ASSERT(vals.size() == 2, "MatrixShape must be initialized with 2 values");
m_rows = vals[0];
m_cols = vals[1];
return *this;
}

LIBRAPID_ALWAYS_INLINE auto MatrixShape::zeros() -> MatrixShape { return MatrixShape({0, 0}); }

LIBRAPID_ALWAYS_INLINE auto MatrixShape::ones() -> MatrixShape { return MatrixShape({1, 1}); }

template<typename Index>
auto MatrixShape::operator[](Index index) const -> const SizeType & {
static_assert(std::is_integral_v<Index>, "Index must be an integral type");
LIBRAPID_ASSERT(index < 2, "Index out of bounds");
LIBRAPID_ASSERT(index >= 0, "Index out of bounds");

return index == 0 ? m_rows : m_cols;
}

template<typename Index>
auto MatrixShape::operator[](Index index) -> SizeType & {
static_assert(std::is_integral_v<Index>, "Index must be an integral type");
LIBRAPID_ASSERT(index < 2, "Index out of bounds");
LIBRAPID_ASSERT(index >= 0, "Index out of bounds");

return index == 0 ? m_rows : m_cols;
}

constexpr auto MatrixShape::ndim() const -> int { return 2; }

auto MatrixShape::subshape(size_t start, size_t end) const -> Shape<LIBRAPID_MAX_ARRAY_DIMS> {
LIBRAPID_ASSERT(start <= end, "Start index must be less than end index");
LIBRAPID_ASSERT(end <= 2,
"End index must be less than or equal to the number of dimensions");
LIBRAPID_ASSERT(start >= 0, "Start index must be greater than or equal to 0");

Shape<LIBRAPID_MAX_ARRAY_DIMS> res;
res[0] = m_rows;
res[1] = m_cols;
return res.subshape(start, end);
}

auto MatrixShape::size() const -> size_t { return m_rows * m_cols; }

template<typename T_, typename Char, typename Ctx>
void MatrixShape::str(const fmt::formatter<T_, Char> &format, Ctx &ctx) const {
fmt::format_to(ctx.out(), "MatrixShape(");
format.format(m_rows, ctx);
fmt::format_to(ctx.out(), ", ");
format.format(m_cols, ctx);
fmt::format_to(ctx.out(), ")");
}

template<size_t N>
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator==(const Shape<N> &lhs,
const MatrixShape &rhs) -> bool {
return lhs.ndim() == 2 && lhs[0] == rhs[0] && lhs[1] == rhs[1];
}

template<size_t N>
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator==(const MatrixShape &lhs,
const Shape<N> &rhs) -> bool {
return rhs == lhs;
}

template<size_t N>
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator!=(const Shape<N> &lhs,
const MatrixShape &rhs) -> bool {
return !(lhs == rhs);
}

template<size_t N>
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator!=(const MatrixShape &lhs,
const Shape<N> &rhs) -> bool {
return !(lhs == rhs);
}

namespace typetraits {
template<typename T>
struct IsSizeType : std::false_type {};

template<size_t N>
struct IsSizeType<Shape<N>> : std::true_type {};

template<>
struct IsSizeType<MatrixShape> : std::true_type {};
} // namespace typetraits

/// Returns true if all inputs have the same shape
/// \tparam T1 Type of the first input
/// \tparam N1 Number of dimensions of the first input
Expand All @@ -313,40 +491,36 @@ namespace librapid {
/// \param second Second input
/// \param shapes Remaining (optional) inputs
/// \return True if all inputs have the same shape, false otherwise
template<size_t N1, size_t N2, size_t... Nn>
LIBRAPID_NODISCARD LIBRAPID_INLINE bool
shapesMatch(const Shape<N1> &first, const Shape<N2> &second, const Shape<Nn> &...shapes) {
if constexpr (sizeof...(Nn) == 0) {
template<typename First, typename Second, typename... Rest,
std::enable_if_t<typetraits::IsSizeType<First>::value &&
typetraits::IsSizeType<Second>::value &&
(typetraits::IsSizeType<Rest>::value && ...),
int> = 0>
LIBRAPID_NODISCARD LIBRAPID_INLINE bool shapesMatch(const First &first, const Second &second,
const Rest &...shapes) {
if constexpr (sizeof...(Rest) == 0) {
return first == second;
} else {
return first == second && shapesMatch(first, shapes...);
}
}

/// \sa shapesMatch
template<size_t N1, size_t N2, size_t... Nn>
template<typename First, typename Second, typename... Rest,
std::enable_if_t<typetraits::IsSizeType<First>::value &&
typetraits::IsSizeType<Second>::value &&
(typetraits::IsSizeType<Rest>::value && ...),
int> = 0>
LIBRAPID_NODISCARD LIBRAPID_INLINE bool
shapesMatch(const std::tuple<Shape<N1>, Shape<N2>, Shape<Nn>...> &shapes) {
if constexpr (sizeof...(Nn) == 0) {
shapesMatch(const std::tuple<First, Second, Rest...> &shapes) {
if constexpr (sizeof...(Rest) == 0) {
return std::get<0>(shapes) == std::get<1>(shapes);
} else {
return std::get<0>(shapes) == std::get<1>(shapes) &&
shapesMatch(std::apply(
[](auto, auto, auto... rest) { return std::make_tuple(rest...); }, shapes));
}
}

namespace typetraits {
template<typename T>
struct IsSizeType {
using value = std::false_type;
};

template<size_t N>
struct IsSizeType<Shape<N>> {
using value = std::true_type;
};
} // namespace typetraits
} // namespace librapid

// Support FMT printing
Expand All @@ -359,6 +533,27 @@ struct fmt::formatter<librapid::Shape<N>> {
using Base = fmt::formatter<SizeType, char>;
Base m_base;

public:
template<typename ParseContext>
FMT_CONSTEXPR auto parse(ParseContext &ctx) -> const char * {
return m_base.parse(ctx);
}

template<typename FormatContext>
FMT_CONSTEXPR auto format(const Type &val, FormatContext &ctx) const -> decltype(ctx.out()) {
val.str(m_base, ctx);
return ctx.out();
}
};

template<>
struct fmt::formatter<librapid::MatrixShape> {
private:
using Type = librapid::MatrixShape;
using SizeType = librapid::MatrixShape::SizeType;
using Base = fmt::formatter<SizeType, char>;
Base m_base;

public:
template<typename ParseContext>
FMT_CONSTEXPR auto parse(ParseContext &ctx) -> const char * {
Expand Down
2 changes: 2 additions & 0 deletions librapid/include/librapid/core/forward.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ namespace librapid {
template<size_t N>
class Shape;

class MatrixShape;

template<size_t N>
class Stride;

Expand Down
Loading

0 comments on commit 479d7c7

Please sign in to comment.