Skip to content

Commit

Permalink
feat/reed-solomon code (#79)
Browse files Browse the repository at this point in the history
* wip: encoding and decoding

* feat: basic decoding

* cleanup

* improve RS tests

* add README for `codes` module

* fix: broken CI
  • Loading branch information
Autoparallel authored May 30, 2024
1 parent d5397f7 commit efb0225
Show file tree
Hide file tree
Showing 6 changed files with 311 additions and 5 deletions.
24 changes: 20 additions & 4 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ repository ="https://github.com/thor314/ronkathon"
version ="0.1.0"

[dependencies]
rand ="0.8.5"
rand ="0.8.5"
itertools="0.13.0"

[dev-dependencies]
rstest ="0.19.0"
Expand Down
66 changes: 66 additions & 0 deletions src/codes/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Codes
In some cryptographic protocols, we need to do encoding and decoding.
Specifically, we implement the Reed-Solomon encoding and decoding in this module.

## Reed-Solomon Encoding
The Reed-Solomon encoding is a kind of error-correcting code that works by oversampling the input data and adding redundant symbols to the input data.
Our specific case takes a `Message<K, P>` that is a list of `K` field elements for the `PrimeField<P>`.
We can then call the `encode::<N>()` method on the `Message<K, P>` to get a `Codeword<N, K, P>` that is an encoded version of message with redundancy added.

First, we create a polynomial in `Monomial` form from our messsage by having each element of the message be a coefficient of the polynomial.
To do this encoding, we get the `N`th root of unity in the field and evaluate the polynomial at each of the powers of that root of unity.
We then store these evaluations in the codeword along with the point at which they were evaluated.
In effect, this is putting polynomial into the Lagrange basis where the node points are the `N`th roots of unity.

That is to say, we have a message $M = (m_0, m_1, \ldots, m_{K-1})$ and we encode it into a codeword $C = (c_0, c_1, \ldots, c_{N-1})$ where:
$$
\begin{align*}
c_i &= \sum_{j=0}^{K-1} m_j \cdot \omega^{ij}
\end{align*}
$$
where $\omega$ is the `N`th root of unity.

## Reed-Solomon Decoding
Given we have a `Codeword<M, K, P>`, we can call the `decode()` method to get a `Message<K, P>` that is the original message so long as the assertion `M>=N` holds.
Doing the decoding now just requires us to go from the Lagrange basis back to the monomial basis.
For example, for a degree 2 polynomial, we have:
$$
\begin{align*}
\ell_0(x) &= \frac{(x - \omega)(x - \omega^2)}{(1 - \omega)(1 - \omega^2)}\\
\ell_1(x) &= \frac{(x - 1)(x - \omega^2)}{(\omega - 1)(\omega - \omega^2)}\\
\ell_2(x) &= \frac{(x - 1)(x - \omega)}{(\omega^2- 1)(\omega^2 - \omega)}
\end{align*}
$$
where $\ell_i(x)$ is the $i$th Lagrange basis polynomial.

Effectively, we just need to expand out our codeword in this basis and collect terms:
$$
\begin{align*}
c_0 \ell_0(x) + c_1 \ell_1(x) + c_2 \ell_2(x) &= m_0 + m_1 x + m_2 x^2
\end{align*}
$$
where $m_i$ are the coefficients of the original message.
Note that we can pick any `N` points of the codeword to do this expansion, but we need at least `K` points to get the original message back.
For now, we just assume the code is the same length for the example.

Multiplying out the left hand side we get the constant coefficient as:
$$
\begin{align*}
m_0 = \frac{c_0 \omega \omega^2}{(\omega - 1)(\omega^2 - 1)} + \frac{c_1 (1) \omega^2}{(1 - \omega)(\omega^2 - \omega)} + \frac{c_2 (1) \omega}{(1 - \omega^2)(\omega - \omega^2)}
\end{align*}
$$
the linear coefficient as:
$$
\begin{align*}
-m_1 = \frac{c_0 (\omega + \omega^2)}{(\omega - 1)(\omega^2 - 1)} + \frac{c_1 (1 + \omega^2)}{(1 - \omega)(\omega^2 - \omega)} + \frac{c_2 (1 + \omega)}{(1 - \omega^2)(\omega - \omega^2)}
\end{align*}
$$
the quadratic coefficient as:
$$
\begin{align*}
m_2 = \frac{c_0 }{(\omega - 1)(\omega^2 - 1)} + \frac{c_1
}{(1 - \omega)(\omega^2 - \omega)} + \frac{c_2}{(1 - \omega^2)(\omega - \omega^2)}
\end{align*}
$$

This process was generalized in the `decode()` method.
6 changes: 6 additions & 0 deletions src/codes/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
//! This module contains the implementation of various error correction codes that are used in the
//! library. These codes can be used for other protocols.
use super::*;

pub mod reed_solomon;
216 changes: 216 additions & 0 deletions src/codes/reed_solomon.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
//! This contains an implementation of the Reed-Solomon error correction code.
use std::array;

use itertools::Itertools;

use super::*;

// TODO: We should allow for arbitrary data in the message so long as it can be
// converted into an element of a prime field and decoded the same way.

/// Represents a message that is to be encoded or decoded using the Reed-Solomon algorithm.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Message<const K: usize, const P: usize> {
/// The data that is to be encoded.
pub data: [PrimeField<P>; K],
}

/// A [`Codeword`] is a message that has been encoded using the Reed-Solomon algorithm.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Codeword<const N: usize, const K: usize, const P: usize> {
/// The data that has been encoded.
pub data: [Coordinate<N, P>; N],
}

/// A [`Coordinate`] represents a point on a polynomial curve with both the x and y coordinates.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Coordinate<const N: usize, const P: usize> {
/// The x-coordinate of the point.
pub x: PrimeField<P>,

/// The y-coordinate of the point.
pub y: PrimeField<P>,
}

impl<const K: usize, const P: usize> Message<K, P> {
/// Creates a new message from the given data.
pub fn new(data: [PrimeField<P>; K]) -> Self { Self { data } }

/// Encodes the message into a [`Codeword`].
pub fn encode<const N: usize>(self) -> Codeword<N, K, P> {
assert_ge::<N, K>();
let primitive_root = PrimeField::<P>::primitive_root_of_unity(N);
let polynomial = Polynomial::from(self);
Codeword {
data: array::from_fn(|pow| Coordinate {
x: primitive_root.pow(pow),
y: polynomial.evaluate(primitive_root.pow(pow)),
}),
}
}

/// Decodes the message from a [`Codeword`].
pub fn decode<const M: usize>(codeword: Codeword<M, K, P>) -> Self {
assert_ge::<M, K>();
let x_values: [PrimeField<P>; K] = {
let mut array = [PrimeField::<P>::ZERO; K];
for (i, x) in codeword.data.iter().map(|c| c.x).take(K).enumerate() {
array[i] = x;
}
array
};

let y_values: [PrimeField<P>; K] = {
let mut array = [PrimeField::<P>::ZERO; K];
for (i, y) in codeword.data.iter().map(|c| c.y).take(K).enumerate() {
array[i] = y;
}
array
};

let mut data = [PrimeField::<P>::ZERO; K];

#[allow(clippy::needless_range_loop)]
// i is the degree of the monomial.
for i in 0..K {
for j in 0..K {
let x_combinations: PrimeField<P> = if i % 2 == 1 {
PrimeField::<P>::ZERO - PrimeField::<P>::ONE
} else {
PrimeField::<P>::ONE
} * x_values
.iter()
.enumerate()
.filter(|&(index, _)| index != j)
.map(|(_, x)| x)
.combinations(K - 1 - i)
.map(|comb| comb.into_iter().copied().product::<PrimeField<P>>())
.sum::<PrimeField<P>>();
let y_combinations = y_values[j];
let numerator = x_combinations * y_combinations;

// this could be put into the x_combinations iter above.
let mut denominator = PrimeField::ONE; // x_values[i];
for k in 0..K {
if k == j {
continue;
}
denominator *= x_values[k] - x_values[j];
}

data[i] += numerator / denominator;
}
}
Message { data }
}
}

const fn assert_ge<const N: usize, const K: usize>() {
assert!(N >= K, "Code size must be greater than or equal to K");
}

impl<const K: usize, const P: usize> From<Message<K, P>> for Polynomial<Monomial, PrimeField<P>> {
fn from(message: Message<K, P>) -> Self { Polynomial::from(message.data) }
}

#[cfg(test)]
mod tests {

// NOTES: When we encode a message to same length, we get the first index correct when we decode.
// Otherwise we are getting the last correct.
use super::*;

// A mersenne prime because why not.
const P: usize = 127;

// Message size.
const K: usize = 3;

// Codeword size which satisfies (127-1) % 7 == 0, so we have roots of unity.
const N: usize = 7;

#[test]
fn encode_same_size() {
// Creat the message from an array using our constants above.
let mut arr = [PrimeField::<P>::ZERO; K];
arr[0] = PrimeField::<P>::new(1);
arr[1] = PrimeField::<P>::new(2);
arr[2] = PrimeField::<P>::new(3);
let message = Message::new(arr);

// Build the codeword from the message.
let codeword = message.encode::<K>();
assert_eq!(codeword.data.len(), K);
assert_eq!(codeword.data[0].x, PrimeField::<P>::new(1));
assert_eq!(codeword.data[1].x, PrimeField::<P>::new(107));
assert_eq!(codeword.data[2].x, PrimeField::<P>::new(19));
assert_eq!(codeword.data[0].y, PrimeField::<P>::new(6));
assert_eq!(codeword.data[1].y, PrimeField::<P>::new(18));
assert_eq!(codeword.data[2].y, PrimeField::<P>::new(106));
}

#[test]
fn encode_larger_size() {
// Creat the message from an array using our constants above.
let mut arr = [PrimeField::<P>::ZERO; K];
arr[0] = PrimeField::<P>::new(1);
arr[1] = PrimeField::<P>::new(2);
arr[2] = PrimeField::<P>::new(3);
let message = Message::new(arr);

// Build the codeword from the message.
let codeword = message.encode::<K>();
assert_eq!(codeword.data.len(), K);
assert_eq!(codeword.data[0].x, PrimeField::<P>::new(1));
assert_eq!(codeword.data[1].x, PrimeField::<P>::new(107));
assert_eq!(codeword.data[2].x, PrimeField::<P>::new(19));
assert_eq!(codeword.data[0].y, PrimeField::<P>::new(6));
assert_eq!(codeword.data[1].y, PrimeField::<P>::new(18));
assert_eq!(codeword.data[2].y, PrimeField::<P>::new(106));
}

#[test]
fn decoding() {
// Creat the message from an array using our constants above.
let mut arr = [PrimeField::<P>::ZERO; K];
arr[0] = PrimeField::<P>::new(1);
arr[1] = PrimeField::<P>::new(2);
arr[2] = PrimeField::<P>::new(3);
let message = Message::new(arr);

// Build the codeword from the message.
let codeword = message.encode::<N>();

// Decode the codeword back into a message.
let decoded = Message::decode::<N>(codeword);

assert_eq!(decoded.data[0], PrimeField::<P>::new(1));
assert_eq!(decoded.data[1], PrimeField::<P>::new(2));
assert_eq!(decoded.data[2], PrimeField::<P>::new(3));
}

#[test]
fn decoding_longer_message() {
// Creat the message from an array using our constants above.
let mut arr = [PrimeField::<P>::ZERO; 5];
arr[0] = PrimeField::<P>::new(1);
arr[1] = PrimeField::<P>::new(2);
arr[2] = PrimeField::<P>::new(3);
arr[3] = PrimeField::<P>::new(4);
arr[4] = PrimeField::<P>::new(5);
let message = Message::new(arr);

// Build the codeword from the message.
let codeword = message.encode::<N>();

// Decode the codeword back into a message.
let decoded = Message::decode::<N>(codeword);

assert_eq!(decoded.data[0], PrimeField::<P>::new(1));
assert_eq!(decoded.data[1], PrimeField::<P>::new(2));
assert_eq!(decoded.data[2], PrimeField::<P>::new(3));
assert_eq!(decoded.data[3], PrimeField::<P>::new(4));
assert_eq!(decoded.data[4], PrimeField::<P>::new(5));
}
}
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#![feature(generic_const_exprs)]
#![warn(missing_docs)]

pub mod codes;
pub mod curve;
pub mod ecdsa;
pub mod field;
Expand Down

0 comments on commit efb0225

Please sign in to comment.