diff --git a/crates/prover/src/core/backend/simd/utils.rs b/crates/prover/src/core/backend/simd/utils.rs index d5f53a22b..d0886cab6 100644 --- a/crates/prover/src/core/backend/simd/utils.rs +++ b/crates/prover/src/core/backend/simd/utils.rs @@ -1,5 +1,11 @@ use std::simd::Swizzle; +use num_traits::One; + +use crate::core::backend::simd::m31::N_LANES; +use crate::core::backend::simd::qm31::PackedSecureField; +use crate::core::fields::qm31::SecureField; + /// Used with [`Swizzle::concat_swizzle`] to interleave the even values of two vectors. pub struct InterleaveEvens; @@ -51,11 +57,51 @@ impl UnsafeConst { unsafe impl Send for UnsafeConst {} unsafe impl Sync for UnsafeConst {} +// TODO(Gali): Remove #[allow(dead_code)]. +#[allow(dead_code)] +pub fn generate_secure_powers_for_simd(felt: SecureField, n_powers: usize) -> Vec { + let step_arr: [SecureField; N_LANES] = (0..N_LANES) + .scan(SecureField::one(), |acc, _| { + let res = *acc; + *acc *= felt; + Some(res) + }) + .collect::>() + .try_into() + .expect("Failed generating secure powers."); + let step_packed_felt = PackedSecureField::from_array(step_arr); + + let mut base_felt = SecureField::one(); + let step_felt = step_arr[N_LANES - 1] * felt; + + let mut packed_powers_vec = Vec::new(); + let mut curr_power: usize = 0; + + while curr_power < n_powers { + let base_packed_felt = PackedSecureField::from_array([base_felt; N_LANES]); + packed_powers_vec.push(base_packed_felt * step_packed_felt); + base_felt *= step_felt; + curr_power += N_LANES; + } + + let powers_vec: Vec = packed_powers_vec + .iter() + .flat_map(|x| x.to_array().to_vec()) + .collect(); + + powers_vec[0..n_powers].to_vec() +} + #[cfg(test)] mod tests { use std::simd::{u32x4, Swizzle}; - use super::{InterleaveEvens, InterleaveOdds}; + use num_traits::One; + + use super::{generate_secure_powers_for_simd, InterleaveEvens, InterleaveOdds}; + use crate::core::fields::qm31::SecureField; + use crate::core::fields::FieldExpOps; + use crate::qm31; #[test] fn interleave_evens() { @@ -76,4 +122,27 @@ mod tests { assert_eq!(res, u32x4::from_array([1, 5, 3, 7])); } + + #[test] + fn generate_secure_powers_works() { + let felt = qm31!(1, 2, 3, 4); + let n_powers = 10; + + let powers = generate_secure_powers_for_simd(felt, n_powers); + + assert_eq!(powers.len(), n_powers); + assert_eq!(powers[0], SecureField::one()); + assert_eq!(powers[1], felt); + assert_eq!(powers[7], felt.pow(7)); + } + + #[test] + fn generate_empty_secure_powers_works() { + let felt = qm31!(1, 2, 3, 4); + let max_log_size = 0; + + let powers = generate_secure_powers_for_simd(felt, max_log_size); + + assert_eq!(powers, vec![]); + } }