From dd0396dd5bb2d464ddf34ae8ba03664bfd7bcf94 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Wed, 22 Jan 2025 15:29:52 -0500 Subject: [PATCH] Fix db-pedia-infer backend (#2736) --- .../examples/db-pedia-infer.rs | 32 ++++++------------- 1 file changed, 10 insertions(+), 22 deletions(-) diff --git a/examples/text-classification/examples/db-pedia-infer.rs b/examples/text-classification/examples/db-pedia-infer.rs index 490ed3b97e..027eb76122 100644 --- a/examples/text-classification/examples/db-pedia-infer.rs +++ b/examples/text-classification/examples/db-pedia-infer.rs @@ -1,6 +1,6 @@ use text_classification::DbPediaDataset; -use burn::tensor::backend::AutodiffBackend; +use burn::tensor::backend::Backend; #[cfg(not(feature = "f16"))] #[allow(dead_code)] @@ -8,7 +8,7 @@ type ElemType = f32; #[cfg(feature = "f16")] type ElemType = burn::tensor::f16; -pub fn launch(device: B::Device) { +pub fn launch(device: B::Device) { text_classification::inference::infer::( device, "/tmp/text-classification-db-pedia", @@ -34,24 +34,18 @@ pub fn launch(device: B::Device) { feature = "ndarray-blas-accelerate", ))] mod ndarray { - use burn::backend::{ - ndarray::{NdArray, NdArrayDevice}, - Autodiff, - }; + use burn::backend::ndarray::{NdArray, NdArrayDevice}; use crate::{launch, ElemType}; pub fn run() { - launch::>>(NdArrayDevice::Cpu); + launch::>(NdArrayDevice::Cpu); } } #[cfg(feature = "tch-gpu")] mod tch_gpu { - use burn::backend::{ - libtorch::{LibTorch, LibTorchDevice}, - Autodiff, - }; + use burn::backend::libtorch::{LibTorch, LibTorchDevice}; use crate::{launch, ElemType}; @@ -61,35 +55,29 @@ mod tch_gpu { #[cfg(target_os = "macos")] let device = LibTorchDevice::Mps; - launch::>>(device); + launch::>(device); } } #[cfg(feature = "tch-cpu")] mod tch_cpu { - use burn::backend::{ - tch::{LibTorch, LibTorchDevice}, - Autodiff, - }; + use burn::backend::tch::{LibTorch, LibTorchDevice}; use crate::{launch, ElemType}; pub fn run() { - launch::>>(LibTorchDevice::Cpu); + launch::>(LibTorchDevice::Cpu); } } #[cfg(feature = "wgpu")] mod wgpu { - use burn::backend::{ - wgpu::{Wgpu, WgpuDevice}, - Autodiff, - }; + use burn::backend::wgpu::{Wgpu, WgpuDevice}; use crate::{launch, ElemType}; pub fn run() { - launch::>>(WgpuDevice::default()); + launch::>(WgpuDevice::default()); } }