Skip to content

Commit

Permalink
implement benchmark for reduce kernel (#2692)
Browse files Browse the repository at this point in the history
  • Loading branch information
maxtremblay authored Jan 13, 2025
1 parent 3e90b6e commit 3990a8a
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 10 deletions.
4 changes: 4 additions & 0 deletions backend-comparison/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ path = "benches/resnet.rs"
harness = false
name = "autodiff"

[[bench]]
harness = false
name = "reduce"

[[bin]]
name = "burnbench"
path = "src/bin/burnbench.rs"
1 change: 1 addition & 0 deletions backend-comparison/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ Available Benchmarks:
- conv-transpose3d
- conv2d
- conv3d
- reduce
```

#### Run benchmarks
Expand Down
102 changes: 102 additions & 0 deletions backend-comparison/benches/reduce.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
use backend_comparison::persistence::save;
use burn::tensor::{backend::Backend, Distribution, Shape, Tensor};
use burn_common::benchmark::{run_benchmark, Benchmark};

enum Instruction {
ArgMin(usize),
SumDim(usize),
Sum,
}

struct ReduceBenchmark<B: Backend> {
instruction: Instruction,
shape: Shape,
device: B::Device,
tensor: Tensor<B, 3>,
}

impl<B: Backend> ReduceBenchmark<B> {
pub fn new(instruction: Instruction, device: B::Device) -> Self {
let shape = Shape::new([4096, 512, 64]);
let tensor = Tensor::random(shape.clone(), Distribution::Default, &device);
Self {
instruction,
shape,
device,
tensor,
}
}
}

impl<B: Backend> Benchmark for ReduceBenchmark<B> {
type Args = ();

fn prepare(&self) -> Self::Args {}

fn execute(&self, _: Self::Args) {
match self.instruction {
Instruction::ArgMin(axis) => {
self.tensor.clone().argmin(axis);
}
Instruction::SumDim(axis) => {
self.tensor.clone().sum_dim(axis);
}
Instruction::Sum => {
self.tensor.clone().sum();
}
}
}

fn name(&self) -> String {
match self.instruction {
Instruction::ArgMin(axis) => format!("reduce-argmin-{axis}"),
Instruction::SumDim(axis) => format!("reduce-sum-{axis}"),
Instruction::Sum => String::from("reduce-sum-full"),
}
}

fn sync(&self) {
B::sync(&self.device)
}

fn shapes(&self) -> Vec<Vec<usize>> {
vec![self.shape.dims.clone()]
}
}

#[allow(dead_code)]
fn bench<B: Backend>(
device: &B::Device,
feature_name: &str,
url: Option<&str>,
token: Option<&str>,
) {
let mut benchmarks = Vec::new();

for axis in 0..3 {
benchmarks.push(ReduceBenchmark::<B>::new(
Instruction::ArgMin(axis),
device.clone(),
));

benchmarks.push(ReduceBenchmark::<B>::new(
Instruction::SumDim(axis),
device.clone(),
));
}

benchmarks.push(ReduceBenchmark::<B>::new(Instruction::Sum, device.clone()));

save::<B>(
benchmarks.into_iter().map(run_benchmark).collect(),
device,
feature_name,
url,
token,
)
.unwrap();
}

fn main() {
backend_comparison::bench_on_backend!();
}
2 changes: 2 additions & 0 deletions backend-comparison/src/burnbenchapp/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ enum BenchmarkValues {
Conv2d,
#[strum(to_string = "conv3d")]
Conv3d,
#[strum(to_string = "reduce")]
Reduce,
}

pub fn execute() {
Expand Down
23 changes: 13 additions & 10 deletions backend-comparison/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ fn update_panic_hook() {
#[macro_export]
macro_rules! bench_on_backend {
() => {
$crate::bench_on_backend!(bench)
};
($fn_name:ident) => {
use std::env;
backend_comparison::init_log().unwrap();

Expand Down Expand Up @@ -99,14 +102,14 @@ macro_rules! bench_on_backend {
{
use burn::backend::wgpu::{Wgpu, WgpuDevice};

bench::<Wgpu<f32, i32>>(&WgpuDevice::default(), feature_name, url, token);
$fn_name::<Wgpu<f32, i32>>(&WgpuDevice::default(), feature_name, url, token);
}

#[cfg(any(feature = "wgpu-spirv"))]
{
use burn::backend::wgpu::{Wgpu, WgpuDevice};

bench::<Wgpu<half::f16, i32>>(&WgpuDevice::default(), feature_name, url, token);
$fn_name::<Wgpu<half::f16, i32>>(&WgpuDevice::default(), feature_name, url, token);
}

#[cfg(feature = "tch-gpu")]
Expand All @@ -117,15 +120,15 @@ macro_rules! bench_on_backend {
let device = LibTorchDevice::Cuda(0);
#[cfg(target_os = "macos")]
let device = LibTorchDevice::Mps;
bench::<LibTorch<half::f16>>(&device, feature_name, url, token);
$fn_name::<LibTorch<half::f16>>(&device, feature_name, url, token);
}

#[cfg(feature = "tch-cpu")]
{
use burn::backend::{libtorch::LibTorchDevice, LibTorch};

let device = LibTorchDevice::Cpu;
bench::<LibTorch>(&device, feature_name, url, token);
$fn_name::<LibTorch>(&device, feature_name, url, token);
}

#[cfg(any(
Expand All @@ -139,7 +142,7 @@ macro_rules! bench_on_backend {
use burn::backend::NdArray;

let device = NdArrayDevice::Cpu;
bench::<NdArray>(&device, feature_name, url, token);
$fn_name::<NdArray>(&device, feature_name, url, token);
}

#[cfg(feature = "candle-cpu")]
Expand All @@ -148,7 +151,7 @@ macro_rules! bench_on_backend {
use burn::backend::Candle;

let device = CandleDevice::Cpu;
bench::<Candle>(&device, feature_name, url, token);
$fn_name::<Candle>(&device, feature_name, url, token);
}

#[cfg(feature = "candle-cuda")]
Expand All @@ -157,7 +160,7 @@ macro_rules! bench_on_backend {
use burn::backend::Candle;

let device = CandleDevice::cuda(0);
bench::<Candle>(&device, feature_name, url, token);
$fn_name::<Candle>(&device, feature_name, url, token);
}

#[cfg(feature = "candle-metal")]
Expand All @@ -166,21 +169,21 @@ macro_rules! bench_on_backend {
use burn::backend::Candle;

let device = CandleDevice::metal(0);
bench::<Candle>(&device, feature_name, url, token);
$fn_name::<Candle>(&device, feature_name, url, token);
}

#[cfg(feature = "cuda-jit")]
{
use burn::backend::cuda_jit::{Cuda, CudaDevice};

bench::<Cuda<half::f16>>(&CudaDevice::default(), feature_name, url, token);
$fn_name::<Cuda<half::f16>>(&CudaDevice::default(), feature_name, url, token);
}

#[cfg(feature = "hip-jit")]
{
use burn::backend::hip_jit::{Hip, HipDevice};

bench::<Hip<half::f16>>(&HipDevice::default(), feature_name, url, token);
$fn_name::<Hip<half::f16>>(&HipDevice::default(), feature_name, url, token);
}
};
}
Expand Down

0 comments on commit 3990a8a

Please sign in to comment.