diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f2a55f6..adb61e0 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -15,13 +15,18 @@ jobs: checks: runs-on: amd-rx7600 steps: - - name: Setup Rust - uses: tracel-ai/github-actions/setup-rust@v1.8 + # -------------------------------------------------------------------------------- + # We don't use our github-action because it seems that the cache does not work well + # with our AMD runner. + # cargo-audit is not found for example whereas it is correctly installed. + - name: checkout + uses: actions/checkout@v4 + - name: install rust + uses: dtolnay/rust-toolchain@master with: - rust-toolchain: stable - cache-key: stable-linux - linux-pre-cleanup: false - # -------------------------------------------------------------------------------- + components: rustfmt, clippy + toolchain: stable + # -------------------------------------------------------------------------------- - name: Audit run: cargo xtask check audit # -------------------------------------------------------------------------------- @@ -39,12 +44,14 @@ jobs: tests: runs-on: amd-rx7600 steps: - - name: Setup Rust - uses: tracel-ai/github-actions/setup-rust@v1.8 + # -------------------------------------------------------------------------------- + - name: checkout + uses: actions/checkout@v4 + - name: install rust + uses: dtolnay/rust-toolchain@master with: - rust-toolchain: stable - cache-key: stable-linux - linux-pre-cleanup: false + components: rustfmt, clippy + toolchain: stable # -------------------------------------------------------------------------------- - name: Lint run: cargo xtask check lint diff --git a/Cargo.lock b/Cargo.lock index 65cfdf3..a1c3703 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -172,7 +172,7 @@ checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990" [[package]] name = "cubecl-hip-sys" -version = "6.3.0" +version = "6.3.1000" dependencies = [ "libc", "rstest", diff --git a/README.md b/README.md index c655cd6..8140691 100644 --- a/README.md +++ b/README.md @@ -43,7 +43,12 @@ If a fix is required and the default ROCm version remains `6.2.4`, the `cubecl-h Add the crate [cubecl-hip-sys][2] to the `Cargo.toml` of your project and enable the feature corresponding to the version of ROCm you have installed. -If you no feature corresponds to your ROCm installation then read the next section to learn + +```toml +cubecl-hip-sys = { version = "6.3.1000", features = ["rocm__6_3_1"] } +``` + +If no feature corresponds to your ROCm installation then read the next section to learn how to generate and submit new bindings for your version. Next you need to point out where you installed ROCm so that `rustc` can link to your ROCM libraries. To do so set the variable `ROCM_PATH`, or `HIP_PATH` or the more specific `CUBECL_ROCM_PATH` to its diff --git a/crates/cubecl-hip-sys/Cargo.toml b/crates/cubecl-hip-sys/Cargo.toml index 81829ec..8cf963b 100644 --- a/crates/cubecl-hip-sys/Cargo.toml +++ b/crates/cubecl-hip-sys/Cargo.toml @@ -12,7 +12,7 @@ version.workspace = true rust-version = "1.81" [features] -default = ["rocm__6_3_1"] +default = [] # ROCm versions rocm__6_2_2 = [ "hip_41134" ] diff --git a/crates/cubecl-hip-sys/build.rs b/crates/cubecl-hip-sys/build.rs index 241614c..ca87ee6 100644 --- a/crates/cubecl-hip-sys/build.rs +++ b/crates/cubecl-hip-sys/build.rs @@ -5,14 +5,53 @@ const ROCM_HIP_FEATURE_PREFIX: &str = "CARGO_FEATURE_HIP_"; include!("src/build_script.rs"); +/// Return true if at least one rocm_x_x_x feature is set +fn is_rocm_feature_set() -> bool { + let mut enabled_features = Vec::new(); + + for (key, value) in env::vars() { + if key.starts_with(ROCM_FEATURE_PREFIX) && value == "1" { + enabled_features.push(format!( + "rocm__{}", + key.strip_prefix(ROCM_FEATURE_PREFIX).unwrap() + )); + } + } + + !enabled_features.is_empty() +} + +/// Make sure that at least one and only one rocm version feature is set +fn ensure_single_rocm_version_feature_set() { + let mut enabled_features = Vec::new(); + + for (key, value) in env::vars() { + if key.starts_with(ROCM_FEATURE_PREFIX) && value == "1" { + enabled_features.push(format!( + "rocm__{}", + key.strip_prefix(ROCM_FEATURE_PREFIX).unwrap() + )); + } + } + + match enabled_features.len() { + 1 => {} + 0 => panic!("No ROCm version feature is enabled. One ROCm version feature must be set."), + _ => panic!( + "Multiple ROCm version features are enabled: {:?}. Only one can be set.", + enabled_features + ), + } +} + /// Make sure that at least one and only one hip feature is set -fn ensure_single_rocm_hip_feature_set() { +fn ensure_single_hip_feature_set() { let mut enabled_features = Vec::new(); for (key, value) in env::vars() { if key.starts_with(ROCM_HIP_FEATURE_PREFIX) && value == "1" { enabled_features.push(format!( - "rocm__{}", + "hip_{}", key.strip_prefix(ROCM_HIP_FEATURE_PREFIX).unwrap() )); } @@ -31,6 +70,10 @@ fn ensure_single_rocm_hip_feature_set() { /// Checks if the version inside `rocm_path` matches crate version fn check_rocm_version(rocm_path: impl AsRef) -> std::io::Result { let rocm_system_version = get_rocm_system_version(rocm_path)?; + if !is_rocm_feature_set() { + // If there is no feature set but we found a system version we continue + return Ok(true); + } let rocm_feature_version = get_rocm_feature_version(); if rocm_system_version.major == rocm_feature_version.major { @@ -50,6 +93,34 @@ fn check_rocm_version(rocm_path: impl AsRef) -> std::io::Result { } } +/// If no rocm_x_x_x feature is set then we set the feature corresponding +/// to the passed ROCm path. +fn set_default_rocm_version(rocm_path: impl AsRef) -> std::io::Result<()> { + if is_rocm_feature_set() { + // a feature has been prodived to set the ROCm version + return Ok(()); + } + println!("cargo::warning=No `rocm__x_x_x` feature set. Using the version of a default installation of ROCm if found on the system. Consider setting a `rocm__x_x_x` feature in the Cargo.toml file of your crate."); + + // Set default feature with the version found on the system + let rocm_system_version = get_rocm_system_version(&rocm_path)?; + let hip_system_patch = get_hip_system_version(&rocm_path)?; + println!("cargo::warning=Found default version of ROCm on system: {rocm_system_version}. Associated HIP patch version is: {}", hip_system_patch.patch); + let default_rocm_feature = format!("rocm__{}", rocm_system_version).replace(".", "_"); + let default_hip_feature = format!("hip_{}", hip_system_patch.patch); + println!("cargo:rustc-cfg=feature=\"{}\"", default_rocm_feature); + println!("cargo:rustc-cfg=feature=\"{}\"", default_hip_feature); + env::set_var( + format!("{ROCM_FEATURE_PREFIX}{}", rocm_system_version).replace(".", "_"), + "1", + ); + env::set_var( + format!("{ROCM_HIP_FEATURE_PREFIX}{}", hip_system_patch.patch), + "1", + ); + Ok(()) +} + /// Return the ROCm version corresponding to the enabled rocm__ feature fn get_rocm_feature_version() -> Version { for (key, value) in env::vars() { @@ -115,7 +186,9 @@ fn main() { let rocm_path = rocm_path_candidates.find(|path| check_rocm_version(path).unwrap_or_default()); if let Some(valid_rocm_path) = rocm_path { - ensure_single_rocm_hip_feature_set(); + set_default_rocm_version(valid_rocm_path).unwrap(); + ensure_single_rocm_version_feature_set(); + ensure_single_hip_feature_set(); // verify HIP compatibility let Version { patch: hip_system_patch_version,