Skip to content

Commit 012d623

Browse files
tlrmchlsmthProExpertProg
authored andcommitted
Allow scalar broadcasting in VisitorRowBroadcast and VisitorColBroadcast
1 parent 56b46e2 commit 012d623

File tree

1 file changed

+39
-5
lines changed

1 file changed

+39
-5
lines changed

include/cutlass/epilogue/threadblock/fusion/visitor_load.hpp

+39-5
Original file line numberDiff line numberDiff line change
@@ -403,10 +403,29 @@ struct VisitorRowBroadcast {
403403
auto src_v = filter(tC_gRow);
404404
auto coord_v = filter(tC_cRow);
405405
auto dst_v = filter(tC_rRow);
406-
CUTLASS_PRAGMA_UNROLL
407-
for (int i = 0; i < size(src_v); ++i) {
408-
bool guard = get<1>(coord_v(i)) < n;
409-
cutlass::arch::global_load<VecType, sizeof(VecType)>(dst_v(i), (void const*)&src_v(i), guard);
406+
407+
if (params_ptr->ptr_row) {
408+
// In this case we are loading from a row vector and broadcasting
409+
CUTLASS_PRAGMA_UNROLL
410+
for (int i = 0; i < size(src_v); ++i) {
411+
bool guard = get<1>(coord_v(i)) < n;
412+
cutlass::arch::global_load<VecType, sizeof(VecType)>(dst_v(i), (void const*)&src_v(i), guard);
413+
}
414+
} else {
415+
// In this case we are loading from a scalar and broadcasting
416+
VecType filled_vec;
417+
CUTLASS_PRAGMA_UNROLL
418+
for (int i = 0; i < VecLength; i++) {
419+
reinterpret_cast<Element*>(&filled_vec)[i] = params_ptr->null_default;
420+
}
421+
422+
CUTLASS_PRAGMA_UNROLL
423+
for (int i = 0; i < size(src_v); ++i) {
424+
if(get<1>(coord_v(i)) < n)
425+
{
426+
dst_v(i) = filled_vec;
427+
}
428+
}
410429
}
411430
}
412431

@@ -524,12 +543,27 @@ struct VisitorColBroadcast {
524543
CUTLASS_DEVICE void
525544
begin_epilogue() {
526545
clear(tC_rCol);
546+
527547
Tensor pred = make_tensor<bool>(shape(tC_gCol));
528548
CUTLASS_PRAGMA_UNROLL
529549
for (int i = 0; i < size(pred); ++i) {
530550
pred(i) = get<0>(tC_cCol(i)) < m;
531551
}
532-
copy_if(pred, tC_gCol, tC_rCol);
552+
553+
if (params_ptr->ptr_col) {
554+
// In this case we are loading from a column vector and broadcasting
555+
copy_if(pred, tC_gCol, tC_rCol);
556+
} else {
557+
// In this case we are loading from a scalar and broadcasting
558+
auto dst_v = filter(tC_rCol);
559+
560+
CUTLASS_PRAGMA_UNROLL
561+
for (int i = 0; i < size(dst_v); ++i) {
562+
if(pred(i)){
563+
dst_v(i) = params_ptr->null_default;
564+
}
565+
}
566+
}
533567
}
534568

535569
template <class ElementAccumulator, int FragmentSize>

0 commit comments

Comments
 (0)