From 16af2d832d84a415f0b27b0bf81d7d7fc3190e3a Mon Sep 17 00:00:00 2001 From: John Wu Date: Sat, 10 Aug 2024 17:06:35 -0400 Subject: [PATCH 1/2] Implemented sumcheck protocol using the field elements and a custom built implementation of a multivariate polynomial --- .gitignore | 2 + Cargo.lock | 110 ++++ Cargo.toml | 4 +- src/algebra/field/prime/mod.rs | 9 +- src/kzg/mod.rs | 2 + src/lib.rs | 2 + src/polynomial/mod.rs | 1 + src/polynomial/multivariate_polynomial.rs | 489 ++++++++++++++++++ src/polynomial/tests.rs | 288 +++++++++++ src/random/mod.rs | 59 +++ src/sumcheck/boolean_array.rs | 37 ++ src/sumcheck/mod.rs | 604 ++++++++++++++++++++++ src/sumcheck/tests.rs | 151 ++++++ src/sumcheck/to_bytes.rs | 3 + 14 files changed, 1759 insertions(+), 2 deletions(-) create mode 100644 src/polynomial/multivariate_polynomial.rs create mode 100644 src/random/mod.rs create mode 100644 src/sumcheck/boolean_array.rs create mode 100644 src/sumcheck/mod.rs create mode 100644 src/sumcheck/tests.rs create mode 100644 src/sumcheck/to_bytes.rs diff --git a/.gitignore b/.gitignore index b1521060..237315a6 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,8 @@ cgt* # ignore instatiations of my test template # don't leak secret env vars .env +.history/* + # exclude compiled files and binaries debug/ target/ diff --git a/Cargo.lock b/Cargo.lock index e2d68e5b..cda390d1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -29,6 +29,15 @@ version = "0.2.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5c6cb57a04249c6480766f7f7cef5467412af1490f8d1e243141daddada3264f" +[[package]] +name = "approx" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6" +dependencies = [ + "num-traits", +] + [[package]] name = "ark-crypto-primitives" version = "0.4.0" @@ -206,6 +215,12 @@ dependencies = [ "generic-array", ] +[[package]] +name = "bytemuck" +version = "1.16.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "102087e286b4677862ea56cf8fc58bb2cdfa8725c40ffb80fe3a008eb7f2fc83" + [[package]] name = "cfg-if" version = "1.0.0" @@ -510,12 +525,49 @@ version = "0.2.155" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c" +[[package]] +name = "matrixmultiply" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9380b911e3e96d10c1f415da0876389aaf1b56759054eeb0de7df940c456ba1a" +dependencies = [ + "autocfg", + "rawpointer", +] + [[package]] name = "memchr" version = "2.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" +[[package]] +name = "nalgebra" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d506eb7e08d6329505faa8a3a00a5dcc6de9f76e0c77e4b75763ae3c770831ff" +dependencies = [ + "approx", + "matrixmultiply", + "nalgebra-macros", + "num-complex", + "num-rational", + "num-traits", + "simba", + "typenum", +] + +[[package]] +name = "nalgebra-macros" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "01fcc0b8149b4632adc89ac3b7b31a12fb6099a0317a4eb2ebff574ef7de7218" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "num-bigint" version = "0.4.6" @@ -526,6 +578,15 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-complex" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" +dependencies = [ + "num-traits", +] + [[package]] name = "num-integer" version = "0.1.46" @@ -535,6 +596,16 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-rational" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824" +dependencies = [ + "num-integer", + "num-traits", +] + [[package]] name = "num-traits" version = "0.2.19" @@ -641,6 +712,12 @@ dependencies = [ "getrandom", ] +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + [[package]] name = "regex" version = "1.10.5" @@ -686,6 +763,7 @@ dependencies = [ "des", "hex", "itertools", + "nalgebra", "pretty_assertions", "rand", "rstest", @@ -731,6 +809,15 @@ dependencies = [ "semver", ] +[[package]] +name = "safe_arch" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3460605018fdc9612bce72735cba0d27efbcd9904780d44c7e3a9948f96148a" +dependencies = [ + "bytemuck", +] + [[package]] name = "semver" version = "1.0.23" @@ -748,6 +835,19 @@ dependencies = [ "digest", ] +[[package]] +name = "simba" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0b7840f121a46d63066ee7a99fc81dcabbc6105e437cae43528cea199b5a05f" +dependencies = [ + "approx", + "num-complex", + "num-traits", + "paste", + "wide", +] + [[package]] name = "slab" version = "0.4.9" @@ -861,6 +961,16 @@ version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +[[package]] +name = "wide" +version = "0.7.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "901e8597c777fa042e9e245bd56c0dc4418c5db3f845b6ff94fbac732c6a0692" +dependencies = [ + "bytemuck", + "safe_arch", +] + [[package]] name = "winnow" version = "0.5.40" diff --git a/Cargo.toml b/Cargo.toml index b26431d7..f26190ee 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,13 +4,15 @@ description="""ronkathon""" edition ="2021" license ="Apache2.0 OR MIT" name ="ronkathon" -repository ="https://github.com/pluto/ronkathon" + +repository ="https://github.com/wu-s-john/ronkathon" version ="0.1.0" [dependencies] rand ="0.8.5" itertools="0.13.0" hex ="0.4.3" +nalgebra = "0.29" [dev-dependencies] rstest ="0.22.0" diff --git a/src/algebra/field/prime/mod.rs b/src/algebra/field/prime/mod.rs index 8794f80f..b3d29084 100644 --- a/src/algebra/field/prime/mod.rs +++ b/src/algebra/field/prime/mod.rs @@ -8,7 +8,7 @@ use std::{fmt, str::FromStr}; use rand::{distributions::Standard, prelude::Distribution, Rng}; use super::*; -use crate::algebra::Finite; +use crate::{algebra::Finite, random::Random}; mod arithmetic; @@ -41,6 +41,13 @@ pub struct PrimeField { pub(crate) value: usize, } +impl Random for PlutoBaseField { + fn random(rng: &mut R) -> Self { + let value = rng.gen_range(0..PlutoPrime::Base as usize); + PlutoBaseField::new(value) + } +} + impl PrimeField

{ /// Creates a new element of the [`PrimeField`] and will automatically compute the modulus and /// return a congruent element between 0 and `P`. Given the `const fn is_prime`, a program that diff --git a/src/kzg/mod.rs b/src/kzg/mod.rs index 85cafee5..11eefe53 100644 --- a/src/kzg/mod.rs +++ b/src/kzg/mod.rs @@ -6,3 +6,5 @@ pub mod setup; pub use setup::*; use super::*; + + diff --git a/src/lib.rs b/src/lib.rs index 40ae5dc5..20616b37 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -33,6 +33,8 @@ pub mod hashes; pub mod kzg; pub mod polynomial; pub mod tree; +pub mod random; +pub mod sumcheck; use core::{ fmt::{self, Display, Formatter}, diff --git a/src/polynomial/mod.rs b/src/polynomial/mod.rs index 49d9a740..7223aa6f 100644 --- a/src/polynomial/mod.rs +++ b/src/polynomial/mod.rs @@ -23,6 +23,7 @@ use super::*; use crate::algebra::field::FiniteField; pub mod arithmetic; +pub mod multivariate_polynomial; #[cfg(test)] mod tests; // https://people.inf.ethz.ch/gander/papers/changing.pdf diff --git a/src/polynomial/multivariate_polynomial.rs b/src/polynomial/multivariate_polynomial.rs new file mode 100644 index 00000000..9dadd92c --- /dev/null +++ b/src/polynomial/multivariate_polynomial.rs @@ -0,0 +1,489 @@ +//! Represents a multivariate polynomial over a finite field. +//! +//! This implementation uses a novel and highly efficient representation for multivariate polynomials. +//! Each term in the polynomial is represented as a key-value pair in a HashMap, where: +//! - The key is a BTreeMap mapping variable indices to their exponents. +//! - The value is the coefficient of the term. +//! +//! This representation offers several significant advantages: +//! 1. Space Efficiency: Only non-zero terms are stored, making it ideal for sparse polynomials. +//! 2. Fast Term Lookup: The use of BTreeMap for exponents allows for quick term identification and manipulation. +//! 3. Ordered Operations: BTreeMap's ordered nature facilitates efficient polynomial arithmetic. +//! 4. Memory Optimization: By using indices instead of full variable objects, we reduce memory usage. +//! 5. Flexible Degree Handling: This structure naturally accommodates polynomials of arbitrary degree. +//! 6. Efficient Iteration: Easy to iterate over terms, useful for various algorithms and transformations. +//! +//! While this representation may have a slight overhead for very small polynomials, +//! its benefits become increasingly apparent as the polynomial's complexity grows, +//! making it an excellent choice for a wide range of cryptographic and algebraic applications. + +use std::collections::{HashMap, BTreeMap}; +use std::hash::Hash; +use std::ops::{Add, Mul}; +use itertools::Itertools; + +use crate::algebra::field::FiniteField; +use super::*; +use super::{Monomial, Polynomial}; + +use std::ops::Sub; + + +/// Represents a multivariate polynomial over a finite field. +/// +/// The polynomial is stored as a collection of terms, where each term is represented by: +/// - A `BTreeMap` as the key, mapping variable indices to their exponents. +/// This allows for efficient storage and manipulation of sparse polynomials. +/// - An `F` value as the coefficient, where `F` is a finite field. +/// +/// The use of `HashMap` for `terms` provides: +/// 1. O(1) average-case complexity for term lookup and insertion. +/// 2. Efficient storage for sparse polynomials, as only non-zero terms are stored. +/// +/// The use of `BTreeMap` for exponents provides: +/// 1. Ordered storage of variable exponents, facilitating polynomial arithmetic. +/// 2. Efficient comparison and manipulation of terms. +/// 3. Memory efficiency by using indices instead of full variable objects. +/// +/// This representation is particularly effective for large, sparse multivariate polynomials +/// commonly encountered in cryptographic and algebraic applications. +#[derive(PartialEq, Eq, Debug, Clone)] +pub struct MultivariatePolynomial { + terms: HashMap, F>, +} + +impl MultivariatePolynomial { + /// Constructs a new `MultivariatePolynomial` representing the zero polynomial. + /// + /// This function creates an empty `MultivariatePolynomial`, which is equivalent to the zero polynomial. + /// The zero polynomial has no terms, and evaluates to zero for all inputs. + pub fn new() -> Self { + Self { terms: HashMap::new() } + } + + + /// Creates a new `MultivariatePolynomial` from a vector of `MultivariateTerm`s. + /// + /// This is the preferred way to create a multivariate polynomial, as it allows + /// for a more intuitive representation of the polynomial's terms. + /// + /// # Arguments + /// + /// * `terms` - A vector of `MultivariateTerm`s representing the polynomial. + /// + /// # Returns + /// + /// A new `MultivariatePolynomial` instance. + /// + /// # Example + /// + /// ``` + /// use your_crate::{MultivariatePolynomial, MultivariateTerm, MultivariateVariable, PlutoBaseField}; + /// + /// let poly = MultivariatePolynomial::::from_terms(vec![ + /// MultivariateTerm::new( + /// vec![ + /// MultivariateVariable { index: 0, exponent: 2 }, + /// MultivariateVariable { index: 1, exponent: 1 } + /// ], + /// PlutoBaseField::new(3) + /// ), + /// MultivariateTerm::new( + /// vec![MultivariateVariable { index: 0, exponent: 1 }], + /// PlutoBaseField::new(2) + /// ), + /// MultivariateTerm::new( + /// vec![], + /// PlutoBaseField::new(1) + /// ) + /// ]); + /// + /// // This creates the polynomial: 3x_0^2*x_1 + 2x_0 + 1 + /// ``` + pub fn from_terms(terms: Vec>) -> Self { + let mut poly = MultivariatePolynomial::new(); + for term in terms { + let mut btree_map = BTreeMap::new(); + for var in term.variables { + btree_map.insert(var.index, var.exponent); + } + poly.insert_term(btree_map, term.coefficient); + } + poly + } + + fn insert_term(&mut self, exponents: BTreeMap, coefficient: F) { + if coefficient != F::ZERO { + let entry = self.terms.entry(exponents.clone()).or_insert(F::ZERO); + *entry += coefficient; + if *entry == F::ZERO { + self.terms.remove(&exponents); + } + } + } + + /// Returns the coefficient of the term with the given exponents. + /// + /// # Arguments + /// + /// * `exponents` - A `BTreeMap` where the keys are variable indices and the values are their exponents. + /// + /// # Returns + /// + /// * `Some(&F)` if a term with the given exponents exists in the polynomial, where `F` is the coefficient. + /// * `None` if no term with the given exponents exists in the polynomial. + /// + /// # Example + /// + /// ``` + /// use std::collections::BTreeMap; + /// use your_crate::{MultivariatePolynomial, PlutoBaseField}; + /// + /// let poly = MultivariatePolynomial::::from_terms(vec![ + /// // ... (terms as in the previous example) + /// ]); + /// + /// let mut exponents = BTreeMap::new(); + /// exponents.insert(0, 2); + /// exponents.insert(1, 1); + /// + /// assert_eq!(poly.coefficient(&exponents), Some(&PlutoBaseField::new(3))); + /// ``` + pub fn coefficient(&self, exponents: &BTreeMap) -> Option<&F> { + self.terms.get(exponents) + } + + /// Evaluates the multivariate polynomial at the given points. + /// + /// # Arguments + /// + /// * `points` - A slice of tuples where each tuple contains: + /// - The index of the variable (usize) + /// - The value to evaluate the variable at (F) + /// + /// # Returns + /// + /// * The result of evaluating the polynomial (F) + /// + /// # Example + /// + /// ``` + /// use your_crate::{MultivariatePolynomial, PlutoBaseField}; + /// + /// let poly = MultivariatePolynomial::::from_terms(vec![ + /// // x^2 + 2xy + 3z + /// // ... (terms definition) + /// ]); + /// + /// let points = vec![(0, PlutoBaseField::new(2)), (1, PlutoBaseField::new(3)), (2, PlutoBaseField::new(1))]; + /// let result = poly.evaluate(&points); + /// // result will be the evaluation of x^2 + 2xy + 3z at x=2, y=3, z=1 + /// ``` + pub fn evaluate(&self, points: &[(usize, F)]) -> F { + self.terms.iter().map(|(exponents, coeff)| { + let term_value = exponents.iter().map(|(&var, &exp)| { + points.iter() + .find(|&&(v, _)| v == var) + .map(|&(_, value)| value.pow(exp)) + .unwrap_or(F::ONE) + }).product::(); + *coeff * term_value + }).sum() + } + + /// Applies the given variable assignments to the polynomial, reducing its degree. + /// + /// This method substitutes the specified variables with their corresponding values, + /// effectively reducing the polynomial's degree for those variables. The resulting + /// polynomial will have fewer variables if any were fully substituted. + /// + /// # Arguments + /// + /// * `variables` - A slice of tuples, where each tuple contains: + /// - The index of the variable to substitute (usize) + /// - The value to substitute for that variable (F) + /// + /// # Returns + /// + /// A new `MultivariatePolynomial` with the specified variables substituted. + /// + /// # Example + /// + /// ``` + /// use your_crate::{MultivariatePolynomial, PlutoBaseField}; + /// + /// let poly = MultivariatePolynomial::::from_terms(vec![ + /// // x^2y + 3xyz + 2z + /// // ... (terms definition) + /// ]); + /// + /// let assignments = vec![(1, PlutoBaseField::new(2))]; // y = 2 + /// let reduced_poly = poly.apply_variables(&assignments); + /// // The resulting polynomial will be of the form: 2x^2 + 6xz + 2z + /// ``` + pub fn apply_variables(&self, variables: &[(usize, F)]) -> Self { + let mut result = MultivariatePolynomial::new(); + + for (exponents, coeff) in &self.terms { + let mut new_exponents = exponents.clone(); + let mut new_coeff = *coeff; + + for &(var, value) in variables { + if let Some(exp) = new_exponents.get(&var) { + new_coeff *= value.pow(*exp); + new_exponents.remove(&var); + } + } + + if !new_exponents.is_empty() { + result.insert_term(new_exponents, new_coeff); + } else { + result.insert_term(BTreeMap::new(), new_coeff); + } + } + + result + } + + + /// Calculates the total degree of the multivariate polynomial. + /// + /// The total degree of a multivariate polynomial is the maximum sum of exponents + /// across all terms in the polynomial. + /// + /// # Returns + /// + /// * `usize` - The total degree of the polynomial. + /// + /// # Example + /// + /// ``` + /// use your_crate::MultivariatePolynomial; + /// use your_crate::FiniteField; + /// + /// let poly = MultivariatePolynomial::::from_terms(vec![ + /// // x^2y + 3xyz^2 + 2z + /// // ... (terms definition) + /// ]); + /// + /// assert_eq!(poly.degree(), 4); // The term xyz^2 has the highest total degree of 4 + /// ``` + pub fn degree(&self) -> usize { + self.terms.keys() + .map(|exponents| exponents.values().sum::()) + .max() + .unwrap_or(0) + } + + /// Returns a vector of all variables present in the polynomial. + /// + /// This method collects all unique variables (represented by their indices) + /// that appear in any term of the polynomial. + /// + /// # Returns + /// + /// * `Vec` - A vector containing the indices of all variables in the polynomial. + /// + /// # Example + /// + /// ``` + /// use your_crate::MultivariatePolynomial; + /// use your_crate::FiniteField; + /// + /// let poly = MultivariatePolynomial::::from_terms(vec![ + /// // x^2y + 3xyz^2 + 2z + /// // ... (terms definition) + /// ]); + /// + /// let vars = poly.variables(); + /// assert_eq!(vars, vec![0, 1, 2]); // Assuming x, y, z are represented by 0, 1, 2 respectively + /// ``` + pub fn variables(&self) -> Vec { + self.terms.keys() + .flat_map(|exponents| exponents.keys().cloned()) + .collect::>() + .into_iter() + .collect() + } +} + +impl Add for MultivariatePolynomial { + type Output = Self; + + fn add(mut self, rhs: Self) -> Self::Output { + for (exponents, coeff) in rhs.terms { + self.insert_term(exponents, coeff); + } + self + } +} + +impl Sub for MultivariatePolynomial { + type Output = Self; + + fn sub(mut self, rhs: Self) -> Self::Output { + for (exponents, coeff) in rhs.terms { + // Negate the coefficient and insert + self.insert_term(exponents, -coeff); + } + self + } +} + +impl Mul for MultivariatePolynomial { + type Output = Self; + + fn mul(self, rhs: Self) -> Self::Output { + let mut result = MultivariatePolynomial::new(); + for (exp1, coeff1) in &self.terms { + for (exp2, coeff2) in &rhs.terms { + let mut new_exp = exp1.clone(); + for (&var, &exp) in exp2 { + *new_exp.entry(var).or_insert(0) += exp; + } + result.insert_term(new_exp, *coeff1 * *coeff2); + } + } + result + } +} + +impl Display for MultivariatePolynomial { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + let mut first = true; + for (exponents, coeff) in self.terms.iter().sorted_by(|(a_exp, _), (b_exp, _)| { + a_exp.iter() + .zip(b_exp.iter()) + .find(|((a_var, a_pow), (b_var, b_pow))| { + a_var.cmp(b_var).then_with(|| b_pow.cmp(a_pow)).is_ne() + }) + .map_or(std::cmp::Ordering::Equal, |(_, _)| std::cmp::Ordering::Less) + }) { + if !first { + write!(f, " + ")?; + } + first = false; + + if *coeff != F::ONE || exponents.is_empty() { + write!(f, "{}", coeff)?; + } + + let mut first_var = true; + for (&var, &exp) in exponents { + if exp > 0 { + if !first_var || *coeff != F::ONE { + write!(f, "*")?; + } + write!(f, "x_{}", var)?; + if exp > 1 { + write!(f, "^{}", exp)?; + } + first_var = false; + } + } + } + + if first { + write!(f, "0")?; + } + + Ok(()) + } +} + +// Implement From for univariate polynomials +impl From> for MultivariatePolynomial { + fn from(poly: Polynomial) -> Self { + let mut result = MultivariatePolynomial::new(); + for (i, &coeff) in poly.coefficients.iter().enumerate() { + if coeff != F::ZERO { + let mut exponents = BTreeMap::new(); + exponents.insert(0, i); + result.insert_term(exponents, coeff); + } + } + result + } +} + +// Extend Polynomial to support conversion to multivariate +impl Polynomial { + /// Converts a univariate polynomial to a multivariate polynomial. + /// + /// This method transforms the current univariate polynomial into a multivariate polynomial + /// where all terms use the same variable, specified by `variable_index`. + /// + /// # Arguments + /// + /// * `variable_index` - The index of the variable to use in the resulting multivariate polynomial. + /// + /// # Returns + /// + /// A `MultivariatePolynomial` equivalent to the original univariate polynomial. + /// + /// # Example + /// + /// ``` + /// let univariate = Polynomial::new([F::ONE, F::TWO, F::THREE]); // x^2 + 2x + 1 + /// let multivariate = univariate.to_multivariate(0); + /// // Result: x_0^2 + 2*x_0 + 1 + /// ``` + pub fn to_multivariate(self, variable_index: usize) -> MultivariatePolynomial { + let mut result = MultivariatePolynomial::new(); + for (i, &coeff) in self.coefficients.iter().enumerate() { + if coeff != F::ZERO { + let mut exponents = BTreeMap::new(); + exponents.insert(variable_index, i); + result.insert_term(exponents, coeff); + } + } + result + } +} + + +/// Represents a variable with an exponent in a multivariate polynomial. +/// Each variable is uniquely identified by its index and has an associated exponent. +#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct MultivariateVariable { + /// The unique identifier for the variable. + /// This index distinguishes one variable from another in the polynomial. + pub index: usize, + + /// The power to which the variable is raised. + /// For example, if exponent is 2, it represents x^2 for the variable x. + pub exponent: usize, +} + +impl MultivariateVariable { + /// Creates a new multivariate variable with the given index and exponent. + pub fn new(index: usize, exponent: usize) -> Self { + MultivariateVariable { index, exponent } + } +} + +/// Represents a term in a multivariate polynomial. +/// +/// # Fields +/// +/// * `variables` - A vector of `MultivariateVariable`s representing the variables in this term. +/// * `coefficient` - The coefficient of this term, represented as a finite field element. +#[derive(PartialEq, Eq)] +pub struct MultivariateTerm { + /// A vector of `MultivariateVariable`s representing the variables in this term. + /// Each `MultivariateVariable` contains an index and an exponent. + pub variables: Vec, + + /// The coefficient of this term, represented as a finite field element. + /// This value multiplies the product of the variables in the term. + pub coefficient: F, +} + +/// Represents a term in a multivariate polynomial. +/// A term consists of a coefficient and a collection of variables with their exponents. +impl MultivariateTerm { + /// Creates a new multivariate term with the given variables and coefficient. + pub fn new(variables: Vec, coefficient: F) -> Self { + MultivariateTerm { variables, coefficient } + } +} diff --git a/src/polynomial/tests.rs b/src/polynomial/tests.rs index 7761e588..78b000a4 100644 --- a/src/polynomial/tests.rs +++ b/src/polynomial/tests.rs @@ -1,4 +1,5 @@ use super::*; +use crate::polynomial::multivariate_polynomial::{MultivariatePolynomial, MultivariateTerm, MultivariateVariable}; #[fixture] fn poly() -> Polynomial { @@ -126,3 +127,290 @@ fn dft(poly: Polynomial) { // Polynomial::::new(vec![PlutoBaseField::ZERO, // PlutoBaseField::ZERO]); assert_eq!(poly.coefficients, [PlutoBaseField::ZERO]); } + +#[test] +fn test_multivariate_polynomial_creation() { + let poly = MultivariatePolynomial::::from_terms(vec![ + MultivariateTerm::new( + vec![MultivariateVariable { index: 0, exponent: 2 }, MultivariateVariable { index: 1, exponent: 1 }], + PlutoBaseField::new(3) + ), + MultivariateTerm::new( + vec![MultivariateVariable { index: 0, exponent: 1 }], + PlutoBaseField::new(2) + ), + MultivariateTerm::new( + vec![], + PlutoBaseField::new(1) + ) + ]); + + assert_eq!(poly.degree(), 3); + + assert_eq!( + poly.variables().into_iter().collect::>(), + vec![0, 1].into_iter().collect::>() + ); +} + +#[test] +fn test_multivariate_polynomial_addition() { + let poly1 = MultivariatePolynomial::::from_terms(vec![ + MultivariateTerm::new( + vec![MultivariateVariable { index: 0, exponent: 2 }], + PlutoBaseField::new(1) + ), + MultivariateTerm::new( + vec![MultivariateVariable { index: 1, exponent: 1 }], + PlutoBaseField::new(2) + ) + ]); + + let poly2 = MultivariatePolynomial::::from_terms(vec![ + MultivariateTerm::new( + vec![MultivariateVariable { index: 0, exponent: 2 }], + PlutoBaseField::new(3) + ), + MultivariateTerm::new( + vec![MultivariateVariable { index: 2, exponent: 1 }], + PlutoBaseField::new(4) + ) + ]); + + let result = poly1 + poly2; + + println!("Addition Result polynomial: {}", result); + + let expected_result = MultivariatePolynomial::::from_terms(vec![ + MultivariateTerm::new( + vec![MultivariateVariable { index: 0, exponent: 2 }], + PlutoBaseField::new(4) + ), + MultivariateTerm::new( + vec![MultivariateVariable { index: 1, exponent: 1 }], + PlutoBaseField::new(2) + ), + MultivariateTerm::new( + vec![MultivariateVariable { index: 2, exponent: 1 }], + PlutoBaseField::new(4) + ) + ]); + + assert_eq!(result, expected_result); +} + +#[test] +fn test_multivariate_polynomial_multiplication() { + let poly1 = MultivariatePolynomial::::from_terms(vec![ + MultivariateTerm::new( + vec![MultivariateVariable { index: 0, exponent: 1 }], + PlutoBaseField::new(2) + ), + MultivariateTerm::new( + vec![], + PlutoBaseField::new(1) + ) + ]); + + let poly2 = MultivariatePolynomial::::from_terms(vec![ + MultivariateTerm::new( + vec![MultivariateVariable { index: 1, exponent: 1 }], + PlutoBaseField::new(3) + ), + MultivariateTerm::new( + vec![], + PlutoBaseField::new(1) + ) + ]); + + let result = poly1 * poly2; + + println!("Multiplication Result polynomial: {}", result); + + let expected_result = MultivariatePolynomial::::from_terms(vec![ + MultivariateTerm::new( + vec![MultivariateVariable { index: 0, exponent: 1 }, MultivariateVariable { index: 1, exponent: 1 }], + PlutoBaseField::new(6) + ), + MultivariateTerm::new( + vec![MultivariateVariable { index: 0, exponent: 1 }], + PlutoBaseField::new(2) + ), + MultivariateTerm::new( + vec![MultivariateVariable { index: 1, exponent: 1 }], + PlutoBaseField::new(3) + ), + MultivariateTerm::new( + vec![], + PlutoBaseField::new(1) + ) + ]); + + assert_eq!(result, expected_result); +} + +#[test] +fn test_multivariate_polynomial_evaluation() { + let poly = MultivariatePolynomial::::from_terms(vec![ + MultivariateTerm::new( + vec![MultivariateVariable { index: 0, exponent: 2 }, MultivariateVariable { index: 1, exponent: 1 }], + PlutoBaseField::new(3) + ), + MultivariateTerm::new( + vec![MultivariateVariable { index: 0, exponent: 1 }], + PlutoBaseField::new(2) + ), + MultivariateTerm::new( + vec![], + PlutoBaseField::new(1) + ) + ]); + + let points = vec![(0, PlutoBaseField::new(2)), (1, PlutoBaseField::new(3))]; + + println!("{}", poly); + + let result = poly.evaluate(&points); + // 3*(2^2)*(3) + 2*(2) + 1 = 3*4*3 + 4 + 1 = 36 + 4 + 1 = 41 + assert_eq!(result, PlutoBaseField::new(41)); +} + +#[test] +fn test_apply_variables_single_variable() { + let poly = MultivariatePolynomial::::from_terms(vec![ + MultivariateTerm::new( + vec![MultivariateVariable { index: 0, exponent: 2 }], + PlutoBaseField::new(3) + ), + MultivariateTerm::new( + vec![MultivariateVariable { index: 0, exponent: 1 }], + PlutoBaseField::new(2) + ), + MultivariateTerm::new( + vec![], + PlutoBaseField::new(1) + ) + ]); + + let variables = vec![(0, PlutoBaseField::new(2))]; + let result = poly.apply_variables(&variables); + + let expected_result = MultivariatePolynomial::::from_terms(vec![ + MultivariateTerm::new( + vec![], + PlutoBaseField::new(17) + ) + ]); + + assert_eq!(result, expected_result); +} + +#[test] +fn test_apply_variables_multiple_variables() { + let poly = MultivariatePolynomial::::from_terms(vec![ + MultivariateTerm::new( + vec![MultivariateVariable { index: 0, exponent: 2 }, MultivariateVariable { index: 1, exponent: 1 }], + PlutoBaseField::new(3) + ), + MultivariateTerm::new( + vec![MultivariateVariable { index: 0, exponent: 1 }], + PlutoBaseField::new(2) + ), + MultivariateTerm::new( + vec![MultivariateVariable { index: 1, exponent: 1 }], + PlutoBaseField::new(4) + ), + MultivariateTerm::new( + vec![], + PlutoBaseField::new(1) + ) + ]); + + println!("Apply Multiple Variables Polynomial: {}", poly); + + let variables = vec![(0, PlutoBaseField::new(2)), (1, PlutoBaseField::new(3))]; + let result = poly.apply_variables(&variables); + + println!("Reduced Multiple Variable Polynomial: {}", result); + + let expected_result = MultivariatePolynomial::::from_terms(vec![ + MultivariateTerm::new( + vec![], + PlutoBaseField::new(53) + ) + ]); + + assert_eq!(result, expected_result); +} + +#[test] +fn test_apply_variables_partial_application() { + let poly = MultivariatePolynomial::::from_terms(vec![ + MultivariateTerm::new( + vec![ + MultivariateVariable { index: 0, exponent: 2 }, + MultivariateVariable { index: 1, exponent: 1 }, + MultivariateVariable { index: 2, exponent: 1 } + ], + PlutoBaseField::new(3) + ), + MultivariateTerm::new( + vec![MultivariateVariable { index: 0, exponent: 1 }, MultivariateVariable { index: 2, exponent: 1 }], + PlutoBaseField::new(2) + ), + MultivariateTerm::new( + vec![MultivariateVariable { index: 1, exponent: 1 }], + PlutoBaseField::new(4) + ) + ]); + + let variables = vec![(0, PlutoBaseField::new(2)), (1, PlutoBaseField::new(3))]; + let applied_poly = poly.apply_variables(&variables); + + let expected_result = MultivariatePolynomial::::from_terms(vec![ + MultivariateTerm::new( + vec![MultivariateVariable { index: 2, exponent: 1 }], + PlutoBaseField::new(40) + ), + MultivariateTerm::new( + vec![], + PlutoBaseField::new(12) + ) + ]); + + assert_eq!(applied_poly, expected_result); +} + +#[test] +fn test_apply_variables_no_effect() { + let poly = MultivariatePolynomial::::from_terms(vec![ + MultivariateTerm::new( + vec![MultivariateVariable { index: 0, exponent: 2 }, MultivariateVariable { index: 1, exponent: 1 }], + PlutoBaseField::new(3) + ), + MultivariateTerm::new( + vec![MultivariateVariable { index: 2, exponent: 1 }], + PlutoBaseField::new(2) + ) + ]); + + let variables = vec![(3, PlutoBaseField::new(5))]; + let result = poly.apply_variables(&variables); + + assert_eq!(result, poly); +} + +#[test] +fn test_apply_variables_empty() { + let poly = MultivariatePolynomial::::from_terms(vec![ + MultivariateTerm::new( + vec![MultivariateVariable { index: 0, exponent: 2 }, MultivariateVariable { index: 1, exponent: 1 }], + PlutoBaseField::new(3) + ) + ]); // 3x_0^2*x_1 + + let variables = vec![]; + let result = poly.apply_variables(&variables); + + assert_eq!(result, poly); +} diff --git a/src/random/mod.rs b/src/random/mod.rs new file mode 100644 index 00000000..5f6af5d3 --- /dev/null +++ b/src/random/mod.rs @@ -0,0 +1,59 @@ +//! # Random Number Generation and Random Oracle Functionality +//! +//! This module provides traits and utilities for random number generation and +//! random oracle functionality, which are essential for various cryptographic +//! operations and protocols. +//! +//! ## Key Components +//! +//! - `Random`: A trait for types that can be randomly generated. +//! - `RandomOracle`: A trait for types that can be generated using a random oracle approach. +//! +//! These traits allow for flexible and secure generation of random instances +//! for implementing types, supporting both standard random generation and +//! more complex random oracle-based generation. +//! +//! The module is designed to work seamlessly with the `rand` crate's `Rng` trait, +//! providing a consistent interface for random number generation across the library. + + +use rand::Rng; + +/// A trait for types that can be randomly generated. +/// +/// Types implementing this trait can create random instances of themselves +/// using a provided random number generator. +pub trait Random { + /// Generates a random instance of the implementing type. + /// + /// # Arguments + /// + /// * `rng` - A mutable reference to a random number generator. + /// + /// # Returns + /// + /// A randomly generated instance of the implementing type. + fn random(rng: &mut R) -> Self; +} + +/// A trait for types that can be generated using a random oracle. +/// +/// Types implementing this trait can create instances of themselves +/// using a provided random number generator and an input byte slice, +/// simulating a random oracle functionality. +pub trait RandomOracle: Random { + /// Generates an instance of the implementing type using a random oracle approach. + /// + /// This method takes both a random number generator and an input byte slice, + /// allowing for deterministic yet unpredictable output based on the input. + /// + /// # Arguments + /// + /// * `rng` - A mutable reference to a random number generator. + /// * `input` - A byte slice used as input to the random oracle. + /// + /// # Returns + /// + /// An instance of the implementing type, generated using the random oracle approach. + fn random_oracle(rng: &mut R, input: &[u8]) -> Self; +} diff --git a/src/sumcheck/boolean_array.rs b/src/sumcheck/boolean_array.rs new file mode 100644 index 00000000..5b1e6c3b --- /dev/null +++ b/src/sumcheck/boolean_array.rs @@ -0,0 +1,37 @@ +struct BooleanArrayIter { + current: Vec, + done: bool, +} + +impl Iterator for BooleanArrayIter { + type Item = Vec; + + fn next(&mut self) -> Option { + if self.done { + return None; + } + + let result = self.current.clone(); + + // Generate next array + for i in 0..self.current.len() { + if self.current[i] { + self.current[i] = false; + } else { + self.current[i] = true; + return Some(result); + } + } + + // If we've reached here, we've generated all arrays + self.done = true; + Some(result) + } +} + +pub fn get_all_possible_boolean_values(length: usize) -> impl Iterator> { + BooleanArrayIter { + current: vec![false; length], + done: false, + } +} \ No newline at end of file diff --git a/src/sumcheck/mod.rs b/src/sumcheck/mod.rs new file mode 100644 index 00000000..9f46ca8b --- /dev/null +++ b/src/sumcheck/mod.rs @@ -0,0 +1,604 @@ +//! # Sumcheck Protocol Implementation +//! +//! This module implements the sumcheck protocol, a powerful interactive proof system +//! used in various zero-knowledge proof constructions. +//! +//! ## Overview of the Sumcheck Protocol +//! +//! The sumcheck protocol allows a prover to convince a verifier that the sum of a +//! multivariate polynomial over all boolean inputs (i.e., the boolean hypercube) is +//! equal to a claimed value, without the verifier having to compute the sum directly. +//! +//! The protocol proceeds in rounds, where in each round: +//! 1. The prover sends a univariate polynomial. +//! 2. The verifier checks certain properties of this polynomial and sends a random challenge. +//! 3. This process reduces the multivariate polynomial to a univariate one in each round. +//! +//! ## Implementation Details +//! +//! This implementation provides both interactive and non-interactive versions of the sumcheck protocol. +//! +//! ### Prover Implementation +//! +//! The prover's implementation is split into several methods: +//! +//! - `prove_first_sumcheck_round`: Computes the claimed sum and the first univariate polynomial. +//! - `prove_sumcheck_round_i`: Generates the univariate polynomial for intermediate rounds. +//! - `prove_sumcheck_last_round`: Handles the final round of the protocol. +//! - `compute_univariate_polynomial`: A helper method to compute the univariate polynomial for each round. +//! +//! This structure allows for a clear separation of concerns and follows the round-based +//! nature of the sumcheck protocol. +//! +//! ### Verifier Implementation +//! +//! The verifier's implementation is divided into separate functions for each stage: +//! +//! - `verify_sumcheck_first_round`: Verifies the first round, checking the claimed sum. +//! - `verify_sumcheck_univariate_poly_sum`: Verifies intermediate rounds. +//! - `verify_sumcheck_last_round`: Performs the final verification step. +//! +//! This separation allows for clear and modular verification logic, closely following +//! the structure of the sumcheck protocol. +//! +//! ### Non-Interactive Version +//! +//! The module also provides non-interactive versions of the protocol: +//! +//! - `non_interactive_sumcheck_prove`: Generates a complete proof in one step. +//! - `non_interactive_sumcheck_verify`: Verifies the complete proof. +//! +//! These functions use a random oracle model to simulate the interactive challenges, +//! making the protocol suitable for non-interactive scenarios. +//! +//! ## Correctness and Efficiency +//! +//! The implementation correctly follows the sumcheck protocol: +//! +//! 1. It reduces the multivariate polynomial to univariate polynomials in each round. +//! 2. It uses random challenges to ensure the prover cannot predict the verification path. +//! 3. The final verification step ties the protocol back to the original multivariate polynomial. +//! +//! The use of `MultivariatePolynomial` and efficient polynomial operations ensures that +//! the implementation is both correct and computationally efficient. +//! +//! ## Usage +//! +//! To use this implementation, create a `MultivariatePolynomial`, then use the `non_interactive_sumcheck_prove` +//! function to generate a proof, and `non_interactive_sumcheck_verify` to verify it. +//! +//! For more fine-grained control, you can use the individual prover and verifier functions +//! to implement an interactive version of the protocol. + +use std::{fmt::Display, hash::{Hash, Hasher}}; + +use rand::{Rng, SeedableRng}; +use crate::{algebra::field::FiniteField, polynomial::multivariate_polynomial::MultivariatePolynomial, random::{Random, RandomOracle}}; + +mod boolean_array; +mod to_bytes; +#[cfg(test)] mod tests; + +use self::{boolean_array::get_all_possible_boolean_values, to_bytes::ToBytes}; + +impl RandomOracle for F { + fn random_oracle(_rng: &mut R, input: &[u8]) -> Self { + // This is a simplified example. In a real implementation, + // you'd want to use a cryptographic hash function here. + let mut hasher = std::collections::hash_map::DefaultHasher::new(); + input.hash(&mut hasher); + let hash = hasher.finish(); + + // Use the hash to seed a new RNG + let mut seeded_rng = rand::rngs::StdRng::seed_from_u64(hash); + + // Generate a random field element using the seeded RNG + Self::random(&mut seeded_rng) + } +} + +impl MultivariatePolynomial { + + + /// Proves the first round of the sumcheck protocol for this multivariate polynomial. + /// + /// This function is crucial for initiating the sumcheck protocol because: + /// 1. It computes the total sum of the polynomial over all boolean inputs, which is the + /// claimed sum that the prover wants to prove. + /// 2. It generates the first univariate polynomial g_1(X1), which is the first step in + /// reducing the multivariate sumcheck to a series of univariate sumchecks. + /// + /// The sumcheck protocol is essential for efficiently verifying the sum of a multivariate + /// polynomial over a boolean hypercube without evaluating every point, which would be + /// exponential in the number of variables. This function sets up the foundation for the + /// entire protocol. + /// + /// # Returns + /// - A tuple containing: + /// 1. The claimed sum (F): The total sum of the polynomial over all boolean inputs. + /// 2. The first univariate polynomial (MultivariatePolynomial): g_1(X1), which is + /// actually univariate despite the type name. + pub fn prove_first_sumcheck_round(&self) -> (F, MultivariatePolynomial) { + let variables = self.variables(); + let num_variables = variables.len(); + + let sum = get_all_possible_boolean_values(num_variables) + .map(|bool_values| { + let assignment: Vec<(usize, F)> = variables.iter().enumerate() + .map(|(i, &var)| (var, if bool_values[i] { F::ONE } else { F::ZERO })) + .collect(); + self.evaluate(&assignment) + }) + .sum(); + + // Compute the univariate polynomial g_1(X1) + let univariate_poly = self.compute_univariate_polynomial(0, vec![]); + + (sum, univariate_poly) + } + + /// Proves the i-th round of the sumcheck protocol for this multivariate polynomial. + /// + /// This function is a key part of the sumcheck protocol, generating the univariate polynomial + /// for the i-th round based on the partial assignment from previous rounds. + /// + /// # Arguments + /// + /// * `i` - The current round number (0-indexed). + /// * `partial_assignment` - A vector of field elements representing the values chosen by the + /// verifier in previous rounds. + /// + /// # Returns + /// + /// * `MultivariatePolynomial` - The univariate polynomial g_i(X_i) for the i-th round. + /// Despite the type name, this polynomial is univariate in X_i. + /// + /// # Properties and Equalities + /// + /// 1. Degree Preservation: The degree of g_i(X_i) in X_i is at most the degree of the original + /// polynomial in X_i. + /// + /// 2. Sum Consistency: The sum of g_i(X_i) over {0,1} equals g_{i-1}(r_{i-1}), where r_{i-1} + /// is the random challenge from the previous round. + /// + /// 3. Randomized Reduction: g_i(X_i) reduces the sum check for i variables to a sum check + /// for i-1 variables when a random point is chosen. + /// + /// 4. Partial Evaluation: g_i(X_i) can be seen as a partial evaluation of the original + /// polynomial, with the first i-1 variables fixed to the values in partial_assignment. + /// + /// These properties ensure the soundness and completeness of the sumcheck protocol, + /// allowing for efficient verification of the claimed sum. + /// + /// # Note + /// + /// This function relies on `compute_univariate_polynomial` to perform the actual computation + /// of the univariate polynomial for the current round. + pub fn prove_sumcheck_round_i( + &self, + i: usize, + partial_assignment: Vec, + ) -> MultivariatePolynomial { + return self.compute_univariate_polynomial(i, partial_assignment); + } + + /// Proves the last round of the sumcheck protocol for this multivariate polynomial. + /// + /// This function is similar to `prove_sumcheck_round_i`, but specifically handles the last round + /// of the sumcheck protocol. It generates the final univariate polynomial based on all previous + /// assignments. + /// + /// # Arguments + /// + /// * `i` - The index of the last round (should be equal to the number of variables minus 1). + /// * `partial_assignment` - A vector of field elements representing all values chosen by the + /// verifier in previous rounds. + /// + /// # Returns + /// + /// * `MultivariatePolynomial` - The final univariate polynomial for the last round. + /// This polynomial is univariate in the last remaining variable. + /// + /// # Note + /// + /// This function relies on `compute_univariate_polynomial` to perform the actual computation + /// of the univariate polynomial for the last round. The result of this function is crucial + /// for the final verification step in the sumcheck protocol. + pub fn prove_sumcheck_last_round( + &self, + i: usize, + partial_assignment: Vec, + ) -> MultivariatePolynomial { + return self.compute_univariate_polynomial(i, partial_assignment); + } + + fn compute_univariate_polynomial( + &self, + round: usize, + partial_assignment: Vec, + ) -> MultivariatePolynomial { + let variables = self.variables(); + let num_variables = variables.len(); + + // Create a polynomial to store the result + // First create a partial evaluation + let partial_poly = self.apply_variables( + &partial_assignment.iter().enumerate().map(|(i, &v)| (i, v)).collect::>(), + ); + + let result_polynomial = get_all_possible_boolean_values(num_variables - round - 1) + .map(|bool_values| { + let further_assignments: Vec = + bool_values.iter().map(|&b| if b { F::ONE } else { F::ZERO }).collect(); + let further_variables = ((round + 1)..num_variables).zip(further_assignments).collect::>(); + let poly = partial_poly.clone().apply_variables( + &further_variables, + ); + poly + }) + .fold(MultivariatePolynomial::new(), |acc, poly| acc + poly); + + // Assert that the resulting polynomial has only one variable + assert!( + result_polynomial.variables().len() <= 1, + "The univariate polynomial should have at most one variable" + ); + + result_polynomial + } + +} + +/// Verifies the first round of the sumcheck protocol. +/// +/// This function is crucial for initiating the verification process in the sumcheck protocol. +/// The verifier needs these components to ensure the correctness of the prover's claim: +/// +/// 1. `claimed_sum`: The total sum claimed by the prover. This is the value that the +/// verifier wants to check without computing the entire sum themselves. +/// +/// 2. `univariate_poly`: The first univariate polynomial g_1(X_1) provided by the prover. +/// This polynomial is supposed to represent the sum over all but the first variable. +/// +/// The verification process involves: +/// +/// 1. Checking that the provided polynomial is indeed univariate. This ensures that +/// the prover is following the protocol correctly by reducing one variable at a time. +/// +/// 2. Verifying that g_1(0) + g_1(1) equals the claimed sum. This check is fundamental +/// because it connects the univariate polynomial to the original multivariate sum. +/// If this equality holds, it suggests that the prover has correctly computed the +/// univariate polynomial for the first round. +/// +/// 3. Generating a random challenge. This challenge will be used in subsequent rounds +/// and is crucial for the security of the protocol. It ensures that the prover +/// cannot predict or manipulate future rounds. +/// +/// # Arguments +/// +/// * `claimed_sum`: The sum claimed by the prover. +/// * `univariate_poly`: The univariate polynomial for the first round. +/// +/// # Returns +/// +/// A tuple containing: +/// - A boolean indicating whether the verification passed (true) or failed (false). +/// - The random challenge generated for the next round. +/// +/// # Type Parameters +/// +/// * `F`: A type that implements both `FiniteField` and `Random` traits. +pub fn verify_sumcheck_first_round( + claimed_sum: F, + univariate_poly: &MultivariatePolynomial +) -> (bool, F) { + // Step 1: Verify that the polynomial is univariate (has only one variable) + if univariate_poly.variables().len() != 1 { + return (false, F::ZERO); + } + + // Step 2: Verify that g(0) + g(1) = claimed_sum + let var = 0; + let sum_at_endpoints = univariate_poly.evaluate(&[(var, F::ZERO)]) + univariate_poly.evaluate(&[(var, F::ONE)]); + + if sum_at_endpoints != claimed_sum { + return (false, F::ZERO); + } + + // Step 3: Generate a random challenge + let mut rng = rand::thread_rng(); + let challenge: F = F::random(&mut rng); + + // Return true (verification passed) and the evaluation at the challenge point + (true, challenge) +} + +/// Verify the i-th round of the sumcheck protocol +/// +/// This function is crucial for verifying the correctness of each intermediate step in the sumcheck protocol. +/// It ensures that the prover is following the protocol correctly and not deviating from the expected behavior. +/// +/// # Arguments +/// +/// * `round`: The current round number of the sumcheck protocol. This is needed to keep track of which variable +/// is being eliminated in the current round. +/// * `challenge`: The random challenge from the previous round. This is used to evaluate the previous round's +/// polynomial and connect it to the current round. +/// * `previous_univariate_poly`: The univariate polynomial from the previous round. This is needed to verify +/// the consistency between rounds. +/// * `current_univariate_poly`: The univariate polynomial for the current round. This is the polynomial that +/// the prover claims represents the sum over the current variable. +/// +/// # Returns +/// +/// A tuple containing: +/// - A boolean indicating whether the verification passed (true) or failed (false). +/// - The new random challenge generated for the next round. +/// +/// # Why these parameters are needed +/// +/// 1. `round`: Keeps track of the protocol's progress and ensures variables are eliminated in order. +/// 2. `challenge`: Connects the current round to the previous one, preventing the prover from deviating. +/// 3. `previous_univariate_poly`: Used to verify consistency between rounds. +/// 4. `current_univariate_poly`: The polynomial to be verified in the current round. +/// +/// These parameters allow the verifier to check: +/// - The univariate nature of the current polynomial (ensuring one variable is eliminated per round). +/// - The consistency between rounds (g_i(r_{i-1}) = g_{i-1}(0) + g_{i-1}(1)). +/// - Generate a new random challenge for the next round, maintaining the protocol's unpredictability. +pub fn verify_sumcheck_univariate_poly_sum( + round: usize, + challenge: F, + previous_univariate_poly: &MultivariatePolynomial, + current_univariate_poly: &MultivariatePolynomial, +) -> (bool, F) { + // Step 1: Verify that the current polynomial is univariate + if current_univariate_poly.variables().len() > 1 { + return (false, F::ZERO); + } + + // Step 2: Verify that g_i(r_{i-1}) = g_{i-1}(0) + g_{i-1}(1) + let prev_var = round - 1; + let sum_at_endpoints = previous_univariate_poly.evaluate(&[(prev_var, challenge)]); + + let eval_at_previous_challenge = current_univariate_poly.evaluate(&[(round, F::ZERO)]) + + current_univariate_poly.evaluate(&[(round, F::ONE)]); + if eval_at_previous_challenge != sum_at_endpoints { + return (false, F::ZERO); + } + + // Step 3: Generate a new random challenge + let mut rng = rand::thread_rng(); + let new_challenge: F = F::random(&mut rng); + + // Return true (verification passed) and the evaluation at the new challenge point + (true, new_challenge) +} + +/// Verifies the final round of the sumcheck protocol. +/// +/// This function is crucial for the verifier to ensure the prover's honesty in the final round +/// of the sumcheck protocol. It checks if the claimed univariate polynomial is consistent with +/// the original multivariate polynomial when all challenges are applied. +/// +/// # Arguments +/// +/// * `challenges`: A vector of all previous challenges from earlier rounds. +/// * `univariate_poly`: The final univariate polynomial claimed by the prover. +/// * `poly`: The original multivariate polynomial. +/// +/// # Returns +/// +/// A boolean indicating whether the verification passed (true) or failed (false). +/// +/// # Why the verifier needs this information +/// +/// 1. `challenges`: The verifier needs all previous challenges to reconstruct the point at which +/// the original polynomial should be evaluated. This ensures consistency across all rounds. +/// +/// 2. `univariate_poly`: This is the prover's final claim about the polynomial after all but one +/// variable have been fixed. The verifier needs to check if this claim is consistent with the +/// original polynomial. +/// +/// 3. `poly`: The original multivariate polynomial is necessary to independently compute the +/// correct evaluation and compare it with the prover's claim. +/// +/// By comparing the evaluation of the original polynomial at the challenge point with the +/// evaluation of the claimed univariate polynomial at a random point, the verifier can detect +/// any dishonesty from the prover with high probability. +pub fn verify_sumcheck_last_round( + challenges: Vec, + univariate_poly: &MultivariatePolynomial, + poly: &MultivariatePolynomial, +) -> bool { + // Step 1: Apply all challenges to the original polynomial + let mut challenges_with_indices = Vec::new(); + for (i, challenge) in challenges.iter().enumerate() { + challenges_with_indices.push((i, *challenge)); + } + let poly_evaluation = poly.evaluate(&challenges_with_indices); + + // Step 2: Generate a random challenge for the last variable + let mut rng = rand::thread_rng(); + let last_challenge: F = F::random(&mut rng); + + // Step 3: Evaluate the univariate polynomial at the last challenge + let last_var = challenges.len(); + let univariate_evaluation = univariate_poly.evaluate(&[(last_var, last_challenge)]); + + // Step 4: Compare the evaluations + poly_evaluation == univariate_evaluation +} + +impl ToBytes for F { + fn to_bytes(&self) -> Vec { + // Implement this based on how your field elements are represented + // This is just an example: + self.to_string().into_bytes() + } +} + +impl ToBytes for MultivariatePolynomial { + fn to_bytes(&self) -> Vec { + // Implement this based on how your polynomials are represented + // This is just an example: + self.to_string().into_bytes() + } +} + +/// Represents a proof for the sumcheck protocol over a finite field. +/// +/// This struct contains all the components necessary for verifying a sumcheck proof, +/// including the claimed sum, round polynomials, challenges, evaluations, and final results. +/// +/// # Type Parameters +/// +/// * `F`: A type that implements the `FiniteField` trait, representing the field over which +/// the sumcheck protocol is performed. +pub struct SumcheckProof { + /// The claimed sum of the polynomial over all boolean inputs. + pub claimed_sum: F, + + /// Vector of univariate polynomials, one for each round of the protocol. + pub round_polynomials: Vec>, + + /// Vector of challenges generated during the protocol. + pub challenges: Vec, + + /// Vector of evaluations of the round polynomials at the challenge points. + pub round_evaluations: Vec, + + /// The final evaluation point, consisting of all challenges combined. + pub final_point: Vec, + + /// The final evaluation of the original multivariate polynomial at the final point. + pub final_evaluation: F, +} + + +/// Generates a non-interactive sumcheck proof for a given multivariate polynomial. +/// +/// This function implements the prover's side of the non-interactive sumcheck protocol. +/// It generates a proof that the sum of the polynomial over all boolean inputs equals +/// the claimed sum, without requiring interaction with the verifier. +/// +/// The non-interactive nature is achieved by using a random oracle to generate challenges, +/// which both the prover and verifier can compute independently. +/// +/// # Arguments +/// +/// * `polynomial` - The multivariate polynomial for which to generate the sumcheck proof. +/// +/// # Returns +/// +/// Returns a `SumcheckProof` containing all necessary components for verification: +/// - The claimed sum +/// - Univariate polynomials for each round +/// - Challenges generated using the random oracle +/// - Evaluations of the round polynomials at the challenge points +/// - The final evaluation point and the polynomial's evaluation at that point +/// +/// # Type Parameters +/// +/// * `F` - A finite field type that implements necessary traits for arithmetic, +/// random number generation, conversion to bytes, and display. +pub fn non_interactive_sumcheck_prove( + polynomial: &MultivariatePolynomial +) -> SumcheckProof { + let num_variables = polynomial.variables().len(); + let mut challenges = Vec::new(); + let mut round_polynomials = Vec::new(); + let mut round_evaluations = Vec::new(); + + // First round: compute the claimed sum and the first univariate polynomial + let (claimed_sum, first_univariate_poly) = polynomial.prove_first_sumcheck_round(); + round_polynomials.push(first_univariate_poly.clone()); + + // Generate the first challenge using the random oracle + let mut rng = rand::thread_rng(); + let challenge: F = F::random_oracle(&mut rng, &claimed_sum.to_bytes()); + challenges.push(challenge); + round_evaluations.push(first_univariate_poly.evaluate(&[(0, challenge)])); + + let mut previous_univariate_poly = first_univariate_poly; + + // Intermediate rounds: generate univariate polynomials and challenges + for i in 1..num_variables { + let univariate_poly = polynomial.prove_sumcheck_round_i(i, challenges.clone()); + round_polynomials.push(univariate_poly.clone()); + + // Generate challenge for this round using the random oracle + let challenge: F = F::random_oracle(&mut rng, &previous_univariate_poly.to_bytes()); + challenges.push(challenge); + round_evaluations.push(univariate_poly.evaluate(&[(i, challenge)])); + + previous_univariate_poly = univariate_poly; + } + + // Final evaluation: evaluate the original polynomial at the challenge point + let final_point = challenges.clone(); + let final_evaluation = polynomial.evaluate(&final_point.iter().cloned().enumerate().collect::>()); + + // Construct and return the proof + SumcheckProof { + claimed_sum, + round_polynomials, + round_evaluations, + challenges, + final_point, + final_evaluation, + } +} + +/// Verifies a non-interactive sumcheck proof. +/// +/// This function allows a verifier to be easily convinced of the correctness of a sumcheck proof +/// without interacting with the prover. The verifier can be convinced by the following steps: +/// +/// 1. Check the consistency of the first round's claimed sum with the provided univariate polynomial. +/// 2. Verify the consistency between consecutive rounds' univariate polynomials. +/// 3. Confirm that the final evaluation matches the original multivariate polynomial at the challenge point. +/// +/// The non-interactive nature of this proof system comes from the use of a random oracle to generate +/// challenges, which both the prover and verifier can compute independently. +/// +/// # Arguments +/// +/// * `proof` - The `SumcheckProof` provided by the prover. +/// * `polynomial` - The original multivariate polynomial being summed over. +/// +/// # Returns +/// +/// Returns `true` if the proof is valid, `false` otherwise. +pub fn non_interactive_sumcheck_verify( + proof: &SumcheckProof, + polynomial: &MultivariatePolynomial +) -> bool { + let num_variables = polynomial.variables().len(); + + // Verify first round + let (valid, _) = verify_sumcheck_first_round(proof.claimed_sum, &proof.round_polynomials[0]); + if !valid { + return false; + } + + // Verify intermediate rounds + for i in 1..num_variables { + let (valid, _) = verify_sumcheck_univariate_poly_sum( + i, + proof.challenges[i-1], + &proof.round_polynomials[i-1], + &proof.round_polynomials[i], + ); + if !valid { + return false; + } + } + + // Verify last round + verify_sumcheck_last_round( + proof.final_point.clone(), + &proof.round_polynomials.last().unwrap(), + polynomial, + ) +} \ No newline at end of file diff --git a/src/sumcheck/tests.rs b/src/sumcheck/tests.rs new file mode 100644 index 00000000..90984561 --- /dev/null +++ b/src/sumcheck/tests.rs @@ -0,0 +1,151 @@ +use crate::{ + algebra::field::{ + prime::PlutoBaseField, + Field, + }, + polynomial::multivariate_polynomial::{ + MultivariatePolynomial, MultivariateTerm, MultivariateVariable, + }, + sumcheck::{ + verify_sumcheck_first_round, verify_sumcheck_last_round, verify_sumcheck_univariate_poly_sum, + }, +}; + +#[test] +#[test] +fn test_full_sumcheck_protocol() { + // This test demonstrates the full sumcheck protocol for the polynomial: + // f(x0, x1, x2) = x0 * (x1 + x2) - (x1 * x2) + // We'll prove and verify the sum of this polynomial over the boolean hypercube {0,1}^3. + + // The sumcheck protocol is used to prove the sum of a multivariate polynomial over a boolean + // hypercube without explicitly computing all 2^n evaluations. This is particularly useful for + // large n, where computing all evaluations would be computationally infeasible. + + // Step 1: Define the polynomial + // We start with a multivariate polynomial because the sumcheck protocol is designed to work + // with functions over boolean inputs, which are naturally represented as multivariate + // polynomials. + let poly = MultivariatePolynomial::::from_terms(vec![ + MultivariateTerm::new( + vec![MultivariateVariable { index: 0, exponent: 1 }, MultivariateVariable { + index: 1, + exponent: 1, + }], + PlutoBaseField::ONE, + ), // x0 * x1 + MultivariateTerm::new( + vec![MultivariateVariable { index: 0, exponent: 1 }, MultivariateVariable { + index: 2, + exponent: 1, + }], + PlutoBaseField::ONE, + ), // x0 * x2 + MultivariateTerm::new( + vec![MultivariateVariable { index: 1, exponent: 1 }, MultivariateVariable { + index: 2, + exponent: 1, + }], + -PlutoBaseField::ONE, + ), // -x1 * x2 + ]); + + // Step 2: First round of the sumcheck protocol + // The prover computes the actual sum over all boolean inputs and generates the first univariate + // polynomial. This step is crucial because it reduces the n-variate polynomial to a univariate + // polynomial in x0, while maintaining the property that its sum over {0,1} equals the original + // sum. + let (claimed_sum, univariate_poly1) = poly.prove_first_sumcheck_round(); + + // The verifier checks the first round + // This check ensures that the sum of the univariate polynomial over {0,1} equals the claimed sum. + // It's a key step in verifying the prover's claim without computing the full sum. + let (valid, _challenge) = verify_sumcheck_first_round(claimed_sum, &univariate_poly1); + assert!(valid, "First round verification failed"); + + // Verify that the first univariate polynomial is correct: f0(x0) = 4x0 - 1 + // This check confirms that the prover correctly computed the univariate polynomial. + let expected_poly1 = MultivariatePolynomial::::from_terms(vec![ + MultivariateTerm::new( + vec![MultivariateVariable { index: 0, exponent: 1 }], + PlutoBaseField::new(4), + ), + MultivariateTerm::new(vec![], -PlutoBaseField::ONE), + ]); + assert_eq!(univariate_poly1, expected_poly1, "First round polynomial is incorrect"); + + println!("Claimed sum: {:?}", claimed_sum); + println!("First univariate polynomial: {}", univariate_poly1); + + // Step 3: Second round of the sumcheck protocol + // The verifier sends a challenge. This challenge is used to reduce the problem further, + // from proving a statement about a sum to proving a statement about a single evaluation. + let random_challenge1 = PlutoBaseField::new(4); + + // The prover generates the second univariate polynomial + // This polynomial represents the partial evaluation of the original polynomial with x0 fixed to + // the challenge value. + let univariate_poly2 = poly.prove_sumcheck_round_i(1, vec![random_challenge1]); + println!("Round 2 univariate polynomial: {}", univariate_poly2); + + // Verify that the second univariate polynomial is correct: f1(x1) = 7x1 + 4 + // This check ensures that the prover correctly computed the second univariate polynomial. + let expected_poly2 = MultivariatePolynomial::::from_terms(vec![ + MultivariateTerm::new( + vec![MultivariateVariable { index: 1, exponent: 1 }], + PlutoBaseField::new(7), + ), + MultivariateTerm::new(vec![], PlutoBaseField::new(4)), + ]); + assert_eq!(univariate_poly2, expected_poly2, "Second round polynomial is incorrect"); + + // The verifier checks the second round + // This check ensures that the evaluation of the first univariate polynomial at the challenge + // point equals the sum of the second univariate polynomial over {0,1}. + let (valid, _challenge) = + verify_sumcheck_univariate_poly_sum(1, random_challenge1, &univariate_poly1, &univariate_poly2); + assert!(valid, "Second round verification failed"); + + // Step 4: Third (final) round of the sumcheck protocol + // The process continues, further reducing the problem to a single point evaluation of the + // original polynomial. + let random_challenge2 = PlutoBaseField::new(4); + + // The prover generates the final univariate polynomial + let univariate_poly3 = + poly.prove_sumcheck_last_round(2, vec![random_challenge1, random_challenge2]); + println!("Round 3 univariate polynomial: {}", univariate_poly3); + + // Verify that the final univariate polynomial is correct: f2(x2) = 16 + // This check confirms that the prover correctly computed the final univariate polynomial. + let expected_poly3 = + MultivariatePolynomial::::from_terms(vec![MultivariateTerm::new( + vec![], + PlutoBaseField::new(16), + )]); + assert_eq!(univariate_poly3, expected_poly3, "Final round polynomial is incorrect"); + + // The verifier checks the final round + // This check ensures that the evaluation of the second univariate polynomial at the challenge + // point equals the sum of the third univariate polynomial over {0,1}. + let (valid, _final_challenge) = + verify_sumcheck_univariate_poly_sum(2, random_challenge2, &univariate_poly2, &univariate_poly3); + assert!(valid, "Final round verification failed"); + + // Step 5: Final verification + // The verifier sends a final challenge and checks the entire protocol + // This step verifies that the final point evaluation claimed by the prover + // matches the evaluation of the original polynomial at the challenge points. + let random_challenge3 = PlutoBaseField::new(4); + let valid = verify_sumcheck_last_round( + vec![random_challenge1, random_challenge2, random_challenge3], + &univariate_poly3, + &poly, + ); + assert!(valid, "Overall sumcheck protocol verification failed"); + + // If we reach this point, the entire sumcheck protocol has been successfully demonstrated + // The verifier is convinced that the prover knows the correct sum without having to compute it + // directly. + println!("Sumcheck protocol successfully verified!"); +} diff --git a/src/sumcheck/to_bytes.rs b/src/sumcheck/to_bytes.rs new file mode 100644 index 00000000..306a5fc8 --- /dev/null +++ b/src/sumcheck/to_bytes.rs @@ -0,0 +1,3 @@ +pub trait ToBytes { + fn to_bytes(&self) -> Vec; +} \ No newline at end of file From e09bf7dbf86949f94f75d1d1ec1d702aa6522604 Mon Sep 17 00:00:00 2001 From: John Wu Date: Sun, 11 Aug 2024 14:25:19 -0400 Subject: [PATCH 2/2] format code --- src/kzg/mod.rs | 2 - src/lib.rs | 2 +- src/polynomial/multivariate_polynomial.rs | 788 +++++++++++----------- src/polynomial/tests.rs | 175 +++-- src/random/mod.rs | 49 +- src/sumcheck/boolean_array.rs | 49 +- src/sumcheck/mod.rs | 467 +++++++------ src/sumcheck/tests.rs | 6 +- src/sumcheck/to_bytes.rs | 4 +- 9 files changed, 777 insertions(+), 765 deletions(-) diff --git a/src/kzg/mod.rs b/src/kzg/mod.rs index 11eefe53..85cafee5 100644 --- a/src/kzg/mod.rs +++ b/src/kzg/mod.rs @@ -6,5 +6,3 @@ pub mod setup; pub use setup::*; use super::*; - - diff --git a/src/lib.rs b/src/lib.rs index 20616b37..e102dd79 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -32,9 +32,9 @@ pub mod encryption; pub mod hashes; pub mod kzg; pub mod polynomial; -pub mod tree; pub mod random; pub mod sumcheck; +pub mod tree; use core::{ fmt::{self, Display, Formatter}, diff --git a/src/polynomial/multivariate_polynomial.rs b/src/polynomial/multivariate_polynomial.rs index 9dadd92c..e02492f6 100644 --- a/src/polynomial/multivariate_polynomial.rs +++ b/src/polynomial/multivariate_polynomial.rs @@ -1,39 +1,42 @@ //! Represents a multivariate polynomial over a finite field. //! -//! This implementation uses a novel and highly efficient representation for multivariate polynomials. -//! Each term in the polynomial is represented as a key-value pair in a HashMap, where: +//! This implementation uses a novel and highly efficient representation for multivariate +//! polynomials. Each term in the polynomial is represented as a key-value pair in a HashMap, where: //! - The key is a BTreeMap mapping variable indices to their exponents. //! - The value is the coefficient of the term. //! //! This representation offers several significant advantages: //! 1. Space Efficiency: Only non-zero terms are stored, making it ideal for sparse polynomials. -//! 2. Fast Term Lookup: The use of BTreeMap for exponents allows for quick term identification and manipulation. +//! 2. Fast Term Lookup: The use of BTreeMap for exponents allows for quick term identification and +//! manipulation. //! 3. Ordered Operations: BTreeMap's ordered nature facilitates efficient polynomial arithmetic. -//! 4. Memory Optimization: By using indices instead of full variable objects, we reduce memory usage. -//! 5. Flexible Degree Handling: This structure naturally accommodates polynomials of arbitrary degree. -//! 6. Efficient Iteration: Easy to iterate over terms, useful for various algorithms and transformations. +//! 4. Memory Optimization: By using indices instead of full variable objects, we reduce memory +//! usage. +//! 5. Flexible Degree Handling: This structure naturally accommodates polynomials of arbitrary +//! degree. +//! 6. Efficient Iteration: Easy to iterate over terms, useful for various algorithms and +//! transformations. //! //! While this representation may have a slight overhead for very small polynomials, //! its benefits become increasingly apparent as the polynomial's complexity grows, //! making it an excellent choice for a wide range of cryptographic and algebraic applications. -use std::collections::{HashMap, BTreeMap}; -use std::hash::Hash; -use std::ops::{Add, Mul}; +use std::{ + collections::{BTreeMap, HashMap}, + hash::Hash, + ops::{Add, Mul, Sub}, +}; + use itertools::Itertools; +use super::{Monomial, Polynomial, *}; use crate::algebra::field::FiniteField; -use super::*; -use super::{Monomial, Polynomial}; - -use std::ops::Sub; - /// Represents a multivariate polynomial over a finite field. /// /// The polynomial is stored as a collection of terms, where each term is represented by: -/// - A `BTreeMap` as the key, mapping variable indices to their exponents. -/// This allows for efficient storage and manipulation of sparse polynomials. +/// - A `BTreeMap` as the key, mapping variable indices to their exponents. This +/// allows for efficient storage and manipulation of sparse polynomials. /// - An `F` value as the coefficient, where `F` is a finite field. /// /// The use of `HashMap` for `terms` provides: @@ -49,441 +52,446 @@ use std::ops::Sub; /// commonly encountered in cryptographic and algebraic applications. #[derive(PartialEq, Eq, Debug, Clone)] pub struct MultivariatePolynomial { - terms: HashMap, F>, + terms: HashMap, F>, } impl MultivariatePolynomial { - /// Constructs a new `MultivariatePolynomial` representing the zero polynomial. - /// - /// This function creates an empty `MultivariatePolynomial`, which is equivalent to the zero polynomial. - /// The zero polynomial has no terms, and evaluates to zero for all inputs. - pub fn new() -> Self { - Self { terms: HashMap::new() } + /// Constructs a new `MultivariatePolynomial` representing the zero polynomial. + /// + /// This function creates an empty `MultivariatePolynomial`, which is equivalent to the zero + /// polynomial. The zero polynomial has no terms, and evaluates to zero for all inputs. + pub fn new() -> Self { Self { terms: HashMap::new() } } + + /// Creates a new `MultivariatePolynomial` from a vector of `MultivariateTerm`s. + /// + /// This is the preferred way to create a multivariate polynomial, as it allows + /// for a more intuitive representation of the polynomial's terms. + /// + /// # Arguments + /// + /// * `terms` - A vector of `MultivariateTerm`s representing the polynomial. + /// + /// # Returns + /// + /// A new `MultivariatePolynomial` instance. + /// + /// # Example + /// + /// ``` + /// use your_crate::{ + /// MultivariatePolynomial, MultivariateTerm, MultivariateVariable, PlutoBaseField, + /// }; + /// + /// let poly = MultivariatePolynomial::::from_terms(vec![ + /// MultivariateTerm::new( + /// vec![MultivariateVariable { index: 0, exponent: 2 }, MultivariateVariable { + /// index: 1, + /// exponent: 1, + /// }], + /// PlutoBaseField::new(3), + /// ), + /// MultivariateTerm::new( + /// vec![MultivariateVariable { index: 0, exponent: 1 }], + /// PlutoBaseField::new(2), + /// ), + /// MultivariateTerm::new(vec![], PlutoBaseField::new(1)), + /// ]); + /// + /// // This creates the polynomial: 3x_0^2*x_1 + 2x_0 + 1 + /// ``` + pub fn from_terms(terms: Vec>) -> Self { + let mut poly = MultivariatePolynomial::new(); + for term in terms { + let mut btree_map = BTreeMap::new(); + for var in term.variables { + btree_map.insert(var.index, var.exponent); + } + poly.insert_term(btree_map, term.coefficient); } - - - /// Creates a new `MultivariatePolynomial` from a vector of `MultivariateTerm`s. - /// - /// This is the preferred way to create a multivariate polynomial, as it allows - /// for a more intuitive representation of the polynomial's terms. - /// - /// # Arguments - /// - /// * `terms` - A vector of `MultivariateTerm`s representing the polynomial. - /// - /// # Returns - /// - /// A new `MultivariatePolynomial` instance. - /// - /// # Example - /// - /// ``` - /// use your_crate::{MultivariatePolynomial, MultivariateTerm, MultivariateVariable, PlutoBaseField}; - /// - /// let poly = MultivariatePolynomial::::from_terms(vec![ - /// MultivariateTerm::new( - /// vec![ - /// MultivariateVariable { index: 0, exponent: 2 }, - /// MultivariateVariable { index: 1, exponent: 1 } - /// ], - /// PlutoBaseField::new(3) - /// ), - /// MultivariateTerm::new( - /// vec![MultivariateVariable { index: 0, exponent: 1 }], - /// PlutoBaseField::new(2) - /// ), - /// MultivariateTerm::new( - /// vec![], - /// PlutoBaseField::new(1) - /// ) - /// ]); - /// - /// // This creates the polynomial: 3x_0^2*x_1 + 2x_0 + 1 - /// ``` - pub fn from_terms(terms: Vec>) -> Self { - let mut poly = MultivariatePolynomial::new(); - for term in terms { - let mut btree_map = BTreeMap::new(); - for var in term.variables { - btree_map.insert(var.index, var.exponent); - } - poly.insert_term(btree_map, term.coefficient); - } - poly + poly + } + + fn insert_term(&mut self, exponents: BTreeMap, coefficient: F) { + if coefficient != F::ZERO { + let entry = self.terms.entry(exponents.clone()).or_insert(F::ZERO); + *entry += coefficient; + if *entry == F::ZERO { + self.terms.remove(&exponents); + } } - - fn insert_term(&mut self, exponents: BTreeMap, coefficient: F) { - if coefficient != F::ZERO { - let entry = self.terms.entry(exponents.clone()).or_insert(F::ZERO); - *entry += coefficient; - if *entry == F::ZERO { - self.terms.remove(&exponents); - } + } + + /// Returns the coefficient of the term with the given exponents. + /// + /// # Arguments + /// + /// * `exponents` - A `BTreeMap` where the keys are variable indices and the values are their + /// exponents. + /// + /// # Returns + /// + /// * `Some(&F)` if a term with the given exponents exists in the polynomial, where `F` is the + /// coefficient. + /// * `None` if no term with the given exponents exists in the polynomial. + /// + /// # Example + /// + /// ``` + /// use std::collections::BTreeMap; + /// use your_crate::{MultivariatePolynomial, PlutoBaseField}; + /// + /// let poly = MultivariatePolynomial::::from_terms(vec![ + /// // ... (terms as in the previous example) + /// ]); + /// + /// let mut exponents = BTreeMap::new(); + /// exponents.insert(0, 2); + /// exponents.insert(1, 1); + /// + /// assert_eq!(poly.coefficient(&exponents), Some(&PlutoBaseField::new(3))); + /// ``` + pub fn coefficient(&self, exponents: &BTreeMap) -> Option<&F> { + self.terms.get(exponents) + } + + /// Evaluates the multivariate polynomial at the given points. + /// + /// # Arguments + /// + /// * `points` - A slice of tuples where each tuple contains: + /// - The index of the variable (usize) + /// - The value to evaluate the variable at (F) + /// + /// # Returns + /// + /// * The result of evaluating the polynomial (F) + /// + /// # Example + /// + /// ``` + /// use your_crate::{MultivariatePolynomial, PlutoBaseField}; + /// + /// let poly = MultivariatePolynomial::::from_terms(vec![ + /// // x^2 + 2xy + 3z + /// // ... (terms definition) + /// ]); + /// + /// let points = vec![(0, PlutoBaseField::new(2)), (1, PlutoBaseField::new(3)), (2, PlutoBaseField::new(1))]; + /// let result = poly.evaluate(&points); + /// // result will be the evaluation of x^2 + 2xy + 3z at x=2, y=3, z=1 + /// ``` + pub fn evaluate(&self, points: &[(usize, F)]) -> F { + self + .terms + .iter() + .map(|(exponents, coeff)| { + let term_value = exponents + .iter() + .map(|(&var, &exp)| { + points + .iter() + .find(|&&(v, _)| v == var) + .map(|&(_, value)| value.pow(exp)) + .unwrap_or(F::ONE) + }) + .product::(); + *coeff * term_value + }) + .sum() + } + + /// Applies the given variable assignments to the polynomial, reducing its degree. + /// + /// This method substitutes the specified variables with their corresponding values, + /// effectively reducing the polynomial's degree for those variables. The resulting + /// polynomial will have fewer variables if any were fully substituted. + /// + /// # Arguments + /// + /// * `variables` - A slice of tuples, where each tuple contains: + /// - The index of the variable to substitute (usize) + /// - The value to substitute for that variable (F) + /// + /// # Returns + /// + /// A new `MultivariatePolynomial` with the specified variables substituted. + /// + /// # Example + /// + /// ``` + /// use your_crate::{MultivariatePolynomial, PlutoBaseField}; + /// + /// let poly = MultivariatePolynomial::::from_terms(vec![ + /// // x^2y + 3xyz + 2z + /// // ... (terms definition) + /// ]); + /// + /// let assignments = vec![(1, PlutoBaseField::new(2))]; // y = 2 + /// let reduced_poly = poly.apply_variables(&assignments); + /// // The resulting polynomial will be of the form: 2x^2 + 6xz + 2z + /// ``` + pub fn apply_variables(&self, variables: &[(usize, F)]) -> Self { + let mut result = MultivariatePolynomial::new(); + + for (exponents, coeff) in &self.terms { + let mut new_exponents = exponents.clone(); + let mut new_coeff = *coeff; + + for &(var, value) in variables { + if let Some(exp) = new_exponents.get(&var) { + new_coeff *= value.pow(*exp); + new_exponents.remove(&var); } - } - - /// Returns the coefficient of the term with the given exponents. - /// - /// # Arguments - /// - /// * `exponents` - A `BTreeMap` where the keys are variable indices and the values are their exponents. - /// - /// # Returns - /// - /// * `Some(&F)` if a term with the given exponents exists in the polynomial, where `F` is the coefficient. - /// * `None` if no term with the given exponents exists in the polynomial. - /// - /// # Example - /// - /// ``` - /// use std::collections::BTreeMap; - /// use your_crate::{MultivariatePolynomial, PlutoBaseField}; - /// - /// let poly = MultivariatePolynomial::::from_terms(vec![ - /// // ... (terms as in the previous example) - /// ]); - /// - /// let mut exponents = BTreeMap::new(); - /// exponents.insert(0, 2); - /// exponents.insert(1, 1); - /// - /// assert_eq!(poly.coefficient(&exponents), Some(&PlutoBaseField::new(3))); - /// ``` - pub fn coefficient(&self, exponents: &BTreeMap) -> Option<&F> { - self.terms.get(exponents) - } + } - /// Evaluates the multivariate polynomial at the given points. - /// - /// # Arguments - /// - /// * `points` - A slice of tuples where each tuple contains: - /// - The index of the variable (usize) - /// - The value to evaluate the variable at (F) - /// - /// # Returns - /// - /// * The result of evaluating the polynomial (F) - /// - /// # Example - /// - /// ``` - /// use your_crate::{MultivariatePolynomial, PlutoBaseField}; - /// - /// let poly = MultivariatePolynomial::::from_terms(vec![ - /// // x^2 + 2xy + 3z - /// // ... (terms definition) - /// ]); - /// - /// let points = vec![(0, PlutoBaseField::new(2)), (1, PlutoBaseField::new(3)), (2, PlutoBaseField::new(1))]; - /// let result = poly.evaluate(&points); - /// // result will be the evaluation of x^2 + 2xy + 3z at x=2, y=3, z=1 - /// ``` - pub fn evaluate(&self, points: &[(usize, F)]) -> F { - self.terms.iter().map(|(exponents, coeff)| { - let term_value = exponents.iter().map(|(&var, &exp)| { - points.iter() - .find(|&&(v, _)| v == var) - .map(|&(_, value)| value.pow(exp)) - .unwrap_or(F::ONE) - }).product::(); - *coeff * term_value - }).sum() - } - - /// Applies the given variable assignments to the polynomial, reducing its degree. - /// - /// This method substitutes the specified variables with their corresponding values, - /// effectively reducing the polynomial's degree for those variables. The resulting - /// polynomial will have fewer variables if any were fully substituted. - /// - /// # Arguments - /// - /// * `variables` - A slice of tuples, where each tuple contains: - /// - The index of the variable to substitute (usize) - /// - The value to substitute for that variable (F) - /// - /// # Returns - /// - /// A new `MultivariatePolynomial` with the specified variables substituted. - /// - /// # Example - /// - /// ``` - /// use your_crate::{MultivariatePolynomial, PlutoBaseField}; - /// - /// let poly = MultivariatePolynomial::::from_terms(vec![ - /// // x^2y + 3xyz + 2z - /// // ... (terms definition) - /// ]); - /// - /// let assignments = vec![(1, PlutoBaseField::new(2))]; // y = 2 - /// let reduced_poly = poly.apply_variables(&assignments); - /// // The resulting polynomial will be of the form: 2x^2 + 6xz + 2z - /// ``` - pub fn apply_variables(&self, variables: &[(usize, F)]) -> Self { - let mut result = MultivariatePolynomial::new(); - - for (exponents, coeff) in &self.terms { - let mut new_exponents = exponents.clone(); - let mut new_coeff = *coeff; - - for &(var, value) in variables { - if let Some(exp) = new_exponents.get(&var) { - new_coeff *= value.pow(*exp); - new_exponents.remove(&var); - } - } - - if !new_exponents.is_empty() { - result.insert_term(new_exponents, new_coeff); - } else { - result.insert_term(BTreeMap::new(), new_coeff); - } - } - - result - } - - - /// Calculates the total degree of the multivariate polynomial. - /// - /// The total degree of a multivariate polynomial is the maximum sum of exponents - /// across all terms in the polynomial. - /// - /// # Returns - /// - /// * `usize` - The total degree of the polynomial. - /// - /// # Example - /// - /// ``` - /// use your_crate::MultivariatePolynomial; - /// use your_crate::FiniteField; - /// - /// let poly = MultivariatePolynomial::::from_terms(vec![ - /// // x^2y + 3xyz^2 + 2z - /// // ... (terms definition) - /// ]); - /// - /// assert_eq!(poly.degree(), 4); // The term xyz^2 has the highest total degree of 4 - /// ``` - pub fn degree(&self) -> usize { - self.terms.keys() - .map(|exponents| exponents.values().sum::()) - .max() - .unwrap_or(0) + if !new_exponents.is_empty() { + result.insert_term(new_exponents, new_coeff); + } else { + result.insert_term(BTreeMap::new(), new_coeff); + } } - /// Returns a vector of all variables present in the polynomial. - /// - /// This method collects all unique variables (represented by their indices) - /// that appear in any term of the polynomial. - /// - /// # Returns - /// - /// * `Vec` - A vector containing the indices of all variables in the polynomial. - /// - /// # Example - /// - /// ``` - /// use your_crate::MultivariatePolynomial; - /// use your_crate::FiniteField; - /// - /// let poly = MultivariatePolynomial::::from_terms(vec![ - /// // x^2y + 3xyz^2 + 2z - /// // ... (terms definition) - /// ]); - /// - /// let vars = poly.variables(); - /// assert_eq!(vars, vec![0, 1, 2]); // Assuming x, y, z are represented by 0, 1, 2 respectively - /// ``` - pub fn variables(&self) -> Vec { - self.terms.keys() - .flat_map(|exponents| exponents.keys().cloned()) - .collect::>() - .into_iter() - .collect() - } + result + } + + /// Calculates the total degree of the multivariate polynomial. + /// + /// The total degree of a multivariate polynomial is the maximum sum of exponents + /// across all terms in the polynomial. + /// + /// # Returns + /// + /// * `usize` - The total degree of the polynomial. + /// + /// # Example + /// + /// ``` + /// use your_crate::MultivariatePolynomial; + /// use your_crate::FiniteField; + /// + /// let poly = MultivariatePolynomial::::from_terms(vec![ + /// // x^2y + 3xyz^2 + 2z + /// // ... (terms definition) + /// ]); + /// + /// assert_eq!(poly.degree(), 4); // The term xyz^2 has the highest total degree of 4 + /// ``` + pub fn degree(&self) -> usize { + self.terms.keys().map(|exponents| exponents.values().sum::()).max().unwrap_or(0) + } + + /// Returns a vector of all variables present in the polynomial. + /// + /// This method collects all unique variables (represented by their indices) + /// that appear in any term of the polynomial. + /// + /// # Returns + /// + /// * `Vec` - A vector containing the indices of all variables in the polynomial. + /// + /// # Example + /// + /// ``` + /// use your_crate::MultivariatePolynomial; + /// use your_crate::FiniteField; + /// + /// let poly = MultivariatePolynomial::::from_terms(vec![ + /// // x^2y + 3xyz^2 + 2z + /// // ... (terms definition) + /// ]); + /// + /// let vars = poly.variables(); + /// assert_eq!(vars, vec![0, 1, 2]); // Assuming x, y, z are represented by 0, 1, 2 respectively + /// ``` + pub fn variables(&self) -> Vec { + self + .terms + .keys() + .flat_map(|exponents| exponents.keys().cloned()) + .collect::>() + .into_iter() + .collect() + } } impl Add for MultivariatePolynomial { - type Output = Self; + type Output = Self; - fn add(mut self, rhs: Self) -> Self::Output { - for (exponents, coeff) in rhs.terms { - self.insert_term(exponents, coeff); - } - self + fn add(mut self, rhs: Self) -> Self::Output { + for (exponents, coeff) in rhs.terms { + self.insert_term(exponents, coeff); } + self + } } impl Sub for MultivariatePolynomial { - type Output = Self; + type Output = Self; - fn sub(mut self, rhs: Self) -> Self::Output { - for (exponents, coeff) in rhs.terms { - // Negate the coefficient and insert - self.insert_term(exponents, -coeff); - } - self + fn sub(mut self, rhs: Self) -> Self::Output { + for (exponents, coeff) in rhs.terms { + // Negate the coefficient and insert + self.insert_term(exponents, -coeff); } + self + } } impl Mul for MultivariatePolynomial { - type Output = Self; - - fn mul(self, rhs: Self) -> Self::Output { - let mut result = MultivariatePolynomial::new(); - for (exp1, coeff1) in &self.terms { - for (exp2, coeff2) in &rhs.terms { - let mut new_exp = exp1.clone(); - for (&var, &exp) in exp2 { - *new_exp.entry(var).or_insert(0) += exp; - } - result.insert_term(new_exp, *coeff1 * *coeff2); - } + type Output = Self; + + fn mul(self, rhs: Self) -> Self::Output { + let mut result = MultivariatePolynomial::new(); + for (exp1, coeff1) in &self.terms { + for (exp2, coeff2) in &rhs.terms { + let mut new_exp = exp1.clone(); + for (&var, &exp) in exp2 { + *new_exp.entry(var).or_insert(0) += exp; } - result + result.insert_term(new_exp, *coeff1 * *coeff2); + } } + result + } } impl Display for MultivariatePolynomial { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - let mut first = true; - for (exponents, coeff) in self.terms.iter().sorted_by(|(a_exp, _), (b_exp, _)| { - a_exp.iter() - .zip(b_exp.iter()) - .find(|((a_var, a_pow), (b_var, b_pow))| { - a_var.cmp(b_var).then_with(|| b_pow.cmp(a_pow)).is_ne() - }) - .map_or(std::cmp::Ordering::Equal, |(_, _)| std::cmp::Ordering::Less) - }) { - if !first { - write!(f, " + ")?; - } - first = false; - - if *coeff != F::ONE || exponents.is_empty() { - write!(f, "{}", coeff)?; - } - - let mut first_var = true; - for (&var, &exp) in exponents { - if exp > 0 { - if !first_var || *coeff != F::ONE { - write!(f, "*")?; - } - write!(f, "x_{}", var)?; - if exp > 1 { - write!(f, "^{}", exp)?; - } - first_var = false; - } - } - } - - if first { - write!(f, "0")?; + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + let mut first = true; + for (exponents, coeff) in self.terms.iter().sorted_by(|(a_exp, _), (b_exp, _)| { + a_exp + .iter() + .zip(b_exp.iter()) + .find(|((a_var, a_pow), (b_var, b_pow))| { + a_var.cmp(b_var).then_with(|| b_pow.cmp(a_pow)).is_ne() + }) + .map_or(std::cmp::Ordering::Equal, |(..)| std::cmp::Ordering::Less) + }) { + if !first { + write!(f, " + ")?; + } + first = false; + + if *coeff != F::ONE || exponents.is_empty() { + write!(f, "{}", coeff)?; + } + + let mut first_var = true; + for (&var, &exp) in exponents { + if exp > 0 { + if !first_var || *coeff != F::ONE { + write!(f, "*")?; + } + write!(f, "x_{}", var)?; + if exp > 1 { + write!(f, "^{}", exp)?; + } + first_var = false; } + } + } - Ok(()) + if first { + write!(f, "0")?; } + + Ok(()) + } } // Implement From for univariate polynomials -impl From> for MultivariatePolynomial { - fn from(poly: Polynomial) -> Self { - let mut result = MultivariatePolynomial::new(); - for (i, &coeff) in poly.coefficients.iter().enumerate() { - if coeff != F::ZERO { - let mut exponents = BTreeMap::new(); - exponents.insert(0, i); - result.insert_term(exponents, coeff); - } - } - result +impl From> + for MultivariatePolynomial +{ + fn from(poly: Polynomial) -> Self { + let mut result = MultivariatePolynomial::new(); + for (i, &coeff) in poly.coefficients.iter().enumerate() { + if coeff != F::ZERO { + let mut exponents = BTreeMap::new(); + exponents.insert(0, i); + result.insert_term(exponents, coeff); + } } + result + } } // Extend Polynomial to support conversion to multivariate impl Polynomial { - /// Converts a univariate polynomial to a multivariate polynomial. - /// - /// This method transforms the current univariate polynomial into a multivariate polynomial - /// where all terms use the same variable, specified by `variable_index`. - /// - /// # Arguments - /// - /// * `variable_index` - The index of the variable to use in the resulting multivariate polynomial. - /// - /// # Returns - /// - /// A `MultivariatePolynomial` equivalent to the original univariate polynomial. - /// - /// # Example - /// - /// ``` - /// let univariate = Polynomial::new([F::ONE, F::TWO, F::THREE]); // x^2 + 2x + 1 - /// let multivariate = univariate.to_multivariate(0); - /// // Result: x_0^2 + 2*x_0 + 1 - /// ``` - pub fn to_multivariate(self, variable_index: usize) -> MultivariatePolynomial { - let mut result = MultivariatePolynomial::new(); - for (i, &coeff) in self.coefficients.iter().enumerate() { - if coeff != F::ZERO { - let mut exponents = BTreeMap::new(); - exponents.insert(variable_index, i); - result.insert_term(exponents, coeff); - } - } - result + /// Converts a univariate polynomial to a multivariate polynomial. + /// + /// This method transforms the current univariate polynomial into a multivariate polynomial + /// where all terms use the same variable, specified by `variable_index`. + /// + /// # Arguments + /// + /// * `variable_index` - The index of the variable to use in the resulting multivariate + /// polynomial. + /// + /// # Returns + /// + /// A `MultivariatePolynomial` equivalent to the original univariate polynomial. + /// + /// # Example + /// + /// ``` + /// let univariate = Polynomial::new([F::ONE, F::TWO, F::THREE]); // x^2 + 2x + 1 + /// let multivariate = univariate.to_multivariate(0); + /// // Result: x_0^2 + 2*x_0 + 1 + /// ``` + pub fn to_multivariate(self, variable_index: usize) -> MultivariatePolynomial { + let mut result = MultivariatePolynomial::new(); + for (i, &coeff) in self.coefficients.iter().enumerate() { + if coeff != F::ZERO { + let mut exponents = BTreeMap::new(); + exponents.insert(variable_index, i); + result.insert_term(exponents, coeff); + } } + result + } } - /// Represents a variable with an exponent in a multivariate polynomial. /// Each variable is uniquely identified by its index and has an associated exponent. #[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct MultivariateVariable { - /// The unique identifier for the variable. - /// This index distinguishes one variable from another in the polynomial. - pub index: usize, + /// The unique identifier for the variable. + /// This index distinguishes one variable from another in the polynomial. + pub index: usize, - /// The power to which the variable is raised. - /// For example, if exponent is 2, it represents x^2 for the variable x. - pub exponent: usize, + /// The power to which the variable is raised. + /// For example, if exponent is 2, it represents x^2 for the variable x. + pub exponent: usize, } impl MultivariateVariable { - /// Creates a new multivariate variable with the given index and exponent. - pub fn new(index: usize, exponent: usize) -> Self { - MultivariateVariable { index, exponent } - } + /// Creates a new multivariate variable with the given index and exponent. + pub fn new(index: usize, exponent: usize) -> Self { MultivariateVariable { index, exponent } } } /// Represents a term in a multivariate polynomial. -/// +/// /// # Fields -/// +/// /// * `variables` - A vector of `MultivariateVariable`s representing the variables in this term. /// * `coefficient` - The coefficient of this term, represented as a finite field element. #[derive(PartialEq, Eq)] pub struct MultivariateTerm { - /// A vector of `MultivariateVariable`s representing the variables in this term. - /// Each `MultivariateVariable` contains an index and an exponent. - pub variables: Vec, + /// A vector of `MultivariateVariable`s representing the variables in this term. + /// Each `MultivariateVariable` contains an index and an exponent. + pub variables: Vec, - /// The coefficient of this term, represented as a finite field element. - /// This value multiplies the product of the variables in the term. - pub coefficient: F, + /// The coefficient of this term, represented as a finite field element. + /// This value multiplies the product of the variables in the term. + pub coefficient: F, } /// Represents a term in a multivariate polynomial. /// A term consists of a coefficient and a collection of variables with their exponents. impl MultivariateTerm { - /// Creates a new multivariate term with the given variables and coefficient. - pub fn new(variables: Vec, coefficient: F) -> Self { - MultivariateTerm { variables, coefficient } - } + /// Creates a new multivariate term with the given variables and coefficient. + pub fn new(variables: Vec, coefficient: F) -> Self { + MultivariateTerm { variables, coefficient } + } } diff --git a/src/polynomial/tests.rs b/src/polynomial/tests.rs index 78b000a4..a35b12c4 100644 --- a/src/polynomial/tests.rs +++ b/src/polynomial/tests.rs @@ -1,5 +1,7 @@ use super::*; -use crate::polynomial::multivariate_polynomial::{MultivariatePolynomial, MultivariateTerm, MultivariateVariable}; +use crate::polynomial::multivariate_polynomial::{ + MultivariatePolynomial, MultivariateTerm, MultivariateVariable, +}; #[fixture] fn poly() -> Polynomial { @@ -132,17 +134,17 @@ fn dft(poly: Polynomial) { fn test_multivariate_polynomial_creation() { let poly = MultivariatePolynomial::::from_terms(vec![ MultivariateTerm::new( - vec![MultivariateVariable { index: 0, exponent: 2 }, MultivariateVariable { index: 1, exponent: 1 }], - PlutoBaseField::new(3) + vec![MultivariateVariable { index: 0, exponent: 2 }, MultivariateVariable { + index: 1, + exponent: 1, + }], + PlutoBaseField::new(3), ), MultivariateTerm::new( vec![MultivariateVariable { index: 0, exponent: 1 }], - PlutoBaseField::new(2) + PlutoBaseField::new(2), ), - MultivariateTerm::new( - vec![], - PlutoBaseField::new(1) - ) + MultivariateTerm::new(vec![], PlutoBaseField::new(1)), ]); assert_eq!(poly.degree(), 3); @@ -158,23 +160,23 @@ fn test_multivariate_polynomial_addition() { let poly1 = MultivariatePolynomial::::from_terms(vec![ MultivariateTerm::new( vec![MultivariateVariable { index: 0, exponent: 2 }], - PlutoBaseField::new(1) + PlutoBaseField::new(1), ), MultivariateTerm::new( vec![MultivariateVariable { index: 1, exponent: 1 }], - PlutoBaseField::new(2) - ) + PlutoBaseField::new(2), + ), ]); let poly2 = MultivariatePolynomial::::from_terms(vec![ MultivariateTerm::new( vec![MultivariateVariable { index: 0, exponent: 2 }], - PlutoBaseField::new(3) + PlutoBaseField::new(3), ), MultivariateTerm::new( vec![MultivariateVariable { index: 2, exponent: 1 }], - PlutoBaseField::new(4) - ) + PlutoBaseField::new(4), + ), ]); let result = poly1 + poly2; @@ -184,16 +186,16 @@ fn test_multivariate_polynomial_addition() { let expected_result = MultivariatePolynomial::::from_terms(vec![ MultivariateTerm::new( vec![MultivariateVariable { index: 0, exponent: 2 }], - PlutoBaseField::new(4) + PlutoBaseField::new(4), ), MultivariateTerm::new( vec![MultivariateVariable { index: 1, exponent: 1 }], - PlutoBaseField::new(2) + PlutoBaseField::new(2), ), MultivariateTerm::new( vec![MultivariateVariable { index: 2, exponent: 1 }], - PlutoBaseField::new(4) - ) + PlutoBaseField::new(4), + ), ]); assert_eq!(result, expected_result); @@ -204,23 +206,17 @@ fn test_multivariate_polynomial_multiplication() { let poly1 = MultivariatePolynomial::::from_terms(vec![ MultivariateTerm::new( vec![MultivariateVariable { index: 0, exponent: 1 }], - PlutoBaseField::new(2) + PlutoBaseField::new(2), ), - MultivariateTerm::new( - vec![], - PlutoBaseField::new(1) - ) + MultivariateTerm::new(vec![], PlutoBaseField::new(1)), ]); let poly2 = MultivariatePolynomial::::from_terms(vec![ MultivariateTerm::new( vec![MultivariateVariable { index: 1, exponent: 1 }], - PlutoBaseField::new(3) + PlutoBaseField::new(3), ), - MultivariateTerm::new( - vec![], - PlutoBaseField::new(1) - ) + MultivariateTerm::new(vec![], PlutoBaseField::new(1)), ]); let result = poly1 * poly2; @@ -229,21 +225,21 @@ fn test_multivariate_polynomial_multiplication() { let expected_result = MultivariatePolynomial::::from_terms(vec![ MultivariateTerm::new( - vec![MultivariateVariable { index: 0, exponent: 1 }, MultivariateVariable { index: 1, exponent: 1 }], - PlutoBaseField::new(6) + vec![MultivariateVariable { index: 0, exponent: 1 }, MultivariateVariable { + index: 1, + exponent: 1, + }], + PlutoBaseField::new(6), ), MultivariateTerm::new( vec![MultivariateVariable { index: 0, exponent: 1 }], - PlutoBaseField::new(2) + PlutoBaseField::new(2), ), MultivariateTerm::new( vec![MultivariateVariable { index: 1, exponent: 1 }], - PlutoBaseField::new(3) + PlutoBaseField::new(3), ), - MultivariateTerm::new( - vec![], - PlutoBaseField::new(1) - ) + MultivariateTerm::new(vec![], PlutoBaseField::new(1)), ]); assert_eq!(result, expected_result); @@ -253,17 +249,17 @@ fn test_multivariate_polynomial_multiplication() { fn test_multivariate_polynomial_evaluation() { let poly = MultivariatePolynomial::::from_terms(vec![ MultivariateTerm::new( - vec![MultivariateVariable { index: 0, exponent: 2 }, MultivariateVariable { index: 1, exponent: 1 }], - PlutoBaseField::new(3) + vec![MultivariateVariable { index: 0, exponent: 2 }, MultivariateVariable { + index: 1, + exponent: 1, + }], + PlutoBaseField::new(3), ), MultivariateTerm::new( vec![MultivariateVariable { index: 0, exponent: 1 }], - PlutoBaseField::new(2) + PlutoBaseField::new(2), ), - MultivariateTerm::new( - vec![], - PlutoBaseField::new(1) - ) + MultivariateTerm::new(vec![], PlutoBaseField::new(1)), ]); let points = vec![(0, PlutoBaseField::new(2)), (1, PlutoBaseField::new(3))]; @@ -280,27 +276,23 @@ fn test_apply_variables_single_variable() { let poly = MultivariatePolynomial::::from_terms(vec![ MultivariateTerm::new( vec![MultivariateVariable { index: 0, exponent: 2 }], - PlutoBaseField::new(3) + PlutoBaseField::new(3), ), MultivariateTerm::new( vec![MultivariateVariable { index: 0, exponent: 1 }], - PlutoBaseField::new(2) + PlutoBaseField::new(2), ), - MultivariateTerm::new( - vec![], - PlutoBaseField::new(1) - ) + MultivariateTerm::new(vec![], PlutoBaseField::new(1)), ]); let variables = vec![(0, PlutoBaseField::new(2))]; let result = poly.apply_variables(&variables); - let expected_result = MultivariatePolynomial::::from_terms(vec![ - MultivariateTerm::new( + let expected_result = + MultivariatePolynomial::::from_terms(vec![MultivariateTerm::new( vec![], - PlutoBaseField::new(17) - ) - ]); + PlutoBaseField::new(17), + )]); assert_eq!(result, expected_result); } @@ -309,21 +301,21 @@ fn test_apply_variables_single_variable() { fn test_apply_variables_multiple_variables() { let poly = MultivariatePolynomial::::from_terms(vec![ MultivariateTerm::new( - vec![MultivariateVariable { index: 0, exponent: 2 }, MultivariateVariable { index: 1, exponent: 1 }], - PlutoBaseField::new(3) + vec![MultivariateVariable { index: 0, exponent: 2 }, MultivariateVariable { + index: 1, + exponent: 1, + }], + PlutoBaseField::new(3), ), MultivariateTerm::new( vec![MultivariateVariable { index: 0, exponent: 1 }], - PlutoBaseField::new(2) + PlutoBaseField::new(2), ), MultivariateTerm::new( vec![MultivariateVariable { index: 1, exponent: 1 }], - PlutoBaseField::new(4) + PlutoBaseField::new(4), ), - MultivariateTerm::new( - vec![], - PlutoBaseField::new(1) - ) + MultivariateTerm::new(vec![], PlutoBaseField::new(1)), ]); println!("Apply Multiple Variables Polynomial: {}", poly); @@ -333,12 +325,11 @@ fn test_apply_variables_multiple_variables() { println!("Reduced Multiple Variable Polynomial: {}", result); - let expected_result = MultivariatePolynomial::::from_terms(vec![ - MultivariateTerm::new( + let expected_result = + MultivariatePolynomial::::from_terms(vec![MultivariateTerm::new( vec![], - PlutoBaseField::new(53) - ) - ]); + PlutoBaseField::new(53), + )]); assert_eq!(result, expected_result); } @@ -348,20 +339,23 @@ fn test_apply_variables_partial_application() { let poly = MultivariatePolynomial::::from_terms(vec![ MultivariateTerm::new( vec![ - MultivariateVariable { index: 0, exponent: 2 }, - MultivariateVariable { index: 1, exponent: 1 }, - MultivariateVariable { index: 2, exponent: 1 } + MultivariateVariable { index: 0, exponent: 2 }, + MultivariateVariable { index: 1, exponent: 1 }, + MultivariateVariable { index: 2, exponent: 1 }, ], - PlutoBaseField::new(3) + PlutoBaseField::new(3), ), MultivariateTerm::new( - vec![MultivariateVariable { index: 0, exponent: 1 }, MultivariateVariable { index: 2, exponent: 1 }], - PlutoBaseField::new(2) + vec![MultivariateVariable { index: 0, exponent: 1 }, MultivariateVariable { + index: 2, + exponent: 1, + }], + PlutoBaseField::new(2), ), MultivariateTerm::new( vec![MultivariateVariable { index: 1, exponent: 1 }], - PlutoBaseField::new(4) - ) + PlutoBaseField::new(4), + ), ]); let variables = vec![(0, PlutoBaseField::new(2)), (1, PlutoBaseField::new(3))]; @@ -370,12 +364,9 @@ fn test_apply_variables_partial_application() { let expected_result = MultivariatePolynomial::::from_terms(vec![ MultivariateTerm::new( vec![MultivariateVariable { index: 2, exponent: 1 }], - PlutoBaseField::new(40) + PlutoBaseField::new(40), ), - MultivariateTerm::new( - vec![], - PlutoBaseField::new(12) - ) + MultivariateTerm::new(vec![], PlutoBaseField::new(12)), ]); assert_eq!(applied_poly, expected_result); @@ -385,13 +376,16 @@ fn test_apply_variables_partial_application() { fn test_apply_variables_no_effect() { let poly = MultivariatePolynomial::::from_terms(vec![ MultivariateTerm::new( - vec![MultivariateVariable { index: 0, exponent: 2 }, MultivariateVariable { index: 1, exponent: 1 }], - PlutoBaseField::new(3) + vec![MultivariateVariable { index: 0, exponent: 2 }, MultivariateVariable { + index: 1, + exponent: 1, + }], + PlutoBaseField::new(3), ), MultivariateTerm::new( vec![MultivariateVariable { index: 2, exponent: 1 }], - PlutoBaseField::new(2) - ) + PlutoBaseField::new(2), + ), ]); let variables = vec![(3, PlutoBaseField::new(5))]; @@ -402,12 +396,13 @@ fn test_apply_variables_no_effect() { #[test] fn test_apply_variables_empty() { - let poly = MultivariatePolynomial::::from_terms(vec![ - MultivariateTerm::new( - vec![MultivariateVariable { index: 0, exponent: 2 }, MultivariateVariable { index: 1, exponent: 1 }], - PlutoBaseField::new(3) - ) - ]); // 3x_0^2*x_1 + let poly = MultivariatePolynomial::::from_terms(vec![MultivariateTerm::new( + vec![MultivariateVariable { index: 0, exponent: 2 }, MultivariateVariable { + index: 1, + exponent: 1, + }], + PlutoBaseField::new(3), + )]); // 3x_0^2*x_1 let variables = vec![]; let result = poly.apply_variables(&variables); diff --git a/src/random/mod.rs b/src/random/mod.rs index 5f6af5d3..76db00e3 100644 --- a/src/random/mod.rs +++ b/src/random/mod.rs @@ -16,7 +16,6 @@ //! The module is designed to work seamlessly with the `rand` crate's `Rng` trait, //! providing a consistent interface for random number generation across the library. - use rand::Rng; /// A trait for types that can be randomly generated. @@ -24,16 +23,16 @@ use rand::Rng; /// Types implementing this trait can create random instances of themselves /// using a provided random number generator. pub trait Random { - /// Generates a random instance of the implementing type. - /// - /// # Arguments - /// - /// * `rng` - A mutable reference to a random number generator. - /// - /// # Returns - /// - /// A randomly generated instance of the implementing type. - fn random(rng: &mut R) -> Self; + /// Generates a random instance of the implementing type. + /// + /// # Arguments + /// + /// * `rng` - A mutable reference to a random number generator. + /// + /// # Returns + /// + /// A randomly generated instance of the implementing type. + fn random(rng: &mut R) -> Self; } /// A trait for types that can be generated using a random oracle. @@ -42,18 +41,18 @@ pub trait Random { /// using a provided random number generator and an input byte slice, /// simulating a random oracle functionality. pub trait RandomOracle: Random { - /// Generates an instance of the implementing type using a random oracle approach. - /// - /// This method takes both a random number generator and an input byte slice, - /// allowing for deterministic yet unpredictable output based on the input. - /// - /// # Arguments - /// - /// * `rng` - A mutable reference to a random number generator. - /// * `input` - A byte slice used as input to the random oracle. - /// - /// # Returns - /// - /// An instance of the implementing type, generated using the random oracle approach. - fn random_oracle(rng: &mut R, input: &[u8]) -> Self; + /// Generates an instance of the implementing type using a random oracle approach. + /// + /// This method takes both a random number generator and an input byte slice, + /// allowing for deterministic yet unpredictable output based on the input. + /// + /// # Arguments + /// + /// * `rng` - A mutable reference to a random number generator. + /// * `input` - A byte slice used as input to the random oracle. + /// + /// # Returns + /// + /// An instance of the implementing type, generated using the random oracle approach. + fn random_oracle(rng: &mut R, input: &[u8]) -> Self; } diff --git a/src/sumcheck/boolean_array.rs b/src/sumcheck/boolean_array.rs index 5b1e6c3b..d50fa9ce 100644 --- a/src/sumcheck/boolean_array.rs +++ b/src/sumcheck/boolean_array.rs @@ -1,37 +1,34 @@ struct BooleanArrayIter { - current: Vec, - done: bool, + current: Vec, + done: bool, } impl Iterator for BooleanArrayIter { - type Item = Vec; + type Item = Vec; - fn next(&mut self) -> Option { - if self.done { - return None; - } - - let result = self.current.clone(); + fn next(&mut self) -> Option { + if self.done { + return None; + } - // Generate next array - for i in 0..self.current.len() { - if self.current[i] { - self.current[i] = false; - } else { - self.current[i] = true; - return Some(result); - } - } + let result = self.current.clone(); - // If we've reached here, we've generated all arrays - self.done = true; - Some(result) + // Generate next array + for i in 0..self.current.len() { + if self.current[i] { + self.current[i] = false; + } else { + self.current[i] = true; + return Some(result); + } } + + // If we've reached here, we've generated all arrays + self.done = true; + Some(result) + } } pub fn get_all_possible_boolean_values(length: usize) -> impl Iterator> { - BooleanArrayIter { - current: vec![false; length], - done: false, - } -} \ No newline at end of file + BooleanArrayIter { current: vec![false; length], done: false } +} diff --git a/src/sumcheck/mod.rs b/src/sumcheck/mod.rs index 9f46ca8b..1cd0336f 100644 --- a/src/sumcheck/mod.rs +++ b/src/sumcheck/mod.rs @@ -16,7 +16,8 @@ //! //! ## Implementation Details //! -//! This implementation provides both interactive and non-interactive versions of the sumcheck protocol. +//! This implementation provides both interactive and non-interactive versions of the sumcheck +//! protocol. //! //! ### Prover Implementation //! @@ -25,7 +26,8 @@ //! - `prove_first_sumcheck_round`: Computes the claimed sum and the first univariate polynomial. //! - `prove_sumcheck_round_i`: Generates the univariate polynomial for intermediate rounds. //! - `prove_sumcheck_last_round`: Handles the final round of the protocol. -//! - `compute_univariate_polynomial`: A helper method to compute the univariate polynomial for each round. +//! - `compute_univariate_polynomial`: A helper method to compute the univariate polynomial for each +//! round. //! //! This structure allows for a clear separation of concerns and follows the round-based //! nature of the sumcheck protocol. @@ -64,49 +66,56 @@ //! //! ## Usage //! -//! To use this implementation, create a `MultivariatePolynomial`, then use the `non_interactive_sumcheck_prove` -//! function to generate a proof, and `non_interactive_sumcheck_verify` to verify it. +//! To use this implementation, create a `MultivariatePolynomial`, then use the +//! `non_interactive_sumcheck_prove` function to generate a proof, and +//! `non_interactive_sumcheck_verify` to verify it. //! //! For more fine-grained control, you can use the individual prover and verifier functions //! to implement an interactive version of the protocol. -use std::{fmt::Display, hash::{Hash, Hasher}}; +use std::{ + fmt::Display, + hash::{Hash, Hasher}, +}; use rand::{Rng, SeedableRng}; -use crate::{algebra::field::FiniteField, polynomial::multivariate_polynomial::MultivariatePolynomial, random::{Random, RandomOracle}}; + +use crate::{ + algebra::field::FiniteField, + polynomial::multivariate_polynomial::MultivariatePolynomial, + random::{Random, RandomOracle}, +}; mod boolean_array; -mod to_bytes; #[cfg(test)] mod tests; +mod to_bytes; use self::{boolean_array::get_all_possible_boolean_values, to_bytes::ToBytes}; impl RandomOracle for F { - fn random_oracle(_rng: &mut R, input: &[u8]) -> Self { - // This is a simplified example. In a real implementation, - // you'd want to use a cryptographic hash function here. - let mut hasher = std::collections::hash_map::DefaultHasher::new(); - input.hash(&mut hasher); - let hash = hasher.finish(); - - // Use the hash to seed a new RNG - let mut seeded_rng = rand::rngs::StdRng::seed_from_u64(hash); - - // Generate a random field element using the seeded RNG - Self::random(&mut seeded_rng) - } + fn random_oracle(_rng: &mut R, input: &[u8]) -> Self { + // This is a simplified example. In a real implementation, + // you'd want to use a cryptographic hash function here. + let mut hasher = std::collections::hash_map::DefaultHasher::new(); + input.hash(&mut hasher); + let hash = hasher.finish(); + + // Use the hash to seed a new RNG + let mut seeded_rng = rand::rngs::StdRng::seed_from_u64(hash); + + // Generate a random field element using the seeded RNG + Self::random(&mut seeded_rng) + } } impl MultivariatePolynomial { - - /// Proves the first round of the sumcheck protocol for this multivariate polynomial. /// /// This function is crucial for initiating the sumcheck protocol because: - /// 1. It computes the total sum of the polynomial over all boolean inputs, which is the - /// claimed sum that the prover wants to prove. - /// 2. It generates the first univariate polynomial g_1(X1), which is the first step in - /// reducing the multivariate sumcheck to a series of univariate sumchecks. + /// 1. It computes the total sum of the polynomial over all boolean inputs, which is the claimed + /// sum that the prover wants to prove. + /// 2. It generates the first univariate polynomial g_1(X1), which is the first step in reducing + /// the multivariate sumcheck to a series of univariate sumchecks. /// /// The sumcheck protocol is essential for efficiently verifying the sum of a multivariate /// polynomial over a boolean hypercube without evaluating every point, which would be @@ -116,15 +125,17 @@ impl MultivariatePolynomial { /// # Returns /// - A tuple containing: /// 1. The claimed sum (F): The total sum of the polynomial over all boolean inputs. - /// 2. The first univariate polynomial (MultivariatePolynomial): g_1(X1), which is - /// actually univariate despite the type name. + /// 2. The first univariate polynomial (MultivariatePolynomial): g_1(X1), which is actually + /// univariate despite the type name. pub fn prove_first_sumcheck_round(&self) -> (F, MultivariatePolynomial) { let variables = self.variables(); let num_variables = variables.len(); let sum = get_all_possible_boolean_values(num_variables) .map(|bool_values| { - let assignment: Vec<(usize, F)> = variables.iter().enumerate() + let assignment: Vec<(usize, F)> = variables + .iter() + .enumerate() .map(|(i, &var)| (var, if bool_values[i] { F::ONE } else { F::ZERO })) .collect(); self.evaluate(&assignment) @@ -150,19 +161,19 @@ impl MultivariatePolynomial { /// /// # Returns /// - /// * `MultivariatePolynomial` - The univariate polynomial g_i(X_i) for the i-th round. - /// Despite the type name, this polynomial is univariate in X_i. + /// * `MultivariatePolynomial` - The univariate polynomial g_i(X_i) for the i-th round. Despite + /// the type name, this polynomial is univariate in X_i. /// /// # Properties and Equalities /// /// 1. Degree Preservation: The degree of g_i(X_i) in X_i is at most the degree of the original /// polynomial in X_i. /// - /// 2. Sum Consistency: The sum of g_i(X_i) over {0,1} equals g_{i-1}(r_{i-1}), where r_{i-1} - /// is the random challenge from the previous round. + /// 2. Sum Consistency: The sum of g_i(X_i) over {0,1} equals g_{i-1}(r_{i-1}), where r_{i-1} is + /// the random challenge from the previous round. /// - /// 3. Randomized Reduction: g_i(X_i) reduces the sum check for i variables to a sum check - /// for i-1 variables when a random point is chosen. + /// 3. Randomized Reduction: g_i(X_i) reduces the sum check for i variables to a sum check for i-1 + /// variables when a random point is chosen. /// /// 4. Partial Evaluation: g_i(X_i) can be seen as a partial evaluation of the original /// polynomial, with the first i-1 variables fixed to the values in partial_assignment. @@ -177,7 +188,7 @@ impl MultivariatePolynomial { pub fn prove_sumcheck_round_i( &self, i: usize, - partial_assignment: Vec, + partial_assignment: Vec, ) -> MultivariatePolynomial { return self.compute_univariate_polynomial(i, partial_assignment); } @@ -196,8 +207,8 @@ impl MultivariatePolynomial { /// /// # Returns /// - /// * `MultivariatePolynomial` - The final univariate polynomial for the last round. - /// This polynomial is univariate in the last remaining variable. + /// * `MultivariatePolynomial` - The final univariate polynomial for the last round. This + /// polynomial is univariate in the last remaining variable. /// /// # Note /// @@ -230,10 +241,9 @@ impl MultivariatePolynomial { .map(|bool_values| { let further_assignments: Vec = bool_values.iter().map(|&b| if b { F::ONE } else { F::ZERO }).collect(); - let further_variables = ((round + 1)..num_variables).zip(further_assignments).collect::>(); - let poly = partial_poly.clone().apply_variables( - &further_variables, - ); + let further_variables = + ((round + 1)..num_variables).zip(further_assignments).collect::>(); + let poly = partial_poly.clone().apply_variables(&further_variables); poly }) .fold(MultivariatePolynomial::new(), |acc, poly| acc + poly); @@ -246,7 +256,6 @@ impl MultivariatePolynomial { result_polynomial } - } /// Verifies the first round of the sumcheck protocol. @@ -254,25 +263,25 @@ impl MultivariatePolynomial { /// This function is crucial for initiating the verification process in the sumcheck protocol. /// The verifier needs these components to ensure the correctness of the prover's claim: /// -/// 1. `claimed_sum`: The total sum claimed by the prover. This is the value that the -/// verifier wants to check without computing the entire sum themselves. +/// 1. `claimed_sum`: The total sum claimed by the prover. This is the value that the verifier wants +/// to check without computing the entire sum themselves. /// -/// 2. `univariate_poly`: The first univariate polynomial g_1(X_1) provided by the prover. -/// This polynomial is supposed to represent the sum over all but the first variable. +/// 2. `univariate_poly`: The first univariate polynomial g_1(X_1) provided by the prover. This +/// polynomial is supposed to represent the sum over all but the first variable. /// /// The verification process involves: /// -/// 1. Checking that the provided polynomial is indeed univariate. This ensures that -/// the prover is following the protocol correctly by reducing one variable at a time. +/// 1. Checking that the provided polynomial is indeed univariate. This ensures that the prover is +/// following the protocol correctly by reducing one variable at a time. /// -/// 2. Verifying that g_1(0) + g_1(1) equals the claimed sum. This check is fundamental -/// because it connects the univariate polynomial to the original multivariate sum. -/// If this equality holds, it suggests that the prover has correctly computed the -/// univariate polynomial for the first round. +/// 2. Verifying that g_1(0) + g_1(1) equals the claimed sum. This check is fundamental because it +/// connects the univariate polynomial to the original multivariate sum. If this equality holds, +/// it suggests that the prover has correctly computed the univariate polynomial for the first +/// round. /// -/// 3. Generating a random challenge. This challenge will be used in subsequent rounds -/// and is crucial for the security of the protocol. It ensures that the prover -/// cannot predict or manipulate future rounds. +/// 3. Generating a random challenge. This challenge will be used in subsequent rounds and is +/// crucial for the security of the protocol. It ensures that the prover cannot predict or +/// manipulate future rounds. /// /// # Arguments /// @@ -289,45 +298,47 @@ impl MultivariatePolynomial { /// /// * `F`: A type that implements both `FiniteField` and `Random` traits. pub fn verify_sumcheck_first_round( - claimed_sum: F, - univariate_poly: &MultivariatePolynomial + claimed_sum: F, + univariate_poly: &MultivariatePolynomial, ) -> (bool, F) { - // Step 1: Verify that the polynomial is univariate (has only one variable) - if univariate_poly.variables().len() != 1 { - return (false, F::ZERO); - } + // Step 1: Verify that the polynomial is univariate (has only one variable) + if univariate_poly.variables().len() != 1 { + return (false, F::ZERO); + } - // Step 2: Verify that g(0) + g(1) = claimed_sum - let var = 0; - let sum_at_endpoints = univariate_poly.evaluate(&[(var, F::ZERO)]) + univariate_poly.evaluate(&[(var, F::ONE)]); + // Step 2: Verify that g(0) + g(1) = claimed_sum + let var = 0; + let sum_at_endpoints = + univariate_poly.evaluate(&[(var, F::ZERO)]) + univariate_poly.evaluate(&[(var, F::ONE)]); - if sum_at_endpoints != claimed_sum { - return (false, F::ZERO); - } + if sum_at_endpoints != claimed_sum { + return (false, F::ZERO); + } - // Step 3: Generate a random challenge - let mut rng = rand::thread_rng(); - let challenge: F = F::random(&mut rng); + // Step 3: Generate a random challenge + let mut rng = rand::thread_rng(); + let challenge: F = F::random(&mut rng); - // Return true (verification passed) and the evaluation at the challenge point - (true, challenge) + // Return true (verification passed) and the evaluation at the challenge point + (true, challenge) } /// Verify the i-th round of the sumcheck protocol /// -/// This function is crucial for verifying the correctness of each intermediate step in the sumcheck protocol. -/// It ensures that the prover is following the protocol correctly and not deviating from the expected behavior. +/// This function is crucial for verifying the correctness of each intermediate step in the sumcheck +/// protocol. It ensures that the prover is following the protocol correctly and not deviating from +/// the expected behavior. /// /// # Arguments /// -/// * `round`: The current round number of the sumcheck protocol. This is needed to keep track of which variable -/// is being eliminated in the current round. -/// * `challenge`: The random challenge from the previous round. This is used to evaluate the previous round's -/// polynomial and connect it to the current round. -/// * `previous_univariate_poly`: The univariate polynomial from the previous round. This is needed to verify -/// the consistency between rounds. -/// * `current_univariate_poly`: The univariate polynomial for the current round. This is the polynomial that -/// the prover claims represents the sum over the current variable. +/// * `round`: The current round number of the sumcheck protocol. This is needed to keep track of +/// which variable is being eliminated in the current round. +/// * `challenge`: The random challenge from the previous round. This is used to evaluate the +/// previous round's polynomial and connect it to the current round. +/// * `previous_univariate_poly`: The univariate polynomial from the previous round. This is needed +/// to verify the consistency between rounds. +/// * `current_univariate_poly`: The univariate polynomial for the current round. This is the +/// polynomial that the prover claims represents the sum over the current variable. /// /// # Returns /// @@ -337,42 +348,46 @@ pub fn verify_sumcheck_first_round( /// /// # Why these parameters are needed /// -/// 1. `round`: Keeps track of the protocol's progress and ensures variables are eliminated in order. -/// 2. `challenge`: Connects the current round to the previous one, preventing the prover from deviating. +/// 1. `round`: Keeps track of the protocol's progress and ensures variables are eliminated in +/// order. +/// 2. `challenge`: Connects the current round to the previous one, preventing the prover from +/// deviating. /// 3. `previous_univariate_poly`: Used to verify consistency between rounds. /// 4. `current_univariate_poly`: The polynomial to be verified in the current round. /// /// These parameters allow the verifier to check: -/// - The univariate nature of the current polynomial (ensuring one variable is eliminated per round). +/// - The univariate nature of the current polynomial (ensuring one variable is eliminated per +/// round). /// - The consistency between rounds (g_i(r_{i-1}) = g_{i-1}(0) + g_{i-1}(1)). -/// - Generate a new random challenge for the next round, maintaining the protocol's unpredictability. +/// - Generate a new random challenge for the next round, maintaining the protocol's +/// unpredictability. pub fn verify_sumcheck_univariate_poly_sum( - round: usize, - challenge: F, - previous_univariate_poly: &MultivariatePolynomial, - current_univariate_poly: &MultivariatePolynomial, + round: usize, + challenge: F, + previous_univariate_poly: &MultivariatePolynomial, + current_univariate_poly: &MultivariatePolynomial, ) -> (bool, F) { - // Step 1: Verify that the current polynomial is univariate - if current_univariate_poly.variables().len() > 1 { - return (false, F::ZERO); - } + // Step 1: Verify that the current polynomial is univariate + if current_univariate_poly.variables().len() > 1 { + return (false, F::ZERO); + } - // Step 2: Verify that g_i(r_{i-1}) = g_{i-1}(0) + g_{i-1}(1) - let prev_var = round - 1; - let sum_at_endpoints = previous_univariate_poly.evaluate(&[(prev_var, challenge)]); + // Step 2: Verify that g_i(r_{i-1}) = g_{i-1}(0) + g_{i-1}(1) + let prev_var = round - 1; + let sum_at_endpoints = previous_univariate_poly.evaluate(&[(prev_var, challenge)]); - let eval_at_previous_challenge = current_univariate_poly.evaluate(&[(round, F::ZERO)]) - + current_univariate_poly.evaluate(&[(round, F::ONE)]); - if eval_at_previous_challenge != sum_at_endpoints { - return (false, F::ZERO); - } + let eval_at_previous_challenge = current_univariate_poly.evaluate(&[(round, F::ZERO)]) + + current_univariate_poly.evaluate(&[(round, F::ONE)]); + if eval_at_previous_challenge != sum_at_endpoints { + return (false, F::ZERO); + } - // Step 3: Generate a new random challenge - let mut rng = rand::thread_rng(); - let new_challenge: F = F::random(&mut rng); + // Step 3: Generate a new random challenge + let mut rng = rand::thread_rng(); + let new_challenge: F = F::random(&mut rng); - // Return true (verification passed) and the evaluation at the new challenge point - (true, new_challenge) + // Return true (verification passed) and the evaluation at the new challenge point + (true, new_challenge) } /// Verifies the final round of the sumcheck protocol. @@ -407,43 +422,43 @@ pub fn verify_sumcheck_univariate_poly_sum( /// evaluation of the claimed univariate polynomial at a random point, the verifier can detect /// any dishonesty from the prover with high probability. pub fn verify_sumcheck_last_round( - challenges: Vec, - univariate_poly: &MultivariatePolynomial, - poly: &MultivariatePolynomial, + challenges: Vec, + univariate_poly: &MultivariatePolynomial, + poly: &MultivariatePolynomial, ) -> bool { - // Step 1: Apply all challenges to the original polynomial - let mut challenges_with_indices = Vec::new(); - for (i, challenge) in challenges.iter().enumerate() { - challenges_with_indices.push((i, *challenge)); - } - let poly_evaluation = poly.evaluate(&challenges_with_indices); + // Step 1: Apply all challenges to the original polynomial + let mut challenges_with_indices = Vec::new(); + for (i, challenge) in challenges.iter().enumerate() { + challenges_with_indices.push((i, *challenge)); + } + let poly_evaluation = poly.evaluate(&challenges_with_indices); - // Step 2: Generate a random challenge for the last variable - let mut rng = rand::thread_rng(); - let last_challenge: F = F::random(&mut rng); + // Step 2: Generate a random challenge for the last variable + let mut rng = rand::thread_rng(); + let last_challenge: F = F::random(&mut rng); - // Step 3: Evaluate the univariate polynomial at the last challenge - let last_var = challenges.len(); - let univariate_evaluation = univariate_poly.evaluate(&[(last_var, last_challenge)]); + // Step 3: Evaluate the univariate polynomial at the last challenge + let last_var = challenges.len(); + let univariate_evaluation = univariate_poly.evaluate(&[(last_var, last_challenge)]); - // Step 4: Compare the evaluations - poly_evaluation == univariate_evaluation + // Step 4: Compare the evaluations + poly_evaluation == univariate_evaluation } impl ToBytes for F { - fn to_bytes(&self) -> Vec { - // Implement this based on how your field elements are represented - // This is just an example: - self.to_string().into_bytes() - } + fn to_bytes(&self) -> Vec { + // Implement this based on how your field elements are represented + // This is just an example: + self.to_string().into_bytes() + } } impl ToBytes for MultivariatePolynomial { - fn to_bytes(&self) -> Vec { - // Implement this based on how your polynomials are represented - // This is just an example: - self.to_string().into_bytes() - } + fn to_bytes(&self) -> Vec { + // Implement this based on how your polynomials are represented + // This is just an example: + self.to_string().into_bytes() + } } /// Represents a proof for the sumcheck protocol over a finite field. @@ -453,29 +468,28 @@ impl ToBytes for MultivariatePolynomial { /// /// # Type Parameters /// -/// * `F`: A type that implements the `FiniteField` trait, representing the field over which -/// the sumcheck protocol is performed. +/// * `F`: A type that implements the `FiniteField` trait, representing the field over which the +/// sumcheck protocol is performed. pub struct SumcheckProof { - /// The claimed sum of the polynomial over all boolean inputs. - pub claimed_sum: F, + /// The claimed sum of the polynomial over all boolean inputs. + pub claimed_sum: F, - /// Vector of univariate polynomials, one for each round of the protocol. - pub round_polynomials: Vec>, + /// Vector of univariate polynomials, one for each round of the protocol. + pub round_polynomials: Vec>, - /// Vector of challenges generated during the protocol. - pub challenges: Vec, + /// Vector of challenges generated during the protocol. + pub challenges: Vec, - /// Vector of evaluations of the round polynomials at the challenge points. - pub round_evaluations: Vec, + /// Vector of evaluations of the round polynomials at the challenge points. + pub round_evaluations: Vec, - /// The final evaluation point, consisting of all challenges combined. - pub final_point: Vec, + /// The final evaluation point, consisting of all challenges combined. + pub final_point: Vec, - /// The final evaluation of the original multivariate polynomial at the final point. - pub final_evaluation: F, + /// The final evaluation of the original multivariate polynomial at the final point. + pub final_evaluation: F, } - /// Generates a non-interactive sumcheck proof for a given multivariate polynomial. /// /// This function implements the prover's side of the non-interactive sumcheck protocol. @@ -500,54 +514,57 @@ pub struct SumcheckProof { /// /// # Type Parameters /// -/// * `F` - A finite field type that implements necessary traits for arithmetic, -/// random number generation, conversion to bytes, and display. -pub fn non_interactive_sumcheck_prove( - polynomial: &MultivariatePolynomial +/// * `F` - A finite field type that implements necessary traits for arithmetic, random number +/// generation, conversion to bytes, and display. +pub fn non_interactive_sumcheck_prove< + F: FiniteField + Random + RandomOracle + Display + ToBytes, +>( + polynomial: &MultivariatePolynomial, ) -> SumcheckProof { - let num_variables = polynomial.variables().len(); - let mut challenges = Vec::new(); - let mut round_polynomials = Vec::new(); - let mut round_evaluations = Vec::new(); - - // First round: compute the claimed sum and the first univariate polynomial - let (claimed_sum, first_univariate_poly) = polynomial.prove_first_sumcheck_round(); - round_polynomials.push(first_univariate_poly.clone()); - - // Generate the first challenge using the random oracle - let mut rng = rand::thread_rng(); - let challenge: F = F::random_oracle(&mut rng, &claimed_sum.to_bytes()); + let num_variables = polynomial.variables().len(); + let mut challenges = Vec::new(); + let mut round_polynomials = Vec::new(); + let mut round_evaluations = Vec::new(); + + // First round: compute the claimed sum and the first univariate polynomial + let (claimed_sum, first_univariate_poly) = polynomial.prove_first_sumcheck_round(); + round_polynomials.push(first_univariate_poly.clone()); + + // Generate the first challenge using the random oracle + let mut rng = rand::thread_rng(); + let challenge: F = F::random_oracle(&mut rng, &claimed_sum.to_bytes()); + challenges.push(challenge); + round_evaluations.push(first_univariate_poly.evaluate(&[(0, challenge)])); + + let mut previous_univariate_poly = first_univariate_poly; + + // Intermediate rounds: generate univariate polynomials and challenges + for i in 1..num_variables { + let univariate_poly = polynomial.prove_sumcheck_round_i(i, challenges.clone()); + round_polynomials.push(univariate_poly.clone()); + + // Generate challenge for this round using the random oracle + let challenge: F = F::random_oracle(&mut rng, &previous_univariate_poly.to_bytes()); challenges.push(challenge); - round_evaluations.push(first_univariate_poly.evaluate(&[(0, challenge)])); + round_evaluations.push(univariate_poly.evaluate(&[(i, challenge)])); - let mut previous_univariate_poly = first_univariate_poly; - - // Intermediate rounds: generate univariate polynomials and challenges - for i in 1..num_variables { - let univariate_poly = polynomial.prove_sumcheck_round_i(i, challenges.clone()); - round_polynomials.push(univariate_poly.clone()); - - // Generate challenge for this round using the random oracle - let challenge: F = F::random_oracle(&mut rng, &previous_univariate_poly.to_bytes()); - challenges.push(challenge); - round_evaluations.push(univariate_poly.evaluate(&[(i, challenge)])); - - previous_univariate_poly = univariate_poly; - } + previous_univariate_poly = univariate_poly; + } - // Final evaluation: evaluate the original polynomial at the challenge point - let final_point = challenges.clone(); - let final_evaluation = polynomial.evaluate(&final_point.iter().cloned().enumerate().collect::>()); - - // Construct and return the proof - SumcheckProof { - claimed_sum, - round_polynomials, - round_evaluations, - challenges, - final_point, - final_evaluation, - } + // Final evaluation: evaluate the original polynomial at the challenge point + let final_point = challenges.clone(); + let final_evaluation = + polynomial.evaluate(&final_point.iter().cloned().enumerate().collect::>()); + + // Construct and return the proof + SumcheckProof { + claimed_sum, + round_polynomials, + round_evaluations, + challenges, + final_point, + final_evaluation, + } } /// Verifies a non-interactive sumcheck proof. @@ -555,12 +572,14 @@ pub fn non_interactive_sumcheck_prove( - proof: &SumcheckProof, - polynomial: &MultivariatePolynomial + proof: &SumcheckProof, + polynomial: &MultivariatePolynomial, ) -> bool { - let num_variables = polynomial.variables().len(); + let num_variables = polynomial.variables().len(); - // Verify first round - let (valid, _) = verify_sumcheck_first_round(proof.claimed_sum, &proof.round_polynomials[0]); - if !valid { - return false; - } + // Verify first round + let (valid, _) = verify_sumcheck_first_round(proof.claimed_sum, &proof.round_polynomials[0]); + if !valid { + return false; + } - // Verify intermediate rounds - for i in 1..num_variables { - let (valid, _) = verify_sumcheck_univariate_poly_sum( - i, - proof.challenges[i-1], - &proof.round_polynomials[i-1], - &proof.round_polynomials[i], - ); - if !valid { - return false; - } + // Verify intermediate rounds + for i in 1..num_variables { + let (valid, _) = verify_sumcheck_univariate_poly_sum( + i, + proof.challenges[i - 1], + &proof.round_polynomials[i - 1], + &proof.round_polynomials[i], + ); + if !valid { + return false; } + } - // Verify last round - verify_sumcheck_last_round( - proof.final_point.clone(), - &proof.round_polynomials.last().unwrap(), - polynomial, - ) -} \ No newline at end of file + // Verify last round + verify_sumcheck_last_round( + proof.final_point.clone(), + &proof.round_polynomials.last().unwrap(), + polynomial, + ) +} diff --git a/src/sumcheck/tests.rs b/src/sumcheck/tests.rs index 90984561..186084bb 100644 --- a/src/sumcheck/tests.rs +++ b/src/sumcheck/tests.rs @@ -1,8 +1,5 @@ use crate::{ - algebra::field::{ - prime::PlutoBaseField, - Field, - }, + algebra::field::{prime::PlutoBaseField, Field}, polynomial::multivariate_polynomial::{ MultivariatePolynomial, MultivariateTerm, MultivariateVariable, }, @@ -11,7 +8,6 @@ use crate::{ }, }; -#[test] #[test] fn test_full_sumcheck_protocol() { // This test demonstrates the full sumcheck protocol for the polynomial: diff --git a/src/sumcheck/to_bytes.rs b/src/sumcheck/to_bytes.rs index 306a5fc8..2d5e3b13 100644 --- a/src/sumcheck/to_bytes.rs +++ b/src/sumcheck/to_bytes.rs @@ -1,3 +1,3 @@ pub trait ToBytes { - fn to_bytes(&self) -> Vec; -} \ No newline at end of file + fn to_bytes(&self) -> Vec; +}