diff --git a/crates/mpz-core/benches/ggm.rs b/crates/mpz-core/benches/ggm.rs index 3b902caf..fb540d91 100644 --- a/crates/mpz-core/benches/ggm.rs +++ b/crates/mpz-core/benches/ggm.rs @@ -1,33 +1,25 @@ use criterion::{black_box, criterion_group, criterion_main, Criterion}; -use mpz_core::{block::Block, ggm_tree::GgmTree}; +use mpz_core::{block::Block, ggm::GgmTree}; #[allow(clippy::all)] fn criterion_benchmark(c: &mut Criterion) { c.bench_function("ggm::gen::1K", move |bench| { let depth = 10; - let ggm = GgmTree::new(depth); - let mut tree = vec![Block::ZERO; 1 << (depth)]; - let mut k0 = vec![Block::ZERO; depth]; - let mut k1 = vec![Block::ZERO; depth]; let seed = rand::random::(); + let mut leaves = vec![Block::ZERO; 1 << depth]; bench.iter(|| { - black_box(ggm.gen( - black_box(seed), - black_box(&mut tree), - black_box(&mut k0), - black_box(&mut k1), - )); + GgmTree::new_from_seed(depth, seed, &mut leaves); + black_box(&leaves); }); }); c.bench_function("ggm::reconstruction::1K", move |bench| { let depth = 10; - let ggm = GgmTree::new(depth); - let mut tree = vec![Block::ZERO; 1 << (depth)]; - let k = vec![Block::ZERO; depth]; - let alpha = vec![false; depth]; + let sums = vec![Block::ZERO; depth]; + let mut leaves = vec![Block::ZERO; 1 << depth]; bench.iter(|| { - black_box(ggm.reconstruct(black_box(&mut tree), black_box(&k), black_box(&alpha))) + GgmTree::new_partial(depth, &sums, 420, &mut leaves); + black_box(&leaves); }); }); } diff --git a/crates/mpz-core/benches/lpn.rs b/crates/mpz-core/benches/lpn.rs index 73ddf049..dc439cca 100644 --- a/crates/mpz-core/benches/lpn.rs +++ b/crates/mpz-core/benches/lpn.rs @@ -7,7 +7,7 @@ fn criterion_benchmark(c: &mut Criterion) { let seed = Block::ZERO; let k = 5_060; let n = 166_400; - let lpn = LpnEncoder::<10>::new(seed, k); + let lpn = LpnEncoder::<10>::new(k); let mut x = vec![Block::ZERO; k as usize]; let mut y = vec![Block::ZERO; n]; let mut prg = Prg::new(); @@ -15,7 +15,7 @@ fn criterion_benchmark(c: &mut Criterion) { prg.random_blocks(&mut y); bench.iter(|| { #[allow(clippy::unit_arg)] - black_box(lpn.compute(&mut y, &x)); + black_box(lpn.compute(seed, &mut y, &x)); }); }); @@ -23,7 +23,7 @@ fn criterion_benchmark(c: &mut Criterion) { let seed = Block::ZERO; let k = 158_000; let n = 10_168_320; - let lpn = LpnEncoder::<10>::new(seed, k); + let lpn = LpnEncoder::<10>::new(k); let mut x = vec![Block::ZERO; k as usize]; let mut y = vec![Block::ZERO; n]; let mut prg = Prg::new(); @@ -31,7 +31,7 @@ fn criterion_benchmark(c: &mut Criterion) { prg.random_blocks(&mut y); bench.iter(|| { #[allow(clippy::unit_arg)] - black_box(lpn.compute(&mut y, &x)); + black_box(lpn.compute(seed, &mut y, &x)); }); }); @@ -39,7 +39,7 @@ fn criterion_benchmark(c: &mut Criterion) { let seed = Block::ZERO; let k = 588_160; let n = 10_616_092; - let lpn = LpnEncoder::<10>::new(seed, k); + let lpn = LpnEncoder::<10>::new(k); let mut x = vec![Block::ZERO; k as usize]; let mut y = vec![Block::ZERO; n]; let mut prg = Prg::new(); @@ -47,7 +47,7 @@ fn criterion_benchmark(c: &mut Criterion) { prg.random_blocks(&mut y); bench.iter(|| { #[allow(clippy::unit_arg)] - black_box(lpn.compute(&mut y, &x)); + black_box(lpn.compute(seed, &mut y, &x)); }); }); } diff --git a/crates/mpz-core/src/aes.rs b/crates/mpz-core/src/aes.rs index c24cb95c..39ba227b 100644 --- a/crates/mpz-core/src/aes.rs +++ b/crates/mpz-core/src/aes.rs @@ -109,8 +109,8 @@ impl FixedKeyAes { .for_each(|(a, b)| *a ^= *b); } - /// Circular correlation-robust hash function instantiated using fixed-key AES - /// (cf., §7.3). + /// Circular correlation-robust hash function instantiated using fixed-key + /// AES (cf., §7.3). /// /// `π(σ(x)) ⊕ σ(x)`, where `π` is instantiated using fixed-key AES /// @@ -120,8 +120,8 @@ impl FixedKeyAes { self.cr(Block::sigma(block)) } - /// Circular correlation-robust hash function instantiated using fixed-key AES - /// (cf., §7.3). + /// Circular correlation-robust hash function instantiated using fixed-key + /// AES (cf., §7.3). /// /// `π(σ(x)) ⊕ σ(x)`, where `π` is instantiated using fixed-key AES /// @@ -159,6 +159,11 @@ impl AesEncryptor { blk } + /// Encrypt a block in-place. + pub fn encrypt_block_inplace(&self, blk: &mut Block) { + self.0.encrypt_block(blk.as_generic_array_mut()); + } + /// Encrypt many blocks in-place. #[inline(always)] pub fn encrypt_many_blocks(&self, blks: &mut [Block; N]) { @@ -177,7 +182,8 @@ impl AesEncryptor { /// /// Each batch of NM blocks is encrypted by a corresponding AES key. /// - /// **Only the first NK * NM blocks of blks are handled, the rest are ignored.** + /// **Only the first NK * NM blocks of blks are handled, the rest are + /// ignored.** /// /// # Arguments /// diff --git a/crates/mpz-core/src/block.rs b/crates/mpz-core/src/block.rs index bba91868..5e5e0e5f 100644 --- a/crates/mpz-core/src/block.rs +++ b/crates/mpz-core/src/block.rs @@ -7,6 +7,10 @@ use generic_array::{typenum::consts::U16, GenericArray}; use itybity::{BitIterable, BitLength, FromBitIterator, GetBit, Lsb0, Msb0}; use rand::{distributions::Standard, prelude::Distribution, CryptoRng, Rng}; use serde::{Deserialize, Serialize}; +use std::{ + fmt::{Debug, Display}, + slice::from_raw_parts, +}; /// A block of 128 bits #[repr(transparent)] @@ -147,6 +151,28 @@ impl Block { bytemuck::cast([x[1], x[0]]) } + /// Converts a slice of blocks to a slice of bytes. + pub fn as_flattened_bytes(slice: &[Self]) -> &[u8] { + // This is equivalent to `<[[u8; 16]]>::as_flattened` + + // SAFETY: `slice.len() * Block::LEN` cannot overflow because `slice` is + // already in the address space. + let len = unsafe { slice.len().unchecked_mul(Self::LEN) }; + // SAFETY: `[u8]` is layout-identical to `[u8; 16]` of which block is a newtype. + unsafe { from_raw_parts(slice.as_ptr().cast(), len) } + } + + /// Converts a slice of block arrays to a slice of bytes. + pub fn array_as_flattened_bytes(slice: &[[Self; N]]) -> &[u8] { + // This is equivalent to `<[[u8; 16 * N]]>::as_flattened` + + // SAFETY: `slice.len() * N * Block::LEN` cannot overflow because `slice` is + // already in the address space. + let len = unsafe { slice.len().unchecked_mul(N * Self::LEN) }; + // SAFETY: `[u8]` is layout-identical to `[u8; 16]` of which block is a newtype. + unsafe { from_raw_parts(slice.as_ptr().cast(), len) } + } + /// Converts a block to a [`GenericArray`](cipher::generic_array::GenericArray) from the [`generic-array`](https://docs.rs/generic-array/latest/generic_array/) crate. #[allow(dead_code)] @@ -182,6 +208,17 @@ impl Block { } } +impl Display for Block { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("Block(")?; + for byte in self.0.iter() { + write!(f, "{:02x}", byte)?; + } + f.write_str(")")?; + Ok(()) + } +} + /// A trait for converting a type to blocks pub trait BlockSerialize { /// The block representation of the type diff --git a/crates/mpz-core/src/ggm.rs b/crates/mpz-core/src/ggm.rs new file mode 100644 index 00000000..3413491b --- /dev/null +++ b/crates/mpz-core/src/ggm.rs @@ -0,0 +1,234 @@ +//! GGM tree. + +use std::ops::Range; + +use itybity::ToBits; + +use crate::{tkprp::TwoKeyPrp, Block}; + +/// Returns the range of nodes at the given layer. +#[inline] +fn layer(n: usize) -> Range { + let start = (1 << n) - 1; + let end = start + (1 << n); + + start..end +} + +/// Returns the width of the tree at the given depth. +#[inline] +fn width(n: usize) -> usize { + 1 << n +} + +/// GGM tree. +pub struct GgmTree<'a> { + depth: usize, + buf: Vec, + leaves: &'a mut [Block], +} + +impl<'a> GgmTree<'a> { + /// Creates a new GGM tree. + /// + /// # Arguments + /// + /// * `depth` - The depth of the tree. + /// * `seed` - The seed of the tree. + /// * `leaves` - The leaves of the tree. + pub fn new_from_seed(depth: usize, seed: Block, leaves: &'a mut [Block]) -> Self { + assert_eq!(leaves.len(), 1 << depth, "invalid length of leaves"); + + let mut buf = vec![Block::ZERO; (1 << depth) - 1]; + + let tkprp = TwoKeyPrp::new([Block::ZERO, Block::ONE]); + + buf[0] = seed; + for n in 0..depth - 1 { + let parents = layer(n); + let children = layer(n + 1); + + let (parents, children) = buf[parents.start..children.end].split_at_mut(width(n)); + + tkprp.expand(parents, children); + } + + // Expand the last layer. + tkprp.expand(&buf[layer(depth - 1)], leaves); + + Self { depth, buf, leaves } + } + + /// Recovers a partial GGM tree which is missing a leaf at the given + /// index. Missing nodes in the tree are set to zero. + /// + /// # Panics + /// + /// - If the position is out of bounds. + /// - If the length of the sums is not equal to the depth minus one. + /// + /// # Arguments + /// + /// * `depth` - Depth of the tree. + /// * `sums` - Sum of the left or right nodes for each layer. + /// * `idx` - Index of the missing leaf. + /// * `leaves` - Leaves of the tree. + pub fn new_partial(depth: usize, sums: &[Block], idx: usize, leaves: &'a mut [Block]) -> Self { + assert!(idx < 1 << depth, "index out of bounds"); + assert_eq!(sums.len(), depth, "invalid length of sums"); + + let mut buf = vec![Block::ZERO; (1 << depth) - 1]; + + let tkprp = TwoKeyPrp::new([Block::ZERO, Block::ONE]); + + // The path length is equal to the depth of the tree. + let idx = idx as u32; + let path = idx.iter_msb0().skip(32 - depth); + + // Recovers the value of the sibling node. + fn recover(layer: &mut [Block], sum: Block, offset: usize, select: bool) { + layer[offset + select as usize] = Block::ZERO; + layer[offset + !select as usize] = Block::ZERO; + + let value = layer + .iter() + .skip(!select as usize) + .step_by(2) + .fold(sum, |acc, value| acc ^ value); + + layer[offset + !select as usize] = value; + } + + let mut offset = 0; + for ((select, sum), n) in path.zip(sums).zip(1..depth + 1) { + if n < depth - 1 { + let (inputs, outputs) = + buf[layer(n).start..layer(n + 1).end].split_at_mut(width(n)); + + recover(inputs, *sum, offset, select); + + tkprp.expand(inputs, outputs); + } else if n == depth - 1 { + let inputs = &mut buf[layer(n)]; + + recover(inputs, *sum, offset, select); + + tkprp.expand(inputs, leaves); + } else if n == depth { + recover(leaves, *sum, offset, select); + + break; + } + + offset += select as usize; + offset <<= 1; + } + + Self { depth, buf, leaves } + } + + /// Returns the root of the tree. + pub fn root(&self) -> &Block { + &self.buf[0] + } + + /// Returns the depth of the tree. + pub fn depth(&self) -> usize { + self.depth + } + + /// Returns the layer at the given depth. + pub fn layer(&self, depth: usize) -> Option<&[Block]> { + if depth < self.depth { + return Some(&self.buf[layer(depth)]); + } else if depth == self.depth { + return Some(&self.leaves); + } + + None + } + + /// Returns an iterator over the layers of the GGM tree. + pub fn iter_layers(&self) -> impl Iterator { + (0..=self.depth).flat_map(|i| self.layer(i)) + } + + /// Returns the sums of the left and right nodes for each layer. + pub fn layer_sums(&self) -> impl Iterator + '_ { + self.iter_layers().skip(1).map(|layer| { + let mut left = Block::ZERO; + let mut right = Block::ZERO; + + for nodes in layer.chunks_exact(2) { + left ^= nodes[0]; + right ^= nodes[1]; + } + + [left, right] + }) + } + + /// Returns the leaves of the GGM tree. + pub fn leaves(&self) -> &[Block] { + self.leaves + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_ggm() { + let seed = Block::ONES; + let depth = 4; + + let mut leaves = vec![Block::ZERO; 1 << depth]; + + GgmTree::new_from_seed(depth, seed, &mut leaves); + + assert_ne!(leaves, vec![Block::ZERO; 1 << depth]); + } + + #[test] + fn test_ggm_get_layer() { + let seed = Block::ONES; + let depth = 4; + + let mut leaves = vec![Block::ZERO; 1 << depth]; + + let ggm = GgmTree::new_from_seed(depth, seed, &mut leaves); + + for i in 0..depth { + let layer = ggm.layer(i).unwrap(); + + assert_eq!(layer.len(), 1 << i); + } + } + + #[test] + fn test_ggm_partial() { + let seed = Block::ONES; + let depth = 4; + + let mut full_leaves = vec![Block::ZERO; 1 << depth]; + let ggm = GgmTree::new_from_seed(depth, seed, &mut full_leaves); + + for i in 0..1 << depth { + let path = i as u32; + let sums = ggm + .layer_sums() + .zip(path.iter_msb0().skip(32 - depth)) + .map(|(sum, select)| sum[!select as usize]) + .collect::>(); + + let mut leaves = vec![Block::ZERO; 1 << depth]; + let ggm_partial = GgmTree::new_partial(depth, &sums, i, &mut leaves); + let mut full_leaves = ggm.leaves().to_vec(); + + full_leaves[i] = Block::ZERO; + + assert_eq!(ggm_partial.leaves(), full_leaves.as_slice()); + } + } +} diff --git a/crates/mpz-core/src/ggm_tree.rs b/crates/mpz-core/src/ggm_tree.rs deleted file mode 100644 index aead069a..00000000 --- a/crates/mpz-core/src/ggm_tree.rs +++ /dev/null @@ -1,175 +0,0 @@ -//! Implement GGM tree for OT. -//! Implementation of GGM based on the procedure explained in the write-up -//! (, Page 14) - -use crate::{tkprp::TwoKeyPrp, Block}; - -/// Struct of GGM -pub struct GgmTree { - tkprp: TwoKeyPrp, - depth: usize, -} - -impl GgmTree { - ///New GgmTree instance. - #[inline(always)] - pub fn new(depth: usize) -> Self { - let tkprp = TwoKeyPrp::new([Block::ZERO, Block::from(1u128.to_le_bytes())]); - Self { tkprp, depth } - } - - /// Create a GGM tree in-place. - /// - /// # Arguments - /// - /// * `seed` - a seed. - /// * `tree` - the destination to write the GGM (binary tree) `tree`, with - /// size `2^{depth}`. - /// * `k0` - XORs of all the left-node values in each level, with size - /// `depth`. - /// * `k1`- XORs of all the right-node values in each level, with size - /// `depth`. - // This implementation is adapted from EMP Toolkit. - pub fn gen(&self, seed: Block, tree: &mut [Block], k0: &mut [Block], k1: &mut [Block]) { - assert_eq!(tree.len(), 1 << (self.depth)); - assert_eq!(k0.len(), self.depth); - assert_eq!(k1.len(), self.depth); - let mut buf = [Block::ZERO; 8]; - self.tkprp.expand_1to2(tree, seed); - k0[0] = tree[0]; - k1[0] = tree[1]; - - self.tkprp.expand_2to4(&mut buf, tree); - k0[1] = buf[0] ^ buf[2]; - k1[1] = buf[1] ^ buf[3]; - tree[0..4].copy_from_slice(&buf[0..4]); - - for h in 2..self.depth { - k0[h] = Block::ZERO; - k1[h] = Block::ZERO; - - // How many nodes there are in this layer - let sz = 1 << h; - for i in (0..=sz - 4).rev().step_by(4) { - self.tkprp.expand_4to8(&mut buf, &tree[i..]); - k0[h] ^= buf[0]; - k0[h] ^= buf[2]; - k0[h] ^= buf[4]; - k0[h] ^= buf[6]; - k1[h] ^= buf[1]; - k1[h] ^= buf[3]; - k1[h] ^= buf[5]; - k1[h] ^= buf[7]; - - tree[2 * i..2 * i + 8].copy_from_slice(&buf); - } - } - } - - /// Reconstruct the GGM tree except the value in a given position. - /// - /// This reconstructs the GGM tree entirely except `tree[pos] == - /// Block::ZERO`. The bit decomposition of `pos` is the complement of - /// `alpha`. i.e., `pos[i] = 1 xor alpha[i]`. - /// - /// # Arguments - /// - /// * `k` - a slice of blocks with length `depth`, the values of k are - /// chosen via OT from k0 and k1. For the i-th value, if `alpha[i] == 1, - /// k[i] = k1[i]; else k[i] = k0[i]`. - /// * `alpha` - a slice of bits with length `depth`. - /// * `tree` - the destination to write the GGM tree. - pub fn reconstruct(&self, tree: &mut [Block], k: &[Block], alpha: &[bool]) { - assert_eq!(tree.len(), 1 << (self.depth)); - assert_eq!(k.len(), self.depth); - assert_eq!(alpha.len(), self.depth); - - let mut pos = 0; - for i in 1..=self.depth { - pos *= 2; - tree[pos] = Block::ZERO; - tree[pos + 1] = Block::ZERO; - if !alpha[i - 1] { - self.reconstruct_layer(i, false, pos, k[i - 1], tree); - pos += 1; - } else { - self.reconstruct_layer(i, true, pos + 1, k[i - 1], tree); - } - } - } - - // Handle each layer. - fn reconstruct_layer( - &self, - depth: usize, - left_or_right: bool, - pos: usize, - k: Block, - tree: &mut [Block], - ) { - // How many nodes there are in this layer - let sz = 1 << depth; - - let mut sum = Block::ZERO; - let start = if left_or_right { 1 } else { 0 }; - - for i in (start..sz).step_by(2) { - sum ^= tree[i]; - } - tree[pos] = sum ^ k; - - if depth == (self.depth) { - return; - } - - let mut buf = [Block::ZERO; 8]; - if sz == 2 { - self.tkprp.expand_2to4(&mut buf, tree); - tree[0..4].copy_from_slice(&buf[0..4]); - } else { - for i in (0..=sz - 4).rev().step_by(4) { - self.tkprp.expand_4to8(&mut buf, &tree[i..]); - tree[2 * i..2 * i + 8].copy_from_slice(&buf); - } - } - } -} - -#[test] -fn ggm_test() { - use crate::{ggm_tree::GgmTree, Block}; - - let depth = 3; - let mut tree = vec![Block::ZERO; 1 << depth]; - let mut k0 = vec![Block::ZERO; depth]; - let mut k1 = vec![Block::ZERO; depth]; - let mut k = vec![Block::ZERO; depth]; - let alpha = [false, true, false]; - let mut pos = 0; - - for a in alpha { - pos <<= 1; - if !a { - pos += 1; - } - } - - let ggm = GgmTree::new(depth); - - ggm.gen(Block::ZERO, &mut tree, &mut k0, &mut k1); - - for i in 0..depth { - if alpha[i] { - k[i] = k1[i]; - } else { - k[i] = k0[i]; - } - } - - let mut tree_reconstruct = vec![Block::ZERO; 1 << depth]; - ggm.reconstruct(&mut tree_reconstruct, &k, &alpha); - - assert_eq!(tree_reconstruct[pos], Block::ZERO); - tree_reconstruct[pos] = tree[pos]; - assert_eq!(tree, tree_reconstruct); -} diff --git a/crates/mpz-core/src/lib.rs b/crates/mpz-core/src/lib.rs index b9824a2c..6bb9a1a1 100644 --- a/crates/mpz-core/src/lib.rs +++ b/crates/mpz-core/src/lib.rs @@ -6,7 +6,7 @@ pub mod aes; pub mod bitvec; pub mod block; pub mod commit; -pub mod ggm_tree; +pub mod ggm; pub mod hash; pub mod lpn; pub mod prg; diff --git a/crates/mpz-core/src/lpn.rs b/crates/mpz-core/src/lpn.rs index 2543fdb2..5de84209 100644 --- a/crates/mpz-core/src/lpn.rs +++ b/crates/mpz-core/src/lpn.rs @@ -1,42 +1,44 @@ //! Implement LPN with local linear code. -//! More specifically, a local linear code is a random boolean matrix with at most D non-zero values in each row. +//! More specifically, a local linear code is a random boolean matrix with at +//! most D non-zero values in each row. use crate::{prp::Prp, Block}; -use rand::{seq::SliceRandom, thread_rng}; +use rand::Rng; use rayon::prelude::*; + /// An LPN encoder. /// -/// The `seed` defines a sparse binary matrix `A` with at most `D` non-zero values in each row. +/// The `seed` defines a sparse binary matrix `A` with at most `D` non-zero +/// values in each row. /// /// Given a vector `x` and `e`, compute `y = Ax + e`. /// -/// `A` - is a binary matrix with `k` columns and `n` rows. The concrete number of `n` is determined by the input length. `A` will be generated on-the-fly. +/// `A` - is a binary matrix with `k` columns and `n` rows. The concrete number +/// of `n` is determined by the input length. `A` will be generated on-the-fly. /// /// `x` - is a `F_{2^128}` vector with length `k`. /// /// `e` - is a `F_{2^128}` vector with length `n`. /// -/// Note that in the standard LPN problem, `x` is a binary vector, `e` is a sparse binary vector. The way we defined here is a more generic way in term of computing `y`. +/// Note that in the standard LPN problem, `x` is a binary vector, `e` is a +/// sparse binary vector. The way we defined here is a more generic way in term +/// of computing `y`. pub struct LpnEncoder { - /// The seed to generate the random sparse matrix A. - seed: Block, - /// The length of the secret, i.e., x. k: u32, - /// A mask to optimize reduction operation. mask: u32, } impl LpnEncoder { /// Create a new LPN instance. - pub fn new(seed: Block, k: u32) -> Self { + pub fn new(k: u32) -> Self { let mut mask = 1; while mask < k { mask <<= 1; mask |= 0x1; } - Self { seed, k, mask } + Self { k, mask } } /// Compute 4 rows as a batch, this is for the `compute` function. @@ -82,16 +84,17 @@ impl LpnEncoder { /// /// # Arguments /// + /// * `seed` - The seed for PRP. /// * `x` - Secret vector with length `k`. /// * `y` - Error vector with length `n`, this is actually `e` in LPN. /// /// # Panics /// /// Panics if `x.len() !=k` or `y.len() != n`. - pub fn compute(&self, y: &mut [Block], x: &[Block]) { + pub fn compute(&self, seed: Block, y: &mut [Block], x: &[Block]) { assert_eq!(x.len() as u32, self.k); assert!(x.len() >= D); - let prp = Prp::new(self.seed); + let prp = Prp::new(seed); let size = y.len() - (y.len() % 4); cfg_if::cfg_if! { @@ -112,54 +115,68 @@ impl LpnEncoder { } } +/// LPN type. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum LpnType { + /// LPN with a uniform error vector. + Uniform, + /// LPN with an error vector that has non-zero entries distributed + /// regularly. + Regular, +} + /// Lpn paramters #[derive(Copy, Clone, Debug)] pub struct LpnParameters { - /// The length of output vecotrs. + /// Length of the output vector. pub n: usize, - /// The length of the secret vector + /// Length of the secret vector. pub k: usize, - /// The Hamming Weight of error vectors + /// Hamming weight of the error vector. pub t: usize, } +/// Samples indices for non-zero entries in the error vector. +/// +/// # Panics +/// +/// Panics if `ty` is `Regular` and `len` is not a multiple of `count`. +/// +/// # Arguments +/// +/// * `rng` - Random number generator. +/// * `ty` - LPN type. +/// * `len` - Length of the error vector. +/// * `count` - Hamming weight. +pub fn sample_error_indices( + rng: &mut R, + ty: LpnType, + len: usize, + count: usize, +) -> Vec { + match ty { + LpnType::Uniform => rand::seq::index::sample(rng, len, count).into_vec(), + LpnType::Regular => { + assert_eq!(len % count, 0); + let step = len / count; + (0..count) + .map(|i| rng.gen_range(i * step..(i + 1) * step)) + .collect() + } + } +} + impl LpnParameters { /// Create a new LpnParameters instance. pub fn new(n: usize, k: usize, t: usize) -> Self { assert!(t <= n); LpnParameters { n, k, t } } - - /// Sample a uniform error vector with HW t. - pub fn sample_uniform_error_vector(&self) -> Vec { - let one: Block = bytemuck::cast(1_u128); - let mut res = vec![Block::ZERO; self.n]; - res[0..self.t].iter_mut().for_each(|x| *x = one); - let mut rng = thread_rng(); - res.shuffle(&mut rng); - res - } - - /// Sample a regular error vector with HW t - pub fn sample_regular_error_vector(&self) -> Vec { - assert_eq!(self.n % self.t, 0); - let one: Block = bytemuck::cast(1_u128); - let mut res = vec![Block::ZERO; self.n]; - let mut rng = thread_rng(); - - res.chunks_exact_mut(self.n / self.t).for_each(|x| { - x[0] = one; - x.shuffle(&mut rng); - }); - res - } } #[cfg(test)] mod tests { - use crate::lpn::LpnEncoder; - use crate::prp::Prp; - use crate::Block; + use crate::{lpn::LpnEncoder, prp::Prp, Block}; impl LpnEncoder { #[allow(dead_code)] @@ -185,10 +202,10 @@ mod tests { } #[allow(dead_code)] - pub(crate) fn compute_naive(&self, y: &mut [Block], x: &[Block]) { + pub(crate) fn compute_naive(&self, seed: Block, y: &mut [Block], x: &[Block]) { assert_eq!(x.len() as u32, self.k); assert!(x.len() >= D); - let prp = Prp::new(self.seed); + let prp = Prp::new(seed); let batch_size = y.len() / 4; for i in 0..batch_size { @@ -203,13 +220,11 @@ mod tests { #[test] fn lpn_test() { - use crate::lpn::LpnEncoder; - use crate::prg::Prg; - use crate::Block; + use crate::{lpn::LpnEncoder, prg::Prg, Block}; let k = 20; let n = 200; - let lpn = LpnEncoder::<10>::new(Block::ZERO, k); + let lpn = LpnEncoder::<10>::new(k); let mut x = vec![Block::ONES; k as usize]; let mut y = vec![Block::ONES; n]; let mut prg = Prg::new(); @@ -217,8 +232,8 @@ mod tests { prg.random_blocks(&mut y); let mut z = y.clone(); - lpn.compute_naive(&mut y, &x); - lpn.compute(&mut z, &x); + lpn.compute_naive(Block::ZERO, &mut y, &x); + lpn.compute(Block::ZERO, &mut z, &x); assert_eq!(y, z); } diff --git a/crates/mpz-core/src/tkprp.rs b/crates/mpz-core/src/tkprp.rs index 0760aed6..2fedafa3 100644 --- a/crates/mpz-core/src/tkprp.rs +++ b/crates/mpz-core/src/tkprp.rs @@ -14,75 +14,37 @@ impl TwoKeyPrp { Self([AesEncryptor::new(seeds[0]), AesEncryptor::new(seeds[1])]) } - /// expand 1 to 2 - #[inline(always)] - pub(crate) fn expand_1to2(&self, children: &mut [Block], parent: Block) { - children[0] = parent; - children[1] = parent; - AesEncryptor::para_encrypt::<2, 1>(&self.0, children); - children[0] ^= parent; - children[1] ^= parent; - } - - /// expand 2 to 4 - // p[0] p[1] - // c[0] c[1] c[2] c[3] - // t[0] t[2] t[1] t[3] - #[inline(always)] - pub(crate) fn expand_2to4(&self, children: &mut [Block], parent: &[Block]) { - let mut tmp = [Block::ZERO; 4]; - children[3] = parent[1]; - children[2] = parent[1]; - children[1] = parent[0]; - children[0] = parent[0]; - - tmp[3] = parent[1]; - tmp[1] = parent[1]; - tmp[2] = parent[0]; - tmp[0] = parent[0]; - - AesEncryptor::para_encrypt::<2, 2>(&self.0, &mut tmp); - - children[3] ^= tmp[3]; - children[2] ^= tmp[1]; - children[1] ^= tmp[2]; - children[0] ^= tmp[0]; - } - - /// expand 4 to 8 - // p[0] p[1] p[2] p[3] - // c[0] c[1] c[2] c[3] c[4] c[5] c[6] c[7] - // t[0] t[4] t[1] t[5] t[2] t[6] t[3] t[7] - #[inline(always)] - pub(crate) fn expand_4to8(&self, children: &mut [Block], parent: &[Block]) { - let mut tmp = [Block::ZERO; 8]; - children[7] = parent[3]; - children[6] = parent[3]; - children[5] = parent[2]; - children[4] = parent[2]; - children[3] = parent[1]; - children[2] = parent[1]; - children[1] = parent[0]; - children[0] = parent[0]; - - tmp[7] = parent[3]; - tmp[3] = parent[3]; - tmp[6] = parent[2]; - tmp[2] = parent[2]; - tmp[5] = parent[1]; - tmp[1] = parent[1]; - tmp[4] = parent[0]; - tmp[0] = parent[0]; - - AesEncryptor::para_encrypt::<2, 4>(&self.0, &mut tmp); - - children[7] ^= tmp[7]; - children[6] ^= tmp[3]; - children[5] ^= tmp[6]; - children[4] ^= tmp[2]; - children[3] ^= tmp[5]; - children[2] ^= tmp[1]; - children[1] ^= tmp[4]; - children[0] ^= tmp[0]; + /// Expands inputs to the destination slice. + /// + /// Outputs are written to the destination slice in the same order as the + /// inputs. + /// + /// # Panics + /// + /// Panics if the destination slice is not twice the length of the input + /// slice. + /// + /// # Arguments + /// + /// * `inputs` - The input blocks to expand with the two-key PRP. + /// * `dest` - The destination slice to write the expanded blocks. + pub fn expand(&self, inputs: &[Block], dest: &mut [Block]) { + assert_eq!( + dest.len(), + inputs.len() * 2, + "dest should have twice the length of inputs" + ); + + inputs + .iter() + .zip(dest.chunks_exact_mut(2)) + .for_each(|(input, dest)| { + dest[1] = *input; + dest[0] = *input; + self.0[1].encrypt_block_inplace(&mut dest[1]); + self.0[0].encrypt_block_inplace(&mut dest[0]); + dest[1] ^= *input; + dest[0] ^= *input; + }); } } diff --git a/crates/mpz-core/src/utils.rs b/crates/mpz-core/src/utils.rs index f7cc53cc..ff06065d 100644 --- a/crates/mpz-core/src/utils.rs +++ b/crates/mpz-core/src/utils.rs @@ -6,3 +6,25 @@ pub fn blake3(data: &[u8]) -> [u8; 32] { hasher.update(data); hasher.finalize().into() } + +/// Returns non-overlapping slices of the given lengths. +pub fn slices_from_lengths<'a, T>(mut src: &'a [T], lengths: &[usize]) -> Vec<&'a [T]> { + let mut slices = Vec::with_capacity(lengths.len()); + for &length in lengths { + let (head, tail) = src.split_at(length); + slices.push(head); + src = tail; + } + slices +} + +/// Returns non-overlapping mutable slices of the given lengths. +pub fn slices_from_lengths_mut<'a, T>(mut src: &'a mut [T], lengths: &[usize]) -> Vec<&'a mut [T]> { + let mut slices = Vec::with_capacity(lengths.len()); + for &length in lengths { + let (head, tail) = src.split_at_mut(length); + slices.push(head); + src = tail; + } + slices +} diff --git a/crates/mpz-ot-core/Cargo.toml b/crates/mpz-ot-core/Cargo.toml index fe735622..e6303a21 100644 --- a/crates/mpz-ot-core/Cargo.toml +++ b/crates/mpz-ot-core/Cargo.toml @@ -24,7 +24,7 @@ tlsn-utils.workspace = true aes.workspace = true ctr.workspace = true -blake3.workspace = true +blake3 = { workspace = true, features = ["serde"] } cipher.workspace = true rand.workspace = true rand_core.workspace = true @@ -40,6 +40,7 @@ cfg-if.workspace = true bytemuck = { workspace = true, features = ["derive"] } enum-try-as-inner.workspace = true futures = { workspace = true } +tokio = { workspace = true, default-features = false, features = ["sync"] } [dev-dependencies] rstest.workspace = true diff --git a/crates/mpz-ot-core/benches/ot.rs b/crates/mpz-ot-core/benches/ot.rs index f3f770a4..58592c12 100644 --- a/crates/mpz-ot-core/benches/ot.rs +++ b/crates/mpz-ot-core/benches/ot.rs @@ -1,8 +1,11 @@ -use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion}; +use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; use itybity::ToBits; -use mpz_core::Block; +use mpz_core::{lpn::LpnType, Block}; use mpz_ot_core::{ - chou_orlandi, kos, + chou_orlandi, + ferret::{self, FerretConfig}, + ideal::rcot::IdealRCOT, + kos, ot::{OTReceiver, OTSender}, rcot::{RCOTReceiver, RCOTSender}, }; @@ -12,6 +15,7 @@ use rand_chacha::ChaCha12Rng; fn chou_orlandi(c: &mut Criterion) { let mut group = c.benchmark_group("chou_orlandi"); for n in [128, 256, 1024] { + group.throughput(Throughput::Elements(n as u64)); group.bench_with_input(BenchmarkId::from_parameter(n), &n, |b, &n| { let msgs = vec![[Block::ONES; 2]; n]; let mut rng = ChaCha12Rng::seed_from_u64(0); @@ -39,6 +43,7 @@ fn chou_orlandi(c: &mut Criterion) { fn kos(c: &mut Criterion) { let mut group = c.benchmark_group("kos"); for n in [1024, 262144] { + group.throughput(Throughput::Elements(n as u64)); group.bench_with_input(BenchmarkId::from_parameter(n), &n, |b, &n| { let mut rng = ChaCha12Rng::seed_from_u64(0); let delta = Block::random(&mut rng); @@ -77,6 +82,55 @@ fn kos(c: &mut Criterion) { } } +fn ferret(c: &mut Criterion) { + let mut group = c.benchmark_group("ferret"); + for ty in [LpnType::Uniform, LpnType::Regular] { + let ty_str = match ty { + LpnType::Uniform => "uniform", + LpnType::Regular => "regular", + }; + + for n in [262144, 1_000_000] { + group.throughput(Throughput::Elements(n as u64)); + group.bench_with_input(BenchmarkId::new(ty_str, n), &n, |b, &n| { + let mut rng = ChaCha12Rng::seed_from_u64(0); + let delta = Block::random(&mut rng); + + let mut builder = FerretConfig::builder(); + builder.lpn_type(ty); + let config = builder.build().unwrap(); + + b.iter(|| { + let cot = IdealRCOT::new(rng.gen(), delta); + let mut sender = ferret::Sender::new(rng.gen(), config.clone(), cot.clone()); + let mut receiver = ferret::Receiver::new(rng.gen(), config.clone(), cot); + + let init = receiver.initialize().unwrap(); + sender.initialize(init).unwrap(); + sender.alloc_bootstrap().unwrap(); + receiver.alloc_bootstrap().unwrap(); + sender.acquire_cot().flush().unwrap(); + receiver.acquire_cot().flush().unwrap(); + sender.alloc(n).unwrap(); + receiver.alloc(n).unwrap(); + + while sender.wants_extend() && receiver.wants_extend() { + sender.start_extend().unwrap(); + let msg = receiver.start_extend().unwrap(); + let msg = sender.extend(msg).unwrap(); + let msg = receiver.extend(msg).unwrap(); + let msg = sender.check(msg).unwrap(); + receiver.finish_extend(msg).unwrap(); + sender.finish_extend().unwrap(); + } + + black_box((sender, receiver)); + }) + }); + } + } +} + criterion_group! { name = chou_orlandi_benches; config = Criterion::default().sample_size(50); @@ -89,4 +143,10 @@ criterion_group! { targets = kos } -criterion_main!(chou_orlandi_benches, kos_benches); +criterion_group! { + name = ferret_benches; + config = Criterion::default().sample_size(10); + targets = ferret +} + +criterion_main!(chou_orlandi_benches, kos_benches, ferret_benches); diff --git a/crates/mpz-ot-core/src/cot/derandomize.rs b/crates/mpz-ot-core/src/cot/derandomize.rs index 8f60a0d7..b16916e1 100644 --- a/crates/mpz-ot-core/src/cot/derandomize.rs +++ b/crates/mpz-ot-core/src/cot/derandomize.rs @@ -28,7 +28,8 @@ struct QueuedSend { #[derive(Debug)] pub struct DerandCOTSender { rcot: T, - /// Keys corresponding to the value 0 of the choice bits which need to be derandomized. + /// Keys corresponding to the value 0 of the choice bits which need to be + /// derandomized. adjust: Vec, queue: VecDeque, } @@ -167,7 +168,7 @@ struct QueuedReceive { pub struct DerandCOTReceiver { rcot: T, /// Choice bits from RCOT which need to be derandomized. - derandomize: BitVec, + derandomize: BitVec, queue: VecDeque, } diff --git a/crates/mpz-ot-core/src/ferret.rs b/crates/mpz-ot-core/src/ferret.rs new file mode 100644 index 00000000..0f1e0375 --- /dev/null +++ b/crates/mpz-ot-core/src/ferret.rs @@ -0,0 +1,124 @@ +//! An implementation of the [`Ferret`](https://eprint.iacr.org/2020/924.pdf) protocol. + +mod config; +pub(crate) mod cuckoo; +pub(crate) mod mpcot; +mod receiver; +mod sender; +pub(crate) mod spcot; + +pub use config::{FerretConfig, FerretConfigBuilder, FerretConfigBuilderError}; +pub use receiver::{Receiver, ReceiverError}; +pub use sender::{Sender, SenderError}; + +use blake3::Hash; +use mpz_core::Block; +use serde::{Deserialize, Serialize}; + +use crate::Derandomize; + +/// Initialize message sent from receiver to sender. +#[derive(Debug, Serialize, Deserialize)] +pub struct Init { + seed: Block, +} + +/// Extend message sent from sender to receiver. +#[derive(Debug, Serialize, Deserialize)] +pub struct SenderExtend { + ms: Vec<[Block; 2]>, + sums: Vec, +} + +/// Check message sent from sender to receiver. +#[derive(Debug, Serialize, Deserialize)] +pub struct SenderCheck { + hashed_v: Hash, +} + +/// Extend message sent from receiver to sender. +#[derive(Debug, Serialize, Deserialize)] +pub struct ReceiverExtend { + derandomize: Derandomize, +} + +/// Check message sent from receiver to sender. +#[derive(Debug, Serialize, Deserialize)] +pub struct ReceiverCheck { + derandomize: Derandomize, +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + ferret::config::TEST_PARAMS, + ideal::rcot::IdealRCOT, + rcot::{RCOTReceiver, RCOTReceiverOutput, RCOTSender, RCOTSenderOutput}, + test::assert_cot, + }; + use mpz_core::lpn::LpnType; + use rand::{rngs::StdRng, SeedableRng}; + use rstest::*; + + #[rstest] + #[case::uniform(LpnType::Uniform)] + #[case::regular(LpnType::Regular)] + fn test_ferret(#[case] lpn_type: LpnType) { + use rand::Rng; + + let mut rng = StdRng::seed_from_u64(0); + let delta = rng.gen(); + let cot = IdealRCOT::new(rng.gen(), delta); + + let mut builder = FerretConfig::builder(); + + builder.lpn_type(lpn_type); + builder.param_selector(|_, _, _| TEST_PARAMS); + + let config = builder.build().unwrap(); + let count = TEST_PARAMS.n * 2; + + let mut sender = Sender::new(rng.gen(), config.clone(), cot.clone()); + let mut receiver = Receiver::new(rng.gen(), config, cot); + + assert!(sender.wants_init()); + assert!(receiver.wants_init()); + + let init = receiver.initialize().unwrap(); + sender.initialize(init).unwrap(); + + assert!(!sender.wants_init()); + assert!(!receiver.wants_init()); + + assert!(sender.wants_bootstrap()); + assert!(receiver.wants_bootstrap()); + + sender.alloc_bootstrap().unwrap(); + receiver.alloc_bootstrap().unwrap(); + + sender.acquire_cot().flush().unwrap(); + receiver.acquire_cot().flush().unwrap(); + + sender.alloc(count).unwrap(); + receiver.alloc(count).unwrap(); + + while sender.wants_extend() && receiver.wants_extend() { + sender.start_extend().unwrap(); + let msg = receiver.start_extend().unwrap(); + let msg = sender.extend(msg).unwrap(); + let msg = receiver.extend(msg).unwrap(); + let msg = sender.check(msg).unwrap(); + receiver.finish_extend(msg).unwrap(); + sender.finish_extend().unwrap(); + } + + assert!(!sender.wants_extend()); + assert!(!receiver.wants_extend()); + + let RCOTSenderOutput { keys, .. } = sender.try_send_rcot(count).unwrap(); + let RCOTReceiverOutput { choices, msgs, .. } = receiver.try_recv_rcot(count).unwrap(); + + assert_cot(delta, &choices, &keys, &msgs); + } +} diff --git a/crates/mpz-ot-core/src/ferret/config.rs b/crates/mpz-ot-core/src/ferret/config.rs new file mode 100644 index 00000000..7ffe10f5 --- /dev/null +++ b/crates/mpz-ot-core/src/ferret/config.rs @@ -0,0 +1,188 @@ +use std::{fmt::Debug, sync::Arc}; + +use derive_builder::Builder; + +use mpz_core::lpn::{LpnParameters, LpnType}; + +use crate::ferret::cuckoo::HASH_NUM; + +/// Computational security parameter. +pub(crate) const CSP: usize = 128; + +#[cfg(test)] +pub(crate) const TEST_PARAMS: LpnParameters = LpnParameters { + n: 9600, + k: 1220, + t: 600, +}; + +/// Ferret configuration. +#[derive(Clone, Builder)] +pub struct FerretConfig { + /// LPN type. + #[builder(default = "LpnType::Uniform")] + lpn_type: LpnType, + /// Whether to reserve bootstrap COTs. + #[builder(default = "true")] + reserve_bootstrap: bool, + #[builder(setter(custom), default = "Arc::new(default_parameter_selector)")] + param_selector: Arc LpnParameters + Send + Sync + 'static>, +} + +impl Debug for FerretConfig { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("FerretConfig") + .field("lpn_type", &self.lpn_type) + .field("reserve_bootstrap", &self.reserve_bootstrap) + .finish_non_exhaustive() + } +} + +impl Default for FerretConfig { + fn default() -> Self { + Self { + lpn_type: LpnType::Uniform, + reserve_bootstrap: true, + param_selector: Arc::new(default_parameter_selector), + } + } +} + +impl FerretConfigBuilder { + /// Configures the LPN parameter selector. + /// + /// The provided function must have the following signature: + /// + /// `(LpnType, available, additional) -> LpnParameters` + /// + /// where `available` is the current number of available COTs and + /// `additional` is the number of COTs that still need to be generated. + pub fn param_selector(&mut self, f: F) -> &mut Self + where + F: Fn(LpnType, usize, usize) -> LpnParameters + Send + Sync + 'static, + { + self.param_selector = Some(Arc::new(f)); + self + } +} + +impl FerretConfig { + /// Returns a new `FerretConfigBuilder`. + pub fn builder() -> FerretConfigBuilder { + FerretConfigBuilder::default() + } + + /// Returns `true` if bootstrap COTs should be reserved. + pub fn reserve_bootstrap(&self) -> bool { + self.reserve_bootstrap + } + + /// Returns the LPN type. + pub fn lpn_type(&self) -> LpnType { + self.lpn_type + } + + /// Returns the cost of a bootstrap iteration. + pub(crate) fn bootstrap_cost(&self) -> usize { + match self.lpn_type { + LpnType::Uniform => iteration_cost(self.lpn_type, UNIFORM_PARAMS[0]), + LpnType::Regular => iteration_cost(self.lpn_type, REGULAR_PARAMS[0]), + } + } + + pub(crate) fn select_params(&self, available: usize, additional: usize) -> LpnParameters { + (self.param_selector)(self.lpn_type, available, additional) + } +} + +fn default_parameter_selector(ty: LpnType, available: usize, additional: usize) -> LpnParameters { + let params = match ty { + LpnType::Uniform => UNIFORM_PARAMS, + LpnType::Regular => REGULAR_PARAMS, + }; + + // *Assumes the parameters are in ascending order.* + for param in params { + let cost = iteration_cost(ty, *param); + let net = param.t - cost; + // If we don't have enough available we select the smallest parameters + // immediately. + if available <= cost { + return *param; + } else if net >= additional { + return *param; + } + } + + // If we reach here, we select the largest parameters. + *params.last().unwrap() +} + +/// Returns the number of COTs needed to execute an iteration with the given +/// parameters. +fn iteration_cost(ty: LpnType, params: LpnParameters) -> usize { + match ty { + // The number here is a rough estimation to ensure sufficient buffer. + // It is hard to precisely compute the number because of the Cuckoo hashes. + LpnType::Uniform => { + let m = (1.5 * (params.t as f32)).ceil() as usize; + m * ((2 * HASH_NUM as usize * params.n / m) + .checked_next_power_of_two() + .expect("The length should be less than usize::MAX / 2 - 1") + .ilog2() as usize) + + params.k + + CSP + } + // In our chosen paramters, we always set n is divided by t and n/t is a power of 2. + LpnType::Regular => { + assert!(params.n % params.t == 0 && (params.n / params.t).is_power_of_two()); + params.t * ((params.n / params.t).ilog2() as usize) + params.k + CSP + } + } +} + +static UNIFORM_PARAMS: &[LpnParameters] = &[ + LpnParameters { + n: 545_656, + k: 34_643, + t: 1_050, + }, + LpnParameters { + n: 1_071_888, + k: 40_800, + t: 1720, + }, + LpnParameters { + n: 5_324_800, + k: 240_000, + t: 1_300, + }, + LpnParameters { + n: 10_488_928, + k: 458_000, + t: 1280, + }, +]; + +static REGULAR_PARAMS: &[LpnParameters] = &[ + LpnParameters { + n: 518_656, + k: 34_643, + t: 1_013, + }, + LpnParameters { + n: 1_740_800, + k: 66_400, + t: 1700, + }, + LpnParameters { + n: 5_324_800, + k: 240_000, + t: 1_300, + }, + LpnParameters { + n: 10_485_760, + k: 458_000, + t: 1280, + }, +]; diff --git a/crates/mpz-ot-core/src/ferret/cuckoo.rs b/crates/mpz-ot-core/src/ferret/cuckoo.rs new file mode 100644 index 00000000..5398f667 --- /dev/null +++ b/crates/mpz-ot-core/src/ferret/cuckoo.rs @@ -0,0 +1,192 @@ +//! Implementation of Cuckoo hash. + +use std::array::from_fn; + +use mpz_core::{aes::AesEncryptor, Block}; + +pub(crate) const HASH_NUM: u32 = 3; +const TRIAL_NUM: usize = 100; + +/// Bucket index in the table. +type BucketIdx = usize; +/// Position of item in the bucket. +type BucketPos = usize; + +/// Cuckoo hash insertion error +#[derive(Debug, thiserror::Error)] +#[error("cycle detected in Cuckoo hashing")] +pub(crate) struct CuckooHashError; + +/// Item in Cuckoo hash table. +#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)] +pub(crate) struct Item { + /// Value in the table. + pub(crate) value: u32, + /// Which hash function is used. + pub(crate) hash_idx: u32, +} + +/// Implementation of Cuckoo hash. See [here](https://eprint.iacr.org/2019/1084.pdf) for reference. +pub(crate) struct CuckooHash { + table: Vec>, +} + +impl CuckooHash { + /// Creates a Cuckoo hash table from the provided hashes and items. + pub(crate) fn new( + hashes: &[AesEncryptor; HASH_NUM as usize], + items: &[usize], + ) -> Result { + // Always sets m = 1.5 * t. t is the length of `items`. + let m = compute_table_length(items.len()); + + // Allocates table. + let mut table = vec![None; m]; + // Inserts each item. + for &value in items { + Self::hash(hashes, &mut table, value as u32)?; + } + + Ok(Self { table }) + } + + /// Returns an iterator over the table. + pub(crate) fn iter(&self) -> impl Iterator> { + self.table.iter() + } + + // Hash an element to a position with the current hash function. + #[inline] + fn hash( + hashes: &[AesEncryptor; HASH_NUM as usize], + table: &mut [Option], + value: u32, + ) -> Result<(), CuckooHashError> { + // The item consists of the value and hash index, starting from 0. + let mut item = Item { value, hash_idx: 0 }; + + for _ in 0..TRIAL_NUM { + // Computes the position of the value. + let pos = hash_to_index(&hashes[item.hash_idx as usize], table.len(), item.value); + + // Inserts the value to position `pos`. + let opt_item = table[pos].replace(item); + + // If position `pos` is not empty before the above insertion, iteratively + // inserts the obtained value. + if let Some(x) = opt_item { + item = x; + item.hash_idx = (item.hash_idx + 1) % HASH_NUM; + } else { + // If no value assigned to position `pos`, end the process. + return Ok(()); + } + } + Err(CuckooHashError) + } +} + +/// Implementation of Bucket. See step 3 in Figure 7. +#[derive(Debug)] +pub(crate) struct Buckets { + buckets: Vec, + /// Maps an index to the buckets it is in and the corresponding position in + /// each bucket. + items: Vec<[(BucketIdx, BucketPos); HASH_NUM as usize]>, +} + +impl Buckets { + /// Creates new buckets. + /// + /// # Arguments + /// + /// * `hashes` - Cuckoo hash functions. + /// * `count` - Number of indices that will be queried. + /// * `domain` - Domain of the indices, ie [0, len). + pub(crate) fn new( + hashes: &[AesEncryptor; HASH_NUM as usize], + count: usize, + domain: usize, + ) -> Self { + let m = compute_table_length(count); + + // NOTE: the sorted step in Step 3.c can be removed. + + let mut buckets = vec![0; m]; + let mut items = Vec::with_capacity(domain); + for value in 0..domain as u32 { + items.push(from_fn(|hash_idx| { + let hash = &hashes[hash_idx]; + let bucket_idx = hash_to_index(hash, m, value); + let pos = buckets[bucket_idx]; + buckets[bucket_idx] += 1; + (bucket_idx, pos) + })); + } + + Self { buckets, items } + } + + /// Returns the buckets and positions for the given index. + pub(crate) fn get(&self, idx: usize) -> &[(usize, usize); HASH_NUM as usize] { + &self.items[idx] + } + + /// Returns an iterator over the bucket lengths. + pub(crate) fn iter_buckets(&self) -> impl Iterator + '_ { + self.buckets.iter().copied() + } + + /// Returns an iterator over the bucket indices and item positions. + #[inline] + pub(crate) fn iter_items( + &self, + ) -> impl Iterator + '_ { + self.items.iter() + } +} + +// Always sets m = 1.5 * t. t is the length of `alphas`. See Section 7.1 +// Parameter Selection. +#[inline(always)] +fn compute_table_length(t: usize) -> usize { + (1.5 * (t as f32)).ceil() as usize +} + +// Hash the value into index using AES. +#[inline(always)] +fn hash_to_index(hash: &AesEncryptor, range: usize, value: u32) -> usize { + let mut blk: Block = bytemuck::cast::<_, Block>(value as u128); + blk = hash.encrypt_block(blk); + let res = u128::from_le_bytes(blk.to_bytes()); + (res as usize) % range +} + +#[cfg(test)] +mod tests { + use super::*; + use mpz_core::aes::AesEncryptor; + use rand::{rngs::StdRng, Rng, SeedableRng}; + + #[test] + fn test_cuckoo_buckets() { + let mut rng = StdRng::seed_from_u64(0); + const NUM: usize = 50; + + let hashes = from_fn(|_| AesEncryptor::new(rng.gen())); + + let input: [usize; NUM] = std::array::from_fn(|i| i); + let cuckoo = CuckooHash::new(&hashes, &input).unwrap(); + let buckets = Buckets::new(&hashes, NUM, 2 * NUM); + + for (bucket_idx, item) in cuckoo.table.iter().enumerate() { + if let Some(item) = item { + // Assert this item is in the corresponding bucket. + assert_eq!( + bucket_idx, + buckets.items[item.value as usize][item.hash_idx as usize].0 + ); + } + } + } +} diff --git a/crates/mpz-ot-core/src/ferret/mpcot.rs b/crates/mpz-ot-core/src/ferret/mpcot.rs new file mode 100644 index 00000000..7793ff4c --- /dev/null +++ b/crates/mpz-ot-core/src/ferret/mpcot.rs @@ -0,0 +1,46 @@ +mod receiver; +mod sender; + +pub(crate) use receiver::{state as receiver_state, MPCOTReceiver, MPCOTReceiverError}; +pub(crate) use sender::{state as sender_state, MPCOTSender, MPCOTSenderError}; + +#[cfg(test)] +mod tests { + use super::*; + use crate::ferret::spcot::spcot; + use mpz_core::lpn::{sample_error_indices, LpnType}; + use rand::{rngs::StdRng, Rng, SeedableRng}; + use rstest::*; + + #[rstest] + #[case::uniform(LpnType::Uniform)] + #[case::regular(LpnType::Regular)] + fn test_mpcot(#[case] lpn_type: LpnType) { + let mut rng = StdRng::seed_from_u64(0); + let delta = rng.gen(); + let cuckoo_seed = rng.gen(); + + let sender = MPCOTSender::new(cuckoo_seed, lpn_type); + let receiver = MPCOTReceiver::new(cuckoo_seed, lpn_type); + + let n = 10; + let indices = sample_error_indices(&mut rng, lpn_type, n, 5); + + let (sender, sender_lengths) = sender.start_extend(indices.len(), 10).unwrap(); + let (receiver, receiver_lengths, receiver_idxs) = + receiver.start_extend(&indices, n).unwrap(); + + assert_eq!(sender_lengths, receiver_lengths); + + let (vs, ws) = spcot(&mut rng, &sender_lengths, &receiver_idxs, delta); + + let sender_output = sender.extend(&vs).unwrap(); + let mut receiver_output = receiver.extend(&ws).unwrap(); + + for idx in indices { + receiver_output[idx] ^= delta; + } + + assert_eq!(sender_output, receiver_output); + } +} diff --git a/crates/mpz-ot-core/src/ferret/mpcot/receiver.rs b/crates/mpz-ot-core/src/ferret/mpcot/receiver.rs new file mode 100644 index 00000000..ce1753ed --- /dev/null +++ b/crates/mpz-ot-core/src/ferret/mpcot/receiver.rs @@ -0,0 +1,249 @@ +use std::array::from_fn; + +use mpz_core::{aes::AesEncryptor, lpn::LpnType, prg::Prg, utils::slices_from_lengths, Block}; +use rand_core::SeedableRng; + +use crate::ferret::cuckoo::{Buckets, CuckooHash, CuckooHashError}; + +type Error = MPCOTReceiverError; +type Result = core::result::Result; + +/// MPCOT receiver. +#[derive(Debug, Default)] +pub(crate) struct MPCOTReceiver { + state: T, +} + +impl MPCOTReceiver { + /// Creates a new Receiver. + /// + /// # Arguments + /// + /// * `seed` - Seed for Cuckoo hashing. + /// * `lpn_type` - The LPN type. + pub(crate) fn new(seed: Block, lpn_type: LpnType) -> Self { + let mut prg = Prg::from_seed(seed); + let hashes = from_fn(|_| AesEncryptor::new(prg.random_block())); + + let state = match lpn_type { + LpnType::Uniform => Initialized::Uniform { hashes }, + LpnType::Regular => Initialized::Regular, + }; + + MPCOTReceiver { state } + } +} + +impl MPCOTReceiver { + /// Starts the MPCOT extension. + /// + /// Returns the SPCOT log2 lengths and indices, respectively. + /// + /// See Step 1 to Step 4 in Figure 7. + /// + /// # Arguments + /// + /// * `idxs` - The queried indices. + /// * `len` - Length of the vector. + pub(crate) fn start_extend( + self, + idxs: &[usize], + len: usize, + ) -> Result<(MPCOTReceiver, Vec, Vec)> { + if idxs.len() > len { + return Err(ErrorRepr::Params { + count: idxs.len(), + len: len as usize, + reason: "indices cannot exceed vector length".to_string(), + } + .into()); + } + + let (state, lengths, idxs) = match self.state { + Initialized::Uniform { hashes } => { + let cuckoo = CuckooHash::new(&hashes, idxs).map_err(ErrorRepr::from)?; + let buckets = Buckets::new(&hashes, idxs.len(), len); + + // Generates queries for SPCOT. + // See Step 4 in Figure 7. + let mut idxs = vec![]; + let mut spcot_log2_lengths = vec![]; + let mut spcot_lengths = vec![]; + for (item, bucket_length) in cuckoo.iter().zip(buckets.iter_buckets()) { + // pad to power of 2. + let power_of_two = (bucket_length + 1) + .checked_next_power_of_two() + .expect("bucket length should be less than usize::MAX / 2 - 1"); + + if let Some(x) = item { + let (_, pos) = buckets.get(x.value as usize)[x.hash_idx as usize]; + + idxs.push(pos); + } else { + idxs.push(bucket_length); + } + + spcot_log2_lengths.push(power_of_two.ilog2() as usize); + spcot_lengths.push(power_of_two); + } + + ( + Extension::Uniform { + len, + buckets, + spcot_lengths, + }, + spcot_log2_lengths, + idxs, + ) + } + Initialized::Regular => { + let count = idxs.len(); + let k = len / count; + if len % count != 0 { + return Err(ErrorRepr::Params { + count, + len, + reason: "len should be a multiple of count".to_string(), + } + .into()); + } else if !k.is_power_of_two() { + return Err(ErrorRepr::Params { + count, + len, + reason: "regular interval length must be a power of two".to_string(), + } + .into()); + } + + if !idxs + .iter() + .enumerate() + .all(|(i, &idx)| i * k <= idx && idx < (i + 1) * k) + { + return Err(ErrorRepr::NotRegular.into()); + } + + let log2_len = k.ilog2() as usize; + let spcot_log2_lengths = (0..count).map(|_| log2_len).collect(); + let idxs = idxs.iter().map(|&idx| idx % k).collect(); + + (Extension::Regular { len }, spcot_log2_lengths, idxs) + } + }; + + Ok((MPCOTReceiver { state }, lengths, idxs)) + } +} +impl MPCOTReceiver { + /// Performs MPCOT extension. + /// + /// See Step 5 in Figure 7. + /// + /// # Arguments + /// + /// * `spcot` - The output of SPCOT. + pub(crate) fn extend(self, ws: &[Block]) -> Result> { + match self.state { + Extension::Uniform { + len, + buckets, + spcot_lengths, + } => { + let spcot_len = spcot_lengths.iter().sum::(); + if ws.len() != spcot_len { + return Err(ErrorRepr::SPCOTLength { + expected: spcot_len, + actual: ws.len(), + } + .into()); + } + + let ws = slices_from_lengths(ws, &spcot_lengths); + let mut res = vec![Block::ZERO; len]; + for (x, &bucket_pos) in res.iter_mut().zip(buckets.iter_items()) { + for (bucket_idx, pos) in bucket_pos { + *x ^= ws[bucket_idx][pos]; + } + } + + Ok(res) + } + Extension::Regular { len } => { + if ws.len() != len { + return Err(ErrorRepr::SPCOTLength { + expected: len, + actual: ws.len(), + } + .into()); + } + + Ok(ws.to_vec()) + } + } + } +} + +#[derive(Debug, thiserror::Error)] +#[error(transparent)] +pub(crate) struct MPCOTReceiverError(#[from] ErrorRepr); + +#[derive(Debug, thiserror::Error)] +#[error("MPCOT sender error: {0}")] +enum ErrorRepr { + #[error("invalid parameters, count: {count}, len: {len}: {reason}")] + Params { + count: usize, + len: usize, + reason: String, + }, + #[error("input indices are not regular")] + NotRegular, + #[error("invalid length of SPCOT vector, expected: {expected}, actual: {actual}")] + SPCOTLength { expected: usize, actual: usize }, + #[error("cuckoo hash error: {0}")] + Cuckoo(#[from] CuckooHashError), +} + +pub(crate) mod state { + use crate::ferret::cuckoo::HASH_NUM; + + use super::*; + + mod sealed { + pub(crate) trait Sealed {} + + impl Sealed for super::Initialized {} + impl Sealed for super::Extension {} + } + + pub(crate) trait State: sealed::Sealed {} + + pub(crate) enum Initialized { + Uniform { + hashes: [AesEncryptor; HASH_NUM as usize], + }, + Regular, + } + + impl State for Initialized {} + + opaque_debug::implement!(Initialized); + + pub(crate) enum Extension { + Uniform { + len: usize, + buckets: Buckets, + spcot_lengths: Vec, + }, + Regular { + len: usize, + }, + } + + impl State for Extension {} + + opaque_debug::implement!(Extension); +} + +use state::{Extension, Initialized}; diff --git a/crates/mpz-ot-core/src/ferret/mpcot/sender.rs b/crates/mpz-ot-core/src/ferret/mpcot/sender.rs new file mode 100644 index 00000000..9d8aef3c --- /dev/null +++ b/crates/mpz-ot-core/src/ferret/mpcot/sender.rs @@ -0,0 +1,226 @@ +use std::array::from_fn; + +use mpz_core::{aes::AesEncryptor, lpn::LpnType, prg::Prg, utils::slices_from_lengths, Block}; +use rand_core::SeedableRng; + +use crate::ferret::cuckoo::Buckets; + +type Error = MPCOTSenderError; +type Result = core::result::Result; + +/// MPCOT sender. +#[derive(Debug, Default)] +pub(crate) struct MPCOTSender { + state: T, +} + +impl MPCOTSender { + /// Creates a new Sender. + /// + /// # Arguments. + /// + /// * `seed` - Seed for Cuckoo hash sent by the receiver. + /// * `lpn_type` - The LPN type. + pub(crate) fn new(seed: Block, lpn_type: LpnType) -> Self { + let state = match lpn_type { + LpnType::Uniform => { + let mut prg = Prg::from_seed(seed); + Initialized::Uniform { + hashes: from_fn(|_| AesEncryptor::new(prg.random_block())), + } + } + LpnType::Regular => Initialized::Regular, + }; + + MPCOTSender { state } + } +} + +impl MPCOTSender { + /// Starts the MPCOT extension. + /// + /// Returns the SPCOT log2 lengths. + /// + /// See Step 1 to Step 4 in Figure 7. + /// + /// # Arguments + /// + /// * `count` - Number of queried indices. + /// * `len` - Length of the vector. + pub(crate) fn start_extend( + self, + count: usize, + len: usize, + ) -> Result<(MPCOTSender, Vec)> { + if count > len { + return Err(ErrorRepr::Params { + count, + len, + reason: "indices cannot exceed vector length".to_string(), + } + .into()); + } + + let (state, bs) = match self.state { + Initialized::Uniform { hashes } => { + let buckets = Buckets::new(&hashes, count, len); + + // First pad (length + 1) to a pow of 2, then computes `log(length + 1)` of each + // bucket. + let mut bs = vec![]; + let mut spcot_lengths = vec![]; + for len in buckets.iter_buckets() { + let power_of_two = (len + 1) + .checked_next_power_of_two() + .expect("bucket length should be less than usize::MAX / 2 - 1"); + + bs.push(power_of_two.ilog2() as usize); + spcot_lengths.push(power_of_two); + } + + ( + Extension::Uniform { + len, + buckets, + spcot_lengths, + }, + bs, + ) + } + Initialized::Regular => { + let k = len / count; + if len % count != 0 { + return Err(ErrorRepr::Params { + count, + len, + reason: "len should be a multiple of count".to_string(), + } + .into()); + } else if !k.is_power_of_two() { + return Err(ErrorRepr::Params { + count, + len, + reason: "regular interval length must be a power of two".to_string(), + } + .into()); + } + + let log2_len = k.ilog2() as usize; + let spcot_log2_lengths = (0..count).map(|_| log2_len).collect(); + + (Extension::Regular { len }, spcot_log2_lengths) + } + }; + + Ok((MPCOTSender { state }, bs)) + } +} + +impl MPCOTSender { + /// Performs MPCOT extension. + /// + /// See Step 5 in Figure 7. + /// + /// # Arguments + /// + /// * `spcot` - The output of SPCOT. + pub(crate) fn extend(self, vs: &[Block]) -> Result> { + match self.state { + Extension::Uniform { + len, + buckets, + spcot_lengths, + } => { + let spcot_len = spcot_lengths.iter().sum::(); + if vs.len() != spcot_len { + return Err(ErrorRepr::SPCOTLength { + expected: spcot_len, + actual: vs.len(), + } + .into()); + } + + let vs = slices_from_lengths(vs, &spcot_lengths); + let mut res = vec![Block::ZERO; len]; + for (x, &bucket_pos) in res.iter_mut().zip(buckets.iter_items()) { + for (bucket_idx, pos) in bucket_pos { + *x ^= vs[bucket_idx][pos]; + } + } + + Ok(res) + } + Extension::Regular { len } => { + if vs.len() != len { + return Err(ErrorRepr::SPCOTLength { + expected: len, + actual: vs.len(), + } + .into()); + } + + Ok(vs.to_vec()) + } + } + } +} + +#[derive(Debug, thiserror::Error)] +#[error(transparent)] +pub(crate) struct MPCOTSenderError(#[from] ErrorRepr); + +#[derive(Debug, thiserror::Error)] +#[error("MPCOT sender error: {0}")] +enum ErrorRepr { + #[error("invalid parameters, count: {count}, len: {len}: {reason}")] + Params { + count: usize, + len: usize, + reason: String, + }, + #[error("invalid length of SPCOT vector, expected: {expected}, actual: {actual}")] + SPCOTLength { expected: usize, actual: usize }, +} + +pub(crate) mod state { + use crate::ferret::cuckoo::HASH_NUM; + + use super::*; + + mod sealed { + pub(crate) trait Sealed {} + + impl Sealed for super::Initialized {} + impl Sealed for super::Extension {} + } + + pub(crate) trait State: sealed::Sealed {} + + pub(crate) enum Initialized { + Uniform { + hashes: [AesEncryptor; HASH_NUM as usize], + }, + Regular, + } + + impl State for Initialized {} + + opaque_debug::implement!(Initialized); + + pub(crate) enum Extension { + Uniform { + len: usize, + buckets: Buckets, + spcot_lengths: Vec, + }, + Regular { + len: usize, + }, + } + + impl State for Extension {} + + opaque_debug::implement!(Extension); +} + +use state::{Extension, Initialized}; diff --git a/crates/mpz-ot-core/src/ferret/receiver.rs b/crates/mpz-ot-core/src/ferret/receiver.rs new file mode 100644 index 00000000..d66c1f23 --- /dev/null +++ b/crates/mpz-ot-core/src/ferret/receiver.rs @@ -0,0 +1,441 @@ +use std::{collections::VecDeque, sync::Arc}; + +use rand::{Rng, SeedableRng}; +use tokio::sync::{Mutex, OwnedMutexGuard}; + +use mpz_common::future::{new_output, MaybeDone, Sender as OutputSender}; +use mpz_core::{ + lpn::{sample_error_indices, LpnEncoder, LpnParameters}, + prg::Prg, + Block, +}; + +use crate::{ + ferret::{ + config::CSP, + mpcot::{receiver_state as mpcot_state, MPCOTReceiver, MPCOTReceiverError}, + spcot::{SPCOTReceiver, SPCOTReceiverError}, + FerretConfig, Init, ReceiverCheck, ReceiverExtend, SenderCheck, SenderExtend, + }, + rcot::{RCOTReceiver, RCOTReceiverOutput}, + TransferId, +}; + +type Error = ReceiverError; +type Result = core::result::Result; + +#[derive(Debug)] +struct Queued { + count: usize, + sender: OutputSender>, +} + +/// Ferret receiver. +#[derive(Debug)] +pub struct Receiver { + cot: Arc>, + alloc: usize, + queue: VecDeque, + transfer_id: TransferId, + prg: Prg, + config: FerretConfig, + macs: Vec, + choices: Vec, + state: State, + spcot: SPCOTReceiver, +} + +impl Receiver +where + COT: RCOTReceiver, +{ + /// Creates a new receiver. + pub fn new(seed: Block, config: FerretConfig, cot: COT) -> Self { + Self { + cot: Arc::new(Mutex::new(cot)), + alloc: 0, + queue: VecDeque::new(), + transfer_id: TransferId::default(), + prg: Prg::from_seed(seed), + config, + macs: Vec::new(), + choices: Vec::new(), + state: State::Init, + spcot: SPCOTReceiver::new(), + } + } + + /// Returns a lock on the inner COT sender. + pub fn acquire_cot(&self) -> OwnedMutexGuard { + Mutex::try_lock_owned(self.cot.clone()).unwrap() + } + + /// Returns `true` if the receiver wants to initialize. + pub fn wants_init(&self) -> bool { + matches!(self.state, State::Init) + } + + /// Returns `true` if the receiver wants to bootstrap. + pub fn wants_bootstrap(&self) -> bool { + self.macs.is_empty() + } + + /// Returns `true` if the receiver wants to extend. + pub fn wants_extend(&self) -> bool { + self.alloc > 0 + } + + /// Initializes the receiver. + pub fn initialize(&mut self) -> Result { + let State::Init = self.state.take() else { + return Err(ErrorRepr::State("not in initialize state".to_string()).into()); + }; + + let seed = self.prg.gen(); + + self.state = State::Extend(Extend { + public_prg: Prg::from_seed(seed), + }); + + Ok(Init { seed }) + } + + /// Allocates COTs for bootstrapping. + pub fn alloc_bootstrap(&self) -> Result<()> { + let cost = self.config.bootstrap_cost(); + self.cot + .try_lock() + .map_err(|_| ErrorRepr::MutexLocked)? + .alloc(cost) + .map_err(Error::bootstrap)?; + + Ok(()) + } + + /// Starts extension. + pub fn start_extend(&mut self) -> Result { + let State::Extend(Extend { mut public_prg }) = self.state.take() else { + return Err(ErrorRepr::State("not in extend state".to_string()).into()); + }; + + // If COTs are empty we haven't bootstrapped from inner COT yet. + if self.macs.is_empty() { + let RCOTReceiverOutput { + msgs: macs, + choices, + .. + } = self + .cot + .try_lock() + .map_err(|_| ErrorRepr::MutexLocked)? + .try_recv_rcot(self.config.bootstrap_cost()) + .map_err(|e| ErrorRepr::Bootstrap(Box::new(e)))?; + + self.macs.extend_from_slice(&macs); + self.choices.extend_from_slice(&choices); + } + + let lpn_type = self.config.lpn_type(); + let params = self.config.select_params(self.macs.len(), self.alloc); + + let err = sample_error_indices(&mut self.prg, lpn_type, params.n, params.t); + + let (mpcot, spcot_lengths, spcot_idxs) = + MPCOTReceiver::new(public_prg.gen(), lpn_type).start_extend(&err, params.n)?; + + let spcot_count: usize = spcot_lengths.iter().sum(); + let masks = &self.choices[self.choices.len() - spcot_count..]; + let derandomize = self.spcot.derandomize(&spcot_lengths, &spcot_idxs, masks)?; + + // Drop used COT choices. + self.choices.truncate(self.choices.len() - spcot_count); + + self.state = State::Extending(Extending { + public_prg, + start: self.macs.len(), + params, + err, + mpcot, + spcot_count, + spcot_lengths, + spcot_idxs, + }); + + Ok(ReceiverExtend { derandomize }) + } + + /// Performs extension. + /// + /// # Arguments + /// + /// * `msg` - The sender's extend message. + pub fn extend(&mut self, msg: SenderExtend) -> Result { + let SenderExtend { ms, sums } = msg; + + let State::Extending(Extending { + public_prg, + start, + params, + err: e, + mpcot, + spcot_count, + spcot_lengths, + spcot_idxs, + }) = self.state.take() + else { + return Err(ErrorRepr::State("not in extending state".to_string()).into()); + }; + + let macs = &self.macs[self.macs.len() - spcot_count..]; + let ws = self + .spcot + .extend(&spcot_lengths, &spcot_idxs, macs, &ms, &sums)?; + + // Drop used COTs. + self.macs.truncate(self.macs.len() - spcot_count); + + let r = mpcot.extend(ws)?; + + let macs = &self.macs[self.macs.len() - CSP..]; + let masks = &self.choices[self.choices.len() - CSP..]; + let derandomize = self.spcot.start_check(macs, masks)?; + + // Drop used COTs. + self.macs.truncate(self.macs.len() - CSP); + self.choices.truncate(self.choices.len() - CSP); + + self.state = State::Finish(Finish { + public_prg, + start, + params, + err: e, + r, + }); + + Ok(ReceiverCheck { derandomize }) + } + + /// Finishes extension. + /// + /// # Arguments + /// + /// * `msg` - The sender's check message. + pub fn finish_extend(&mut self, msg: SenderCheck) -> Result<()> { + let SenderCheck { hashed_v } = msg; + + let State::Finish(Finish { + mut public_prg, + start, + params, + err, + r, + }) = self.state.take() + else { + return Err(ErrorRepr::State("not in finish state".to_string()).into()); + }; + + self.spcot.check(hashed_v)?; + + let encoder = LpnEncoder::<10>::new(params.k as u32); + let lpn_seed = public_prg.gen(); + + // Compute z = A * w + r. + let w = &self.macs[self.macs.len() - params.k..]; + let mut z = r; + encoder.compute(lpn_seed, &mut z, w); + + self.macs.truncate(self.macs.len() - params.k); + + // Compute x = A * u + e. + let u: Vec<_> = self.choices[self.choices.len() - params.k..] + .iter() + .map(|x| if *x { Block::ONE } else { Block::ZERO }) + .collect(); + let mut x = vec![Block::ZERO; params.n]; + for &idx in &err { + x[idx] = Block::ONE; + } + + encoder.compute(lpn_seed, &mut x, &u); + + self.choices.truncate(self.choices.len() - params.k); + + let x: Vec<_> = x.iter().map(|x| x.lsb()).collect(); + + self.macs.extend_from_slice(&z); + self.choices.extend_from_slice(&x); + + self.alloc = self.alloc.saturating_sub(self.macs.len() - start); + if self.alloc == 0 { + // We've finished extending. + self.process_queue(); + } + + self.state = State::Extend(Extend { public_prg }); + + Ok(()) + } + + fn process_queue(&mut self) { + while let Some(next) = self.queue.pop_front() { + if self.available() < next.count { + self.queue.push_front(next); + return; + } + + let id = self.transfer_id.next(); + let macs = self.macs.split_off(self.macs.len() - next.count); + let choices = self.choices.split_off(self.choices.len() - next.count); + + _ = next.sender.send(RCOTReceiverOutput { + id, + msgs: macs, + choices, + }); + } + } +} + +impl RCOTReceiver for Receiver +where + COT: RCOTReceiver, +{ + type Error = ReceiverError; + type Future = MaybeDone>; + + fn alloc(&mut self, count: usize) -> Result<()> { + self.alloc += count; + Ok(()) + } + + fn available(&self) -> usize { + if self.config.reserve_bootstrap() { + self.macs.len().saturating_sub(self.config.bootstrap_cost()) + } else { + self.macs.len() + } + } + + fn try_recv_rcot(&mut self, count: usize) -> Result> { + if self.available() < count { + return Err(ErrorRepr::InsufficientCOTs { + expected: count, + actual: self.available(), + } + .into()); + } + + let choices = self.choices.split_off(self.choices.len() - count); + let keys = self.macs.split_off(self.macs.len() - count); + + Ok(RCOTReceiverOutput { + id: self.transfer_id.next(), + choices, + msgs: keys, + }) + } + + fn queue_recv_rcot(&mut self, count: usize) -> Result { + if self.available() >= count { + let output = self.try_recv_rcot(count)?; + let (sender, recv) = new_output(); + sender.send(output); + + return Ok(recv); + } else { + let (sender, recv) = new_output(); + + self.queue.push_back(Queued { count, sender }); + + return Ok(recv); + } + } +} + +enum State { + Init, + Extend(Extend), + Extending(Extending), + Finish(Finish), + Error, +} + +opaque_debug::implement!(State); + +impl State { + fn take(&mut self) -> Self { + std::mem::replace(self, State::Error) + } +} + +struct Extend { + public_prg: Prg, +} + +struct Extending { + public_prg: Prg, + start: usize, + params: LpnParameters, + err: Vec, + mpcot: MPCOTReceiver, + spcot_count: usize, + spcot_lengths: Vec, + spcot_idxs: Vec, +} + +struct Finish { + public_prg: Prg, + start: usize, + params: LpnParameters, + err: Vec, + r: Vec, +} + +/// Ferret receiver error. +#[derive(Debug, thiserror::Error)] +#[error(transparent)] +pub struct ReceiverError(ErrorRepr); + +impl ReceiverError { + fn bootstrap(err: E) -> Self + where + E: Into>, + { + Self(ErrorRepr::Bootstrap(err.into())) + } +} + +#[derive(Debug, thiserror::Error)] +#[error("ferret receiver error: {0}")] +enum ErrorRepr { + #[error("invalid state: {0}")] + State(String), + #[error("bootstrap COT mutex is still locked")] + MutexLocked, + #[error("bootstrap COT error: {0}")] + Bootstrap(Box), + #[error("SPCOT receiver error: {0}")] + SPCOT(SPCOTReceiverError), + #[error("MPCOT receiver error: {0}")] + MPCOT(MPCOTReceiverError), + #[error("insufficient COTs: expected {expected}, actual {actual}")] + InsufficientCOTs { expected: usize, actual: usize }, +} + +impl From for ReceiverError { + fn from(repr: ErrorRepr) -> Self { + Self(repr) + } +} + +impl From for ReceiverError { + fn from(err: SPCOTReceiverError) -> Self { + Self(ErrorRepr::SPCOT(err)) + } +} + +impl From for ReceiverError { + fn from(err: MPCOTReceiverError) -> Self { + Self(ErrorRepr::MPCOT(err)) + } +} diff --git a/crates/mpz-ot-core/src/ferret/sender.rs b/crates/mpz-ot-core/src/ferret/sender.rs new file mode 100644 index 00000000..07c0a8f4 --- /dev/null +++ b/crates/mpz-ot-core/src/ferret/sender.rs @@ -0,0 +1,425 @@ +use std::{collections::VecDeque, sync::Arc}; + +use rand::{Rng, SeedableRng}; +use tokio::sync::{Mutex, OwnedMutexGuard}; + +use mpz_common::future::{new_output, MaybeDone, Sender as OutputSender}; +use mpz_core::{ + lpn::{LpnEncoder, LpnParameters}, + prg::Prg, + Block, +}; + +use crate::{ + ferret::{ + config::CSP, + mpcot::{sender_state as mpcot_state, MPCOTSender, MPCOTSenderError}, + spcot::{SPCOTSender, SPCOTSenderError}, + FerretConfig, Init, ReceiverCheck, ReceiverExtend, SenderCheck, SenderExtend, + }, + rcot::{RCOTSender, RCOTSenderOutput}, + TransferId, +}; + +type Error = SenderError; +type Result = core::result::Result; + +#[derive(Debug)] +struct Queued { + count: usize, + sender: OutputSender>, +} + +/// Ferret sender. +#[derive(Debug)] +pub struct Sender { + cot: Arc>, + alloc: usize, + queue: VecDeque, + transfer_id: TransferId, + prg: Prg, + delta: Block, + config: FerretConfig, + keys: Vec, + state: State, + spcot: SPCOTSender, +} + +impl Sender +where + COT: RCOTSender, +{ + /// Creates a new sender. + pub fn new(seed: Block, config: FerretConfig, cot: COT) -> Self { + let delta = cot.delta(); + Self { + cot: Arc::new(Mutex::new(cot)), + alloc: 0, + queue: VecDeque::new(), + transfer_id: TransferId::default(), + prg: Prg::from_seed(seed), + delta, + config, + keys: Vec::new(), + state: State::Init, + spcot: SPCOTSender::new(delta), + } + } + + /// Returns a lock on the inner COT sender. + pub fn acquire_cot(&self) -> OwnedMutexGuard { + Mutex::try_lock_owned(self.cot.clone()).unwrap() + } + + /// Returns `true` if the sender wants to initialize. + pub fn wants_init(&self) -> bool { + matches!(self.state, State::Init) + } + + /// Returns `true` if the sender wants to bootstrap. + pub fn wants_bootstrap(&self) -> bool { + self.keys.is_empty() + } + + /// Returns `true` if the sender wants to extend. + pub fn wants_extend(&self) -> bool { + self.alloc > 0 + } + + /// Initializes the sender, receiving message from the receiver. + pub fn initialize(&mut self, init: Init) -> Result<()> { + let State::Init = self.state.take() else { + return Err(ErrorRepr::State("not in initialize state".to_string()).into()); + }; + + let Init { seed } = init; + + self.state = State::Extend(Extend { + public_prg: Prg::from_seed(seed), + }); + + Ok(()) + } + + /// Allocates COTs for bootstrapping. + pub fn alloc_bootstrap(&self) -> Result<()> { + let cost = self.config.bootstrap_cost(); + self.cot + .try_lock() + .map_err(|_| ErrorRepr::MutexLocked)? + .alloc(cost) + .map_err(Error::bootstrap)?; + + Ok(()) + } + + /// Starts extension. + pub fn start_extend(&mut self) -> Result<()> { + let State::Extend(Extend { mut public_prg }) = self.state.take() else { + return Err(ErrorRepr::State("not in extend state".to_string()).into()); + }; + + // If COTs are empty we haven't bootstrapped from inner COT yet. + if self.keys.is_empty() { + let RCOTSenderOutput { keys, .. } = self + .cot + .try_lock() + .map_err(|_| ErrorRepr::MutexLocked)? + .try_send_rcot(self.config.bootstrap_cost()) + .map_err(|e| ErrorRepr::Bootstrap(Box::new(e)))?; + + self.keys.extend_from_slice(&keys); + } + + let params = self.config.select_params(self.keys.len(), self.alloc); + + let (mpcot, spcot_lengths) = MPCOTSender::new(public_prg.gen(), self.config.lpn_type()) + .start_extend(params.t, params.n)?; + + self.state = State::Extending(Extending { + public_prg, + start: self.keys.len(), + params, + mpcot, + spcot_lengths, + }); + + Ok(()) + } + + /// Performs extension. + /// + /// # Arguments + /// + /// * `msg` - Receiver extend message. + pub fn extend(&mut self, msg: ReceiverExtend) -> Result { + let ReceiverExtend { derandomize } = msg; + + let State::Extending(Extending { + public_prg, + start, + params, + mpcot, + spcot_lengths, + }) = self.state.take() + else { + return Err(ErrorRepr::State("not in extending state".to_string()).into()); + }; + + let spcot_count: usize = spcot_lengths.iter().sum(); + let spcot_keys = &self.keys[self.keys.len() - spcot_count..]; + + let (vs, ms, sums) = self.spcot.extend( + &mut self.prg, + &spcot_lengths, + &spcot_keys, + &derandomize.flip, + )?; + + // Drop used keys. + self.keys.truncate(self.keys.len() - spcot_count); + + let s = mpcot.extend(vs)?; + + self.state = State::Check(Check { + public_prg, + start, + params, + s, + }); + + Ok(SenderExtend { ms, sums }) + } + + /// Performs the SPCOT consistency check. + /// + /// # Arguments + /// + /// * `msg` - Receiver check message. + pub fn check(&mut self, msg: ReceiverCheck) -> Result { + let ReceiverCheck { derandomize } = msg; + + let State::Check(Check { + public_prg, + start, + params, + s, + }) = self.state.take() + else { + return Err(ErrorRepr::State("not in check state".to_string()).into()); + }; + + let check_keys = &self.keys[self.keys.len() - CSP..]; + let hashed_v = self.spcot.check(check_keys, &derandomize.flip)?; + + // Drop used keys. + self.keys.truncate(self.keys.len() - CSP); + + self.state = State::Finish(Finish { + public_prg, + start, + params, + s, + }); + + Ok(SenderCheck { hashed_v }) + } + + /// Finishes the extension. + pub fn finish_extend(&mut self) -> Result<()> { + let State::Finish(Finish { + mut public_prg, + start, + params, + s, + }) = self.state.take() + else { + return Err(ErrorRepr::State("not in finish state".to_string()).into()); + }; + + let encoder = LpnEncoder::<10>::new(params.k as u32); + let lpn_seed = public_prg.gen(); + + // Compute y = A * v + s + let v = &self.keys[self.keys.len() - params.k..]; + let mut y = s; + encoder.compute(lpn_seed, &mut y, &v); + + self.keys.truncate(self.keys.len() - params.k); + self.keys.extend_from_slice(&y); + + self.alloc = self.alloc.saturating_sub(self.keys.len() - start); + if self.alloc == 0 { + // We've finished extending. + self.process_queue(); + } + + self.state = State::Extend(Extend { public_prg }); + + Ok(()) + } + + fn process_queue(&mut self) { + while let Some(next) = self.queue.pop_front() { + if self.available() < next.count { + self.queue.push_front(next); + return; + } + + let id = self.transfer_id.next(); + let keys = self.keys.split_off(self.keys.len() - next.count); + + _ = next.sender.send(RCOTSenderOutput { id, keys }); + } + } +} + +impl RCOTSender for Sender +where + COT: RCOTSender, +{ + type Error = SenderError; + type Future = MaybeDone>; + + fn alloc(&mut self, count: usize) -> Result<()> { + self.alloc += count; + Ok(()) + } + + fn available(&self) -> usize { + if self.config.reserve_bootstrap() { + self.keys.len().saturating_sub(self.config.bootstrap_cost()) + } else { + self.keys.len() + } + } + + fn delta(&self) -> Block { + self.delta + } + + fn try_send_rcot(&mut self, count: usize) -> Result, Self::Error> { + if self.available() < count { + return Err(ErrorRepr::InsufficientCOTs { + expected: count, + actual: self.available(), + } + .into()); + } + + let keys = self.keys.split_off(self.keys.len() - count); + + Ok(RCOTSenderOutput { + id: self.transfer_id.next(), + keys, + }) + } + + fn queue_send_rcot(&mut self, count: usize) -> Result { + if self.available() >= count { + let output = self.try_send_rcot(count)?; + let (sender, recv) = new_output(); + sender.send(output); + + return Ok(recv); + } else { + let (sender, recv) = new_output(); + + self.queue.push_back(Queued { count, sender }); + + return Ok(recv); + } + } +} + +enum State { + Init, + Extend(Extend), + Extending(Extending), + Check(Check), + Finish(Finish), + Error, +} + +opaque_debug::implement!(State); + +impl State { + fn take(&mut self) -> Self { + std::mem::replace(self, State::Error) + } +} + +struct Extend { + public_prg: Prg, +} + +struct Extending { + public_prg: Prg, + start: usize, + params: LpnParameters, + mpcot: MPCOTSender, + spcot_lengths: Vec, +} + +struct Check { + public_prg: Prg, + start: usize, + params: LpnParameters, + s: Vec, +} + +struct Finish { + public_prg: Prg, + start: usize, + params: LpnParameters, + s: Vec, +} + +/// Ferret sender error. +#[derive(Debug, thiserror::Error)] +#[error(transparent)] +pub struct SenderError(ErrorRepr); + +impl SenderError { + fn bootstrap(err: E) -> Self + where + E: Into>, + { + Self(ErrorRepr::Bootstrap(err.into())) + } +} + +#[derive(Debug, thiserror::Error)] +#[error("ferret sender error: {0}")] +enum ErrorRepr { + #[error("invalid state: {0}")] + State(String), + #[error("bootstrap COT mutex is still locked")] + MutexLocked, + #[error("bootstrap COT error: {0}")] + Bootstrap(Box), + #[error("SPCOT sender error: {0}")] + SPCOT(SPCOTSenderError), + #[error("MPCOT sender error: {0}")] + MPCOT(MPCOTSenderError), + #[error("insufficient COTs: expected {expected}, actual {actual}")] + InsufficientCOTs { expected: usize, actual: usize }, +} + +impl From for SenderError { + fn from(repr: ErrorRepr) -> Self { + Self(repr) + } +} + +impl From for SenderError { + fn from(e: SPCOTSenderError) -> Self { + Self(ErrorRepr::SPCOT(e)) + } +} + +impl From for SenderError { + fn from(e: MPCOTSenderError) -> Self { + Self(ErrorRepr::MPCOT(e)) + } +} diff --git a/crates/mpz-ot-core/src/ferret/spcot.rs b/crates/mpz-ot-core/src/ferret/spcot.rs new file mode 100644 index 00000000..3e8d9da6 --- /dev/null +++ b/crates/mpz-ot-core/src/ferret/spcot.rs @@ -0,0 +1,151 @@ +//! Implementation of the Single-Point COT (spcot) protocol in the [`Ferret`](https://eprint.iacr.org/2020/924.pdf) paper. + +mod receiver; +mod sender; + +pub(crate) use receiver::{SPCOTReceiver, SPCOTReceiverError}; +pub(crate) use sender::{SPCOTSender, SPCOTSenderError}; + +#[cfg(test)] +use mpz_core::Block; + +#[cfg(test)] +/// Generates ideal SPCOT outputs. +/// +/// Returns the sender and receiver outputs, respectively. +pub(crate) fn spcot( + rng: &mut R, + lengths: &[usize], + idxs: &[usize], + delta: Block, +) -> (Vec, Vec) { + assert_eq!(lengths.len(), idxs.len()); + + let total_length = lengths.iter().map(|length| 1 << length).sum(); + let vs: Vec = (0..total_length).map(|_| rng.gen()).collect(); + let mut ws = vs.clone(); + + let mut i = 0; + for (&idx, &length) in idxs.iter().zip(lengths) { + ws[i + idx] ^= delta; + i += 1 << length; + } + + (vs, ws) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + ferret::config::CSP, + ideal::rcot::IdealRCOT, + rcot::{RCOTReceiverOutput, RCOTSenderOutput}, + test::assert_spcot, + }; + use mpz_core::utils::slices_from_lengths; + use rand::{rngs::StdRng, Rng, SeedableRng}; + + fn execute( + rng: &mut R, + sender: &mut SPCOTSender, + receiver: &mut SPCOTReceiver, + lengths: &[usize], + idxs: &[usize], + ) -> (Vec, Vec) { + let len_sum: usize = lengths.iter().sum(); + + let mut cot = IdealRCOT::new(rng.gen(), sender.delta()); + cot.alloc(len_sum + CSP); + cot.flush().unwrap(); + + let ( + RCOTSenderOutput { keys, .. }, + RCOTReceiverOutput { + choices: masks, + msgs: macs, + .. + }, + ) = cot.transfer(len_sum).unwrap(); + + let derandomize = receiver.derandomize(&lengths, &idxs, &masks).unwrap(); + + let (vs, ms, sums) = sender + .extend(rng, &lengths, &keys, &derandomize.flip) + .unwrap(); + let ws = receiver.extend(&lengths, &idxs, &macs, &ms, &sums).unwrap(); + + let vs = vs.to_vec(); + let ws = ws.to_vec(); + + let spcot_lengths = lengths.iter().map(|length| 1 << length).collect::>(); + for ((v, w), &idx) in slices_from_lengths(&vs, &spcot_lengths) + .into_iter() + .zip(slices_from_lengths(&ws, &spcot_lengths)) + .zip(idxs) + { + assert_spcot(sender.delta(), &w, idx, &v); + } + + assert!(sender.wants_check()); + assert!(receiver.wants_check()); + + let ( + RCOTSenderOutput { keys, .. }, + RCOTReceiverOutput { + choices: masks, + msgs: macs, + .. + }, + ) = cot.transfer(CSP).unwrap(); + + let derandomize = receiver.start_check(&macs, &masks).unwrap(); + let hashed_v = sender.check(&keys, &derandomize.flip).unwrap(); + receiver.check(hashed_v).unwrap(); + + assert!(!sender.wants_check()); + assert!(!receiver.wants_check()); + + (vs, ws) + } + + #[test] + fn test_spcot() { + let mut rng = StdRng::seed_from_u64(0); + let delta = rng.gen(); + + let mut sender = SPCOTSender::new(delta); + let mut receiver = SPCOTReceiver::new(); + + // Execute twice. + for _ in 0..2 { + let lengths: Vec = (1..8).collect(); + let idxs: Vec = (1..8).map(|n| rng.gen_range(0..1 << n)).collect(); + execute(&mut rng, &mut sender, &mut receiver, &lengths, &idxs); + } + } + + #[test] + fn test_ideal_spcot() { + let mut rng = StdRng::seed_from_u64(0); + let delta = rng.gen(); + + let idxs: Vec<_> = (0..8).map(|n| rng.gen_range(0..1 << n)).collect(); + let lengths: Vec<_> = (0..8).collect(); + + let (vs, ws) = spcot(&mut rng, &lengths, &idxs, delta); + + assert_eq!(vs.len(), ws.len()); + + let mut i = 0; + for (&idx, &length) in idxs.iter().zip(&lengths) { + let length = 1 << length; + let v = &vs[i..i + length]; + let w = &ws[i..i + length]; + + assert_spcot(delta, w, idx, v); + + i += length; + } + } +} diff --git a/crates/mpz-ot-core/src/ferret/spcot/receiver.rs b/crates/mpz-ot-core/src/ferret/spcot/receiver.rs new file mode 100644 index 00000000..6631d661 --- /dev/null +++ b/crates/mpz-ot-core/src/ferret/spcot/receiver.rs @@ -0,0 +1,292 @@ +use blake3::{hash, Hash, Hasher}; +use cfg_if::cfg_if; +use itybity::ToBits; +use rand::SeedableRng; +#[cfg(feature = "rayon")] +use rayon::prelude::*; + +use mpz_core::{ + aes::FIXED_KEY_AES, + bitvec::BitVec, + ggm::GgmTree, + prg::Prg, + utils::{slices_from_lengths, slices_from_lengths_mut}, + Block, +}; + +use crate::{ferret::config::CSP, Derandomize}; + +type Error = SPCOTReceiverError; +type Result = core::result::Result; + +#[derive(Debug)] +struct Check { + z: Block, + chis: Vec, +} + +#[derive(Debug)] +pub(crate) struct SPCOTReceiver { + ws: Vec, + lengths: Vec, + indices: Vec, + check: Option, + counter: u128, + transcript: Hasher, +} + +impl SPCOTReceiver { + /// Creates a new SPCOT receiver. + pub(crate) fn new() -> Self { + Self { + ws: Vec::new(), + lengths: Vec::new(), + indices: Vec::new(), + check: None, + counter: 0, + transcript: Hasher::new(), + } + } + + #[cfg(test)] + pub(crate) fn wants_check(&self) -> bool { + !self.ws.is_empty() + } + + /// Derandomizes OT messages for SPCOTs. + /// + /// # Arguments + /// + /// * `log2_lengths` - log2 length of the SPCOT vectors. + /// * `idxs` - Chosen SPCOT indices. + /// * `masks` - Random COT choice masks. + pub(crate) fn derandomize( + &mut self, + log2_lengths: &[usize], + idxs: &[usize], + masks: &[bool], + ) -> Result { + let sum: usize = log2_lengths.iter().sum(); + if idxs.len() != log2_lengths.len() { + return Err(ErrorRepr::IndexCount { + expected: log2_lengths.len(), + actual: idxs.len(), + } + .into()); + } else if masks.len() != sum { + return Err(ErrorRepr::MaskCount { + expected: sum, + actual: masks.len(), + } + .into()); + } + + let flip = BitVec::from_iter( + idxs.iter() + .zip(log2_lengths) + .flat_map(|(idx, length)| idx.iter_msb0().skip(usize::BITS as usize - length)) + .zip(masks) + .map(|(b, m)| !b ^ m), + ); + + self.transcript.update(flip.as_raw_slice()); + + Ok(Derandomize { flip }) + } + + /// Computes multiple SPCOTs. + /// + /// Returns the SPCOT vectors. + /// + /// # Arguments + /// + /// * `log2_lengths` - log2 length of the SPCOT vectors. + /// * `idxs` - Chosen SPCOT indices. + /// * `macs` - COT MACs used to decrypt OT messages. + /// * `ms` - OT messages. + /// * `sums` - SPCOT sums. + pub(crate) fn extend( + &mut self, + log2_lengths: &[usize], + idxs: &[usize], + macs: &[Block], + ms: &[[Block; 2]], + sums: &[Block], + ) -> Result<&[Block]> { + let len_sum: usize = log2_lengths.iter().sum(); + if idxs.len() != log2_lengths.len() { + return Err(ErrorRepr::IndexCount { + expected: log2_lengths.len(), + actual: idxs.len(), + } + .into()); + } else if macs.len() != len_sum { + return Err(ErrorRepr::MacCount { + expected: len_sum, + actual: macs.len(), + } + .into()); + } else if ms.len() != len_sum { + return Err(ErrorRepr::MsgCount { + expected: len_sum, + actual: ms.len(), + } + .into()); + } else if sums.len() != log2_lengths.len() { + return Err(ErrorRepr::SumCount { + expected: log2_lengths.len(), + actual: sums.len(), + } + .into()); + } + + let cipher = &(*FIXED_KEY_AES); + let ggm_sums: Vec = ms + .iter() + .zip(macs) + .zip( + idxs.iter() + .zip(log2_lengths) + .flat_map(|(idx, length)| idx.iter_msb0().skip(usize::BITS as usize - length)), + ) + .enumerate() + .map(|(i, (([m0, m1], &t), b))| { + let tweak = Block::from((self.counter + i as u128).to_be_bytes()); + if !b { + cipher.tccr(tweak, t) ^ m1 + } else { + cipher.tccr(tweak, t) ^ m0 + } + }) + .collect(); + + // Allocate space for the outputs. + let len: usize = log2_lengths.iter().map(|length| 1 << length).sum(); + let start = self.ws.len(); + self.ws.resize_with(start + len, || Block::ZERO); + + let spcot_lengths: Vec<_> = log2_lengths.iter().map(|length| 1 << length).collect(); + let ggm_sums = slices_from_lengths(&ggm_sums, log2_lengths); + let ws = slices_from_lengths_mut(&mut self.ws[start..], &spcot_lengths); + + let iter = { + cfg_if! { + if #[cfg(feature = "rayon")] { + ws.into_par_iter() + } else { + ws.into_iter() + } + } + }; + + iter.zip(ggm_sums) + .zip(sums) + .zip(log2_lengths) + .zip(idxs) + .for_each(|((((w, sums), sum), &length), &idx)| { + GgmTree::new_partial(length, sums, idx, w); + + w[idx] = w.iter().fold(*sum, |acc, &x| acc ^ x); + }); + + self.transcript.update(Block::array_as_flattened_bytes(ms)); + self.transcript.update(Block::as_flattened_bytes(sums)); + self.lengths.extend_from_slice(log2_lengths); + self.indices.extend_from_slice(idxs); + self.counter += len_sum as u128; + + Ok(&self.ws[start..]) + } + + pub(crate) fn start_check(&mut self, macs: &[Block], masks: &[bool]) -> Result { + if self.check.is_some() { + return Err(ErrorRepr::State("check already started".to_string()).into()); + } else if macs.len() != CSP { + return Err(ErrorRepr::MacCount { + expected: CSP, + actual: macs.len(), + } + .into()); + } else if masks.len() != CSP { + return Err(ErrorRepr::MaskCount { + expected: CSP, + actual: masks.len(), + } + .into()); + } + + let seed = *self.transcript.finalize().as_bytes(); + let mut prg = Prg::from_seed(Block::try_from(&seed[0..16]).unwrap()); + + // The sum of all the chi[alpha]. + let mut sum_chi_alpha = Block::ZERO; + + let mut chis = vec![Block::ZERO; self.ws.len()]; + prg.random_blocks(&mut chis); + + let mut i = 0; + for (length, idx) in self.lengths.iter().zip(&self.indices) { + sum_chi_alpha ^= chis[i + idx]; + i += 1 << length; + } + + let x_prime = BitVec::from_iter( + sum_chi_alpha + .iter_lsb0() + .zip(masks) + .map(|(x, &x_star)| x != x_star), + ); + + let z = Block::inn_prdt_red(macs, &Block::MONOMIAL); + + self.check = Some(Check { z, chis }); + + Ok(Derandomize { flip: x_prime }) + } + + pub(crate) fn check(&mut self, hashed_v: Hash) -> Result<()> { + let Some(Check { z, chis }) = self.check.take() else { + return Err(ErrorRepr::State("check not started".to_string()).into()); + }; + + // Computes W. + let w = z ^ Block::inn_prdt_red(&chis, &self.ws); + + // Computes H'(W) + let hashed_w = hash(&w.to_bytes()); + + if hashed_v != hashed_w { + return Err(ErrorRepr::Check.into()); + } + + self.ws.clear(); + self.lengths.clear(); + self.indices.clear(); + self.transcript.reset(); + + Ok(()) + } +} + +#[derive(Debug, thiserror::Error)] +#[error(transparent)] +pub(crate) struct SPCOTReceiverError(#[from] ErrorRepr); + +#[derive(Debug, thiserror::Error)] +#[error("SPCOT receiver error: {0}")] +enum ErrorRepr { + #[error("invalid state: {0}")] + State(String), + #[error("incorrect index count, expected: {expected}, actual: {actual}")] + IndexCount { expected: usize, actual: usize }, + #[error("incorrect COT MAC count, expected: {expected}, actual: {actual}")] + MacCount { expected: usize, actual: usize }, + #[error("incorrect COT mask count, expected: {expected}, actual: {actual}")] + MaskCount { expected: usize, actual: usize }, + #[error("incorrect OT message count, expected: {expected}, actual: {actual}")] + MsgCount { expected: usize, actual: usize }, + #[error("incorrect SPCOT sum count, expected: {expected}, actual: {actual}")] + SumCount { expected: usize, actual: usize }, + #[error("invalid consistency check")] + Check, +} diff --git a/crates/mpz-ot-core/src/ferret/spcot/sender.rs b/crates/mpz-ot-core/src/ferret/spcot/sender.rs new file mode 100644 index 00000000..0b723f2a --- /dev/null +++ b/crates/mpz-ot-core/src/ferret/spcot/sender.rs @@ -0,0 +1,205 @@ +use blake3::{hash, Hash, Hasher}; +use cfg_if::cfg_if; +use rand::{Rng, SeedableRng}; +#[cfg(feature = "rayon")] +use rayon::prelude::*; + +use mpz_core::{ + aes::FIXED_KEY_AES, bitvec::BitVec, ggm::GgmTree, prg::Prg, utils::slices_from_lengths_mut, + Block, +}; + +use crate::ferret::config::CSP; + +type Error = SPCOTSenderError; +type Result = core::result::Result; + +#[derive(Debug)] +pub(crate) struct SPCOTSender { + vs: Vec, + delta: Block, + counter: u128, + transcript: Hasher, +} + +impl SPCOTSender { + /// Creates a new SPCOT sender. + pub(crate) fn new(delta: Block) -> Self { + Self { + vs: Vec::new(), + delta, + counter: 0, + transcript: Hasher::new(), + } + } + + #[cfg(test)] + pub(crate) fn delta(&self) -> Block { + self.delta + } + + #[cfg(test)] + pub(crate) fn wants_check(&self) -> bool { + !self.vs.is_empty() + } + + /// Computes multiple SPCOTs. + /// + /// Returns the SPCOT vectors, OT messages and SPCOT sums. + /// + /// # Arguments + /// + /// * `log2_lengths` - log2 length of the SPCOT vectors. + /// * `keys` - COT keys. + /// * `masks` - Derandomized COT choices bits from the receiver. + pub(crate) fn extend( + &mut self, + rng: &mut R, + log2_lengths: &[usize], + keys: &[Block], + masks: &BitVec, + ) -> Result<(&[Block], Vec<[Block; 2]>, Vec)> { + let len_sum: usize = log2_lengths.iter().sum(); + if keys.len() != len_sum { + return Err(ErrorRepr::KeyCount { + expected: len_sum, + actual: keys.len(), + } + .into()); + } else if masks.len() != len_sum { + return Err(ErrorRepr::MaskCount { + expected: len_sum, + actual: masks.len(), + } + .into()); + } + + // Compute OT keys. + let cipher = &(*FIXED_KEY_AES); + let mut ms: Vec<_> = keys + .iter() + .zip(masks.iter().by_vals()) + .enumerate() + .map(|(i, (key, b))| { + let mut m = if b { + [key ^ self.delta, *key] + } else { + [*key, key ^ self.delta] + }; + let tweak = Block::from((self.counter + i as u128).to_be_bytes()); + cipher.tccr_many(&[tweak, tweak], &mut m); + m + }) + .collect(); + + // Allocate space for the outputs. + let len: usize = log2_lengths.iter().map(|length| 1 << length).sum(); + let start = self.vs.len(); + self.vs.resize_with(start + len, || Block::ZERO); + + let spcot_lengths: Vec<_> = log2_lengths.iter().map(|length| 1 << length).collect(); + let seeds: Vec = (0..log2_lengths.len()).map(|_| rng.gen()).collect(); + let vs = slices_from_lengths_mut(&mut self.vs[start..], &spcot_lengths); + let ks = slices_from_lengths_mut(&mut ms, log2_lengths); + + let iter = { + cfg_if! { + if #[cfg(feature = "rayon")] { + vs.into_par_iter() + } else { + vs.into_iter() + } + } + }; + + let sums: Vec<_> = iter + .zip(ks) + .zip(log2_lengths) + .zip(seeds) + .map(|(((v, ks), &depth), seed)| { + // Generate the SPCOT vector from GGM leaves. + let tree = GgmTree::new_from_seed(depth, seed, v); + + // Encrypt the OT messages. + tree.layer_sums().zip(ks).for_each(|(sums, ks)| { + ks[0] ^= sums[0]; + ks[1] ^= sums[1]; + }); + + // Compute the sum of the leaves. + tree.leaves().iter().fold(self.delta, |acc, x| acc ^ x) + }) + .collect(); + + self.transcript.update(masks.as_raw_slice()); + self.transcript.update(Block::array_as_flattened_bytes(&ms)); + self.transcript.update(Block::as_flattened_bytes(&sums)); + self.counter += len_sum as u128; + + Ok((&self.vs[start..], ms, sums)) + } + + /// Performs the SPCOT consistency check. + /// + /// # Arguments + /// + /// * `keys` - COT keys. + /// * `masks` - Derandomized COT choice bits from the receiver. + pub(crate) fn check(&mut self, keys: &[Block], masks: &BitVec) -> Result { + if keys.len() != CSP { + return Err(ErrorRepr::KeyCount { + expected: CSP, + actual: keys.len(), + } + .into()); + } else if masks.len() != CSP { + return Err(ErrorRepr::MaskCount { + expected: CSP, + actual: masks.len(), + } + .into()); + } + + // Step 8 in Figure 6. + + // Computes y = y_star + x' * Delta + let y: Vec = keys + .iter() + .zip(masks.iter().by_vals()) + .map(|(&y, x)| if x { y ^ self.delta } else { y }) + .collect(); + + // Computes Y + let mut v = Block::inn_prdt_red(&y, &Block::MONOMIAL); + + // Computes V + let seed = *self.transcript.finalize().as_bytes(); + let mut prg = Prg::from_seed(Block::try_from(&seed[0..16]).unwrap()); + + let mut chis = vec![Block::ZERO; self.vs.len()]; + prg.random_blocks(&mut chis); + + v ^= Block::inn_prdt_red(&chis, &self.vs); + + // Computes H'(V) + let hashed_v = hash(&v.to_bytes()); + + self.vs.clear(); + self.transcript.reset(); + + Ok(hashed_v) + } +} + +#[derive(Debug, thiserror::Error)] +#[error(transparent)] +pub(crate) struct SPCOTSenderError(#[from] ErrorRepr); + +#[derive(Debug, thiserror::Error)] +#[error("SPCOT sender error: {0}")] +enum ErrorRepr { + #[error("incorrect key count, expected: {expected}, actual: {actual}")] + KeyCount { expected: usize, actual: usize }, + #[error("incorrect mask count, expected: {expected}, actual: {actual}")] + MaskCount { expected: usize, actual: usize }, +} diff --git a/crates/mpz-ot-core/src/lib.rs b/crates/mpz-ot-core/src/lib.rs index 98c0c51c..10f27f13 100644 --- a/crates/mpz-ot-core/src/lib.rs +++ b/crates/mpz-ot-core/src/lib.rs @@ -26,6 +26,7 @@ use serde::{Deserialize, Serialize}; pub mod chou_orlandi; pub mod cot; +pub mod ferret; pub mod ideal; pub mod kos; pub mod ot; @@ -66,5 +67,5 @@ impl TransferId { #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct Derandomize { /// Correction bits - pub flip: BitVec, + pub flip: BitVec, } diff --git a/crates/mpz-ot-core/src/test.rs b/crates/mpz-ot-core/src/test.rs index 74c1a282..54c83e2c 100644 --- a/crates/mpz-ot-core/src/test.rs +++ b/crates/mpz-ot-core/src/test.rs @@ -17,15 +17,15 @@ pub fn assert_ot(choices: &[bool], msgs: &[[Block; 2]], received: &[Block]) { } /// Asserts the correctness of correlated oblivious transfer. -pub fn assert_cot(delta: Block, choices: &[bool], msgs: &[Block], received: &[Block]) { +pub fn assert_cot(delta: Block, choices: &[bool], keys: &[Block], macs: &[Block]) { assert!(choices .iter() - .zip(msgs.iter().zip(received)) - .all(|(&choice, (&msg, &received))| { + .zip(keys.iter().zip(macs)) + .all(|(&choice, (&key, &mac))| { if choice { - received == msg ^ delta + mac == key ^ delta } else { - received == msg + mac == key } })); } @@ -43,3 +43,14 @@ pub fn assert_rot(choices: &[bool], msgs: &[[T; 2]], receiv } })); } + +/// Asserts the correctness of single-point correlated oblivious transfer. +pub fn assert_spcot(delta: Block, keys: &[Block], idx: usize, received: &[Block]) { + assert_eq!(received.len(), keys.len()); + + assert_eq!( + keys.iter().fold(delta, |x_acc, x| x_acc ^ x), + received.iter().fold(Block::ZERO, |x_acc, x| x_acc ^ x) + ); + assert_eq!(keys[idx] ^ delta, received[idx]); +} diff --git a/crates/mpz-ot/src/ferret.rs b/crates/mpz-ot/src/ferret.rs new file mode 100644 index 00000000..9a17e28a --- /dev/null +++ b/crates/mpz-ot/src/ferret.rs @@ -0,0 +1,48 @@ +//! An implementation of the [`Ferret`](https://eprint.iacr.org/2020/924.pdf) protocol. + +mod receiver; +mod sender; + +pub use receiver::{Receiver, ReceiverError}; +pub use sender::{Sender, SenderError}; + +pub use mpz_core::lpn::LpnType; +pub use mpz_ot_core::ferret::{FerretConfig, FerretConfigBuilder, FerretConfigBuilderError}; + +#[cfg(test)] +mod tests { + use super::*; + use crate::ideal::rcot::ideal_rcot; + use mpz_core::lpn::LpnParameters; + use rand::{rngs::StdRng, Rng, SeedableRng}; + use rstest::*; + + #[rstest] + #[case::uniform(LpnType::Uniform)] + #[case::regular(LpnType::Regular)] + #[tokio::test] + async fn test_ferret(#[case] lpn_type: LpnType) { + use crate::test::test_rcot; + + let mut rng = StdRng::seed_from_u64(0); + let delta = rng.gen(); + + let (cot_sender, cot_receiver) = ideal_rcot(rng.gen(), delta); + + let mut builder = FerretConfig::builder(); + + builder.lpn_type(lpn_type); + builder.param_selector(|_, _, _| LpnParameters { + n: 9600, + k: 1220, + t: 600, + }); + + let config = builder.build().unwrap(); + + let sender = Sender::new(config.clone(), rng.gen(), cot_sender); + let receiver = Receiver::new(config, rng.gen(), cot_receiver); + + test_rcot(sender, receiver, 20_000, 2).await; + } +} diff --git a/crates/mpz-ot/src/ferret/receiver.rs b/crates/mpz-ot/src/ferret/receiver.rs new file mode 100644 index 00000000..1d30a1dd --- /dev/null +++ b/crates/mpz-ot/src/ferret/receiver.rs @@ -0,0 +1,141 @@ +use async_trait::async_trait; +use serio::{stream::IoStreamExt, SinkExt}; + +use mpz_common::{future::MaybeDone, Context, Flush}; +use mpz_core::Block; +use mpz_ot_core::{ + ferret::{FerretConfig, Receiver as Core, ReceiverError as CoreError}, + rcot::{RCOTReceiver, RCOTReceiverOutput}, +}; + +type Error = ReceiverError; + +/// Ferret receiver. +#[derive(Debug)] +pub struct Receiver { + core: Core, +} + +impl Receiver +where + COT: RCOTReceiver, +{ + /// Creates a new Receiver. + /// + /// # Arguments + /// + /// * `config` - Receiver's configuration. + /// * `seed` - Receiver's PRG seed. + /// * `cot` - COT used for bootstrapping. + pub fn new(config: FerretConfig, seed: Block, cot: COT) -> Self { + Self { + core: Core::new(seed, config, cot), + } + } +} + +impl RCOTReceiver for Receiver +where + COT: RCOTReceiver, +{ + type Error = Error; + type Future = MaybeDone>; + + fn alloc(&mut self, count: usize) -> Result<(), Self::Error> { + self.core.alloc(count).map_err(Error::from) + } + + fn available(&self) -> usize { + self.core.available() + } + + fn try_recv_rcot( + &mut self, + count: usize, + ) -> Result, Self::Error> { + self.core.try_recv_rcot(count).map_err(Error::from) + } + + fn queue_recv_rcot(&mut self, count: usize) -> Result { + self.core.queue_recv_rcot(count).map_err(Error::from) + } +} + +#[async_trait] +impl Flush for Receiver +where + Ctx: Context, + COT: RCOTReceiver + Flush + Send, +{ + type Error = Error; + + fn wants_flush(&self) -> bool { + self.core.wants_init() || self.core.wants_extend() + } + + async fn flush(&mut self, ctx: &mut Ctx) -> Result<(), Self::Error> { + if self.core.wants_init() { + let msg = self.core.initialize()?; + ctx.io_mut().send(msg).await?; + } + + // TODO: Run this concurrently with the above. + if self.core.wants_bootstrap() { + self.core.alloc_bootstrap()?; + self.core + .acquire_cot() + .flush(ctx) + .await + .map_err(Error::bootstrap)?; + } + + while self.core.wants_extend() { + let msg = self.core.start_extend()?; + ctx.io_mut().send(msg).await?; + let msg = ctx.io_mut().expect_next().await?; + let msg = self.core.extend(msg)?; + ctx.io_mut().send(msg).await?; + let msg = ctx.io_mut().expect_next().await?; + self.core.finish_extend(msg)?; + } + + Ok(()) + } +} + +/// Ferret receiver error. +#[derive(Debug, thiserror::Error)] +#[error(transparent)] +pub struct ReceiverError(#[from] ErrorRepr); + +impl ReceiverError { + fn bootstrap(err: E) -> Self + where + E: Into>, + { + Self(ErrorRepr::Bootstrap(err.into())) + } +} + +#[derive(Debug, thiserror::Error)] +#[error("ferret receiver error: {0}")] +enum ErrorRepr { + #[error("core error: {0}")] + Core(CoreError), + #[error("bootstrap COT error: {0}")] + Bootstrap(Box), + #[error("io error: {0}")] + Io(std::io::Error), +} + +impl From for ReceiverError { + fn from(e: CoreError) -> Self { + Self(ErrorRepr::Core(e)) + } +} + +impl From for ReceiverError { + fn from(e: std::io::Error) -> Self { + Self(ErrorRepr::Io(e)) + } +} diff --git a/crates/mpz-ot/src/ferret/sender.rs b/crates/mpz-ot/src/ferret/sender.rs new file mode 100644 index 00000000..5f224362 --- /dev/null +++ b/crates/mpz-ot/src/ferret/sender.rs @@ -0,0 +1,142 @@ +use async_trait::async_trait; +use mpz_common::{future::MaybeDone, Context, Flush}; +use mpz_core::Block; +use mpz_ot_core::{ + ferret::{FerretConfig, Sender as Core, SenderError as CoreError}, + rcot::{RCOTSender, RCOTSenderOutput}, +}; +use serio::{stream::IoStreamExt, SinkExt}; + +type Error = SenderError; + +/// Ferret sender. +#[derive(Debug)] +pub struct Sender { + core: Core, +} + +impl Sender +where + COT: RCOTSender, +{ + /// Creates a new Sender. + /// + /// # Arguments + /// + /// * `config` - Sender's configuration. + /// * `seed` - Sender's PRG seed. + /// * `cot` - COT used for bootstrapping. + pub fn new(config: FerretConfig, seed: Block, cot: COT) -> Self { + Self { + core: Core::new(seed, config, cot), + } + } +} + +impl RCOTSender for Sender +where + COT: RCOTSender, +{ + type Error = Error; + type Future = MaybeDone>; + + fn alloc(&mut self, count: usize) -> Result<(), Self::Error> { + self.core.alloc(count).map_err(Error::from) + } + + fn available(&self) -> usize { + self.core.available() + } + + fn delta(&self) -> Block { + self.core.delta() + } + + fn try_send_rcot(&mut self, count: usize) -> Result, Self::Error> { + self.core.try_send_rcot(count).map_err(Error::from) + } + + fn queue_send_rcot(&mut self, count: usize) -> Result { + self.core.queue_send_rcot(count).map_err(Error::from) + } +} + +#[async_trait] +impl Flush for Sender +where + Ctx: Context, + COT: RCOTSender + Flush + Send, +{ + type Error = Error; + + fn wants_flush(&self) -> bool { + self.core.wants_init() || self.core.wants_extend() + } + + async fn flush(&mut self, ctx: &mut Ctx) -> Result<(), Self::Error> { + if self.core.wants_init() { + let init = ctx.io_mut().expect_next().await?; + self.core.initialize(init)?; + } + + // TODO: Run this concurrently with the above. + if self.core.wants_bootstrap() { + self.core.alloc_bootstrap()?; + self.core + .acquire_cot() + .flush(ctx) + .await + .map_err(Error::bootstrap)?; + } + + while self.core.wants_extend() { + self.core.start_extend()?; + let msg = ctx.io_mut().expect_next().await?; + let msg = self.core.extend(msg)?; + ctx.io_mut().send(msg).await?; + let msg = ctx.io_mut().expect_next().await?; + let msg = self.core.check(msg)?; + ctx.io_mut().send(msg).await?; + self.core.finish_extend()?; + } + + Ok(()) + } +} + +/// Ferret sender error. +#[derive(Debug, thiserror::Error)] +#[error(transparent)] +pub struct SenderError(#[from] ErrorRepr); + +impl SenderError { + fn bootstrap(err: E) -> Self + where + E: Into>, + { + Self(ErrorRepr::Bootstrap(err.into())) + } +} + +#[derive(Debug, thiserror::Error)] +#[error("ferret sender error: {0}")] +enum ErrorRepr { + #[error("core error: {0}")] + Core(CoreError), + #[error("bootstrap COT error: {0}")] + Bootstrap(Box), + #[error("io error: {0}")] + Io(std::io::Error), +} + +impl From for SenderError { + fn from(e: CoreError) -> Self { + Self(ErrorRepr::Core(e)) + } +} + +impl From for SenderError { + fn from(e: std::io::Error) -> Self { + Self(ErrorRepr::Io(e)) + } +} diff --git a/crates/mpz-ot/src/ideal/rcot.rs b/crates/mpz-ot/src/ideal/rcot.rs index 2bc776bf..3c62f769 100644 --- a/crates/mpz-ot/src/ideal/rcot.rs +++ b/crates/mpz-ot/src/ideal/rcot.rs @@ -142,6 +142,6 @@ mod tests { async fn test_ideal_rcot() { let mut rng = StdRng::seed_from_u64(0); let (sender, receiver) = ideal_rcot(rng.gen(), rng.gen()); - test_rcot(sender, receiver, 8).await; + test_rcot(sender, receiver, 128, 8).await; } } diff --git a/crates/mpz-ot/src/kos.rs b/crates/mpz-ot/src/kos.rs index 99d90cc7..62a2b542 100644 --- a/crates/mpz-ot/src/kos.rs +++ b/crates/mpz-ot/src/kos.rs @@ -1,10 +1,10 @@ -//! Correlated random oblivious transfer extension protocol with leakage based on -//! [`KOS15`](https://eprint.iacr.org/archive/2015/546/1433798896.pdf). +//! Correlated random oblivious transfer extension protocol with leakage based +//! on [`KOS15`](https://eprint.iacr.org/archive/2015/546/1433798896.pdf). //! //! # Warning //! -//! The user of this protocol must carefully consider if the leakage introduced in this protocol -//! is acceptable for their specific application. +//! The user of this protocol must carefully consider if the leakage introduced +//! in this protocol is acceptable for their specific application. mod receiver; mod sender; @@ -34,6 +34,6 @@ mod tests { let sender = Sender::new(SenderConfig::default(), delta, base_receiver); let receiver = Receiver::new(ReceiverConfig::default(), base_sender); - test_rcot(sender, receiver, 1).await; + test_rcot(sender, receiver, 128, 1).await; } } diff --git a/crates/mpz-ot/src/lib.rs b/crates/mpz-ot/src/lib.rs index 11887b0c..0f657103 100644 --- a/crates/mpz-ot/src/lib.rs +++ b/crates/mpz-ot/src/lib.rs @@ -11,6 +11,7 @@ pub mod chou_orlandi; pub mod cot; +pub mod ferret; #[cfg(any(test, feature = "ideal"))] pub mod ideal; pub mod kos; diff --git a/crates/mpz-ot/src/test.rs b/crates/mpz-ot/src/test.rs index 4c784711..6a9b2463 100644 --- a/crates/mpz-ot/src/test.rs +++ b/crates/mpz-ot/src/test.rs @@ -48,14 +48,13 @@ where } /// Tests RCOT functionality. -pub async fn test_rcot(mut sender: S, mut receiver: R, cycles: usize) +pub async fn test_rcot(mut sender: S, mut receiver: R, count: usize, cycles: usize) where S: RCOTSender + Flush, R: RCOTReceiver + Flush, { let (mut sender_ctx, mut receiver_ctx) = test_st_executor(8); - let count = 128; for _ in 0..cycles { let ( RCOTSenderOutput {