From be434418ee0b4160c9893658b37d54057cebb5b0 Mon Sep 17 00:00:00 2001 From: devloper Date: Wed, 1 May 2024 18:11:08 -0700 Subject: [PATCH] Fix modular arithmetic (#13) * Fix modular arithmetic * cargo f --------- Co-authored-by: devloper <3347622+devloper@users.noreply.github.com> Co-authored-by: Thor Kampefner --- src/field.rs | 251 +++++++++++++++++++++++---------------------------- src/main.rs | 7 +- 2 files changed, 114 insertions(+), 144 deletions(-) diff --git a/src/field.rs b/src/field.rs index 695278ee..868a679d 100644 --- a/src/field.rs +++ b/src/field.rs @@ -1,153 +1,122 @@ -use p3_field::{AbstractField, Field, Packable, halve_u64}; -use core::iter::{Product, Sum}; -use core::hash::{Hash, Hasher}; -use core::ops::{Add, AddAssign, Div, Mul, MulAssign, Neg, Sub, SubAssign}; -use serde::{Deserialize, Serialize}; +use core::{ + hash::{Hash, Hasher}, + iter::{Product, Sum}, + ops::{Add, AddAssign, Div, Mul, MulAssign, Neg, Sub, SubAssign}, +}; use std::fmt; + use num_bigint::BigUint; +use p3_field::{halve_u64, AbstractField, Field, Packable}; +use serde::{Deserialize, Serialize}; const PLUTO_FIELD_PRIME: u64 = 101; #[derive(Copy, Clone, Default, Serialize, Deserialize, Debug)] pub struct PlutoField { - value: u64, + value: u64, } impl PlutoField { - const ORDER_U64: u64 = PLUTO_FIELD_PRIME; + pub const ORDER_U64: u64 = PLUTO_FIELD_PRIME; - pub fn new(value: u64) -> Self { - Self { value } - } + pub fn new(value: u64) -> Self { Self { value } } } impl PartialEq for PlutoField { - fn eq(&self, other: &Self) -> bool { - // TODO: removed canonicalization - self.value == other.value - //self.as_canonical_u64() == other.as_canonical_u64() - } + fn eq(&self, other: &Self) -> bool { + // TODO: removed canonicalization + self.value == other.value + // self.as_canonical_u64() == other.as_canonical_u64() + } } impl fmt::Display for PlutoField { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.value) - } + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{}", self.value) } } impl Eq for PlutoField {} impl Packable for PlutoField {} impl Div for PlutoField { - type Output = Self; + type Output = Self; - #[allow(clippy::suspicious_arithmetic_impl)] - fn div(self, rhs: Self) -> Self { - self * rhs.inverse() - } + #[allow(clippy::suspicious_arithmetic_impl)] + fn div(self, rhs: Self) -> Self { self * rhs.inverse() } } impl Field for PlutoField { - // TODO: Add cfg-guarded Packing for AVX2, NEON, etc. - type Packing = Self; + // TODO: Add cfg-guarded Packing for AVX2, NEON, etc. + type Packing = Self; - fn is_zero(&self) -> bool { - self.value == 0 || self.value == Self::ORDER_U64 - } + fn is_zero(&self) -> bool { self.value == 0 || self.value == Self::ORDER_U64 } - #[inline] - fn exp_u64_generic>(val: AF, _power: u64) -> AF { - // TODO: Fix exponentiation - val - // match power { - // 10540996611094048183 => exp_10540996611094048183(val), // used to compute x^{1/7} - // _ => exp_u64_by_squaring(val, power), - // } - } + #[inline] + fn exp_u64_generic>(val: AF, _power: u64) -> AF { + // TODO: Fix exponentiation + val + // match power { + // 10540996611094048183 => exp_10540996611094048183(val), // used to compute x^{1/7} + // _ => exp_u64_by_squaring(val, power), + // } + } - fn try_inverse(&self) -> Option { - // TODO: Fix inverse - Some(Self::new(1)) - } + fn try_inverse(&self) -> Option { + // TODO: Fix inverse + Some(Self::new(1)) + } - #[inline] - fn halve(&self) -> Self { - PlutoField::new(halve_u64::(self.value)) - } + #[inline] + fn halve(&self) -> Self { PlutoField::new(halve_u64::(self.value)) } - #[inline] - fn order() -> BigUint { - PLUTO_FIELD_PRIME.into() - } + #[inline] + fn order() -> BigUint { PLUTO_FIELD_PRIME.into() } } - impl AbstractField for PlutoField { - type F = Self; + type F = Self; - fn zero() -> Self { - Self::new(0) - } - fn one() -> Self { - Self::new(1) - } - fn two() -> Self { - Self::new(2) - } - fn neg_one() -> Self { - Self::new(Self::ORDER_U64 - 1) - } + fn zero() -> Self { Self::new(0) } - #[inline] - fn from_f(f: Self::F) -> Self { - f - } + fn one() -> Self { Self::new(1) } - fn from_bool(b: bool) -> Self { - Self::new(u64::from(b)) - } + fn two() -> Self { Self::new(2) } - fn from_canonical_u8(n: u8) -> Self { - Self::new(u64::from(n)) - } + fn neg_one() -> Self { Self::new(Self::ORDER_U64 - 1) } - fn from_canonical_u16(n: u16) -> Self { - Self::new(u64::from(n)) - } + #[inline] + fn from_f(f: Self::F) -> Self { f } - fn from_canonical_u32(n: u32) -> Self { - Self::new(u64::from(n)) - } + fn from_bool(b: bool) -> Self { Self::new(u64::from(b)) } - fn from_canonical_u64(n: u64) -> Self { - Self::new(n) - } + fn from_canonical_u8(n: u8) -> Self { Self::new(u64::from(n)) } - fn from_canonical_usize(n: usize) -> Self { - Self::new(n as u64) - } + fn from_canonical_u16(n: u16) -> Self { Self::new(u64::from(n)) } - fn from_wrapped_u32(n: u32) -> Self { - // A u32 must be canonical, plus we don't store canonical encodings anyway, so there's no - // need for a reduction. - Self::new(u64::from(n)) - } + fn from_canonical_u32(n: u32) -> Self { Self::new(u64::from(n)) } - fn from_wrapped_u64(n: u64) -> Self { - // There's no need to reduce `n` to canonical form, as our internal encoding is - // non-canonical, so there's no need for a reduction. - Self::new(n) - } + fn from_canonical_u64(n: u64) -> Self { Self::new(n) } - // Sage: GF(2^64 - 2^32 + 1).multiplicative_generator() - fn generator() -> Self { - Self::new(7) - } + fn from_canonical_usize(n: usize) -> Self { Self::new(n as u64) } + + fn from_wrapped_u32(n: u32) -> Self { + // A u32 must be canonical, plus we don't store canonical encodings anyway, so there's no + // need for a reduction. + Self::new(u64::from(n)) + } + + fn from_wrapped_u64(n: u64) -> Self { + // There's no need to reduce `n` to canonical form, as our internal encoding is + // non-canonical, so there's no need for a reduction. + Self::new(n) + } + + // Sage: GF(2^64 - 2^32 + 1).multiplicative_generator() + fn generator() -> Self { Self::new(7) } } impl Hash for PlutoField { - fn hash(&self, state: &mut H) { - state.write_u64(self.value); - // state.write_u64(self.as_canonical_u64()); - } + fn hash(&self, state: &mut H) { + state.write_u64(self.value); + // state.write_u64(self.as_canonical_u64()); + } } // impl PrimeField for PlutoField { @@ -175,66 +144,68 @@ impl Hash for PlutoField { // } impl Mul for PlutoField { - type Output = Self; + type Output = Self; - fn mul(self, rhs: Self) -> Self { - // reduce128(u128::from(self.value) * u128::from(rhs.value)) - let mul = self.value * rhs.value; - Self::new(mul) - } + fn mul(self, rhs: Self) -> Self { + // reduce128(u128::from(self.value) * u128::from(rhs.value)) + let mul = self.value * rhs.value; + Self::new(mul) + } } impl Product for PlutoField { - fn product>(iter: I) -> Self { - iter.reduce(|x, y| x * y).unwrap_or(Self::one()) - } + fn product>(iter: I) -> Self { + iter.reduce(|x, y| x * y).unwrap_or(Self::one()) + } } impl SubAssign for PlutoField { - fn sub_assign(&mut self, rhs: Self) { - *self = *self - rhs; - } + fn sub_assign(&mut self, rhs: Self) { *self = *self - rhs; } } impl AddAssign for PlutoField { - fn add_assign(&mut self, rhs: Self) { - *self = *self + rhs; - } + fn add_assign(&mut self, rhs: Self) { *self = *self + rhs; } } impl MulAssign for PlutoField { - fn mul_assign(&mut self, rhs: Self) { - *self = *self * rhs; - } + fn mul_assign(&mut self, rhs: Self) { *self = *self * rhs; } } impl Neg for PlutoField { - type Output = Self; + type Output = Self; - fn neg(self) -> Self::Output { - Self::new(Self::ORDER_U64 - self.value) - // Self::new(Self::ORDER_U64 - self.as_canonical_u64()) - } + fn neg(self) -> Self::Output { + Self::new(Self::ORDER_U64 - self.value) + // Self::new(Self::ORDER_U64 - self.as_canonical_u64()) + } } impl Add for PlutoField { - type Output = Self; - fn add(self, rhs: Self) -> Self { - Self::new(self.value + rhs.value) - } + type Output = Self; + + fn add(self, rhs: Self) -> Self { + let mut sum = self.value + rhs.value; + let (corr_sum, over) = sum.overflowing_sub(PLUTO_FIELD_PRIME); + if !over { + sum = corr_sum; + } + Self { value: sum } + } } impl Sum for PlutoField { - fn sum>(iter: I) -> Self { - iter.reduce(|x, y| x + y).unwrap_or(Self::zero()) - } + fn sum>(iter: I) -> Self { + iter.reduce(|x, y| x + y).unwrap_or(Self::zero()) + } } impl Sub for PlutoField { - type Output = Self; - - fn sub(self, rhs: Self) -> Self { - let diff = self.value-rhs.value; - Self::new(diff) - } -} \ No newline at end of file + type Output = Self; + + fn sub(self, rhs: Self) -> Self { + let (mut diff, over) = self.value.overflowing_sub(rhs.value); + let corr = if over { PLUTO_FIELD_PRIME } else { 0 }; + diff = diff.wrapping_add(corr); + Self::new(diff) + } +} diff --git a/src/main.rs b/src/main.rs index f0bc6870..04d124fc 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,6 @@ - use ronkathon::field::PlutoField; fn main() { - let f = PlutoField::new(1); - println!("hello field={:?}", f); -} \ No newline at end of file + let f = PlutoField::new(1); + println!("hello field={:?}", f); +}