From 0b6c76e2d9158a9ec4e4a32853d15862a100c87f Mon Sep 17 00:00:00 2001 From: luka Date: Wed, 17 Jul 2024 14:03:14 -0400 Subject: [PATCH] Added template parameter EnableNullptr (default false) --- .../threadblock/fusion/visitor_load.hpp | 69 ++++++++++--------- 1 file changed, 38 insertions(+), 31 deletions(-) diff --git a/include/cutlass/epilogue/threadblock/fusion/visitor_load.hpp b/include/cutlass/epilogue/threadblock/fusion/visitor_load.hpp index 623845913e..221ddf4479 100644 --- a/include/cutlass/epilogue/threadblock/fusion/visitor_load.hpp +++ b/include/cutlass/epilogue/threadblock/fusion/visitor_load.hpp @@ -335,7 +335,8 @@ struct VisitorAuxLoad{ template< class ThreadMap, class Element, - class StrideMNL + class StrideMNL, + bool EnableNullptr = false > struct VisitorRowBroadcast { @@ -404,29 +405,31 @@ struct VisitorRowBroadcast { auto coord_v = filter(tC_cRow); auto dst_v = filter(tC_rRow); - if (params_ptr->ptr_row) { - // In this case we are loading from a row vector and broadcasting - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(src_v); ++i) { - bool guard = get<1>(coord_v(i)) < n; - cutlass::arch::global_load(dst_v(i), (void const*)&src_v(i), guard); - } - } else { - // In this case we are loading from a scalar and broadcasting - VecType filled_vec; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < VecLength; i++) { - reinterpret_cast(&filled_vec)[i] = params_ptr->null_default; - } + if constexpr (EnableNullptr) { + if (params_ptr->ptr_row == nullptr) { + // In this case we are loading from a scalar and broadcasting + VecType filled_vec; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < VecLength; i++) { + reinterpret_cast(&filled_vec)[i] = params_ptr->null_default; + } - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(src_v); ++i) { - if(get<1>(coord_v(i)) < n) - { - dst_v(i) = filled_vec; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(src_v); ++i) { + if (get<1>(coord_v(i)) < n) { + dst_v(i) = filled_vec; + } } + return; } } + + // In this case we are loading from a row vector and broadcasting + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(src_v); ++i) { + bool guard = get<1>(coord_v(i)) < n; + cutlass::arch::global_load(dst_v(i), (void const *)&src_v(i), guard); + } } template @@ -483,7 +486,8 @@ struct VisitorRowBroadcast { template< class ThreadMap, class Element, - class StrideMNL = Stride<_1,_0,_0> + class StrideMNL = Stride<_1,_0,_0>, + bool EnableNullptr = false > struct VisitorColBroadcast { @@ -550,20 +554,23 @@ struct VisitorColBroadcast { pred(i) = get<0>(tC_cCol(i)) < m; } - if (params_ptr->ptr_col) { - // In this case we are loading from a column vector and broadcasting - copy_if(pred, tC_gCol, tC_rCol); - } else { - // In this case we are loading from a scalar and broadcasting - auto dst_v = filter(tC_rCol); + if constexpr (EnableNullptr) { + if (params_ptr->ptr_col == nullptr) { + // In this case we are loading from a scalar and broadcasting + auto dst_v = filter(tC_rCol); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(dst_v); ++i) { - if(pred(i)){ - dst_v(i) = params_ptr->null_default; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(dst_v); ++i) { + if (pred(i)) { + dst_v(i) = params_ptr->null_default; + } } + return; } } + + // In this case we are loading from a column vector and broadcasting + copy_if(pred, tC_gCol, tC_rCol); } template