Skip to content

Commit

Permalink
Add secure powers generation for simd
Browse files Browse the repository at this point in the history
  • Loading branch information
Gali-StarkWare committed Dec 10, 2024
1 parent 060f0e4 commit 9bed75c
Showing 1 changed file with 42 additions and 1 deletion.
43 changes: 42 additions & 1 deletion crates/prover/src/core/backend/simd/utils.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
use std::simd::Swizzle;

use itertools::Itertools;

use crate::core::backend::simd::m31::N_LANES;
use crate::core::backend::simd::qm31::PackedSecureField;
use crate::core::fields::qm31::SecureField;
use crate::core::utils::generate_secure_powers;

/// Used with [`Swizzle::concat_swizzle`] to interleave the even values of two vectors.
pub struct InterleaveEvens;

Expand Down Expand Up @@ -51,11 +58,35 @@ impl<T> UnsafeConst<T> {
unsafe impl<T> Send for UnsafeConst<T> {}
unsafe impl<T> Sync for UnsafeConst<T> {}

// TODO(Gali): Remove #[allow(dead_code)].
#[allow(dead_code)]
pub fn generate_secure_powers_simd(felt: SecureField, n_powers: usize) -> Vec<SecureField> {
let base_arr = generate_secure_powers(felt, N_LANES).try_into().unwrap();
let base_packed_felt = PackedSecureField::from_array(base_arr);

let step = base_arr[N_LANES - 1] * felt;
let step = PackedSecureField::broadcast(step);

let packed_size = (n_powers + N_LANES - 1) / N_LANES;

(0..packed_size)
.scan(base_packed_felt, |acc, _| {
let res = *acc;
*acc *= step;
Some(res)
})
.flat_map(|x| x.to_array())
.take(n_powers)
.collect_vec()
}

#[cfg(test)]
mod tests {
use std::simd::{u32x4, Swizzle};

use super::{InterleaveEvens, InterleaveOdds};
use super::{generate_secure_powers_simd, InterleaveEvens, InterleaveOdds};
use crate::core::utils::generate_secure_powers;
use crate::qm31;

#[test]
fn interleave_evens() {
Expand All @@ -76,4 +107,14 @@ mod tests {

assert_eq!(res, u32x4::from_array([1, 5, 3, 7]));
}

#[test]
fn test_generate_secure_powers_simd() {
let felt = qm31!(1, 2, 3, 4);
let n_powers = 10;

let cpu_powers = generate_secure_powers(felt, n_powers);
let powers = generate_secure_powers_simd(felt, n_powers);
assert_eq!(powers, cpu_powers);
}
}

0 comments on commit 9bed75c

Please sign in to comment.