Skip to content

Commit 2e4c82f

Browse files
authored
Fix repeat for dims > 1 (#1713)
1 parent 3a02a54 commit 2e4c82f

File tree

12 files changed

+136
-90
lines changed

12 files changed

+136
-90
lines changed

crates/burn-fusion/src/ops/boolean.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -564,7 +564,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
564564

565565
let stream = tensor.stream;
566566
let mut shape = tensor.shape.clone();
567-
shape[dim] = times;
567+
shape[dim] *= times;
568568
let out = tensor.client.tensor_uninitialized(shape);
569569

570570
let desc = RepeatOperationDescription {

crates/burn-fusion/src/ops/float.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1620,7 +1620,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
16201620

16211621
let stream = tensor.stream;
16221622
let mut shape = tensor.shape.clone();
1623-
shape[dim] = times;
1623+
shape[dim] *= times;
16241624
let out = tensor.client.tensor_uninitialized(shape);
16251625

16261626
let desc = RepeatOperationDescription {

crates/burn-fusion/src/ops/int.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1665,7 +1665,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
16651665

16661666
let stream = tensor.stream;
16671667
let mut shape = tensor.shape.clone();
1668-
shape[dim] = times;
1668+
shape[dim] *= times;
16691669
let out = tensor.client.tensor_uninitialized(shape);
16701670

16711671
let desc = RepeatOperationDescription {

crates/burn-jit/src/kernel/index/repeat.rs

+12-13
Original file line numberDiff line numberDiff line change
@@ -38,19 +38,21 @@ impl RepeatComputeShader {
3838

3939
let stride_input = scope.create_local(Elem::UInt);
4040
let stride_output = scope.create_local(Elem::UInt);
41-
let shape_output = scope.create_local(Elem::UInt);
41+
let shape = scope.create_local(Elem::UInt);
4242

4343
for i in 0..self.rank {
44+
gpu!(scope, stride_input = stride(input, i));
45+
gpu!(scope, stride_output = stride(output, i));
4446
if i != self.dim {
45-
gpu!(scope, stride_input = stride(input, i));
46-
gpu!(scope, stride_output = stride(output, i));
47-
gpu!(scope, shape_output = shape(output, i));
48-
49-
gpu!(scope, offset_local = id / stride_output);
50-
gpu!(scope, offset_local = offset_local % shape_output);
51-
gpu!(scope, offset_local = offset_local * stride_input);
52-
gpu!(scope, offset_input += offset_local);
47+
gpu!(scope, shape = shape(output, i));
48+
} else {
49+
gpu!(scope, shape = shape(input, i));
5350
}
51+
52+
gpu!(scope, offset_local = id / stride_output);
53+
gpu!(scope, offset_local = offset_local % shape);
54+
gpu!(scope, offset_local = offset_local * stride_input);
55+
gpu!(scope, offset_input += offset_local);
5456
}
5557

5658
let result = scope.create_local(input.item());
@@ -108,12 +110,9 @@ pub(crate) fn repeat<R: Runtime, E: JitElement, const D1: usize>(
108110
times: usize,
109111
) -> JitTensor<R, E, D1> {
110112
let mut shape = input.shape.clone();
111-
if shape.dims[dim] != 1 {
112-
panic!("Can only repeat dimension with dim=1");
113-
}
114113

115114
// Create output handle
116-
shape.dims[dim] = times;
115+
shape.dims[dim] *= times;
117116
let num_elems_output = shape.num_elements();
118117
let handle = input
119118
.client

crates/burn-tensor/src/tensor/api/base.rs

-4
Original file line numberDiff line numberDiff line change
@@ -564,10 +564,6 @@ where
564564
}
565565

566566
/// Repeat the tensor along the given dimension.
567-
///
568-
/// # Panics
569-
///
570-
/// If the selected dimension more than one item.
571567
pub fn repeat(self, dim: usize, times: usize) -> Self {
572568
Self::new(K::repeat(self.primitive, dim, times))
573569
}

crates/burn-tensor/src/tensor/ops/bool_tensor.rs

+10-23
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
use super::{cat::cat_with_slice_assign, BoolTensor, Device, FloatTensor, IntTensor};
1+
use super::{
2+
cat::cat_with_slice_assign, repeat::repeat_with_slice_assign, BoolTensor, Device, FloatTensor,
3+
IntTensor,
4+
};
25
use crate::{
36
backend::Backend, chunk, narrow, tensor::Shape, Bool, Data, ElementConversion, Tensor,
47
};
@@ -174,28 +177,12 @@ pub trait BoolTensorOps<B: Backend> {
174177
dim: usize,
175178
times: usize,
176179
) -> BoolTensor<B, D> {
177-
let mut shape = Self::bool_shape(&tensor);
178-
if shape.dims[dim] != 1 {
179-
panic!("Can only repeat dimension with dim=1");
180-
}
181-
shape.dims[dim] = times;
182-
183-
let mut i = 0;
184-
let ranges_select_all = [0; D].map(|_| {
185-
let start = 0;
186-
let end = shape.dims[i];
187-
i += 1;
188-
start..end
189-
});
190-
191-
let mut tensor_output = Self::bool_empty(shape, &Self::bool_device(&tensor));
192-
for i in 0..times {
193-
let mut ranges = ranges_select_all.clone();
194-
ranges[dim] = i..i + 1;
195-
tensor_output = Self::bool_slice_assign(tensor_output, ranges, tensor.clone());
196-
}
197-
198-
tensor_output
180+
repeat_with_slice_assign::<B, D, Bool>(
181+
Tensor::<B, D, Bool>::from_primitive(tensor),
182+
dim,
183+
times,
184+
)
185+
.into_primitive()
199186
}
200187

201188
/// Concatenates the tensors along the given dimension.

crates/burn-tensor/src/tensor/ops/int_tensor.rs

+7-22
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use super::cat::cat_with_slice_assign;
2+
use super::repeat::repeat_with_slice_assign;
23
use super::{BoolTensor, Device, FloatTensor, IntElem, IntTensor};
34
use crate::Tensor;
45
use crate::{backend::Backend, tensor::Shape, Data, Distribution, ElementConversion, Int};
@@ -270,28 +271,12 @@ pub trait IntTensorOps<B: Backend> {
270271
dim: usize,
271272
times: usize,
272273
) -> IntTensor<B, D> {
273-
let mut shape = Self::int_shape(&tensor);
274-
if shape.dims[dim] != 1 {
275-
panic!("Can only repeat dimension with dim=1");
276-
}
277-
shape.dims[dim] = times;
278-
279-
let mut i = 0;
280-
let indices_select_all = [0; D].map(|_| {
281-
let start = 0;
282-
let end = shape.dims[i];
283-
i += 1;
284-
start..end
285-
});
286-
287-
let mut tensor_output = Self::int_empty(shape, &Self::int_device(&tensor));
288-
for i in 0..times {
289-
let mut indices = indices_select_all.clone();
290-
indices[dim] = i..i + 1;
291-
tensor_output = Self::int_slice_assign(tensor_output, indices, tensor.clone());
292-
}
293-
294-
tensor_output
274+
repeat_with_slice_assign::<B, D, Int>(
275+
Tensor::<B, D, Int>::from_primitive(tensor),
276+
dim,
277+
times,
278+
)
279+
.into_primitive()
295280
}
296281

297282
/// Concatenates the given tensors along the given dimension.

crates/burn-tensor/src/tensor/ops/modules/cat.rs

+1-3
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,8 @@ pub(crate) fn cat_with_slice_assign<B: Backend, const D: usize, K: TensorKind<B>
1919

2020
let mut i = 0;
2121
let indices_select_all = [0; D].map(|_| {
22-
let start = 0;
23-
let end = shape.dims[i];
2422
i += 1;
25-
start..end
23+
0..shape.dims[i - 1]
2624
});
2725

2826
let mut output_index = 0;

crates/burn-tensor/src/tensor/ops/modules/mod.rs

+2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ pub mod conv;
33

44
/// Module with cat operation
55
pub(crate) mod cat;
6+
/// Module with repeat operation
7+
pub(crate) mod repeat;
68
/// Module with unfold operations.
79
pub(crate) mod unfold;
810

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
use crate::{backend::Backend, BasicOps, Tensor, TensorKind};
2+
3+
pub(crate) fn repeat_with_slice_assign<
4+
B: Backend,
5+
const D: usize,
6+
K: TensorKind<B> + BasicOps<B>,
7+
>(
8+
tensor: Tensor<B, D, K>,
9+
dim: usize,
10+
times: usize,
11+
) -> Tensor<B, D, K> {
12+
let mut shape = tensor.shape();
13+
let device = tensor.device();
14+
15+
let original_dim_length = shape.dims[dim];
16+
shape.dims[dim] *= times;
17+
18+
let mut tensor_output = Tensor::empty(shape.clone(), &device);
19+
20+
let mut i = 0;
21+
let indices_select_all = [0; D].map(|_| {
22+
i += 1;
23+
0..shape.dims[i - 1]
24+
});
25+
26+
let mut output_index = 0;
27+
for _ in 0..times {
28+
let mut indices = indices_select_all.clone();
29+
indices[dim] = output_index..output_index + original_dim_length;
30+
output_index += original_dim_length;
31+
32+
tensor_output = tensor_output.slice_assign(indices, tensor.clone());
33+
}
34+
35+
tensor_output
36+
}

crates/burn-tensor/src/tensor/ops/tensor.rs

+3-22
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use super::cat::cat_with_slice_assign;
2+
use super::repeat::repeat_with_slice_assign;
23
use super::{BoolTensor, Device, FloatElem, FloatTensor, FullPrecisionBackend, IntElem, IntTensor};
34
use crate::backend::BackendBridge;
45
use crate::Tensor;
@@ -193,28 +194,8 @@ pub trait FloatTensorOps<B: Backend> {
193194
dim: usize,
194195
times: usize,
195196
) -> FloatTensor<B, D> {
196-
let mut shape = B::float_shape(&tensor);
197-
if shape.dims[dim] != 1 {
198-
panic!("Can only repeat dimension with dim=1");
199-
}
200-
shape.dims[dim] = times;
201-
202-
let mut i = 0;
203-
let indices_select_all = [0; D].map(|_| {
204-
let start = 0;
205-
let end = shape.dims[i];
206-
i += 1;
207-
start..end
208-
});
209-
210-
let mut tensor_output = B::float_empty(shape, &B::float_device(&tensor));
211-
for i in 0..times {
212-
let mut indices = indices_select_all.clone();
213-
indices[dim] = i..i + 1;
214-
tensor_output = B::float_slice_assign(tensor_output, indices, tensor.clone());
215-
}
216-
217-
tensor_output
197+
repeat_with_slice_assign::<B, D, Float>(Tensor::<B, D>::from_primitive(tensor), dim, times)
198+
.into_primitive()
218199
}
219200

220201
/// Adds two tensors together.

crates/burn-tensor/src/tests/ops/repeat.rs

+62
Original file line numberDiff line numberDiff line change
@@ -45,4 +45,66 @@ mod tests {
4545
let data_expected = Data::from([[0, 1, 2], [0, 1, 2], [0, 1, 2], [0, 1, 2]]);
4646
assert_eq!(data_expected, data_actual);
4747
}
48+
49+
#[test]
50+
fn should_support_float_repeat_on_dims_larger_than_1() {
51+
let data = Data::from([
52+
[[1.0, 2.0], [3.0, 4.0]],
53+
[[5.0, 6.0], [7.0, 8.0]],
54+
[[9.0, 10.0], [11.0, 12.0]],
55+
[[13.0, 14.0], [15.0, 16.0]],
56+
]);
57+
let tensor = Tensor::<TestBackend, 3>::from_data(data, &Default::default());
58+
59+
let data_actual = tensor.repeat(2, 2).into_data();
60+
61+
let data_expected = Data::from([
62+
[[1.0, 2.0, 1.0, 2.0], [3.0, 4.0, 3.0, 4.0]],
63+
[[5.0, 6.0, 5.0, 6.0], [7.0, 8.0, 7.0, 8.0]],
64+
[[9.0, 10.0, 9.0, 10.0], [11.0, 12.0, 11.0, 12.0]],
65+
[[13.0, 14.0, 13.0, 14.0], [15.0, 16.0, 15.0, 16.0]],
66+
]);
67+
68+
assert_eq!(data_expected, data_actual);
69+
}
70+
71+
#[test]
72+
fn should_support_int_repeat_on_dims_larger_than_1() {
73+
let data = Data::from([
74+
[[1, 2], [3, 4]],
75+
[[5, 6], [7, 8]],
76+
[[9, 10], [11, 12]],
77+
[[13, 14], [15, 16]],
78+
]);
79+
let tensor = Tensor::<TestBackend, 3, Int>::from_data(data, &Default::default());
80+
81+
let data_actual = tensor.repeat(2, 3).into_data();
82+
83+
let data_expected = Data::from([
84+
[[1, 2, 1, 2, 1, 2], [3, 4, 3, 4, 3, 4]],
85+
[[5, 6, 5, 6, 5, 6], [7, 8, 7, 8, 7, 8]],
86+
[[9, 10, 9, 10, 9, 10], [11, 12, 11, 12, 11, 12]],
87+
[[13, 14, 13, 14, 13, 14], [15, 16, 15, 16, 15, 16]],
88+
]);
89+
90+
assert_eq!(data_expected, data_actual);
91+
}
92+
93+
#[test]
94+
fn should_support_bool_repeat_on_dims_larger_than_1() {
95+
let data = Data::from([
96+
[[false, true], [true, false]],
97+
[[true, true], [false, false]],
98+
]);
99+
let tensor = Tensor::<TestBackend, 3, Bool>::from_data(data, &Default::default());
100+
101+
let data_actual = tensor.repeat(1, 2).into_data();
102+
103+
let data_expected = Data::from([
104+
[[false, true], [true, false], [false, true], [true, false]],
105+
[[true, true], [false, false], [true, true], [false, false]],
106+
]);
107+
108+
assert_eq!(data_expected, data_actual);
109+
}
48110
}

0 commit comments

Comments
 (0)