From 40aa96acee0144051943967492fe6c2c9d1f1e55 Mon Sep 17 00:00:00 2001 From: Xiang Xie Date: Fri, 28 Jun 2024 23:55:36 +0800 Subject: [PATCH 1/6] add ferret with io --- crates/mpz-common/src/ideal.rs | 4 +- crates/mpz-core/src/ggm_tree.rs | 52 ++-- crates/mpz-ot-core/src/ferret/mod.rs | 19 +- crates/mpz-ot-core/src/ferret/mpcot/mod.rs | 5 +- .../mpz-ot-core/src/ferret/mpcot/receiver.rs | 30 +-- .../src/ferret/mpcot/receiver_regular.rs | 30 +-- crates/mpz-ot-core/src/ferret/mpcot/sender.rs | 30 +-- .../src/ferret/mpcot/sender_regular.rs | 30 +-- crates/mpz-ot-core/src/ferret/receiver.rs | 20 +- crates/mpz-ot-core/src/ferret/sender.rs | 24 +- crates/mpz-ot-core/src/ferret/spcot/mod.rs | 81 ++++-- .../mpz-ot-core/src/ferret/spcot/receiver.rs | 243 +++++++++++++----- crates/mpz-ot-core/src/ferret/spcot/sender.rs | 156 +++++++---- crates/mpz-ot/src/ferret/error.rs | 67 +++++ crates/mpz-ot/src/ferret/mod.rs | 175 +++++++++++++ crates/mpz-ot/src/ferret/mpcot/error.rs | 59 +++++ crates/mpz-ot/src/ferret/mpcot/mod.rs | 165 ++++++++++++ crates/mpz-ot/src/ferret/mpcot/receiver.rs | 192 ++++++++++++++ crates/mpz-ot/src/ferret/mpcot/sender.rs | 166 ++++++++++++ crates/mpz-ot/src/ferret/receiver.rs | 192 ++++++++++++++ crates/mpz-ot/src/ferret/sender.rs | 160 ++++++++++++ crates/mpz-ot/src/ferret/spcot/error.rs | 59 +++++ crates/mpz-ot/src/ferret/spcot/mod.rs | 103 ++++++++ crates/mpz-ot/src/ferret/spcot/receiver.rs | 164 ++++++++++++ crates/mpz-ot/src/ferret/spcot/sender.rs | 144 +++++++++++ crates/mpz-ot/src/ideal/cot.rs | 11 +- crates/mpz-ot/src/lib.rs | 1 + 27 files changed, 2126 insertions(+), 256 deletions(-) create mode 100644 crates/mpz-ot/src/ferret/error.rs create mode 100644 crates/mpz-ot/src/ferret/mod.rs create mode 100644 crates/mpz-ot/src/ferret/mpcot/error.rs create mode 100644 crates/mpz-ot/src/ferret/mpcot/mod.rs create mode 100644 crates/mpz-ot/src/ferret/mpcot/receiver.rs create mode 100644 crates/mpz-ot/src/ferret/mpcot/sender.rs create mode 100644 crates/mpz-ot/src/ferret/receiver.rs create mode 100644 crates/mpz-ot/src/ferret/sender.rs create mode 100644 crates/mpz-ot/src/ferret/spcot/error.rs create mode 100644 crates/mpz-ot/src/ferret/spcot/mod.rs create mode 100644 crates/mpz-ot/src/ferret/spcot/receiver.rs create mode 100644 crates/mpz-ot/src/ferret/spcot/sender.rs diff --git a/crates/mpz-common/src/ideal.rs b/crates/mpz-common/src/ideal.rs index 804472ef..7fcb1628 100644 --- a/crates/mpz-common/src/ideal.rs +++ b/crates/mpz-common/src/ideal.rs @@ -18,7 +18,7 @@ struct Buffer { } /// The ideal functionality from the perspective of Alice. -#[derive(Debug)] +#[derive(Debug, Default)] pub struct Alice { f: Arc>, buffer: Arc>, @@ -79,7 +79,7 @@ impl Alice { } /// The ideal functionality from the perspective of Bob. -#[derive(Debug)] +#[derive(Debug, Default)] pub struct Bob { f: Arc>, buffer: Arc>, diff --git a/crates/mpz-core/src/ggm_tree.rs b/crates/mpz-core/src/ggm_tree.rs index 913fffb6..840efcc6 100644 --- a/crates/mpz-core/src/ggm_tree.rs +++ b/crates/mpz-core/src/ggm_tree.rs @@ -32,33 +32,35 @@ impl GgmTree { assert_eq!(k0.len(), self.depth); assert_eq!(k1.len(), self.depth); let mut buf = [Block::ZERO; 8]; - self.tkprp.expand_1to2(tree, seed); - k0[0] = tree[0]; - k1[0] = tree[1]; + if self.depth > 1 { + self.tkprp.expand_1to2(tree, seed); + k0[0] = tree[0]; + k1[0] = tree[1]; - self.tkprp.expand_2to4(&mut buf, tree); - k0[1] = buf[0] ^ buf[2]; - k1[1] = buf[1] ^ buf[3]; - tree[0..4].copy_from_slice(&buf[0..4]); - - for h in 2..self.depth { - k0[h] = Block::ZERO; - k1[h] = Block::ZERO; - - // How many nodes there are in this layer - let sz = 1 << h; - for i in (0..=sz - 4).rev().step_by(4) { - self.tkprp.expand_4to8(&mut buf, &tree[i..]); - k0[h] ^= buf[0]; - k0[h] ^= buf[2]; - k0[h] ^= buf[4]; - k0[h] ^= buf[6]; - k1[h] ^= buf[1]; - k1[h] ^= buf[3]; - k1[h] ^= buf[5]; - k1[h] ^= buf[7]; + self.tkprp.expand_2to4(&mut buf, tree); + k0[1] = buf[0] ^ buf[2]; + k1[1] = buf[1] ^ buf[3]; + tree[0..4].copy_from_slice(&buf[0..4]); - tree[2 * i..2 * i + 8].copy_from_slice(&buf); + for h in 2..self.depth { + k0[h] = Block::ZERO; + k1[h] = Block::ZERO; + + // How many nodes there are in this layer + let sz = 1 << h; + for i in (0..=sz - 4).rev().step_by(4) { + self.tkprp.expand_4to8(&mut buf, &tree[i..]); + k0[h] ^= buf[0]; + k0[h] ^= buf[2]; + k0[h] ^= buf[4]; + k0[h] ^= buf[6]; + k1[h] ^= buf[1]; + k1[h] ^= buf[3]; + k1[h] ^= buf[5]; + k1[h] ^= buf[7]; + + tree[2 * i..2 * i + 8].copy_from_slice(&buf); + } } } } diff --git a/crates/mpz-ot-core/src/ferret/mod.rs b/crates/mpz-ot-core/src/ferret/mod.rs index 3ad7701e..bbbf264a 100644 --- a/crates/mpz-ot-core/src/ferret/mod.rs +++ b/crates/mpz-ot-core/src/ferret/mod.rs @@ -36,11 +36,12 @@ pub const LPN_PARAMETERS_UNIFORM: LpnParameters = LpnParameters { }; /// The type of Lpn parameters. -#[derive(Debug)] +#[derive(Debug, Clone, Copy, Default)] pub enum LpnType { /// Uniform error distribution. Uniform, /// Regular error distribution. + #[default] Regular, } @@ -48,7 +49,6 @@ pub enum LpnType { mod tests { use super::*; - use msgs::LpnMatrixSeed; use receiver::Receiver; use sender::Sender; @@ -56,7 +56,6 @@ mod tests { use crate::test::assert_cot; use crate::{MPCOTReceiverOutput, MPCOTSenderOutput, RCOTReceiverOutput, RCOTSenderOutput}; use mpz_core::{lpn::LpnParameters, prg::Prg}; - use rand::SeedableRng; const LPN_PARAMETERS_TEST: LpnParameters = LpnParameters { n: 9600, @@ -66,7 +65,7 @@ mod tests { #[test] fn ferret_test() { - let mut prg = Prg::from_seed([1u8; 16].into()); + let mut prg = Prg::new(); let delta = prg.random_block(); let mut ideal_cot = IdealCOT::default(); let mut ideal_mpcot = IdealMpcot::default(); @@ -101,18 +100,8 @@ mod tests { ) .unwrap(); - let LpnMatrixSeed { - seed: lpn_matrix_seed, - } = seed; - let mut sender = sender - .setup( - delta, - LPN_PARAMETERS_TEST, - LpnType::Regular, - lpn_matrix_seed, - &v, - ) + .setup(delta, LPN_PARAMETERS_TEST, LpnType::Regular, seed, &v) .unwrap(); // extend once diff --git a/crates/mpz-ot-core/src/ferret/mpcot/mod.rs b/crates/mpz-ot-core/src/ferret/mpcot/mod.rs index e74dc38a..047780d4 100644 --- a/crates/mpz-ot-core/src/ferret/mpcot/mod.rs +++ b/crates/mpz-ot-core/src/ferret/mpcot/mod.rs @@ -16,11 +16,10 @@ mod tests { use crate::ideal::spcot::IdealSpcot; use crate::{SPCOTReceiverOutput, SPCOTSenderOutput}; use mpz_core::prg::Prg; - use rand::SeedableRng; #[test] fn mpcot_general_test() { - let mut prg = Prg::from_seed([1u8; 16].into()); + let mut prg = Prg::new(); let delta = prg.random_block(); let mut ideal_spcot = IdealSpcot::new_with_delta(delta); @@ -96,7 +95,7 @@ mod tests { #[test] fn mpcot_regular_test() { - let mut prg = Prg::from_seed([2u8; 16].into()); + let mut prg = Prg::new(); let delta = prg.random_block(); let mut ideal_spcot = IdealSpcot::new_with_delta(delta); diff --git a/crates/mpz-ot-core/src/ferret/mpcot/receiver.rs b/crates/mpz-ot-core/src/ferret/mpcot/receiver.rs index 0f8613af..e4d362da 100644 --- a/crates/mpz-ot-core/src/ferret/mpcot/receiver.rs +++ b/crates/mpz-ot-core/src/ferret/mpcot/receiver.rs @@ -32,11 +32,11 @@ impl Receiver { /// # Argument /// /// * `hash_seed` - Random seed to generate hashes, will be sent to the sender. - pub fn setup(self, hash_seed: Block) -> (Receiver, HashSeed) { + pub fn setup(self, hash_seed: Block) -> (Receiver, HashSeed) { let mut prg = Prg::from_seed(hash_seed); let hashes = std::array::from_fn(|_| AesEncryptor::new(prg.random_block())); let recv = Receiver { - state: state::PreExtension { + state: state::Extension { counter: 0, hashes: Arc::new(hashes), }, @@ -48,7 +48,7 @@ impl Receiver { } } -impl Receiver { +impl Receiver { /// Performs the hash procedure in MPCOT extension. /// Outputs the length of each bucket plus 1. /// @@ -63,7 +63,7 @@ impl Receiver { self, alphas: &[u32], n: u32, - ) -> Result<(Receiver, Vec<(usize, u32)>), ReceiverError> { + ) -> Result<(Receiver, Vec<(usize, u32)>), ReceiverError> { if alphas.len() as u32 > n { return Err(ReceiverError::InvalidInput( "length of alphas should not exceed n".to_string(), @@ -104,7 +104,7 @@ impl Receiver { } let receiver = Receiver { - state: state::Extension { + state: state::ExtensionInternal { counter: self.state.counter, m, n, @@ -117,7 +117,7 @@ impl Receiver { Ok((receiver, p)) } } -impl Receiver { +impl Receiver { /// Performs MPCOT extension. /// /// See Step 5 in Figure 7. @@ -128,7 +128,7 @@ impl Receiver { pub fn extend( self, rt: &[Vec], - ) -> Result<(Receiver, Vec), ReceiverError> { + ) -> Result<(Receiver, Vec), ReceiverError> { if rt.len() != self.state.m { return Err(ReceiverError::InvalidInput( "the length rt should be m".to_string(), @@ -165,7 +165,7 @@ impl Receiver { } let receiver = Receiver { - state: state::PreExtension { + state: state::Extension { counter: self.state.counter + 1, hashes: self.state.hashes, }, @@ -182,8 +182,8 @@ pub mod state { pub trait Sealed {} impl Sealed for super::Initialized {} - impl Sealed for super::PreExtension {} impl Sealed for super::Extension {} + impl Sealed for super::ExtensionInternal {} } /// The receiver's state. @@ -200,20 +200,20 @@ pub mod state { /// The receiver's state before extending. /// /// In this state the receiver performs pre extension in MPCOT (potentially multiple times). - pub struct PreExtension { + pub struct Extension { /// Current MPCOT counter pub(super) counter: usize, /// The hashes to generate Cuckoo hash table. pub(super) hashes: Arc<[AesEncryptor; CUCKOO_HASH_NUM]>, } - impl State for PreExtension {} + impl State for Extension {} - opaque_debug::implement!(PreExtension); + opaque_debug::implement!(Extension); /// The receiver's state of extension. /// /// In this state the receiver performs MPCOT extension (potentially multiple times). - pub struct Extension { + pub struct ExtensionInternal { /// Current MPCOT counter pub(super) counter: usize, /// Current length of Cuckoo hash table, will possibly be changed in each extension. @@ -228,7 +228,7 @@ pub mod state { pub(super) buckets_length: Vec, } - impl State for Extension {} + impl State for ExtensionInternal {} - opaque_debug::implement!(Extension); + opaque_debug::implement!(ExtensionInternal); } diff --git a/crates/mpz-ot-core/src/ferret/mpcot/receiver_regular.rs b/crates/mpz-ot-core/src/ferret/mpcot/receiver_regular.rs index 2b226108..e1e7edfe 100644 --- a/crates/mpz-ot-core/src/ferret/mpcot/receiver_regular.rs +++ b/crates/mpz-ot-core/src/ferret/mpcot/receiver_regular.rs @@ -19,13 +19,13 @@ impl Receiver { } /// Completes the setup phase of the protocol. - pub fn setup(self) -> Receiver { + pub fn setup(self) -> Receiver { Receiver { - state: state::PreExtension { counter: 0 }, + state: state::Extension { counter: 0 }, } } } -impl Receiver { +impl Receiver { /// Performs the prepare procedure in MPCOT extension. /// Outputs the indices for SPCOT. /// @@ -38,7 +38,7 @@ impl Receiver { self, alphas: &[u32], n: u32, - ) -> Result<(Receiver, Vec<(usize, u32)>), ReceiverError> { + ) -> Result<(Receiver, Vec<(usize, u32)>), ReceiverError> { let t = alphas.len() as u32; if t > n { return Err(ReceiverError::InvalidInput( @@ -91,7 +91,7 @@ impl Receiver { .collect(); let receiver = Receiver { - state: state::Extension { + state: state::ExtensionInternal { counter: self.state.counter, n, queries_length, @@ -103,7 +103,7 @@ impl Receiver { } } -impl Receiver { +impl Receiver { /// Performs MPCOT extension. /// /// # Arguments. @@ -112,7 +112,7 @@ impl Receiver { pub fn extend( self, rt: &[Vec], - ) -> Result<(Receiver, Vec), ReceiverError> { + ) -> Result<(Receiver, Vec), ReceiverError> { if rt .iter() .zip(self.state.queries_depth.iter()) @@ -130,7 +130,7 @@ impl Receiver { } let receiver = Receiver { - state: state::PreExtension { + state: state::Extension { counter: self.state.counter + 1, }, }; @@ -145,8 +145,8 @@ pub mod state { pub trait Sealed {} impl Sealed for super::Initialized {} - impl Sealed for super::PreExtension {} impl Sealed for super::Extension {} + impl Sealed for super::ExtensionInternal {} } /// The receiver's state. @@ -162,19 +162,19 @@ pub mod state { /// The receiver's state before extending. /// /// In this state the receiver performs pre extension in MPCOT (potentially multiple times). - pub struct PreExtension { + pub struct Extension { /// Current MPCOT counter pub(super) counter: usize, } - impl State for PreExtension {} + impl State for Extension {} - opaque_debug::implement!(PreExtension); + opaque_debug::implement!(Extension); /// The receiver's state after the setup phase. /// /// In this state the receiver performs MPCOT extension (potentially multiple times). - pub struct Extension { + pub struct ExtensionInternal { /// Current MPCOT counter #[allow(dead_code)] pub(super) counter: usize, @@ -186,7 +186,7 @@ pub mod state { pub(super) queries_depth: Vec, } - impl State for Extension {} + impl State for ExtensionInternal {} - opaque_debug::implement!(Extension); + opaque_debug::implement!(ExtensionInternal); } diff --git a/crates/mpz-ot-core/src/ferret/mpcot/sender.rs b/crates/mpz-ot-core/src/ferret/mpcot/sender.rs index f1e49105..ad025574 100644 --- a/crates/mpz-ot-core/src/ferret/mpcot/sender.rs +++ b/crates/mpz-ot-core/src/ferret/mpcot/sender.rs @@ -31,12 +31,12 @@ impl Sender { /// /// * `delta` - The sender's global secret. /// * `hash_seed` - The seed for Cuckoo hash sent by the receiver. - pub fn setup(self, delta: Block, hash_seed: HashSeed) -> Sender { + pub fn setup(self, delta: Block, hash_seed: HashSeed) -> Sender { let HashSeed { seed: hash_seed } = hash_seed; let mut prg = Prg::from_seed(hash_seed); let hashes = std::array::from_fn(|_| AesEncryptor::new(prg.random_block())); Sender { - state: state::PreExtension { + state: state::Extension { delta, counter: 0, hashes: Arc::new(hashes), @@ -45,7 +45,7 @@ impl Sender { } } -impl Sender { +impl Sender { /// Performs the hash procedure in MPCOT extension. /// Outputs the length of each bucket plus 1. /// @@ -59,7 +59,7 @@ impl Sender { self, t: u32, n: u32, - ) -> Result<(Sender, Vec), SenderError> { + ) -> Result<(Sender, Vec), SenderError> { if t > n { return Err(SenderError::InvalidInput( "t should not exceed n".to_string(), @@ -86,7 +86,7 @@ impl Sender { } let sender = Sender { - state: state::Extension { + state: state::ExtensionInternal { delta: self.state.delta, counter: self.state.counter, m, @@ -101,7 +101,7 @@ impl Sender { } } -impl Sender { +impl Sender { /// Performs MPCOT extension. /// /// See Step 5 in Figure 7. @@ -112,7 +112,7 @@ impl Sender { pub fn extend( self, st: &[Vec], - ) -> Result<(Sender, Vec), SenderError> { + ) -> Result<(Sender, Vec), SenderError> { if st.len() != self.state.m { return Err(SenderError::InvalidInput( "the length st should be m".to_string(), @@ -147,7 +147,7 @@ impl Sender { } let sender = Sender { - state: state::PreExtension { + state: state::Extension { delta: self.state.delta, counter: self.state.counter + 1, hashes: self.state.hashes, @@ -166,8 +166,8 @@ pub mod state { pub trait Sealed {} impl Sealed for super::Initialized {} - impl Sealed for super::PreExtension {} impl Sealed for super::Extension {} + impl Sealed for super::ExtensionInternal {} } /// The sender's state. @@ -184,7 +184,7 @@ pub mod state { /// The sender's state before extending. /// /// In this state the sender performs pre extension in MPCOT (potentially multiple times). - pub struct PreExtension { + pub struct Extension { /// Sender's global secret. pub(super) delta: Block, /// Current MPCOT counter @@ -193,13 +193,13 @@ pub mod state { pub(super) hashes: Arc<[AesEncryptor; CUCKOO_HASH_NUM]>, } - impl State for PreExtension {} - opaque_debug::implement!(PreExtension); + impl State for Extension {} + opaque_debug::implement!(Extension); /// The sender's state of extension. /// /// In this state the sender performs MPCOT extension (potentially multiple times). - pub struct Extension { + pub struct ExtensionInternal { /// Sender's global secret. pub(super) delta: Block, /// Current MPCOT counter @@ -217,7 +217,7 @@ pub mod state { pub(super) buckets_length: Vec, } - impl State for Extension {} + impl State for ExtensionInternal {} - opaque_debug::implement!(Extension); + opaque_debug::implement!(ExtensionInternal); } diff --git a/crates/mpz-ot-core/src/ferret/mpcot/sender_regular.rs b/crates/mpz-ot-core/src/ferret/mpcot/sender_regular.rs index db0646b6..7afa5106 100644 --- a/crates/mpz-ot-core/src/ferret/mpcot/sender_regular.rs +++ b/crates/mpz-ot-core/src/ferret/mpcot/sender_regular.rs @@ -23,14 +23,14 @@ impl Sender { /// # Argument. /// /// * `delta` - The sender's global secret. - pub fn setup(self, delta: Block) -> Sender { + pub fn setup(self, delta: Block) -> Sender { Sender { - state: state::PreExtension { delta, counter: 0 }, + state: state::Extension { delta, counter: 0 }, } } } -impl Sender { +impl Sender { /// Performs the prepare procedure in MPCOT extension. /// Outputs the information for SPCOT. /// @@ -42,7 +42,7 @@ impl Sender { self, t: u32, n: u32, - ) -> Result<(Sender, Vec), SenderError> { + ) -> Result<(Sender, Vec), SenderError> { if t > n { return Err(SenderError::InvalidInput( "t should not exceed n".to_string(), @@ -78,7 +78,7 @@ impl Sender { } let sender = Sender { - state: state::Extension { + state: state::ExtensionInternal { delta: self.state.delta, counter: self.state.counter, n, @@ -91,7 +91,7 @@ impl Sender { } } -impl Sender { +impl Sender { /// Performs MPCOT extension. /// /// # Arguments. @@ -100,7 +100,7 @@ impl Sender { pub fn extend( self, st: &[Vec], - ) -> Result<(Sender, Vec), SenderError> { + ) -> Result<(Sender, Vec), SenderError> { if st .iter() .zip(self.state.queries_depth.iter()) @@ -117,7 +117,7 @@ impl Sender { } let sender = Sender { - state: state::PreExtension { + state: state::Extension { delta: self.state.delta, counter: self.state.counter + 1, }, @@ -135,8 +135,8 @@ pub mod state { pub trait Sealed {} impl Sealed for super::Initialized {} - impl Sealed for super::PreExtension {} impl Sealed for super::Extension {} + impl Sealed for super::ExtensionInternal {} } /// The sender's state. @@ -153,20 +153,20 @@ pub mod state { /// The sender's state before extending. /// /// In this state the sender performs pre extension in MPCOT (potentially multiple times). - pub struct PreExtension { + pub struct Extension { /// Sender's global secret. pub(super) delta: Block, /// Current MPCOT counter pub(super) counter: usize, } - impl State for PreExtension {} - opaque_debug::implement!(PreExtension); + impl State for Extension {} + opaque_debug::implement!(Extension); /// The sender's state after the setup phase. /// /// In this state the sender performs MPCOT extension (potentially multiple times). - pub struct Extension { + pub struct ExtensionInternal { /// Sender's global secret. pub(super) delta: Block, /// Current MPCOT counter @@ -179,7 +179,7 @@ pub mod state { pub(super) queries_depth: Vec, } - impl State for Extension {} + impl State for ExtensionInternal {} - opaque_debug::implement!(Extension); + opaque_debug::implement!(ExtensionInternal); } diff --git a/crates/mpz-ot-core/src/ferret/receiver.rs b/crates/mpz-ot-core/src/ferret/receiver.rs index 4d08c69b..e5939c60 100644 --- a/crates/mpz-ot-core/src/ferret/receiver.rs +++ b/crates/mpz-ot-core/src/ferret/receiver.rs @@ -4,7 +4,10 @@ use mpz_core::{ Block, }; -use crate::ferret::{error::ReceiverError, LpnType}; +use crate::{ + ferret::{error::ReceiverError, LpnType}, + TransferId, +}; use super::msgs::LpnMatrixSeed; @@ -59,6 +62,7 @@ impl Receiver { u: u.to_vec(), w: w.to_vec(), e: Vec::default(), + id: TransferId::default(), }, }, LpnMatrixSeed { seed }, @@ -69,10 +73,6 @@ impl Receiver { impl Receiver { /// The prepare precedure of extension, sample error vectors and outputs information for MPCOT. /// See step 3 and 4. - /// - /// # Arguments. - /// - /// * `lpn_type` - The type of LPN parameters. pub fn get_mpcot_query(&mut self) -> (Vec, usize) { match self.state.lpn_type { LpnType::Uniform => { @@ -105,6 +105,8 @@ impl Receiver { return Err(ReceiverError("the length of r should be n".to_string())); } + self.state.id.next(); + // Compute z = A * w + r. let mut z = r.to_vec(); self.state.lpn_encoder.compute(&mut z, &self.state.w); @@ -133,6 +135,11 @@ impl Receiver { Ok((x_, z_)) } + + /// Returns id + pub fn id(&self) -> TransferId { + self.state.id + } } /// The receiver's state. @@ -176,6 +183,9 @@ pub mod state { /// Receiver's lpn error vector. pub(super) e: Vec, + + /// TransferID + pub(super) id: TransferId, } impl State for Extension {} diff --git a/crates/mpz-ot-core/src/ferret/sender.rs b/crates/mpz-ot-core/src/ferret/sender.rs index 9e8db180..2af3e4ae 100644 --- a/crates/mpz-ot-core/src/ferret/sender.rs +++ b/crates/mpz-ot-core/src/ferret/sender.rs @@ -4,7 +4,12 @@ use mpz_core::{ Block, }; -use crate::ferret::{error::SenderError, LpnType}; +use crate::{ + ferret::{error::SenderError, LpnType}, + TransferId, +}; + +use super::msgs::LpnMatrixSeed; /// Ferret sender. #[derive(Debug, Default)] @@ -36,7 +41,7 @@ impl Sender { delta: Block, lpn_parameters: LpnParameters, lpn_type: LpnType, - seed: Block, + seed: LpnMatrixSeed, v: &[Block], ) -> Result, SenderError> { if v.len() != lpn_parameters.k { @@ -44,6 +49,7 @@ impl Sender { "the length of v should be equal to k".to_string(), )); } + let LpnMatrixSeed { seed } = seed; let lpn_encoder = LpnEncoder::<10>::new(seed, lpn_parameters.k as u32); Ok(Sender { @@ -54,6 +60,7 @@ impl Sender { lpn_type, lpn_encoder, v: v.to_vec(), + id: TransferId::default(), }, }) } @@ -63,6 +70,7 @@ impl Sender { /// Outputs the information for MPCOT. /// /// See step 3 and 4. + #[inline] pub fn get_mpcot_query(&self) -> (u32, u32) { ( self.state.lpn_parameters.t as u32, @@ -83,6 +91,8 @@ impl Sender { return Err(SenderError("the length of s should be n".to_string())); } + self.state.id.next(); + // Compute y = A * v + s let mut y = s.to_vec(); self.state.lpn_encoder.compute(&mut y, &self.state.v); @@ -97,10 +107,17 @@ impl Sender { Ok(y_) } + + /// Returns id + pub fn id(&self) -> TransferId { + self.state.id + } } /// The sender's state. pub mod state { + use crate::TransferId; + use super::*; mod sealed { @@ -141,6 +158,9 @@ pub mod state { /// Sender's COT message in the setup phase. pub(super) v: Vec, + + /// TransferID. + pub(crate) id: TransferId, } impl State for Extension {} diff --git a/crates/mpz-ot-core/src/ferret/spcot/mod.rs b/crates/mpz-ot-core/src/ferret/spcot/mod.rs index 802efb66..63ebea15 100644 --- a/crates/mpz-ot-core/src/ferret/spcot/mod.rs +++ b/crates/mpz-ot-core/src/ferret/spcot/mod.rs @@ -7,8 +7,6 @@ pub mod sender; #[cfg(test)] mod tests { - use mpz_core::prg::Prg; - use super::{receiver::Receiver as SpcotReceiver, sender::Sender as SpcotSender}; use crate::{ferret::CSP, ideal::cot::IdealCOT, RCOTReceiverOutput, RCOTSenderOutput}; @@ -18,49 +16,82 @@ mod tests { let sender = SpcotSender::new(); let receiver = SpcotReceiver::new(); - let mut prg = Prg::new(); - let sender_seed = prg.random_block(); let delta = ideal_cot.delta(); - let mut sender = sender.setup(delta, sender_seed); + let mut sender = sender.setup(delta); let mut receiver = receiver.setup(); - let h1 = 8; - let alpha1 = 3; + let hs = [8, 4, 10]; + let alphas = [3, 2, 4]; - // Extend once - let (msg_for_sender, msg_for_receiver) = ideal_cot.random_correlated(h1); + let h_sum = hs.iter().sum(); + // batch extension + let (msg_for_sender, msg_for_receiver) = ideal_cot.random_correlated(h_sum); let RCOTReceiverOutput { - choices: rs, - msgs: ts, + choices: rss, + msgs: tss, .. } = msg_for_receiver; - let RCOTSenderOutput { msgs: qs, .. } = msg_for_sender; - let maskbits = receiver.extend_mask_bits(h1, alpha1, &rs).unwrap(); - let msg_from_sender = sender.extend(h1, &qs, maskbits).unwrap(); + let RCOTSenderOutput { msgs: qss, .. } = msg_for_sender; + + let maskbits = receiver.extend_mask_bits(&hs, &alphas, &rss).unwrap(); + + let msg_from_sender = sender.extend(&hs, &qss, &maskbits).unwrap(); + + receiver + .extend(&hs, &alphas, &tss, &msg_from_sender) + .unwrap(); + + // Check + let (msg_for_sender, msg_for_receiver) = ideal_cot.random_correlated(CSP); + + let RCOTReceiverOutput { + choices: x_star, + msgs: z_star, + .. + } = msg_for_receiver; + + let RCOTSenderOutput { msgs: y_star, .. } = msg_for_sender; + + let check_from_receiver = receiver.check_pre(&x_star).unwrap(); - receiver.extend(h1, alpha1, &ts, msg_from_sender).unwrap(); + let (mut output_sender, check) = sender.check(&y_star, check_from_receiver).unwrap(); - // Extend twice - let h2 = 4; - let alpha2 = 2; + let output_receiver = receiver.check(&z_star, check).unwrap(); - let (msg_for_sender, msg_for_receiver) = ideal_cot.random_correlated(h2); + assert!(output_sender + .iter_mut() + .zip(output_receiver.iter()) + .all(|(vs, (ws, alpha))| { + vs[*alpha as usize] ^= delta; + vs == ws + })); + + // extend twice + let hs = [6, 9, 8]; + let alphas = [2, 1, 3]; + + let h_sum = hs.iter().sum(); + + let (msg_for_sender, msg_for_receiver) = ideal_cot.random_correlated(h_sum); let RCOTReceiverOutput { - choices: rs, - msgs: ts, + choices: rss, + msgs: tss, .. } = msg_for_receiver; - let RCOTSenderOutput { msgs: qs, .. } = msg_for_sender; - let maskbits = receiver.extend_mask_bits(h2, alpha2, &rs).unwrap(); + let RCOTSenderOutput { msgs: qss, .. } = msg_for_sender; + + let maskbits = receiver.extend_mask_bits(&hs, &alphas, &rss).unwrap(); - let msg_from_sender = sender.extend(h2, &qs, maskbits).unwrap(); + let msg_from_sender = sender.extend(&hs, &qss, &maskbits).unwrap(); - receiver.extend(h2, alpha2, &ts, msg_from_sender).unwrap(); + receiver + .extend(&hs, &alphas, &tss, &msg_from_sender) + .unwrap(); // Check let (msg_for_sender, msg_for_receiver) = ideal_cot.random_correlated(CSP); diff --git a/crates/mpz-ot-core/src/ferret/spcot/receiver.rs b/crates/mpz-ot-core/src/ferret/spcot/receiver.rs index 5e860f31..baf10ae2 100644 --- a/crates/mpz-ot-core/src/ferret/spcot/receiver.rs +++ b/crates/mpz-ot-core/src/ferret/spcot/receiver.rs @@ -6,6 +6,10 @@ use mpz_core::{ utils::blake3, Block, }; use rand_core::SeedableRng; +#[cfg(feature = "rayon")] +use rayon::iter::{ + IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelIterator, +}; use super::msgs::{CheckFromReceiver, CheckFromSender, ExtendFromSender, MaskBits}; @@ -43,71 +47,101 @@ impl Receiver { } impl Receiver { - /// Performs the mask bit step in extension. + /// Performs the mask bit step in batch in extension. /// /// See step 4 in Figure 6. /// /// # Arguments /// - /// * `h` - The depth of the GGM tree. - /// * `alpha` - The chosen position. - /// * `rs` - The message from COT ideal functionality for the receiver. Only the random bits are used. + /// * `hs` - The depths of the GGM trees. + /// * `alphas` - The vector of chosen positions. + /// * `rss` - The message from COT ideal functionality for the receiver for all the tress. Only the random bits are used. pub fn extend_mask_bits( &mut self, - h: usize, - alpha: u32, - rs: &[bool], - ) -> Result { + hs: &[usize], + alphas: &[u32], + rss: &[bool], + ) -> Result, ReceiverError> { if self.state.extended { return Err(ReceiverError::InvalidState( "extension is not allowed".to_string(), )); } - if alpha >= (1 << h) { + if alphas.len() != hs.len() { + return Err(ReceiverError::InvalidLength( + "the length of alphas should be the length of hs".to_string(), + )); + } + + if alphas + .iter() + .zip(hs.iter()) + .any(|(alpha, h)| *alpha >= (1 << h)) + { return Err(ReceiverError::InvalidInput( "the input pos should be no more than 2^h-1".to_string(), )); } - if rs.len() != h { + let h_sum = hs.iter().sum(); + + if rss.len() != h_sum { return Err(ReceiverError::InvalidLength( - "the length of r should be h".to_string(), + "the length of r should be the sum of h".to_string(), )); } - // Step 4 in Figure 6 + let mut rs_s = vec![Vec::::new(); hs.len()]; + let mut rss_vec = rss.to_vec(); + for (index, h) in hs.iter().enumerate() { + rs_s[index] = rss_vec.drain(0..*h).collect(); + } - let bs: Vec = alpha - .iter_msb0() - .skip(32 - h) - // Computes alpha_i XOR r_i XOR 1. - .zip(rs.iter()) - .map(|(alpha, &r)| alpha == r) - .collect(); + // Step 4 in Figure 6 + let mut bss = vec![Vec::::new(); hs.len()]; + + let iter = bss + .iter_mut() + .zip(alphas.iter()) + .zip(hs.iter()) + .zip(rs_s.iter()) + .map(|(((bs, alpha), h), rs)| (bs, alpha, h, rs)); + + for (bs, alpha, h, rs) in iter { + *bs = alpha + .iter_msb0() + .skip(32 - h) + // Computes alpha_i XOR r_i XOR 1. + .zip(rs.iter()) + .map(|(alpha, &r)| alpha == r) + .collect(); + } // Updates hasher. - self.state.hasher.update(&bs.to_bytes()); + self.state.hasher.update(&bss.to_bytes()); + + let res: Vec = bss.into_iter().map(|bs| MaskBits { bs }).collect(); - Ok(MaskBits { bs }) + Ok(res) } - /// Performs the GGM reconstruction step in extension. This function can be called multiple times before checking. + /// Performs the GGM reconstruction step in batch in extension. This function can be called multiple times before checking. /// /// See step 5 in Figure 6. /// /// # Arguments /// - /// * `h` - The depth of the GGM tree. - /// * `alpha` - The chosen position. - /// * `ts` - The message from COT ideal functionality for the receiver. Only the chosen blocks are used. - /// * `extendfs` - The message sent by the sender. + /// * `hs` - The depths of the GGM trees. + /// * `alphas` - The vector of chosen positions. + /// * `tss` - The message from COT ideal functionality for the receiver. Only the chosen blocks are used. + /// * `extendfss` - The vector of messages sent by the sender. pub fn extend( &mut self, - h: usize, - alpha: u32, - ts: &[Block], - extendfs: ExtendFromSender, + hs: &[usize], + alphas: &[u32], + tss: &[Block], + extendfss: &[ExtendFromSender], ) -> Result<(), ReceiverError> { if self.state.extended { return Err(ReceiverError::InvalidState( @@ -115,61 +149,122 @@ impl Receiver { )); } - if alpha >= (1 << h) { + if alphas.len() != hs.len() { + return Err(ReceiverError::InvalidLength( + "the length of alphas should be the length of hs".to_string(), + )); + } + + if alphas + .iter() + .zip(hs.iter()) + .any(|(alpha, h)| *alpha >= (1 << h)) + { return Err(ReceiverError::InvalidInput( "the input pos should be no more than 2^h-1".to_string(), )); } - let ExtendFromSender { ms, sum } = extendfs; - if ts.len() != h { + let h_sum = hs.iter().sum(); + + if tss.len() != h_sum { return Err(ReceiverError::InvalidLength( - "the length of t should be h".to_string(), + "the length of tss should be the sum of h".to_string(), )); } - if ms.len() != h { + let mut ts_s = vec![Vec::::new(); hs.len()]; + let mut tss_vec = tss.to_vec(); + for (index, h) in hs.iter().enumerate() { + ts_s[index] = tss_vec.drain(0..*h).collect(); + } + + if extendfss.len() != hs.len() { return Err(ReceiverError::InvalidLength( - "the length of M should be h".to_string(), + "the length of extendfss should be the length of hs".to_string(), )); } - // Updates hasher - self.state.hasher.update(&ms.to_bytes()); - self.state.hasher.update(&sum.to_bytes()); - - let alpha_bar_vec: Vec = alpha.iter_msb0().skip(32 - h).map(|a| !a).collect(); - - // Step 5 in Figure 6. - let k: Vec = ms - .into_iter() - .zip(ts) - .zip(alpha_bar_vec.iter()) - .enumerate() - .map(|(i, (([m0, m1], &t), &b))| { - let tweak: Block = bytemuck::cast([i, self.state.exec_counter]); - if !b { - // H(t, i|ell) ^ M0 - FIXED_KEY_AES.tccr(tweak, t) ^ m0 - } else { - // H(t, i|ell) ^ M1 - FIXED_KEY_AES.tccr(tweak, t) ^ m1 - } - }) - .collect(); + let mut ms_s = vec![Vec::<[Block; 2]>::new(); hs.len()]; + let mut sum_s = vec![Block::ZERO; hs.len()]; - // Reconstructs GGM tree except `ws[alpha]`. - let ggm_tree = GgmTree::new(h); - let mut tree = vec![Block::ZERO; 1 << h]; - ggm_tree.reconstruct(&mut tree, &k, &alpha_bar_vec); + for (index, extendfs) in extendfss.iter().enumerate() { + ms_s[index].clone_from(&extendfs.ms); + sum_s[index] = extendfs.sum; + } + + if ms_s.iter().zip(hs.iter()).any(|(ms, h)| ms.len() != *h) { + return Err(ReceiverError::InvalidLength( + "the length of ms should be h".to_string(), + )); + } + // Updates hasher + self.state.hasher.update(&ms_s.to_bytes()); + self.state.hasher.update(&sum_s.to_bytes()); + + let mut trees = vec![Vec::::new(); hs.len()]; + + cfg_if::cfg_if! { + if #[cfg(feature = "rayon")]{ + let iter = alphas + .par_iter() + .zip(ms_s.par_iter()) + .zip(sum_s.par_iter()) + .zip(hs.par_iter()) + .zip(ts_s.par_iter()) + .zip(trees.par_iter_mut()) + .map(|(((((alpha, ms), sum), h), ts), tree)| (alpha, ms, sum, h, ts, tree)); + }else{ + let iter = alphas + .iter() + .zip(ms_s.iter()) + .zip(sum_s.iter()) + .zip(hs.iter()) + .zip(ts_s.iter()) + .zip(trees.iter_mut()) + .map(|(((((alpha, ms), sum), h), ts), tree)| (alpha, ms, sum, h, ts, tree)); + } + } - // Sets `tree[alpha]`, which is `ws[alpha]`. - tree[alpha as usize] = tree.iter().fold(sum, |acc, &x| acc ^ x); + iter.for_each(|(alpha, ms, sum, h, ts, tree)| { + let alpha_bar_vec: Vec = alpha.iter_msb0().skip(32 - h).map(|a| !a).collect(); + + // Step 5 in Figure 6. + let k: Vec = ms + .iter() + .zip(ts) + .zip(alpha_bar_vec.iter()) + .enumerate() + .map(|(i, (([m0, m1], &t), &b))| { + let tweak: Block = bytemuck::cast([i, self.state.exec_counter]); + if !b { + // H(t, i|ell) ^ M0 + FIXED_KEY_AES.tccr(tweak, t) ^ *m0 + } else { + // H(t, i|ell) ^ M1 + FIXED_KEY_AES.tccr(tweak, t) ^ *m1 + } + }) + .collect(); + + // Reconstructs GGM tree except `ws[alpha]`. + let ggm_tree = GgmTree::new(*h); + *tree = vec![Block::ZERO; 1 << h]; + ggm_tree.reconstruct(tree, &k, &alpha_bar_vec); + + // Sets `tree[alpha]`, which is `ws[alpha]`. + tree[(*alpha) as usize] = tree.iter().fold(*sum, |acc, &x| acc ^ x); + }); + + for tree in trees { + self.state.unchecked_ws.extend_from_slice(&tree); + } - self.state.unchecked_ws.extend_from_slice(&tree); - self.state.alphas_and_length.push((alpha, 1 << h)); + for (alpha, h) in alphas.iter().zip(hs.iter()) { + self.state.alphas_and_length.push((*alpha, 1 << h)); + } - self.state.exec_counter += 1; + self.state.exec_counter += hs.len(); Ok(()) } @@ -248,7 +343,6 @@ impl Receiver { } self.state.cot_counter += self.state.unchecked_ws.len(); - self.state.extended = true; let mut res = Vec::new(); for (alpha, n) in &self.state.alphas_and_length { @@ -256,8 +350,19 @@ impl Receiver { res.push((tmp, *alpha)); } + self.state.hasher = blake3::Hasher::new(); + self.state.alphas_and_length.clear(); + self.state.chis.clear(); + self.state.unchecked_ws.clear(); + Ok(res) } + + /// Complete extension. + #[inline] + pub fn finalize(&mut self) { + self.state.extended = true; + } } /// The receiver's state. diff --git a/crates/mpz-ot-core/src/ferret/spcot/sender.rs b/crates/mpz-ot-core/src/ferret/spcot/sender.rs index fef1327e..a62ad3bb 100644 --- a/crates/mpz-ot-core/src/ferret/spcot/sender.rs +++ b/crates/mpz-ot-core/src/ferret/spcot/sender.rs @@ -5,6 +5,10 @@ use mpz_core::{ utils::blake3, Block, }; use rand_core::SeedableRng; +#[cfg(feature = "rayon")] +use rayon::iter::{ + IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelIterator, +}; use super::msgs::{CheckFromReceiver, CheckFromSender, ExtendFromSender, MaskBits}; @@ -29,8 +33,7 @@ impl Sender { /// # Arguments /// /// * `delta` - The sender's global secret. - /// * `seed` - The random seed to generate PRG. - pub fn setup(self, delta: Block, seed: Block) -> Sender { + pub fn setup(self, delta: Block) -> Sender { Sender { state: state::Extension { delta, @@ -39,7 +42,6 @@ impl Sender { cot_counter: 0, exec_counter: 0, extended: false, - prg: Prg::from_seed(seed), hasher: blake3::Hasher::new(), }, } @@ -47,85 +49,137 @@ impl Sender { } impl Sender { - /// Performs the SPCOT extension. + /// Performs batch SPCOT extension. /// /// See Step 1-5 in Figure 6. /// /// # Arguments /// - /// * `h` - The depth of the GGM tree. - /// * `qs`- The blocks received by calling the COT functionality. - /// * `mask`- The mask bits sent by the receiver. + /// * `hs` - The depths of the GGM trees. + /// * `qss`- The blocks received by calling the COT functionality for hs trees. + /// * `masks`- The vector of mask bits sent by the receiver. pub fn extend( &mut self, - h: usize, - qs: &[Block], - mask: MaskBits, - ) -> Result { + hs: &[usize], + qss: &[Block], + masks: &[MaskBits], + ) -> Result, SenderError> { if self.state.extended { return Err(SenderError::InvalidState( "extension is not allowed".to_string(), )); } - if qs.len() != h { + let h_sum = hs.iter().sum(); + + if qss.len() != h_sum { return Err(SenderError::InvalidLength( - "the length of q should be h".to_string(), + "the length of qss should be the sum of h".to_string(), )); } - let MaskBits { bs } = mask; + let mut qs_s = vec![Vec::::new(); hs.len()]; + let mut qss_vec = qss.to_vec(); + for (index, h) in hs.iter().enumerate() { + qs_s[index] = qss_vec.drain(0..*h).collect(); + } - if bs.len() != h { + if masks.len() != hs.len() { + return Err(SenderError::InvalidLength( + "the length of masks should be the length of hs".to_string(), + )); + } + + let bss: Vec> = masks.iter().map(|m| m.clone().bs).collect(); + + if bss.iter().zip(hs.iter()).any(|(b, h)| b.len() != *h) { return Err(SenderError::InvalidLength( "the length of b should be h".to_string(), )); } // Updates hasher. - self.state.hasher.update(&bs.to_bytes()); + self.state.hasher.update(&bss.to_bytes()); // Step 3-4, Figure 6. // Generates a GGM tree with depth h and seed s. - let s = self.state.prg.random_block(); - let ggm_tree = GgmTree::new(h); - let mut k0 = vec![Block::ZERO; h]; - let mut k1 = vec![Block::ZERO; h]; - let mut tree = vec![Block::ZERO; 1 << h]; - ggm_tree.gen(s, &mut tree, &mut k0, &mut k1); + let mut trees = vec![Vec::::new(); hs.len()]; + let mut ms_s = vec![Vec::<[Block; 2]>::new(); hs.len()]; + let mut sum_s = vec![Block::ZERO; hs.len()]; + + cfg_if::cfg_if! { + if #[cfg(feature = "rayon")]{ + let iter = trees + .par_iter_mut().zip(hs.par_iter()) + .zip(qs_s.par_iter()) + .zip(bss.par_iter()) + .zip(ms_s.par_iter_mut()) + .zip(sum_s.par_iter_mut()) + .map(|(((((tree, h), qs), bs), ms), sum)| (tree, h, qs, bs, ms, sum)); + }else{ + let iter = trees + .iter_mut() + .zip(hs.iter()) + .zip(qs_s.iter()) + .zip(bss.iter()) + .zip(ms_s.iter_mut()) + .zip(sum_s.iter_mut()) + .map(|(((((tree, h), qs), bs), ms), sum)| (tree, h, qs, bs, ms, sum)); + } + } + + iter.for_each(|(tree, h, qs, bs, ms, sum)| { + let s = Prg::new().random_block(); + let ggm_tree = GgmTree::new(*h); + let mut k0 = vec![Block::ZERO; *h]; + let mut k1 = vec![Block::ZERO; *h]; + *tree = vec![Block::ZERO; 1 << h]; + ggm_tree.gen(s, tree, &mut k0, &mut k1); + + // Computes the sum of the leaves and delta. + *sum = tree.iter().fold(self.state.delta, |acc, &x| acc ^ x); + + // Computes M0 and M1. + for (((i, &q), b), (k0, k1)) in + qs.iter().enumerate().zip(bs).zip(k0.into_iter().zip(k1)) + { + let mut m = if *b { + [q ^ self.state.delta, q] + } else { + [q, q ^ self.state.delta] + }; + let tweak: Block = bytemuck::cast([i, self.state.exec_counter]); + FIXED_KEY_AES.tccr_many(&[tweak, tweak], &mut m); + m[0] ^= k0; + m[1] ^= k1; + ms.push(m); + } + }); // Stores the tree, i.e., the possible output of sender. - self.state.unchecked_vs.extend_from_slice(&tree); + for tree in trees { + self.state.unchecked_vs.extend_from_slice(&tree); + } // Stores the length of this extension. - self.state.vs_length.push(1 << h); - - // Computes the sum of the leaves and delta. - let sum = tree.iter().fold(self.state.delta, |acc, &x| acc ^ x); - - // Computes M0 and M1. - let mut ms: Vec<[Block; 2]> = Vec::with_capacity(qs.len()); - for (((i, &q), b), (k0, k1)) in qs.iter().enumerate().zip(bs).zip(k0.into_iter().zip(k1)) { - let mut m = if b { - [q ^ self.state.delta, q] - } else { - [q, q ^ self.state.delta] - }; - let tweak: Block = bytemuck::cast([i, self.state.exec_counter]); - FIXED_KEY_AES.tccr_many(&[tweak, tweak], &mut m); - m[0] ^= k0; - m[1] ^= k1; - ms.push(m); + for h in hs { + self.state.vs_length.push(1 << h); } // Updates hasher - self.state.hasher.update(&ms.to_bytes()); - self.state.hasher.update(&sum.to_bytes()); + self.state.hasher.update(&ms_s.to_bytes()); + self.state.hasher.update(&sum_s.to_bytes()); - self.state.exec_counter += 1; + self.state.exec_counter += hs.len(); + + let res: Vec = ms_s + .into_iter() + .zip(sum_s.iter()) + .map(|(ms, &sum)| ExtendFromSender { ms, sum }) + .collect(); - Ok(ExtendFromSender { ms, sum }) + Ok(res) } /// Performs the consistency check for the resulting COTs. @@ -193,10 +247,18 @@ impl Sender { res.push(tmp); } - self.state.extended = true; + self.state.hasher = blake3::Hasher::new(); + self.state.unchecked_vs.clear(); + self.state.vs_length.clear(); Ok((res, CheckFromSender { hashed_v })) } + + /// Complete extension. + #[inline] + pub fn finalize(&mut self) { + self.state.extended = true; + } } /// The sender's state. @@ -239,8 +301,6 @@ pub mod state { /// This is to prevent the receiver from extending twice pub(super) extended: bool, - /// A PRG to generate random strings. - pub(super) prg: Prg, /// A hasher to generate chi seed. pub(super) hasher: blake3::Hasher, } diff --git a/crates/mpz-ot/src/ferret/error.rs b/crates/mpz-ot/src/ferret/error.rs new file mode 100644 index 00000000..6952f0ec --- /dev/null +++ b/crates/mpz-ot/src/ferret/error.rs @@ -0,0 +1,67 @@ +use crate::OTError; + +/// A Ferret sender error. +#[derive(Debug, thiserror::Error)] +#[allow(missing_docs)] +pub enum SenderError { + #[error(transparent)] + IOError(#[from] std::io::Error), + #[error(transparent)] + CoreError(#[from] mpz_ot_core::ferret::error::SenderError), + #[error(transparent)] + MPCOTSenderError(#[from] crate::ferret::mpcot::SenderError), + #[error(transparent)] + RandomCOTError(#[from] OTError), + #[error("{0}")] + StateError(String), + #[error("{0}")] + MPCOTSenderTypeError(String), +} + +impl From for OTError { + fn from(err: SenderError) -> Self { + match err { + SenderError::IOError(e) => e.into(), + e => OTError::SenderError(Box::new(e)), + } + } +} + +impl From for SenderError { + fn from(err: crate::ferret::sender::StateError) -> Self { + SenderError::StateError(err.to_string()) + } +} + +/// A Ferret receiver error. +#[derive(Debug, thiserror::Error)] +#[allow(missing_docs)] +pub enum ReceiverError { + #[error(transparent)] + IOError(#[from] std::io::Error), + #[error(transparent)] + CoreError(#[from] mpz_ot_core::ferret::error::ReceiverError), + #[error(transparent)] + MPCOTReceiverError(#[from] crate::ferret::mpcot::ReceiverError), + #[error(transparent)] + RandomCOTError(#[from] OTError), + #[error("{0}")] + StateError(String), + #[error("{0}")] + MPCOTReceiverTypeError(String), +} + +impl From for OTError { + fn from(err: ReceiverError) -> Self { + match err { + ReceiverError::IOError(e) => e.into(), + e => OTError::ReceiverError(Box::new(e)), + } + } +} + +impl From for ReceiverError { + fn from(err: crate::ferret::receiver::StateError) -> Self { + ReceiverError::StateError(err.to_string()) + } +} diff --git a/crates/mpz-ot/src/ferret/mod.rs b/crates/mpz-ot/src/ferret/mod.rs new file mode 100644 index 00000000..2b2047b9 --- /dev/null +++ b/crates/mpz-ot/src/ferret/mod.rs @@ -0,0 +1,175 @@ +//! An implementation of the [`Ferret`](https://eprint.iacr.org/2020/924.pdf) protocol. +mod error; +mod mpcot; +mod receiver; +mod sender; +mod spcot; + +pub use error::{ReceiverError, SenderError}; +pub use receiver::Receiver; +pub use sender::Sender; + +use mpz_core::lpn::LpnParameters; +use mpz_ot_core::ferret::LpnType; + +/// Configuration of Ferret. +#[derive(Debug)] +pub struct FerretConfig { + rcot: RandomCOT, + setup_rcot: SetupRandomCOT, + lpn_parameters: LpnParameters, + lpn_type: LpnType, +} + +impl FerretConfig { + /// Create a new instance. + /// + /// # Arguments. + /// + /// * `rcot` - The rcot for MPCOT. + /// * `setup_rcot` - The rcot for setup. + /// * `lpn_parameters` - The parameters of LPN. + /// * `lpn_type` - The type of LPN. + pub fn new( + rcot: RandomCOT, + setup_rcot: SetupRandomCOT, + lpn_parameters: LpnParameters, + lpn_type: LpnType, + ) -> Self { + Self { + rcot, + setup_rcot, + lpn_parameters, + lpn_type, + } + } + + /// Get rcot + pub fn rcot(&self) -> RandomCOT { + self.rcot.clone() + } + + /// Get the setup rcot + pub fn setup_rcot(&mut self) -> &mut SetupRandomCOT { + &mut self.setup_rcot + } + + /// Get the lpn type + pub fn lpn_type(&self) -> LpnType { + self.lpn_type + } + + /// Get the lpn parameters + pub fn lpn_parameters(&self) -> LpnParameters { + self.lpn_parameters + } +} + +#[cfg(test)] +mod tests { + use futures::TryFutureExt; + use mpz_common::executor::test_st_executor; + use mpz_core::{lpn::LpnParameters, Block}; + use mpz_ot_core::{ferret::LpnType, test::assert_cot, RCOTReceiverOutput, RCOTSenderOutput}; + + use crate::{ + ideal::cot::{ideal_rcot, IdealCOTReceiver, IdealCOTSender}, + OTError, RandomCOTReceiver, RandomCOTSender, + }; + + use super::*; + + // l = n - k = 8380 + const LPN_PARAMETERS_TEST: LpnParameters = LpnParameters { + n: 9600, + k: 1220, + t: 600, + }; + + fn setup() -> ( + Sender, + Receiver, + Block, + ) { + let (mut rcot_sender, rcot_receiver) = ideal_rcot(); + + let sender_config = FerretConfig::new( + rcot_sender.clone(), + rcot_sender.clone(), + LPN_PARAMETERS_TEST, + LpnType::Regular, + ); + + let receiver_config = FerretConfig::new( + rcot_receiver.clone(), + rcot_receiver, + LPN_PARAMETERS_TEST, + LpnType::Regular, + ); + + let delta = rcot_sender.alice().get_mut().delta(); + + let sender = Sender::new(sender_config); + + let receiver = Receiver::new(receiver_config); + + (sender, receiver, delta) + } + + #[tokio::test] + async fn test_ferret() { + let (mut ctx_sender, mut ctx_receiver) = test_st_executor(8); + + let (mut sender, mut receiver, delta) = setup(); + + tokio::try_join!( + sender + .setup_with_delta(&mut ctx_sender, delta) + .map_err(OTError::from), + receiver.setup(&mut ctx_receiver).map_err(OTError::from) + ) + .unwrap(); + + // extend once. + let count = 8000; + let ( + RCOTSenderOutput { + id: sender_id, + msgs: u, + }, + RCOTReceiverOutput { + id: receiver_id, + choices: b, + msgs: w, + }, + ) = tokio::try_join!( + sender.send_random_correlated(&mut ctx_sender, count), + receiver.receive_random_correlated(&mut ctx_receiver, count) + ) + .unwrap(); + + assert_eq!(sender_id, receiver_id); + assert_cot(delta, &b, &u, &w); + + // extend twice + let count = 9000; + let ( + RCOTSenderOutput { + id: sender_id, + msgs: u, + }, + RCOTReceiverOutput { + id: receiver_id, + choices: b, + msgs: w, + }, + ) = tokio::try_join!( + sender.send_random_correlated(&mut ctx_sender, count), + receiver.receive_random_correlated(&mut ctx_receiver, count) + ) + .unwrap(); + + assert_eq!(sender_id, receiver_id); + assert_cot(delta, &b, &u, &w); + } +} diff --git a/crates/mpz-ot/src/ferret/mpcot/error.rs b/crates/mpz-ot/src/ferret/mpcot/error.rs new file mode 100644 index 00000000..238808d0 --- /dev/null +++ b/crates/mpz-ot/src/ferret/mpcot/error.rs @@ -0,0 +1,59 @@ +use crate::OTError; + +/// A MPCOT sender error. +#[derive(Debug, thiserror::Error)] +#[allow(missing_docs)] +pub enum SenderError { + #[error(transparent)] + IOError(#[from] std::io::Error), + #[error(transparent)] + CoreError(#[from] mpz_ot_core::ferret::mpcot::error::SenderError), + #[error(transparent)] + SPCOTSenderError(#[from] crate::ferret::spcot::SenderError), + #[error("{0}")] + StateError(String), +} + +impl From for OTError { + fn from(err: SenderError) -> Self { + match err { + SenderError::IOError(e) => e.into(), + e => OTError::SenderError(Box::new(e)), + } + } +} + +impl From for SenderError { + fn from(err: crate::ferret::mpcot::sender::StateError) -> Self { + SenderError::StateError(err.to_string()) + } +} + +/// A MPCOT receiver error +#[derive(Debug, thiserror::Error)] +#[allow(missing_docs)] +pub enum ReceiverError { + #[error(transparent)] + IOError(#[from] std::io::Error), + #[error(transparent)] + CoreError(#[from] mpz_ot_core::ferret::mpcot::error::ReceiverError), + #[error(transparent)] + SpcotReceiverError(#[from] crate::ferret::spcot::ReceiverError), + #[error("{0}")] + StateError(String), +} + +impl From for OTError { + fn from(err: ReceiverError) -> Self { + match err { + ReceiverError::IOError(e) => e.into(), + e => OTError::ReceiverError(Box::new(e)), + } + } +} + +impl From for ReceiverError { + fn from(err: crate::ferret::mpcot::receiver::StateError) -> Self { + ReceiverError::StateError(err.to_string()) + } +} diff --git a/crates/mpz-ot/src/ferret/mpcot/mod.rs b/crates/mpz-ot/src/ferret/mpcot/mod.rs new file mode 100644 index 00000000..598b5734 --- /dev/null +++ b/crates/mpz-ot/src/ferret/mpcot/mod.rs @@ -0,0 +1,165 @@ +//! Implementation of the Multiple-Point COT (mpcot) protocol in the [`Ferret`](https://eprint.iacr.org/2020/924.pdf) paper. + +mod error; +mod receiver; +mod sender; + +pub(crate) use error::{ReceiverError, SenderError}; +pub(crate) use receiver::Receiver; +pub(crate) use sender::Sender; + +#[cfg(test)] +mod tests { + use futures::TryFutureExt; + use mpz_common::executor::test_st_executor; + use mpz_core::Block; + use mpz_ot_core::ferret::LpnType; + + use crate::{ + ideal::cot::{ideal_rcot, IdealCOTReceiver, IdealCOTSender}, + OTError, + }; + + use receiver::Receiver; + use sender::Sender; + + use super::*; + + fn setup( + lpn_type: LpnType, + ) -> ( + Sender, + Receiver, + IdealCOTSender, + IdealCOTReceiver, + Block, + ) { + let (mut rcot_sender, rcot_receiver) = ideal_rcot(); + + let delta = rcot_sender.alice().get_mut().delta(); + + let sender = Sender::new(lpn_type); + + let receiver = Receiver::new(lpn_type); + + (sender, receiver, rcot_sender, rcot_receiver, delta) + } + + #[tokio::test] + async fn test_mpcot() { + let (mut ctx_sender, mut ctx_receiver) = test_st_executor(8); + + let (mut sender, mut receiver, rcot_sender, rcot_receiver, delta) = setup(LpnType::Uniform); + + let alphas = [0, 1, 3, 4, 2]; + let t = alphas.len(); + let n = 10; + + tokio::try_join!( + sender + .setup_with_delta(&mut ctx_sender, delta, rcot_sender) + .map_err(OTError::from), + receiver + .setup(&mut ctx_receiver, rcot_receiver) + .map_err(OTError::from) + ) + .unwrap(); + + let (mut output_sender, output_receiver) = tokio::try_join!( + sender + .extend(&mut ctx_sender, t as u32, n) + .map_err(OTError::from), + receiver + .extend(&mut ctx_receiver, &alphas, n) + .map_err(OTError::from) + ) + .unwrap(); + + for i in alphas { + output_sender[i as usize] ^= delta; + } + + assert_eq!(output_sender, output_receiver); + + // extend twice + let alphas = [5, 1, 7, 2]; + let t = alphas.len(); + let n = 16; + + let (mut output_sender, output_receiver) = tokio::try_join!( + sender + .extend(&mut ctx_sender, t as u32, n) + .map_err(OTError::from), + receiver + .extend(&mut ctx_receiver, &alphas, n) + .map_err(OTError::from) + ) + .unwrap(); + + for i in alphas { + output_sender[i as usize] ^= delta; + } + + assert_eq!(output_sender, output_receiver); + + sender.finalize().unwrap(); + receiver.finalize().unwrap(); + + let (mut sender, mut receiver, rcot_sender, rcot_receiver, delta) = setup(LpnType::Regular); + + // extend once. + let alphas = [0, 3, 4, 7, 9]; + let t = alphas.len(); + let n = 10; + + tokio::try_join!( + sender + .setup_with_delta(&mut ctx_sender, delta, rcot_sender) + .map_err(OTError::from), + receiver + .setup(&mut ctx_receiver, rcot_receiver) + .map_err(OTError::from) + ) + .unwrap(); + + let (mut output_sender, output_receiver) = tokio::try_join!( + sender + .extend(&mut ctx_sender, t as u32, n) + .map_err(OTError::from), + receiver + .extend(&mut ctx_receiver, &alphas, n) + .map_err(OTError::from) + ) + .unwrap(); + + for i in alphas { + output_sender[i as usize] ^= delta; + } + + assert_eq!(output_sender, output_receiver); + + // extend twice. + let alphas = [0, 3, 7, 9, 14, 15]; + let t = alphas.len(); + let n = 16; + + let (mut output_sender, output_receiver) = tokio::try_join!( + sender + .extend(&mut ctx_sender, t as u32, n) + .map_err(OTError::from), + receiver + .extend(&mut ctx_receiver, &alphas, n) + .map_err(OTError::from) + ) + .unwrap(); + + for i in alphas { + output_sender[i as usize] ^= delta; + } + + assert_eq!(output_sender, output_receiver); + + sender.finalize().unwrap(); + receiver.finalize().unwrap(); + } +} diff --git a/crates/mpz-ot/src/ferret/mpcot/receiver.rs b/crates/mpz-ot/src/ferret/mpcot/receiver.rs new file mode 100644 index 00000000..e2553efd --- /dev/null +++ b/crates/mpz-ot/src/ferret/mpcot/receiver.rs @@ -0,0 +1,192 @@ +use crate::{ + ferret::{mpcot::error::ReceiverError, spcot::Receiver as SpcotReceiver}, + RandomCOTReceiver, +}; +use enum_try_as_inner::EnumTryAsInner; + +use mpz_common::Context; +use mpz_core::{prg::Prg, Block}; +use mpz_ot_core::ferret::{ + mpcot::{ + receiver::{state as uniform_state, Receiver as UniformReceiverCore}, + receiver_regular::{state as regular_state, Receiver as RegularReceiverCore}, + }, + LpnType, +}; +use serio::SinkExt; +use utils_aio::non_blocking_backend::{Backend, NonBlockingBackend}; + +#[derive(Debug, EnumTryAsInner)] +#[derive_err(Debug)] +pub(crate) enum State { + UniformInitialized(UniformReceiverCore), + UniformExtension(UniformReceiverCore), + RegularInitialized(RegularReceiverCore), + RegularExtension(RegularReceiverCore), + Complete, + Error, +} + +/// MPCOT receiver. +#[derive(Debug)] +pub(crate) struct Receiver { + state: State, + spcot: SpcotReceiver, + lpn_type: LpnType, +} + +impl Receiver { + /// Creates a new Sender. + /// + /// # Arguments. + /// + /// * `lpn_type` - The type of LPN. + pub(crate) fn new(lpn_type: LpnType) -> Self { + match lpn_type { + LpnType::Uniform => Self { + state: State::UniformInitialized(UniformReceiverCore::new()), + spcot: crate::ferret::spcot::Receiver::new(), + lpn_type, + }, + LpnType::Regular => Self { + state: State::RegularInitialized(RegularReceiverCore::new()), + spcot: crate::ferret::spcot::Receiver::new(), + lpn_type, + }, + } + } + + /// Performs setup for receiver. + /// + /// # Arguments + /// + /// * `ctx` - The context. + /// * `rcot` - The random COT used by Receiver. + pub(crate) async fn setup( + &mut self, + ctx: &mut Ctx, + rcot: RandomCOT, + ) -> Result<(), ReceiverError> { + match self.lpn_type { + LpnType::Uniform => { + let ext_receiver = std::mem::replace(&mut self.state, State::Error) + .try_into_uniform_initialized()?; + + let hash_seed = Prg::new().random_block(); + + let (ext_receiver, hash_seed) = ext_receiver.setup(hash_seed); + + ctx.io_mut().send(hash_seed).await?; + + self.state = State::UniformExtension(ext_receiver); + } + LpnType::Regular => { + let ext_receiver = std::mem::replace(&mut self.state, State::Error) + .try_into_regular_initialized()?; + + let ext_receiver = ext_receiver.setup(); + + self.state = State::RegularExtension(ext_receiver); + } + } + + self.spcot.setup(rcot)?; + + Ok(()) + } + + /// Performs MPCOT extension. + /// + /// + /// # Arguments + /// + /// * `ctx` - The context, + /// * `alphas` - The queried indices. + /// * `n` - The total number of indices. + pub(crate) async fn extend( + &mut self, + ctx: &mut Ctx, + alphas: &[u32], + n: u32, + ) -> Result, ReceiverError> + where + RandomCOT: RandomCOTReceiver, + { + let alphas_vec = alphas.to_vec(); + + match self.lpn_type { + LpnType::Uniform => { + let ext_receiver = std::mem::replace(&mut self.state, State::Error) + .try_into_uniform_extension()?; + + let (ext_receiver, h_and_pos) = + Backend::spawn(move || ext_receiver.pre_extend(&alphas_vec, n)).await?; + + let mut hs = vec![0usize; h_and_pos.len()]; + + let mut pos = vec![0u32; h_and_pos.len()]; + for (index, (h, p)) in h_and_pos.iter().enumerate() { + hs[index] = *h; + pos[index] = *p; + } + + self.spcot.extend(ctx, &pos, &hs).await?; + + let rt = self.spcot.check(ctx).await?; + + let rt: Vec> = rt.into_iter().map(|(elem, _)| elem).collect(); + let (ext_receiver, output) = + Backend::spawn(move || ext_receiver.extend(&rt)).await?; + + self.state = State::UniformExtension(ext_receiver); + + Ok(output) + } + + LpnType::Regular => { + let ext_receiver = std::mem::replace(&mut self.state, State::Error) + .try_into_regular_extension()?; + + let (ext_receiver, h_and_pos) = + Backend::spawn(move || ext_receiver.pre_extend(&alphas_vec, n)).await?; + + let mut hs = vec![0usize; h_and_pos.len()]; + + let mut pos = vec![0u32; h_and_pos.len()]; + for (index, (h, p)) in h_and_pos.iter().enumerate() { + hs[index] = *h; + pos[index] = *p; + } + + self.spcot.extend(ctx, &pos, &hs).await?; + + let rt = self.spcot.check(ctx).await?; + + let rt: Vec> = rt.into_iter().map(|(elem, _)| elem).collect(); + let (ext_receiver, output) = + Backend::spawn(move || ext_receiver.extend(&rt)).await?; + + self.state = State::RegularExtension(ext_receiver); + + Ok(output) + } + } + } + + /// Complete extension. + pub(crate) fn finalize(&mut self) -> Result<(), ReceiverError> { + match self.lpn_type { + LpnType::Uniform => { + std::mem::replace(&mut self.state, State::Error).try_into_uniform_extension()?; + } + LpnType::Regular => { + std::mem::replace(&mut self.state, State::Error).try_into_regular_extension()?; + } + } + + self.spcot.finalize()?; + self.state = State::Complete; + + Ok(()) + } +} diff --git a/crates/mpz-ot/src/ferret/mpcot/sender.rs b/crates/mpz-ot/src/ferret/mpcot/sender.rs new file mode 100644 index 00000000..a0256276 --- /dev/null +++ b/crates/mpz-ot/src/ferret/mpcot/sender.rs @@ -0,0 +1,166 @@ +use crate::{ + ferret::{mpcot::error::SenderError, spcot::Sender as SpcotSender}, + RandomCOTSender, +}; +use enum_try_as_inner::EnumTryAsInner; +use mpz_common::Context; +use mpz_core::Block; +use mpz_ot_core::ferret::{ + mpcot::{ + msgs::HashSeed, + sender::{state as uniform_state, Sender as UniformSenderCore}, + sender_regular::{state as regular_state, Sender as RegularSenderCore}, + }, + LpnType, +}; +use serio::stream::IoStreamExt; +use utils_aio::non_blocking_backend::{Backend, NonBlockingBackend}; + +#[derive(Debug, EnumTryAsInner)] +#[derive_err(Debug)] +pub(crate) enum State { + UniformInitialized(UniformSenderCore), + UniformExtension(UniformSenderCore), + RegularInitialized(RegularSenderCore), + RegularExtension(RegularSenderCore), + Complete, + Error, +} + +/// MPCOT sender. +#[derive(Debug)] +pub(crate) struct Sender { + state: State, + spcot: SpcotSender, + lpn_type: LpnType, +} + +impl Sender { + /// Creates a new Sender. + /// + /// # Arguments. + /// + /// * `lpn_type` - The type of LPN. + pub(crate) fn new(lpn_type: LpnType) -> Self { + match lpn_type { + LpnType::Uniform => Self { + state: State::UniformInitialized(UniformSenderCore::new()), + spcot: crate::ferret::spcot::Sender::new(), + lpn_type, + }, + LpnType::Regular => Self { + state: State::RegularInitialized(RegularSenderCore::new()), + spcot: crate::ferret::spcot::Sender::new(), + lpn_type, + }, + } + } + + /// Performs setup with provided delta. + /// + /// # Arguments + /// + /// * `ctx` - The channel. + /// * `delta` - The delta value to use for OT extension. + /// * `rcot` - The random COT used by Sender. + pub(crate) async fn setup_with_delta( + &mut self, + ctx: &mut Ctx, + delta: Block, + rcot: RandomCOT, + ) -> Result<(), SenderError> { + match self.lpn_type { + LpnType::Uniform => { + let ext_sender = std::mem::replace(&mut self.state, State::Error) + .try_into_uniform_initialized()?; + + let hash_seed: HashSeed = ctx.io_mut().expect_next().await?; + + let ext_sender = ext_sender.setup(delta, hash_seed); + + self.state = State::UniformExtension(ext_sender); + } + + LpnType::Regular => { + let ext_sender = std::mem::replace(&mut self.state, State::Error) + .try_into_regular_initialized()?; + + let ext_sender = ext_sender.setup(delta); + + self.state = State::RegularExtension(ext_sender); + } + } + + self.spcot.setup_with_delta(delta, rcot)?; + + Ok(()) + } + + /// Performs MPCOT extension. + /// + /// + /// # Arguments. + /// + /// * `ctx` - The context. + /// * `t` - The number of queried indices. + /// * `n` - The total number of indices. + pub(crate) async fn extend( + &mut self, + ctx: &mut Ctx, + t: u32, + n: u32, + ) -> Result, SenderError> + where + RandomCOT: RandomCOTSender, + { + match self.lpn_type { + LpnType::Uniform => { + let ext_sender = std::mem::replace(&mut self.state, State::Error) + .try_into_uniform_extension()?; + + let (ext_sender, hs) = Backend::spawn(move || ext_sender.pre_extend(t, n)).await?; + + self.spcot.extend(ctx, &hs).await?; + + let st = self.spcot.check(ctx).await?; + + let (ext_sender, output) = Backend::spawn(move || ext_sender.extend(&st)).await?; + + self.state = State::UniformExtension(ext_sender); + Ok(output) + } + LpnType::Regular => { + let ext_sender = std::mem::replace(&mut self.state, State::Error) + .try_into_regular_extension()?; + + let (ext_sender, hs) = Backend::spawn(move || ext_sender.pre_extend(t, n)).await?; + + self.spcot.extend(ctx, &hs).await?; + + let st = self.spcot.check(ctx).await?; + + let (ext_sender, output) = Backend::spawn(move || ext_sender.extend(&st)).await?; + + self.state = State::RegularExtension(ext_sender); + Ok(output) + } + } + } + + /// Complete extension. + pub(crate) fn finalize(&mut self) -> Result<(), SenderError> { + match self.lpn_type { + LpnType::Uniform => { + std::mem::replace(&mut self.state, State::Error).try_into_uniform_extension()?; + } + LpnType::Regular => { + std::mem::replace(&mut self.state, State::Error).try_into_regular_extension()?; + } + } + + self.spcot.finalize()?; + self.state = State::Complete; + + Ok(()) + } +} diff --git a/crates/mpz-ot/src/ferret/receiver.rs b/crates/mpz-ot/src/ferret/receiver.rs new file mode 100644 index 00000000..520506e8 --- /dev/null +++ b/crates/mpz-ot/src/ferret/receiver.rs @@ -0,0 +1,192 @@ +use crate::{ + ferret::{mpcot::Receiver as MpcotReceiver, ReceiverError}, + RandomCOTReceiver, +}; +use enum_try_as_inner::EnumTryAsInner; +use mpz_common::Context; +use mpz_core::{prg::Prg, Block}; +use mpz_ot_core::{ + ferret::receiver::{state, Receiver as ReceiverCore}, + RCOTReceiverOutput, +}; +use serio::SinkExt; +use utils_aio::non_blocking_backend::{Backend, NonBlockingBackend}; + +use super::FerretConfig; +use crate::{async_trait, OTError}; + +#[derive(Debug, EnumTryAsInner)] +#[derive_err(Debug)] +pub(crate) enum State { + Initialized(ReceiverCore), + Extension(ReceiverCore), + Complete, + Error, +} + +/// Ferret Receiver. +#[derive(Debug)] +pub struct Receiver { + state: State, + mpcot: MpcotReceiver, + config: FerretConfig, +} + +impl Receiver +where + RandomCOT: Send + Default + Clone, + SetupRandomCOT: Send, +{ + /// Creates a new Receiver. + /// + /// # Arguments. + /// + /// * `config` - Ferret configuration. + pub fn new(config: FerretConfig) -> Self { + Self { + state: State::Initialized(ReceiverCore::new()), + mpcot: MpcotReceiver::new(config.lpn_type()), + config, + } + } + + /// Setup for receiver. + /// + /// # Arguments. + /// + /// * `ctx` - The channel context. + pub async fn setup(&mut self, ctx: &mut Ctx) -> Result<(), ReceiverError> + where + Ctx: Context, + SetupRandomCOT: RandomCOTReceiver, + { + let ext_receiver = + std::mem::replace(&mut self.state, State::Error).try_into_initialized()?; + + let rcot = self.config.rcot(); + self.mpcot.setup(ctx, rcot).await?; + + let params = self.config.lpn_parameters(); + let lpn_type = self.config.lpn_type(); + + // Get random blocks from ideal Random COT. + + let RCOTReceiverOutput { + choices: u, + msgs: w, + .. + } = self + .config + .setup_rcot() + .receive_random_correlated(ctx, params.k) + .await?; + + let seed = Prg::new().random_block(); + + let (ext_receiver, seed) = ext_receiver.setup(params, lpn_type, seed, &u, &w)?; + + ctx.io_mut().send(seed).await?; + + self.state = State::Extension(ext_receiver); + + Ok(()) + } + + /// Performs extension. + /// + /// # Arguments + /// + /// * `ctx` - The channel context. + async fn extend(&mut self, ctx: &mut Ctx) -> Result<(Vec, Vec), ReceiverError> + where + Ctx: Context, + RandomCOT: RandomCOTReceiver, + { + let mut ext_receiver = + std::mem::replace(&mut self.state, State::Error).try_into_extension()?; + + let (alphas, n) = ext_receiver.get_mpcot_query(); + + let r = self.mpcot.extend(ctx, &alphas, n as u32).await?; + + let (ext_receiver, choices, msgs) = Backend::spawn(move || { + ext_receiver + .extend(&r) + .map(|(choices, msgs)| (ext_receiver, choices, msgs)) + }) + .await?; + + self.state = State::Extension(ext_receiver); + + Ok((choices, msgs)) + } + + /// Complete extension + pub fn finalize(&mut self) -> Result<(), ReceiverError> { + std::mem::replace(&mut self.state, State::Error).try_into_extension()?; + self.state = State::Complete; + self.mpcot.finalize()?; + + Ok(()) + } +} + +#[async_trait] +impl RandomCOTReceiver + for Receiver +where + Ctx: Context, + RandomCOT: RandomCOTReceiver + Send + Clone + Default + 'static, + SetupRandomCOT: Send + 'static, +{ + async fn receive_random_correlated( + &mut self, + ctx: &mut Ctx, + count: usize, + ) -> Result, OTError> { + let (mut choices_buffer, mut msgs_buffer) = self.extend(ctx).await?; + + assert_eq!(choices_buffer.len(), msgs_buffer.len()); + + let l = choices_buffer.len(); + + let id = self + .state + .try_as_extension() + .map_err(ReceiverError::from)? + .id(); + + if count <= l { + let choices_res = choices_buffer.drain(..count).collect(); + + let msgs_res = msgs_buffer.drain(..count).collect(); + + return Ok(RCOTReceiverOutput { + id, + choices: choices_res, + msgs: msgs_res, + }); + } else { + let mut choices_res = choices_buffer; + let mut msgs_res = msgs_buffer; + + for _ in 0..count / l - 1 { + (choices_buffer, msgs_buffer) = self.extend(ctx).await?; + + choices_res.extend_from_slice(&choices_buffer); + msgs_res.extend_from_slice(&msgs_buffer); + } + + (choices_buffer, msgs_buffer) = self.extend(ctx).await?; + + choices_res.extend_from_slice(&choices_buffer[0..count % l]); + msgs_res.extend_from_slice(&msgs_buffer[0..count % l]); + + return Ok(RCOTReceiverOutput { + id, + choices: choices_res, + msgs: msgs_res, + }); + } + } +} diff --git a/crates/mpz-ot/src/ferret/sender.rs b/crates/mpz-ot/src/ferret/sender.rs new file mode 100644 index 00000000..709ff8e2 --- /dev/null +++ b/crates/mpz-ot/src/ferret/sender.rs @@ -0,0 +1,160 @@ +use crate::{ferret::mpcot::Sender as MpcotSender, RandomCOTSender}; +use enum_try_as_inner::EnumTryAsInner; +use mpz_common::Context; +use mpz_core::Block; +use mpz_ot_core::{ + ferret::sender::{state, Sender as SenderCore}, + RCOTSenderOutput, +}; +use serio::stream::IoStreamExt; +use utils_aio::non_blocking_backend::{Backend, NonBlockingBackend}; + +use super::{FerretConfig, SenderError}; +use crate::{async_trait, OTError}; + +#[derive(Debug, EnumTryAsInner)] +#[derive_err(Debug)] +pub(crate) enum State { + Initialized(SenderCore), + Extension(SenderCore), + Complete, + Error, +} + +/// Ferret Sender. +#[derive(Debug)] +pub struct Sender { + state: State, + mpcot: MpcotSender, + config: FerretConfig, +} + +impl Sender +where + RandomCOT: Send + Default + Clone, + SetupRandomCOT: Send, +{ + /// Creates a new Sender. + pub fn new(config: FerretConfig) -> Self { + Self { + state: State::Initialized(SenderCore::new()), + mpcot: MpcotSender::new(config.lpn_type()), + config, + } + } + + /// Setup with provided delta. + /// + /// # Argument + /// + /// * `ctx` - The channel context. + /// * `delta` - The provided delta used for sender. + pub async fn setup_with_delta( + &mut self, + ctx: &mut Ctx, + delta: Block, + ) -> Result<(), SenderError> + where + Ctx: Context, + SetupRandomCOT: RandomCOTSender, + { + let ext_sender = std::mem::replace(&mut self.state, State::Error).try_into_initialized()?; + + let rcot = self.config.rcot(); + + self.mpcot.setup_with_delta(ctx, delta, rcot).await?; + + let params = self.config.lpn_parameters(); + let lpn_type = self.config.lpn_type(); + + // Get random blocks from ideal Random COT. + let RCOTSenderOutput { msgs: v, .. } = self + .config + .setup_rcot() + .send_random_correlated(ctx, params.k) + .await?; + + // Get seed for LPN matrix from receiver. + let seed = ctx.io_mut().expect_next().await?; + + // Ferret core setup. + let ext_sender = ext_sender.setup(delta, params, lpn_type, seed, &v)?; + + self.state = State::Extension(ext_sender); + + Ok(()) + } + + /// Performs extension. + /// + /// # Argument + /// + /// * `ctx` - The channel context. + async fn extend(&mut self, ctx: &mut Ctx) -> Result, SenderError> + where + RandomCOT: RandomCOTSender, + { + let mut ext_sender = + std::mem::replace(&mut self.state, State::Error).try_into_extension()?; + + let (t, n) = ext_sender.get_mpcot_query(); + + let s = self.mpcot.extend(ctx, t, n).await?; + + let (ext_sender, output) = + Backend::spawn(move || ext_sender.extend(&s).map(|output| (ext_sender, output))) + .await?; + self.state = State::Extension(ext_sender); + + Ok(output) + } + + /// Complete extension + pub fn finalize(&mut self) -> Result<(), SenderError> { + std::mem::replace(&mut self.state, State::Error).try_into_extension()?; + self.state = State::Complete; + self.mpcot.finalize()?; + + Ok(()) + } +} + +#[async_trait] +impl RandomCOTSender + for Sender +where + Ctx: Context, + RandomCOT: RandomCOTSender + Send + Default + Clone + 'static, + SetupRandomCOT: Send + 'static, +{ + async fn send_random_correlated( + &mut self, + ctx: &mut Ctx, + count: usize, + ) -> Result, OTError> { + let mut buffer = self.extend(ctx).await?; + let l = buffer.len(); + + let id = self + .state + .try_as_extension() + .map_err(SenderError::from)? + .id(); + + if count <= l { + let res = buffer.drain(..count).collect(); + return Ok(RCOTSenderOutput { id, msgs: res }); + } else { + let mut res = buffer; + for _ in 0..count / l - 1 { + buffer = self.extend(ctx).await?; + res.extend_from_slice(&buffer); + } + + buffer = self.extend(ctx).await?; + res.extend_from_slice(&buffer[0..count % l]); + + return Ok(RCOTSenderOutput { id, msgs: res }); + } + } +} diff --git a/crates/mpz-ot/src/ferret/spcot/error.rs b/crates/mpz-ot/src/ferret/spcot/error.rs new file mode 100644 index 00000000..0fa9dc9c --- /dev/null +++ b/crates/mpz-ot/src/ferret/spcot/error.rs @@ -0,0 +1,59 @@ +use crate::OTError; + +/// A SPCOT sender error. +#[derive(Debug, thiserror::Error)] +#[allow(missing_docs)] +pub enum SenderError { + #[error(transparent)] + IOError(#[from] std::io::Error), + #[error(transparent)] + CoreError(#[from] mpz_ot_core::ferret::spcot::error::SenderError), + #[error(transparent)] + RandomCOTError(#[from] OTError), + #[error("{0}")] + StateError(String), +} + +impl From for OTError { + fn from(err: SenderError) -> Self { + match err { + SenderError::IOError(e) => e.into(), + e => OTError::SenderError(Box::new(e)), + } + } +} + +impl From for SenderError { + fn from(err: crate::ferret::spcot::sender::StateError) -> Self { + SenderError::StateError(err.to_string()) + } +} + +/// A SPCOT receiver error. +#[derive(Debug, thiserror::Error)] +#[allow(missing_docs)] +pub enum ReceiverError { + #[error(transparent)] + IOError(#[from] std::io::Error), + #[error(transparent)] + CoreError(#[from] mpz_ot_core::ferret::spcot::error::ReceiverError), + #[error(transparent)] + RandomCOTError(#[from] OTError), + #[error("{0}")] + StateError(String), +} + +impl From for OTError { + fn from(err: ReceiverError) -> Self { + match err { + ReceiverError::IOError(e) => e.into(), + e => OTError::ReceiverError(Box::new(e)), + } + } +} + +impl From for ReceiverError { + fn from(err: crate::ferret::spcot::receiver::StateError) -> Self { + ReceiverError::StateError(err.to_string()) + } +} diff --git a/crates/mpz-ot/src/ferret/spcot/mod.rs b/crates/mpz-ot/src/ferret/spcot/mod.rs new file mode 100644 index 00000000..6e53fd28 --- /dev/null +++ b/crates/mpz-ot/src/ferret/spcot/mod.rs @@ -0,0 +1,103 @@ +//! Implementation of the Single-Point COT (spcot) protocol in the [`Ferret`](https://eprint.iacr.org/2020/924.pdf) paper. + +mod error; +mod receiver; +mod sender; + +pub(crate) use error::{ReceiverError, SenderError}; +pub(crate) use receiver::Receiver; +pub(crate) use sender::Sender; + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + ideal::cot::{ideal_rcot, IdealCOTReceiver, IdealCOTSender}, + OTError, + }; + use futures::TryFutureExt; + use mpz_common::executor::test_st_executor; + use mpz_core::Block; + + fn setup() -> ( + Sender, + Receiver, + IdealCOTSender, + IdealCOTReceiver, + Block, + ) { + let (mut rcot_sender, rcot_receiver) = ideal_rcot(); + + let delta = rcot_sender.alice().get_mut().delta(); + + let sender = Sender::new(); + let receiver = Receiver::new(); + + (sender, receiver, rcot_sender, rcot_receiver, delta) + } + + #[tokio::test] + async fn test_spcot() { + let (mut ctx_sender, mut ctx_receiver) = test_st_executor(8); + + let (mut sender, mut receiver, rcot_sender, rcot_receiver, delta) = setup(); + + // shold set the same delta as in RCOT. + sender.setup_with_delta(delta, rcot_sender).unwrap(); + receiver.setup(rcot_receiver).unwrap(); + + let hs = [8, 4]; + let alphas = [4, 2]; + + tokio::try_join!( + sender.extend(&mut ctx_sender, &hs).map_err(OTError::from), + receiver + .extend(&mut ctx_receiver, &alphas, &hs) + .map_err(OTError::from) + ) + .unwrap(); + + let (mut output_sender, output_receiver) = tokio::try_join!( + sender.check(&mut ctx_sender).map_err(OTError::from), + receiver.check(&mut ctx_receiver).map_err(OTError::from) + ) + .unwrap(); + + assert!(output_sender + .iter_mut() + .zip(output_receiver.iter()) + .all(|(vs, (ws, alpha))| { + vs[*alpha as usize] ^= delta; + vs == ws + })); + + // extend twice. + let hs = [6, 9, 8]; + let alphas = [2, 1, 3]; + + tokio::try_join!( + sender.extend(&mut ctx_sender, &hs).map_err(OTError::from), + receiver + .extend(&mut ctx_receiver, &alphas, &hs) + .map_err(OTError::from) + ) + .unwrap(); + + let (mut output_sender, output_receiver) = tokio::try_join!( + sender.check(&mut ctx_sender).map_err(OTError::from), + receiver.check(&mut ctx_receiver).map_err(OTError::from) + ) + .unwrap(); + + assert!(output_sender + .iter_mut() + .zip(output_receiver.iter()) + .all(|(vs, (ws, alpha))| { + vs[*alpha as usize] ^= delta; + vs == ws + })); + + sender.finalize().unwrap(); + receiver.finalize().unwrap(); + } +} diff --git a/crates/mpz-ot/src/ferret/spcot/receiver.rs b/crates/mpz-ot/src/ferret/spcot/receiver.rs new file mode 100644 index 00000000..3c48bfad --- /dev/null +++ b/crates/mpz-ot/src/ferret/spcot/receiver.rs @@ -0,0 +1,164 @@ +use crate::{ferret::spcot::error::ReceiverError, RandomCOTReceiver}; +use enum_try_as_inner::EnumTryAsInner; +use mpz_common::Context; +use mpz_core::Block; +use mpz_ot_core::{ + ferret::{ + spcot::{ + msgs::ExtendFromSender, + receiver::{state, Receiver as ReceiverCore}, + }, + CSP, + }, + RCOTReceiverOutput, +}; +use serio::{stream::IoStreamExt, SinkExt}; +use utils_aio::non_blocking_backend::{Backend, NonBlockingBackend}; + +#[derive(Debug, EnumTryAsInner)] +#[derive_err(Debug)] +pub(crate) enum State { + Initialized(ReceiverCore), + Extension(Box>), + Complete, + Error, +} + +/// SPCOT Receiver. +#[derive(Debug)] +pub(crate) struct Receiver { + state: State, + rcot: RandomCOT, +} + +impl Receiver { + /// Creates a new Receiver. + pub(crate) fn new() -> Self { + Self { + state: State::Initialized(ReceiverCore::new()), + rcot: Default::default(), + } + } + + /// Performs setup for receiver. + /// + /// # Arguments. + /// + /// * `rcot` - The random COT used by the receiver. + pub(crate) fn setup(&mut self, rcot: RandomCOT) -> Result<(), ReceiverError> { + let ext_receiver = + std::mem::replace(&mut self.state, State::Error).try_into_initialized()?; + + let ext_receiver = ext_receiver.setup(); + self.state = State::Extension(Box::new(ext_receiver)); + self.rcot = rcot; + Ok(()) + } + + /// Performs spcot extension for receiver. + /// + /// # Arguments + /// + /// * `ctx` - The context. + /// * `alphas`` - The vector of chosen positions. + /// * `h` - The depth of GGM tree. + pub(crate) async fn extend( + &mut self, + ctx: &mut Ctx, + alphas: &[u32], + hs: &[usize], + ) -> Result<(), ReceiverError> + where + RandomCOT: RandomCOTReceiver, + { + let mut ext_receiver = + std::mem::replace(&mut self.state, State::Error).try_into_extension()?; + + let h = hs.iter().sum(); + let RCOTReceiverOutput { + choices: rss, + msgs: tss, + .. + } = self.rcot.receive_random_correlated(ctx, h).await?; + + // extend + let h_in = hs.to_vec(); + let alphas_in = alphas.to_vec(); + let (mut ext_receiver, masks) = Backend::spawn(move || { + ext_receiver + .extend_mask_bits(&h_in, &alphas_in, &rss) + .map(|mask| (ext_receiver, mask)) + }) + .await?; + + ctx.io_mut().send(masks).await?; + + let extendfss: Vec = ctx.io_mut().expect_next().await?; + + let h_in = hs.to_vec(); + let alphas_in = alphas.to_vec(); + let ext_receiver = Backend::spawn(move || { + ext_receiver + .extend(&h_in, &alphas_in, &tss, &extendfss) + .map(|_| ext_receiver) + }) + .await?; + + self.state = State::Extension(ext_receiver); + + Ok(()) + } + + /// Performs batch check for SPCOT extension. + /// + /// # Arguments + /// + /// * `ctx` - The context. + pub(crate) async fn check( + &mut self, + ctx: &mut Ctx, + ) -> Result, u32)>, ReceiverError> + where + RandomCOT: RandomCOTReceiver, + { + let mut ext_receiver = + std::mem::replace(&mut self.state, State::Error).try_into_extension()?; + + // batch check + let RCOTReceiverOutput { + choices: x_star, + msgs: z_star, + .. + } = self.rcot.receive_random_correlated(ctx, CSP).await?; + + let (mut ext_receiver, checkfr) = Backend::spawn(move || { + ext_receiver + .check_pre(&x_star) + .map(|checkfr| (ext_receiver, checkfr)) + }) + .await?; + + ctx.io_mut().send(checkfr).await?; + let check = ctx.io_mut().expect_next().await?; + + let (ext_receiver, output) = Backend::spawn(move || { + ext_receiver + .check(&z_star, check) + .map(|output| (ext_receiver, output)) + }) + .await?; + + self.state = State::Extension(ext_receiver); + + Ok(output) + } + + /// Complete extension. + pub(crate) fn finalize(&mut self) -> Result<(), ReceiverError> { + std::mem::replace(&mut self.state, State::Error).try_into_extension()?; + + self.state = State::Complete; + + Ok(()) + } +} diff --git a/crates/mpz-ot/src/ferret/spcot/sender.rs b/crates/mpz-ot/src/ferret/spcot/sender.rs new file mode 100644 index 00000000..9178b787 --- /dev/null +++ b/crates/mpz-ot/src/ferret/spcot/sender.rs @@ -0,0 +1,144 @@ +use crate::{ferret::spcot::error::SenderError, RandomCOTSender}; +use enum_try_as_inner::EnumTryAsInner; +use mpz_common::Context; +use mpz_core::Block; +use mpz_ot_core::{ + ferret::{ + spcot::{ + msgs::MaskBits, + sender::{state, Sender as SenderCore}, + }, + CSP, + }, + RCOTSenderOutput, +}; +use serio::{stream::IoStreamExt, SinkExt}; +use utils_aio::non_blocking_backend::{Backend, NonBlockingBackend}; + +#[derive(Debug, EnumTryAsInner)] +#[derive_err(Debug)] +pub(crate) enum State { + Initialized(SenderCore), + Extension(Box>), + Complete, + Error, +} + +/// SPCOT sender. +#[derive(Debug)] +pub(crate) struct Sender { + state: State, + rcot: RandomCOT, +} + +impl Sender { + /// Creates a new Sender. + pub(crate) fn new() -> Self { + Self { + state: State::Initialized(SenderCore::new()), + rcot: Default::default(), + } + } + + /// Performs setup with the provided delta. + /// + /// # Arguments + /// + /// * `delta` - The delta value to use for OT extension. + /// * `rcot` - The random COT used by the sender. + pub(crate) fn setup_with_delta( + &mut self, + delta: Block, + rcot: RandomCOT, + ) -> Result<(), SenderError> { + let ext_sender = std::mem::replace(&mut self.state, State::Error).try_into_initialized()?; + + let ext_sender = ext_sender.setup(delta); + + self.state = State::Extension(Box::new(ext_sender)); + self.rcot = rcot; + Ok(()) + } + + /// Performs spcot extension for sender. + /// + /// # Arguments + /// + /// * `ctx` - The context. + /// * `hs` - The depths of GGM trees. + pub(crate) async fn extend( + &mut self, + ctx: &mut Ctx, + hs: &[usize], + ) -> Result<(), SenderError> + where + RandomCOT: RandomCOTSender, + { + let mut ext_sender = + std::mem::replace(&mut self.state, State::Error).try_into_extension()?; + + let h = hs.iter().sum(); + let RCOTSenderOutput { msgs: qss, .. } = self.rcot.send_random_correlated(ctx, h).await?; + + let masks: Vec = ctx.io_mut().expect_next().await?; + + // extend + let h_in = hs.to_vec(); + let (ext_sender, extend_msg) = Backend::spawn(move || { + ext_sender + .extend(&h_in, &qss, &masks) + .map(|extend_msg| (ext_sender, extend_msg)) + }) + .await?; + + ctx.io_mut().send(extend_msg).await?; + + self.state = State::Extension(ext_sender); + + Ok(()) + } + + /// Performs batch check for SPCOT extension. + /// + /// # Arguments + /// + /// * `ctx` - The context. + pub(crate) async fn check( + &mut self, + ctx: &mut Ctx, + ) -> Result>, SenderError> + where + RandomCOT: RandomCOTSender, + { + let mut ext_sender = + std::mem::replace(&mut self.state, State::Error).try_into_extension()?; + + // batch check + let RCOTSenderOutput { msgs: y_star, .. } = + self.rcot.send_random_correlated(ctx, CSP).await?; + + let checkfr = ctx.io_mut().expect_next().await?; + + let (ext_sender, output, check_msg) = Backend::spawn(move || { + ext_sender + .check(&y_star, checkfr) + .map(|(output, check_msg)| (ext_sender, output, check_msg)) + }) + .await?; + + ctx.io_mut().send(check_msg).await?; + + self.state = State::Extension(ext_sender); + + Ok(output) + } + + /// Complete extension. + pub(crate) fn finalize(&mut self) -> Result<(), SenderError> { + std::mem::replace(&mut self.state, State::Error).try_into_extension()?; + + self.state = State::Complete; + + Ok(()) + } +} diff --git a/crates/mpz-ot/src/ideal/cot.rs b/crates/mpz-ot/src/ideal/cot.rs index b0084957..18233dfe 100644 --- a/crates/mpz-ot/src/ideal/cot.rs +++ b/crates/mpz-ot/src/ideal/cot.rs @@ -46,9 +46,16 @@ pub fn ideal_rcot() -> (IdealCOTSender, IdealCOTReceiver) { } /// Ideal COT sender. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Default)] pub struct IdealCOTSender(Alice); +impl IdealCOTSender { + /// Returns Alice. + pub fn alice(&mut self) -> &mut Alice { + &mut self.0 + } +} + #[async_trait] impl OTSetup for IdealCOTSender where @@ -98,7 +105,7 @@ impl RandomCOTSender for IdealCOTSender { } /// Ideal COT receiver. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Default)] pub struct IdealCOTReceiver(Bob); #[async_trait] diff --git a/crates/mpz-ot/src/lib.rs b/crates/mpz-ot/src/lib.rs index b9871eab..0e4d1b48 100644 --- a/crates/mpz-ot/src/lib.rs +++ b/crates/mpz-ot/src/lib.rs @@ -10,6 +10,7 @@ )] pub mod chou_orlandi; +pub mod ferret; #[cfg(any(test, feature = "ideal"))] pub mod ideal; pub mod kos; From ef72b9b88820f789400d14996c0ba24745a007a4 Mon Sep 17 00:00:00 2001 From: Xiang Xie Date: Sat, 29 Jun 2024 00:11:59 +0800 Subject: [PATCH 2/6] cargo clippy --- crates/mpz-ot/src/ferret/mpcot/error.rs | 4 ++-- crates/mpz-ot/src/ferret/spcot/error.rs | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/crates/mpz-ot/src/ferret/mpcot/error.rs b/crates/mpz-ot/src/ferret/mpcot/error.rs index 238808d0..e300bf0d 100644 --- a/crates/mpz-ot/src/ferret/mpcot/error.rs +++ b/crates/mpz-ot/src/ferret/mpcot/error.rs @@ -2,7 +2,7 @@ use crate::OTError; /// A MPCOT sender error. #[derive(Debug, thiserror::Error)] -#[allow(missing_docs)] +#[allow(missing_docs, clippy::enum_variant_names)] pub enum SenderError { #[error(transparent)] IOError(#[from] std::io::Error), @@ -31,7 +31,7 @@ impl From for SenderError { /// A MPCOT receiver error #[derive(Debug, thiserror::Error)] -#[allow(missing_docs)] +#[allow(missing_docs, clippy::enum_variant_names)] pub enum ReceiverError { #[error(transparent)] IOError(#[from] std::io::Error), diff --git a/crates/mpz-ot/src/ferret/spcot/error.rs b/crates/mpz-ot/src/ferret/spcot/error.rs index 0fa9dc9c..5f23f466 100644 --- a/crates/mpz-ot/src/ferret/spcot/error.rs +++ b/crates/mpz-ot/src/ferret/spcot/error.rs @@ -2,7 +2,7 @@ use crate::OTError; /// A SPCOT sender error. #[derive(Debug, thiserror::Error)] -#[allow(missing_docs)] +#[allow(missing_docs, clippy::enum_variant_names)] pub enum SenderError { #[error(transparent)] IOError(#[from] std::io::Error), @@ -31,7 +31,7 @@ impl From for SenderError { /// A SPCOT receiver error. #[derive(Debug, thiserror::Error)] -#[allow(missing_docs)] +#[allow(missing_docs, clippy::enum_variant_names)] pub enum ReceiverError { #[error(transparent)] IOError(#[from] std::io::Error), From 116035e929c217aae106ed253d731f169781227f Mon Sep 17 00:00:00 2001 From: "sinu.eth" <65924192+sinui0@users.noreply.github.com> Date: Thu, 15 Aug 2024 19:08:06 -0700 Subject: [PATCH 3/6] refactor(mpz-ot): ferret clean up (#173) * refactor(mpz-ot): ferret clean up * buffer OTs, setup rcot only invoked once * fix mpcot test --------- Co-authored-by: Xiang Xie --- crates/mpz-common/src/ideal.rs | 4 +- crates/mpz-ot-core/src/ferret/mod.rs | 30 +- crates/mpz-ot-core/src/ferret/receiver.rs | 51 ++- crates/mpz-ot-core/src/ferret/sender.rs | 49 ++- crates/mpz-ot-core/src/lib.rs | 2 +- crates/mpz-ot/src/ferret/error.rs | 375 ++++++++++++++++++--- crates/mpz-ot/src/ferret/mod.rs | 123 ++----- crates/mpz-ot/src/ferret/mpcot.rs | 185 ++++++++++ crates/mpz-ot/src/ferret/mpcot/error.rs | 59 ---- crates/mpz-ot/src/ferret/mpcot/mod.rs | 165 --------- crates/mpz-ot/src/ferret/mpcot/receiver.rs | 192 ----------- crates/mpz-ot/src/ferret/mpcot/sender.rs | 166 --------- crates/mpz-ot/src/ferret/receiver.rs | 265 +++++++++------ crates/mpz-ot/src/ferret/sender.rs | 298 +++++++++++----- crates/mpz-ot/src/ferret/spcot.rs | 167 +++++++++ crates/mpz-ot/src/ferret/spcot/error.rs | 59 ---- crates/mpz-ot/src/ferret/spcot/mod.rs | 103 ------ crates/mpz-ot/src/ferret/spcot/receiver.rs | 164 --------- crates/mpz-ot/src/ferret/spcot/sender.rs | 144 -------- crates/mpz-ot/src/ideal/cot.rs | 16 +- crates/mpz-ot/src/lib.rs | 13 +- 21 files changed, 1227 insertions(+), 1403 deletions(-) create mode 100644 crates/mpz-ot/src/ferret/mpcot.rs delete mode 100644 crates/mpz-ot/src/ferret/mpcot/error.rs delete mode 100644 crates/mpz-ot/src/ferret/mpcot/mod.rs delete mode 100644 crates/mpz-ot/src/ferret/mpcot/receiver.rs delete mode 100644 crates/mpz-ot/src/ferret/mpcot/sender.rs create mode 100644 crates/mpz-ot/src/ferret/spcot.rs delete mode 100644 crates/mpz-ot/src/ferret/spcot/error.rs delete mode 100644 crates/mpz-ot/src/ferret/spcot/mod.rs delete mode 100644 crates/mpz-ot/src/ferret/spcot/receiver.rs delete mode 100644 crates/mpz-ot/src/ferret/spcot/sender.rs diff --git a/crates/mpz-common/src/ideal.rs b/crates/mpz-common/src/ideal.rs index 7fcb1628..1b6b3181 100644 --- a/crates/mpz-common/src/ideal.rs +++ b/crates/mpz-common/src/ideal.rs @@ -35,7 +35,7 @@ impl Clone for Alice { impl Alice { /// Returns a lock to the ideal functionality. - pub fn get_mut(&mut self) -> MutexGuard<'_, F> { + pub fn lock(&self) -> MutexGuard<'_, F> { self.f.lock().unwrap() } @@ -96,7 +96,7 @@ impl Clone for Bob { impl Bob { /// Returns a lock to the ideal functionality. - pub fn get_mut(&mut self) -> MutexGuard<'_, F> { + pub fn lock(&self) -> MutexGuard<'_, F> { self.f.lock().unwrap() } diff --git a/crates/mpz-ot-core/src/ferret/mod.rs b/crates/mpz-ot-core/src/ferret/mod.rs index bbbf264a..0e27f0a9 100644 --- a/crates/mpz-ot-core/src/ferret/mod.rs +++ b/crates/mpz-ot-core/src/ferret/mod.rs @@ -52,9 +52,11 @@ mod tests { use receiver::Receiver; use sender::Sender; - use crate::ideal::{cot::IdealCOT, mpcot::IdealMpcot}; - use crate::test::assert_cot; - use crate::{MPCOTReceiverOutput, MPCOTSenderOutput, RCOTReceiverOutput, RCOTSenderOutput}; + use crate::{ + ideal::{cot::IdealCOT, mpcot::IdealMpcot}, + test::assert_cot, + MPCOTReceiverOutput, MPCOTSenderOutput, RCOTReceiverOutput, RCOTSenderOutput, + }; use mpz_core::{lpn::LpnParameters, prg::Prg}; const LPN_PARAMETERS_TEST: LpnParameters = LpnParameters { @@ -111,8 +113,15 @@ mod tests { let (MPCOTSenderOutput { s, .. }, MPCOTReceiverOutput { r, .. }) = ideal_mpcot.extend(&query.0, query.1); - let msgs = sender.extend(&s).unwrap(); - let (choices, received) = receiver.extend(&r).unwrap(); + sender.extend(s).unwrap(); + receiver.extend(r).unwrap(); + + let RCOTSenderOutput { msgs, .. } = sender.consume(2).unwrap(); + let RCOTReceiverOutput { + choices, + msgs: received, + .. + } = receiver.consume(2).unwrap(); assert_cot(delta, &choices, &msgs, &received); @@ -123,8 +132,15 @@ mod tests { let (MPCOTSenderOutput { s, .. }, MPCOTReceiverOutput { r, .. }) = ideal_mpcot.extend(&query.0, query.1); - let msgs = sender.extend(&s).unwrap(); - let (choices, received) = receiver.extend(&r).unwrap(); + sender.extend(s).unwrap(); + receiver.extend(r).unwrap(); + + let RCOTSenderOutput { msgs, .. } = sender.consume(sender.remaining()).unwrap(); + let RCOTReceiverOutput { + choices, + msgs: received, + .. + } = receiver.consume(receiver.remaining()).unwrap(); assert_cot(delta, &choices, &msgs, &received); } diff --git a/crates/mpz-ot-core/src/ferret/receiver.rs b/crates/mpz-ot-core/src/ferret/receiver.rs index e5939c60..1cfd1e08 100644 --- a/crates/mpz-ot-core/src/ferret/receiver.rs +++ b/crates/mpz-ot-core/src/ferret/receiver.rs @@ -1,4 +1,6 @@ //! Ferret receiver +use std::collections::VecDeque; + use mpz_core::{ lpn::{LpnEncoder, LpnParameters}, Block, @@ -6,7 +8,7 @@ use mpz_core::{ use crate::{ ferret::{error::ReceiverError, LpnType}, - TransferId, + RCOTReceiverOutput, TransferId, }; use super::msgs::LpnMatrixSeed; @@ -63,6 +65,8 @@ impl Receiver { w: w.to_vec(), e: Vec::default(), id: TransferId::default(), + choices_buffer: VecDeque::new(), + msgs_buffer: VecDeque::new(), }, }, LpnMatrixSeed { seed }, @@ -71,6 +75,16 @@ impl Receiver { } impl Receiver { + /// Returns the current transfer id. + pub fn id(&self) -> TransferId { + self.state.id + } + + /// Returns the number of remaining COTs. + pub fn remaining(&self) -> usize { + self.state.choices_buffer.len() + } + /// The prepare precedure of extension, sample error vectors and outputs information for MPCOT. /// See step 3 and 4. pub fn get_mpcot_query(&mut self) -> (Vec, usize) { @@ -100,7 +114,7 @@ impl Receiver { /// # Arguments. /// /// * `r` - The vector received from the MPCOT protocol. - pub fn extend(&mut self, r: &[Block]) -> Result<(Vec, Vec), ReceiverError> { + pub fn extend(&mut self, r: Vec) -> Result<(), ReceiverError> { if r.len() != self.state.lpn_parameters.n { return Err(ReceiverError("the length of r should be n".to_string())); } @@ -108,7 +122,7 @@ impl Receiver { self.state.id.next(); // Compute z = A * w + r. - let mut z = r.to_vec(); + let mut z = r; self.state.lpn_encoder.compute(&mut z, &self.state.w); // Compute x = A * u + e. @@ -133,12 +147,32 @@ impl Receiver { // Update counter self.state.counter += 1; - Ok((x_, z_)) + self.state.choices_buffer.extend(x_); + self.state.msgs_buffer.extend(z_); + + Ok(()) } - /// Returns id - pub fn id(&self) -> TransferId { - self.state.id + /// Consumes `count` COTs. + pub fn consume( + &mut self, + count: usize, + ) -> Result, ReceiverError> { + if count > self.state.choices_buffer.len() { + return Err(ReceiverError(format!( + "insufficient OTs: {} < {count}", + self.state.choices_buffer.len() + ))); + } + + let choices = self.state.choices_buffer.drain(0..count).collect(); + let msgs = self.state.msgs_buffer.drain(0..count).collect(); + + Ok(RCOTReceiverOutput { + id: self.state.id.next(), + choices, + msgs, + }) } } @@ -186,6 +220,9 @@ pub mod state { /// TransferID pub(super) id: TransferId, + /// Extended OTs buffers. + pub(super) choices_buffer: VecDeque, + pub(super) msgs_buffer: VecDeque, } impl State for Extension {} diff --git a/crates/mpz-ot-core/src/ferret/sender.rs b/crates/mpz-ot-core/src/ferret/sender.rs index 2af3e4ae..436d6003 100644 --- a/crates/mpz-ot-core/src/ferret/sender.rs +++ b/crates/mpz-ot-core/src/ferret/sender.rs @@ -1,4 +1,6 @@ //! Ferret sender. +use std::collections::VecDeque; + use mpz_core::{ lpn::{LpnEncoder, LpnParameters}, Block, @@ -6,7 +8,7 @@ use mpz_core::{ use crate::{ ferret::{error::SenderError, LpnType}, - TransferId, + RCOTSenderOutput, TransferId, }; use super::msgs::LpnMatrixSeed; @@ -61,12 +63,28 @@ impl Sender { lpn_encoder, v: v.to_vec(), id: TransferId::default(), + msgs_buffer: VecDeque::new(), }, }) } } impl Sender { + /// Returns the current transfer id. + pub fn id(&self) -> TransferId { + self.state.id + } + + /// Returns the number of remaining COTs. + pub fn remaining(&self) -> usize { + self.state.msgs_buffer.len() + } + + /// Returns the delta correlation. + pub fn delta(&self) -> Block { + self.state.delta + } + /// Outputs the information for MPCOT. /// /// See step 3 and 4. @@ -86,7 +104,7 @@ impl Sender { /// # Arguments. /// /// * `s` - The vector received from the MPCOT protocol. - pub fn extend(&mut self, s: &[Block]) -> Result, SenderError> { + pub fn extend(&mut self, s: Vec) -> Result<(), SenderError> { if s.len() != self.state.lpn_parameters.n { return Err(SenderError("the length of s should be n".to_string())); } @@ -94,7 +112,7 @@ impl Sender { self.state.id.next(); // Compute y = A * v + s - let mut y = s.to_vec(); + let mut y = s; self.state.lpn_encoder.compute(&mut y, &self.state.v); let y_ = y.split_off(self.state.lpn_parameters.k); @@ -104,13 +122,26 @@ impl Sender { // Update counter self.state.counter += 1; + self.state.msgs_buffer.extend(y_); - Ok(y_) + Ok(()) } - /// Returns id - pub fn id(&self) -> TransferId { - self.state.id + /// Consumes `count` COTs. + pub fn consume(&mut self, count: usize) -> Result, SenderError> { + if count > self.state.msgs_buffer.len() { + return Err(SenderError(format!( + "insufficient OTs: {} < {count}", + self.state.msgs_buffer.len() + ))); + } + + let msgs = self.state.msgs_buffer.drain(0..count).collect(); + + Ok(RCOTSenderOutput { + id: self.state.id.next(), + msgs, + }) } } @@ -159,8 +190,10 @@ pub mod state { /// Sender's COT message in the setup phase. pub(super) v: Vec, - /// TransferID. + /// Transfer ID. pub(crate) id: TransferId, + /// COT messages buffer. + pub(super) msgs_buffer: VecDeque, } impl State for Extension {} diff --git a/crates/mpz-ot-core/src/lib.rs b/crates/mpz-ot-core/src/lib.rs index 8dd77287..dcfffc59 100644 --- a/crates/mpz-ot-core/src/lib.rs +++ b/crates/mpz-ot-core/src/lib.rs @@ -45,7 +45,7 @@ impl std::fmt::Display for TransferId { impl TransferId { /// Returns the current transfer ID, incrementing `self` in-place. - pub(crate) fn next(&mut self) -> Self { + pub fn next(&mut self) -> Self { let id = *self; self.0 += 1; id diff --git a/crates/mpz-ot/src/ferret/error.rs b/crates/mpz-ot/src/ferret/error.rs index 6952f0ec..4e428a4b 100644 --- a/crates/mpz-ot/src/ferret/error.rs +++ b/crates/mpz-ot/src/ferret/error.rs @@ -1,67 +1,342 @@ -use crate::OTError; +use std::fmt::Display; -/// A Ferret sender error. +/// Ferret sender error. #[derive(Debug, thiserror::Error)] -#[allow(missing_docs)] -pub enum SenderError { - #[error(transparent)] - IOError(#[from] std::io::Error), - #[error(transparent)] - CoreError(#[from] mpz_ot_core::ferret::error::SenderError), - #[error(transparent)] - MPCOTSenderError(#[from] crate::ferret::mpcot::SenderError), - #[error(transparent)] - RandomCOTError(#[from] OTError), - #[error("{0}")] - StateError(String), - #[error("{0}")] - MPCOTSenderTypeError(String), -} - -impl From for OTError { - fn from(err: SenderError) -> Self { - match err { - SenderError::IOError(e) => e.into(), - e => OTError::SenderError(Box::new(e)), +pub struct SenderError { + kind: SenderErrorKind, + #[source] + source: Option>, +} + +impl SenderError { + pub(crate) fn state(msg: impl Into) -> Self { + Self { + kind: SenderErrorKind::State, + source: Some(msg.into().into()), + } + } + + pub(crate) fn io(msg: impl Into) -> Self { + Self { + kind: SenderErrorKind::Io, + source: Some(msg.into().into()), + } + } +} + +#[derive(Debug)] +enum SenderErrorKind { + Io, + State, + Core, + Rcot, + Mpcot, +} + +impl Display for SenderError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match &self.kind { + SenderErrorKind::Io => f.write_str("io error")?, + SenderErrorKind::State => f.write_str("state error")?, + SenderErrorKind::Core => f.write_str("core error")?, + SenderErrorKind::Rcot => f.write_str("rcot error")?, + SenderErrorKind::Mpcot => f.write_str("mpcot error")?, + } + + if let Some(source) = &self.source { + write!(f, " caused by: {}", source) + } else { + Ok(()) + } + } +} + +impl From for SenderError { + fn from(err: std::io::Error) -> Self { + Self { + kind: SenderErrorKind::Io, + source: Some(Box::new(err)), + } + } +} + +impl From for SenderError { + fn from(err: mpz_ot_core::ferret::error::SenderError) -> Self { + Self { + kind: SenderErrorKind::Core, + source: Some(Box::new(err)), + } + } +} + +impl From for SenderError { + fn from(err: crate::OTError) -> Self { + Self { + kind: SenderErrorKind::Rcot, + source: Some(Box::new(err)), + } + } +} + +impl From for SenderError { + fn from(err: MPCOTError) -> Self { + Self { + kind: SenderErrorKind::Mpcot, + source: Some(Box::new(err)), } } } -impl From for SenderError { - fn from(err: crate::ferret::sender::StateError) -> Self { - SenderError::StateError(err.to_string()) +impl From for crate::OTError { + fn from(err: SenderError) -> Self { + crate::OTError::SenderError(Box::new(err)) } } -/// A Ferret receiver error. +/// Ferret receiver error. #[derive(Debug, thiserror::Error)] -#[allow(missing_docs)] -pub enum ReceiverError { - #[error(transparent)] - IOError(#[from] std::io::Error), - #[error(transparent)] - CoreError(#[from] mpz_ot_core::ferret::error::ReceiverError), - #[error(transparent)] - MPCOTReceiverError(#[from] crate::ferret::mpcot::ReceiverError), - #[error(transparent)] - RandomCOTError(#[from] OTError), - #[error("{0}")] - StateError(String), - #[error("{0}")] - MPCOTReceiverTypeError(String), -} - -impl From for OTError { +pub struct ReceiverError { + kind: ReceiverErrorKind, + #[source] + source: Option>, +} + +impl ReceiverError { + pub(crate) fn state(msg: impl Into) -> Self { + Self { + kind: ReceiverErrorKind::State, + source: Some(msg.into().into()), + } + } + + pub(crate) fn io(msg: impl Into) -> Self { + Self { + kind: ReceiverErrorKind::Io, + source: Some(msg.into().into()), + } + } +} + +#[derive(Debug)] +enum ReceiverErrorKind { + Io, + State, + Core, + Rcot, + Mpcot, +} + +impl Display for ReceiverError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match &self.kind { + ReceiverErrorKind::Io => f.write_str("io error")?, + ReceiverErrorKind::State => f.write_str("state error")?, + ReceiverErrorKind::Core => f.write_str("core error")?, + ReceiverErrorKind::Rcot => f.write_str("rcot error")?, + ReceiverErrorKind::Mpcot => f.write_str("mpcot error")?, + } + + if let Some(source) = &self.source { + write!(f, " caused by: {}", source) + } else { + Ok(()) + } + } +} + +impl From for ReceiverError { + fn from(err: std::io::Error) -> Self { + Self { + kind: ReceiverErrorKind::Io, + source: Some(Box::new(err)), + } + } +} + +impl From for ReceiverError { + fn from(err: mpz_ot_core::ferret::error::ReceiverError) -> Self { + Self { + kind: ReceiverErrorKind::Core, + source: Some(Box::new(err)), + } + } +} + +impl From for ReceiverError { + fn from(err: crate::OTError) -> Self { + Self { + kind: ReceiverErrorKind::Rcot, + source: Some(Box::new(err)), + } + } +} + +impl From for ReceiverError { + fn from(err: MPCOTError) -> Self { + Self { + kind: ReceiverErrorKind::Mpcot, + source: Some(Box::new(err)), + } + } +} + +impl From for crate::OTError { fn from(err: ReceiverError) -> Self { - match err { - ReceiverError::IOError(e) => e.into(), - e => OTError::ReceiverError(Box::new(e)), + crate::OTError::ReceiverError(Box::new(err)) + } +} + +mod mpcot { + use super::*; + + /// MPCOT error. + #[derive(Debug, thiserror::Error)] + pub(crate) struct MPCOTError { + kind: ErrorKind, + #[source] + source: Option>, + } + + #[derive(Debug)] + enum ErrorKind { + Io, + Core, + Rcot, + Spcot, + } + + impl Display for MPCOTError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match &self.kind { + ErrorKind::Io => f.write_str("io error")?, + ErrorKind::Core => f.write_str("core error")?, + ErrorKind::Rcot => f.write_str("rcot error")?, + ErrorKind::Spcot => f.write_str("spcot error")?, + } + + if let Some(source) = &self.source { + write!(f, " caused by: {}", source) + } else { + Ok(()) + } + } + } + + impl From for MPCOTError { + fn from(err: std::io::Error) -> Self { + Self { + kind: ErrorKind::Io, + source: Some(Box::new(err)), + } + } + } + + impl From for MPCOTError { + fn from(err: mpz_ot_core::ferret::mpcot::error::SenderError) -> Self { + Self { + kind: ErrorKind::Core, + source: Some(Box::new(err)), + } + } + } + + impl From for MPCOTError { + fn from(err: mpz_ot_core::ferret::mpcot::error::ReceiverError) -> Self { + Self { + kind: ErrorKind::Core, + source: Some(Box::new(err)), + } + } + } + + impl From for MPCOTError { + fn from(err: SPCOTError) -> Self { + Self { + kind: ErrorKind::Spcot, + source: Some(Box::new(err)), + } + } + } + + impl From for MPCOTError { + fn from(err: crate::OTError) -> Self { + Self { + kind: ErrorKind::Rcot, + source: Some(Box::new(err)), + } } } } +pub(crate) use mpcot::MPCOTError; + +mod spcot { + use super::*; + + /// SPCOT error. + #[derive(Debug, thiserror::Error)] + pub(crate) struct SPCOTError { + kind: ErrorKind, + #[source] + source: Option>, + } + + #[derive(Debug)] + enum ErrorKind { + Io, + Core, + Rcot, + } + + impl Display for SPCOTError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match &self.kind { + ErrorKind::Io => f.write_str("io error")?, + ErrorKind::Core => f.write_str("core error")?, + ErrorKind::Rcot => f.write_str("rcot error")?, + } -impl From for ReceiverError { - fn from(err: crate::ferret::receiver::StateError) -> Self { - ReceiverError::StateError(err.to_string()) + if let Some(source) = &self.source { + write!(f, " caused by: {}", source) + } else { + Ok(()) + } + } + } + + impl From for SPCOTError { + fn from(err: std::io::Error) -> Self { + Self { + kind: ErrorKind::Io, + source: Some(Box::new(err)), + } + } + } + + impl From for SPCOTError { + fn from(err: mpz_ot_core::ferret::spcot::error::SenderError) -> Self { + Self { + kind: ErrorKind::Core, + source: Some(Box::new(err)), + } + } + } + + impl From for SPCOTError { + fn from(err: mpz_ot_core::ferret::spcot::error::ReceiverError) -> Self { + Self { + kind: ErrorKind::Core, + source: Some(Box::new(err)), + } + } + } + + impl From for SPCOTError { + fn from(err: crate::OTError) -> Self { + Self { + kind: ErrorKind::Rcot, + source: Some(Box::new(err)), + } + } } } +pub(crate) use spcot::SPCOTError; diff --git a/crates/mpz-ot/src/ferret/mod.rs b/crates/mpz-ot/src/ferret/mod.rs index 2b2047b9..086e5e8b 100644 --- a/crates/mpz-ot/src/ferret/mod.rs +++ b/crates/mpz-ot/src/ferret/mod.rs @@ -13,47 +13,26 @@ use mpz_core::lpn::LpnParameters; use mpz_ot_core::ferret::LpnType; /// Configuration of Ferret. -#[derive(Debug)] -pub struct FerretConfig { - rcot: RandomCOT, - setup_rcot: SetupRandomCOT, +#[derive(Debug, Clone)] +pub struct FerretConfig { lpn_parameters: LpnParameters, lpn_type: LpnType, } -impl FerretConfig { +impl FerretConfig { /// Create a new instance. /// /// # Arguments. /// - /// * `rcot` - The rcot for MPCOT. - /// * `setup_rcot` - The rcot for setup. /// * `lpn_parameters` - The parameters of LPN. /// * `lpn_type` - The type of LPN. - pub fn new( - rcot: RandomCOT, - setup_rcot: SetupRandomCOT, - lpn_parameters: LpnParameters, - lpn_type: LpnType, - ) -> Self { + pub fn new(lpn_parameters: LpnParameters, lpn_type: LpnType) -> Self { Self { - rcot, - setup_rcot, lpn_parameters, lpn_type, } } - /// Get rcot - pub fn rcot(&self) -> RandomCOT { - self.rcot.clone() - } - - /// Get the setup rcot - pub fn setup_rcot(&mut self) -> &mut SetupRandomCOT { - &mut self.setup_rcot - } - /// Get the lpn type pub fn lpn_type(&self) -> LpnType { self.lpn_type @@ -67,17 +46,14 @@ impl FerretConfig { #[cfg(test)] mod tests { - use futures::TryFutureExt; + use super::*; + use futures::TryFutureExt as _; use mpz_common::executor::test_st_executor; - use mpz_core::{lpn::LpnParameters, Block}; + use mpz_core::lpn::LpnParameters; use mpz_ot_core::{ferret::LpnType, test::assert_cot, RCOTReceiverOutput, RCOTSenderOutput}; + use rstest::*; - use crate::{ - ideal::cot::{ideal_rcot, IdealCOTReceiver, IdealCOTSender}, - OTError, RandomCOTReceiver, RandomCOTSender, - }; - - use super::*; + use crate::{ideal::cot::ideal_rcot, Correlation, OTError, RandomCOTReceiver, RandomCOTSender}; // l = n - k = 8380 const LPN_PARAMETERS_TEST: LpnParameters = LpnParameters { @@ -86,73 +62,46 @@ mod tests { t: 600, }; - fn setup() -> ( - Sender, - Receiver, - Block, - ) { - let (mut rcot_sender, rcot_receiver) = ideal_rcot(); - - let sender_config = FerretConfig::new( - rcot_sender.clone(), - rcot_sender.clone(), - LPN_PARAMETERS_TEST, - LpnType::Regular, - ); - - let receiver_config = FerretConfig::new( - rcot_receiver.clone(), - rcot_receiver, - LPN_PARAMETERS_TEST, - LpnType::Regular, - ); - - let delta = rcot_sender.alice().get_mut().delta(); - - let sender = Sender::new(sender_config); - - let receiver = Receiver::new(receiver_config); - - (sender, receiver, delta) - } - + #[rstest] + #[case::uniform(LpnType::Uniform)] + #[case::regular(LpnType::Regular)] #[tokio::test] - async fn test_ferret() { + async fn test_ferret(#[case] lpn_type: LpnType) { let (mut ctx_sender, mut ctx_receiver) = test_st_executor(8); - let (mut sender, mut receiver, delta) = setup(); + let (rcot_sender, rcot_receiver) = ideal_rcot(); + + let config = FerretConfig::new(LPN_PARAMETERS_TEST, lpn_type); + + let mut sender = Sender::new(config.clone(), rcot_sender); + let mut receiver = Receiver::new(config, rcot_receiver); tokio::try_join!( - sender - .setup_with_delta(&mut ctx_sender, delta) - .map_err(OTError::from), + sender.setup(&mut ctx_sender).map_err(OTError::from), receiver.setup(&mut ctx_receiver).map_err(OTError::from) ) .unwrap(); // extend once. - let count = 8000; - let ( - RCOTSenderOutput { - id: sender_id, - msgs: u, - }, - RCOTReceiverOutput { - id: receiver_id, - choices: b, - msgs: w, - }, - ) = tokio::try_join!( - sender.send_random_correlated(&mut ctx_sender, count), - receiver.receive_random_correlated(&mut ctx_receiver, count) + let count = LPN_PARAMETERS_TEST.k; + tokio::try_join!( + sender.extend(&mut ctx_sender, count).map_err(OTError::from), + receiver + .extend(&mut ctx_receiver, count) + .map_err(OTError::from) ) .unwrap(); - assert_eq!(sender_id, receiver_id); - assert_cot(delta, &b, &u, &w); - // extend twice - let count = 9000; + let count = 10000; + tokio::try_join!( + sender.extend(&mut ctx_sender, count).map_err(OTError::from), + receiver + .extend(&mut ctx_receiver, count) + .map_err(OTError::from) + ) + .unwrap(); + let ( RCOTSenderOutput { id: sender_id, @@ -170,6 +119,6 @@ mod tests { .unwrap(); assert_eq!(sender_id, receiver_id); - assert_cot(delta, &b, &u, &w); + assert_cot(sender.delta(), &b, &u, &w); } } diff --git a/crates/mpz-ot/src/ferret/mpcot.rs b/crates/mpz-ot/src/ferret/mpcot.rs new file mode 100644 index 00000000..be7de33a --- /dev/null +++ b/crates/mpz-ot/src/ferret/mpcot.rs @@ -0,0 +1,185 @@ +//! Implementation of the Multiple-Point COT (mpcot) protocol in the [`Ferret`](https://eprint.iacr.org/2020/924.pdf) paper. + +use mpz_common::{cpu::CpuBackend, Context}; +use mpz_core::{prg::Prg, Block}; +use mpz_ot_core::ferret::{ + mpcot::{ + msgs::HashSeed, receiver::Receiver as UniformReceiverCore, + receiver_regular::Receiver as RegularReceiverCore, sender::Sender as UniformSender, + sender_regular::Sender as RegularSender, + }, + LpnType, +}; +use serio::{stream::IoStreamExt as _, SinkExt as _}; + +use crate::{ + ferret::{error::MPCOTError as Error, spcot}, + RandomCOTReceiver, RandomCOTSender, +}; + +/// MPCOT send. +/// +/// # Arguments. +/// +/// * `ctx` - Thread context. +/// * `rcot` - Random COT sender. +/// * `delta` - Delta correlation. +/// * `lpn_type` - The type of LPN. +/// * `t` - The number of queried indices. +/// * `n` - The total number of indices. +pub(crate) async fn send>( + ctx: &mut Ctx, + rcot: &mut RandomCOT, + delta: Block, + lpn_type: LpnType, + t: u32, + n: u32, +) -> Result, Error> { + match lpn_type { + LpnType::Uniform => { + let hash_seed: HashSeed = ctx.io_mut().expect_next().await?; + + let (sender, hs) = CpuBackend::blocking(move || { + UniformSender::new() + .setup(delta, hash_seed) + .pre_extend(t, n) + }) + .await?; + + let st = spcot::send(ctx, rcot, delta, &hs).await?; + + let (_, output) = CpuBackend::blocking(move || sender.extend(&st)).await?; + + Ok(output) + } + LpnType::Regular => { + let (sender, hs) = + CpuBackend::blocking(move || RegularSender::new().setup(delta).pre_extend(t, n)) + .await?; + + let st = spcot::send(ctx, rcot, delta, &hs).await?; + + let (_, output) = CpuBackend::blocking(move || sender.extend(&st)).await?; + + Ok(output) + } + } +} + +/// MPCOT receive. +/// +/// # Arguments +/// +/// * `ctx` - Thread context. +/// * `rcot` - Random COT receiver. +/// * `lpn_type` - The type of LPN. +/// * `alphas` - The queried indices. +/// * `n` - The total number of indices. +pub(crate) async fn receive>( + ctx: &mut Ctx, + rcot: &mut RandomCOT, + lpn_type: LpnType, + alphas: Vec, + n: u32, +) -> Result, Error> { + match lpn_type { + LpnType::Uniform => { + let hash_seed = Prg::new().random_block(); + + let (receiver, hash_seed) = UniformReceiverCore::new().setup(hash_seed); + + ctx.io_mut().send(hash_seed).await?; + + let (receiver, h_and_pos) = + CpuBackend::blocking(move || receiver.pre_extend(&alphas, n)).await?; + + let mut hs = vec![0usize; h_and_pos.len()]; + + let mut pos = vec![0u32; h_and_pos.len()]; + for (index, (h, p)) in h_and_pos.iter().enumerate() { + hs[index] = *h; + pos[index] = *p; + } + + let rt = spcot::receive(ctx, rcot, &pos, &hs).await?; + let rt: Vec> = rt.into_iter().map(|(elem, _)| elem).collect(); + let (_, output) = CpuBackend::blocking(move || receiver.extend(&rt)).await?; + + Ok(output) + } + LpnType::Regular => { + let receiver = RegularReceiverCore::new().setup(); + + let (receiver, h_and_pos) = + CpuBackend::blocking(move || receiver.pre_extend(&alphas, n)).await?; + + let mut hs = vec![0usize; h_and_pos.len()]; + + let mut pos = vec![0u32; h_and_pos.len()]; + for (index, (h, p)) in h_and_pos.iter().enumerate() { + hs[index] = *h; + pos[index] = *p; + } + + let rt = spcot::receive(ctx, rcot, &pos, &hs).await?; + let rt: Vec> = rt.into_iter().map(|(elem, _)| elem).collect(); + let (_, output) = CpuBackend::blocking(move || receiver.extend(&rt)).await?; + + Ok(output) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ideal::cot::ideal_rcot; + use mpz_common::executor::test_st_executor; + use mpz_ot_core::ferret::LpnType; + use rstest::*; + + #[rstest] + #[case(LpnType::Uniform)] + #[case(LpnType::Regular)] + #[tokio::test] + async fn test_mpcot(#[case] lpn_type: LpnType) { + use crate::Correlation; + + let (mut ctx_sender, mut ctx_receiver) = test_st_executor(8); + let (mut rcot_sender, mut rcot_receiver) = ideal_rcot(); + + let alphas = match lpn_type { + LpnType::Uniform => vec![0, 1, 3, 4, 2], + LpnType::Regular => vec![0, 3, 4, 7, 9], + }; + + let t = alphas.len(); + let n = 10; + let delta = rcot_sender.delta(); + + let (mut output_sender, output_receiver) = tokio::try_join!( + send( + &mut ctx_sender, + &mut rcot_sender, + delta, + lpn_type, + t as u32, + n + ), + receive( + &mut ctx_receiver, + &mut rcot_receiver, + lpn_type, + alphas.clone(), + n + ) + ) + .unwrap(); + + for i in alphas { + output_sender[i as usize] ^= delta; + } + + assert_eq!(output_sender, output_receiver); + } +} diff --git a/crates/mpz-ot/src/ferret/mpcot/error.rs b/crates/mpz-ot/src/ferret/mpcot/error.rs deleted file mode 100644 index e300bf0d..00000000 --- a/crates/mpz-ot/src/ferret/mpcot/error.rs +++ /dev/null @@ -1,59 +0,0 @@ -use crate::OTError; - -/// A MPCOT sender error. -#[derive(Debug, thiserror::Error)] -#[allow(missing_docs, clippy::enum_variant_names)] -pub enum SenderError { - #[error(transparent)] - IOError(#[from] std::io::Error), - #[error(transparent)] - CoreError(#[from] mpz_ot_core::ferret::mpcot::error::SenderError), - #[error(transparent)] - SPCOTSenderError(#[from] crate::ferret::spcot::SenderError), - #[error("{0}")] - StateError(String), -} - -impl From for OTError { - fn from(err: SenderError) -> Self { - match err { - SenderError::IOError(e) => e.into(), - e => OTError::SenderError(Box::new(e)), - } - } -} - -impl From for SenderError { - fn from(err: crate::ferret::mpcot::sender::StateError) -> Self { - SenderError::StateError(err.to_string()) - } -} - -/// A MPCOT receiver error -#[derive(Debug, thiserror::Error)] -#[allow(missing_docs, clippy::enum_variant_names)] -pub enum ReceiverError { - #[error(transparent)] - IOError(#[from] std::io::Error), - #[error(transparent)] - CoreError(#[from] mpz_ot_core::ferret::mpcot::error::ReceiverError), - #[error(transparent)] - SpcotReceiverError(#[from] crate::ferret::spcot::ReceiverError), - #[error("{0}")] - StateError(String), -} - -impl From for OTError { - fn from(err: ReceiverError) -> Self { - match err { - ReceiverError::IOError(e) => e.into(), - e => OTError::ReceiverError(Box::new(e)), - } - } -} - -impl From for ReceiverError { - fn from(err: crate::ferret::mpcot::receiver::StateError) -> Self { - ReceiverError::StateError(err.to_string()) - } -} diff --git a/crates/mpz-ot/src/ferret/mpcot/mod.rs b/crates/mpz-ot/src/ferret/mpcot/mod.rs deleted file mode 100644 index 598b5734..00000000 --- a/crates/mpz-ot/src/ferret/mpcot/mod.rs +++ /dev/null @@ -1,165 +0,0 @@ -//! Implementation of the Multiple-Point COT (mpcot) protocol in the [`Ferret`](https://eprint.iacr.org/2020/924.pdf) paper. - -mod error; -mod receiver; -mod sender; - -pub(crate) use error::{ReceiverError, SenderError}; -pub(crate) use receiver::Receiver; -pub(crate) use sender::Sender; - -#[cfg(test)] -mod tests { - use futures::TryFutureExt; - use mpz_common::executor::test_st_executor; - use mpz_core::Block; - use mpz_ot_core::ferret::LpnType; - - use crate::{ - ideal::cot::{ideal_rcot, IdealCOTReceiver, IdealCOTSender}, - OTError, - }; - - use receiver::Receiver; - use sender::Sender; - - use super::*; - - fn setup( - lpn_type: LpnType, - ) -> ( - Sender, - Receiver, - IdealCOTSender, - IdealCOTReceiver, - Block, - ) { - let (mut rcot_sender, rcot_receiver) = ideal_rcot(); - - let delta = rcot_sender.alice().get_mut().delta(); - - let sender = Sender::new(lpn_type); - - let receiver = Receiver::new(lpn_type); - - (sender, receiver, rcot_sender, rcot_receiver, delta) - } - - #[tokio::test] - async fn test_mpcot() { - let (mut ctx_sender, mut ctx_receiver) = test_st_executor(8); - - let (mut sender, mut receiver, rcot_sender, rcot_receiver, delta) = setup(LpnType::Uniform); - - let alphas = [0, 1, 3, 4, 2]; - let t = alphas.len(); - let n = 10; - - tokio::try_join!( - sender - .setup_with_delta(&mut ctx_sender, delta, rcot_sender) - .map_err(OTError::from), - receiver - .setup(&mut ctx_receiver, rcot_receiver) - .map_err(OTError::from) - ) - .unwrap(); - - let (mut output_sender, output_receiver) = tokio::try_join!( - sender - .extend(&mut ctx_sender, t as u32, n) - .map_err(OTError::from), - receiver - .extend(&mut ctx_receiver, &alphas, n) - .map_err(OTError::from) - ) - .unwrap(); - - for i in alphas { - output_sender[i as usize] ^= delta; - } - - assert_eq!(output_sender, output_receiver); - - // extend twice - let alphas = [5, 1, 7, 2]; - let t = alphas.len(); - let n = 16; - - let (mut output_sender, output_receiver) = tokio::try_join!( - sender - .extend(&mut ctx_sender, t as u32, n) - .map_err(OTError::from), - receiver - .extend(&mut ctx_receiver, &alphas, n) - .map_err(OTError::from) - ) - .unwrap(); - - for i in alphas { - output_sender[i as usize] ^= delta; - } - - assert_eq!(output_sender, output_receiver); - - sender.finalize().unwrap(); - receiver.finalize().unwrap(); - - let (mut sender, mut receiver, rcot_sender, rcot_receiver, delta) = setup(LpnType::Regular); - - // extend once. - let alphas = [0, 3, 4, 7, 9]; - let t = alphas.len(); - let n = 10; - - tokio::try_join!( - sender - .setup_with_delta(&mut ctx_sender, delta, rcot_sender) - .map_err(OTError::from), - receiver - .setup(&mut ctx_receiver, rcot_receiver) - .map_err(OTError::from) - ) - .unwrap(); - - let (mut output_sender, output_receiver) = tokio::try_join!( - sender - .extend(&mut ctx_sender, t as u32, n) - .map_err(OTError::from), - receiver - .extend(&mut ctx_receiver, &alphas, n) - .map_err(OTError::from) - ) - .unwrap(); - - for i in alphas { - output_sender[i as usize] ^= delta; - } - - assert_eq!(output_sender, output_receiver); - - // extend twice. - let alphas = [0, 3, 7, 9, 14, 15]; - let t = alphas.len(); - let n = 16; - - let (mut output_sender, output_receiver) = tokio::try_join!( - sender - .extend(&mut ctx_sender, t as u32, n) - .map_err(OTError::from), - receiver - .extend(&mut ctx_receiver, &alphas, n) - .map_err(OTError::from) - ) - .unwrap(); - - for i in alphas { - output_sender[i as usize] ^= delta; - } - - assert_eq!(output_sender, output_receiver); - - sender.finalize().unwrap(); - receiver.finalize().unwrap(); - } -} diff --git a/crates/mpz-ot/src/ferret/mpcot/receiver.rs b/crates/mpz-ot/src/ferret/mpcot/receiver.rs deleted file mode 100644 index e2553efd..00000000 --- a/crates/mpz-ot/src/ferret/mpcot/receiver.rs +++ /dev/null @@ -1,192 +0,0 @@ -use crate::{ - ferret::{mpcot::error::ReceiverError, spcot::Receiver as SpcotReceiver}, - RandomCOTReceiver, -}; -use enum_try_as_inner::EnumTryAsInner; - -use mpz_common::Context; -use mpz_core::{prg::Prg, Block}; -use mpz_ot_core::ferret::{ - mpcot::{ - receiver::{state as uniform_state, Receiver as UniformReceiverCore}, - receiver_regular::{state as regular_state, Receiver as RegularReceiverCore}, - }, - LpnType, -}; -use serio::SinkExt; -use utils_aio::non_blocking_backend::{Backend, NonBlockingBackend}; - -#[derive(Debug, EnumTryAsInner)] -#[derive_err(Debug)] -pub(crate) enum State { - UniformInitialized(UniformReceiverCore), - UniformExtension(UniformReceiverCore), - RegularInitialized(RegularReceiverCore), - RegularExtension(RegularReceiverCore), - Complete, - Error, -} - -/// MPCOT receiver. -#[derive(Debug)] -pub(crate) struct Receiver { - state: State, - spcot: SpcotReceiver, - lpn_type: LpnType, -} - -impl Receiver { - /// Creates a new Sender. - /// - /// # Arguments. - /// - /// * `lpn_type` - The type of LPN. - pub(crate) fn new(lpn_type: LpnType) -> Self { - match lpn_type { - LpnType::Uniform => Self { - state: State::UniformInitialized(UniformReceiverCore::new()), - spcot: crate::ferret::spcot::Receiver::new(), - lpn_type, - }, - LpnType::Regular => Self { - state: State::RegularInitialized(RegularReceiverCore::new()), - spcot: crate::ferret::spcot::Receiver::new(), - lpn_type, - }, - } - } - - /// Performs setup for receiver. - /// - /// # Arguments - /// - /// * `ctx` - The context. - /// * `rcot` - The random COT used by Receiver. - pub(crate) async fn setup( - &mut self, - ctx: &mut Ctx, - rcot: RandomCOT, - ) -> Result<(), ReceiverError> { - match self.lpn_type { - LpnType::Uniform => { - let ext_receiver = std::mem::replace(&mut self.state, State::Error) - .try_into_uniform_initialized()?; - - let hash_seed = Prg::new().random_block(); - - let (ext_receiver, hash_seed) = ext_receiver.setup(hash_seed); - - ctx.io_mut().send(hash_seed).await?; - - self.state = State::UniformExtension(ext_receiver); - } - LpnType::Regular => { - let ext_receiver = std::mem::replace(&mut self.state, State::Error) - .try_into_regular_initialized()?; - - let ext_receiver = ext_receiver.setup(); - - self.state = State::RegularExtension(ext_receiver); - } - } - - self.spcot.setup(rcot)?; - - Ok(()) - } - - /// Performs MPCOT extension. - /// - /// - /// # Arguments - /// - /// * `ctx` - The context, - /// * `alphas` - The queried indices. - /// * `n` - The total number of indices. - pub(crate) async fn extend( - &mut self, - ctx: &mut Ctx, - alphas: &[u32], - n: u32, - ) -> Result, ReceiverError> - where - RandomCOT: RandomCOTReceiver, - { - let alphas_vec = alphas.to_vec(); - - match self.lpn_type { - LpnType::Uniform => { - let ext_receiver = std::mem::replace(&mut self.state, State::Error) - .try_into_uniform_extension()?; - - let (ext_receiver, h_and_pos) = - Backend::spawn(move || ext_receiver.pre_extend(&alphas_vec, n)).await?; - - let mut hs = vec![0usize; h_and_pos.len()]; - - let mut pos = vec![0u32; h_and_pos.len()]; - for (index, (h, p)) in h_and_pos.iter().enumerate() { - hs[index] = *h; - pos[index] = *p; - } - - self.spcot.extend(ctx, &pos, &hs).await?; - - let rt = self.spcot.check(ctx).await?; - - let rt: Vec> = rt.into_iter().map(|(elem, _)| elem).collect(); - let (ext_receiver, output) = - Backend::spawn(move || ext_receiver.extend(&rt)).await?; - - self.state = State::UniformExtension(ext_receiver); - - Ok(output) - } - - LpnType::Regular => { - let ext_receiver = std::mem::replace(&mut self.state, State::Error) - .try_into_regular_extension()?; - - let (ext_receiver, h_and_pos) = - Backend::spawn(move || ext_receiver.pre_extend(&alphas_vec, n)).await?; - - let mut hs = vec![0usize; h_and_pos.len()]; - - let mut pos = vec![0u32; h_and_pos.len()]; - for (index, (h, p)) in h_and_pos.iter().enumerate() { - hs[index] = *h; - pos[index] = *p; - } - - self.spcot.extend(ctx, &pos, &hs).await?; - - let rt = self.spcot.check(ctx).await?; - - let rt: Vec> = rt.into_iter().map(|(elem, _)| elem).collect(); - let (ext_receiver, output) = - Backend::spawn(move || ext_receiver.extend(&rt)).await?; - - self.state = State::RegularExtension(ext_receiver); - - Ok(output) - } - } - } - - /// Complete extension. - pub(crate) fn finalize(&mut self) -> Result<(), ReceiverError> { - match self.lpn_type { - LpnType::Uniform => { - std::mem::replace(&mut self.state, State::Error).try_into_uniform_extension()?; - } - LpnType::Regular => { - std::mem::replace(&mut self.state, State::Error).try_into_regular_extension()?; - } - } - - self.spcot.finalize()?; - self.state = State::Complete; - - Ok(()) - } -} diff --git a/crates/mpz-ot/src/ferret/mpcot/sender.rs b/crates/mpz-ot/src/ferret/mpcot/sender.rs deleted file mode 100644 index a0256276..00000000 --- a/crates/mpz-ot/src/ferret/mpcot/sender.rs +++ /dev/null @@ -1,166 +0,0 @@ -use crate::{ - ferret::{mpcot::error::SenderError, spcot::Sender as SpcotSender}, - RandomCOTSender, -}; -use enum_try_as_inner::EnumTryAsInner; -use mpz_common::Context; -use mpz_core::Block; -use mpz_ot_core::ferret::{ - mpcot::{ - msgs::HashSeed, - sender::{state as uniform_state, Sender as UniformSenderCore}, - sender_regular::{state as regular_state, Sender as RegularSenderCore}, - }, - LpnType, -}; -use serio::stream::IoStreamExt; -use utils_aio::non_blocking_backend::{Backend, NonBlockingBackend}; - -#[derive(Debug, EnumTryAsInner)] -#[derive_err(Debug)] -pub(crate) enum State { - UniformInitialized(UniformSenderCore), - UniformExtension(UniformSenderCore), - RegularInitialized(RegularSenderCore), - RegularExtension(RegularSenderCore), - Complete, - Error, -} - -/// MPCOT sender. -#[derive(Debug)] -pub(crate) struct Sender { - state: State, - spcot: SpcotSender, - lpn_type: LpnType, -} - -impl Sender { - /// Creates a new Sender. - /// - /// # Arguments. - /// - /// * `lpn_type` - The type of LPN. - pub(crate) fn new(lpn_type: LpnType) -> Self { - match lpn_type { - LpnType::Uniform => Self { - state: State::UniformInitialized(UniformSenderCore::new()), - spcot: crate::ferret::spcot::Sender::new(), - lpn_type, - }, - LpnType::Regular => Self { - state: State::RegularInitialized(RegularSenderCore::new()), - spcot: crate::ferret::spcot::Sender::new(), - lpn_type, - }, - } - } - - /// Performs setup with provided delta. - /// - /// # Arguments - /// - /// * `ctx` - The channel. - /// * `delta` - The delta value to use for OT extension. - /// * `rcot` - The random COT used by Sender. - pub(crate) async fn setup_with_delta( - &mut self, - ctx: &mut Ctx, - delta: Block, - rcot: RandomCOT, - ) -> Result<(), SenderError> { - match self.lpn_type { - LpnType::Uniform => { - let ext_sender = std::mem::replace(&mut self.state, State::Error) - .try_into_uniform_initialized()?; - - let hash_seed: HashSeed = ctx.io_mut().expect_next().await?; - - let ext_sender = ext_sender.setup(delta, hash_seed); - - self.state = State::UniformExtension(ext_sender); - } - - LpnType::Regular => { - let ext_sender = std::mem::replace(&mut self.state, State::Error) - .try_into_regular_initialized()?; - - let ext_sender = ext_sender.setup(delta); - - self.state = State::RegularExtension(ext_sender); - } - } - - self.spcot.setup_with_delta(delta, rcot)?; - - Ok(()) - } - - /// Performs MPCOT extension. - /// - /// - /// # Arguments. - /// - /// * `ctx` - The context. - /// * `t` - The number of queried indices. - /// * `n` - The total number of indices. - pub(crate) async fn extend( - &mut self, - ctx: &mut Ctx, - t: u32, - n: u32, - ) -> Result, SenderError> - where - RandomCOT: RandomCOTSender, - { - match self.lpn_type { - LpnType::Uniform => { - let ext_sender = std::mem::replace(&mut self.state, State::Error) - .try_into_uniform_extension()?; - - let (ext_sender, hs) = Backend::spawn(move || ext_sender.pre_extend(t, n)).await?; - - self.spcot.extend(ctx, &hs).await?; - - let st = self.spcot.check(ctx).await?; - - let (ext_sender, output) = Backend::spawn(move || ext_sender.extend(&st)).await?; - - self.state = State::UniformExtension(ext_sender); - Ok(output) - } - LpnType::Regular => { - let ext_sender = std::mem::replace(&mut self.state, State::Error) - .try_into_regular_extension()?; - - let (ext_sender, hs) = Backend::spawn(move || ext_sender.pre_extend(t, n)).await?; - - self.spcot.extend(ctx, &hs).await?; - - let st = self.spcot.check(ctx).await?; - - let (ext_sender, output) = Backend::spawn(move || ext_sender.extend(&st)).await?; - - self.state = State::RegularExtension(ext_sender); - Ok(output) - } - } - } - - /// Complete extension. - pub(crate) fn finalize(&mut self) -> Result<(), SenderError> { - match self.lpn_type { - LpnType::Uniform => { - std::mem::replace(&mut self.state, State::Error).try_into_uniform_extension()?; - } - LpnType::Regular => { - std::mem::replace(&mut self.state, State::Error).try_into_regular_extension()?; - } - } - - self.spcot.finalize()?; - self.state = State::Complete; - - Ok(()) - } -} diff --git a/crates/mpz-ot/src/ferret/receiver.rs b/crates/mpz-ot/src/ferret/receiver.rs index 520506e8..d04e5d29 100644 --- a/crates/mpz-ot/src/ferret/receiver.rs +++ b/crates/mpz-ot/src/ferret/receiver.rs @@ -1,52 +1,61 @@ -use crate::{ - ferret::{mpcot::Receiver as MpcotReceiver, ReceiverError}, - RandomCOTReceiver, -}; -use enum_try_as_inner::EnumTryAsInner; -use mpz_common::Context; +use std::mem; + +use async_trait::async_trait; +use mpz_common::{cpu::CpuBackend, Allocate, Context, Preprocess}; use mpz_core::{prg::Prg, Block}; use mpz_ot_core::{ - ferret::receiver::{state, Receiver as ReceiverCore}, + ferret::{ + receiver::{state, Receiver as ReceiverCore}, + LpnType, CSP, CUCKOO_HASH_NUM, + }, RCOTReceiverOutput, }; use serio::SinkExt; -use utils_aio::non_blocking_backend::{Backend, NonBlockingBackend}; -use super::FerretConfig; -use crate::{async_trait, OTError}; +use crate::{ + ferret::{mpcot, FerretConfig, ReceiverError}, + OTError, RandomCOTReceiver, +}; -#[derive(Debug, EnumTryAsInner)] -#[derive_err(Debug)] +#[derive(Debug)] pub(crate) enum State { Initialized(ReceiverCore), Extension(ReceiverCore), - Complete, Error, } +impl State { + fn take(&mut self) -> Self { + std::mem::replace(self, State::Error) + } +} + /// Ferret Receiver. #[derive(Debug)] -pub struct Receiver { +pub struct Receiver { state: State, - mpcot: MpcotReceiver, - config: FerretConfig, + config: FerretConfig, + rcot: RandomCOT, + alloc: usize, + buffer: ReceiverBuffer, + buffer_len: usize, } -impl Receiver -where - RandomCOT: Send + Default + Clone, - SetupRandomCOT: Send, -{ +impl Receiver { /// Creates a new Receiver. /// /// # Arguments. /// - /// * `config` - Ferret configuration. - pub fn new(config: FerretConfig) -> Self { + /// * `config` - The Ferret config. + /// * `rcot` - The random COT in setup. + pub fn new(config: FerretConfig, rcot: RandomCOT) -> Self { Self { state: State::Initialized(ReceiverCore::new()), - mpcot: MpcotReceiver::new(config.lpn_type()), config, + rcot, + alloc: 0, + buffer: Default::default(), + buffer_len: 0, } } @@ -58,36 +67,59 @@ where pub async fn setup(&mut self, ctx: &mut Ctx) -> Result<(), ReceiverError> where Ctx: Context, - SetupRandomCOT: RandomCOTReceiver, + RandomCOT: RandomCOTReceiver, { - let ext_receiver = - std::mem::replace(&mut self.state, State::Error).try_into_initialized()?; - - let rcot = self.config.rcot(); - self.mpcot.setup(ctx, rcot).await?; + let State::Initialized(receiver) = self.state.take() else { + return Err(ReceiverError::state("receiver not in initialized state")); + }; let params = self.config.lpn_parameters(); let lpn_type = self.config.lpn_type(); - // Get random blocks from ideal Random COT. + // Compute the number of buffered OTs. + self.buffer_len = match lpn_type { + // The number here is a rough estimation to ensure sufficient buffer. + // It is hard to precisely compute the number because of the Cuckoo hashes. + LpnType::Uniform => { + let m = (1.5 * (params.t as f32)).ceil() as usize; + m * ((2 * CUCKOO_HASH_NUM * params.n / m) + .checked_next_power_of_two() + .expect("The length should be less than usize::MAX / 2 - 1") + .ilog2() as usize) + + CSP + } + // In our chosen paramters, we always set n is divided by t and n/t is a power of 2. + LpnType::Regular => { + assert!(params.n % params.t == 0 && (params.n / params.t).is_power_of_two()); + params.t * ((params.n / params.t).ilog2() as usize) + CSP + } + }; + // Get random blocks from ideal Random COT. let RCOTReceiverOutput { - choices: u, - msgs: w, - .. + choices: mut u, + msgs: mut w, + id, } = self - .config - .setup_rcot() - .receive_random_correlated(ctx, params.k) + .rcot + .receive_random_correlated(ctx, params.k + self.buffer_len) .await?; + // Initiate buffer. + let buffer = RCOTReceiverOutput { + id, + choices: u.drain(0..self.buffer_len).collect(), + msgs: w.drain(0..self.buffer_len).collect(), + }; + self.buffer = ReceiverBuffer::new(buffer); + let seed = Prg::new().random_block(); - let (ext_receiver, seed) = ext_receiver.setup(params, lpn_type, seed, &u, &w)?; + let (receiver, seed) = receiver.setup(params, lpn_type, seed, &u, &w)?; ctx.io_mut().send(seed).await?; - self.state = State::Extension(ext_receiver); + self.state = State::Extension(receiver); Ok(()) } @@ -96,97 +128,126 @@ where /// /// # Arguments /// - /// * `ctx` - The channel context. - async fn extend(&mut self, ctx: &mut Ctx) -> Result<(Vec, Vec), ReceiverError> + /// * `ctx` - Thread context. + /// * `count` - The number of OTs to extend. + pub async fn extend(&mut self, ctx: &mut Ctx, count: usize) -> Result<(), ReceiverError> where Ctx: Context, - RandomCOT: RandomCOTReceiver, + RandomCOT: RandomCOTReceiver + Send, { - let mut ext_receiver = - std::mem::replace(&mut self.state, State::Error).try_into_extension()?; + let State::Extension(mut receiver) = self.state.take() else { + return Err(ReceiverError::state("receiver not in extension state")); + }; - let (alphas, n) = ext_receiver.get_mpcot_query(); + let lpn_type = self.config.lpn_type(); + let target = receiver.remaining() + count; + while receiver.remaining() < target { + let (alphas, n) = receiver.get_mpcot_query(); - let r = self.mpcot.extend(ctx, &alphas, n as u32).await?; + let r = mpcot::receive(ctx, &mut self.buffer, lpn_type, alphas, n as u32).await?; - let (ext_receiver, choices, msgs) = Backend::spawn(move || { - ext_receiver - .extend(&r) - .map(|(choices, msgs)| (ext_receiver, choices, msgs)) - }) - .await?; + receiver = CpuBackend::blocking(move || receiver.extend(r).map(|()| receiver)).await?; - self.state = State::Extension(ext_receiver); + // Update receiver buffer. + let buffer = receiver + .consume(self.buffer_len) + .map_err(ReceiverError::from) + .map_err(OTError::from)?; - Ok((choices, msgs)) - } + self.buffer = ReceiverBuffer::new(buffer); + } - /// Complete extension - pub fn finalize(&mut self) -> Result<(), ReceiverError> { - std::mem::replace(&mut self.state, State::Error).try_into_extension()?; - self.state = State::Complete; - self.mpcot.finalize()?; + self.state = State::Extension(receiver); Ok(()) } } #[async_trait] -impl RandomCOTReceiver - for Receiver +impl RandomCOTReceiver for Receiver where - Ctx: Context, - RandomCOT: RandomCOTReceiver + Send + Clone + Default + 'static, - SetupRandomCOT: Send + 'static, + RandomCOT: Send, { async fn receive_random_correlated( &mut self, - ctx: &mut Ctx, + _ctx: &mut Ctx, count: usize, ) -> Result, OTError> { - let (mut choices_buffer, mut msgs_buffer) = self.extend(ctx).await?; - - assert_eq!(choices_buffer.len(), msgs_buffer.len()); - - let l = choices_buffer.len(); - - let id = self - .state - .try_as_extension() - .map_err(ReceiverError::from)? - .id(); - - if count <= l { - let choices_res = choices_buffer.drain(..count).collect(); + let State::Extension(receiver) = &mut self.state else { + return Err(ReceiverError::state("receiver not in extension state").into()); + }; + + receiver + .consume(count) + .map_err(ReceiverError::from) + .map_err(OTError::from) + } +} - let msgs_res = msgs_buffer.drain(..count).collect(); +impl Allocate for Receiver { + fn alloc(&mut self, count: usize) { + self.alloc += count; + } +} - return Ok(RCOTReceiverOutput { - id, - choices: choices_res, - msgs: msgs_res, - }); - } else { - let mut choices_res = choices_buffer; - let mut msgs_res = msgs_buffer; +#[async_trait] +impl Preprocess for Receiver +where + Ctx: Context, + RandomCOT: RandomCOTReceiver + Send, +{ + type Error = ReceiverError; - for _ in 0..count / l - 1 { - (choices_buffer, msgs_buffer) = self.extend(ctx).await?; + async fn preprocess(&mut self, ctx: &mut Ctx) -> Result<(), Self::Error> { + let count = mem::take(&mut self.alloc); + self.extend(ctx, count).await + } +} - choices_res.extend_from_slice(&choices_buffer); - msgs_res.extend_from_slice(&msgs_buffer); - } +#[derive(Debug)] +struct ReceiverBuffer { + buffer: RCOTReceiverOutput, +} - (choices_buffer, msgs_buffer) = self.extend(ctx).await?; +impl ReceiverBuffer { + fn new(buffer: RCOTReceiverOutput) -> Self { + Self { buffer } + } +} - choices_res.extend_from_slice(&choices_buffer[0..count % l]); - msgs_res.extend_from_slice(&msgs_buffer[0..count % l]); +impl Default for ReceiverBuffer { + fn default() -> Self { + ReceiverBuffer { + buffer: RCOTReceiverOutput { + id: Default::default(), + choices: Vec::new(), + msgs: Vec::new(), + }, + } + } +} - return Ok(RCOTReceiverOutput { - id, - choices: choices_res, - msgs: msgs_res, - }); +#[async_trait] +impl RandomCOTReceiver for ReceiverBuffer { + async fn receive_random_correlated( + &mut self, + _ctx: &mut Ctx, + count: usize, + ) -> Result, OTError> { + if count > self.buffer.choices.len() { + return Err(ReceiverError::io(format!( + "insufficient OTs: {} < {count}", + self.buffer.choices.len() + )) + .into()); } + + let choices = self.buffer.choices.drain(0..count).collect(); + let msgs = self.buffer.msgs.drain(0..count).collect(); + Ok(RCOTReceiverOutput { + id: self.buffer.id.next(), + choices, + msgs, + }) } } diff --git a/crates/mpz-ot/src/ferret/sender.rs b/crates/mpz-ot/src/ferret/sender.rs index 709ff8e2..187c1744 100644 --- a/crates/mpz-ot/src/ferret/sender.rs +++ b/crates/mpz-ot/src/ferret/sender.rs @@ -1,45 +1,60 @@ -use crate::{ferret::mpcot::Sender as MpcotSender, RandomCOTSender}; -use enum_try_as_inner::EnumTryAsInner; -use mpz_common::Context; +use std::mem; + +use crate::{ferret::mpcot, Correlation, RandomCOTSender}; +use async_trait::async_trait; +use mpz_common::{cpu::CpuBackend, Allocate, Context, Preprocess}; use mpz_core::Block; use mpz_ot_core::{ - ferret::sender::{state, Sender as SenderCore}, + ferret::{ + sender::{state, Sender as SenderCore}, + LpnType, CSP, CUCKOO_HASH_NUM, + }, RCOTSenderOutput, }; use serio::stream::IoStreamExt; -use utils_aio::non_blocking_backend::{Backend, NonBlockingBackend}; use super::{FerretConfig, SenderError}; -use crate::{async_trait, OTError}; +use crate::OTError; -#[derive(Debug, EnumTryAsInner)] -#[derive_err(Debug)] +#[derive(Debug)] pub(crate) enum State { Initialized(SenderCore), Extension(SenderCore), - Complete, Error, } +impl State { + fn take(&mut self) -> Self { + std::mem::replace(self, State::Error) + } +} + /// Ferret Sender. #[derive(Debug)] -pub struct Sender { +pub struct Sender { state: State, - mpcot: MpcotSender, - config: FerretConfig, + config: FerretConfig, + rcot: RandomCOT, + alloc: usize, + buffer: SenderBuffer, + buffer_len: usize, } -impl Sender -where - RandomCOT: Send + Default + Clone, - SetupRandomCOT: Send, -{ +impl Sender { /// Creates a new Sender. - pub fn new(config: FerretConfig) -> Self { + /// + /// # Argument + /// + /// `config` - The Ferret config. + /// `rcot` - The random COT in setup. + pub fn new(config: FerretConfig, rcot: RandomCOT) -> Self { Self { state: State::Initialized(SenderCore::new()), - mpcot: MpcotSender::new(config.lpn_type()), config, + rcot, + alloc: 0, + buffer: Default::default(), + buffer_len: 0, } } @@ -48,39 +63,57 @@ where /// # Argument /// /// * `ctx` - The channel context. - /// * `delta` - The provided delta used for sender. - pub async fn setup_with_delta( - &mut self, - ctx: &mut Ctx, - delta: Block, - ) -> Result<(), SenderError> + pub async fn setup(&mut self, ctx: &mut Ctx) -> Result<(), SenderError> where Ctx: Context, - SetupRandomCOT: RandomCOTSender, + RandomCOT: RandomCOTSender + Correlation, { - let ext_sender = std::mem::replace(&mut self.state, State::Error).try_into_initialized()?; - - let rcot = self.config.rcot(); - - self.mpcot.setup_with_delta(ctx, delta, rcot).await?; + let State::Initialized(sender) = self.state.take() else { + return Err(SenderError::state("sender not in initialized state")); + }; let params = self.config.lpn_parameters(); let lpn_type = self.config.lpn_type(); + // Compute the number of buffered OTs. + self.buffer_len = match lpn_type { + // The number here is a rough estimation to ensure sufficient buffer. + // It is hard to precisely compute the number because of the Cuckoo hashes. + LpnType::Uniform => { + let m = (1.5 * (params.t as f32)).ceil() as usize; + m * ((2 * CUCKOO_HASH_NUM * params.n / m) + .checked_next_power_of_two() + .expect("The length should be less than usize::MAX / 2 - 1") + .ilog2() as usize) + + CSP + } + // In our chosen paramters, we always set n is divided by t and n/t is a power of 2. + LpnType::Regular => { + assert!(params.n % params.t == 0 && (params.n / params.t).is_power_of_two()); + params.t * ((params.n / params.t).ilog2() as usize) + CSP + } + }; + // Get random blocks from ideal Random COT. - let RCOTSenderOutput { msgs: v, .. } = self - .config - .setup_rcot() - .send_random_correlated(ctx, params.k) + let RCOTSenderOutput { msgs: mut v, id } = self + .rcot + .send_random_correlated(ctx, params.k + self.buffer_len) .await?; + // Initiate buffer. + let buffer = RCOTSenderOutput { + id, + msgs: v.drain(0..self.buffer_len).collect(), + }; + self.buffer = SenderBuffer::new(self.rcot.delta(), buffer); + // Get seed for LPN matrix from receiver. let seed = ctx.io_mut().expect_next().await?; // Ferret core setup. - let ext_sender = ext_sender.setup(delta, params, lpn_type, seed, &v)?; + let sender = sender.setup(self.rcot.delta(), params, lpn_type, seed, &v)?; - self.state = State::Extension(ext_sender); + self.state = State::Extension(sender); Ok(()) } @@ -89,72 +122,173 @@ where /// /// # Argument /// - /// * `ctx` - The channel context. - async fn extend(&mut self, ctx: &mut Ctx) -> Result, SenderError> + /// * `ctx` - Thread context. + /// * `count` - The number of OTs to extend. + pub async fn extend( + &mut self, + ctx: &mut Ctx, + count: usize, + ) -> Result<(), SenderError> where - RandomCOT: RandomCOTSender, + RandomCOT: RandomCOTSender + Send, { - let mut ext_sender = - std::mem::replace(&mut self.state, State::Error).try_into_extension()?; + let State::Extension(mut sender) = self.state.take() else { + return Err(SenderError::state("sender not in extension state")); + }; - let (t, n) = ext_sender.get_mpcot_query(); + let lpn_type = self.config.lpn_type(); + let delta = sender.delta(); + let target = sender.remaining() + count; + while sender.remaining() < target { + let (t, n) = sender.get_mpcot_query(); - let s = self.mpcot.extend(ctx, t, n).await?; + let s = mpcot::send(ctx, &mut self.buffer, delta, lpn_type, t, n).await?; - let (ext_sender, output) = - Backend::spawn(move || ext_sender.extend(&s).map(|output| (ext_sender, output))) - .await?; - self.state = State::Extension(ext_sender); + sender = CpuBackend::blocking(move || sender.extend(s).map(|()| sender)).await?; - Ok(output) - } + // Update sender buffer. + let buffer = sender + .consume(self.buffer_len) + .map_err(SenderError::from) + .map_err(OTError::from)?; - /// Complete extension - pub fn finalize(&mut self) -> Result<(), SenderError> { - std::mem::replace(&mut self.state, State::Error).try_into_extension()?; - self.state = State::Complete; - self.mpcot.finalize()?; + self.buffer = SenderBuffer::new(delta, buffer); + } + + self.state = State::Extension(sender); Ok(()) } } +impl Correlation for Sender +where + RandomCOT: Correlation, +{ + type Correlation = Block; + + fn delta(&self) -> Self::Correlation { + self.rcot.delta() + } +} + #[async_trait] -impl RandomCOTSender - for Sender +impl RandomCOTSender for Sender where - Ctx: Context, - RandomCOT: RandomCOTSender + Send + Default + Clone + 'static, - SetupRandomCOT: Send + 'static, + RandomCOT: Correlation + Send, { async fn send_random_correlated( &mut self, - ctx: &mut Ctx, + _ctx: &mut Ctx, count: usize, ) -> Result, OTError> { - let mut buffer = self.extend(ctx).await?; - let l = buffer.len(); - - let id = self - .state - .try_as_extension() - .map_err(SenderError::from)? - .id(); - - if count <= l { - let res = buffer.drain(..count).collect(); - return Ok(RCOTSenderOutput { id, msgs: res }); - } else { - let mut res = buffer; - for _ in 0..count / l - 1 { - buffer = self.extend(ctx).await?; - res.extend_from_slice(&buffer); - } + let State::Extension(sender) = &mut self.state else { + return Err(SenderError::state("sender not in extension state").into()); + }; + + sender + .consume(count) + .map_err(SenderError::from) + .map_err(OTError::from) + } +} + +impl Allocate for Sender { + fn alloc(&mut self, count: usize) { + self.alloc += count; + } +} + +#[async_trait] +impl Preprocess for Sender +where + Ctx: Context, + RandomCOT: RandomCOTSender + Send, +{ + type Error = SenderError; - buffer = self.extend(ctx).await?; - res.extend_from_slice(&buffer[0..count % l]); + async fn preprocess(&mut self, ctx: &mut Ctx) -> Result<(), Self::Error> { + let count = mem::take(&mut self.alloc); + self.extend(ctx, count).await + } +} + +#[derive(Debug)] +struct SenderBuffer { + delta: Block, + buffer: RCOTSenderOutput, +} + +impl SenderBuffer { + fn new(delta: Block, buffer: RCOTSenderOutput) -> Self { + Self { delta, buffer } + } +} + +impl Default for SenderBuffer { + fn default() -> Self { + let buffer = RCOTSenderOutput { + id: Default::default(), + msgs: Vec::new(), + }; + Self { + delta: Block::ZERO, + buffer, + } + } +} +impl Correlation for SenderBuffer { + type Correlation = Block; + + fn delta(&self) -> Self::Correlation { + self.delta + } +} - return Ok(RCOTSenderOutput { id, msgs: res }); +#[async_trait] +impl RandomCOTSender for SenderBuffer { + async fn send_random_correlated( + &mut self, + _ctx: &mut Ctx, + count: usize, + ) -> Result, OTError> { + if count > self.buffer.msgs.len() { + return Err(SenderError::io(format!( + "insufficient OTs: {} < {count}", + self.buffer.msgs.len() + )) + .into()); } + + let msgs = self.buffer.msgs.drain(0..count).collect(); + Ok(RCOTSenderOutput { + id: self.buffer.id.next(), + msgs, + }) + } +} + +#[derive(Debug)] +struct BootstrappedSender<'a>(&'a mut SenderCore); + +impl Correlation for BootstrappedSender<'_> { + type Correlation = Block; + + fn delta(&self) -> Block { + self.0.delta() + } +} + +#[async_trait] +impl RandomCOTSender for BootstrappedSender<'_> { + async fn send_random_correlated( + &mut self, + _ctx: &mut Ctx, + count: usize, + ) -> Result, OTError> { + self.0 + .consume(count) + .map_err(SenderError::from) + .map_err(OTError::from) } } diff --git a/crates/mpz-ot/src/ferret/spcot.rs b/crates/mpz-ot/src/ferret/spcot.rs new file mode 100644 index 00000000..bccad692 --- /dev/null +++ b/crates/mpz-ot/src/ferret/spcot.rs @@ -0,0 +1,167 @@ +//! Implementation of the Single-Point COT (spcot) protocol in the [`Ferret`](https://eprint.iacr.org/2020/924.pdf) paper. + +use mpz_common::{cpu::CpuBackend, Context}; +use mpz_core::Block; +use mpz_ot_core::{ + ferret::{ + spcot::{ + msgs::{ExtendFromSender, MaskBits}, + receiver::Receiver as ReceiverCore, + sender::Sender as SenderCore, + }, + CSP, + }, + RCOTReceiverOutput, RCOTSenderOutput, +}; +use serio::{stream::IoStreamExt as _, SinkExt as _}; + +use crate::{ferret::error::SPCOTError as Error, RandomCOTReceiver, RandomCOTSender}; + +/// SPCOT send. +/// +/// # Arguments +/// +/// * `ctx` - Thread context. +/// * `rcot` - Random COT sender. +/// * `delta` - Delta correlation. +/// * `hs` - The depth of the GGM trees. +pub(crate) async fn send>( + ctx: &mut Ctx, + rcot: &mut RandomCOT, + delta: Block, + hs: &[usize], +) -> Result>, Error> { + let mut sender = SenderCore::new().setup(delta); + + let h = hs.iter().sum(); + let RCOTSenderOutput { msgs: qss, .. } = rcot.send_random_correlated(ctx, h).await?; + + let masks: Vec = ctx.io_mut().expect_next().await?; + + // extend + let h_in = hs.to_vec(); + let (mut sender, extend_msg) = CpuBackend::blocking(move || { + sender + .extend(&h_in, &qss, &masks) + .map(|extend_msg| (sender, extend_msg)) + }) + .await?; + + ctx.io_mut().send(extend_msg).await?; + + // batch check + let RCOTSenderOutput { msgs: y_star, .. } = rcot.send_random_correlated(ctx, CSP).await?; + + let checkfr = ctx.io_mut().expect_next().await?; + + let (output, check_msg) = CpuBackend::blocking(move || { + sender + .check(&y_star, checkfr) + .map(|(output, check_msg)| (output, check_msg)) + }) + .await?; + + ctx.io_mut().send(check_msg).await?; + + Ok(output) +} + +/// SPCOT receive. +/// +/// # Arguments +/// +/// * `ctx` - Thread context. +/// * `rcot` - Random COT receiver. +/// * `alphas` - Vector of chosen positions. +/// * `hs` - The depth of the GGM trees. +pub(crate) async fn receive>( + ctx: &mut Ctx, + rcot: &mut RandomCOT, + alphas: &[u32], + hs: &[usize], +) -> Result, u32)>, Error> { + let mut receiver = ReceiverCore::new().setup(); + + let h = hs.iter().sum(); + let RCOTReceiverOutput { + choices: rss, + msgs: tss, + .. + } = rcot.receive_random_correlated(ctx, h).await?; + + // extend + let h_in = hs.to_vec(); + let alphas_in = alphas.to_vec(); + let (mut receiver, masks) = CpuBackend::blocking(move || { + receiver + .extend_mask_bits(&h_in, &alphas_in, &rss) + .map(|mask| (receiver, mask)) + }) + .await?; + + ctx.io_mut().send(masks).await?; + + let extendfss: Vec = ctx.io_mut().expect_next().await?; + + let h_in = hs.to_vec(); + let alphas_in = alphas.to_vec(); + let mut receiver = CpuBackend::blocking(move || { + receiver + .extend(&h_in, &alphas_in, &tss, &extendfss) + .map(|_| receiver) + }) + .await?; + + // batch check + let RCOTReceiverOutput { + choices: x_star, + msgs: z_star, + .. + } = rcot.receive_random_correlated(ctx, CSP).await?; + + let (mut receiver, checkfr) = CpuBackend::blocking(move || { + receiver + .check_pre(&x_star) + .map(|checkfr| (receiver, checkfr)) + }) + .await?; + + ctx.io_mut().send(checkfr).await?; + let check = ctx.io_mut().expect_next().await?; + + let output = + CpuBackend::blocking(move || receiver.check(&z_star, check).map(|output| output)).await?; + + Ok(output) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ideal::cot::ideal_rcot, Correlation}; + use mpz_common::executor::test_st_executor; + + #[tokio::test] + async fn test_spcot() { + let (mut ctx_sender, mut ctx_receiver) = test_st_executor(8); + let (mut rcot_sender, mut rcot_receiver) = ideal_rcot(); + + let hs = [8usize, 4]; + let alphas = [4u32, 2]; + let delta = rcot_sender.delta(); + + let (mut output_sender, output_receiver) = tokio::try_join!( + send(&mut ctx_sender, &mut rcot_sender, delta, &hs), + receive(&mut ctx_receiver, &mut rcot_receiver, &alphas, &hs) + ) + .unwrap(); + + assert!(output_sender + .iter_mut() + .zip(output_receiver.iter()) + .all(|(vs, (ws, alpha))| { + vs[*alpha as usize] ^= delta; + vs == ws + })); + } +} diff --git a/crates/mpz-ot/src/ferret/spcot/error.rs b/crates/mpz-ot/src/ferret/spcot/error.rs deleted file mode 100644 index 5f23f466..00000000 --- a/crates/mpz-ot/src/ferret/spcot/error.rs +++ /dev/null @@ -1,59 +0,0 @@ -use crate::OTError; - -/// A SPCOT sender error. -#[derive(Debug, thiserror::Error)] -#[allow(missing_docs, clippy::enum_variant_names)] -pub enum SenderError { - #[error(transparent)] - IOError(#[from] std::io::Error), - #[error(transparent)] - CoreError(#[from] mpz_ot_core::ferret::spcot::error::SenderError), - #[error(transparent)] - RandomCOTError(#[from] OTError), - #[error("{0}")] - StateError(String), -} - -impl From for OTError { - fn from(err: SenderError) -> Self { - match err { - SenderError::IOError(e) => e.into(), - e => OTError::SenderError(Box::new(e)), - } - } -} - -impl From for SenderError { - fn from(err: crate::ferret::spcot::sender::StateError) -> Self { - SenderError::StateError(err.to_string()) - } -} - -/// A SPCOT receiver error. -#[derive(Debug, thiserror::Error)] -#[allow(missing_docs, clippy::enum_variant_names)] -pub enum ReceiverError { - #[error(transparent)] - IOError(#[from] std::io::Error), - #[error(transparent)] - CoreError(#[from] mpz_ot_core::ferret::spcot::error::ReceiverError), - #[error(transparent)] - RandomCOTError(#[from] OTError), - #[error("{0}")] - StateError(String), -} - -impl From for OTError { - fn from(err: ReceiverError) -> Self { - match err { - ReceiverError::IOError(e) => e.into(), - e => OTError::ReceiverError(Box::new(e)), - } - } -} - -impl From for ReceiverError { - fn from(err: crate::ferret::spcot::receiver::StateError) -> Self { - ReceiverError::StateError(err.to_string()) - } -} diff --git a/crates/mpz-ot/src/ferret/spcot/mod.rs b/crates/mpz-ot/src/ferret/spcot/mod.rs deleted file mode 100644 index 6e53fd28..00000000 --- a/crates/mpz-ot/src/ferret/spcot/mod.rs +++ /dev/null @@ -1,103 +0,0 @@ -//! Implementation of the Single-Point COT (spcot) protocol in the [`Ferret`](https://eprint.iacr.org/2020/924.pdf) paper. - -mod error; -mod receiver; -mod sender; - -pub(crate) use error::{ReceiverError, SenderError}; -pub(crate) use receiver::Receiver; -pub(crate) use sender::Sender; - -#[cfg(test)] -mod tests { - use super::*; - use crate::{ - ideal::cot::{ideal_rcot, IdealCOTReceiver, IdealCOTSender}, - OTError, - }; - use futures::TryFutureExt; - use mpz_common::executor::test_st_executor; - use mpz_core::Block; - - fn setup() -> ( - Sender, - Receiver, - IdealCOTSender, - IdealCOTReceiver, - Block, - ) { - let (mut rcot_sender, rcot_receiver) = ideal_rcot(); - - let delta = rcot_sender.alice().get_mut().delta(); - - let sender = Sender::new(); - let receiver = Receiver::new(); - - (sender, receiver, rcot_sender, rcot_receiver, delta) - } - - #[tokio::test] - async fn test_spcot() { - let (mut ctx_sender, mut ctx_receiver) = test_st_executor(8); - - let (mut sender, mut receiver, rcot_sender, rcot_receiver, delta) = setup(); - - // shold set the same delta as in RCOT. - sender.setup_with_delta(delta, rcot_sender).unwrap(); - receiver.setup(rcot_receiver).unwrap(); - - let hs = [8, 4]; - let alphas = [4, 2]; - - tokio::try_join!( - sender.extend(&mut ctx_sender, &hs).map_err(OTError::from), - receiver - .extend(&mut ctx_receiver, &alphas, &hs) - .map_err(OTError::from) - ) - .unwrap(); - - let (mut output_sender, output_receiver) = tokio::try_join!( - sender.check(&mut ctx_sender).map_err(OTError::from), - receiver.check(&mut ctx_receiver).map_err(OTError::from) - ) - .unwrap(); - - assert!(output_sender - .iter_mut() - .zip(output_receiver.iter()) - .all(|(vs, (ws, alpha))| { - vs[*alpha as usize] ^= delta; - vs == ws - })); - - // extend twice. - let hs = [6, 9, 8]; - let alphas = [2, 1, 3]; - - tokio::try_join!( - sender.extend(&mut ctx_sender, &hs).map_err(OTError::from), - receiver - .extend(&mut ctx_receiver, &alphas, &hs) - .map_err(OTError::from) - ) - .unwrap(); - - let (mut output_sender, output_receiver) = tokio::try_join!( - sender.check(&mut ctx_sender).map_err(OTError::from), - receiver.check(&mut ctx_receiver).map_err(OTError::from) - ) - .unwrap(); - - assert!(output_sender - .iter_mut() - .zip(output_receiver.iter()) - .all(|(vs, (ws, alpha))| { - vs[*alpha as usize] ^= delta; - vs == ws - })); - - sender.finalize().unwrap(); - receiver.finalize().unwrap(); - } -} diff --git a/crates/mpz-ot/src/ferret/spcot/receiver.rs b/crates/mpz-ot/src/ferret/spcot/receiver.rs deleted file mode 100644 index 3c48bfad..00000000 --- a/crates/mpz-ot/src/ferret/spcot/receiver.rs +++ /dev/null @@ -1,164 +0,0 @@ -use crate::{ferret::spcot::error::ReceiverError, RandomCOTReceiver}; -use enum_try_as_inner::EnumTryAsInner; -use mpz_common::Context; -use mpz_core::Block; -use mpz_ot_core::{ - ferret::{ - spcot::{ - msgs::ExtendFromSender, - receiver::{state, Receiver as ReceiverCore}, - }, - CSP, - }, - RCOTReceiverOutput, -}; -use serio::{stream::IoStreamExt, SinkExt}; -use utils_aio::non_blocking_backend::{Backend, NonBlockingBackend}; - -#[derive(Debug, EnumTryAsInner)] -#[derive_err(Debug)] -pub(crate) enum State { - Initialized(ReceiverCore), - Extension(Box>), - Complete, - Error, -} - -/// SPCOT Receiver. -#[derive(Debug)] -pub(crate) struct Receiver { - state: State, - rcot: RandomCOT, -} - -impl Receiver { - /// Creates a new Receiver. - pub(crate) fn new() -> Self { - Self { - state: State::Initialized(ReceiverCore::new()), - rcot: Default::default(), - } - } - - /// Performs setup for receiver. - /// - /// # Arguments. - /// - /// * `rcot` - The random COT used by the receiver. - pub(crate) fn setup(&mut self, rcot: RandomCOT) -> Result<(), ReceiverError> { - let ext_receiver = - std::mem::replace(&mut self.state, State::Error).try_into_initialized()?; - - let ext_receiver = ext_receiver.setup(); - self.state = State::Extension(Box::new(ext_receiver)); - self.rcot = rcot; - Ok(()) - } - - /// Performs spcot extension for receiver. - /// - /// # Arguments - /// - /// * `ctx` - The context. - /// * `alphas`` - The vector of chosen positions. - /// * `h` - The depth of GGM tree. - pub(crate) async fn extend( - &mut self, - ctx: &mut Ctx, - alphas: &[u32], - hs: &[usize], - ) -> Result<(), ReceiverError> - where - RandomCOT: RandomCOTReceiver, - { - let mut ext_receiver = - std::mem::replace(&mut self.state, State::Error).try_into_extension()?; - - let h = hs.iter().sum(); - let RCOTReceiverOutput { - choices: rss, - msgs: tss, - .. - } = self.rcot.receive_random_correlated(ctx, h).await?; - - // extend - let h_in = hs.to_vec(); - let alphas_in = alphas.to_vec(); - let (mut ext_receiver, masks) = Backend::spawn(move || { - ext_receiver - .extend_mask_bits(&h_in, &alphas_in, &rss) - .map(|mask| (ext_receiver, mask)) - }) - .await?; - - ctx.io_mut().send(masks).await?; - - let extendfss: Vec = ctx.io_mut().expect_next().await?; - - let h_in = hs.to_vec(); - let alphas_in = alphas.to_vec(); - let ext_receiver = Backend::spawn(move || { - ext_receiver - .extend(&h_in, &alphas_in, &tss, &extendfss) - .map(|_| ext_receiver) - }) - .await?; - - self.state = State::Extension(ext_receiver); - - Ok(()) - } - - /// Performs batch check for SPCOT extension. - /// - /// # Arguments - /// - /// * `ctx` - The context. - pub(crate) async fn check( - &mut self, - ctx: &mut Ctx, - ) -> Result, u32)>, ReceiverError> - where - RandomCOT: RandomCOTReceiver, - { - let mut ext_receiver = - std::mem::replace(&mut self.state, State::Error).try_into_extension()?; - - // batch check - let RCOTReceiverOutput { - choices: x_star, - msgs: z_star, - .. - } = self.rcot.receive_random_correlated(ctx, CSP).await?; - - let (mut ext_receiver, checkfr) = Backend::spawn(move || { - ext_receiver - .check_pre(&x_star) - .map(|checkfr| (ext_receiver, checkfr)) - }) - .await?; - - ctx.io_mut().send(checkfr).await?; - let check = ctx.io_mut().expect_next().await?; - - let (ext_receiver, output) = Backend::spawn(move || { - ext_receiver - .check(&z_star, check) - .map(|output| (ext_receiver, output)) - }) - .await?; - - self.state = State::Extension(ext_receiver); - - Ok(output) - } - - /// Complete extension. - pub(crate) fn finalize(&mut self) -> Result<(), ReceiverError> { - std::mem::replace(&mut self.state, State::Error).try_into_extension()?; - - self.state = State::Complete; - - Ok(()) - } -} diff --git a/crates/mpz-ot/src/ferret/spcot/sender.rs b/crates/mpz-ot/src/ferret/spcot/sender.rs deleted file mode 100644 index 9178b787..00000000 --- a/crates/mpz-ot/src/ferret/spcot/sender.rs +++ /dev/null @@ -1,144 +0,0 @@ -use crate::{ferret::spcot::error::SenderError, RandomCOTSender}; -use enum_try_as_inner::EnumTryAsInner; -use mpz_common::Context; -use mpz_core::Block; -use mpz_ot_core::{ - ferret::{ - spcot::{ - msgs::MaskBits, - sender::{state, Sender as SenderCore}, - }, - CSP, - }, - RCOTSenderOutput, -}; -use serio::{stream::IoStreamExt, SinkExt}; -use utils_aio::non_blocking_backend::{Backend, NonBlockingBackend}; - -#[derive(Debug, EnumTryAsInner)] -#[derive_err(Debug)] -pub(crate) enum State { - Initialized(SenderCore), - Extension(Box>), - Complete, - Error, -} - -/// SPCOT sender. -#[derive(Debug)] -pub(crate) struct Sender { - state: State, - rcot: RandomCOT, -} - -impl Sender { - /// Creates a new Sender. - pub(crate) fn new() -> Self { - Self { - state: State::Initialized(SenderCore::new()), - rcot: Default::default(), - } - } - - /// Performs setup with the provided delta. - /// - /// # Arguments - /// - /// * `delta` - The delta value to use for OT extension. - /// * `rcot` - The random COT used by the sender. - pub(crate) fn setup_with_delta( - &mut self, - delta: Block, - rcot: RandomCOT, - ) -> Result<(), SenderError> { - let ext_sender = std::mem::replace(&mut self.state, State::Error).try_into_initialized()?; - - let ext_sender = ext_sender.setup(delta); - - self.state = State::Extension(Box::new(ext_sender)); - self.rcot = rcot; - Ok(()) - } - - /// Performs spcot extension for sender. - /// - /// # Arguments - /// - /// * `ctx` - The context. - /// * `hs` - The depths of GGM trees. - pub(crate) async fn extend( - &mut self, - ctx: &mut Ctx, - hs: &[usize], - ) -> Result<(), SenderError> - where - RandomCOT: RandomCOTSender, - { - let mut ext_sender = - std::mem::replace(&mut self.state, State::Error).try_into_extension()?; - - let h = hs.iter().sum(); - let RCOTSenderOutput { msgs: qss, .. } = self.rcot.send_random_correlated(ctx, h).await?; - - let masks: Vec = ctx.io_mut().expect_next().await?; - - // extend - let h_in = hs.to_vec(); - let (ext_sender, extend_msg) = Backend::spawn(move || { - ext_sender - .extend(&h_in, &qss, &masks) - .map(|extend_msg| (ext_sender, extend_msg)) - }) - .await?; - - ctx.io_mut().send(extend_msg).await?; - - self.state = State::Extension(ext_sender); - - Ok(()) - } - - /// Performs batch check for SPCOT extension. - /// - /// # Arguments - /// - /// * `ctx` - The context. - pub(crate) async fn check( - &mut self, - ctx: &mut Ctx, - ) -> Result>, SenderError> - where - RandomCOT: RandomCOTSender, - { - let mut ext_sender = - std::mem::replace(&mut self.state, State::Error).try_into_extension()?; - - // batch check - let RCOTSenderOutput { msgs: y_star, .. } = - self.rcot.send_random_correlated(ctx, CSP).await?; - - let checkfr = ctx.io_mut().expect_next().await?; - - let (ext_sender, output, check_msg) = Backend::spawn(move || { - ext_sender - .check(&y_star, checkfr) - .map(|(output, check_msg)| (ext_sender, output, check_msg)) - }) - .await?; - - ctx.io_mut().send(check_msg).await?; - - self.state = State::Extension(ext_sender); - - Ok(output) - } - - /// Complete extension. - pub(crate) fn finalize(&mut self) -> Result<(), SenderError> { - std::mem::replace(&mut self.state, State::Error).try_into_extension()?; - - self.state = State::Complete; - - Ok(()) - } -} diff --git a/crates/mpz-ot/src/ideal/cot.rs b/crates/mpz-ot/src/ideal/cot.rs index 18233dfe..bc7df0a6 100644 --- a/crates/mpz-ot/src/ideal/cot.rs +++ b/crates/mpz-ot/src/ideal/cot.rs @@ -11,7 +11,9 @@ use mpz_ot_core::{ ideal::cot::IdealCOT, COTReceiverOutput, COTSenderOutput, RCOTReceiverOutput, RCOTSenderOutput, }; -use crate::{COTReceiver, COTSender, OTError, OTSetup, RandomCOTReceiver, RandomCOTSender}; +use crate::{ + COTReceiver, COTSender, Correlation, OTError, OTSetup, RandomCOTReceiver, RandomCOTSender, +}; fn cot( f: &mut IdealCOT, @@ -82,6 +84,14 @@ where } } +impl Correlation for IdealCOTSender { + type Correlation = Block; + + fn delta(&self) -> Block { + self.0.lock().delta() + } +} + #[async_trait] impl COTSender for IdealCOTSender { async fn send_correlated( @@ -170,7 +180,7 @@ mod tests { let (mut ctx_a, mut ctx_b) = test_st_executor(8); let (mut alice, mut bob) = ideal_cot(); - let delta = alice.0.get_mut().delta(); + let delta = alice.delta(); let count = 10; let choices = (0..count).map(|_| rng.gen()).collect::>(); @@ -201,7 +211,7 @@ mod tests { let (mut ctx_a, mut ctx_b) = test_st_executor(8); let (mut alice, mut bob) = ideal_rcot(); - let delta = alice.0.get_mut().delta(); + let delta = alice.delta(); let count = 10; diff --git a/crates/mpz-ot/src/lib.rs b/crates/mpz-ot/src/lib.rs index 0e4d1b48..c1508883 100644 --- a/crates/mpz-ot/src/lib.rs +++ b/crates/mpz-ot/src/lib.rs @@ -61,9 +61,18 @@ pub trait OTSender { async fn send(&mut self, ctx: &mut Ctx, msgs: &[T]) -> Result; } +/// Correlation of COT messages. +pub trait Correlation { + /// The type of the correlation. + type Correlation; + + /// Returns the correlation. + fn delta(&self) -> Self::Correlation; +} + /// A correlated oblivious transfer sender. #[async_trait] -pub trait COTSender { +pub trait COTSender: Correlation { /// Obliviously transfers the correlated messages to the receiver. /// /// Returns the `0`-bit messages that were obliviously transferred. @@ -97,7 +106,7 @@ pub trait RandomOTSender { /// A random correlated oblivious transfer sender. #[async_trait] -pub trait RandomCOTSender { +pub trait RandomCOTSender: Correlation { /// Obliviously transfers the correlated messages to the receiver. /// /// Returns the `0`-bit messages that were obliviously transferred. From 5f4e90f904fc7973fd5eed22b2b15ef3d02f882a Mon Sep 17 00:00:00 2001 From: Xiang Xie Date: Tue, 20 Aug 2024 07:09:36 +0800 Subject: [PATCH 4/6] fix clippy --- crates/mpz-ot-core/src/chou_orlandi/receiver.rs | 2 +- crates/mpz-ot-core/src/chou_orlandi/sender.rs | 2 +- crates/mpz-ot-core/src/ferret/receiver.rs | 4 ++-- crates/mpz-ot-core/src/ferret/sender.rs | 4 ++-- crates/mpz-ot-core/src/ideal/cot.rs | 2 +- crates/mpz-ot-core/src/ideal/mpcot.rs | 2 +- crates/mpz-ot-core/src/ideal/ot.rs | 2 +- crates/mpz-ot-core/src/ideal/rot.rs | 4 ++-- crates/mpz-ot-core/src/ideal/spcot.rs | 2 +- crates/mpz-ot-core/src/kos/receiver.rs | 2 +- crates/mpz-ot-core/src/kos/sender.rs | 2 +- crates/mpz-ot-core/src/lib.rs | 2 +- crates/mpz-ot/src/ferret/receiver.rs | 10 +++++----- crates/mpz-ot/src/ferret/sender.rs | 2 +- crates/mpz-ot/src/ferret/spcot.rs | 3 +-- 15 files changed, 22 insertions(+), 23 deletions(-) diff --git a/crates/mpz-ot-core/src/chou_orlandi/receiver.rs b/crates/mpz-ot-core/src/chou_orlandi/receiver.rs index 403802f9..d9638951 100644 --- a/crates/mpz-ot-core/src/chou_orlandi/receiver.rs +++ b/crates/mpz-ot-core/src/chou_orlandi/receiver.rs @@ -153,7 +153,7 @@ impl Receiver { let SenderPayload { id, payload } = payload; // Check that the transfer id matches - let expected_id = current_id.next(); + let expected_id = current_id.next_id(); if id != expected_id { return Err(ReceiverError::IdMismatch(expected_id, id)); } diff --git a/crates/mpz-ot-core/src/chou_orlandi/sender.rs b/crates/mpz-ot-core/src/chou_orlandi/sender.rs index 09a8b5a6..328354eb 100644 --- a/crates/mpz-ot-core/src/chou_orlandi/sender.rs +++ b/crates/mpz-ot-core/src/chou_orlandi/sender.rs @@ -139,7 +139,7 @@ impl Sender { } = receiver_payload; // Check that the transfer id matches - let expected_id = current_id.next(); + let expected_id = current_id.next_id(); if id != expected_id { return Err(SenderError::IdMismatch(expected_id, id)); } diff --git a/crates/mpz-ot-core/src/ferret/receiver.rs b/crates/mpz-ot-core/src/ferret/receiver.rs index 1cfd1e08..782d2b9e 100644 --- a/crates/mpz-ot-core/src/ferret/receiver.rs +++ b/crates/mpz-ot-core/src/ferret/receiver.rs @@ -119,7 +119,7 @@ impl Receiver { return Err(ReceiverError("the length of r should be n".to_string())); } - self.state.id.next(); + self.state.id.next_id(); // Compute z = A * w + r. let mut z = r; @@ -169,7 +169,7 @@ impl Receiver { let msgs = self.state.msgs_buffer.drain(0..count).collect(); Ok(RCOTReceiverOutput { - id: self.state.id.next(), + id: self.state.id.next_id(), choices, msgs, }) diff --git a/crates/mpz-ot-core/src/ferret/sender.rs b/crates/mpz-ot-core/src/ferret/sender.rs index 436d6003..e6af6452 100644 --- a/crates/mpz-ot-core/src/ferret/sender.rs +++ b/crates/mpz-ot-core/src/ferret/sender.rs @@ -109,7 +109,7 @@ impl Sender { return Err(SenderError("the length of s should be n".to_string())); } - self.state.id.next(); + self.state.id.next_id(); // Compute y = A * v + s let mut y = s; @@ -139,7 +139,7 @@ impl Sender { let msgs = self.state.msgs_buffer.drain(0..count).collect(); Ok(RCOTSenderOutput { - id: self.state.id.next(), + id: self.state.id.next_id(), msgs, }) } diff --git a/crates/mpz-ot-core/src/ideal/cot.rs b/crates/mpz-ot-core/src/ideal/cot.rs index a28abef8..a842129d 100644 --- a/crates/mpz-ot-core/src/ideal/cot.rs +++ b/crates/mpz-ot-core/src/ideal/cot.rs @@ -76,7 +76,7 @@ impl IdealCOT { .collect(); self.counter += count; - let id = self.transfer_id.next(); + let id = self.transfer_id.next_id(); ( RCOTSenderOutput { id, msgs }, diff --git a/crates/mpz-ot-core/src/ideal/mpcot.rs b/crates/mpz-ot-core/src/ideal/mpcot.rs index 44a5595f..c038331b 100644 --- a/crates/mpz-ot-core/src/ideal/mpcot.rs +++ b/crates/mpz-ot-core/src/ideal/mpcot.rs @@ -60,7 +60,7 @@ impl IdealMpcot { self.counter += 1; } - let id = self.transfer_id.next(); + let id = self.transfer_id.next_id(); (MPCOTSenderOutput { id, s }, MPCOTReceiverOutput { id, r }) } diff --git a/crates/mpz-ot-core/src/ideal/ot.rs b/crates/mpz-ot-core/src/ideal/ot.rs index e389066e..76ebe630 100644 --- a/crates/mpz-ot-core/src/ideal/ot.rs +++ b/crates/mpz-ot-core/src/ideal/ot.rs @@ -55,7 +55,7 @@ impl IdealOT { self.counter += choices.len(); self.choices.extend(choices); - let id = self.transfer_id.next(); + let id = self.transfer_id.next_id(); (OTSenderOutput { id }, OTReceiverOutput { id, msgs: chosen }) } diff --git a/crates/mpz-ot-core/src/ideal/rot.rs b/crates/mpz-ot-core/src/ideal/rot.rs index 8a8b5d68..e29b9204 100644 --- a/crates/mpz-ot-core/src/ideal/rot.rs +++ b/crates/mpz-ot-core/src/ideal/rot.rs @@ -68,7 +68,7 @@ impl IdealROT { .collect(); self.counter += count; - let id = self.transfer_id.next(); + let id = self.transfer_id.next_id(); ( ROTSenderOutput { id, msgs }, @@ -103,7 +103,7 @@ impl IdealROT { .collect(); self.counter += choices.len(); - let id = self.transfer_id.next(); + let id = self.transfer_id.next_id(); ( ROTSenderOutput { id, msgs }, diff --git a/crates/mpz-ot-core/src/ideal/spcot.rs b/crates/mpz-ot-core/src/ideal/spcot.rs index 12c5f829..93b3c720 100644 --- a/crates/mpz-ot-core/src/ideal/spcot.rs +++ b/crates/mpz-ot-core/src/ideal/spcot.rs @@ -61,7 +61,7 @@ impl IdealSpcot { self.counter += n; } - let id = self.transfer_id.next(); + let id = self.transfer_id.next_id(); (SPCOTSenderOutput { id, v }, SPCOTReceiverOutput { id, w }) } diff --git a/crates/mpz-ot-core/src/kos/receiver.rs b/crates/mpz-ot-core/src/kos/receiver.rs index fdcad328..127c4f1d 100644 --- a/crates/mpz-ot-core/src/kos/receiver.rs +++ b/crates/mpz-ot-core/src/kos/receiver.rs @@ -330,7 +330,7 @@ impl Receiver { )); } - let id = self.state.transfer_id.next(); + let id = self.state.transfer_id.next_id(); let index = self.state.index - self.state.keys.len(); Ok(ReceiverKeys { diff --git a/crates/mpz-ot-core/src/kos/sender.rs b/crates/mpz-ot-core/src/kos/sender.rs index 24917940..23edff5c 100644 --- a/crates/mpz-ot-core/src/kos/sender.rs +++ b/crates/mpz-ot-core/src/kos/sender.rs @@ -294,7 +294,7 @@ impl Sender { return Err(SenderError::InsufficientSetup(count, self.state.keys.len())); } - let id = self.state.transfer_id.next(); + let id = self.state.transfer_id.next_id(); Ok(SenderKeys { id, diff --git a/crates/mpz-ot-core/src/lib.rs b/crates/mpz-ot-core/src/lib.rs index dcfffc59..b0b69260 100644 --- a/crates/mpz-ot-core/src/lib.rs +++ b/crates/mpz-ot-core/src/lib.rs @@ -45,7 +45,7 @@ impl std::fmt::Display for TransferId { impl TransferId { /// Returns the current transfer ID, incrementing `self` in-place. - pub fn next(&mut self) -> Self { + pub fn next_id(&mut self) -> Self { let id = *self; self.0 += 1; id diff --git a/crates/mpz-ot/src/ferret/receiver.rs b/crates/mpz-ot/src/ferret/receiver.rs index d04e5d29..fbbb38eb 100644 --- a/crates/mpz-ot/src/ferret/receiver.rs +++ b/crates/mpz-ot/src/ferret/receiver.rs @@ -19,8 +19,8 @@ use crate::{ #[derive(Debug)] pub(crate) enum State { - Initialized(ReceiverCore), - Extension(ReceiverCore), + Initialized(Box>), + Extension(Box>), Error, } @@ -50,7 +50,7 @@ impl Receiver { /// * `rcot` - The random COT in setup. pub fn new(config: FerretConfig, rcot: RandomCOT) -> Self { Self { - state: State::Initialized(ReceiverCore::new()), + state: State::Initialized(Box::new(ReceiverCore::new())), config, rcot, alloc: 0, @@ -119,7 +119,7 @@ impl Receiver { ctx.io_mut().send(seed).await?; - self.state = State::Extension(receiver); + self.state = State::Extension(Box::new(receiver)); Ok(()) } @@ -245,7 +245,7 @@ impl RandomCOTReceiver for ReceiverBuffer { let choices = self.buffer.choices.drain(0..count).collect(); let msgs = self.buffer.msgs.drain(0..count).collect(); Ok(RCOTReceiverOutput { - id: self.buffer.id.next(), + id: self.buffer.id.next_id(), choices, msgs, }) diff --git a/crates/mpz-ot/src/ferret/sender.rs b/crates/mpz-ot/src/ferret/sender.rs index 187c1744..02884b2c 100644 --- a/crates/mpz-ot/src/ferret/sender.rs +++ b/crates/mpz-ot/src/ferret/sender.rs @@ -262,7 +262,7 @@ impl RandomCOTSender for SenderBuffer { let msgs = self.buffer.msgs.drain(0..count).collect(); Ok(RCOTSenderOutput { - id: self.buffer.id.next(), + id: self.buffer.id.next_id(), msgs, }) } diff --git a/crates/mpz-ot/src/ferret/spcot.rs b/crates/mpz-ot/src/ferret/spcot.rs index bccad692..5fcb6e6c 100644 --- a/crates/mpz-ot/src/ferret/spcot.rs +++ b/crates/mpz-ot/src/ferret/spcot.rs @@ -57,7 +57,6 @@ pub(crate) async fn send>( let (output, check_msg) = CpuBackend::blocking(move || { sender .check(&y_star, checkfr) - .map(|(output, check_msg)| (output, check_msg)) }) .await?; @@ -130,7 +129,7 @@ pub(crate) async fn receive Date: Tue, 20 Aug 2024 08:01:39 +0800 Subject: [PATCH 5/6] params --- crates/mpz-ot-core/src/ferret/mod.rs | 12 ++++++------ crates/mpz-ot/examples/ferret.rs | 1 + 2 files changed, 7 insertions(+), 6 deletions(-) create mode 100644 crates/mpz-ot/examples/ferret.rs diff --git a/crates/mpz-ot-core/src/ferret/mod.rs b/crates/mpz-ot-core/src/ferret/mod.rs index 0e27f0a9..6b478b06 100644 --- a/crates/mpz-ot-core/src/ferret/mod.rs +++ b/crates/mpz-ot-core/src/ferret/mod.rs @@ -22,17 +22,17 @@ pub const CUCKOO_TRIAL_NUM: usize = 100; /// LPN parameters with regular noise. /// Derived from https://github.com/emp-toolkit/emp-ot/blob/master/emp-ot/ferret/constants.h pub const LPN_PARAMETERS_REGULAR: LpnParameters = LpnParameters { - n: 10180608, - k: 124000, - t: 4971, + n: 10_180_608, + k: 124_000, + t: 4_971, }; /// LPN parameters with uniform noise. /// Derived from Table 2. pub const LPN_PARAMETERS_UNIFORM: LpnParameters = LpnParameters { - n: 10616092, - k: 588160, - t: 1324, + n: 10_616_092, + k: 588_160, + t: 1_324, }; /// The type of Lpn parameters. diff --git a/crates/mpz-ot/examples/ferret.rs b/crates/mpz-ot/examples/ferret.rs new file mode 100644 index 00000000..f328e4d9 --- /dev/null +++ b/crates/mpz-ot/examples/ferret.rs @@ -0,0 +1 @@ +fn main() {} From 84399bd0976fa89b7203f2612a5de069b0c558ba Mon Sep 17 00:00:00 2001 From: Xiang Xie Date: Wed, 21 Aug 2024 10:08:18 +0800 Subject: [PATCH 6/6] add default ferret configs --- crates/mpz-ot-core/src/ferret/mod.rs | 19 ---- crates/mpz-ot/src/ferret/mod.rs | 132 +++++++++++++++++++++++++++ crates/mpz-ot/src/ferret/spcot.rs | 9 +- 3 files changed, 134 insertions(+), 26 deletions(-) diff --git a/crates/mpz-ot-core/src/ferret/mod.rs b/crates/mpz-ot-core/src/ferret/mod.rs index 6b478b06..ac73c005 100644 --- a/crates/mpz-ot-core/src/ferret/mod.rs +++ b/crates/mpz-ot-core/src/ferret/mod.rs @@ -1,7 +1,4 @@ //! An implementation of the [`Ferret`](https://eprint.iacr.org/2020/924.pdf) protocol. - -use mpz_core::lpn::LpnParameters; - pub mod cuckoo; pub mod error; pub mod mpcot; @@ -19,22 +16,6 @@ pub const CUCKOO_HASH_NUM: usize = 3; /// Trial numbers in Cuckoo hash insertion. pub const CUCKOO_TRIAL_NUM: usize = 100; -/// LPN parameters with regular noise. -/// Derived from https://github.com/emp-toolkit/emp-ot/blob/master/emp-ot/ferret/constants.h -pub const LPN_PARAMETERS_REGULAR: LpnParameters = LpnParameters { - n: 10_180_608, - k: 124_000, - t: 4_971, -}; - -/// LPN parameters with uniform noise. -/// Derived from Table 2. -pub const LPN_PARAMETERS_UNIFORM: LpnParameters = LpnParameters { - n: 10_616_092, - k: 588_160, - t: 1_324, -}; - /// The type of Lpn parameters. #[derive(Debug, Clone, Copy, Default)] pub enum LpnType { diff --git a/crates/mpz-ot/src/ferret/mod.rs b/crates/mpz-ot/src/ferret/mod.rs index 086e5e8b..9d421885 100644 --- a/crates/mpz-ot/src/ferret/mod.rs +++ b/crates/mpz-ot/src/ferret/mod.rs @@ -44,6 +44,138 @@ impl FerretConfig { } } +/// Ferret config with regular LPN parameters. +/// Parameters for setup with small extension output. +pub const FERRET_REGULAR_SETUP_SMALL: FerretConfig = FerretConfig { + lpn_parameters: LpnParameters { + n: 102_400, + k: 6_750, + t: 1_600, + }, + lpn_type: LpnType::Regular, +}; + +/// Ferret config with regular LPN parameters. +/// Parameters for extension with small extension output. +pub const FERRET_REGULAR_EXTENSION_SMALL: FerretConfig = FerretConfig { + lpn_parameters: LpnParameters { + n: 1_740_800, + k: 66_400, + t: 1700, + }, + lpn_type: LpnType::Regular, +}; + +/// Ferret config with regular LPN parameters. +/// Parameters for setup with medium extension output. +pub const FERRET_REGULAR_SETUP_MEDIUM: FerretConfig = FerretConfig { + lpn_parameters: LpnParameters { + n: 283_648, + k: 18_584, + t: 1_108, + }, + lpn_type: LpnType::Regular, +}; + +/// Ferret config with regular LPN parameters. +/// Parameters for extension with medium extension output. +pub const FERRET_REGULAR_EXTENSION_MEDIUM: FerretConfig = FerretConfig { + lpn_parameters: LpnParameters { + n: 5_324_800, + k: 240_000, + t: 1_300, + }, + lpn_type: LpnType::Regular, +}; + +/// Ferret config with regular LPN parameters. +/// Parameters for setup with large extension output. +pub const FERRET_REGULAR_SETUP_LARGE: FerretConfig = FerretConfig { + lpn_parameters: LpnParameters { + n: 518_656, + k: 34_643, + t: 1_013, + }, + lpn_type: LpnType::Regular, +}; + +/// Ferret config with regular LPN parameters. +/// Parameters for extension with large extension output. +pub const FERRET_REGULAR_EXTENSION_LARGE: FerretConfig = FerretConfig { + lpn_parameters: LpnParameters { + n: 10_485_760, + k: 458_000, + t: 1280, + }, + lpn_type: LpnType::Regular, +}; + +/// Ferret config with uniform LPN parameters. +/// Parameters for setup with small extension output. +pub const FERRET_UNIFORM_SETUP_SMALL: FerretConfig = FerretConfig { + lpn_parameters: LpnParameters { + n: 98_000, + k: 4_450, + t: 1_600, + }, + lpn_type: LpnType::Uniform, +}; + +/// Ferret config with uniform LPN parameters. +/// Parameters for extension with small extension output. +pub const FERRET_UNIFORM_EXTENSION_SMALL: FerretConfig = FerretConfig { + lpn_parameters: LpnParameters { + n: 1_071_888, + k: 40_800, + t: 1720, + }, + lpn_type: LpnType::Uniform, +}; + +/// Ferret config with uniform LPN parameters. +/// Parameters for setup with medium extension output. +pub const FERRET_UNIFORM_SETUP_MEDIUM: FerretConfig = FerretConfig { + lpn_parameters: LpnParameters { + n: 283_648, + k: 18_584, + t: 1_108, + }, + lpn_type: LpnType::Uniform, +}; + +/// Ferret config with uniform LPN parameters. +/// Parameters for extension with medium extension output. +pub const FERRET_UNIFORM_EXTENSION_MEDIUM: FerretConfig = FerretConfig { + lpn_parameters: LpnParameters { + n: 5_324_800, + k: 240_000, + t: 1_300, + }, + lpn_type: LpnType::Uniform, +}; + +/// Ferret config with uniform LPN parameters. +/// Parameters for setup with large extension output. +pub const FERRET_UNIFORM_SETUP_LARGE: FerretConfig = FerretConfig { + lpn_parameters: LpnParameters { + n: 545_656, + k: 34_643, + t: 1_050, + }, + lpn_type: LpnType::Uniform, +}; + +/// Ferret config with uniform LPN parameters. +/// Parameters for extension with large extension output. +pub const FERRET_UNIFORM_EXTENSION_LARGE: FerretConfig = FerretConfig { + lpn_parameters: LpnParameters { + n: 10_488_928, + k: 458_000, + t: 1_280, + }, + lpn_type: LpnType::Uniform, +}; + #[cfg(test)] mod tests { use super::*; diff --git a/crates/mpz-ot/src/ferret/spcot.rs b/crates/mpz-ot/src/ferret/spcot.rs index 5fcb6e6c..e63a1aa9 100644 --- a/crates/mpz-ot/src/ferret/spcot.rs +++ b/crates/mpz-ot/src/ferret/spcot.rs @@ -54,11 +54,7 @@ pub(crate) async fn send>( let checkfr = ctx.io_mut().expect_next().await?; - let (output, check_msg) = CpuBackend::blocking(move || { - sender - .check(&y_star, checkfr) - }) - .await?; + let (output, check_msg) = CpuBackend::blocking(move || sender.check(&y_star, checkfr)).await?; ctx.io_mut().send(check_msg).await?; @@ -128,8 +124,7 @@ pub(crate) async fn receive