Skip to content

Commit

Permalink
Wgpu fusion auto-vectorized operations (#1123)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Jan 8, 2024
1 parent ab67b6b commit de5f932
Show file tree
Hide file tree
Showing 22 changed files with 1,191 additions and 348 deletions.
6 changes: 5 additions & 1 deletion backend-comparison/benches/custom_gelu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ impl<B: Backend, const D: usize> Benchmark for CustomGeluBenchmark<B, D> {
fn sync(&self) {
B::sync(&self.device)
}

fn num_samples(&self) -> usize {
50
}
}

fn gelu_custom<B, const D: usize, Erf>(x: Tensor<B, D>, erf: Erf) -> Tensor<B, D>
Expand Down Expand Up @@ -87,7 +91,7 @@ fn erf_positive<B: Backend, const D: usize>(x: Tensor<B, D>) -> Tensor<B, D> {
fn bench<B: Backend>(device: &B::Device) {
const D: usize = 3;
let shape: Shape<D> = [32, 512, 2048].into();
let num_repeats = 10;
let num_repeats = 1;

let reference_gelu = CustomGeluBenchmark::<B, D>::new(
shape.clone(),
Expand Down
166 changes: 150 additions & 16 deletions burn-wgpu/src/codegen/function.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
use super::Elem;
use super::Item;
use serde::{Deserialize, Serialize};
use std::fmt::Display;

/// Not all functions are native to WGSL, so this struct allows to support more functions.
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
pub enum Function {
Powf(Elem),
Erf(Elem),
Powf(Item),
Erf(Item),
#[cfg(target_os = "macos")]
SafeTanh(Elem),
SafeTanh(Item),
}

impl Display for Function {
Expand All @@ -22,10 +22,12 @@ impl Display for Function {
}
}

fn format_powf(f: &mut core::fmt::Formatter<'_>, elem: &Elem) -> core::fmt::Result {
fn format_powf(f: &mut core::fmt::Formatter<'_>, item: &Item) -> core::fmt::Result {
let elem = item.elem();

f.write_fmt(format_args!(
"
fn powf(lhs: {elem}, rhs: {elem}) -> {elem} {{
fn powf_scalar(lhs: {elem}, rhs: {elem}) -> {elem} {{
let modulo = rhs % 2.0;
if (modulo == 0.0) {{
Expand All @@ -40,17 +42,61 @@ fn powf(lhs: {elem}, rhs: {elem}) -> {elem} {{
}}
}}
"
))
))?;

match item {
Item::Vec4(elem) => f.write_fmt(format_args!(
"
fn powf(lhs: {item}, rhs: {elem}) -> {item} {{
return vec4(
powf_scalar(lhs[0], rhs),
powf_scalar(lhs[1], rhs),
powf_scalar(lhs[2], rhs),
powf_scalar(lhs[3], rhs),
);
}}
"
)),
Item::Vec3(elem) => f.write_fmt(format_args!(
"
fn powf(lhs: {item}, rhs: {elem}) -> {item} {{
return vec3(
powf_scalar(lhs[0], rhs),
powf_scalar(lhs[1], rhs),
powf_scalar(lhs[2], rhs),
);
}}
"
)),
Item::Vec2(elem) => f.write_fmt(format_args!(
"
fn powf(lhs: {item}, rhs: {elem}) -> {item} {{
return vec2(
powf_scalar(lhs[0], rhs),
powf_scalar(lhs[1], rhs),
);
}}
"
)),
Item::Scalar(elem) => f.write_fmt(format_args!(
"
fn powf(lhs: {elem}, rhs: {elem}) -> {elem} {{
return powf_scalar(lhs, rhs);
}}
"
)),
}
}

fn format_erf(f: &mut core::fmt::Formatter<'_>, elem: &Elem) -> core::fmt::Result {
fn format_erf(f: &mut core::fmt::Formatter<'_>, ty: &Item) -> core::fmt::Result {
let elem = ty.elem();
f.write_fmt(format_args!(
"
/// An approximation of the error function: https://en.wikipedia.org/wiki/Error_function#Numerical_approximations
///
/// > (maximum error: 1.5×10−7)
/// > All of these approximations are valid for x ≥ 0. To use these approximations for negative x, use the fact that erf x is an odd function, so erf x = −erf(−x).
fn erf_positive(x: {elem}) -> {elem} {{
fn erf_positive_scalar(x: {elem}) -> {elem} {{
let p = 0.3275911;
let a1 = 0.254829592;
let a2 = -0.284496736;
Expand All @@ -64,29 +110,117 @@ fn erf_positive(x: {elem}) -> {elem} {{
return 1.0 - (tmp * t * exp(-x * x));
}}
fn erf(x: {elem}) -> {elem} {{
fn erf_scalar(x: {elem}) -> {elem} {{
if (x < 0.0) {{
return -1.0 * erf_positive(-1.0 * x);
return -1.0 * erf_positive_scalar(-1.0 * x);
}}
return erf_positive(x);
return erf_positive_scalar(x);
}}
"
))
))?;

match ty {
Item::Vec4(_) => f.write_fmt(format_args!(
"
fn erf(x: {ty}) -> {ty} {{
return vec4(
erf_scalar(x[0]),
erf_scalar(x[1]),
erf_scalar(x[2]),
erf_scalar(x[3]),
);
}}
"
)),
Item::Vec3(_) => f.write_fmt(format_args!(
"
fn erf(x: {ty}) -> {ty} {{
return vec3(
erf_scalar(x[0]),
erf_scalar(x[1]),
erf_scalar(x[2]),
);
}}
"
)),
Item::Vec2(_) => f.write_fmt(format_args!(
"
fn erf(x: {ty}) -> {ty} {{
return vec2(
erf_scalar(x[0]),
erf_scalar(x[1]),
);
}}
"
)),
Item::Scalar(_) => f.write_fmt(format_args!(
"
fn erf(x: {ty}) -> {ty} {{
return erf_scalar(x);
}}
"
)),
}
}

#[cfg(target_os = "macos")]
fn format_safe_tanh(f: &mut core::fmt::Formatter<'_>, elem: &Elem) -> core::fmt::Result {
fn format_safe_tanh(f: &mut core::fmt::Formatter<'_>, item: &Item) -> core::fmt::Result {
let elem = item.elem();

f.write_fmt(format_args!(
"
/// Metal has a weird numerical behaviour with tanh for inputs over 43.0
fn safe_tanh(x: {elem}) -> {elem} {{
fn safe_tanh_scalar(x: {elem}) -> {elem} {{
if x > 43.0 {{
return 1.0;
}} else {{
return tanh(x);
}}
}}
"
))
))?;

match item {
Item::Vec4(_) => f.write_fmt(format_args!(
"
fn safe_tanh(x: {item}) -> {item} {{
return vec4(
safe_tanh_scalar(x[0]),
safe_tanh_scalar(x[1]),
safe_tanh_scalar(x[2]),
safe_tanh_scalar(x[3]),
);
}}
"
)),
Item::Vec3(_) => f.write_fmt(format_args!(
"
fn safe_tanh(x: {item}) -> {item} {{
return vec3(
safe_tanh_scalar(x[0]),
safe_tanh_scalar(x[1]),
safe_tanh_scalar(x[2]),
);
}}
"
)),
Item::Vec2(_) => f.write_fmt(format_args!(
"
fn safe_tanh(x: {item}) -> {item} {{
return vec2(
safe_tanh_scalar(x[0]),
safe_tanh_scalar(x[1]),
);
}}
"
)),
Item::Scalar(_) => f.write_fmt(format_args!(
"
fn safe_tanh(x: {item}) -> {item} {{
return safe_tanh_scalar(x);
}}
"
)),
}
}
Loading

0 comments on commit de5f932

Please sign in to comment.