@@ -403,10 +403,29 @@ struct VisitorRowBroadcast {
403
403
auto src_v = filter (tC_gRow);
404
404
auto coord_v = filter (tC_cRow);
405
405
auto dst_v = filter (tC_rRow);
406
- CUTLASS_PRAGMA_UNROLL
407
- for (int i = 0 ; i < size (src_v); ++i) {
408
- bool guard = get<1 >(coord_v (i)) < n;
409
- cutlass::arch::global_load<VecType, sizeof (VecType)>(dst_v (i), (void const *)&src_v (i), guard);
406
+
407
+ if (params_ptr->ptr_row ) {
408
+ // In this case we are loading from a row vector and broadcasting
409
+ CUTLASS_PRAGMA_UNROLL
410
+ for (int i = 0 ; i < size (src_v); ++i) {
411
+ bool guard = get<1 >(coord_v (i)) < n;
412
+ cutlass::arch::global_load<VecType, sizeof (VecType)>(dst_v (i), (void const *)&src_v (i), guard);
413
+ }
414
+ } else {
415
+ // In this case we are loading from a scalar and broadcasting
416
+ VecType filled_vec;
417
+ CUTLASS_PRAGMA_UNROLL
418
+ for (int i = 0 ; i < VecLength; i++) {
419
+ reinterpret_cast <Element*>(&filled_vec)[i] = params_ptr->null_default ;
420
+ }
421
+
422
+ CUTLASS_PRAGMA_UNROLL
423
+ for (int i = 0 ; i < size (src_v); ++i) {
424
+ if (get<1 >(coord_v (i)) < n)
425
+ {
426
+ dst_v (i) = filled_vec;
427
+ }
428
+ }
410
429
}
411
430
}
412
431
@@ -524,12 +543,27 @@ struct VisitorColBroadcast {
524
543
CUTLASS_DEVICE void
525
544
begin_epilogue () {
526
545
clear (tC_rCol);
546
+
527
547
Tensor pred = make_tensor<bool >(shape (tC_gCol));
528
548
CUTLASS_PRAGMA_UNROLL
529
549
for (int i = 0 ; i < size (pred); ++i) {
530
550
pred (i) = get<0 >(tC_cCol (i)) < m;
531
551
}
532
- copy_if (pred, tC_gCol, tC_rCol);
552
+
553
+ if (params_ptr->ptr_col ) {
554
+ // In this case we are loading from a column vector and broadcasting
555
+ copy_if (pred, tC_gCol, tC_rCol);
556
+ } else {
557
+ // In this case we are loading from a scalar and broadcasting
558
+ auto dst_v = filter (tC_rCol);
559
+
560
+ CUTLASS_PRAGMA_UNROLL
561
+ for (int i = 0 ; i < size (dst_v); ++i) {
562
+ if (pred (i)){
563
+ dst_v (i) = params_ptr->null_default ;
564
+ }
565
+ }
566
+ }
533
567
}
534
568
535
569
template <class ElementAccumulator , int FragmentSize>
0 commit comments