Skip to content

Commit

Permalink
Merge branch 'main' into feat/fuse-reshape
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard committed Feb 3, 2025
2 parents 6ce7c3e + 6b2e66b commit f0c82e2
Show file tree
Hide file tree
Showing 109 changed files with 2,935 additions and 1,285 deletions.
1,150 changes: 309 additions & 841 deletions Cargo.lock

Large diffs are not rendered by default.

12 changes: 6 additions & 6 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ version = "0.17.0"
atomic_float = "1"
bytemuck = "1.21.0"
candle-core = { version = "0.8" }
clap = { version = "4.5.26", features = ["derive"] }
clap = { version = "4.5.27", features = ["derive"] }
colored = "2.1.0"
console_error_panic_hook = "0.1.7"
csv = "1.3.1"
Expand All @@ -47,14 +47,14 @@ globwalk = "0.9.1"
hashbrown = "0.15.2"
hound = "3.5.1"
image = "0.25.5"
indicatif = "0.17.9"
indicatif = "0.17.11"
js-sys = "0.3.72"
libm = "0.2.11"
log = { default-features = false, version = "0.4.25" }
md5 = "0.7.0"
paste = "1"
percent-encoding = "2.3.1"
polars = { version = "0.44.2", features = ["lazy"] }
polars = { version = "0.46.0", features = ["lazy"] }
pretty_assertions = "1.4.1"
proc-macro2 = "1.0.93"
protobuf = "3.7.1"
Expand Down Expand Up @@ -101,7 +101,7 @@ ratatui = "0.29.0"

# WGPU stuff
text_placeholder = "0.5.1"
wgpu = "24.0.0"
wgpu = "24.0.1"

# Benchmarks and Burnbench
arboard = "3.4.1"
Expand Down Expand Up @@ -141,11 +141,11 @@ serde = { version = "1.0.217", default-features = false, features = [
"alloc",
] } # alloc is for no_std, derive is needed
serde_json = { version = "1.0.137", default-features = false }
uuid = { version = "1.12.0", default-features = false }
uuid = { version = "1.12.1", default-features = false }

libc = "0.2.169"
nvml-wrapper = "0.10.0"
sysinfo = "0.32.1"
sysinfo = "0.33.1"
systemstat = "0.2.3"
tch = "0.15.0"

Expand Down
8 changes: 4 additions & 4 deletions backend-comparison/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ candle-accelerate = ["burn/candle", "burn/accelerate"]
candle-cpu = ["burn/candle"]
candle-cuda = ["burn/candle-cuda"]
candle-metal = ["burn/candle", "burn/metal"]
cuda-jit = ["burn/cuda-jit"]
cuda-jit-fusion = ["cuda-jit", "burn/fusion"]
cuda = ["burn/cuda"]
cuda-fusion = ["cuda", "burn/fusion"]
default = ["burn/std", "burn/autodiff", "burn/wgpu", "burn/autotune"]
hip-jit = ["burn/hip-jit"]
hip = ["burn/hip"]
ndarray = ["burn/ndarray"]
ndarray-blas-accelerate = ["burn/ndarray", "burn/accelerate"]
ndarray-blas-netlib = ["burn/ndarray", "burn/blas-netlib"]
Expand All @@ -27,7 +27,7 @@ tch-cpu = ["burn/tch"]
tch-gpu = ["burn/tch"]
wgpu = ["burn/wgpu", "burn/autotune"]
wgpu-fusion = ["wgpu", "burn/fusion"]
wgpu-spirv = ["burn/wgpu-spirv", "burn/autotune"]
wgpu-spirv = ["burn/vulkan", "burn/autotune"]
wgpu-spirv-fusion = ["wgpu-spirv", "burn/fusion"]

[dependencies]
Expand Down
14 changes: 7 additions & 7 deletions backend-comparison/src/burnbenchapp/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@ enum BackendValues {
CandleCuda,
#[strum(to_string = "candle-metal")]
CandleMetal,
#[strum(to_string = "cuda")]
Cuda,
#[strum(to_string = "cuda-fusion")]
CudaFusion,
#[cfg(target_os = "linux")]
#[strum(to_string = "hip")]
Hip,
#[strum(to_string = "ndarray")]
Ndarray,
#[strum(to_string = "ndarray-blas-accelerate")]
Expand All @@ -82,13 +89,6 @@ enum BackendValues {
WgpuSpirv,
#[strum(to_string = "wgpu-spirv-fusion")]
WgpuSpirvFusion,
#[strum(to_string = "cuda-jit")]
CudaJit,
#[strum(to_string = "cuda-jit-fusion")]
CudaJitFusion,
#[cfg(target_os = "linux")]
#[strum(to_string = "hip-jit")]
HipJit,
}

#[derive(Debug, Clone, PartialEq, Eq, ValueEnum, Display, EnumIter)]
Expand Down
20 changes: 10 additions & 10 deletions backend-comparison/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,12 @@ macro_rules! bench_on_backend {
let feature_name = "wgpu-spirv";
#[cfg(feature = "wgpu-spirv-fusion")]
let feature_name = "wgpu-spirv-fusion";
#[cfg(feature = "cuda-jit")]
let feature_name = "cuda-jit";
#[cfg(feature = "cuda-jit-fusion")]
let feature_name = "cuda-jit-fusion";
#[cfg(feature = "hip-jit")]
let feature_name = "hip-jit";
#[cfg(feature = "cuda")]
let feature_name = "cuda";
#[cfg(feature = "cuda-fusion")]
let feature_name = "cuda-fusion";
#[cfg(feature = "hip")]
let feature_name = "hip";

#[cfg(any(feature = "wgpu"))]
{
Expand Down Expand Up @@ -172,16 +172,16 @@ macro_rules! bench_on_backend {
$fn_name::<Candle>(&device, feature_name, url, token);
}

#[cfg(feature = "cuda-jit")]
#[cfg(feature = "cuda")]
{
use burn::backend::cuda_jit::{Cuda, CudaDevice};
use burn::backend::cuda::{Cuda, CudaDevice};

$fn_name::<Cuda<half::f16>>(&CudaDevice::default(), feature_name, url, token);
}

#[cfg(feature = "hip-jit")]
#[cfg(feature = "hip")]
{
use burn::backend::hip_jit::{Hip, HipDevice};
use burn::backend::hip::{Hip, HipDevice};

$fn_name::<Hip<half::f16>>(&HipDevice::default(), feature_name, url, token);
}
Expand Down
2 changes: 1 addition & 1 deletion backend-comparison/src/persistence/system_info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ impl BenchmarkSystemInfo {

fn enumerate_cpus() -> Vec<String> {
let system = sysinfo::System::new_with_specifics(
sysinfo::RefreshKind::new().with_cpu(sysinfo::CpuRefreshKind::everything()),
sysinfo::RefreshKind::nothing().with_cpu(sysinfo::CpuRefreshKind::everything()),
);
let cpu_names: HashSet<String> = system
.cpus()
Expand Down
4 changes: 2 additions & 2 deletions burn-book/src/advanced/no-std.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,15 @@ We are using ndarray, so we just need to define the NdArray backend as usual
use burn::{backend::NdArray, tensor::Tensor};

type Backend = NdArray<f32>;
type BackendDeice = <Backend as burn::tensor::backend::Backend>::Device;
type BackendDevice = <Backend as burn::tensor::backend::Backend>::Device;
```

Then inside the `main` function add
```rs
use your_model::Model;

// Get a default device for the backend
let device = BackendDeice::default();
let device = BackendDevice::default();

// Create a new model and load the state
let model: Model<Backend> = Model::default();
Expand Down
Loading

0 comments on commit f0c82e2

Please sign in to comment.