@@ -45,4 +45,66 @@ mod tests {
45
45
let data_expected = Data :: from ( [ [ 0 , 1 , 2 ] , [ 0 , 1 , 2 ] , [ 0 , 1 , 2 ] , [ 0 , 1 , 2 ] ] ) ;
46
46
assert_eq ! ( data_expected, data_actual) ;
47
47
}
48
+
49
+ #[ test]
50
+ fn should_support_float_repeat_on_dims_larger_than_1 ( ) {
51
+ let data = Data :: from ( [
52
+ [ [ 1.0 , 2.0 ] , [ 3.0 , 4.0 ] ] ,
53
+ [ [ 5.0 , 6.0 ] , [ 7.0 , 8.0 ] ] ,
54
+ [ [ 9.0 , 10.0 ] , [ 11.0 , 12.0 ] ] ,
55
+ [ [ 13.0 , 14.0 ] , [ 15.0 , 16.0 ] ] ,
56
+ ] ) ;
57
+ let tensor = Tensor :: < TestBackend , 3 > :: from_data ( data, & Default :: default ( ) ) ;
58
+
59
+ let data_actual = tensor. repeat ( 2 , 2 ) . into_data ( ) ;
60
+
61
+ let data_expected = Data :: from ( [
62
+ [ [ 1.0 , 2.0 , 1.0 , 2.0 ] , [ 3.0 , 4.0 , 3.0 , 4.0 ] ] ,
63
+ [ [ 5.0 , 6.0 , 5.0 , 6.0 ] , [ 7.0 , 8.0 , 7.0 , 8.0 ] ] ,
64
+ [ [ 9.0 , 10.0 , 9.0 , 10.0 ] , [ 11.0 , 12.0 , 11.0 , 12.0 ] ] ,
65
+ [ [ 13.0 , 14.0 , 13.0 , 14.0 ] , [ 15.0 , 16.0 , 15.0 , 16.0 ] ] ,
66
+ ] ) ;
67
+
68
+ assert_eq ! ( data_expected, data_actual) ;
69
+ }
70
+
71
+ #[ test]
72
+ fn should_support_int_repeat_on_dims_larger_than_1 ( ) {
73
+ let data = Data :: from ( [
74
+ [ [ 1 , 2 ] , [ 3 , 4 ] ] ,
75
+ [ [ 5 , 6 ] , [ 7 , 8 ] ] ,
76
+ [ [ 9 , 10 ] , [ 11 , 12 ] ] ,
77
+ [ [ 13 , 14 ] , [ 15 , 16 ] ] ,
78
+ ] ) ;
79
+ let tensor = Tensor :: < TestBackend , 3 , Int > :: from_data ( data, & Default :: default ( ) ) ;
80
+
81
+ let data_actual = tensor. repeat ( 2 , 3 ) . into_data ( ) ;
82
+
83
+ let data_expected = Data :: from ( [
84
+ [ [ 1 , 2 , 1 , 2 , 1 , 2 ] , [ 3 , 4 , 3 , 4 , 3 , 4 ] ] ,
85
+ [ [ 5 , 6 , 5 , 6 , 5 , 6 ] , [ 7 , 8 , 7 , 8 , 7 , 8 ] ] ,
86
+ [ [ 9 , 10 , 9 , 10 , 9 , 10 ] , [ 11 , 12 , 11 , 12 , 11 , 12 ] ] ,
87
+ [ [ 13 , 14 , 13 , 14 , 13 , 14 ] , [ 15 , 16 , 15 , 16 , 15 , 16 ] ] ,
88
+ ] ) ;
89
+
90
+ assert_eq ! ( data_expected, data_actual) ;
91
+ }
92
+
93
+ #[ test]
94
+ fn should_support_bool_repeat_on_dims_larger_than_1 ( ) {
95
+ let data = Data :: from ( [
96
+ [ [ false , true ] , [ true , false ] ] ,
97
+ [ [ true , true ] , [ false , false ] ] ,
98
+ ] ) ;
99
+ let tensor = Tensor :: < TestBackend , 3 , Bool > :: from_data ( data, & Default :: default ( ) ) ;
100
+
101
+ let data_actual = tensor. repeat ( 1 , 2 ) . into_data ( ) ;
102
+
103
+ let data_expected = Data :: from ( [
104
+ [ [ false , true ] , [ true , false ] , [ false , true ] , [ true , false ] ] ,
105
+ [ [ true , true ] , [ false , false ] , [ true , true ] , [ false , false ] ] ,
106
+ ] ) ;
107
+
108
+ assert_eq ! ( data_expected, data_actual) ;
109
+ }
48
110
}
0 commit comments