Skip to content

Commit

Permalink
Use the crate rocm_ features and change naming scheme of them
Browse files Browse the repository at this point in the history
In order to make it easier to parse the rocm version from the feature name
the new name scheme is 'rocm__x_x_x'
  • Loading branch information
syl20bnr committed Dec 9, 2024
1 parent d66ea7f commit c63b693
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 15 deletions.
6 changes: 3 additions & 3 deletions crates/cubecl-hip-sys/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ version.workspace = true
rust-version = "1.81"

[features]
default = ["rocm_624"]
rocm_622 = []
rocm_624 = []
default = ["rocm__6_2_4"]
rocm__6_2_2 = []
rocm__6_2_4 = []

[dependencies]
libc = { workspace = true }
Expand Down
60 changes: 52 additions & 8 deletions crates/cubecl-hip-sys/build.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use std::env;
use std::{env, io};
use std::path::Path;

const ROCM_FEATURE_PREFIX: &str = "CARGO_FEATURE_ROCM__";

/// Reads a header inside the rocm folder, that contains the lib's version
fn get_system_hip_version(rocm_path: impl AsRef<Path>) -> std::io::Result<(u8, u8, u32)> {
let version_path = rocm_path.as_ref().join("include/hip/hip_version.h");
Expand Down Expand Up @@ -45,16 +47,56 @@ fn hip_header_patch_number_to_release_patch_number(number: u32) -> Option<u32> {
}
}

/// Return the ROCm version corresponding to the enabled feature
fn get_rocm_feature_version() -> io::Result<(u8, u8, u32)> {
for (key, value) in env::vars() {
if key.starts_with(ROCM_FEATURE_PREFIX) && value == "1" {
if let Some(version) = key.strip_prefix(ROCM_FEATURE_PREFIX) {
// Parse the version using `_` as the delimiter
let parts: Vec<&str> = version.split('_').collect();
if parts.len() == 3 {
if let (Ok(major), Ok(minor), Ok(patch)) = (
parts[0].parse::<u8>(),
parts[1].parse::<u8>(),
parts[2].parse::<u32>(),
) {
return Ok((major, minor, patch));
}
}
}
}
}

Err(io::Error::new(
io::ErrorKind::NotFound,
"No valid ROCm feature version found. One 'rocm_<version>' feature must be set.",
))
}

/// Make sure that feature is set correctly
fn ensure_single_rocm_feature_set() {
let mut enabled_features = Vec::new();

for (key, value) in env::vars() {
if key.starts_with(ROCM_FEATURE_PREFIX) && value == "1" {
enabled_features.push(format!("rocm__{}", key.strip_prefix(ROCM_FEATURE_PREFIX).unwrap()));
}
}

match enabled_features.len() {
1 => {},
0 => panic!("No ROCm version features are enabled. One ROCm version feature must be set."),
_ => panic!(
"Multiple ROCm version features are enabled: {:?}. Only one can be set.",
enabled_features
),
}
}

/// Checks if the version inside `rocm_path` matches crate version
fn check_version(rocm_path: impl AsRef<Path>) -> std::io::Result<bool> {
let (system_major, system_minor, system_patch) = get_system_hip_version(rocm_path)?;

// Can be fairly sure that crate's versioning won't fail
let crate_major = env!("CARGO_PKG_VERSION_MAJOR").parse::<u8>().unwrap();
let crate_minor = env!("CARGO_PKG_VERSION_MINOR").parse::<u8>().unwrap();
// Need at least u32 here, because of the crates reserved versions for revisions,
// and _unlikely_ possibility that ROCm's patch version will be higher than 16
let crate_patch = env!("CARGO_PKG_VERSION_PATCH").parse::<u32>().unwrap() / 1000;
let (crate_major, crate_minor, crate_patch) = get_rocm_feature_version()?;

if crate_major == system_major {
let mismatches = match (crate_minor == system_minor, crate_patch == system_patch) {
Expand All @@ -74,6 +116,8 @@ fn check_version(rocm_path: impl AsRef<Path>) -> std::io::Result<bool> {
}

fn main() {
ensure_single_rocm_feature_set();

println!("cargo:rerun-if-changed=build.rs");
println!("cargo:rerun-if-env-changed=CUBECL_ROCM_PATH");
println!("cargo:rerun-if-env-changed=ROCM_PATH");
Expand Down
8 changes: 4 additions & 4 deletions crates/cubecl-hip-sys/src/bindings/mod.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#[cfg(feature = "rocm_622")]
#[cfg(feature = "rocm__6_2_2")]
mod bindings_622;
#[cfg(feature = "rocm_622")]
#[cfg(feature = "rocm__6_2_2")]
pub use bindings_622::*;
#[cfg(feature = "rocm_624")]
#[cfg(feature = "rocm__6_2_4")]
mod bindings_624;
#[cfg(feature = "rocm_624")]
#[cfg(feature = "rocm__6_2_4")]
pub use bindings_624::*;

0 comments on commit c63b693

Please sign in to comment.