Skip to content

Commit 933784e

Browse files
Rust extf mul divmod (#167)
Co-authored-by: feltroidprime <[email protected]>
1 parent 2986b7b commit 933784e

File tree

8 files changed

+195
-84
lines changed

8 files changed

+195
-84
lines changed

hydra/garaga/hints/ecip.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -107,12 +107,8 @@ def zk_ecip_hint(
107107
if ec_group_class == G1Point and use_rust:
108108
pts = []
109109
c_id = Bs[0].curve_id
110-
if c_id == CurveID.BLS12_381:
111-
nb = 48
112-
else:
113-
nb = 32
114110
for pt in Bs:
115-
pts.extend([pt.x.to_bytes(nb, "big"), pt.y.to_bytes(nb, "big")])
111+
pts.extend([pt.x, pt.y])
116112
field_type = get_field_type_from_ec_point(Bs[0])
117113
field = get_base_field(c_id.value, field_type)
118114

hydra/garaga/hints/extf_mul.py

+7-22
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import operator
22
from functools import reduce
33

4+
from garaga import garaga_rs
45
from garaga.algebra import ModuloCircuitElement, Polynomial, PyFelt
56
from garaga.definitions import (
67
direct_to_tower,
@@ -19,29 +20,13 @@ def nondeterministic_extension_field_mul_divmod(
1920
curve_id: int,
2021
extension_degree: int,
2122
) -> tuple[list[PyFelt], list[PyFelt]]:
22-
23-
Ps = [Polynomial(P) for P in Ps]
2423
field = get_base_field(curve_id)
25-
26-
P_irr = get_irreducible_poly(curve_id, extension_degree)
27-
28-
z_poly = reduce(operator.mul, Ps) # Π(Pi)
29-
z_polyq, z_polyr = divmod(z_poly, P_irr)
30-
31-
z_polyr_coeffs = z_polyr.get_coeffs()
32-
z_polyq_coeffs = z_polyq.get_coeffs()
33-
# assert len(z_polyq_coeffs) <= (
34-
# extension_degree - 1
35-
# ), f"len z_polyq_coeffs={len(z_polyq_coeffs)}, degree: {z_polyq.degree()}"
36-
assert (
37-
len(z_polyr_coeffs) <= extension_degree
38-
), f"len z_polyr_coeffs={len(z_polyr_coeffs)}, degree: {z_polyr.degree()}"
39-
40-
# Extend polynomials with 0 coefficients to match the expected lengths.
41-
# TODO : pass exact expected max degree when len(Ps)>2.
42-
z_polyq_coeffs += [field(0)] * (extension_degree - 1 - len(z_polyq_coeffs))
43-
z_polyr_coeffs += [field(0)] * (extension_degree - len(z_polyr_coeffs))
44-
24+
ps = [[c.value for c in P] for P in Ps]
25+
q, r = garaga_rs.nondeterministic_extension_field_mul_divmod(
26+
curve_id, extension_degree, ps
27+
)
28+
z_polyq_coeffs = [field(c) for c in q] if len(q) > 0 else [field.zero()]
29+
z_polyr_coeffs = [field(c) for c in r] if len(r) > 0 else [field.zero()]
4530
return (z_polyq_coeffs, z_polyr_coeffs)
4631

4732

tools/garaga_rs/src/ecip/core.rs

+21-50
Original file line numberDiff line numberDiff line change
@@ -3,118 +3,89 @@ use lambdaworks_math::elliptic_curve::short_weierstrass::curves::bls12_381::fiel
33
use lambdaworks_math::elliptic_curve::short_weierstrass::curves::bn_254::field_extension::BN254PrimeField;
44
use lambdaworks_math::field::element::FieldElement;
55
use lambdaworks_math::field::traits::IsPrimeField;
6-
use lambdaworks_math::traits::ByteConversion;
76

87
use crate::ecip::curve::{SECP256K1PrimeField, SECP256R1PrimeField, X25519PrimeField};
98
use crate::ecip::ff::FF;
109
use crate::ecip::g1point::G1Point;
1110
use crate::ecip::rational_function::FunctionFelt;
1211
use crate::ecip::rational_function::RationalFunction;
12+
use crate::io::parse_field_elements_from_list;
1313

1414
use num_bigint::{BigInt, BigUint, ToBigInt};
1515

1616
use super::curve::CurveParamsProvider;
1717

1818
pub fn zk_ecip_hint(
19-
list_bytes: Vec<Vec<u8>>,
19+
list_values: Vec<BigUint>,
2020
list_scalars: Vec<BigUint>,
2121
curve_id: usize,
2222
) -> Result<[Vec<String>; 5], String> {
2323
match curve_id {
2424
0 => {
25-
let list_felts: Vec<FieldElement<BN254PrimeField>> = list_bytes
26-
.into_iter()
27-
.map(|x| {
28-
FieldElement::<BN254PrimeField>::from_bytes_be(&x)
29-
.map_err(|e| format!("Byte conversion error: {:?}", e))
30-
})
31-
.collect::<Result<Vec<FieldElement<BN254PrimeField>>, _>>()?;
25+
let list_felts = parse_field_elements_from_list::<BN254PrimeField>(&list_values)?;
3226

3327
let points: Vec<G1Point<BN254PrimeField>> = list_felts
3428
.chunks(2)
3529
.map(|chunk| G1Point::new(chunk[0].clone(), chunk[1].clone()))
3630
.collect();
3731

38-
let scalars: Vec<Vec<i8>> = extract_scalars::<BN254PrimeField>(list_scalars);
39-
Ok(run_ecip::<BN254PrimeField>(points, scalars))
32+
let dss: Vec<Vec<i8>> = construct_digits_vectors::<BN254PrimeField>(list_scalars);
33+
Ok(run_ecip::<BN254PrimeField>(points, dss))
4034
}
4135
1 => {
42-
let list_felts: Vec<FieldElement<BLS12381PrimeField>> = list_bytes
43-
.into_iter()
44-
.map(|x| {
45-
FieldElement::<BLS12381PrimeField>::from_bytes_be(&x)
46-
.map_err(|e| format!("Byte conversion error: {:?}", e))
47-
})
48-
.collect::<Result<Vec<FieldElement<BLS12381PrimeField>>, _>>()?;
36+
let list_felts = parse_field_elements_from_list::<BLS12381PrimeField>(&list_values)?;
4937

5038
let points: Vec<G1Point<BLS12381PrimeField>> = list_felts
5139
.chunks(2)
5240
.map(|chunk| G1Point::new(chunk[0].clone(), chunk[1].clone()))
5341
.collect();
5442

55-
let scalars: Vec<Vec<i8>> = extract_scalars::<BLS12381PrimeField>(list_scalars);
56-
Ok(run_ecip::<BLS12381PrimeField>(points, scalars))
43+
let dss: Vec<Vec<i8>> = construct_digits_vectors::<BLS12381PrimeField>(list_scalars);
44+
Ok(run_ecip::<BLS12381PrimeField>(points, dss))
5745
}
5846
2 => {
59-
let list_felts: Vec<FieldElement<SECP256K1PrimeField>> = list_bytes
60-
.into_iter()
61-
.map(|x| {
62-
FieldElement::<SECP256K1PrimeField>::from_bytes_be(&x)
63-
.map_err(|e| format!("Byte conversion error: {:?}", e))
64-
})
65-
.collect::<Result<Vec<FieldElement<SECP256K1PrimeField>>, _>>()?;
47+
let list_felts = parse_field_elements_from_list::<SECP256K1PrimeField>(&list_values)?;
6648

6749
let points: Vec<G1Point<SECP256K1PrimeField>> = list_felts
6850
.chunks(2)
6951
.map(|chunk| G1Point::new(chunk[0].clone(), chunk[1].clone()))
7052
.collect();
7153

72-
let scalars: Vec<Vec<i8>> = extract_scalars::<SECP256K1PrimeField>(list_scalars);
73-
Ok(run_ecip::<SECP256K1PrimeField>(points, scalars))
54+
let dss: Vec<Vec<i8>> = construct_digits_vectors::<SECP256K1PrimeField>(list_scalars);
55+
Ok(run_ecip::<SECP256K1PrimeField>(points, dss))
7456
}
7557
3 => {
76-
let list_felts: Vec<FieldElement<SECP256R1PrimeField>> = list_bytes
77-
.into_iter()
78-
.map(|x| {
79-
FieldElement::<SECP256R1PrimeField>::from_bytes_be(&x)
80-
.map_err(|e| format!("Byte conversion error: {:?}", e))
81-
})
82-
.collect::<Result<Vec<FieldElement<SECP256R1PrimeField>>, _>>()?;
58+
let list_felts = parse_field_elements_from_list::<SECP256R1PrimeField>(&list_values)?;
8359

8460
let points: Vec<G1Point<SECP256R1PrimeField>> = list_felts
8561
.chunks(2)
8662
.map(|chunk| G1Point::new(chunk[0].clone(), chunk[1].clone()))
8763
.collect();
8864

89-
let scalars: Vec<Vec<i8>> = extract_scalars::<SECP256R1PrimeField>(list_scalars);
90-
Ok(run_ecip::<SECP256R1PrimeField>(points, scalars))
65+
let dss: Vec<Vec<i8>> = construct_digits_vectors::<SECP256R1PrimeField>(list_scalars);
66+
Ok(run_ecip::<SECP256R1PrimeField>(points, dss))
9167
}
9268
4 => {
93-
let list_felts: Vec<FieldElement<X25519PrimeField>> = list_bytes
94-
.into_iter()
95-
.map(|x| {
96-
FieldElement::<X25519PrimeField>::from_bytes_be(&x)
97-
.map_err(|e| format!("Byte conversion error: {:?}", e))
98-
})
99-
.collect::<Result<Vec<FieldElement<X25519PrimeField>>, _>>()?;
69+
let list_felts = parse_field_elements_from_list::<X25519PrimeField>(&list_values)?;
10070

10171
let points: Vec<G1Point<X25519PrimeField>> = list_felts
10272
.chunks(2)
10373
.map(|chunk| G1Point::new(chunk[0].clone(), chunk[1].clone()))
10474
.collect();
10575

106-
let scalars: Vec<Vec<i8>> = extract_scalars::<X25519PrimeField>(list_scalars);
107-
Ok(run_ecip::<X25519PrimeField>(points, scalars))
76+
let dss: Vec<Vec<i8>> = construct_digits_vectors::<X25519PrimeField>(list_scalars);
77+
Ok(run_ecip::<X25519PrimeField>(points, dss))
10878
}
10979
_ => Err(String::from("Invalid curve ID")),
11080
}
11181
}
11282

113-
fn extract_scalars<F: IsPrimeField + CurveParamsProvider<F>>(list: Vec<BigUint>) -> Vec<Vec<i8>> {
83+
fn construct_digits_vectors<F: IsPrimeField + CurveParamsProvider<F>>(
84+
list: Vec<BigUint>,
85+
) -> Vec<Vec<i8>> {
11486
let mut dss_ = Vec::new();
11587

116-
for i in 0..list.len() {
117-
let scalar_biguint = list[i].clone();
88+
for scalar_biguint in list {
11889
let neg_3_digits = neg_3_base_le(scalar_biguint);
11990
dss_.push(neg_3_digits);
12091
}

tools/garaga_rs/src/ecip/curve.rs

+28
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@ use lambdaworks_math::field::fields::montgomery_backed_prime_fields::{
55
IsModulus, MontgomeryBackendPrimeField,
66
};
77

8+
use crate::ecip::polynomial::Polynomial;
89
use lambdaworks_math::field::traits::IsPrimeField;
910
use lambdaworks_math::unsigned_integer::element::U256;
1011
use num_bigint::BigUint;
1112
use std::cmp::PartialEq;
13+
use std::collections::HashMap;
1214

1315
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1416
pub enum CurveID {
@@ -72,6 +74,21 @@ pub struct CurveParams<F: IsPrimeField> {
7274
pub g_y: FieldElement<F>,
7375
pub n: FieldElement<F>, // Order of the curve
7476
pub h: u32, // Cofactor
77+
pub irreducible_polys: HashMap<usize, &'static [i8]>,
78+
}
79+
80+
pub fn get_irreducible_poly<F: IsPrimeField + CurveParamsProvider<F>>(
81+
ext_degree: usize,
82+
) -> Polynomial<F> {
83+
let coeffs = (F::get_curve_params().irreducible_polys)[&ext_degree];
84+
fn lift<F: IsPrimeField>(c: i8) -> FieldElement<F> {
85+
if c >= 0 {
86+
FieldElement::from(c as u64)
87+
} else {
88+
-FieldElement::from(-c as u64)
89+
}
90+
}
91+
return Polynomial::new(coeffs.iter().map(|x| lift::<F>(*x)).collect());
7592
}
7693

7794
/// A trait that provides curve parameters for a specific field type.
@@ -99,6 +116,7 @@ impl CurveParamsProvider<SECP256K1PrimeField> for SECP256K1PrimeField {
99116
"FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141",
100117
),
101118
h: 1,
119+
irreducible_polys: HashMap::from([]), // Provide appropriate values here
102120
}
103121
}
104122
}
@@ -122,6 +140,7 @@ impl CurveParamsProvider<SECP256R1PrimeField> for SECP256R1PrimeField {
122140
"FFFFFFFF00000000FFFFFFFFFFFFFFFFBCE6FAADA7179E84F3B9CAC2FC632551",
123141
),
124142
h: 1,
143+
irreducible_polys: HashMap::from([]), // Provide appropriate values here
125144
}
126145
}
127146
}
@@ -143,6 +162,7 @@ impl CurveParamsProvider<X25519PrimeField> for X25519PrimeField {
143162
"1000000000000000000000000000000014DEF9DEA2F79CD65812631A5CF5D3ED",
144163
),
145164
h: 8,
165+
irreducible_polys: HashMap::from([]), // Provide appropriate values here
146166
}
147167
}
148168
}
@@ -158,6 +178,10 @@ impl CurveParamsProvider<BN254PrimeField> for BN254PrimeField {
158178
g_y: FieldElement::from_hex_unchecked("2"), // Replace with actual 'g_y'
159179
n: FieldElement::from_hex_unchecked("1"), // Replace with actual 'n'
160180
h: 1, // Replace with actual 'h'
181+
irreducible_polys: HashMap::from([
182+
(6, [82, 0, 0, -18, 0, 0, 1].as_slice()),
183+
(12, [82, 0, 0, 0, 0, 0, -18, 0, 0, 0, 0, 0, 1].as_slice()),
184+
]),
161185
}
162186
}
163187
}
@@ -173,6 +197,10 @@ impl CurveParamsProvider<BLS12381PrimeField> for BLS12381PrimeField {
173197
g_y: FieldElement::from_hex_unchecked("2"), // Replace with actual 'g_y'
174198
n: FieldElement::from_hex_unchecked("1"), // Replace with actual 'n'
175199
h: 1, // Replace with actual 'h'
200+
irreducible_polys: HashMap::from([
201+
(6, [2, 0, 0, -2, 0, 0, 1].as_slice()),
202+
(12, [2, 0, 0, 0, 0, 0, -2, 0, 0, 0, 0, 0, 1].as_slice()),
203+
]),
176204
}
177205
}
178206
}

tools/garaga_rs/src/ecip/polynomial.rs

+5-1
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@ impl<F: IsPrimeField> Polynomial<F> {
8080
Polynomial::new(vec![FieldElement::<F>::zero()])
8181
}
8282

83+
pub fn one() -> Self {
84+
Polynomial::new(vec![FieldElement::<F>::one()])
85+
}
86+
8387
pub fn mul_with_ref(&self, other: &Polynomial<F>) -> Polynomial<F> {
8488
if self.degree() == -1 || other.degree() == -1 {
8589
return Polynomial::zero();
@@ -142,7 +146,7 @@ impl<F: IsPrimeField> Polynomial<F> {
142146
for (i, coeff) in self.coefficients.iter().enumerate().skip(1) {
143147
let u_64 = i as u64;
144148
let degree = &FieldElement::<F>::from(u_64);
145-
new_coeffs[i - 1] = *(&coeff) * degree;
149+
new_coeffs[i - 1] = coeff * degree;
146150
}
147151
Polynomial::new(new_coeffs)
148152
}

tools/garaga_rs/src/extf_mul.rs

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
use crate::ecip::{
2+
curve::{get_irreducible_poly, CurveParamsProvider},
3+
polynomial::{pad_with_zero_coefficients_to_length, Polynomial},
4+
};
5+
use lambdaworks_math::field::traits::IsPrimeField;
6+
7+
// Returns (Q(X), R(X)) such that Π(Pi)(X) = Q(X) * P_irr(X) + R(X), for a given curve and extension degree.
8+
// R(X) is the result of the multiplication in the extension field.
9+
// Q(X) is used for verification.
10+
pub fn nondeterministic_extension_field_mul_divmod<F: IsPrimeField + CurveParamsProvider<F>>(
11+
ext_degree: usize,
12+
ps: Vec<Polynomial<F>>,
13+
) -> (Polynomial<F>, Polynomial<F>) {
14+
let mut z_poly = Polynomial::one();
15+
for poly in ps {
16+
z_poly = z_poly.mul_with_ref(&poly);
17+
}
18+
19+
let p_irr = get_irreducible_poly(ext_degree);
20+
21+
let (z_polyq, mut z_polyr) = z_poly.divmod(&p_irr);
22+
assert!(z_polyr.coefficients.len() <= ext_degree);
23+
24+
// Extend polynomial with 0 coefficients to match the expected length.
25+
if z_polyr.coefficients.len() < ext_degree {
26+
pad_with_zero_coefficients_to_length(&mut z_polyr, ext_degree);
27+
}
28+
29+
(z_polyq, z_polyr)
30+
}

tools/garaga_rs/src/io.rs

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
use lambdaworks_math::field::element::FieldElement;
2+
use lambdaworks_math::field::traits::IsPrimeField;
3+
use lambdaworks_math::traits::ByteConversion;
4+
use num_bigint::BigUint;
5+
6+
pub fn parse_field_elements_from_list<F: IsPrimeField>(
7+
coeffs: &[BigUint],
8+
) -> Result<Vec<FieldElement<F>>, String>
9+
where
10+
FieldElement<F>: ByteConversion,
11+
{
12+
let length = (F::field_bit_size() + 7) / 8;
13+
coeffs
14+
.iter()
15+
.map(|x| {
16+
let bytes = x.to_bytes_be();
17+
let pad_length = length.saturating_sub(bytes.len());
18+
let mut padded_bytes = vec![0u8; pad_length];
19+
padded_bytes.extend(bytes);
20+
FieldElement::from_bytes_be(&padded_bytes)
21+
.map_err(|e| format!("Byte conversion error: {:?}", e))
22+
})
23+
.collect()
24+
}

0 commit comments

Comments
 (0)