Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: garble vm #191

Merged
merged 4 commits into from
Nov 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ clmul = { path = "crates/clmul" }
matrix-transpose = { path = "crates/matrix-transpose" }

tlsn-utils = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "43995c5" }
tlsn-utils-aio = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "43995c5" }

# rand
rand_chacha = "0.3"
Expand Down Expand Up @@ -88,7 +87,6 @@ serde = "1.0"
serde_yaml = "0.9"
serde_arrays = "0.1"
bincode = "1.3.3"
prost-build = "0.9"
bytes = "1"
yamux = "0.10"
bytemuck = { version = "1.13", features = ["derive"] }
Expand All @@ -98,7 +96,6 @@ serio = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "43995c5" }
uid-mux = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "43995c5" }

# testing
prost = "0.9"
rstest = "0.12"
pretty_assertions = "1"
criterion = "0.3"
Expand Down
6 changes: 3 additions & 3 deletions crates/matrix-transpose/src/scalar.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
/// Unsafe matrix transpose
///
/// This function transposes a matrix of generic elements. This function is an implementation of
/// the byte-level transpose in
/// https://docs.rs/oblivious-transfer/latest/oblivious_transfer/extension/fn.transpose128.html
/// This function transposes a matrix of generic elements. This function is an
/// implementation of the byte-level transpose in
/// <https://docs.rs/oblivious-transfer/latest/oblivious_transfer/extension/fn.transpose128.html>
///
/// # Safety
///
Expand Down
76 changes: 49 additions & 27 deletions crates/mpz-circuits/src/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ use crate::{
#[derive(Debug, thiserror::Error)]
#[allow(missing_docs)]
pub enum CircuitError {
#[error("Invalid number of wires: expected {0}, got {1}")]
InvalidWireCount(usize, usize),
#[error("Invalid number of inputs: expected {0}, got {1}")]
InvalidInputCount(usize, usize),
#[error("Invalid number of outputs: expected {0}, got {1}")]
Expand Down Expand Up @@ -41,6 +43,16 @@ impl Circuit {
&self.outputs
}

/// Returns the input length of the circuit in bits.
pub fn input_len(&self) -> usize {
self.inputs.iter().map(|input| input.len()).sum()
}

/// Returns the output length of the circuit in bits.
pub fn output_len(&self) -> usize {
self.outputs.iter().map(|output| output.len()).sum()
}

/// Returns a reference to the gates of the circuit.
pub fn gates(&self) -> &[Gate] {
&self.gates
Expand Down Expand Up @@ -109,6 +121,39 @@ impl Circuit {
self
}

/// Evaluate the circuit using the provided wires.
///
/// It is the callers responsibility to ensure the input wires are set.
pub fn evaluate_raw(&self, wires: &mut [bool]) -> Result<(), CircuitError> {
if wires.len() != self.feed_count {
return Err(CircuitError::InvalidWireCount(self.feed_count, wires.len()));
}

for gate in self.gates.iter() {
match gate {
Gate::Xor { x, y, z } => {
let x = wires[x.id];
let y = wires[y.id];

wires[z.id] = x ^ y;
}
Gate::And { x, y, z } => {
let x = wires[x.id];
let y = wires[y.id];

wires[z.id] = x & y;
}
Gate::Inv { x, z } => {
let x = wires[x.id];

wires[z.id] = !x;
}
}
}

Ok(())
}

/// Evaluate the circuit with the given inputs.
///
/// # Arguments
Expand All @@ -126,7 +171,7 @@ impl Circuit {
));
}

let mut feeds: Vec<Option<bool>> = vec![None; self.feed_count];
let mut feeds: Vec<bool> = vec![false; self.feed_count];

for (input, value) in self.inputs.iter().zip(values) {
if input.value_type() != value.value_type() {
Expand All @@ -137,41 +182,18 @@ impl Circuit {
}

for (node, bit) in input.iter().zip(value.clone().into_iter_lsb0()) {
feeds[node.id] = Some(bit);
feeds[node.id] = bit;
}
}

for gate in self.gates.iter() {
match gate {
Gate::Xor { x, y, z } => {
let x = feeds[x.id].expect("Feed should be set");
let y = feeds[y.id].expect("Feed should be set");

feeds[z.id] = Some(x ^ y);
}
Gate::And { x, y, z } => {
let x = feeds[x.id].expect("Feed should be set");
let y = feeds[y.id].expect("Feed should be set");

feeds[z.id] = Some(x & y);
}
Gate::Inv { x, z } => {
let x = feeds[x.id].expect("Feed should be set");

feeds[z.id] = Some(!x);
}
}
}
self.evaluate_raw(&mut feeds)?;

let outputs = self
.outputs
.iter()
.cloned()
.map(|output| {
let bits: Vec<bool> = output
.iter()
.map(|node| feeds[node.id].expect("Feed should be set"))
.collect();
let bits: Vec<bool> = output.iter().map(|node| feeds[node.id]).collect();

output
.from_bin_repr(&bits)
Expand Down
18 changes: 18 additions & 0 deletions crates/mpz-common/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ pub struct ContextError {
}

impl ContextError {
#[allow(dead_code)]
pub(crate) fn new<E: Into<Box<dyn std::error::Error + Send + Sync>>>(
kind: ErrorKind,
source: E,
Expand All @@ -29,6 +30,7 @@ impl ContextError {
}

#[derive(Debug)]
#[allow(dead_code)]
pub(crate) enum ErrorKind {
Mux,
Thread,
Expand Down Expand Up @@ -58,6 +60,22 @@ pub trait Context: Send + Sync {
/// Returns a mutable reference to the thread's I/O channel.
fn io_mut(&mut self) -> &mut Self::Io;

/// Executes a collection of tasks provided with a context.
///
/// If multi-threading is available, the tasks are load balanced across
/// threads. Otherwise, they are executed sequentially.
async fn map<'a, F, T, R, W>(
&'a mut self,
items: Vec<T>,
f: F,
weight: W,
) -> Result<Vec<R>, ContextError>
where
F: for<'b> Fn(&'b mut Self, T) -> ScopedBoxFuture<'static, 'b, R> + Clone + Send + 'static,
T: Send + 'static,
R: Send + 'static,
W: Fn(&T) -> usize + Send + 'static;

/// Forks the thread and executes the provided closures concurrently.
///
/// Implementations may not be able to fork, in which case the closures are
Expand Down
19 changes: 19 additions & 0 deletions crates/mpz-common/src/executor/dummy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,25 @@ impl Context for DummyExecutor {
&mut self.io
}

async fn map<'a, F, T, R, W>(
&'a mut self,
items: Vec<T>,
f: F,
_weight: W,
) -> Result<Vec<R>, ContextError>
where
F: for<'b> Fn(&'b mut Self, T) -> ScopedBoxFuture<'static, 'b, R> + Clone + Send + 'static,
T: Send + 'static,
R: Send + 'static,
W: Fn(&T) -> usize + Send + 'static,
{
let mut results = Vec::with_capacity(items.len());
for item in items {
results.push(f(self, item).await);
}
Ok(results)
}

async fn join<'a, A, B, RA, RB>(&'a mut self, a: A, b: B) -> Result<(RA, RB), ContextError>
where
A: for<'b> FnOnce(&'b mut Self) -> ScopedBoxFuture<'a, 'b, RA> + Send + 'a,
Expand Down
2 changes: 1 addition & 1 deletion crates/mpz-common/src/executor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ mod test_utils {
}

/// Test multi-threaded executor.
pub type TestMTExecutor = MTExecutor<TestFramedMux, MemoryDuplex>;
pub type TestMTExecutor = MTExecutor<TestFramedMux>;

/// Creates a pair of multi-threaded executors with multiplexed I/O
/// channels.
Expand Down
Loading