Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] Add new burn-vision crate with one initial op #2753

Open
wants to merge 25 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
e7a84b9
Update cubecl
wingertge Jan 25, 2025
267ec63
Update to scope merge
wingertge Jan 26, 2025
71584e0
Fix bitwise shift
wingertge Jan 26, 2025
5838364
Initial JIT implementation
wingertge Jan 28, 2025
d292f85
Merge branch 'main' into feat/burn-vision
wingertge Jan 28, 2025
9e65150
Move testgen to burn-jit
wingertge Jan 28, 2025
0484f51
Improve HA4/8 algo
wingertge Jan 28, 2025
f62a9ee
Terminate units past the predefined 32 plane size
wingertge Jan 28, 2025
8edac2b
move jit backend back into `burn-vision` and make tests work
wingertge Jan 30, 2025
05b40e3
Add initial CPU implementation without stats
wingertge Jan 30, 2025
7708993
Implement stats
wingertge Jan 31, 2025
aeea3a8
Implement all backends except fusion
wingertge Jan 31, 2025
a994ca7
Fix autodiff to use GPU when available
wingertge Jan 31, 2025
866307b
Fixes and cleanup
wingertge Jan 31, 2025
a8e3994
Add docs
wingertge Jan 31, 2025
021360b
Update cubecl
wingertge Jan 31, 2025
a1d727f
Merge branch 'main' into feat/burn-vision
wingertge Jan 31, 2025
01ff01b
Compact labels for JIT
wingertge Feb 1, 2025
d790113
Improve JIT backend implementation by adding label compaction
wingertge Feb 2, 2025
15c431c
Use GPU reduction for max label
wingertge Feb 2, 2025
e3ec085
Manually fuse presence and prefix sum
wingertge Feb 2, 2025
11c8f1f
Make prefix sum more generic over line size
wingertge Feb 2, 2025
ee5ad73
Merge branch 'main' into feat/burn-vision
wingertge Feb 3, 2025
e6126c8
Add vision tests to xtask
wingertge Feb 3, 2025
1bbf50a
Fix CPU and other review stuff
wingertge Feb 3, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 36 additions & 16 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,8 @@ ahash = { version = "0.8.11", default-features = false }
portable-atomic-util = { version = "0.2.4", features = ["alloc"] }

### For the main burn branch. ###
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "a172f6760052bef392e6f0e44e912460960f2c1b" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "a172f6760052bef392e6f0e44e912460960f2c1b" }
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "ff94be8c9a79d1ac2e44829c2b4ec5a7e91b82e2" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "ff94be8c9a79d1ac2e44829c2b4ec5a7e91b82e2" }
### For local development. ###
# cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }
Expand Down
3 changes: 3 additions & 0 deletions crates/burn-candle/src/element.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@ use burn_tensor::Element;
use candle_core::{FloatDType, Tensor, WithDType};
use half::{bf16, f16};

/// Candle element
pub trait CandleElement: Element + WithDType {}
/// Candle float element
pub trait FloatCandleElement: CandleElement + FloatDType {}
/// Candle int element
pub trait IntCandleElement: CandleElement {}

impl CandleElement for f64 {}
Expand Down
1 change: 1 addition & 0 deletions crates/burn-candle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ mod ops;
mod tensor;

pub use backend::*;
pub use element::*;
pub use tensor::*;

#[cfg(test)]
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-cuda/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-cuda"
version.workspace = true

[features]
default = ["fusion", "autotune", "burn-jit/default", "cubecl/default"]
autotune = ["burn-jit/autotune"]
default = ["fusion", "autotune", "burn-jit/default", "cubecl/default"]
doc = ["burn-jit/doc"]
fusion = ["burn-fusion", "burn-jit/fusion"]
std = ["burn-jit/std", "cubecl/std"]
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-jit/src/kernel/index/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pub(crate) use flip::*;
pub(crate) use repeat_dim::*;
pub(crate) use select::*;
pub(crate) use select_assign::*;
pub(crate) use slice::*;
pub use slice::*;
pub(crate) use slice_assign::*;

pub(crate) use gather::*;
Expand Down
3 changes: 2 additions & 1 deletion crates/burn-jit/src/kernel/index/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ use burn_tensor::Shape;
use cubecl::{calculate_cube_count_elemwise, prelude::*};
use std::ops::Range;

pub(crate) fn slice<R: JitRuntime, E: JitElement>(
/// Slice a jit tensor with a set of ranges
pub fn slice<R: JitRuntime, E: JitElement>(
tensor: JitTensor<R>,
indices: &[Range<usize>],
) -> JitTensor<R> {
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-jit/src/kernel/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,4 @@ pub mod reduce;

pub(crate) use clamp::*;
pub(crate) use comparison::*;
pub(crate) use index::*;
pub use index::*;
3 changes: 2 additions & 1 deletion crates/burn-jit/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
extern crate derive_new;
extern crate alloc;

mod ops;
/// Utilities for implementing JIT kernels
pub mod ops;

/// Kernel module
pub mod kernel;
Expand Down
4 changes: 3 additions & 1 deletion crates/burn-jit/src/ops/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ pub(crate) fn swap_dims<R: JitRuntime>(
tensor
}

/// Permute a tensor's dimensions
pub fn permute<R: JitRuntime>(mut tensor: JitTensor<R>, axes: &[usize]) -> JitTensor<R> {
// remap strides
tensor.strides = axes.iter().map(|i| tensor.strides[*i]).collect();
Expand Down Expand Up @@ -135,7 +136,8 @@ pub(crate) fn expand<R: JitRuntime>(tensor: JitTensor<R>, target_shape: Shape) -
}
}

pub(crate) fn reshape<R: JitRuntime>(tensor: JitTensor<R>, shape: Shape) -> JitTensor<R> {
/// Reshape a jit tensor to a new shape
pub fn reshape<R: JitRuntime>(tensor: JitTensor<R>, shape: Shape) -> JitTensor<R> {
// TODO: Not force standard layout all the time (improve performance).
let tensor = kernel::into_contiguous(tensor);

Expand Down
5 changes: 3 additions & 2 deletions crates/burn-jit/src/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ mod qtensor;
mod transaction;

pub(crate) mod base;
pub(crate) use base::*;
pub use base::*;

pub(crate) mod numeric;
/// Numeric utility functions for jit backends
pub mod numeric;
Loading