Skip to content

Commit

Permalink
Support multiplexer (#207)
Browse files Browse the repository at this point in the history
* add multiplexer stdlib
  • Loading branch information
katat authored Oct 28, 2024
1 parent 4cfb0c5 commit eb1def1
Show file tree
Hide file tree
Showing 6 changed files with 197 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/stdlib/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ pub fn init_stdlib_dep<B: Backend>(
path_prefix: &str,
) -> usize {
// list the stdlib dependency in order
let libs = vec!["bits", "comparator", "int"];
let libs = vec!["bits", "comparator", "multiplexer", "int"];

let mut node_id = node_id;

Expand Down
105 changes: 105 additions & 0 deletions src/stdlib/native/multiplexer/lib.no
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
use std::comparator;

/// Multiplies two vectors of the same length and returns the accumulated sum (dot product).
///
/// # Parameters
/// - `lhs`: A vector (array) of `Field` elements.
/// - `rhs`: A vector (array) of `Field` elements.
///
/// # Returns
/// - `Field`: The accumulated sum resulting from the element-wise multiplication of `lhs` and `rhs`.
///
/// # Panics
/// - The function assumes that `lhs` and `rhs` have the same length, `LEN`.
///
/// # Example
/// ```
/// let lhs = [1, 2, 3];
/// let rhs = [4, 5, 6];
/// let result = escalar_product(lhs, rhs);
/// result should be 1*4 + 2*5 + 3*6 = 32
/// ```
fn escalar_product(lhs: [Field; LEN], rhs: [Field; LEN]) -> Field {
let mut lc = 0;
for idx in 0..LEN {
lc = lc + (lhs[idx] * rhs[idx]);
}
return lc;
}

/// Generates a selector array of a given length `LEN` with all zeros except for a one at the specified `target_idx`.
///
/// # Parameters
/// - `LEN`: The length of the output array.
/// - `target_idx`: The index where the value should be 1. The rest of the array will be filled with zeros.
///
/// # Returns
/// - `[Field; LEN]`: An array of length `LEN` where all elements are zero except for a single `1` at `target_idx`.
///
/// # Panics
/// - This function asserts that there is exactly one `1` in the generated array, ensuring `target_idx` is within bounds.
///
/// # Example
/// ```
/// let selector = gen_selector_arr(5, 2);
/// `selector` should be [0, 0, 1, 0, 0]
/// ```
fn gen_selector_arr(const LEN: Field, target_idx: Field) -> [Field; LEN] {
let mut selector = [0; LEN];
let mut lc = 0;
let one = 1;
let zero = 0;

for idx in 0..LEN {
selector[idx] = if idx == target_idx { one } else { zero };
lc = lc + selector[idx];
}

// Ensures there is exactly one '1' in the range of LEN.
assert(lc == 1);

return selector;
}

/// Selects an element from a 2D array based on a `target_idx` and returns a vector of length `WIDLEN`.
///
/// # Parameters
/// - `arr`: A 2D array of dimensions `[ARRLEN][WIDLEN]` containing `Field` elements.
/// - `target_idx`: The index that determines which row of `arr` to select.
///
/// # Returns
/// - `[Field; WIDLEN]`: A vector representing the selected row from `arr`.
///
/// # Algorithm
/// 1. Generate a selector array using `gen_selector_arr` that has a `1` at `target_idx` and `0`s elsewhere.
/// 2. For each column index `idx` of the 2D array:
/// - Extract the `idx`-th element from each row into a temporary array.
/// - Use `escalar_product` with the temporary array and the selector array to `select` the value corresponding to `target_idx`.
/// 3. Reset the temporary array for the next iteration.
/// 4. Return the vector containing the selected row.
///
/// # Example
/// ```
/// let arr = [[1, 2], [3, 4], [5, 6]];
/// let result = select_element(arr, 1);
/// `result` should be [3, 4] as it selects the second row (index 1).
/// ```
fn select_element(arr: [[Field; WIDLEN]; ARRLEN], target_idx: Field) -> [Field; WIDLEN] {
let mut result = [0; WIDLEN];

let selector_arr = gen_selector_arr(ARRLEN, target_idx);
let mut one_len_arr = [0; ARRLEN];

for idx in 0..WIDLEN {
for jdx in 0..ARRLEN {
one_len_arr[jdx] = arr[jdx][idx];
}
// Only one element in `selector_arr` is `1`, so the result is the element in `one_len_arr`
// at the same index as the `1` in `selector_arr`.
result[idx] = escalar_product(one_len_arr, selector_arr);

// Reset the temporary array for the next column.
one_len_arr = [0; ARRLEN];
}
return result;
}
1 change: 1 addition & 0 deletions src/tests/stdlib/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
mod comparator;
mod multiplexer;

use std::{path::Path, str::FromStr};

Expand Down
48 changes: 48 additions & 0 deletions src/tests/stdlib/multiplexer/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
use crate::error::{self};

use super::test_stdlib;
use error::Result;
use rstest::rstest;

#[rstest]
#[case(r#"{"xx": [["0", "1", "2"], ["3", "4", "5"], ["6", "7", "8"]]}"#, r#"{"sel": "1"}"#, vec!["3", "4", "5"])]
fn test_in_range(
#[case] public_inputs: &str,
#[case] private_inputs: &str,
#[case] expected_output: Vec<&str>,
) -> Result<()> {
test_stdlib(
"multiplexer/select_element/main.no",
Some("multiplexer/select_element/main.asm"),
public_inputs,
private_inputs,
expected_output,
)?;

Ok(())
}

// require the select idx to be in range
#[rstest]
#[case(r#"{"xx": [["0", "1", "2"], ["3", "4", "5"], ["6", "7", "8"]]}"#, r#"{"sel": "3"}"#, vec![])]
fn test_out_range(
#[case] public_inputs: &str,
#[case] private_inputs: &str,
#[case] expected_output: Vec<&str>,
) -> Result<()> {
use crate::error::ErrorKind;

let err = test_stdlib(
"multiplexer/select_element/main.no",
Some("multiplexer/select_element/main.asm"),
public_inputs,
private_inputs,
expected_output,
)
.err()
.expect("Expected error");

assert!(matches!(err.kind, ErrorKind::InvalidWitness(..)));

Ok(())
}
36 changes: 36 additions & 0 deletions src/tests/stdlib/multiplexer/select_element/main.asm
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
@ noname.0.7.0
@ public inputs: 12

0 == (v_14) * (1)
v_16 == (v_15) * (v_13 + -1 * v_14)
-1 * v_17 + 1 == (v_16) * (1)
v_18 == (v_17) * (v_13 + -1 * v_14)
0 == (v_18) * (1)
1 == (v_19) * (1)
v_21 == (v_20) * (v_13 + -1 * v_19)
-1 * v_22 + 1 == (v_21) * (1)
v_23 == (v_22) * (v_13 + -1 * v_19)
0 == (v_23) * (1)
2 == (v_24) * (1)
v_26 == (v_25) * (v_13 + -1 * v_24)
-1 * v_27 + 1 == (v_26) * (1)
v_28 == (v_27) * (v_13 + -1 * v_24)
0 == (v_28) * (1)
1 == (v_29) * (1)
v_31 == (v_30) * (-1 * v_17 + -1 * v_22 + -1 * v_27 + v_29)
-1 * v_32 + 1 == (v_31) * (1)
v_33 == (v_32) * (-1 * v_17 + -1 * v_22 + -1 * v_27 + v_29)
0 == (v_33) * (1)
1 == (v_32) * (1)
v_34 == (v_4) * (v_17)
v_35 == (v_7) * (v_22)
v_36 == (v_10) * (v_27)
v_37 == (v_5) * (v_17)
v_38 == (v_8) * (v_22)
v_39 == (v_11) * (v_27)
v_40 == (v_6) * (v_17)
v_41 == (v_9) * (v_22)
v_42 == (v_12) * (v_27)
v_34 + v_35 + v_36 == (v_1) * (1)
v_37 + v_38 + v_39 == (v_2) * (1)
v_40 + v_41 + v_42 == (v_3) * (1)
6 changes: 6 additions & 0 deletions src/tests/stdlib/multiplexer/select_element/main.no
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
use std::multiplexer;

fn main(pub xx: [[Field; 3]; 3], sel: Field) -> [Field; 3] {
let chosen_elements = multiplexer::select_element(xx, sel);
return chosen_elements;
}

0 comments on commit eb1def1

Please sign in to comment.