Skip to content

Commit

Permalink
Change Eval Framework Copy requirement to Clone. (#834)
Browse files Browse the repository at this point in the history
<!-- Reviewable:start -->
This change is [<img src="https://reviewable.io/review_button.svg" height="34" align="absmiddle" alt="Reviewable"/>](https://reviewable.io/reviews/starkware-libs/stwo/834)
<!-- Reviewable:end -->
  • Loading branch information
Alon-Ti authored Sep 25, 2024
1 parent ee8a9c6 commit 6e649fc
Show file tree
Hide file tree
Showing 21 changed files with 276 additions and 192 deletions.
22 changes: 11 additions & 11 deletions crates/prover/src/constraint_framework/logup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ impl<E: EvalAtRow> LogupAtRow<E> {

pub fn write_frac(&mut self, eval: &mut E, fraction: Fraction<E::EF, E::EF>) {
// Add a constraint that num / denom = diff.
if let Some(cur_frac) = self.cur_frac {
let cur_cumsum = eval.next_extension_interaction_mask(self.interaction, [0])[0];
let diff = cur_cumsum - self.prev_col_cumsum;
if let Some(cur_frac) = self.cur_frac.clone() {
let [cur_cumsum] = eval.next_extension_interaction_mask(self.interaction, [0]);
let diff = cur_cumsum.clone() - self.prev_col_cumsum.clone();
self.prev_col_cumsum = cur_cumsum;
eval.add_constraint(diff * cur_frac.denominator - cur_frac.numerator);
}
Expand All @@ -76,7 +76,7 @@ impl<E: EvalAtRow> LogupAtRow<E> {
pub fn finalize(mut self, eval: &mut E) {
assert!(!self.is_finalized, "LogupAtRow was already finalized");

let frac = self.cur_frac.unwrap();
let frac = self.cur_frac.clone().unwrap();

// TODO(ShaharS): remove `claimed_row_index` interaction value and get the shifted offset
// from the is_first column when constant columns are supported.
Expand All @@ -89,7 +89,7 @@ impl<E: EvalAtRow> LogupAtRow<E> {
);

// Constrain that the claimed_sum in case that it is not equal to the total_sum.
eval.add_constraint((claimed_cumsum - claimed_sum) * self.is_first);
eval.add_constraint((claimed_cumsum - claimed_sum) * self.is_first.clone());
(cur_cumsum, prev_row_cumsum)
}
None => {
Expand All @@ -99,8 +99,8 @@ impl<E: EvalAtRow> LogupAtRow<E> {
}
};
// Fix `prev_row_cumsum` by subtracting `total_sum` if this is the first row.
let fixed_prev_row_cumsum = prev_row_cumsum - self.is_first * self.total_sum;
let diff = cur_cumsum - fixed_prev_row_cumsum - self.prev_col_cumsum;
let fixed_prev_row_cumsum = prev_row_cumsum - self.is_first.clone() * self.total_sum;
let diff = cur_cumsum - fixed_prev_row_cumsum - self.prev_col_cumsum.clone();

eval.add_constraint(diff * frac.denominator - frac.numerator);

Expand Down Expand Up @@ -138,9 +138,9 @@ impl<const N: usize> LookupElements<N> {
alpha_powers,
}
}
pub fn combine<F: Copy, EF>(&self, values: &[F]) -> EF
pub fn combine<F: Clone, EF>(&self, values: &[F]) -> EF
where
EF: Copy + Zero + From<F> + From<SecureField> + Mul<F, Output = EF> + Sub<EF, Output = EF>,
EF: Clone + Zero + From<F> + From<SecureField> + Mul<F, Output = EF> + Sub<EF, Output = EF>,
{
assert!(
self.alpha_powers.len() >= values.len(),
Expand All @@ -149,8 +149,8 @@ impl<const N: usize> LookupElements<N> {
values
.iter()
.zip(self.alpha_powers)
.fold(EF::zero(), |acc, (&value, power)| {
acc + EF::from(power) * value
.fold(EF::zero(), |acc, (value, power)| {
acc + EF::from(power) * value.clone()
})
- EF::from(self.z)
}
Expand Down
11 changes: 7 additions & 4 deletions crates/prover/src/constraint_framework/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ pub trait EvalAtRow {
/// constraints. It might be [BaseField] packed types, or even [SecureField], when evaluating
/// the columns out of domain.
type F: FieldExpOps
+ Copy
+ Clone
+ Debug
+ Zero
+ Neg<Output = Self::F>
Expand All @@ -48,7 +48,7 @@ pub trait EvalAtRow {
/// A field type representing the closure of `F` with multiplying by [SecureField]. Constraints
/// usually get multiplied by [SecureField] values for security.
type EF: One
+ Copy
+ Clone
+ Debug
+ Zero
+ From<Self::F>
Expand Down Expand Up @@ -84,8 +84,11 @@ pub trait EvalAtRow {
interaction: usize,
offsets: [isize; N],
) -> [Self::EF; N] {
let res_col_major = array::from_fn(|_| self.next_interaction_mask(interaction, offsets));
array::from_fn(|i| Self::combine_ef(res_col_major.map(|c| c[i])))
let mut res_col_major =
array::from_fn(|_| self.next_interaction_mask(interaction, offsets).into_iter());
array::from_fn(|_| {
Self::combine_ef(res_col_major.each_mut().map(|iter| iter.next().unwrap()))
})
}

/// Adds a constraint to the component.
Expand Down
8 changes: 4 additions & 4 deletions crates/prover/src/core/backend/simd/circle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ impl SimdBackend {
);

let mut product = F::one();
for &num in mappings.iter() {
for num in mappings.iter() {
if index & 1 == 1 {
product *= num;
product *= *num;
}
index >>= 1;
if index == 0 {
Expand Down Expand Up @@ -108,8 +108,8 @@ impl SimdBackend {
.iter()
.skip(1)
.zip(denom_inverses.iter())
.for_each(|(&m, &d)| {
steps.push(m * d);
.for_each(|(m, d)| {
steps.push(*m * *d);
});
steps.push(F::one());
steps
Expand Down
48 changes: 35 additions & 13 deletions crates/prover/src/core/backend/simd/very_packed_m31.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ use crate::core::fields::FieldExpOps;
pub const LOG_N_VERY_PACKED_ELEMS: u32 = 1;
pub const N_VERY_PACKED_ELEMS: usize = 1 << LOG_N_VERY_PACKED_ELEMS;

#[derive(Copy, Clone, Debug)]
#[derive(Clone, Debug, Copy)]
#[repr(transparent)]
pub struct Vectorized<A, const N: usize>(pub [A; N]);
pub struct Vectorized<A: Copy, const N: usize>(pub [A; N]);

impl<A, const N: usize> Vectorized<A, N> {
impl<A: Copy, const N: usize> Vectorized<A, N> {
pub fn from_fn<F>(cb: F) -> Self
where
F: FnMut(usize) -> A,
Expand All @@ -27,17 +27,18 @@ impl<A, const N: usize> Vectorized<A, N> {
}
}

impl<A, const N: usize> From<[A; N]> for Vectorized<A, N> {
impl<A: Copy, const N: usize> From<[A; N]> for Vectorized<A, N> {
fn from(array: [A; N]) -> Self {
Vectorized(array)
}
}

unsafe impl<A, const N: usize> Zeroable for Vectorized<A, N> {
unsafe impl<A: Copy, const N: usize> Zeroable for Vectorized<A, N> {
fn zeroed() -> Self {
unsafe { core::mem::zeroed() }
}
}

unsafe impl<A: Pod, const N: usize> Pod for Vectorized<A, N> {}

pub type VeryPackedM31 = Vectorized<PackedM31, N_VERY_PACKED_ELEMS>;
Expand Down Expand Up @@ -121,47 +122,65 @@ impl Scalar for PackedM31 {}
impl Scalar for PackedCM31 {}
impl Scalar for PackedQM31 {}

impl<A: Add<B> + Copy, B: Copy, const N: usize> Add<Vectorized<B, N>> for Vectorized<A, N> {
impl<A: Add<B> + Copy, B: Copy, const N: usize> Add<Vectorized<B, N>> for Vectorized<A, N>
where
<A as Add<B>>::Output: Copy,
{
type Output = Vectorized<A::Output, N>;

fn add(self, other: Vectorized<B, N>) -> Self::Output {
Vectorized::from_fn(|i| self.0[i] + other.0[i])
}
}

impl<A: Add<B> + Copy, B: Scalar + Copy, const N: usize> Add<B> for Vectorized<A, N> {
impl<A: Add<B> + Copy, B: Scalar + Copy, const N: usize> Add<B> for Vectorized<A, N>
where
<A as Add<B>>::Output: Copy,
{
type Output = Vectorized<A::Output, N>;

fn add(self, other: B) -> Self::Output {
Vectorized::from_fn(|i| self.0[i] + other)
}
}

impl<A: Sub<B> + Copy, B: Copy, const N: usize> Sub<Vectorized<B, N>> for Vectorized<A, N> {
impl<A: Sub<B> + Copy, B: Copy, const N: usize> Sub<Vectorized<B, N>> for Vectorized<A, N>
where
<A as Sub<B>>::Output: Copy,
{
type Output = Vectorized<A::Output, N>;

fn sub(self, other: Vectorized<B, N>) -> Self::Output {
Vectorized::from_fn(|i| self.0[i] - other.0[i])
}
}

impl<A: Sub<B> + Copy, B: Scalar + Copy, const N: usize> Sub<B> for Vectorized<A, N> {
impl<A: Sub<B> + Copy, B: Scalar + Copy, const N: usize> Sub<B> for Vectorized<A, N>
where
<A as Sub<B>>::Output: Copy,
{
type Output = Vectorized<A::Output, N>;

fn sub(self, other: B) -> Self::Output {
Vectorized::from_fn(|i| self.0[i] - other)
}
}

impl<A: Mul<B> + Copy, B: Copy, const N: usize> Mul<Vectorized<B, N>> for Vectorized<A, N> {
impl<A: Mul<B> + Copy, B: Copy, const N: usize> Mul<Vectorized<B, N>> for Vectorized<A, N>
where
<A as Mul<B>>::Output: Copy,
{
type Output = Vectorized<A::Output, N>;

fn mul(self, other: Vectorized<B, N>) -> Self::Output {
Vectorized::from_fn(|i| self.0[i] * other.0[i])
}
}

impl<A: Mul<B> + Copy, B: Scalar + Copy, const N: usize> Mul<B> for Vectorized<A, N> {
impl<A: Mul<B> + Copy, B: Scalar + Copy, const N: usize> Mul<B> for Vectorized<A, N>
where
<A as Mul<B>>::Output: Copy,
{
type Output = Vectorized<A::Output, N>;

fn mul(self, other: B) -> Self::Output {
Expand Down Expand Up @@ -197,7 +216,10 @@ impl<A: MulAssign<B> + Copy, B: Copy, const N: usize> MulAssign<Vectorized<B, N>
}
}

impl<A: Neg + Copy, const N: usize> Neg for Vectorized<A, N> {
impl<A: Neg + Copy, const N: usize> Neg for Vectorized<A, N>
where
<A as Neg>::Output: Copy,
{
type Output = Vectorized<A::Output, N>;

#[inline(always)]
Expand All @@ -222,7 +244,7 @@ impl<A: One + Copy, const N: usize> One for Vectorized<A, N> {
}
}

impl<A: FieldExpOps + Zero, const N: usize> FieldExpOps for Vectorized<A, N> {
impl<A: FieldExpOps + Zero + Copy, const N: usize> FieldExpOps for Vectorized<A, N> {
fn inverse(&self) -> Self {
Vectorized::from_fn(|i| {
assert!(!self.0[i].is_zero(), "0 has no inverse");
Expand Down
26 changes: 13 additions & 13 deletions crates/prover/src/core/circle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ impl<F: Zero + Add<Output = F> + FieldExpOps + Sub<Output = F> + Neg<Output = F>
}

pub fn double(&self) -> Self {
*self + *self
self.clone() + self.clone()
}

/// Applies the circle's x-coordinate doubling map.
Expand All @@ -40,7 +40,7 @@ impl<F: Zero + Add<Output = F> + FieldExpOps + Sub<Output = F> + Neg<Output = F>
/// ```
pub fn double_x(x: F) -> F {
let sx = x.square();
sx + sx - F::one()
sx.clone() + sx - F::one()
}

/// Returns the log order of a point.
Expand All @@ -61,7 +61,7 @@ impl<F: Zero + Add<Output = F> + FieldExpOps + Sub<Output = F> + Neg<Output = F>
// we only need the x-coordinate to check order since the only point
// with x=1 is the circle's identity
let mut res = 0;
let mut cur = self.x;
let mut cur = self.x.clone();
while cur != F::one() {
cur = Self::double_x(cur);
res += 1;
Expand All @@ -71,10 +71,10 @@ impl<F: Zero + Add<Output = F> + FieldExpOps + Sub<Output = F> + Neg<Output = F>

pub fn mul(&self, mut scalar: u128) -> CirclePoint<F> {
let mut res = Self::zero();
let mut cur = *self;
let mut cur = self.clone();
while scalar > 0 {
if scalar & 1 == 1 {
res = res + cur;
res = res + cur.clone();
}
cur = cur.double();
scalar >>= 1;
Expand All @@ -83,7 +83,7 @@ impl<F: Zero + Add<Output = F> + FieldExpOps + Sub<Output = F> + Neg<Output = F>
}

pub fn repeated_double(&self, n: u32) -> Self {
let mut res = *self;
let mut res = self.clone();
for _ in 0..n {
res = res.double();
}
Expand All @@ -92,22 +92,22 @@ impl<F: Zero + Add<Output = F> + FieldExpOps + Sub<Output = F> + Neg<Output = F>

pub fn conjugate(&self) -> CirclePoint<F> {
Self {
x: self.x,
y: -self.y,
x: self.x.clone(),
y: -self.y.clone(),
}
}

pub fn antipode(&self) -> CirclePoint<F> {
Self {
x: -self.x,
y: -self.y,
x: -self.x.clone(),
y: -self.y.clone(),
}
}

pub fn into_ef<EF: From<F>>(&self) -> CirclePoint<EF> {
CirclePoint {
x: self.x.into(),
y: self.y.into(),
x: self.x.clone().into(),
y: self.y.clone().into(),
}
}

Expand All @@ -126,7 +126,7 @@ impl<F: Zero + Add<Output = F> + FieldExpOps + Sub<Output = F> + Neg<Output = F>
type Output = Self;

fn add(self, rhs: Self) -> Self::Output {
let x = self.x * rhs.x - self.y * rhs.y;
let x = self.x.clone() * rhs.x.clone() - self.y.clone() * rhs.y.clone();
let y = self.x * rhs.y + self.y * rhs.x;
Self { x, y }
}
Expand Down
4 changes: 2 additions & 2 deletions crates/prover/src/core/fft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use super::fields::m31::BaseField;

pub fn butterfly<F>(v0: &mut F, v1: &mut F, twid: BaseField)
where
F: Copy + AddAssign<F> + Sub<F, Output = F> + Mul<BaseField, Output = F>,
F: AddAssign<F> + Sub<F, Output = F> + Mul<BaseField, Output = F> + Copy,
{
let tmp = *v1 * twid;
*v1 = *v0 - tmp;
Expand All @@ -13,7 +13,7 @@ where

pub fn ibutterfly<F>(v0: &mut F, v1: &mut F, itwid: BaseField)
where
F: Copy + AddAssign<F> + Add<F, Output = F> + Sub<F, Output = F> + Mul<BaseField, Output = F>,
F: AddAssign<F> + Add<F, Output = F> + Sub<F, Output = F> + Mul<BaseField, Output = F> + Copy,
{
let tmp = *v0;
*v0 = tmp + *v1;
Expand Down
12 changes: 6 additions & 6 deletions crates/prover/src/core/fields/m31.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,12 +186,12 @@ macro_rules! m31 {
/// assert_eq!(pow2147483645(v), v.pow(2147483645));
/// ```
pub fn pow2147483645<T: FieldExpOps>(v: T) -> T {
let t0 = sqn::<2, T>(v) * v;
let t1 = sqn::<1, T>(t0) * t0;
let t2 = sqn::<3, T>(t1) * t0;
let t3 = sqn::<1, T>(t2) * t0;
let t4 = sqn::<8, T>(t3) * t3;
let t5 = sqn::<8, T>(t4) * t3;
let t0 = sqn::<2, T>(v.clone()) * v.clone();
let t1 = sqn::<1, T>(t0.clone()) * t0.clone();
let t2 = sqn::<3, T>(t1.clone()) * t0.clone();
let t3 = sqn::<1, T>(t2.clone()) * t0.clone();
let t4 = sqn::<8, T>(t3.clone()) * t3.clone();
let t5 = sqn::<8, T>(t4.clone()) * t3.clone();
sqn::<7, T>(t5) * t2
}

Expand Down
Loading

0 comments on commit 6e649fc

Please sign in to comment.