Skip to content

Commit

Permalink
Merge branch 'main' into feat/burn-vision
Browse files Browse the repository at this point in the history
  • Loading branch information
wingertge committed Jan 31, 2025
2 parents 021360b + cb0854c commit a1d727f
Show file tree
Hide file tree
Showing 8 changed files with 214 additions and 714 deletions.
851 changes: 180 additions & 671 deletions Cargo.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ log = { default-features = false, version = "0.4.25" }
md5 = "0.7.0"
paste = "1"
percent-encoding = "2.3.1"
polars = { version = "0.44.2", features = ["lazy"] }
polars = { version = "0.46.0", features = ["lazy"] }
pretty_assertions = "1.4.1"
proc-macro2 = "1.0.93"
protobuf = "3.7.1"
Expand Down
14 changes: 7 additions & 7 deletions backend-comparison/src/burnbenchapp/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@ enum BackendValues {
CandleCuda,
#[strum(to_string = "candle-metal")]
CandleMetal,
#[strum(to_string = "cuda")]
Cuda,
#[strum(to_string = "cuda-fusion")]
CudaFusion,
#[cfg(target_os = "linux")]
#[strum(to_string = "hip")]
Hip,
#[strum(to_string = "ndarray")]
Ndarray,
#[strum(to_string = "ndarray-blas-accelerate")]
Expand All @@ -82,13 +89,6 @@ enum BackendValues {
WgpuSpirv,
#[strum(to_string = "wgpu-spirv-fusion")]
WgpuSpirvFusion,
#[strum(to_string = "cuda-jit")]
CudaJit,
#[strum(to_string = "cuda-jit-fusion")]
CudaJitFusion,
#[cfg(target_os = "linux")]
#[strum(to_string = "hip-jit")]
HipJit,
}

#[derive(Debug, Clone, PartialEq, Eq, ValueEnum, Display, EnumIter)]
Expand Down
22 changes: 11 additions & 11 deletions crates/burn-dataset/src/dataset/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -269,20 +269,20 @@ mod tests {
}

fn create_test_dataframe() -> DataFrame {
let s0 = Column::Series(Series::new("int32".into(), &[1i32, 2i32, 3i32]));
let s1 = Column::Series(Series::new("bool".into(), &[true, false, true]));
let s2 = Column::Series(Series::new("float64".into(), &[1.1f64, 2.2f64, 3.3f64]));
let s3 = Column::Series(Series::new("string".into(), &["Boo", "Boo2", "Boo3"]));
let s6 = Column::Series(Series::new("int16".into(), &[1i16, 2i16, 3i16]));
let s8 = Column::Series(Series::new("uint32".into(), &[1u32, 2u32, 3u32]));
let s9 = Column::Series(Series::new("uint64".into(), &[1u64, 2u64, 3u64]));
let s10 = Column::Series(Series::new("float32".into(), &[1.1f32, 2.2f32, 3.3f32]));
let s11 = Column::Series(Series::new("int64".into(), &[1i64, 2i64, 3i64]));
let s12 = Column::Series(Series::new("int8".into(), &[1i8, 2i8, 3i8]));
let s0 = Column::new("int32".into(), &[1i32, 2i32, 3i32]);
let s1 = Column::new("bool".into(), &[true, false, true]);
let s2 = Column::new("float64".into(), &[1.1f64, 2.2f64, 3.3f64]);
let s3 = Column::new("string".into(), &["Boo", "Boo2", "Boo3"]);
let s6 = Column::new("int16".into(), &[1i16, 2i16, 3i16]);
let s8 = Column::new("uint32".into(), &[1u32, 2u32, 3u32]);
let s9 = Column::new("uint64".into(), &[1u64, 2u64, 3u64]);
let s10 = Column::new("float32".into(), &[1.1f32, 2.2f32, 3.3f32]);
let s11 = Column::new("int64".into(), &[1i64, 2i64, 3i64]);
let s12 = Column::new("int8".into(), &[1i8, 2i8, 3i8]);

let binary_data: Vec<&[u8]> = vec![&[1, 2, 3], &[4, 5, 6], &[7, 8, 9]];

let s13 = Column::Series(Series::new("binary".into(), binary_data));
let s13 = Column::new("binary".into(), binary_data);
DataFrame::new(vec![s0, s1, s2, s3, s6, s8, s9, s10, s11, s12, s13]).unwrap()
}

Expand Down
4 changes: 2 additions & 2 deletions crates/burn-remote/src/server/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,12 @@ impl<B: ReprBackend> SessionManager<B> {

impl<B: ReprBackend> Session<B> {
fn new(runner: Runner<B>) -> Self {
let (sender, reveiver) = std::sync::mpsc::sync_channel(1);
let (sender, receiver) = std::sync::mpsc::sync_channel(1);
Self {
runner,
streams: Default::default(),
sender,
receiver: Some(reveiver),
receiver: Some(receiver),
}
}

Expand Down
4 changes: 2 additions & 2 deletions crates/burn-tensor/src/tensor/backend/conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ mod tests {
}

#[test]
fn should_build_indices_2d_complexe() {
fn should_build_indices_2d_complex() {
let shape = Shape::new([2, 3]);

let indices = build_indices(&shape, Order::Left);
Expand All @@ -206,7 +206,7 @@ mod tests {
}

#[test]
fn should_build_indices_3d_complexe() {
fn should_build_indices_3d_complex() {
let shape = Shape::new([2, 5, 3]);

let indices = build_indices(&shape, Order::Left);
Expand Down
29 changes: 10 additions & 19 deletions examples/wgan/examples/wgan-generate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,18 @@ pub fn launch<B: Backend>(device: B::Device) {
feature = "ndarray-blas-accelerate",
))]
mod ndarray {
use burn::backend::{
ndarray::{NdArray, NdArrayDevice},
Autodiff,
};
use burn::backend::ndarray::{NdArray, NdArrayDevice};

use crate::launch;

pub fn run() {
launch::<Autodiff<NdArray>>(NdArrayDevice::Cpu);
launch::<NdArray>(NdArrayDevice::Cpu);
}
}

#[cfg(feature = "tch-gpu")]
mod tch_gpu {
use burn::backend::{
libtorch::{LibTorch, LibTorchDevice},
Autodiff,
};
use burn::backend::libtorch::{LibTorch, LibTorchDevice};

use crate::launch;

Expand All @@ -38,41 +32,38 @@ mod tch_gpu {
#[cfg(target_os = "macos")]
let device = LibTorchDevice::Mps;

launch::<Autodiff<LibTorch>>(device);
launch::<LibTorch>(device);
}
}

#[cfg(feature = "tch-cpu")]
mod tch_cpu {
use burn::backend::{
libtorch::{LibTorch, LibTorchDevice},
Autodiff,
};
use burn::backend::libtorch::{LibTorch, LibTorchDevice};

use crate::launch;

pub fn run() {
launch::<Autodiff<LibTorch>>(LibTorchDevice::Cpu);
launch::<LibTorch>(LibTorchDevice::Cpu);
}
}

#[cfg(feature = "wgpu")]
mod wgpu {
use crate::launch;
use burn::backend::{wgpu::Wgpu, Autodiff};
use burn::backend::wgpu::Wgpu;

pub fn run() {
launch::<Autodiff<Wgpu>>(Default::default());
launch::<Wgpu>(Default::default());
}
}

#[cfg(feature = "cuda")]
mod cuda {
use crate::launch;
use burn::backend::{Autodiff, Cuda};
use burn::backend::Cuda;

pub fn run() {
launch::<Autodiff<Cuda>>(Default::default());
launch::<Cuda>(Default::default());
}
}

Expand Down
2 changes: 1 addition & 1 deletion examples/wgan/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ impl<B: Backend> Discriminator<B> {
}
}

// Use model config to construct a generative and adverserial model
// Use model config to construct a generative and adversarial model
#[derive(Config, Debug)]
pub struct ModelConfig {
/// Dimensionality of the latent space
Expand Down

0 comments on commit a1d727f

Please sign in to comment.