From c63b693e4c43fff6f7ac252f73a34e617cf6ada4 Mon Sep 17 00:00:00 2001 From: syl20bnr Date: Mon, 9 Dec 2024 12:55:58 -0500 Subject: [PATCH] Use the crate rocm_ features and change naming scheme of them In order to make it easier to parse the rocm version from the feature name the new name scheme is 'rocm__x_x_x' --- crates/cubecl-hip-sys/Cargo.toml | 6 +-- crates/cubecl-hip-sys/build.rs | 60 ++++++++++++++++++++--- crates/cubecl-hip-sys/src/bindings/mod.rs | 8 +-- 3 files changed, 59 insertions(+), 15 deletions(-) diff --git a/crates/cubecl-hip-sys/Cargo.toml b/crates/cubecl-hip-sys/Cargo.toml index def93be..4573bb5 100644 --- a/crates/cubecl-hip-sys/Cargo.toml +++ b/crates/cubecl-hip-sys/Cargo.toml @@ -12,9 +12,9 @@ version.workspace = true rust-version = "1.81" [features] -default = ["rocm_624"] -rocm_622 = [] -rocm_624 = [] +default = ["rocm__6_2_4"] +rocm__6_2_2 = [] +rocm__6_2_4 = [] [dependencies] libc = { workspace = true } diff --git a/crates/cubecl-hip-sys/build.rs b/crates/cubecl-hip-sys/build.rs index b9bb248..88012ff 100644 --- a/crates/cubecl-hip-sys/build.rs +++ b/crates/cubecl-hip-sys/build.rs @@ -1,6 +1,8 @@ -use std::env; +use std::{env, io}; use std::path::Path; +const ROCM_FEATURE_PREFIX: &str = "CARGO_FEATURE_ROCM__"; + /// Reads a header inside the rocm folder, that contains the lib's version fn get_system_hip_version(rocm_path: impl AsRef) -> std::io::Result<(u8, u8, u32)> { let version_path = rocm_path.as_ref().join("include/hip/hip_version.h"); @@ -45,16 +47,56 @@ fn hip_header_patch_number_to_release_patch_number(number: u32) -> Option { } } +/// Return the ROCm version corresponding to the enabled feature +fn get_rocm_feature_version() -> io::Result<(u8, u8, u32)> { + for (key, value) in env::vars() { + if key.starts_with(ROCM_FEATURE_PREFIX) && value == "1" { + if let Some(version) = key.strip_prefix(ROCM_FEATURE_PREFIX) { + // Parse the version using `_` as the delimiter + let parts: Vec<&str> = version.split('_').collect(); + if parts.len() == 3 { + if let (Ok(major), Ok(minor), Ok(patch)) = ( + parts[0].parse::(), + parts[1].parse::(), + parts[2].parse::(), + ) { + return Ok((major, minor, patch)); + } + } + } + } + } + + Err(io::Error::new( + io::ErrorKind::NotFound, + "No valid ROCm feature version found. One 'rocm_' feature must be set.", + )) +} + +/// Make sure that feature is set correctly +fn ensure_single_rocm_feature_set() { + let mut enabled_features = Vec::new(); + + for (key, value) in env::vars() { + if key.starts_with(ROCM_FEATURE_PREFIX) && value == "1" { + enabled_features.push(format!("rocm__{}", key.strip_prefix(ROCM_FEATURE_PREFIX).unwrap())); + } + } + + match enabled_features.len() { + 1 => {}, + 0 => panic!("No ROCm version features are enabled. One ROCm version feature must be set."), + _ => panic!( + "Multiple ROCm version features are enabled: {:?}. Only one can be set.", + enabled_features + ), + } +} + /// Checks if the version inside `rocm_path` matches crate version fn check_version(rocm_path: impl AsRef) -> std::io::Result { let (system_major, system_minor, system_patch) = get_system_hip_version(rocm_path)?; - - // Can be fairly sure that crate's versioning won't fail - let crate_major = env!("CARGO_PKG_VERSION_MAJOR").parse::().unwrap(); - let crate_minor = env!("CARGO_PKG_VERSION_MINOR").parse::().unwrap(); - // Need at least u32 here, because of the crates reserved versions for revisions, - // and _unlikely_ possibility that ROCm's patch version will be higher than 16 - let crate_patch = env!("CARGO_PKG_VERSION_PATCH").parse::().unwrap() / 1000; + let (crate_major, crate_minor, crate_patch) = get_rocm_feature_version()?; if crate_major == system_major { let mismatches = match (crate_minor == system_minor, crate_patch == system_patch) { @@ -74,6 +116,8 @@ fn check_version(rocm_path: impl AsRef) -> std::io::Result { } fn main() { + ensure_single_rocm_feature_set(); + println!("cargo:rerun-if-changed=build.rs"); println!("cargo:rerun-if-env-changed=CUBECL_ROCM_PATH"); println!("cargo:rerun-if-env-changed=ROCM_PATH"); diff --git a/crates/cubecl-hip-sys/src/bindings/mod.rs b/crates/cubecl-hip-sys/src/bindings/mod.rs index 0770c79..0c85688 100644 --- a/crates/cubecl-hip-sys/src/bindings/mod.rs +++ b/crates/cubecl-hip-sys/src/bindings/mod.rs @@ -1,8 +1,8 @@ -#[cfg(feature = "rocm_622")] +#[cfg(feature = "rocm__6_2_2")] mod bindings_622; -#[cfg(feature = "rocm_622")] +#[cfg(feature = "rocm__6_2_2")] pub use bindings_622::*; -#[cfg(feature = "rocm_624")] +#[cfg(feature = "rocm__6_2_4")] mod bindings_624; -#[cfg(feature = "rocm_624")] +#[cfg(feature = "rocm__6_2_4")] pub use bindings_624::*;