Skip to content

Commit 79d6e19

Browse files
committed
Fix db-pedia-infer backend
1 parent b33bd24 commit 79d6e19

File tree

1 file changed

+10
-22
lines changed

1 file changed

+10
-22
lines changed

examples/text-classification/examples/db-pedia-infer.rs

+10-22
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
use text_classification::DbPediaDataset;
22

3-
use burn::tensor::backend::AutodiffBackend;
3+
use burn::tensor::backend::Backend;
44

55
#[cfg(not(feature = "f16"))]
66
#[allow(dead_code)]
77
type ElemType = f32;
88
#[cfg(feature = "f16")]
99
type ElemType = burn::tensor::f16;
1010

11-
pub fn launch<B: AutodiffBackend>(device: B::Device) {
11+
pub fn launch<B: Backend>(device: B::Device) {
1212
text_classification::inference::infer::<B, DbPediaDataset>(
1313
device,
1414
"/tmp/text-classification-db-pedia",
@@ -34,24 +34,18 @@ pub fn launch<B: AutodiffBackend>(device: B::Device) {
3434
feature = "ndarray-blas-accelerate",
3535
))]
3636
mod ndarray {
37-
use burn::backend::{
38-
ndarray::{NdArray, NdArrayDevice},
39-
Autodiff,
40-
};
37+
use burn::backend::ndarray::{NdArray, NdArrayDevice};
4138

4239
use crate::{launch, ElemType};
4340

4441
pub fn run() {
45-
launch::<Autodiff<NdArray<ElemType>>>(NdArrayDevice::Cpu);
42+
launch::<NdArray<ElemType>>(NdArrayDevice::Cpu);
4643
}
4744
}
4845

4946
#[cfg(feature = "tch-gpu")]
5047
mod tch_gpu {
51-
use burn::backend::{
52-
libtorch::{LibTorch, LibTorchDevice},
53-
Autodiff,
54-
};
48+
use burn::backend::libtorch::{LibTorch, LibTorchDevice};
5549

5650
use crate::{launch, ElemType};
5751

@@ -61,35 +55,29 @@ mod tch_gpu {
6155
#[cfg(target_os = "macos")]
6256
let device = LibTorchDevice::Mps;
6357

64-
launch::<Autodiff<LibTorch<ElemType>>>(device);
58+
launch::<LibTorch<ElemType>>(device);
6559
}
6660
}
6761

6862
#[cfg(feature = "tch-cpu")]
6963
mod tch_cpu {
70-
use burn::backend::{
71-
tch::{LibTorch, LibTorchDevice},
72-
Autodiff,
73-
};
64+
use burn::backend::tch::{LibTorch, LibTorchDevice};
7465

7566
use crate::{launch, ElemType};
7667

7768
pub fn run() {
78-
launch::<Autodiff<LibTorch<ElemType>>>(LibTorchDevice::Cpu);
69+
launch::<LibTorch<ElemType>>(LibTorchDevice::Cpu);
7970
}
8071
}
8172

8273
#[cfg(feature = "wgpu")]
8374
mod wgpu {
84-
use burn::backend::{
85-
wgpu::{Wgpu, WgpuDevice},
86-
Autodiff,
87-
};
75+
use burn::backend::wgpu::{Wgpu, WgpuDevice};
8876

8977
use crate::{launch, ElemType};
9078

9179
pub fn run() {
92-
launch::<Autodiff<Wgpu<ElemType, i32>>>(WgpuDevice::default());
80+
launch::<Wgpu<ElemType, i32>>(WgpuDevice::default());
9381
}
9482
}
9583

0 commit comments

Comments
 (0)