Skip to content

Commit

Permalink
dependencies
Browse files Browse the repository at this point in the history
  • Loading branch information
thor314 committed May 1, 2024
1 parent d8f688b commit 14de622
Show file tree
Hide file tree
Showing 22 changed files with 2,344 additions and 10 deletions.
16 changes: 6 additions & 10 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
[package]
authors =["Pluto Authors"]
description="""ronkathon"""
edition ="2021"
license ="Apache2.0 OR MIT"
name ="ronkathon"
repository ="https://github.com/thor314/ronkathon"
version ="0.1.0"
[workspace]

[dependencies]
anyhow ="1.0"
members = [
"ronkathon",
"field",
"util"
]
14 changes: 14 additions & 0 deletions field/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[package]
name = "p3-field"
version = "0.1.0"
edition = "2021"
license = "MIT OR Apache-2.0"

[dependencies]
p3-util = { path = "../util" }
num-bigint = { version = "0.4.3", default-features = false }
num-traits = { version = "0.2.18", default_features = false }

itertools = "0.12.0"
rand = "0.8.5"
serde = { version = "1.0", default-features = false, features = ["derive"] }
203 changes: 203 additions & 0 deletions field/src/array.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
use core::array;
use core::iter::{Product, Sum};
use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};

use crate::{AbstractField, Field};

#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct FieldArray<F: Field, const N: usize>(pub [F; N]);

impl<F: Field, const N: usize> Default for FieldArray<F, N> {
fn default() -> Self {
Self::zero()
}
}

impl<F: Field, const N: usize> From<F> for FieldArray<F, N> {
fn from(val: F) -> Self {
[val; N].into()
}
}

impl<F: Field, const N: usize> From<[F; N]> for FieldArray<F, N> {
fn from(arr: [F; N]) -> Self {
Self(arr)
}
}

impl<F: Field, const N: usize> AbstractField for FieldArray<F, N> {
type F = F;

fn zero() -> Self {
FieldArray([F::zero(); N])
}
fn one() -> Self {
FieldArray([F::one(); N])
}
fn two() -> Self {
FieldArray([F::two(); N])
}
fn neg_one() -> Self {
FieldArray([F::neg_one(); N])
}

#[inline]
fn from_f(f: Self::F) -> Self {
f.into()
}

fn from_bool(b: bool) -> Self {
[F::from_bool(b); N].into()
}

fn from_canonical_u8(n: u8) -> Self {
[F::from_canonical_u8(n); N].into()
}

fn from_canonical_u16(n: u16) -> Self {
[F::from_canonical_u16(n); N].into()
}

fn from_canonical_u32(n: u32) -> Self {
[F::from_canonical_u32(n); N].into()
}

fn from_canonical_u64(n: u64) -> Self {
[F::from_canonical_u64(n); N].into()
}

fn from_canonical_usize(n: usize) -> Self {
[F::from_canonical_usize(n); N].into()
}

fn from_wrapped_u32(n: u32) -> Self {
[F::from_wrapped_u32(n); N].into()
}

fn from_wrapped_u64(n: u64) -> Self {
[F::from_wrapped_u64(n); N].into()
}

fn generator() -> Self {
[F::generator(); N].into()
}
}

impl<F: Field, const N: usize> Add for FieldArray<F, N> {
type Output = Self;

#[inline]
fn add(self, rhs: Self) -> Self::Output {
array::from_fn(|i| self.0[i] + rhs.0[i]).into()
}
}

impl<F: Field, const N: usize> Add<F> for FieldArray<F, N> {
type Output = Self;

#[inline]
fn add(self, rhs: F) -> Self::Output {
self.0.map(|x| x + rhs).into()
}
}

impl<F: Field, const N: usize> AddAssign for FieldArray<F, N> {
#[inline]
fn add_assign(&mut self, rhs: Self) {
self.0.iter_mut().zip(rhs.0).for_each(|(x, y)| *x += y);
}
}

impl<F: Field, const N: usize> AddAssign<F> for FieldArray<F, N> {
#[inline]
fn add_assign(&mut self, rhs: F) {
self.0.iter_mut().for_each(|x| *x += rhs);
}
}

impl<F: Field, const N: usize> Sub for FieldArray<F, N> {
type Output = Self;

#[inline]
fn sub(self, rhs: Self) -> Self::Output {
array::from_fn(|i| self.0[i] - rhs.0[i]).into()
}
}

impl<F: Field, const N: usize> Sub<F> for FieldArray<F, N> {
type Output = Self;

#[inline]
fn sub(self, rhs: F) -> Self::Output {
self.0.map(|x| x - rhs).into()
}
}

impl<F: Field, const N: usize> SubAssign for FieldArray<F, N> {
#[inline]
fn sub_assign(&mut self, rhs: Self) {
self.0.iter_mut().zip(rhs.0).for_each(|(x, y)| *x -= y);
}
}

impl<F: Field, const N: usize> SubAssign<F> for FieldArray<F, N> {
#[inline]
fn sub_assign(&mut self, rhs: F) {
self.0.iter_mut().for_each(|x| *x -= rhs);
}
}

impl<F: Field, const N: usize> Neg for FieldArray<F, N> {
type Output = Self;

#[inline]
fn neg(self) -> Self::Output {
self.0.map(|x| -x).into()
}
}

impl<F: Field, const N: usize> Mul for FieldArray<F, N> {
type Output = Self;

#[inline]
fn mul(self, rhs: Self) -> Self::Output {
array::from_fn(|i| self.0[i] * rhs.0[i]).into()
}
}

impl<F: Field, const N: usize> Mul<F> for FieldArray<F, N> {
type Output = Self;

#[inline]
fn mul(self, rhs: F) -> Self::Output {
self.0.map(|x| x * rhs).into()
}
}

impl<F: Field, const N: usize> MulAssign for FieldArray<F, N> {
#[inline]
fn mul_assign(&mut self, rhs: Self) {
self.0.iter_mut().zip(rhs.0).for_each(|(x, y)| *x *= y);
}
}

impl<F: Field, const N: usize> MulAssign<F> for FieldArray<F, N> {
#[inline]
fn mul_assign(&mut self, rhs: F) {
self.0.iter_mut().for_each(|x| *x *= rhs);
}
}

impl<F: Field, const N: usize> Sum for FieldArray<F, N> {
#[inline]
fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
iter.reduce(|lhs, rhs| lhs + rhs).unwrap_or(Self::zero())
}
}

impl<F: Field, const N: usize> Product for FieldArray<F, N> {
#[inline]
fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
iter.reduce(|lhs, rhs| lhs * rhs).unwrap_or(Self::one())
}
}
99 changes: 99 additions & 0 deletions field/src/batch_inverse.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
use alloc::vec;
use alloc::vec::Vec;

use crate::field::Field;

/// Batch multiplicative inverses with Montgomery's trick
/// This is Montgomery's trick. At a high level, we invert the product of the given field
/// elements, then derive the individual inverses from that via multiplication.
///
/// The usual Montgomery trick involves calculating an array of cumulative products,
/// resulting in a long dependency chain. To increase instruction-level parallelism, we
/// compute WIDTH separate cumulative product arrays that only meet at the end.
///
/// # Panics
/// Might panic if asserts or unwraps uncover a bug.
pub fn batch_multiplicative_inverse<F: Field>(x: &[F]) -> Vec<F> {
// Higher WIDTH increases instruction-level parallelism, but too high a value will cause us
// to run out of registers.
const WIDTH: usize = 4;
// JN note: WIDTH is 4. The code is specialized to this value and will need
// modification if it is changed. I tried to make it more generic, but Rust's const
// generics are not yet good enough.

// Handle special cases. Paradoxically, below is repetitive but concise.
// The branches should be very predictable.
let n = x.len();
if n == 0 {
return Vec::new();
} else if n == 1 {
return vec![x[0].inverse()];
} else if n == 2 {
let x01 = x[0] * x[1];
let x01inv = x01.inverse();
return vec![x01inv * x[1], x01inv * x[0]];
} else if n == 3 {
let x01 = x[0] * x[1];
let x012 = x01 * x[2];
let x012inv = x012.inverse();
let x01inv = x012inv * x[2];
return vec![x01inv * x[1], x01inv * x[0], x012inv * x01];
}
debug_assert!(n >= WIDTH);

// Buf is reused for a few things to save allocations.
// Fill buf with cumulative product of x, only taking every 4th value. Concretely, buf will
// be [
// x[0], x[1], x[2], x[3],
// x[0] * x[4], x[1] * x[5], x[2] * x[6], x[3] * x[7],
// x[0] * x[4] * x[8], x[1] * x[5] * x[9], x[2] * x[6] * x[10], x[3] * x[7] * x[11],
// ...
// ].
// If n is not a multiple of WIDTH, the result is truncated from the end. For example,
// for n == 5, we get [x[0], x[1], x[2], x[3], x[0] * x[4]].
let mut buf: Vec<F> = Vec::with_capacity(n);
// cumul_prod holds the last WIDTH elements of buf. This is redundant, but it's how we
// convince LLVM to keep the values in the registers.
let mut cumul_prod: [F; WIDTH] = x[..WIDTH].try_into().unwrap();
buf.extend(cumul_prod);
for (i, &xi) in x[WIDTH..].iter().enumerate() {
cumul_prod[i % WIDTH] *= xi;
buf.push(cumul_prod[i % WIDTH]);
}
debug_assert_eq!(buf.len(), n);

let mut a_inv = {
// This is where the four dependency chains meet.
// Take the last four elements of buf and invert them all.
let c01 = cumul_prod[0] * cumul_prod[1];
let c23 = cumul_prod[2] * cumul_prod[3];
let c0123 = c01 * c23;
let c0123inv = c0123.inverse();
let c01inv = c0123inv * c23;
let c23inv = c0123inv * c01;
[
c01inv * cumul_prod[1],
c01inv * cumul_prod[0],
c23inv * cumul_prod[3],
c23inv * cumul_prod[2],
]
};

for i in (WIDTH..n).rev() {
// buf[i - WIDTH] has not been written to by this loop, so it equals
// x[i % WIDTH] * x[i % WIDTH + WIDTH] * ... * x[i - WIDTH].
buf[i] = buf[i - WIDTH] * a_inv[i % WIDTH];
// buf[i] now holds the inverse of x[i].
a_inv[i % WIDTH] *= x[i];
}
for i in (0..WIDTH).rev() {
buf[i] = a_inv[i];
}

for (&bi, &xi) in buf.iter().zip(x) {
// Sanity check only.
debug_assert_eq!(bi * xi, F::one());
}

buf
}
Loading

0 comments on commit 14de622

Please sign in to comment.