@@ -39,12 +39,9 @@ pub struct Context<'a, H> {
39
39
pub scalar_u8 : & ' a Vec < u8 > ,
40
40
}
41
41
42
- #[ derive( Default ) ]
43
42
pub ( crate ) struct OperationConverter {
44
43
tensors_relative2global : HashMap < TensorId , TensorDescription > ,
45
44
tensors_global2relative : HashMap < TensorId , TensorDescription > ,
46
- /// Only useful to create new shape ID.
47
- /// You should use tensor descriptions to retrieve the proper shape.
48
45
shapes_global2relative : HashMap < usize , usize > ,
49
46
scalar_f32 : Vec < f32 > ,
50
47
scalar_f16 : Vec < f16 > ,
@@ -59,6 +56,32 @@ pub(crate) struct OperationConverter {
59
56
scalar_u8 : Vec < u8 > ,
60
57
}
61
58
59
+ impl Default for OperationConverter {
60
+ fn default ( ) -> Self {
61
+ let mut val = Self {
62
+ tensors_relative2global : Default :: default ( ) ,
63
+ tensors_global2relative : Default :: default ( ) ,
64
+ shapes_global2relative : Default :: default ( ) ,
65
+ scalar_f32 : Default :: default ( ) ,
66
+ scalar_f16 : Default :: default ( ) ,
67
+ scalar_bf16 : Default :: default ( ) ,
68
+ scalar_i64 : Default :: default ( ) ,
69
+ scalar_i32 : Default :: default ( ) ,
70
+ scalar_i16 : Default :: default ( ) ,
71
+ scalar_i8 : Default :: default ( ) ,
72
+ scalar_u64 : Default :: default ( ) ,
73
+ scalar_u32 : Default :: default ( ) ,
74
+ scalar_u16 : Default :: default ( ) ,
75
+ scalar_u8 : Default :: default ( ) ,
76
+ } ;
77
+
78
+ // global 1 is always shape id 0.
79
+ val. shapes_global2relative . insert ( 1 , 0 ) ;
80
+
81
+ val
82
+ }
83
+ }
84
+
62
85
/// Fork of a [context](Context) which owns its data.
63
86
pub struct ContextOwned < H > {
64
87
tensors : HashMap < TensorId , TensorDescription > ,
@@ -180,7 +203,11 @@ impl OperationConverter {
180
203
pub ( crate ) fn clear ( & mut self ) {
181
204
self . tensors_relative2global . clear ( ) ;
182
205
self . tensors_global2relative . clear ( ) ;
206
+
183
207
self . shapes_global2relative . clear ( ) ;
208
+ // global 1 is always shape id 0.
209
+ self . shapes_global2relative . insert ( 1 , 0 ) ;
210
+
184
211
self . scalar_f32 . clear ( ) ;
185
212
self . scalar_f16 . clear ( ) ;
186
213
self . scalar_bf16 . clear ( ) ;
@@ -1129,7 +1156,7 @@ impl RelativeOps for BaseOperationDescription {
1129
1156
BaseOperationDescription :: ToDevice ( desc. to_relative ( converter) )
1130
1157
}
1131
1158
BaseOperationDescription :: Reshape ( desc) => {
1132
- BaseOperationDescription :: Reshape ( ReshapeDescription {
1159
+ BaseOperationDescription :: Reshape ( UnaryOperationDescription {
1133
1160
input : desc. input . to_relative ( converter) ,
1134
1161
out : desc. out . to_relative ( converter) ,
1135
1162
} )
@@ -1246,6 +1273,7 @@ impl RelativeOps for TensorDescription {
1246
1273
// We never saw this dim value before, therefore we create a new ID.
1247
1274
let dim_id = converter. shapes_global2relative . len ( ) ;
1248
1275
relative_shape. push ( dim_id) ;
1276
+
1249
1277
converter. shapes_global2relative . insert ( * dim, dim_id) ;
1250
1278
}
1251
1279
}
@@ -1300,7 +1328,7 @@ mod tests {
1300
1328
tensor1_local,
1301
1329
TensorDescription {
1302
1330
id: TensorId :: new( 0 ) ,
1303
- shape: vec![ 0 , 1 , 2 ] ,
1331
+ shape: vec![ 1 , 2 , 3 ] ,
1304
1332
status: TensorStatus :: ReadOnly ,
1305
1333
dtype: DType :: F32
1306
1334
}
@@ -1309,7 +1337,7 @@ mod tests {
1309
1337
tensor2_local,
1310
1338
TensorDescription {
1311
1339
id: TensorId :: new( 1 ) ,
1312
- shape: vec![ 0 , 3 , 2 ] ,
1340
+ shape: vec![ 1 , 4 , 3 ] ,
1313
1341
status: TensorStatus :: ReadOnly ,
1314
1342
dtype: DType :: F32
1315
1343
}
0 commit comments