From 4878b863c8e69826641f86a0f559367c0aa511d6 Mon Sep 17 00:00:00 2001 From: Zack Angelo Date: Mon, 15 Jul 2024 13:35:02 -0500 Subject: [PATCH 1/2] metal: precompile kernels --- candle-metal-kernels/Cargo.toml | 4 ++ candle-metal-kernels/build.rs | 117 ++++++++++++++++++++++++++++++++ candle-metal-kernels/src/lib.rs | 65 ++++++++++++------ 3 files changed, 165 insertions(+), 21 deletions(-) create mode 100644 candle-metal-kernels/build.rs diff --git a/candle-metal-kernels/Cargo.toml b/candle-metal-kernels/Cargo.toml index e7a85f1f6..20cb9b601 100644 --- a/candle-metal-kernels/Cargo.toml +++ b/candle-metal-kernels/Cargo.toml @@ -23,3 +23,7 @@ half = { version = "2.3.1", features = [ "rand_distr", ] } rand = "0.8.5" + +[build-dependencies] +anyhow = "1.0.44" +convert_case = "0.6.0" diff --git a/candle-metal-kernels/build.rs b/candle-metal-kernels/build.rs new file mode 100644 index 000000000..c79957c89 --- /dev/null +++ b/candle-metal-kernels/build.rs @@ -0,0 +1,117 @@ +use anyhow::Result; +use std::io::Write; +use std::path::{Path, PathBuf}; +use std::{env, fs}; + +fn main() -> Result<()> { + let kernel_files: Vec = kernel_source_files()?; + let mut metallib_files: Vec = Vec::with_capacity(kernel_files.len()); + + for kernel_file in kernel_files { + let ir_path = compile_kernel(kernel_file)?; + let metallib_path = link_kernel(ir_path)?; + metallib_files.push(metallib_path); + } + + gen_metallibs_rs(metallib_files)?; + + Ok(()) +} + +fn kernel_source_files() -> Result> { + let manifest_dir = env::var("CARGO_MANIFEST_DIR")?; + let src_dir = Path::new(&manifest_dir).join("src"); + + let mut paths = Vec::new(); + for entry in fs::read_dir(src_dir)? { + let entry = entry.unwrap(); + let path = entry.path(); + if path.extension().map(|ext| ext.to_str().unwrap()) == Some("metal") { + paths.push(path); + } + } + + Ok(paths) +} + +fn compile_kernel(kernel_path: impl AsRef) -> Result { + let out_dir = std::env::var("OUT_DIR")?; + + let ir_file_name = format!( + "{}.ir", + kernel_path.as_ref().file_stem().unwrap().to_str().unwrap() + ); + + let output_file = Path::new(&out_dir).join(ir_file_name); + + let mut command = std::process::Command::new("xcrun"); + command.arg("metal"); + command.arg("-c"); + command.arg(format!("{}", kernel_path.as_ref().display())); + command.arg("-o"); + command.arg(format!("{}", output_file.display())); + + let status = command.status()?; + + if !status.success() { + return Err(anyhow::anyhow!( + "Failed to compile kernel file: {:?}", + kernel_path.as_ref() + )); + } + + Ok(output_file) +} + +fn link_kernel(ir_path: impl AsRef) -> Result { + let out_dir = std::env::var("OUT_DIR")?; + + let metallib_file_name = format!( + "{}.metallib", + ir_path.as_ref().file_stem().unwrap().to_str().unwrap() + ); + + let output_file = Path::new(&out_dir).join(metallib_file_name); + + let mut command = std::process::Command::new("xcrun"); + command.arg("metallib"); + command.arg(format!("{}", ir_path.as_ref().display())); + command.arg("-o"); + command.arg(format!("{}", output_file.display())); + + let status = command.status()?; + + if !status.success() { + return Err(anyhow::anyhow!( + "Failed to link kernel file: {:?}", + ir_path.as_ref() + )); + } + + Ok(output_file) +} + +fn gen_metallibs_rs(metallibs: Vec) -> Result<()> { + use convert_case::{Case, Casing}; + + // generate a rust source file that contains an include_bytes constant + // for every metallib file + let out_dir = std::env::var("OUT_DIR")?; + let out_file = Path::new(&out_dir).join("candle_metallibs.rs"); + + let mut file = fs::File::create(&out_file)?; + + // writeln!(file, "pub mod metallibs {{")?; + for metallib in metallibs { + let name = metallib.file_stem().unwrap().to_str().unwrap(); + writeln!( + file, + " pub const {}: &'static [u8] = include_bytes!(\"{}\");", + name.to_case(Case::ScreamingSnake), + metallib.display() + )?; + } + // writeln!(file, "}}")?; + + Ok(()) +} diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 1815dd321..bd550d9ba 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -23,6 +23,10 @@ const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib"); const QUANTIZED: &str = include_str!("quantized.metal"); const SORT: &str = include_str!("sort.metal"); +mod metallibs { + include!(concat!(env!("OUT_DIR"), "/candle_metallibs.rs")); +} + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum Source { Affine, @@ -187,21 +191,29 @@ impl Kernels { } } - fn get_library_source(&self, source: Source) -> &'static str { - match source { - Source::Affine => AFFINE, - Source::Unary => UNARY, - Source::Binary => BINARY, - Source::Ternary => TERNARY, - Source::Indexing => INDEXING, - Source::Cast => CAST, - Source::Reduce => REDUCE, - Source::Conv => CONV, - Source::Random => RANDOM, - Source::Quantized => QUANTIZED, - Source::Sort => SORT, - Source::Mfa => panic!("Invalid lib"), - } + // fn get_library_source(&self, source: Source) -> &'static str { + // match source { + // Source::Affine => AFFINE, + // Source::Unary => UNARY, + // Source::Binary => BINARY, + // Source::Ternary => TERNARY, + // Source::Indexing => INDEXING, + // Source::Cast => CAST, + // Source::Reduce => REDUCE, + // Source::Conv => CONV, + // Source::Random => RANDOM, + // Source::Quantized => QUANTIZED, + // Source::Sort => SORT, + // Source::Mfa => panic!("Invalid lib"), + // } + // } + + fn load_metallib(device: &Device, data: &[u8]) -> Result { + device.new_library_with_data(data).map_err(|e| { + MetalKernelError::LoadLibraryError(format!( + "Candle metal requires macosx > 13.0 or higher, cannot load mfa: {e}" + )) + }) } /// Load the give library from its [`source`]. @@ -224,12 +236,23 @@ impl Kernels { )) })? } - source => { - let source_content = self.get_library_source(source); - device - .new_library_with_source(source_content, &CompileOptions::new()) - .map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))? - } + Source::Affine => Self::load_metallib(device, metallibs::AFFINE)?, + Source::Indexing => Self::load_metallib(device, metallibs::INDEXING)?, + Source::Unary => Self::load_metallib(device, metallibs::UNARY)?, + Source::Binary => Self::load_metallib(device, metallibs::BINARY)?, + Source::Ternary => Self::load_metallib(device, metallibs::TERNARY)?, + Source::Cast => Self::load_metallib(device, metallibs::CAST)?, + Source::Reduce => Self::load_metallib(device, metallibs::REDUCE)?, + Source::Conv => Self::load_metallib(device, metallibs::CONV)?, + Source::Random => Self::load_metallib(device, metallibs::RANDOM)?, + Source::Quantized => Self::load_metallib(device, metallibs::QUANTIZED)?, + Source::Sort => Self::load_metallib(device, metallibs::SORT)?, + // source => { + // let source_content = self.get_library_source(source); + // device + // .new_library_with_source(source_content, &CompileOptions::new()) + // .map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))? + // } }; libraries.insert(source, lib.clone()); Ok(lib) From ac1cb58421737ba01de98c439217c3c305e28382 Mon Sep 17 00:00:00 2001 From: Zack Angelo Date: Mon, 15 Jul 2024 13:54:46 -0500 Subject: [PATCH 2/2] re-run build script if metal kernel changes --- candle-metal-kernels/build.rs | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/candle-metal-kernels/build.rs b/candle-metal-kernels/build.rs index c79957c89..b40b26af0 100644 --- a/candle-metal-kernels/build.rs +++ b/candle-metal-kernels/build.rs @@ -37,10 +37,10 @@ fn kernel_source_files() -> Result> { fn compile_kernel(kernel_path: impl AsRef) -> Result { let out_dir = std::env::var("OUT_DIR")?; - let ir_file_name = format!( - "{}.ir", - kernel_path.as_ref().file_stem().unwrap().to_str().unwrap() - ); + let file_stem = kernel_path.as_ref().file_stem().unwrap().to_str().unwrap(); + let ir_file_name = format!("{}.ir", file_stem,); + + println!("cargo:rerun-if-changed=src/{}.metal", file_stem); let output_file = Path::new(&out_dir).join(ir_file_name); @@ -101,17 +101,15 @@ fn gen_metallibs_rs(metallibs: Vec) -> Result<()> { let mut file = fs::File::create(&out_file)?; - // writeln!(file, "pub mod metallibs {{")?; for metallib in metallibs { let name = metallib.file_stem().unwrap().to_str().unwrap(); writeln!( file, - " pub const {}: &'static [u8] = include_bytes!(\"{}\");", + "pub const {}: &'static [u8] = include_bytes!(\"{}\");", name.to_case(Case::ScreamingSnake), metallib.display() )?; } - // writeln!(file, "}}")?; Ok(()) }