Skip to content

Commit

Permalink
Migrate to type magic autotune (#2710)
Browse files Browse the repository at this point in the history
  • Loading branch information
wingertge authored Jan 20, 2025
1 parent b4d9d54 commit 949e77f
Show file tree
Hide file tree
Showing 13 changed files with 249 additions and 452 deletions.
168 changes: 88 additions & 80 deletions Cargo.lock

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -153,14 +153,14 @@ ahash = { version = "0.8.11", default-features = false }
portable-atomic-util = { version = "0.2.4", features = ["alloc"] }

### For the main burn branch. ###
# cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "3c083cb136214404d8eb594258534d10a118a077" }
# cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "3c083cb136214404d8eb594258534d10a118a077" }
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "2cc42af02671d90255ab823e29a4a3ad2e564333" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "2cc42af02671d90255ab823e29a4a3ad2e564333" }
### For local development. ###
# cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }
### For the release. ###
cubecl = { version = "0.4.0", default-features = false }
cubecl-common = { version = "0.4.0", default-features = false }
# cubecl = { version = "0.4.0", default-features = false }
# cubecl-common = { version = "0.4.0", default-features = false }

### For xtask crate ###
tracel-xtask = { version = "=1.1.8" }
Expand Down
31 changes: 1 addition & 30 deletions crates/burn-jit/src/kernel/conv/conv2d/gemm/algorithm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,14 @@ use cubecl::{
tile::{accelerated::Accelerated, TileMatmulFamily},
InvalidConfigError,
},
kernels::{matmul::AdvancedConfig, MatmulAvailabilityError},
kernels::matmul::AdvancedConfig,
},
prelude::*,
};

use super::{
base::{ConvolutionConfigFactory, ConvolutionFamily, ConvolutionProblem},
homogeneous::base::ImplicitGemmConvolutionFamily,
precision::ConvPrecision,
selection::ConvSelection,
};

Expand Down Expand Up @@ -47,34 +46,6 @@ pub trait Algorithm {
Self::GlobalConvolution::check_config(&config)?;
Ok(config)
}

/// Check availability of the matmul algorithm
fn check_availability<R: Runtime, CS: ConvPrecision>(
client: &ComputeClient<R::Server, R::Channel>,
config: &<Self::GlobalConvolution as ConvolutionConfigFactory>::Config,
) -> Result<(), MatmulAvailabilityError> {
Self::GlobalConvolution::check_availability::<R, CS>(client, config)
}

/// Determine whether the given convolution problem is valid to launch (within hardware limits)
fn can_launch<R: Runtime, CS: ConvPrecision>(
client: &ComputeClient<R::Server, R::Channel>,
problem: &ConvolutionProblem,
config: &<Self::GlobalConvolution as ConvolutionConfigFactory>::Config,
selection: &Self::Selection,
) -> bool {
if problem.options.groups > 1 || Self::check_availability::<R, CS>(client, config).is_err()
{
return false;
}

let cube_count = Self::cube_count(selection, problem);
let (max_x, max_y, max_z) = R::max_cube_count();
match cube_count {
CubeCount::Static(x, y, z) => x <= max_x && y <= max_y && z <= max_z,
_ => true,
}
}
}

/// Cmma convolution
Expand Down
8 changes: 1 addition & 7 deletions crates/burn-jit/src/kernel/conv/conv2d/gemm/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use cubecl::linalg::{
stage::{StageMatmul, StageMatmulFamily},
InvalidConfigError, MatmulProblem, MatrixLayout,
},
kernels::{matmul::AdvancedConfig, MatmulAvailabilityError},
kernels::matmul::AdvancedConfig,
},
tensor::{ReadWrite, VirtualTensor},
};
Expand Down Expand Up @@ -91,12 +91,6 @@ pub trait ConvolutionConfigFactory: Send + Sync + 'static {
/// Asserts that the configuration for this matmul will lead to a valid computation
fn check_config(config: &Self::Config) -> Result<(), InvalidConfigError>;

/// Checks if the client can handle the features used in this computation
fn check_availability<R: Runtime, CS: ConvPrecision>(
client: &ComputeClient<R::Server, R::Channel>,
config: &Self::Config,
) -> Result<(), MatmulAvailabilityError>;

fn make_config(
input: Self::Input,
problem: &ConvolutionProblem,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use cubecl::{
},
Ident, InvalidConfigError, MatrixLayout, StageDim,
},
kernels::{matmul::AdvancedConfig, MatmulAvailabilityError},
kernels::matmul::AdvancedConfig,
},
tensor::{ReadWrite, VirtualTensor},
},
Expand Down Expand Up @@ -194,13 +194,6 @@ where
SMM::check_config(&config.to_smm_config())
}

fn check_availability<R: Runtime, CS: ConvPrecision>(
client: &ComputeClient<R::Server, R::Channel>,
config: &Self::Config,
) -> Result<(), MatmulAvailabilityError> {
SMM::check_availability::<R, (CS::EG, CS::ES, CS::EA)>(client, &config.to_smm_config())
}

fn make_config(
input: Self::Input,
problem: &ConvolutionProblem,
Expand Down
60 changes: 6 additions & 54 deletions crates/burn-jit/src/kernel/conv/conv2d/gemm/launch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use burn_tensor::{
use cubecl::{
flex32,
ir::{Elem, FloatKind},
linalg::matmul::{self, components::MatrixLayout},
linalg::matmul::{self},
tensor_line_size, tf32, Feature,
};
use half::{bf16, f16};
Expand All @@ -23,7 +23,7 @@ use crate::{
algorithm::{Algorithm, ImplicitCmmaConv},
base::{ConvolutionLaunch, ConvolutionProblem},
},
nchw_to_nhwc, Conv2dAutotuneKey, ConvLaunchError,
nchw_to_nhwc, ConvLaunchError,
},
into_contiguous,
},
Expand Down Expand Up @@ -108,6 +108,10 @@ pub fn conv2d_gemm_with_algo<
where
SP::EG: JitElement,
{
if options.groups != 1 {
return Err(ConvLaunchError::Groups(options.groups));
}

let [batch_size, in_channels, height, width] = input.shape.dims();
let [out_channels, _, kernel_h, kernel_w] = weight.shape.dims();

Expand Down Expand Up @@ -226,58 +230,6 @@ where
Ok(permute(out, &[0, 3, 1, 2]))
}

pub fn problem_from_key<R: JitRuntime, F: FloatElement>(
key: &Conv2dAutotuneKey,
out_h: usize,
out_w: usize,
) -> ConvolutionProblem {
let in_stride_2 = key.in_channels;
let in_stride_1 = key.width * in_stride_2;
let in_stride_0 = key.height * in_stride_1;

let m = key.batch_size * out_h * out_w;
let n = key.out_channels;
let k = key.kernel_size[0] * key.kernel_size[1] * key.in_channels;

let options = ConvOptions {
stride: key.stride,
padding: key.padding,
dilation: key.dilation,
groups: key.groups,
};

// Target 128 bit accesses
let available_vectorizations = R::supported_line_sizes()
.iter()
.copied()
.filter(|it| *it as usize * size_of::<F>() <= 16)
.collect::<Vec<_>>();
let lhs_line_size = tensor_line_size(
&available_vectorizations,
&[key.batch_size, key.height, key.width, key.in_channels],
&[in_stride_0, in_stride_1, in_stride_2, 1],
3,
);
let rhs_line_size = tensor_line_size(&available_vectorizations, &[k, n], &[n, 1], 1);
let out_line_size = tensor_line_size(&available_vectorizations, &[m, n], &[n, 1], 1);

ConvolutionProblem {
m,
n,
k,
lhs_layout: MatrixLayout::RowMajor,
rhs_layout: MatrixLayout::RowMajor,
lhs_line_size,
rhs_line_size,
out_line_size,
kernel_size: (key.kernel_size[0] as u32, key.kernel_size[1] as u32),
options,
out_shape_y: out_h,
out_shape_x: out_w,
has_bias: key.has_bias,
}
}

pub(crate) fn has_tf32<R: JitRuntime>(c: &JitTensor<R>) -> bool {
c.client
.properties()
Expand Down
61 changes: 39 additions & 22 deletions crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use cmma::{Matrix, MatrixIdent, MatrixLayout};
use cubecl::{
cube,
ir::{Elem, FloatKind},
linalg::matmul::kernels::{MatmulAvailabilityError, MatmulLaunchError},
prelude::*,
Compiler, CubeCount, CubeDim, Feature,
};
Expand Down Expand Up @@ -66,7 +67,7 @@ pub fn conv2d_implicit_gemm<R: JitRuntime, F: FloatElement>(

let padded_batch_size = padded_batch_size(batch_size, out_h, out_w);

if !can_do_implicit_gemm::<R, F>(
check_availability::<R, F>(
batch_size,
in_channels,
out_channels,
Expand All @@ -75,15 +76,7 @@ pub fn conv2d_implicit_gemm<R: JitRuntime, F: FloatElement>(
out_h,
out_w,
&input.client,
) {
panic!(
"Requirements for implicit GEMM not met:
- CMMA must be available
- `groups` must be 1
- subcube size must be non-variable (might not hold on Intel)
"
);
}
)?;

// If input is contiguous NCHW, use custom transpose kernel
let input = match input.is_contiguous() {
Expand Down Expand Up @@ -643,7 +636,7 @@ fn load_weight_tile<F: Float, FMat: Float>(
}

#[allow(clippy::too_many_arguments)]
pub(crate) fn can_do_implicit_gemm<R: JitRuntime, E: FloatElement>(
pub(crate) fn check_availability<R: JitRuntime, E: FloatElement>(
batch_size: usize,
in_channels: usize,
out_channels: usize,
Expand All @@ -652,7 +645,7 @@ pub(crate) fn can_do_implicit_gemm<R: JitRuntime, E: FloatElement>(
out_h: usize,
out_w: usize,
client: &ComputeClient<R::Server, R::Channel>,
) -> bool {
) -> Result<(), ConvLaunchError> {
let cmma_k = match (
E::as_elem_native_unchecked(),
client
Expand All @@ -672,19 +665,43 @@ pub(crate) fn can_do_implicit_gemm<R: JitRuntime, E: FloatElement>(
let gemm_n = out_channels;
let gemm_k = in_channels * kernel_h * kernel_w;

let size = find_cmma_size::<R, E>(client, gemm_m as u32, gemm_k as u32, gemm_n as u32);

if let Some((cmma_m, cmma_k, cmma_n)) = size {
let warps_per_cube = 8;
let (cmma_m, cmma_n, cmma_k) =
find_cmma_size::<R, E>(client, gemm_m as u32, gemm_k as u32, gemm_n as u32).ok_or_else(
|| {
ConvLaunchError::Matmul(MatmulLaunchError::Unavailable(
MatmulAvailabilityError::CmmaInstructionUnavailable {
input: E::as_elem_native_unchecked(),
output: E::as_elem_native_unchecked(),
m: 16,
n: 16,
k: cmma_k as u32,
},
))
},
)?;

let warps_per_cube = 8;

let smem_size = ((cmma_m + cmma_n) * cmma_k * warps_per_cube) as usize * size_of::<f16>();
if <R::Compiler as Compiler>::max_shared_memory_size() < smem_size {
return Err(ConvLaunchError::Matmul(MatmulLaunchError::InvalidConfig(
Box::new("Not enough shared memory"),
)));
}

let smem_size = ((cmma_m + cmma_n) * cmma_k * warps_per_cube) as usize * size_of::<f16>();
let topology = client.properties().hardware_properties();
let not_intel = topology.plane_size_min >= 32;
let topology = client.properties().hardware_properties();
if topology.plane_size_min < 32 {
return Err(ConvLaunchError::Matmul(MatmulLaunchError::Unavailable(
MatmulAvailabilityError::PlaneDimUnsupported {
plane_dim: topology.plane_size_min,
},
)));
}

<R::Compiler as Compiler>::max_shared_memory_size() >= smem_size && groups == 1 && not_intel
} else {
false
if groups != 1 {
return Err(ConvLaunchError::Groups(groups));
}
Ok(())
}

fn padded_k(
Expand Down
Loading

0 comments on commit 949e77f

Please sign in to comment.