From 9bed75ca887a307937d127779aad3acb2bcc2cee Mon Sep 17 00:00:00 2001 From: Gali Michlevich Date: Tue, 10 Dec 2024 13:16:46 +0200 Subject: [PATCH] Add secure powers generation for simd --- crates/prover/src/core/backend/simd/utils.rs | 43 +++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/crates/prover/src/core/backend/simd/utils.rs b/crates/prover/src/core/backend/simd/utils.rs index d5f53a22b..f72f4ee73 100644 --- a/crates/prover/src/core/backend/simd/utils.rs +++ b/crates/prover/src/core/backend/simd/utils.rs @@ -1,5 +1,12 @@ use std::simd::Swizzle; +use itertools::Itertools; + +use crate::core::backend::simd::m31::N_LANES; +use crate::core::backend::simd::qm31::PackedSecureField; +use crate::core::fields::qm31::SecureField; +use crate::core::utils::generate_secure_powers; + /// Used with [`Swizzle::concat_swizzle`] to interleave the even values of two vectors. pub struct InterleaveEvens; @@ -51,11 +58,35 @@ impl UnsafeConst { unsafe impl Send for UnsafeConst {} unsafe impl Sync for UnsafeConst {} +// TODO(Gali): Remove #[allow(dead_code)]. +#[allow(dead_code)] +pub fn generate_secure_powers_simd(felt: SecureField, n_powers: usize) -> Vec { + let base_arr = generate_secure_powers(felt, N_LANES).try_into().unwrap(); + let base_packed_felt = PackedSecureField::from_array(base_arr); + + let step = base_arr[N_LANES - 1] * felt; + let step = PackedSecureField::broadcast(step); + + let packed_size = (n_powers + N_LANES - 1) / N_LANES; + + (0..packed_size) + .scan(base_packed_felt, |acc, _| { + let res = *acc; + *acc *= step; + Some(res) + }) + .flat_map(|x| x.to_array()) + .take(n_powers) + .collect_vec() +} + #[cfg(test)] mod tests { use std::simd::{u32x4, Swizzle}; - use super::{InterleaveEvens, InterleaveOdds}; + use super::{generate_secure_powers_simd, InterleaveEvens, InterleaveOdds}; + use crate::core::utils::generate_secure_powers; + use crate::qm31; #[test] fn interleave_evens() { @@ -76,4 +107,14 @@ mod tests { assert_eq!(res, u32x4::from_array([1, 5, 3, 7])); } + + #[test] + fn test_generate_secure_powers_simd() { + let felt = qm31!(1, 2, 3, 4); + let n_powers = 10; + + let cpu_powers = generate_secure_powers(felt, n_powers); + let powers = generate_secure_powers_simd(felt, n_powers); + assert_eq!(powers, cpu_powers); + } }