Skip to content

Commit 2de270f

Browse files
Fix tch view data corruption (#1434)
1 parent 61c0474 commit 2de270f

File tree

2 files changed

+97
-12
lines changed

2 files changed

+97
-12
lines changed

crates/burn-tch/src/tensor.rs

+85-12
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,58 @@ use std::{marker::PhantomData, sync::Arc};
99
#[allow(clippy::arc_with_non_send_sync)]
1010
pub type StorageRef = Arc<*mut c_void>;
1111

12+
/// A reference to a tensor storage.
13+
#[derive(PartialEq, Debug, Clone)]
14+
pub enum Storage {
15+
/// When a tensor is a partial view of another tensor.
16+
View {
17+
/// Storage reference for the whole buffer.
18+
buffer_ref: StorageRef,
19+
/// Storage reference for the partial buffer.
20+
view_ref: StorageRef,
21+
},
22+
/// When a tensor use all of its buffer.
23+
Owned {
24+
/// Storage reference for the whole buffer.
25+
buffer_ref: StorageRef,
26+
},
27+
}
28+
29+
impl Storage {
30+
/// Check if the storage can be used inplace.
31+
pub fn can_mut(&self) -> bool {
32+
match self {
33+
Storage::View {
34+
buffer_ref: start_ref,
35+
view_ref,
36+
} => Arc::strong_count(start_ref) == 1 && Arc::strong_count(view_ref) == 1,
37+
Storage::Owned {
38+
buffer_ref: start_ref,
39+
} => Arc::strong_count(start_ref) == 1,
40+
}
41+
}
42+
43+
/// Get the whole buffer reference.
44+
pub fn buffer_ref(&self) -> &StorageRef {
45+
match self {
46+
Storage::View {
47+
buffer_ref: start_ref,
48+
view_ref: _,
49+
} => start_ref,
50+
Storage::Owned {
51+
buffer_ref: start_ref,
52+
} => start_ref,
53+
}
54+
}
55+
}
56+
1257
/// A tensor that uses the tch backend.
1358
#[derive(Debug, PartialEq)]
1459
pub struct TchTensor<E: tch::kind::Element, const D: usize> {
1560
/// Handle to the tensor. Call methods on this field.
1661
pub tensor: tch::Tensor,
1762
/// The tensor's storage
18-
pub storage: StorageRef,
63+
pub storage: Storage,
1964
phantom: PhantomData<E>,
2065
}
2166

@@ -27,26 +72,49 @@ impl<E: tch::kind::Element, const D: usize> TchTensor<E, D> {
2772
/// instead.
2873
pub fn new(tensor: tch::Tensor) -> Self {
2974
#[allow(clippy::arc_with_non_send_sync)]
30-
let data = Arc::new(tensor.data_ptr());
75+
let storage = Storage::Owned {
76+
buffer_ref: Arc::new(tensor.data_ptr()),
77+
};
3178

3279
Self {
3380
tensor,
3481
phantom: PhantomData,
35-
storage: data,
82+
storage,
3683
}
3784
}
3885

3986
/// Create a tensor that was created from an operation executed on a parent tensor.
4087
///
4188
/// If the child tensor shared the same storage as its parent, it will be cloned, effectively
4289
/// tracking how much tensors point to the same memory space.
43-
pub fn from_existing(tensor: tch::Tensor, storage_parent: StorageRef) -> Self {
90+
pub fn from_existing(tensor: tch::Tensor, storage_parent: Storage) -> Self {
4491
let storage_child = tensor.data_ptr();
92+
let mut is_a_new_tensor = true;
93+
94+
match &storage_parent {
95+
Storage::View {
96+
buffer_ref: start_ref,
97+
view_ref,
98+
} => {
99+
if storage_child == *start_ref.as_ref() || storage_child == *view_ref.as_ref() {
100+
is_a_new_tensor = false;
101+
}
102+
}
103+
Storage::Owned {
104+
buffer_ref: start_ref,
105+
} => {
106+
if storage_child == *start_ref.as_ref() {
107+
is_a_new_tensor = false;
108+
}
109+
}
110+
};
45111

46-
#[allow(clippy::arc_with_non_send_sync)]
47-
let storage = match storage_child == *storage_parent {
48-
true => storage_parent.clone(),
49-
false => Arc::new(storage_child),
112+
let storage = match is_a_new_tensor {
113+
true => Storage::Owned {
114+
#[allow(clippy::arc_with_non_send_sync)]
115+
buffer_ref: Arc::new(storage_child),
116+
},
117+
false => storage_parent.clone(),
50118
};
51119

52120
Self {
@@ -57,10 +125,15 @@ impl<E: tch::kind::Element, const D: usize> TchTensor<E, D> {
57125
}
58126

59127
/// Create a tensor that uses a part of its parent tensor such as slice and narrow.
60-
pub fn partial(tensor: tch::Tensor, storage_parent: StorageRef) -> Self {
128+
pub fn partial(tensor: tch::Tensor, storage_parent: Storage) -> Self {
129+
let storage = Storage::View {
130+
buffer_ref: storage_parent.buffer_ref().clone(),
131+
#[allow(clippy::arc_with_non_send_sync)]
132+
view_ref: Arc::new(tensor.data_ptr()),
133+
};
61134
Self {
62135
tensor,
63-
storage: storage_parent,
136+
storage,
64137
phantom: PhantomData,
65138
}
66139
}
@@ -96,7 +169,7 @@ impl<P: tch::kind::Element, const D: usize> TchTensor<P, D> {
96169
&mut self,
97170
func: F,
98171
) -> Option<TchTensor<EOut, D_OUT>> {
99-
if Arc::strong_count(&self.storage) > 1 {
172+
if !self.storage.can_mut() {
100173
return None;
101174
}
102175

@@ -113,7 +186,7 @@ impl<P: tch::kind::Element, const D: usize> TchTensor<P, D> {
113186
FOwn: Fn(tch::Tensor) -> tch::Tensor,
114187
FRef: Fn(&tch::Tensor) -> tch::Tensor,
115188
{
116-
if Arc::strong_count(&self.storage) > 1 {
189+
if !self.storage.can_mut() {
117190
return TchTensor::from_existing(fref(&self.tensor), self.storage);
118191
}
119192

crates/burn-tensor/src/tests/ops/reshape.rs

+12
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,18 @@ mod tests {
7070
assert_eq!(reshaped.shape(), [4, 3].into());
7171
}
7272

73+
#[test]
74+
fn should_not_corrupt_after_slice() {
75+
let zeros = Tensor::<TestBackend, 1>::zeros([2], &Default::default());
76+
zeros.clone().slice([1..2]).reshape([1]).exp();
77+
78+
// May lead to zeroes being equal to [0.0, 1.0]
79+
assert_eq!(
80+
zeros.to_data(),
81+
Tensor::<TestBackend, 1>::zeros([2], &Default::default()).to_data()
82+
);
83+
}
84+
7385
#[test]
7486
#[should_panic]
7587
fn multiple_neg_ones() {

0 commit comments

Comments
 (0)