-
Notifications
You must be signed in to change notification settings - Fork 56
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add multiplexer stdlib
- Loading branch information
Showing
6 changed files
with
197 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
use std::comparator; | ||
|
||
/// Multiplies two vectors of the same length and returns the accumulated sum (dot product). | ||
/// | ||
/// # Parameters | ||
/// - `lhs`: A vector (array) of `Field` elements. | ||
/// - `rhs`: A vector (array) of `Field` elements. | ||
/// | ||
/// # Returns | ||
/// - `Field`: The accumulated sum resulting from the element-wise multiplication of `lhs` and `rhs`. | ||
/// | ||
/// # Panics | ||
/// - The function assumes that `lhs` and `rhs` have the same length, `LEN`. | ||
/// | ||
/// # Example | ||
/// ``` | ||
/// let lhs = [1, 2, 3]; | ||
/// let rhs = [4, 5, 6]; | ||
/// let result = escalar_product(lhs, rhs); | ||
/// result should be 1*4 + 2*5 + 3*6 = 32 | ||
/// ``` | ||
fn escalar_product(lhs: [Field; LEN], rhs: [Field; LEN]) -> Field { | ||
let mut lc = 0; | ||
for idx in 0..LEN { | ||
lc = lc + (lhs[idx] * rhs[idx]); | ||
} | ||
return lc; | ||
} | ||
|
||
/// Generates a selector array of a given length `LEN` with all zeros except for a one at the specified `target_idx`. | ||
/// | ||
/// # Parameters | ||
/// - `LEN`: The length of the output array. | ||
/// - `target_idx`: The index where the value should be 1. The rest of the array will be filled with zeros. | ||
/// | ||
/// # Returns | ||
/// - `[Field; LEN]`: An array of length `LEN` where all elements are zero except for a single `1` at `target_idx`. | ||
/// | ||
/// # Panics | ||
/// - This function asserts that there is exactly one `1` in the generated array, ensuring `target_idx` is within bounds. | ||
/// | ||
/// # Example | ||
/// ``` | ||
/// let selector = gen_selector_arr(5, 2); | ||
/// `selector` should be [0, 0, 1, 0, 0] | ||
/// ``` | ||
fn gen_selector_arr(const LEN: Field, target_idx: Field) -> [Field; LEN] { | ||
let mut selector = [0; LEN]; | ||
let mut lc = 0; | ||
let one = 1; | ||
let zero = 0; | ||
|
||
for idx in 0..LEN { | ||
selector[idx] = if idx == target_idx { one } else { zero }; | ||
lc = lc + selector[idx]; | ||
} | ||
|
||
// Ensures there is exactly one '1' in the range of LEN. | ||
assert(lc == 1); | ||
|
||
return selector; | ||
} | ||
|
||
/// Selects an element from a 2D array based on a `target_idx` and returns a vector of length `WIDLEN`. | ||
/// | ||
/// # Parameters | ||
/// - `arr`: A 2D array of dimensions `[ARRLEN][WIDLEN]` containing `Field` elements. | ||
/// - `target_idx`: The index that determines which row of `arr` to select. | ||
/// | ||
/// # Returns | ||
/// - `[Field; WIDLEN]`: A vector representing the selected row from `arr`. | ||
/// | ||
/// # Algorithm | ||
/// 1. Generate a selector array using `gen_selector_arr` that has a `1` at `target_idx` and `0`s elsewhere. | ||
/// 2. For each column index `idx` of the 2D array: | ||
/// - Extract the `idx`-th element from each row into a temporary array. | ||
/// - Use `escalar_product` with the temporary array and the selector array to `select` the value corresponding to `target_idx`. | ||
/// 3. Reset the temporary array for the next iteration. | ||
/// 4. Return the vector containing the selected row. | ||
/// | ||
/// # Example | ||
/// ``` | ||
/// let arr = [[1, 2], [3, 4], [5, 6]]; | ||
/// let result = select_element(arr, 1); | ||
/// `result` should be [3, 4] as it selects the second row (index 1). | ||
/// ``` | ||
fn select_element(arr: [[Field; WIDLEN]; ARRLEN], target_idx: Field) -> [Field; WIDLEN] { | ||
let mut result = [0; WIDLEN]; | ||
|
||
let selector_arr = gen_selector_arr(ARRLEN, target_idx); | ||
let mut one_len_arr = [0; ARRLEN]; | ||
|
||
for idx in 0..WIDLEN { | ||
for jdx in 0..ARRLEN { | ||
one_len_arr[jdx] = arr[jdx][idx]; | ||
} | ||
// Only one element in `selector_arr` is `1`, so the result is the element in `one_len_arr` | ||
// at the same index as the `1` in `selector_arr`. | ||
result[idx] = escalar_product(one_len_arr, selector_arr); | ||
|
||
// Reset the temporary array for the next column. | ||
one_len_arr = [0; ARRLEN]; | ||
} | ||
return result; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
mod comparator; | ||
mod multiplexer; | ||
|
||
use std::{path::Path, str::FromStr}; | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
use crate::error::{self}; | ||
|
||
use super::test_stdlib; | ||
use error::Result; | ||
use rstest::rstest; | ||
|
||
#[rstest] | ||
#[case(r#"{"xx": [["0", "1", "2"], ["3", "4", "5"], ["6", "7", "8"]]}"#, r#"{"sel": "1"}"#, vec!["3", "4", "5"])] | ||
fn test_in_range( | ||
#[case] public_inputs: &str, | ||
#[case] private_inputs: &str, | ||
#[case] expected_output: Vec<&str>, | ||
) -> Result<()> { | ||
test_stdlib( | ||
"multiplexer/select_element/main.no", | ||
Some("multiplexer/select_element/main.asm"), | ||
public_inputs, | ||
private_inputs, | ||
expected_output, | ||
)?; | ||
|
||
Ok(()) | ||
} | ||
|
||
// require the select idx to be in range | ||
#[rstest] | ||
#[case(r#"{"xx": [["0", "1", "2"], ["3", "4", "5"], ["6", "7", "8"]]}"#, r#"{"sel": "3"}"#, vec![])] | ||
fn test_out_range( | ||
#[case] public_inputs: &str, | ||
#[case] private_inputs: &str, | ||
#[case] expected_output: Vec<&str>, | ||
) -> Result<()> { | ||
use crate::error::ErrorKind; | ||
|
||
let err = test_stdlib( | ||
"multiplexer/select_element/main.no", | ||
Some("multiplexer/select_element/main.asm"), | ||
public_inputs, | ||
private_inputs, | ||
expected_output, | ||
) | ||
.err() | ||
.expect("Expected error"); | ||
|
||
assert!(matches!(err.kind, ErrorKind::InvalidWitness(..))); | ||
|
||
Ok(()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
@ noname.0.7.0 | ||
@ public inputs: 12 | ||
|
||
0 == (v_14) * (1) | ||
v_16 == (v_15) * (v_13 + -1 * v_14) | ||
-1 * v_17 + 1 == (v_16) * (1) | ||
v_18 == (v_17) * (v_13 + -1 * v_14) | ||
0 == (v_18) * (1) | ||
1 == (v_19) * (1) | ||
v_21 == (v_20) * (v_13 + -1 * v_19) | ||
-1 * v_22 + 1 == (v_21) * (1) | ||
v_23 == (v_22) * (v_13 + -1 * v_19) | ||
0 == (v_23) * (1) | ||
2 == (v_24) * (1) | ||
v_26 == (v_25) * (v_13 + -1 * v_24) | ||
-1 * v_27 + 1 == (v_26) * (1) | ||
v_28 == (v_27) * (v_13 + -1 * v_24) | ||
0 == (v_28) * (1) | ||
1 == (v_29) * (1) | ||
v_31 == (v_30) * (-1 * v_17 + -1 * v_22 + -1 * v_27 + v_29) | ||
-1 * v_32 + 1 == (v_31) * (1) | ||
v_33 == (v_32) * (-1 * v_17 + -1 * v_22 + -1 * v_27 + v_29) | ||
0 == (v_33) * (1) | ||
1 == (v_32) * (1) | ||
v_34 == (v_4) * (v_17) | ||
v_35 == (v_7) * (v_22) | ||
v_36 == (v_10) * (v_27) | ||
v_37 == (v_5) * (v_17) | ||
v_38 == (v_8) * (v_22) | ||
v_39 == (v_11) * (v_27) | ||
v_40 == (v_6) * (v_17) | ||
v_41 == (v_9) * (v_22) | ||
v_42 == (v_12) * (v_27) | ||
v_34 + v_35 + v_36 == (v_1) * (1) | ||
v_37 + v_38 + v_39 == (v_2) * (1) | ||
v_40 + v_41 + v_42 == (v_3) * (1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
use std::multiplexer; | ||
|
||
fn main(pub xx: [[Field; 3]; 3], sel: Field) -> [Field; 3] { | ||
let chosen_elements = multiplexer::select_element(xx, sel); | ||
return chosen_elements; | ||
} |