1
1
use text_classification:: DbPediaDataset ;
2
2
3
- use burn:: tensor:: backend:: AutodiffBackend ;
3
+ use burn:: tensor:: backend:: Backend ;
4
4
5
5
#[ cfg( not( feature = "f16" ) ) ]
6
6
#[ allow( dead_code) ]
7
7
type ElemType = f32 ;
8
8
#[ cfg( feature = "f16" ) ]
9
9
type ElemType = burn:: tensor:: f16 ;
10
10
11
- pub fn launch < B : AutodiffBackend > ( device : B :: Device ) {
11
+ pub fn launch < B : Backend > ( device : B :: Device ) {
12
12
text_classification:: inference:: infer :: < B , DbPediaDataset > (
13
13
device,
14
14
"/tmp/text-classification-db-pedia" ,
@@ -34,24 +34,18 @@ pub fn launch<B: AutodiffBackend>(device: B::Device) {
34
34
feature = "ndarray-blas-accelerate" ,
35
35
) ) ]
36
36
mod ndarray {
37
- use burn:: backend:: {
38
- ndarray:: { NdArray , NdArrayDevice } ,
39
- Autodiff ,
40
- } ;
37
+ use burn:: backend:: ndarray:: { NdArray , NdArrayDevice } ;
41
38
42
39
use crate :: { launch, ElemType } ;
43
40
44
41
pub fn run ( ) {
45
- launch :: < Autodiff < NdArray < ElemType > > > ( NdArrayDevice :: Cpu ) ;
42
+ launch :: < NdArray < ElemType > > ( NdArrayDevice :: Cpu ) ;
46
43
}
47
44
}
48
45
49
46
#[ cfg( feature = "tch-gpu" ) ]
50
47
mod tch_gpu {
51
- use burn:: backend:: {
52
- libtorch:: { LibTorch , LibTorchDevice } ,
53
- Autodiff ,
54
- } ;
48
+ use burn:: backend:: libtorch:: { LibTorch , LibTorchDevice } ;
55
49
56
50
use crate :: { launch, ElemType } ;
57
51
@@ -61,35 +55,29 @@ mod tch_gpu {
61
55
#[ cfg( target_os = "macos" ) ]
62
56
let device = LibTorchDevice :: Mps ;
63
57
64
- launch :: < Autodiff < LibTorch < ElemType > > > ( device) ;
58
+ launch :: < LibTorch < ElemType > > ( device) ;
65
59
}
66
60
}
67
61
68
62
#[ cfg( feature = "tch-cpu" ) ]
69
63
mod tch_cpu {
70
- use burn:: backend:: {
71
- tch:: { LibTorch , LibTorchDevice } ,
72
- Autodiff ,
73
- } ;
64
+ use burn:: backend:: tch:: { LibTorch , LibTorchDevice } ;
74
65
75
66
use crate :: { launch, ElemType } ;
76
67
77
68
pub fn run ( ) {
78
- launch :: < Autodiff < LibTorch < ElemType > > > ( LibTorchDevice :: Cpu ) ;
69
+ launch :: < LibTorch < ElemType > > ( LibTorchDevice :: Cpu ) ;
79
70
}
80
71
}
81
72
82
73
#[ cfg( feature = "wgpu" ) ]
83
74
mod wgpu {
84
- use burn:: backend:: {
85
- wgpu:: { Wgpu , WgpuDevice } ,
86
- Autodiff ,
87
- } ;
75
+ use burn:: backend:: wgpu:: { Wgpu , WgpuDevice } ;
88
76
89
77
use crate :: { launch, ElemType } ;
90
78
91
79
pub fn run ( ) {
92
- launch :: < Autodiff < Wgpu < ElemType , i32 > > > ( WgpuDevice :: default ( ) ) ;
80
+ launch :: < Wgpu < ElemType , i32 > > ( WgpuDevice :: default ( ) ) ;
93
81
}
94
82
}
95
83
0 commit comments