From 711f767e9efdc399d361331ed0496f889ef7d0c6 Mon Sep 17 00:00:00 2001 From: zhenfei Date: Thu, 26 Dec 2024 10:46:54 -0500 Subject: [PATCH 1/3] wip --- Cargo.lock | 1 + config/mpi_config/Cargo.toml | 1 + config/mpi_config/src/lib.rs | 357 ++++++++++++++++---------- config/mpi_config/tests/gather_vec.rs | 2 +- 4 files changed, 231 insertions(+), 130 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c837a681..01ab40ff 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1586,6 +1586,7 @@ dependencies = [ "arith", "mersenne31", "mpi", + "rayon", ] [[package]] diff --git a/config/mpi_config/Cargo.toml b/config/mpi_config/Cargo.toml index 41deab80..9d06fd55 100644 --- a/config/mpi_config/Cargo.toml +++ b/config/mpi_config/Cargo.toml @@ -6,6 +6,7 @@ edition = "2021" [dependencies] arith = { path = "../../arith" } mpi.workspace = true +rayon.workspace = true [dev-dependencies] mersenne31 = { path = "../../arith/mersenne31"} \ No newline at end of file diff --git a/config/mpi_config/src/lib.rs b/config/mpi_config/src/lib.rs index 2fae9994..8e95503d 100644 --- a/config/mpi_config/src/lib.rs +++ b/config/mpi_config/src/lib.rs @@ -1,12 +1,12 @@ use std::{cmp, fmt::Debug}; use arith::Field; -use mpi::{ - environment::Universe, - ffi, - topology::{Process, SimpleCommunicator}, - traits::*, -}; +// use mpi::{ +// environment::Universe, +// ffi, +// topology::{Process, SimpleCommunicator}, +// traits::*, +// }; #[macro_export] macro_rules! root_println { @@ -17,113 +17,152 @@ macro_rules! root_println { }; } -static mut UNIVERSE: Option = None; -static mut WORLD: Option = None; +// static mut UNIVERSE: Option = None; +// static mut WORLD: Option = None; -#[derive(Clone)] +#[derive(Clone, Debug, PartialEq)] pub struct MPIConfig { - pub universe: Option<&'static mpi::environment::Universe>, - pub world: Option<&'static SimpleCommunicator>, + /// The shared memory between all the processes + pub universe: Vec, + /// The local memory of the current process + pub worlds: Vec>, + // pub universe: Option<&'static mpi::environment::Universe>, + // pub world: Option<&'static SimpleCommunicator>, + /// The number of worlds pub world_size: i32, + /// The current world rank pub world_rank: i32, } impl Default for MPIConfig { fn default() -> Self { Self { - universe: None, - world: None, + universe: Vec::new(), + worlds: Vec::new(), world_size: 1, world_rank: 0, } } } -impl Debug for MPIConfig { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let universe_fmt = if self.universe.is_none() { - Option::::None - } else { - Some(self.universe.unwrap().buffer_size()) - }; +// impl Debug for MPIConfig { +// fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +// let universe_fmt = if self.universe.is_none() { +// Option::::None +// } else { +// Some(self.universe.unwrap().buffer_size()) +// }; - let world_fmt = if self.world.is_none() { - Option::::None - } else { - Some(0usize) - }; +// let world_fmt = if self.world.is_none() { +// Option::::None +// } else { +// Some(0usize) +// }; - f.debug_struct("MPIConfig") - .field("universe", &universe_fmt) - .field("world", &world_fmt) - .field("world_size", &self.world_size) - .field("world_rank", &self.world_rank) - .finish() - } -} +// f.debug_struct("MPIConfig") +// .field("universe", &universe_fmt) +// .field("world", &world_fmt) +// .field("world_size", &self.world_size) +// .field("world_rank", &self.world_rank) +// .finish() +// } +// } -// Note: may not be correct -impl PartialEq for MPIConfig { - fn eq(&self, other: &Self) -> bool { - self.world_rank == other.world_rank && self.world_size == other.world_size - } -} +// // Note: may not be correct +// impl PartialEq for MPIConfig { +// fn eq(&self, other: &Self) -> bool { +// self.world_rank == other.world_rank && self.world_size == other.world_size +// } +// } /// MPI toolkit: impl MPIConfig { const ROOT_RANK: i32 = 0; - /// The communication limit for MPI is 2^30. Save 10 bits for #parties here. - const CHUNK_SIZE: usize = 1usize << 20; + // /// The communication limit for MPI is 2^30. Save 10 bits for #parties here. + // const CHUNK_SIZE: usize = 1usize << 20; // OK if already initialized, mpi::initialize() will return None #[allow(static_mut_refs)] pub fn init() { - unsafe { - let universe = mpi::initialize(); - if universe.is_some() { - UNIVERSE = universe; - WORLD = Some(UNIVERSE.as_ref().unwrap().world()); - } - } + // do nothing + + + + + // unsafe { + // let universe = mpi::initialize(); + // if universe.is_some() { + // UNIVERSE = universe; + // WORLD = Some(UNIVERSE.as_ref().unwrap().world()); + // } + // } } #[inline] pub fn finalize() { - unsafe { ffi::MPI_Finalize() }; + // do nothing + // unsafe { ffi::MPI_Finalize() }; } #[allow(static_mut_refs)] pub fn new() -> Self { Self::init(); - let universe = unsafe { UNIVERSE.as_ref() }; - let world = unsafe { WORLD.as_ref() }; - let world_size = if let Some(world) = world { - world.size() - } else { - 1 - }; - let world_rank = if let Some(world) = world { - world.rank() - } else { - 0 + // let universe = unsafe { UNIVERSE.as_ref() }; + // let world = unsafe { WORLD.as_ref() }; + // let world_size = if let Some(world) = world { + // world.size() + // } else { + // 1 + // }; + // let world_rank = if let Some(world) = world { + // world.rank() + // } else { + // 0 + // }; + // Self { + // universe, + // world, + // world_size, + // world_rank, + // } + + let num_worlds = rayon::current_num_threads() as i32; + let world_rank = match rayon::current_thread_index() { + Some(rank) => rank as i32, + None => 0, }; + + let universe = vec![]; + let worlds = vec![vec![]; num_worlds as usize]; + Self { universe, - world, - world_size, + worlds, + world_size: num_worlds, world_rank, } } #[inline] pub fn new_for_verifier(world_size: i32) -> Self { + + let universe = vec![]; + let worlds = vec![vec![]; world_size as usize]; + Self { - universe: None, - world: None, - world_size, + universe, + worlds, + world_size: world_size, world_rank: 0, } + + // Self { + // universe: None, + // world: None, + // world_size, + // world_rank: 0, + // } } /// Return an u8 vector sharing THE SAME MEMORY SLOT with the input. @@ -143,84 +182,141 @@ impl MPIConfig { } #[allow(clippy::collapsible_else_if)] - pub fn gather_vec(&self, local_vec: &Vec, global_vec: &mut Vec) { - unsafe { - if self.world_size == 1 { - *global_vec = local_vec.clone() - } else { - assert!(!self.is_root() || global_vec.len() == local_vec.len() * self.world_size()); - + pub fn gather_vec(&mut self, local_vec: &Vec, global_vec: &mut Vec) { + if self.world_size == 1 { + *global_vec = local_vec.clone(); + return; + } + + // For non-root processes, we just need to store our local vector in our world's memory + if !self.is_root() { + assert!(global_vec.is_empty(), "Non-root processes should have empty global vectors"); + + // Convert local vector to bytes and store in our world's slot + unsafe { let local_vec_u8 = Self::vec_to_u8_bytes(local_vec); - let local_n_bytes = local_vec_u8.len(); - let n_chunks = (local_n_bytes + Self::CHUNK_SIZE - 1) / Self::CHUNK_SIZE; - if n_chunks == 1 { - if self.world_rank == Self::ROOT_RANK { - let mut global_vec_u8 = Self::vec_to_u8_bytes(global_vec); - self.root_process() - .gather_into_root(&local_vec_u8, &mut global_vec_u8); - global_vec_u8.leak(); // discard control of the memory - } else { - self.root_process().gather_into(&local_vec_u8); - } - } else { - if self.world_rank == Self::ROOT_RANK { - let mut chunk_buffer_u8 = vec![0u8; Self::CHUNK_SIZE * self.world_size()]; - let mut global_vec_u8 = Self::vec_to_u8_bytes(global_vec); - for i in 0..n_chunks { - let local_start = i * Self::CHUNK_SIZE; - let local_end = cmp::min(local_start + Self::CHUNK_SIZE, local_n_bytes); - self.root_process().gather_into_root( - &local_vec_u8[local_start..local_end], - &mut chunk_buffer_u8, - ); - - // distribute the data to where they belong to in global vec - let actual_chunk_size = local_end - local_start; - for j in 0..self.world_size() { - let global_start = j * local_n_bytes + local_start; - let global_end = global_start + actual_chunk_size; - global_vec_u8[global_start..global_end].copy_from_slice( - &chunk_buffer_u8[j * Self::CHUNK_SIZE - ..j * Self::CHUNK_SIZE + actual_chunk_size], - ); - } - } - global_vec_u8.leak(); // discard control of the memory - } else { - for i in 0..n_chunks { - let local_start = i * Self::CHUNK_SIZE; - let local_end = cmp::min(local_start + Self::CHUNK_SIZE, local_n_bytes); - self.root_process() - .gather_into(&local_vec_u8[local_start..local_end]); - } - } - } - local_vec_u8.leak(); // discard control of the memory + self.worlds[self.world_rank as usize] = local_vec_u8; } + return; } + + // For root process, we need to gather all vectors + assert!( + global_vec.len() == local_vec.len() * self.world_size(), + "Root's global_vec size must match total data size" + ); + + // First, store root's local vector in the beginning of global_vec + let root_data_size = local_vec.len(); + global_vec[..root_data_size].copy_from_slice(local_vec); + + // Then gather data from other processes + for rank in 1..self.world_size { + let rank = rank as usize; + let start_idx = rank * root_data_size; + let end_idx = start_idx + root_data_size; + + // Get the bytes from the corresponding world's memory + let world_bytes = &self.worlds[rank]; + + // Safety: We're reconstructing the vector with the same layout + unsafe { + let other_vec = Vec::::from_raw_parts( + world_bytes.as_ptr() as *mut F, + root_data_size, + root_data_size, + ); + + // Copy the data to the appropriate position in global_vec + global_vec[start_idx..end_idx].copy_from_slice(&other_vec); + + // Don't drop the vector since we don't own the memory + std::mem::forget(other_vec); + } + } + + // unsafe { + // if self.world_size == 1 { + // *global_vec = local_vec.clone() + // } else { + // assert!(!self.is_root() || global_vec.len() == local_vec.len() * self.world_size()); + + // let local_vec_u8 = Self::vec_to_u8_bytes(local_vec); + // let local_n_bytes = local_vec_u8.len(); + // let n_chunks = (local_n_bytes + Self::CHUNK_SIZE - 1) / Self::CHUNK_SIZE; + // if n_chunks == 1 { + // if self.world_rank == Self::ROOT_RANK { + // let mut global_vec_u8 = Self::vec_to_u8_bytes(global_vec); + // self.root_process() + // .gather_into_root(&local_vec_u8, &mut global_vec_u8); + // global_vec_u8.leak(); // discard control of the memory + // } else { + // self.root_process().gather_into(&local_vec_u8); + // } + // } else { + // if self.world_rank == Self::ROOT_RANK { + // let mut chunk_buffer_u8 = vec![0u8; Self::CHUNK_SIZE * self.world_size()]; + // let mut global_vec_u8 = Self::vec_to_u8_bytes(global_vec); + // for i in 0..n_chunks { + // let local_start = i * Self::CHUNK_SIZE; + // let local_end = cmp::min(local_start + Self::CHUNK_SIZE, local_n_bytes); + // self.root_process().gather_into_root( + // &local_vec_u8[local_start..local_end], + // &mut chunk_buffer_u8, + // ); + + // // distribute the data to where they belong to in global vec + // let actual_chunk_size = local_end - local_start; + // for j in 0..self.world_size() { + // let global_start = j * local_n_bytes + local_start; + // let global_end = global_start + actual_chunk_size; + // global_vec_u8[global_start..global_end].copy_from_slice( + // &chunk_buffer_u8[j * Self::CHUNK_SIZE + // ..j * Self::CHUNK_SIZE + actual_chunk_size], + // ); + // } + // } + // global_vec_u8.leak(); // discard control of the memory + // } else { + // for i in 0..n_chunks { + // let local_start = i * Self::CHUNK_SIZE; + // let local_end = cmp::min(local_start + Self::CHUNK_SIZE, local_n_bytes); + // self.root_process() + // .gather_into(&local_vec_u8[local_start..local_end]); + // } + // } + // } + // local_vec_u8.leak(); // discard control of the memory + // } + // } } - /// Root process broadcase a value f into all the processes + /// Root process broadcast a value f into all the processes #[inline] pub fn root_broadcast_f(&self, f: &mut F) { - unsafe { + // unsafe { if self.world_size == 1 { } else { - let mut vec_u8 = Self::elem_to_u8_bytes(f, F::SIZE); - self.root_process().broadcast_into(&mut vec_u8); - vec_u8.leak(); + + + // let mut vec_u8 = Self::elem_to_u8_bytes(f, F::SIZE); + // self.root_process().broadcast_into(&mut vec_u8); + // vec_u8.leak(); } - } + // } } #[inline] + /// copy the root process's memory to the buffer pub fn root_broadcast_bytes(&self, bytes: &mut Vec) { - self.root_process().broadcast_into(bytes); + // self.root_process().broadcast_into(bytes); + bytes.clear(); + bytes.copy_from_slice(&self.root_process()); } /// sum up all local values #[inline] - pub fn sum_vec(&self, local_vec: &Vec) -> Vec { + pub fn sum_vec(&mut self, local_vec: &Vec) -> Vec { if self.world_size == 1 { local_vec.clone() } else if self.world_rank == Self::ROOT_RANK { @@ -241,7 +337,7 @@ impl MPIConfig { /// coef has a length of mpi_world_size #[inline] - pub fn coef_combine_vec(&self, local_vec: &Vec, coef: &[F]) -> Vec { + pub fn coef_combine_vec(&mut self, local_vec: &Vec, coef: &[F]) -> Vec { if self.world_size == 1 { // Warning: literally, it should be coef[0] * local_vec // but coef[0] is always one in our use case of self.world_size = 1 @@ -278,13 +374,16 @@ impl MPIConfig { } #[inline(always)] - pub fn root_process(&self) -> Process { - self.world.unwrap().process_at_rank(Self::ROOT_RANK) + pub fn root_process(&self) -> &Vec { + &self.worlds[0] + + // self.world.unwrap().process_at_rank(Self::ROOT_RANK) } #[inline(always)] pub fn barrier(&self) { - self.world.unwrap().barrier(); + // do nothing + // self.world.unwrap().barrier(); } } diff --git a/config/mpi_config/tests/gather_vec.rs b/config/mpi_config/tests/gather_vec.rs index 979869d3..66f9497f 100644 --- a/config/mpi_config/tests/gather_vec.rs +++ b/config/mpi_config/tests/gather_vec.rs @@ -6,7 +6,7 @@ use mpi_config::MPIConfig; fn test_gather_vec() { const TEST_SIZE: usize = (1 << 10) + 1; - let mpi_config = MPIConfig::new(); + let mut mpi_config = MPIConfig::new(); let mut local_vec = vec![M31::ZERO; TEST_SIZE]; for i in 0..TEST_SIZE { local_vec[i] = M31::from((mpi_config.world_rank() * TEST_SIZE + i) as u32); From 96e44a2d4d5ee40671562e35094761475f28a113 Mon Sep 17 00:00:00 2001 From: zhenfei Date: Thu, 26 Dec 2024 10:59:34 -0500 Subject: [PATCH 2/3] 1 --- config/mpi_config/build.rs | 110 +++++++++++ config/mpi_config/src/lib.rs | 359 +++++++++++++---------------------- 2 files changed, 240 insertions(+), 229 deletions(-) create mode 100644 config/mpi_config/build.rs diff --git a/config/mpi_config/build.rs b/config/mpi_config/build.rs new file mode 100644 index 00000000..c58933ae --- /dev/null +++ b/config/mpi_config/build.rs @@ -0,0 +1,110 @@ +use std::process::Command; +use std::env; + +fn main() { + // First check if mpicc is available + let mpicc_check = Command::new("which") + .arg("mpicc") + .output(); + + + if let Err(_) = mpicc_check { + println!("cargo:warning=mpicc not found, attempting to install..."); + + // Detect the operating system + let os = env::consts::OS; + + match os { + "linux" => { + // Try to detect the package manager + let apt_check = Command::new("which") + .arg("apt") + .output(); + + let dnf_check = Command::new("which") + .arg("dnf") + .output(); + + if apt_check.is_ok() { + // Debian/Ubuntu + eprintln!("cargo:warning=Using apt to install OpenMPI..."); + let status = Command::new("sudo") + .args(&["apt", "update"]) + .status() + .expect("Failed to run apt update"); + + if !status.success() { + panic!("Failed to update apt"); + } + + let status = Command::new("sudo") + .args(&["apt", "install", "-y", "openmpi-bin", "libopenmpi-dev"]) + .status() + .expect("Failed to install OpenMPI"); + + if !status.success() { + panic!("Failed to install OpenMPI"); + } + } else if dnf_check.is_ok() { + // Fedora/RHEL + eprintln!("cargo:warning=Using dnf to install OpenMPI..."); + let status = Command::new("sudo") + .args(&["dnf", "install", "-y", "openmpi", "openmpi-devel"]) + .status() + .expect("Failed to install OpenMPI"); + + if !status.success() { + panic!("Failed to install OpenMPI"); + } + } else { + panic!("Unsupported Linux distribution. Please install OpenMPI manually."); + } + }, + "macos" => { + // Check for Homebrew + let brew_check = Command::new("which") + .arg("brew") + .output(); + + if brew_check.is_ok() { + eprintln!("cargo:warning=Using Homebrew to install OpenMPI..."); + let status = Command::new("brew") + .args(&["install", "open-mpi"]) + .status() + .expect("Failed to install OpenMPI"); + + if !status.success() { + panic!("Failed to install OpenMPI"); + } + } else { + panic!("Homebrew not found. Please install Homebrew first or install OpenMPI manually."); + } + }, + _ => panic!("Unsupported operating system. Please install OpenMPI manually."), + } + } + + // After installation (or if already installed), set up compilation flags + eprintln!("cargo:rustc-link-search=/usr/lib"); + eprintln!("cargo:rustc-link-lib=mpi"); + + // Get MPI compilation flags + let output = Command::new("mpicc") + .arg("-show") + .output() + .expect("Failed to run mpicc"); + + let flags = String::from_utf8_lossy(&output.stdout); + + // Parse the flags and add them to the build + for flag in flags.split_whitespace() { + if flag.starts_with("-L") { + eprintln!("cargo:rustc-link-search=native={}", &flag[2..]); + } else if flag.starts_with("-l") { + eprintln!("cargo:rustc-link-lib={}", &flag[2..]); + } + } + + // Force rebuild if build.rs changes + eprintln!("cargo:rerun-if-changed=build.rs"); +} \ No newline at end of file diff --git a/config/mpi_config/src/lib.rs b/config/mpi_config/src/lib.rs index 8e95503d..55f7a4b6 100644 --- a/config/mpi_config/src/lib.rs +++ b/config/mpi_config/src/lib.rs @@ -1,12 +1,12 @@ use std::{cmp, fmt::Debug}; use arith::Field; -// use mpi::{ -// environment::Universe, -// ffi, -// topology::{Process, SimpleCommunicator}, -// traits::*, -// }; +use mpi::{ + environment::Universe, + ffi, + topology::{Process, SimpleCommunicator}, + traits::*, +}; #[macro_export] macro_rules! root_println { @@ -17,152 +17,113 @@ macro_rules! root_println { }; } -// static mut UNIVERSE: Option = None; -// static mut WORLD: Option = None; +static mut UNIVERSE: Option = None; +static mut WORLD: Option = None; -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone)] pub struct MPIConfig { - /// The shared memory between all the processes - pub universe: Vec, - /// The local memory of the current process - pub worlds: Vec>, - // pub universe: Option<&'static mpi::environment::Universe>, - // pub world: Option<&'static SimpleCommunicator>, - /// The number of worlds + pub universe: Option<&'static mpi::environment::Universe>, + pub world: Option<&'static SimpleCommunicator>, pub world_size: i32, - /// The current world rank pub world_rank: i32, } impl Default for MPIConfig { fn default() -> Self { Self { - universe: Vec::new(), - worlds: Vec::new(), + universe: None, + world: None, world_size: 1, world_rank: 0, } } } -// impl Debug for MPIConfig { -// fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { -// let universe_fmt = if self.universe.is_none() { -// Option::::None -// } else { -// Some(self.universe.unwrap().buffer_size()) -// }; +impl Debug for MPIConfig { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let universe_fmt = if self.universe.is_none() { + Option::::None + } else { + Some(self.universe.unwrap().buffer_size()) + }; -// let world_fmt = if self.world.is_none() { -// Option::::None -// } else { -// Some(0usize) -// }; + let world_fmt = if self.world.is_none() { + Option::::None + } else { + Some(0usize) + }; -// f.debug_struct("MPIConfig") -// .field("universe", &universe_fmt) -// .field("world", &world_fmt) -// .field("world_size", &self.world_size) -// .field("world_rank", &self.world_rank) -// .finish() -// } -// } + f.debug_struct("MPIConfig") + .field("universe", &universe_fmt) + .field("world", &world_fmt) + .field("world_size", &self.world_size) + .field("world_rank", &self.world_rank) + .finish() + } +} -// // Note: may not be correct -// impl PartialEq for MPIConfig { -// fn eq(&self, other: &Self) -> bool { -// self.world_rank == other.world_rank && self.world_size == other.world_size -// } -// } +// Note: may not be correct +impl PartialEq for MPIConfig { + fn eq(&self, other: &Self) -> bool { + self.world_rank == other.world_rank && self.world_size == other.world_size + } +} /// MPI toolkit: impl MPIConfig { const ROOT_RANK: i32 = 0; - // /// The communication limit for MPI is 2^30. Save 10 bits for #parties here. - // const CHUNK_SIZE: usize = 1usize << 20; + /// The communication limit for MPI is 2^30. Save 10 bits for #parties here. + const CHUNK_SIZE: usize = 1usize << 20; // OK if already initialized, mpi::initialize() will return None #[allow(static_mut_refs)] pub fn init() { - // do nothing - - - - - // unsafe { - // let universe = mpi::initialize(); - // if universe.is_some() { - // UNIVERSE = universe; - // WORLD = Some(UNIVERSE.as_ref().unwrap().world()); - // } - // } + unsafe { + let universe = mpi::initialize(); + if universe.is_some() { + UNIVERSE = universe; + WORLD = Some(UNIVERSE.as_ref().unwrap().world()); + } + } } #[inline] pub fn finalize() { - // do nothing - // unsafe { ffi::MPI_Finalize() }; + unsafe { ffi::MPI_Finalize() }; } #[allow(static_mut_refs)] pub fn new() -> Self { Self::init(); - // let universe = unsafe { UNIVERSE.as_ref() }; - // let world = unsafe { WORLD.as_ref() }; - // let world_size = if let Some(world) = world { - // world.size() - // } else { - // 1 - // }; - // let world_rank = if let Some(world) = world { - // world.rank() - // } else { - // 0 - // }; - // Self { - // universe, - // world, - // world_size, - // world_rank, - // } - - let num_worlds = rayon::current_num_threads() as i32; - let world_rank = match rayon::current_thread_index() { - Some(rank) => rank as i32, - None => 0, + let universe = unsafe { UNIVERSE.as_ref() }; + let world = unsafe { WORLD.as_ref() }; + let world_size = if let Some(world) = world { + world.size() + } else { + 1 + }; + let world_rank = if let Some(world) = world { + world.rank() + } else { + 0 }; - - let universe = vec![]; - let worlds = vec![vec![]; num_worlds as usize]; - Self { universe, - worlds, - world_size: num_worlds, + world, + world_size, world_rank, } } #[inline] pub fn new_for_verifier(world_size: i32) -> Self { - - let universe = vec![]; - let worlds = vec![vec![]; world_size as usize]; - Self { - universe, - worlds, - world_size: world_size, + universe: None, + world: None, + world_size, world_rank: 0, } - - // Self { - // universe: None, - // world: None, - // world_size, - // world_rank: 0, - // } } /// Return an u8 vector sharing THE SAME MEMORY SLOT with the input. @@ -182,141 +143,84 @@ impl MPIConfig { } #[allow(clippy::collapsible_else_if)] - pub fn gather_vec(&mut self, local_vec: &Vec, global_vec: &mut Vec) { - if self.world_size == 1 { - *global_vec = local_vec.clone(); - return; - } - - // For non-root processes, we just need to store our local vector in our world's memory - if !self.is_root() { - assert!(global_vec.is_empty(), "Non-root processes should have empty global vectors"); - - // Convert local vector to bytes and store in our world's slot - unsafe { + pub fn gather_vec(&self, local_vec: &Vec, global_vec: &mut Vec) { + unsafe { + if self.world_size == 1 { + *global_vec = local_vec.clone() + } else { + assert!(!self.is_root() || global_vec.len() == local_vec.len() * self.world_size()); + let local_vec_u8 = Self::vec_to_u8_bytes(local_vec); - self.worlds[self.world_rank as usize] = local_vec_u8; - } - return; - } - - // For root process, we need to gather all vectors - assert!( - global_vec.len() == local_vec.len() * self.world_size(), - "Root's global_vec size must match total data size" - ); - - // First, store root's local vector in the beginning of global_vec - let root_data_size = local_vec.len(); - global_vec[..root_data_size].copy_from_slice(local_vec); - - // Then gather data from other processes - for rank in 1..self.world_size { - let rank = rank as usize; - let start_idx = rank * root_data_size; - let end_idx = start_idx + root_data_size; - - // Get the bytes from the corresponding world's memory - let world_bytes = &self.worlds[rank]; - - // Safety: We're reconstructing the vector with the same layout - unsafe { - let other_vec = Vec::::from_raw_parts( - world_bytes.as_ptr() as *mut F, - root_data_size, - root_data_size, - ); - - // Copy the data to the appropriate position in global_vec - global_vec[start_idx..end_idx].copy_from_slice(&other_vec); - - // Don't drop the vector since we don't own the memory - std::mem::forget(other_vec); + let local_n_bytes = local_vec_u8.len(); + let n_chunks = (local_n_bytes + Self::CHUNK_SIZE - 1) / Self::CHUNK_SIZE; + if n_chunks == 1 { + if self.world_rank == Self::ROOT_RANK { + let mut global_vec_u8 = Self::vec_to_u8_bytes(global_vec); + self.root_process() + .gather_into_root(&local_vec_u8, &mut global_vec_u8); + global_vec_u8.leak(); // discard control of the memory + } else { + self.root_process().gather_into(&local_vec_u8); + } + } else { + if self.world_rank == Self::ROOT_RANK { + let mut chunk_buffer_u8 = vec![0u8; Self::CHUNK_SIZE * self.world_size()]; + let mut global_vec_u8 = Self::vec_to_u8_bytes(global_vec); + for i in 0..n_chunks { + let local_start = i * Self::CHUNK_SIZE; + let local_end = cmp::min(local_start + Self::CHUNK_SIZE, local_n_bytes); + self.root_process().gather_into_root( + &local_vec_u8[local_start..local_end], + &mut chunk_buffer_u8, + ); + + // distribute the data to where they belong to in global vec + let actual_chunk_size = local_end - local_start; + for j in 0..self.world_size() { + let global_start = j * local_n_bytes + local_start; + let global_end = global_start + actual_chunk_size; + global_vec_u8[global_start..global_end].copy_from_slice( + &chunk_buffer_u8[j * Self::CHUNK_SIZE + ..j * Self::CHUNK_SIZE + actual_chunk_size], + ); + } + } + global_vec_u8.leak(); // discard control of the memory + } else { + for i in 0..n_chunks { + let local_start = i * Self::CHUNK_SIZE; + let local_end = cmp::min(local_start + Self::CHUNK_SIZE, local_n_bytes); + self.root_process() + .gather_into(&local_vec_u8[local_start..local_end]); + } + } + } + local_vec_u8.leak(); // discard control of the memory } } - - // unsafe { - // if self.world_size == 1 { - // *global_vec = local_vec.clone() - // } else { - // assert!(!self.is_root() || global_vec.len() == local_vec.len() * self.world_size()); - - // let local_vec_u8 = Self::vec_to_u8_bytes(local_vec); - // let local_n_bytes = local_vec_u8.len(); - // let n_chunks = (local_n_bytes + Self::CHUNK_SIZE - 1) / Self::CHUNK_SIZE; - // if n_chunks == 1 { - // if self.world_rank == Self::ROOT_RANK { - // let mut global_vec_u8 = Self::vec_to_u8_bytes(global_vec); - // self.root_process() - // .gather_into_root(&local_vec_u8, &mut global_vec_u8); - // global_vec_u8.leak(); // discard control of the memory - // } else { - // self.root_process().gather_into(&local_vec_u8); - // } - // } else { - // if self.world_rank == Self::ROOT_RANK { - // let mut chunk_buffer_u8 = vec![0u8; Self::CHUNK_SIZE * self.world_size()]; - // let mut global_vec_u8 = Self::vec_to_u8_bytes(global_vec); - // for i in 0..n_chunks { - // let local_start = i * Self::CHUNK_SIZE; - // let local_end = cmp::min(local_start + Self::CHUNK_SIZE, local_n_bytes); - // self.root_process().gather_into_root( - // &local_vec_u8[local_start..local_end], - // &mut chunk_buffer_u8, - // ); - - // // distribute the data to where they belong to in global vec - // let actual_chunk_size = local_end - local_start; - // for j in 0..self.world_size() { - // let global_start = j * local_n_bytes + local_start; - // let global_end = global_start + actual_chunk_size; - // global_vec_u8[global_start..global_end].copy_from_slice( - // &chunk_buffer_u8[j * Self::CHUNK_SIZE - // ..j * Self::CHUNK_SIZE + actual_chunk_size], - // ); - // } - // } - // global_vec_u8.leak(); // discard control of the memory - // } else { - // for i in 0..n_chunks { - // let local_start = i * Self::CHUNK_SIZE; - // let local_end = cmp::min(local_start + Self::CHUNK_SIZE, local_n_bytes); - // self.root_process() - // .gather_into(&local_vec_u8[local_start..local_end]); - // } - // } - // } - // local_vec_u8.leak(); // discard control of the memory - // } - // } } - /// Root process broadcast a value f into all the processes + /// Root process broadcase a value f into all the processes #[inline] pub fn root_broadcast_f(&self, f: &mut F) { - // unsafe { + unsafe { if self.world_size == 1 { } else { - - - // let mut vec_u8 = Self::elem_to_u8_bytes(f, F::SIZE); - // self.root_process().broadcast_into(&mut vec_u8); - // vec_u8.leak(); + let mut vec_u8 = Self::elem_to_u8_bytes(f, F::SIZE); + self.root_process().broadcast_into(&mut vec_u8); + vec_u8.leak(); } - // } + } } #[inline] - /// copy the root process's memory to the buffer pub fn root_broadcast_bytes(&self, bytes: &mut Vec) { - // self.root_process().broadcast_into(bytes); - bytes.clear(); - bytes.copy_from_slice(&self.root_process()); + self.root_process().broadcast_into(bytes); } /// sum up all local values #[inline] - pub fn sum_vec(&mut self, local_vec: &Vec) -> Vec { + pub fn sum_vec(&self, local_vec: &Vec) -> Vec { if self.world_size == 1 { local_vec.clone() } else if self.world_rank == Self::ROOT_RANK { @@ -337,7 +241,7 @@ impl MPIConfig { /// coef has a length of mpi_world_size #[inline] - pub fn coef_combine_vec(&mut self, local_vec: &Vec, coef: &[F]) -> Vec { + pub fn coef_combine_vec(&self, local_vec: &Vec, coef: &[F]) -> Vec { if self.world_size == 1 { // Warning: literally, it should be coef[0] * local_vec // but coef[0] is always one in our use case of self.world_size = 1 @@ -374,17 +278,14 @@ impl MPIConfig { } #[inline(always)] - pub fn root_process(&self) -> &Vec { - &self.worlds[0] - - // self.world.unwrap().process_at_rank(Self::ROOT_RANK) + pub fn root_process(&self) -> Process { + self.world.unwrap().process_at_rank(Self::ROOT_RANK) } #[inline(always)] pub fn barrier(&self) { - // do nothing - // self.world.unwrap().barrier(); + self.world.unwrap().barrier(); } } -unsafe impl Send for MPIConfig {} +unsafe impl Send for MPIConfig {} \ No newline at end of file From 34dce795b9eac81791dbe32d622d977f21bf3a9d Mon Sep 17 00:00:00 2001 From: zhenfei Date: Thu, 26 Dec 2024 11:07:35 -0500 Subject: [PATCH 3/3] 1 --- .github/workflows/mpi.yml | 46 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 .github/workflows/mpi.yml diff --git a/.github/workflows/mpi.yml b/.github/workflows/mpi.yml new file mode 100644 index 00000000..cdce35e1 --- /dev/null +++ b/.github/workflows/mpi.yml @@ -0,0 +1,46 @@ +name: MPI Tests + +on: [pull_request, push] + +jobs: + test: + name: Test MPI on ${{ matrix.os }} + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest, macos-latest] + rust: [stable] + + steps: + - uses: actions/checkout@v3 + + - name: Install Rust + uses: dtolnay/rust-toolchain@stable + with: + toolchain: ${{ matrix.rust }} + + - name: Cache cargo registry + uses: actions/cache@v3 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} + + - name: Build + run: cargo build --verbose + + - name: Run tests + run: | + # Run single process tests + RUSTFLAGS="-C target-cpu=native -C target-feature=+avx512f" cargo +nightly run --release --bin=gkr -- -s keccak -f fr -t 16 + + # Run multi-process tests with mpirun + RUSTFLAGS="-C target-cpu=native -C target-feature=+avx512f" mpiexec -n 2 cargo +nightly run --release --bin=gkr-mpi -- -s keccak -f fr + + - name: Run specific MPI tests + run: | + RUSTFLAGS="-C target-cpu=native -C target-feature=+avx512f" mpiexec -n 2 cargo +nightly run --release --bin=gkr-mpi -- -s keccak -f gf2ext128 + RUSTFLAGS="-C target-cpu=native -C target-feature=+avx512f" mpiexec -n 2 cargo +nightly run --release --bin=gkr-mpi -- -s keccak -f m31ext3 + RUSTFLAGS="-C target-cpu=native -C target-feature=+avx512f" mpiexec -n 2 cargo +nightly run --release --bin=gkr-mpi -- -s keccak -f fr + RUSTFLAGS="-C target-cpu=native -C target-feature=+avx512f" mpiexec -n 2 cargo +nightly run --release --bin=gkr-mpi -- -s poseidon -f m31ext3 \ No newline at end of file