4
4
//! DASIP, 2018
5
5
6
6
use crate :: {
7
- backends:: jit:: connected_components:: stats_from_opts, ConnectedStatsOptions ,
8
- ConnectedStatsPrimitive , Connectivity ,
7
+ backends:: jit:: { connected_components:: stats_from_opts, prefix_sum :: prefix_sum } ,
8
+ ConnectedStatsOptions , ConnectedStatsPrimitive , Connectivity ,
9
9
} ;
10
10
use burn_jit:: {
11
11
kernel,
12
12
ops:: { into_data_sync, numeric:: zeros_device} ,
13
13
tensor:: JitTensor ,
14
14
BoolElement , FloatElement , IntElement , JitBackend , JitRuntime ,
15
15
} ;
16
- use burn_tensor:: ops:: IntTensorOps ;
17
- use cubecl:: { calculate_cube_count_elemwise , prelude:: * , Feature } ;
16
+ use burn_tensor:: { ops:: IntTensorOps , Shape } ;
17
+ use cubecl:: { prelude:: * , Feature } ;
18
18
19
19
const BLOCK_H : u32 = 4 ;
20
20
@@ -380,6 +380,7 @@ fn analysis<I: Int, BT: CubePrimitive>(
380
380
while label != u32:: cast_from ( labels[ b_offs + label] ) - 1 {
381
381
label = u32:: cast_from ( labels[ b_offs + label] ) - 1 ;
382
382
}
383
+ label += 1 ;
383
384
384
385
Atomic :: add ( & area[ b_offs + label] , I :: cast_from ( count) ) ;
385
386
@@ -397,13 +398,17 @@ fn analysis<I: Int, BT: CubePrimitive>(
397
398
label = plane_broadcast ( label, UNIT_POS_X - s_dist) ;
398
399
399
400
if p {
400
- labels[ labels_index] = I :: cast_from ( label + 1 ) ;
401
+ labels[ labels_index] = I :: cast_from ( label) ;
401
402
}
402
403
}
403
404
}
404
405
405
406
#[ cube( launch) ]
406
- fn compact_labels < I : Int > ( labels : & mut Tensor < I > , remap : & Tensor < I > ) {
407
+ fn compact_labels < I : Int > (
408
+ labels : & mut Tensor < I > ,
409
+ remap : & Tensor < I > ,
410
+ max_label : & Tensor < Atomic < I > > ,
411
+ ) {
407
412
let batch = ABSOLUTE_POS_Z ;
408
413
let x = ABSOLUTE_POS_X ;
409
414
let y = ABSOLUTE_POS_Y ;
@@ -416,7 +421,9 @@ fn compact_labels<I: Int>(labels: &mut Tensor<I>, remap: &Tensor<I>) {
416
421
417
422
let label = u32:: cast_from ( labels[ labels_pos] ) ;
418
423
if label != 0 {
419
- labels[ labels_pos] = remap[ label] ;
424
+ let new_label = remap[ label] ;
425
+ labels[ labels_pos] = new_label;
426
+ Atomic :: max ( & max_label[ batch] , new_label) ;
420
427
}
421
428
}
422
429
@@ -433,11 +440,9 @@ fn compact_stats<I: Int>(
433
440
bottom : & Tensor < I > ,
434
441
bottom_new : & mut Tensor < I > ,
435
442
remap : & Tensor < I > ,
436
- max_label : u32 ,
437
- #[ comptime] opts : ConnectedStatsOptions ,
438
443
) {
439
444
let label = ABSOLUTE_POS_X ;
440
- if label > max_label {
445
+ if label >= remap . len ( ) {
441
446
terminate ! ( ) ;
442
447
}
443
448
@@ -448,12 +453,12 @@ fn compact_stats<I: Int>(
448
453
let new_label = u32:: cast_from ( remap[ label] ) ;
449
454
450
455
area_new[ new_label] = area;
451
- if opts . bounds_enabled {
452
- top_new [ new_label ] = top [ label ] ;
453
- left_new [ new_label] = left [ label] ;
454
- right_new [ new_label] = right [ label] ;
455
- bottom_new [ new_label] = bottom [ label] ;
456
- }
456
+ // This should be gated but there's a problem with the Eq bound only being implemented for tuples
457
+ // up to 12 elems, so I can't pass the opts. It's not unsafe, but potentially unnecessary work.
458
+ top_new [ new_label] = top [ label] ;
459
+ left_new [ new_label] = left [ label] ;
460
+ right_new [ new_label] = right [ label] ;
461
+ bottom_new [ new_label ] = bottom [ label ] ;
457
462
}
458
463
459
464
#[ allow( clippy:: type_complexity) ]
@@ -525,7 +530,7 @@ pub fn hardware_accelerated<R: JitRuntime, F: FloatElement, I: IntElement, BT: B
525
530
batches as u32 ,
526
531
) ;
527
532
528
- let stats = stats_from_opts ( labels. clone ( ) , stats_opt) ;
533
+ let mut stats = stats_from_opts ( labels. clone ( ) , stats_opt) ;
529
534
530
535
if stats_opt == ConnectedStatsOptions :: none ( ) {
531
536
relabeling:: launch :: < I , BT , R > (
@@ -553,28 +558,35 @@ pub fn hardware_accelerated<R: JitRuntime, F: FloatElement, I: IntElement, BT: B
553
558
if stats_opt. compact_labels {
554
559
let max_labels = into_data_sync :: < R , I > ( stats. max_label . clone ( ) ) . convert :: < u32 > ( ) ;
555
560
let max_label = * max_labels. as_slice :: < u32 > ( ) . unwrap ( ) . iter ( ) . max ( ) . unwrap ( ) as usize ;
556
- let sliced = kernel:: slice :: < R , I > ( stats. area . clone ( ) , & [ 0 ..batches, 0 ..max_label + 1 ] ) ;
561
+ let sliced = kernel:: slice :: < R , I > (
562
+ stats. area . clone ( ) ,
563
+ & [ 0 ..batches, 0 ..( max_label + 1 ) . next_multiple_of ( 4 ) ] ,
564
+ ) ;
557
565
let present = JitBackend :: < R , F , I , BT > :: int_not_equal_elem ( sliced, I :: new ( 0 ) ) ;
558
- let relabel = JitBackend :: < R , F , I , BT > :: int_prefix_sum ( present) ;
566
+ let present = kernel:: cast :: < R , BT , I > ( present) ;
567
+ let relabel = prefix_sum :: < R , I > ( present) ;
559
568
560
569
let cube_dim = CubeDim :: default ( ) ;
561
570
let cube_count = CubeCount :: new_3d (
562
571
( cols as u32 ) . div_ceil ( cube_dim. x ) ,
563
572
( rows as u32 ) . div_ceil ( cube_dim. y ) ,
564
573
batches as u32 ,
565
574
) ;
566
- compact_labels:: launch (
575
+ stats. max_label =
576
+ zeros_device :: < R , I > ( client. clone ( ) , device. clone ( ) , Shape :: new ( [ batches] ) ) ;
577
+ compact_labels:: launch :: < I , R > (
567
578
& client,
568
579
cube_count,
569
580
cube_dim,
570
581
labels. as_tensor_arg :: < I > ( 1 ) ,
571
582
relabel. as_tensor_arg :: < I > ( 1 ) ,
583
+ stats. max_label . as_tensor_arg :: < I > ( 1 ) ,
572
584
) ;
573
585
574
586
let cube_dim = CubeDim :: new_1d ( 256 ) ;
575
587
let cube_count =
576
588
CubeCount :: new_3d ( ( rows * cols) . div_ceil ( 256 ) as u32 , 1 , batches as u32 ) ;
577
- compact_stats:: launch (
589
+ compact_stats:: launch :: < I , R > (
578
590
& client,
579
591
cube_count,
580
592
cube_dim,
@@ -589,8 +601,6 @@ pub fn hardware_accelerated<R: JitRuntime, F: FloatElement, I: IntElement, BT: B
589
601
stats. bottom . copy ( ) . as_tensor_arg :: < I > ( 1 ) ,
590
602
stats. bottom . as_tensor_arg :: < I > ( 1 ) ,
591
603
relabel. as_tensor_arg :: < I > ( 1 ) ,
592
- ScalarArg :: new ( max_label as u32 ) ,
593
- stats_opt,
594
604
) ;
595
605
}
596
606
}
0 commit comments