Skip to content

Commit

Permalink
Added template parameter EnableNullptr (default false)
Browse files Browse the repository at this point in the history
  • Loading branch information
ProExpertProg committed Jul 17, 2024
1 parent 012d623 commit 0b6c76e
Showing 1 changed file with 38 additions and 31 deletions.
69 changes: 38 additions & 31 deletions include/cutlass/epilogue/threadblock/fusion/visitor_load.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,8 @@ struct VisitorAuxLoad{
template<
class ThreadMap,
class Element,
class StrideMNL
class StrideMNL,
bool EnableNullptr = false
>
struct VisitorRowBroadcast {

Expand Down Expand Up @@ -404,29 +405,31 @@ struct VisitorRowBroadcast {
auto coord_v = filter(tC_cRow);
auto dst_v = filter(tC_rRow);

if (params_ptr->ptr_row) {
// In this case we are loading from a row vector and broadcasting
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(src_v); ++i) {
bool guard = get<1>(coord_v(i)) < n;
cutlass::arch::global_load<VecType, sizeof(VecType)>(dst_v(i), (void const*)&src_v(i), guard);
}
} else {
// In this case we are loading from a scalar and broadcasting
VecType filled_vec;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < VecLength; i++) {
reinterpret_cast<Element*>(&filled_vec)[i] = params_ptr->null_default;
}
if constexpr (EnableNullptr) {
if (params_ptr->ptr_row == nullptr) {
// In this case we are loading from a scalar and broadcasting
VecType filled_vec;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < VecLength; i++) {
reinterpret_cast<Element *>(&filled_vec)[i] = params_ptr->null_default;
}

CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(src_v); ++i) {
if(get<1>(coord_v(i)) < n)
{
dst_v(i) = filled_vec;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(src_v); ++i) {
if (get<1>(coord_v(i)) < n) {
dst_v(i) = filled_vec;
}
}
return;
}
}

// In this case we are loading from a row vector and broadcasting
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(src_v); ++i) {
bool guard = get<1>(coord_v(i)) < n;
cutlass::arch::global_load<VecType, sizeof(VecType)>(dst_v(i), (void const *)&src_v(i), guard);
}
}

template <class ElementAccumulator, int FragmentSize>
Expand Down Expand Up @@ -483,7 +486,8 @@ struct VisitorRowBroadcast {
template<
class ThreadMap,
class Element,
class StrideMNL = Stride<_1,_0,_0>
class StrideMNL = Stride<_1,_0,_0>,
bool EnableNullptr = false
>
struct VisitorColBroadcast {

Expand Down Expand Up @@ -550,20 +554,23 @@ struct VisitorColBroadcast {
pred(i) = get<0>(tC_cCol(i)) < m;
}

if (params_ptr->ptr_col) {
// In this case we are loading from a column vector and broadcasting
copy_if(pred, tC_gCol, tC_rCol);
} else {
// In this case we are loading from a scalar and broadcasting
auto dst_v = filter(tC_rCol);
if constexpr (EnableNullptr) {
if (params_ptr->ptr_col == nullptr) {
// In this case we are loading from a scalar and broadcasting
auto dst_v = filter(tC_rCol);

CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(dst_v); ++i) {
if(pred(i)){
dst_v(i) = params_ptr->null_default;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(dst_v); ++i) {
if (pred(i)) {
dst_v(i) = params_ptr->null_default;
}
}
return;
}
}

// In this case we are loading from a column vector and broadcasting
copy_if(pred, tC_gCol, tC_rCol);
}

template <class ElementAccumulator, int FragmentSize>
Expand Down

0 comments on commit 0b6c76e

Please sign in to comment.