diff --git a/.github/workflows/mpi.yml b/.github/workflows/mpi.yml new file mode 100644 index 00000000..cdce35e1 --- /dev/null +++ b/.github/workflows/mpi.yml @@ -0,0 +1,46 @@ +name: MPI Tests + +on: [pull_request, push] + +jobs: + test: + name: Test MPI on ${{ matrix.os }} + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest, macos-latest] + rust: [stable] + + steps: + - uses: actions/checkout@v3 + + - name: Install Rust + uses: dtolnay/rust-toolchain@stable + with: + toolchain: ${{ matrix.rust }} + + - name: Cache cargo registry + uses: actions/cache@v3 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} + + - name: Build + run: cargo build --verbose + + - name: Run tests + run: | + # Run single process tests + RUSTFLAGS="-C target-cpu=native -C target-feature=+avx512f" cargo +nightly run --release --bin=gkr -- -s keccak -f fr -t 16 + + # Run multi-process tests with mpirun + RUSTFLAGS="-C target-cpu=native -C target-feature=+avx512f" mpiexec -n 2 cargo +nightly run --release --bin=gkr-mpi -- -s keccak -f fr + + - name: Run specific MPI tests + run: | + RUSTFLAGS="-C target-cpu=native -C target-feature=+avx512f" mpiexec -n 2 cargo +nightly run --release --bin=gkr-mpi -- -s keccak -f gf2ext128 + RUSTFLAGS="-C target-cpu=native -C target-feature=+avx512f" mpiexec -n 2 cargo +nightly run --release --bin=gkr-mpi -- -s keccak -f m31ext3 + RUSTFLAGS="-C target-cpu=native -C target-feature=+avx512f" mpiexec -n 2 cargo +nightly run --release --bin=gkr-mpi -- -s keccak -f fr + RUSTFLAGS="-C target-cpu=native -C target-feature=+avx512f" mpiexec -n 2 cargo +nightly run --release --bin=gkr-mpi -- -s poseidon -f m31ext3 \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index c837a681..01ab40ff 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1586,6 +1586,7 @@ dependencies = [ "arith", "mersenne31", "mpi", + "rayon", ] [[package]] diff --git a/config/mpi_config/Cargo.toml b/config/mpi_config/Cargo.toml index 41deab80..9d06fd55 100644 --- a/config/mpi_config/Cargo.toml +++ b/config/mpi_config/Cargo.toml @@ -6,6 +6,7 @@ edition = "2021" [dependencies] arith = { path = "../../arith" } mpi.workspace = true +rayon.workspace = true [dev-dependencies] mersenne31 = { path = "../../arith/mersenne31"} \ No newline at end of file diff --git a/config/mpi_config/build.rs b/config/mpi_config/build.rs new file mode 100644 index 00000000..c58933ae --- /dev/null +++ b/config/mpi_config/build.rs @@ -0,0 +1,110 @@ +use std::process::Command; +use std::env; + +fn main() { + // First check if mpicc is available + let mpicc_check = Command::new("which") + .arg("mpicc") + .output(); + + + if let Err(_) = mpicc_check { + println!("cargo:warning=mpicc not found, attempting to install..."); + + // Detect the operating system + let os = env::consts::OS; + + match os { + "linux" => { + // Try to detect the package manager + let apt_check = Command::new("which") + .arg("apt") + .output(); + + let dnf_check = Command::new("which") + .arg("dnf") + .output(); + + if apt_check.is_ok() { + // Debian/Ubuntu + eprintln!("cargo:warning=Using apt to install OpenMPI..."); + let status = Command::new("sudo") + .args(&["apt", "update"]) + .status() + .expect("Failed to run apt update"); + + if !status.success() { + panic!("Failed to update apt"); + } + + let status = Command::new("sudo") + .args(&["apt", "install", "-y", "openmpi-bin", "libopenmpi-dev"]) + .status() + .expect("Failed to install OpenMPI"); + + if !status.success() { + panic!("Failed to install OpenMPI"); + } + } else if dnf_check.is_ok() { + // Fedora/RHEL + eprintln!("cargo:warning=Using dnf to install OpenMPI..."); + let status = Command::new("sudo") + .args(&["dnf", "install", "-y", "openmpi", "openmpi-devel"]) + .status() + .expect("Failed to install OpenMPI"); + + if !status.success() { + panic!("Failed to install OpenMPI"); + } + } else { + panic!("Unsupported Linux distribution. Please install OpenMPI manually."); + } + }, + "macos" => { + // Check for Homebrew + let brew_check = Command::new("which") + .arg("brew") + .output(); + + if brew_check.is_ok() { + eprintln!("cargo:warning=Using Homebrew to install OpenMPI..."); + let status = Command::new("brew") + .args(&["install", "open-mpi"]) + .status() + .expect("Failed to install OpenMPI"); + + if !status.success() { + panic!("Failed to install OpenMPI"); + } + } else { + panic!("Homebrew not found. Please install Homebrew first or install OpenMPI manually."); + } + }, + _ => panic!("Unsupported operating system. Please install OpenMPI manually."), + } + } + + // After installation (or if already installed), set up compilation flags + eprintln!("cargo:rustc-link-search=/usr/lib"); + eprintln!("cargo:rustc-link-lib=mpi"); + + // Get MPI compilation flags + let output = Command::new("mpicc") + .arg("-show") + .output() + .expect("Failed to run mpicc"); + + let flags = String::from_utf8_lossy(&output.stdout); + + // Parse the flags and add them to the build + for flag in flags.split_whitespace() { + if flag.starts_with("-L") { + eprintln!("cargo:rustc-link-search=native={}", &flag[2..]); + } else if flag.starts_with("-l") { + eprintln!("cargo:rustc-link-lib={}", &flag[2..]); + } + } + + // Force rebuild if build.rs changes + eprintln!("cargo:rerun-if-changed=build.rs"); +} \ No newline at end of file diff --git a/config/mpi_config/src/lib.rs b/config/mpi_config/src/lib.rs index 2fae9994..55f7a4b6 100644 --- a/config/mpi_config/src/lib.rs +++ b/config/mpi_config/src/lib.rs @@ -288,4 +288,4 @@ impl MPIConfig { } } -unsafe impl Send for MPIConfig {} +unsafe impl Send for MPIConfig {} \ No newline at end of file diff --git a/config/mpi_config/tests/gather_vec.rs b/config/mpi_config/tests/gather_vec.rs index 979869d3..66f9497f 100644 --- a/config/mpi_config/tests/gather_vec.rs +++ b/config/mpi_config/tests/gather_vec.rs @@ -6,7 +6,7 @@ use mpi_config::MPIConfig; fn test_gather_vec() { const TEST_SIZE: usize = (1 << 10) + 1; - let mpi_config = MPIConfig::new(); + let mut mpi_config = MPIConfig::new(); let mut local_vec = vec![M31::ZERO; TEST_SIZE]; for i in 0..TEST_SIZE { local_vec[i] = M31::from((mpi_config.world_rank() * TEST_SIZE + i) as u32);