Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
laggui committed Feb 3, 2025
2 parents 68a0f93 + 9f00320 commit 7ecc9c4
Show file tree
Hide file tree
Showing 51 changed files with 8,991 additions and 210 deletions.
8,345 changes: 8,345 additions & 0 deletions Cargo.lock

Large diffs are not rendered by default.

10 changes: 5 additions & 5 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,14 @@ globwalk = "0.9.1"
hashbrown = "0.15.2"
hound = "3.5.1"
image = "0.25.5"
indicatif = "0.17.9"
indicatif = "0.17.11"
js-sys = "0.3.72"
libm = "0.2.11"
log = { default-features = false, version = "0.4.25" }
md5 = "0.7.0"
paste = "1"
percent-encoding = "2.3.1"
polars = { version = "0.44.2", features = ["lazy"] }
polars = { version = "0.46.0", features = ["lazy"] }
pretty_assertions = "1.4.1"
proc-macro2 = "1.0.93"
protobuf = "3.7.1"
Expand Down Expand Up @@ -145,16 +145,16 @@ uuid = { version = "1.12.1", default-features = false }

libc = "0.2.169"
nvml-wrapper = "0.10.0"
sysinfo = "0.32.1"
sysinfo = "0.33.1"
systemstat = "0.2.3"
tch = "0.15.0"

ahash = { version = "0.8.11", default-features = false }
portable-atomic-util = { version = "0.2.4", features = ["alloc"] }

### For the main burn branch. ###
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "ff34667accfe077d4a1cd48ae419868e142acfd6" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "ff34667accfe077d4a1cd48ae419868e142acfd6" }
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "a172f6760052bef392e6f0e44e912460960f2c1b" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "a172f6760052bef392e6f0e44e912460960f2c1b" }
### For local development. ###
# cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }
Expand Down
8 changes: 4 additions & 4 deletions backend-comparison/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ candle-accelerate = ["burn/candle", "burn/accelerate"]
candle-cpu = ["burn/candle"]
candle-cuda = ["burn/candle-cuda"]
candle-metal = ["burn/candle", "burn/metal"]
cuda-jit = ["burn/cuda-jit"]
cuda-jit-fusion = ["cuda-jit", "burn/fusion"]
cuda = ["burn/cuda"]
cuda-fusion = ["cuda", "burn/fusion"]
default = ["burn/std", "burn/autodiff", "burn/wgpu", "burn/autotune"]
hip-jit = ["burn/hip-jit"]
hip = ["burn/hip"]
ndarray = ["burn/ndarray"]
ndarray-blas-accelerate = ["burn/ndarray", "burn/accelerate"]
ndarray-blas-netlib = ["burn/ndarray", "burn/blas-netlib"]
Expand All @@ -27,7 +27,7 @@ tch-cpu = ["burn/tch"]
tch-gpu = ["burn/tch"]
wgpu = ["burn/wgpu", "burn/autotune"]
wgpu-fusion = ["wgpu", "burn/fusion"]
wgpu-spirv = ["burn/wgpu-spirv", "burn/autotune"]
wgpu-spirv = ["burn/vulkan", "burn/autotune"]
wgpu-spirv-fusion = ["wgpu-spirv", "burn/fusion"]

[dependencies]
Expand Down
14 changes: 7 additions & 7 deletions backend-comparison/src/burnbenchapp/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@ enum BackendValues {
CandleCuda,
#[strum(to_string = "candle-metal")]
CandleMetal,
#[strum(to_string = "cuda")]
Cuda,
#[strum(to_string = "cuda-fusion")]
CudaFusion,
#[cfg(target_os = "linux")]
#[strum(to_string = "hip")]
Hip,
#[strum(to_string = "ndarray")]
Ndarray,
#[strum(to_string = "ndarray-blas-accelerate")]
Expand All @@ -82,13 +89,6 @@ enum BackendValues {
WgpuSpirv,
#[strum(to_string = "wgpu-spirv-fusion")]
WgpuSpirvFusion,
#[strum(to_string = "cuda-jit")]
CudaJit,
#[strum(to_string = "cuda-jit-fusion")]
CudaJitFusion,
#[cfg(target_os = "linux")]
#[strum(to_string = "hip-jit")]
HipJit,
}

#[derive(Debug, Clone, PartialEq, Eq, ValueEnum, Display, EnumIter)]
Expand Down
20 changes: 10 additions & 10 deletions backend-comparison/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,12 @@ macro_rules! bench_on_backend {
let feature_name = "wgpu-spirv";
#[cfg(feature = "wgpu-spirv-fusion")]
let feature_name = "wgpu-spirv-fusion";
#[cfg(feature = "cuda-jit")]
let feature_name = "cuda-jit";
#[cfg(feature = "cuda-jit-fusion")]
let feature_name = "cuda-jit-fusion";
#[cfg(feature = "hip-jit")]
let feature_name = "hip-jit";
#[cfg(feature = "cuda")]
let feature_name = "cuda";
#[cfg(feature = "cuda-fusion")]
let feature_name = "cuda-fusion";
#[cfg(feature = "hip")]
let feature_name = "hip";

#[cfg(any(feature = "wgpu"))]
{
Expand Down Expand Up @@ -172,16 +172,16 @@ macro_rules! bench_on_backend {
$fn_name::<Candle>(&device, feature_name, url, token);
}

#[cfg(feature = "cuda-jit")]
#[cfg(feature = "cuda")]
{
use burn::backend::cuda_jit::{Cuda, CudaDevice};
use burn::backend::cuda::{Cuda, CudaDevice};

$fn_name::<Cuda<half::f16>>(&CudaDevice::default(), feature_name, url, token);
}

#[cfg(feature = "hip-jit")]
#[cfg(feature = "hip")]
{
use burn::backend::hip_jit::{Hip, HipDevice};
use burn::backend::hip::{Hip, HipDevice};

$fn_name::<Hip<half::f16>>(&HipDevice::default(), feature_name, url, token);
}
Expand Down
2 changes: 1 addition & 1 deletion backend-comparison/src/persistence/system_info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ impl BenchmarkSystemInfo {

fn enumerate_cpus() -> Vec<String> {
let system = sysinfo::System::new_with_specifics(
sysinfo::RefreshKind::new().with_cpu(sysinfo::CpuRefreshKind::everything()),
sysinfo::RefreshKind::nothing().with_cpu(sysinfo::CpuRefreshKind::everything()),
);
let cpu_names: HashSet<String> = system
.cpus()
Expand Down
1 change: 1 addition & 0 deletions burn-book/src/building-blocks/module.md
Original file line number Diff line number Diff line change
Expand Up @@ -294,3 +294,4 @@ Burn comes with built-in modules that you can use to build your own modules.
| `CrossEntropyLoss` | `nn.CrossEntropyLoss` |
| `MseLoss` | `nn.MSELoss` |
| `HuberLoss` | `nn.HuberLoss` |
| `PoissonNllLoss` | `nn.PoissonNLLLoss` |
19 changes: 10 additions & 9 deletions crates/burn-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ doc = [
"ndarray",
"tch",
"wgpu",
"cuda-jit",
"hip-jit",
"cuda",
"hip",
"audio",
"vision",
"autodiff",
Expand Down Expand Up @@ -88,7 +88,7 @@ fusion = ["burn-wgpu?/fusion", "burn-cuda?/fusion"]

## Backend features
accelerate = ["burn-candle?/accelerate", "burn-ndarray?/blas-accelerate"]
autotune = ["burn-wgpu?/autotune"]
autotune = ["burn-wgpu?/autotune", "burn-cuda?/autotune", "burn-hip?/autotune"]
blas-netlib = ["burn-ndarray?/blas-netlib"]
metal = ["burn-candle?/metal"]
openblas = ["burn-ndarray?/blas-openblas"]
Expand All @@ -100,26 +100,27 @@ template = ["burn-wgpu?/template"]

candle = ["burn-candle"]
candle-cuda = ["candle", "burn-candle/cuda"]
cuda-jit = ["burn-cuda"]
hip-jit = ["burn-hip"]
cuda = ["burn-cuda"]
hip = ["burn-hip"]
ndarray = ["burn-ndarray"]
tch = ["burn-tch"]
wgpu = ["burn-wgpu"]
wgpu-spirv = ["wgpu", "burn-wgpu/spirv"]
vulkan = ["wgpu", "burn-wgpu/vulkan"]
webgpu = ["wgpu", "burn-wgpu/webgpu"]

# Custom deserializer for Record that is helpful for importing data, such as PyTorch pt files.
record-item-custom-serde = ["thiserror", "regex"]

# Serialization formats
experimental-named-tensor = ["burn-tensor/experimental-named-tensor"]

test-cuda = ["cuda-jit"] # To use cuda during testing, default uses ndarray.
test-hip = ["hip-jit"] # To use hip during testing, default uses ndarray.
test-cuda = ["cuda"] # To use cuda during testing, default uses ndarray.
test-hip = ["hip"] # To use hip during testing, default uses ndarray.
test-tch = ["tch"] # To use tch during testing, default uses ndarray.
test-wgpu = ["wgpu"] # To use wgpu during testing, default uses ndarray.
test-wgpu-spirv = [
"test-wgpu",
"wgpu-spirv",
"vulkan",
] # To use wgpu-spirv during testing, default uses ndarray.

[dependencies]
Expand Down
22 changes: 14 additions & 8 deletions crates/burn-core/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,29 @@ pub use burn_wgpu as wgpu;
#[cfg(feature = "wgpu")]
pub use burn_wgpu::Wgpu;

#[cfg(feature = "cuda-jit")]
pub use burn_cuda as cuda_jit;
#[cfg(feature = "webgpu")]
pub use burn_wgpu::WebGpu;

#[cfg(feature = "cuda-jit")]
pub use burn_cuda::Cuda as CudaJit;
#[cfg(feature = "vulkan")]
pub use burn_wgpu::Vulkan;

#[cfg(feature = "cuda")]
pub use burn_cuda as cuda;

#[cfg(feature = "cuda")]
pub use burn_cuda::Cuda;

#[cfg(feature = "candle")]
pub use burn_candle as candle;

#[cfg(feature = "candle")]
pub use burn_candle::Candle;

#[cfg(feature = "hip-jit")]
pub use burn_hip as hip_jit;
#[cfg(feature = "hip")]
pub use burn_hip as hip;

#[cfg(feature = "hip-jit")]
pub use burn_hip::Hip as HipJit;
#[cfg(feature = "hip")]
pub use burn_hip::Hip;

#[cfg(feature = "tch")]
pub use burn_tch as libtorch;
Expand Down
2 changes: 2 additions & 0 deletions crates/burn-core/src/nn/loss/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@ mod binary_cross_entropy;
mod cross_entropy;
mod huber;
mod mse;
mod poisson;
mod reduction;

pub use binary_cross_entropy::*;
pub use cross_entropy::*;
pub use huber::*;
pub use mse::*;
pub use poisson::*;
pub use reduction::*;
Loading

0 comments on commit 7ecc9c4

Please sign in to comment.