@@ -9,13 +9,58 @@ use std::{marker::PhantomData, sync::Arc};
9
9
#[ allow( clippy:: arc_with_non_send_sync) ]
10
10
pub type StorageRef = Arc < * mut c_void > ;
11
11
12
+ /// A reference to a tensor storage.
13
+ #[ derive( PartialEq , Debug , Clone ) ]
14
+ pub enum Storage {
15
+ /// When a tensor is a partial view of another tensor.
16
+ View {
17
+ /// Storage reference for the whole buffer.
18
+ buffer_ref : StorageRef ,
19
+ /// Storage reference for the partial buffer.
20
+ view_ref : StorageRef ,
21
+ } ,
22
+ /// When a tensor use all of its buffer.
23
+ Owned {
24
+ /// Storage reference for the whole buffer.
25
+ buffer_ref : StorageRef ,
26
+ } ,
27
+ }
28
+
29
+ impl Storage {
30
+ /// Check if the storage can be used inplace.
31
+ pub fn can_mut ( & self ) -> bool {
32
+ match self {
33
+ Storage :: View {
34
+ buffer_ref : start_ref,
35
+ view_ref,
36
+ } => Arc :: strong_count ( start_ref) == 1 && Arc :: strong_count ( view_ref) == 1 ,
37
+ Storage :: Owned {
38
+ buffer_ref : start_ref,
39
+ } => Arc :: strong_count ( start_ref) == 1 ,
40
+ }
41
+ }
42
+
43
+ /// Get the whole buffer reference.
44
+ pub fn buffer_ref ( & self ) -> & StorageRef {
45
+ match self {
46
+ Storage :: View {
47
+ buffer_ref : start_ref,
48
+ view_ref : _,
49
+ } => start_ref,
50
+ Storage :: Owned {
51
+ buffer_ref : start_ref,
52
+ } => start_ref,
53
+ }
54
+ }
55
+ }
56
+
12
57
/// A tensor that uses the tch backend.
13
58
#[ derive( Debug , PartialEq ) ]
14
59
pub struct TchTensor < E : tch:: kind:: Element , const D : usize > {
15
60
/// Handle to the tensor. Call methods on this field.
16
61
pub tensor : tch:: Tensor ,
17
62
/// The tensor's storage
18
- pub storage : StorageRef ,
63
+ pub storage : Storage ,
19
64
phantom : PhantomData < E > ,
20
65
}
21
66
@@ -27,26 +72,49 @@ impl<E: tch::kind::Element, const D: usize> TchTensor<E, D> {
27
72
/// instead.
28
73
pub fn new ( tensor : tch:: Tensor ) -> Self {
29
74
#[ allow( clippy:: arc_with_non_send_sync) ]
30
- let data = Arc :: new ( tensor. data_ptr ( ) ) ;
75
+ let storage = Storage :: Owned {
76
+ buffer_ref : Arc :: new ( tensor. data_ptr ( ) ) ,
77
+ } ;
31
78
32
79
Self {
33
80
tensor,
34
81
phantom : PhantomData ,
35
- storage : data ,
82
+ storage,
36
83
}
37
84
}
38
85
39
86
/// Create a tensor that was created from an operation executed on a parent tensor.
40
87
///
41
88
/// If the child tensor shared the same storage as its parent, it will be cloned, effectively
42
89
/// tracking how much tensors point to the same memory space.
43
- pub fn from_existing ( tensor : tch:: Tensor , storage_parent : StorageRef ) -> Self {
90
+ pub fn from_existing ( tensor : tch:: Tensor , storage_parent : Storage ) -> Self {
44
91
let storage_child = tensor. data_ptr ( ) ;
92
+ let mut is_a_new_tensor = true ;
93
+
94
+ match & storage_parent {
95
+ Storage :: View {
96
+ buffer_ref : start_ref,
97
+ view_ref,
98
+ } => {
99
+ if storage_child == * start_ref. as_ref ( ) || storage_child == * view_ref. as_ref ( ) {
100
+ is_a_new_tensor = false ;
101
+ }
102
+ }
103
+ Storage :: Owned {
104
+ buffer_ref : start_ref,
105
+ } => {
106
+ if storage_child == * start_ref. as_ref ( ) {
107
+ is_a_new_tensor = false ;
108
+ }
109
+ }
110
+ } ;
45
111
46
- #[ allow( clippy:: arc_with_non_send_sync) ]
47
- let storage = match storage_child == * storage_parent {
48
- true => storage_parent. clone ( ) ,
49
- false => Arc :: new ( storage_child) ,
112
+ let storage = match is_a_new_tensor {
113
+ true => Storage :: Owned {
114
+ #[ allow( clippy:: arc_with_non_send_sync) ]
115
+ buffer_ref : Arc :: new ( storage_child) ,
116
+ } ,
117
+ false => storage_parent. clone ( ) ,
50
118
} ;
51
119
52
120
Self {
@@ -57,10 +125,15 @@ impl<E: tch::kind::Element, const D: usize> TchTensor<E, D> {
57
125
}
58
126
59
127
/// Create a tensor that uses a part of its parent tensor such as slice and narrow.
60
- pub fn partial ( tensor : tch:: Tensor , storage_parent : StorageRef ) -> Self {
128
+ pub fn partial ( tensor : tch:: Tensor , storage_parent : Storage ) -> Self {
129
+ let storage = Storage :: View {
130
+ buffer_ref : storage_parent. buffer_ref ( ) . clone ( ) ,
131
+ #[ allow( clippy:: arc_with_non_send_sync) ]
132
+ view_ref : Arc :: new ( tensor. data_ptr ( ) ) ,
133
+ } ;
61
134
Self {
62
135
tensor,
63
- storage : storage_parent ,
136
+ storage,
64
137
phantom : PhantomData ,
65
138
}
66
139
}
@@ -96,7 +169,7 @@ impl<P: tch::kind::Element, const D: usize> TchTensor<P, D> {
96
169
& mut self ,
97
170
func : F ,
98
171
) -> Option < TchTensor < EOut , D_OUT > > {
99
- if Arc :: strong_count ( & self . storage ) > 1 {
172
+ if ! self . storage . can_mut ( ) {
100
173
return None ;
101
174
}
102
175
@@ -113,7 +186,7 @@ impl<P: tch::kind::Element, const D: usize> TchTensor<P, D> {
113
186
FOwn : Fn ( tch:: Tensor ) -> tch:: Tensor ,
114
187
FRef : Fn ( & tch:: Tensor ) -> tch:: Tensor ,
115
188
{
116
- if Arc :: strong_count ( & self . storage ) > 1 {
189
+ if ! self . storage . can_mut ( ) {
117
190
return TchTensor :: from_existing ( fref ( & self . tensor ) , self . storage ) ;
118
191
}
119
192
0 commit comments