diff --git a/crates/prover/src/core/backend/simd/utils.rs b/crates/prover/src/core/backend/simd/utils.rs index d5f53a22b..b253180ec 100644 --- a/crates/prover/src/core/backend/simd/utils.rs +++ b/crates/prover/src/core/backend/simd/utils.rs @@ -1,5 +1,12 @@ use std::simd::Swizzle; +use itertools::Itertools; + +use crate::core::backend::simd::m31::N_LANES; +use crate::core::backend::simd::qm31::PackedSecureField; +use crate::core::fields::qm31::SecureField; +use crate::core::utils::generate_secure_powers; + /// Used with [`Swizzle::concat_swizzle`] to interleave the even values of two vectors. pub struct InterleaveEvens; @@ -51,11 +58,42 @@ impl UnsafeConst { unsafe impl Send for UnsafeConst {} unsafe impl Sync for UnsafeConst {} +// TODO(Gali): Remove #[allow(dead_code)]. +#[allow(dead_code)] +pub fn generate_secure_powers_simd(felt: SecureField, n_powers: usize) -> Vec { + let base_arr: [SecureField; N_LANES] = + generate_secure_powers(felt, N_LANES).try_into().unwrap(); + let base_packed_felt = PackedSecureField::from_array(base_arr); + + let step_felt = base_arr[N_LANES - 1] * felt; + let step_packed_felt = PackedSecureField::broadcast(step_felt); + + let packed_powers_vec_size = if n_powers % N_LANES == 0 { + n_powers / N_LANES + } else { + (n_powers / N_LANES) + 1 + }; + + (0..packed_powers_vec_size) + .scan(base_packed_felt, |acc, _| { + let res = *acc; + *acc *= step_packed_felt; + Some(res) + }) + .collect_vec() + .into_iter() + .flat_map(|x| x.to_array().to_vec()) + .take(n_powers) + .collect_vec() +} + #[cfg(test)] mod tests { use std::simd::{u32x4, Swizzle}; - use super::{InterleaveEvens, InterleaveOdds}; + use super::{generate_secure_powers_simd, InterleaveEvens, InterleaveOdds}; + use crate::core::utils::generate_secure_powers; + use crate::qm31; #[test] fn interleave_evens() { @@ -76,4 +114,14 @@ mod tests { assert_eq!(res, u32x4::from_array([1, 5, 3, 7])); } + + #[test] + fn generate_secure_powers_simd_returns_the_same_as_cpu() { + let felt = qm31!(1, 2, 3, 4); + let n_powers = 30; + + let cpu_powers = generate_secure_powers(felt, n_powers); + let powers = generate_secure_powers_simd(felt, n_powers); + assert_eq!(powers, cpu_powers); + } }