From 2b0e2da5773845ee4480e89d8a3cb49113f642f2 Mon Sep 17 00:00:00 2001 From: "jtriley.eth" Date: Thu, 4 Jul 2024 22:19:28 -0500 Subject: [PATCH] binfield: usize -> enum --- src/field/binary_towers/extension.rs | 9 +++-- src/field/binary_towers/mod.rs | 57 +++++++++++++++++++--------- src/field/binary_towers/tests.rs | 12 +++--- 3 files changed, 51 insertions(+), 27 deletions(-) diff --git a/src/field/binary_towers/extension.rs b/src/field/binary_towers/extension.rs index 9387893f..ea588bdc 100644 --- a/src/field/binary_towers/extension.rs +++ b/src/field/binary_towers/extension.rs @@ -300,7 +300,7 @@ pub(super) fn multiply(a: &[BinaryField], b: &[BinaryField], k: usize) -> Vec Vec { pub(super) fn to_bool_vec(mut num: u64, length: usize) -> Vec { let mut result = Vec::new(); while num > 0 { - result.push(BinaryField::new(((num & 1) != 0) as u8)); + result.push(match num & 1 { + 0 => BinaryField::Zero, + _ => BinaryField::One, + }); num >>= 1; } - result.extend(std::iter::repeat(BinaryField::new(0)).take(length - result.len())); + result.extend(std::iter::repeat(BinaryField::Zero).take(length - result.len())); result } diff --git a/src/field/binary_towers/mod.rs b/src/field/binary_towers/mod.rs index 0149b04a..1e9b7550 100644 --- a/src/field/binary_towers/mod.rs +++ b/src/field/binary_towers/mod.rs @@ -13,41 +13,51 @@ pub use extension::BinaryTowers; /// binary field containing element `{0,1}` #[derive(Debug, Default, Hash, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] -pub struct BinaryField(u8); - -impl BinaryField { - /// create new binary field element - pub const fn new(value: u8) -> Self { - debug_assert!(value < 2, "value should be less than 2"); - Self(value) - } +pub enum BinaryField { + /// binary field element `0` + #[default] + Zero, + /// binary field element `1` + One, } impl FiniteField for BinaryField { - const ONE: Self = BinaryField(1); + const ONE: Self = BinaryField::One; const ORDER: usize = 2; const PRIMITIVE_ELEMENT: Self = Self::ONE; - const ZERO: Self = BinaryField(0); + const ZERO: Self = BinaryField::Zero; fn inverse(&self) -> Option { - if *self == Self::ZERO { - return None; + match *self { + Self::Zero => None, + Self::One => Some(Self::One), } - Some(*self) } fn pow(self, _: usize) -> Self { self } } impl From for BinaryField { - fn from(value: usize) -> Self { Self::new(value as u8) } + fn from(value: usize) -> Self { + match value { + 0 => BinaryField::Zero, + 1 => BinaryField::One, + _ => panic!("Invalid `usize` value. Must be 0 or 1."), + } + } } impl Add for BinaryField { type Output = Self; #[allow(clippy::suspicious_arithmetic_impl)] - fn add(self, rhs: Self) -> Self::Output { BinaryField::new(self.0 ^ rhs.0) } + fn add(self, rhs: Self) -> Self::Output { + if self == rhs { + Self::ZERO + } else { + Self::ONE + } + } } impl AddAssign for BinaryField { @@ -64,7 +74,13 @@ impl Sub for BinaryField { type Output = Self; #[allow(clippy::suspicious_arithmetic_impl)] - fn sub(self, rhs: Self) -> Self::Output { BinaryField(self.0 ^ rhs.0) } + fn sub(self, rhs: Self) -> Self::Output { + if self == rhs { + Self::ZERO + } else { + Self::ONE + } + } } impl SubAssign for BinaryField { @@ -81,7 +97,12 @@ impl Mul for BinaryField { type Output = Self; #[allow(clippy::suspicious_arithmetic_impl)] - fn mul(self, rhs: Self) -> Self::Output { BinaryField(self.0 & rhs.0) } + fn mul(self, rhs: Self) -> Self::Output { + match (self, rhs) { + (Self::One, Self::One) => Self::ONE, + _ => Self::ZERO, + } + } } impl MulAssign for BinaryField { @@ -98,7 +119,7 @@ impl Div for BinaryField { type Output = Self; #[allow(clippy::suspicious_arithmetic_impl)] - fn div(self, rhs: Self) -> Self::Output { self * rhs.inverse().unwrap() } + fn div(self, rhs: Self) -> Self::Output { self * rhs.inverse().expect("divide by zero") } } impl DivAssign for BinaryField { diff --git a/src/field/binary_towers/tests.rs b/src/field/binary_towers/tests.rs index eaaf0091..0db5bc08 100644 --- a/src/field/binary_towers/tests.rs +++ b/src/field/binary_towers/tests.rs @@ -133,7 +133,7 @@ pub(super) fn num_digits(n: u64) -> usize { fn from_bool_vec(num: Vec) -> u64 { let mut result: u64 = 0; for (i, &bit) in num.iter().rev().enumerate() { - if bit.0 == 1 { + if bit == BinaryField::One { result |= 1 << (num.len() - 1 - i); } } @@ -146,20 +146,20 @@ fn from_bool_vec(num: Vec) -> u64 { #[should_panic] #[case(1, 0)] fn binary_field_arithmetic(#[case] a: usize, #[case] b: usize) { - let arg1 = BinaryField::new(a as u8); - let arg2 = BinaryField::new(b as u8); + let arg1 = BinaryField::from(a); + let arg2 = BinaryField::from(b); let a_test = TestBinaryField::new(a); let b_test = TestBinaryField::new(b); - assert_eq!((arg1 + arg2).0, (a_test + b_test).value as u8); + assert_eq!((arg1 + arg2), BinaryField::from((a_test + b_test).value)); assert_eq!(arg1 - arg2, arg1 + arg2); - assert_eq!((arg1 * arg2).0, (a_test * b_test).value as u8); + assert_eq!((arg1 * arg2), BinaryField::from((a_test * b_test).value)); let inv_res = arg2.inverse(); assert!(inv_res.is_some()); assert_eq!(inv_res.unwrap(), arg2); - assert_eq!((arg1 / arg2).0, (a_test / b_test).value as u8); + assert_eq!((arg1 / arg2), BinaryField::from((a_test / b_test).value)); } #[rstest]