@@ -106,10 +106,14 @@ namespace librapid {
106
106
// / \return Class of the array multiplication
107
107
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE MatmulClass matmulClass () const ;
108
108
109
+ LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ShapeType calculateShape () const ;
110
+
109
111
// / \brief Determine the shape of the result
110
112
// / \return Shape of the result
111
113
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ShapeType shape () const ;
112
114
115
+ LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE size_t size () const ;
116
+
113
117
// / \brief Determine the number of dimensions of the result
114
118
// / \return Number of dimensions of the result
115
119
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE int64_t ndim () const ;
@@ -173,6 +177,9 @@ namespace librapid {
173
177
ScalarA m_alpha; // Scaling factor for A
174
178
TypeB m_b; // Second array
175
179
ScalarB m_beta; // Scaling factor for B
180
+
181
+ ShapeType m_shape;
182
+ size_t m_size;
176
183
};
177
184
178
185
template <typename ShapeTypeA, typename StorageTypeA, typename ShapeTypeB,
@@ -182,23 +189,26 @@ namespace librapid {
182
189
TypeB &&b, Beta beta) :
183
190
m_transA (transA),
184
191
m_transB(transB), m_a(std::forward<TypeA>(a)), m_alpha(static_cast <ScalarA>(alpha)),
185
- m_b(std::forward<TypeB>(b)), m_beta(static_cast <ScalarB>(beta)) {}
192
+ m_b(std::forward<TypeB>(b)), m_beta(static_cast <ScalarB>(beta)),
193
+ m_shape(calculateShape()), m_size(m_shape.size()) {}
186
194
187
195
template <typename ShapeTypeA, typename StorageTypeA, typename ShapeTypeB,
188
196
typename StorageTypeB, typename Alpha, typename Beta>
189
197
ArrayMultiply<ShapeTypeA, StorageTypeA, ShapeTypeB, StorageTypeB, Alpha,
190
198
Beta>::ArrayMultiply(TypeA &&a, TypeB &&b) :
191
199
m_transA (false ),
192
200
m_transB(false ), m_a(std::forward<TypeA>(a)), m_alpha(1 ),
193
- m_b(std::forward<TypeB>(b)), m_beta(0 ) {}
201
+ m_b(std::forward<TypeB>(b)), m_beta(0 ), m_shape(calculateShape()),
202
+ m_size(m_shape.size()) {}
194
203
195
204
template <typename ShapeTypeA, typename StorageTypeA, typename ShapeTypeB,
196
205
typename StorageTypeB, typename Alpha, typename Beta>
197
206
ArrayMultiply<ShapeTypeA, StorageTypeA, ShapeTypeB, StorageTypeB, Alpha,
198
207
Beta>::ArrayMultiply(bool transA, bool transB, TypeA &&a, TypeB &&b) :
199
208
m_transA (transA),
200
209
m_transB(transB), m_a(std::forward<TypeA>(a)), m_alpha(1 ),
201
- m_b(std::forward<TypeB>(b)), m_beta(0 ) {}
210
+ m_b(std::forward<TypeB>(b)), m_beta(0 ), m_shape(calculateShape()),
211
+ m_size(m_shape.size()) {}
202
212
203
213
template <typename ShapeTypeA, typename StorageTypeA, typename ShapeTypeB,
204
214
typename StorageTypeB, typename Alpha, typename Beta>
@@ -268,8 +278,8 @@ namespace librapid {
268
278
269
279
template <typename ShapeTypeA, typename StorageTypeA, typename ShapeTypeB,
270
280
typename StorageTypeB, typename Alpha, typename Beta>
271
- auto ArrayMultiply<ShapeTypeA, StorageTypeA, ShapeTypeB, StorageTypeB, Alpha, Beta>::shape()
272
- const -> ShapeType {
281
+ auto ArrayMultiply<ShapeTypeA, StorageTypeA, ShapeTypeB, StorageTypeB, Alpha,
282
+ Beta>::calculateShape() const -> ShapeType {
273
283
const auto &shapeA = m_a.shape ();
274
284
const auto &shapeB = m_b.shape ();
275
285
MatmulClass matmulClass = this ->matmulClass ();
@@ -294,6 +304,21 @@ namespace librapid {
294
304
return {1 };
295
305
}
296
306
307
+ template <typename ShapeTypeA, typename StorageTypeA, typename ShapeTypeB,
308
+ typename StorageTypeB, typename Alpha, typename Beta>
309
+ auto ArrayMultiply<ShapeTypeA, StorageTypeA, ShapeTypeB, StorageTypeB, Alpha, Beta>::shape()
310
+ const -> ShapeType {
311
+ return m_shape;
312
+ }
313
+
314
+ template <typename ShapeTypeA, typename StorageTypeA, typename ShapeTypeB,
315
+ typename StorageTypeB, typename Alpha, typename Beta>
316
+ auto
317
+ ArrayMultiply<ShapeTypeA, StorageTypeA, ShapeTypeB, StorageTypeB, Alpha, Beta>::size() const
318
+ -> size_t {
319
+ return m_size;
320
+ }
321
+
297
322
template <typename ShapeTypeA, typename StorageTypeA, typename ShapeTypeB,
298
323
typename StorageTypeB, typename Alpha, typename Beta>
299
324
auto
0 commit comments