Skip to content

Commit ef4a9fe

Browse files
author
enpsi
committed
feat: GKR2 verifier
1 parent 00c0cf9 commit ef4a9fe

File tree

6 files changed

+322
-39
lines changed

6 files changed

+322
-39
lines changed

gkr/src/prover/gkr_square.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ pub fn gkr_square_prove<C: GKRConfig, T: Transcript<C::ChallengeField>>(
2626
for _i in 0..circuit.layers.last().unwrap().output_var_num {
2727
rz0.push(transcript.generate_challenge_field_element());
2828
}
29+
log::trace!("Initial rz0: {:?}", rz0);
2930

3031
let mut r_simd = vec![];
3132
for _i in 0..C::get_field_pack_size().trailing_zeros() {
@@ -74,7 +75,7 @@ pub fn gkr_square_prove<C: GKRConfig, T: Transcript<C::ChallengeField>>(
7475
log::trace!("Layer {} proved", i);
7576
log::trace!("rz0.0: {:?}", rz0[0]);
7677
log::trace!("rz0.1: {:?}", rz0[1]);
77-
log::trace!("rz0.2: {:?}", rz0[2]);
78+
// log::trace!("rz0.2: {:?}", rz0[2]);
7879
}
7980

8081
end_timer!(timer);

gkr/src/tests/gkr_correctness.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ fn do_prove_verify<C: GKRConfig>(config: Config<C>, circuit: &mut Circuit<C>) {
307307
let (mut claimed_v, proof) = prover.prove(circuit);
308308

309309
// Verify
310-
// let verifier = Verifier::new(&config);
311-
// let public_input = vec![];
312-
// assert!(verifier.verify(circuit, &public_input, &mut claimed_v, &proof))
310+
let verifier = Verifier::new(&config);
311+
let public_input = vec![];
312+
assert!(verifier.verify(circuit, &public_input, &mut claimed_v, &proof))
313313
}

gkr/src/verifier.rs

+74-31
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use std::{
66
use arith::{Field, FieldSerde};
77
use ark_std::{end_timer, start_timer};
88
use circuit::{Circuit, CircuitLayer};
9-
use config::{Config, FiatShamirHashType, GKRConfig, PolynomialCommitmentType};
9+
use config::{Config, FiatShamirHashType, GKRConfig, GKRScheme, PolynomialCommitmentType};
1010
use sumcheck::{GKRVerifierHelper, VerifierScratchPad};
1111
use transcript::{
1212
BytesHashTranscript, FieldHashTranscript, Keccak256hasher, MIMCHasher, Proof, SHA256hasher,
@@ -17,6 +17,9 @@ use transcript::{
1717
use crate::grind;
1818
use crate::RawCommitment;
1919

20+
mod gkr_square;
21+
pub use gkr_square::gkr_square_verify;
22+
2023
#[inline(always)]
2124
fn verify_sumcheck_step<C: GKRConfig, T: Transcript<C::ChallengeField>>(
2225
mut proof_reader: impl Read,
@@ -31,17 +34,27 @@ fn verify_sumcheck_step<C: GKRConfig, T: Transcript<C::ChallengeField>>(
3134
ps.push(C::ChallengeField::deserialize_from(&mut proof_reader).unwrap());
3235
transcript.append_field_element(&ps[i]);
3336
}
37+
log::trace!("ps {:?}", ps);
3438

3539
let r = transcript.generate_challenge_field_element();
40+
log::trace!("r {:?}", r);
3641
randomness_vec.push(r);
3742

3843
let verified = (ps[0] + ps[1]) == *claimed_sum;
44+
log::trace!("verified {:?}", verified);
45+
log::trace!("claimed_sum {:?}", claimed_sum);
46+
log::trace!("ps[0] + ps[1] {:?}", ps[0] + ps[1]);
3947

4048
if degree == 2 {
4149
*claimed_sum = GKRVerifierHelper::degree_2_eval(&ps, r, sp);
4250
} else if degree == 3 {
4351
*claimed_sum = GKRVerifierHelper::degree_3_eval(&ps, r, sp);
52+
} else if degree == 6 {
53+
*claimed_sum = GKRVerifierHelper::degree_6_eval(&ps, r, sp);
54+
} else {
55+
panic!("unsupported degree");
4456
}
57+
log::trace!("next claimed_sum {:?}", claimed_sum);
4558

4659
verified
4760
}
@@ -287,38 +300,68 @@ impl<C: GKRConfig> Verifier<C> {
287300

288301
circuit.fill_rnd_coefs(transcript);
289302

290-
let (mut verified, rz0, rz1, r_simd, r_mpi, claimed_v0, claimed_v1) = gkr_verify(
291-
&self.config,
292-
circuit,
293-
public_input,
294-
claimed_v,
295-
transcript,
296-
&mut cursor,
297-
);
298-
299-
log::info!("GKR verification: {}", verified);
300-
301-
match self.config.polynomial_commitment_type {
302-
PolynomialCommitmentType::Raw => {
303-
// for Raw, no need to load from proof
304-
log::trace!("rz0.size() = {}", rz0.len());
305-
log::trace!("Poly_vals.size() = {}", commitment.poly_vals.len());
306-
307-
let v1 = commitment.mpi_verify(&rz0, &r_simd, &r_mpi, claimed_v0);
308-
verified &= v1;
309-
310-
if rz1.is_some() {
311-
let v2 = commitment.mpi_verify(
312-
rz1.as_ref().unwrap(),
313-
&r_simd,
314-
&r_mpi,
315-
claimed_v1.unwrap(),
316-
);
317-
verified &= v2;
303+
let verified = match self.config.gkr_scheme {
304+
GKRScheme::Vanilla => {
305+
let (mut verified, rz0, rz1, r_simd, r_mpi, claimed_v0, claimed_v1) = gkr_verify(
306+
&self.config,
307+
circuit,
308+
public_input,
309+
claimed_v,
310+
transcript,
311+
&mut cursor,
312+
);
313+
314+
log::info!("GKR verification: {}", verified);
315+
316+
match self.config.polynomial_commitment_type {
317+
PolynomialCommitmentType::Raw => {
318+
// for Raw, no need to load from proof
319+
log::trace!("rz0.size() = {}", rz0.len());
320+
log::trace!("Poly_vals.size() = {}", commitment.poly_vals.len());
321+
322+
let v1 = commitment.mpi_verify(&rz0, &r_simd, &r_mpi, claimed_v0);
323+
verified &= v1;
324+
325+
if rz1.is_some() {
326+
let v2 = commitment.mpi_verify(
327+
rz1.as_ref().unwrap(),
328+
&r_simd,
329+
&r_mpi,
330+
claimed_v1.unwrap(),
331+
);
332+
verified &= v2;
333+
}
334+
}
335+
_ => todo!(),
318336
}
337+
verified
319338
}
320-
_ => todo!(),
321-
}
339+
GKRScheme::GkrSquare => {
340+
let (mut verified, rz, r_simd, r_mpi, claimed_v) = gkr_square_verify(
341+
&self.config,
342+
circuit,
343+
public_input,
344+
claimed_v,
345+
transcript,
346+
&mut cursor,
347+
);
348+
349+
log::info!("GKR verification: {}", verified);
350+
351+
match self.config.polynomial_commitment_type {
352+
PolynomialCommitmentType::Raw => {
353+
// for Raw, no need to load from proof
354+
log::trace!("rz.size() = {}", rz.len());
355+
log::trace!("Poly_vals.size() = {}", commitment.poly_vals.len());
356+
357+
let v1 = commitment.mpi_verify(&rz, &r_simd, &r_mpi, claimed_v);
358+
verified &= v1;
359+
}
360+
_ => todo!(),
361+
}
362+
verified
363+
}
364+
};
322365

323366
end_timer!(timer);
324367

gkr/src/verifier/gkr_square.rs

+162
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
use super::verify_sumcheck_step;
2+
use arith::{Field, FieldSerde};
3+
use ark_std::{end_timer, start_timer};
4+
use circuit::{Circuit, CircuitLayer};
5+
use config::{Config, GKRConfig};
6+
use std::{io::Read, vec};
7+
use sumcheck::{GKRVerifierHelper, VerifierScratchPad};
8+
use transcript::Transcript;
9+
10+
#[allow(clippy::type_complexity)]
11+
pub fn gkr_square_verify<C: GKRConfig, T: Transcript<C::ChallengeField>>(
12+
config: &Config<C>,
13+
circuit: &Circuit<C>,
14+
public_input: &[C::SimdCircuitField],
15+
claimed_v: &C::ChallengeField,
16+
transcript: &mut T,
17+
mut proof_reader: impl Read,
18+
) -> (
19+
bool,
20+
Vec<C::ChallengeField>,
21+
Vec<C::ChallengeField>,
22+
Vec<C::ChallengeField>,
23+
C::ChallengeField,
24+
) {
25+
let timer = start_timer!(|| "gkr verify");
26+
let mut sp = VerifierScratchPad::<C>::new(config, circuit);
27+
28+
let layer_num = circuit.layers.len();
29+
let mut rz = vec![];
30+
let mut r_simd = vec![];
31+
let mut r_mpi = vec![];
32+
33+
for _ in 0..circuit.layers.last().unwrap().output_var_num {
34+
rz.push(transcript.generate_challenge_field_element());
35+
}
36+
log::trace!("rz {:?}", rz);
37+
38+
for _ in 0..C::get_field_pack_size().trailing_zeros() {
39+
r_simd.push(transcript.generate_challenge_field_element());
40+
}
41+
log::trace!("r_simd {:?}", r_simd);
42+
43+
// TODO: MPI support
44+
assert_eq!(
45+
config.mpi_config.world_size().trailing_zeros(),
46+
0,
47+
"MPI not supported yet"
48+
);
49+
for _ in 0..config.mpi_config.world_size().trailing_zeros() {
50+
r_mpi.push(transcript.generate_challenge_field_element());
51+
}
52+
53+
let mut verified = true;
54+
let mut current_claim = *claimed_v;
55+
log::trace!("Starting claim: {:?}", current_claim);
56+
for i in (0..layer_num).rev() {
57+
let cur_verified;
58+
(cur_verified, rz, r_simd, r_mpi, current_claim) = sumcheck_verify_gkr_square_layer(
59+
config,
60+
&circuit.layers[i],
61+
public_input,
62+
&rz,
63+
&r_simd,
64+
&r_mpi,
65+
current_claim,
66+
&mut proof_reader,
67+
transcript,
68+
&mut sp,
69+
i == layer_num - 1,
70+
);
71+
verified &= cur_verified;
72+
}
73+
end_timer!(timer);
74+
(verified, rz, r_simd, r_mpi, current_claim)
75+
}
76+
77+
#[allow(clippy::too_many_arguments)]
78+
#[allow(clippy::type_complexity)]
79+
#[allow(clippy::unnecessary_unwrap)]
80+
fn sumcheck_verify_gkr_square_layer<C: GKRConfig, T: Transcript<C::ChallengeField>>(
81+
config: &Config<C>,
82+
layer: &CircuitLayer<C>,
83+
public_input: &[C::SimdCircuitField],
84+
rz: &[C::ChallengeField],
85+
r_simd: &Vec<C::ChallengeField>,
86+
r_mpi: &Vec<C::ChallengeField>,
87+
current_claim: C::ChallengeField,
88+
mut proof_reader: impl Read,
89+
transcript: &mut T,
90+
sp: &mut VerifierScratchPad<C>,
91+
is_output_layer: bool,
92+
) -> (
93+
bool,
94+
Vec<C::ChallengeField>,
95+
Vec<C::ChallengeField>,
96+
Vec<C::ChallengeField>,
97+
C::ChallengeField,
98+
) {
99+
// GKR2 with Power5 gate has degree 6 polynomial
100+
let degree = 6;
101+
102+
GKRVerifierHelper::prepare_layer(layer, &None, rz, &None, r_simd, r_mpi, sp, is_output_layer);
103+
104+
let var_num = layer.input_var_num;
105+
let mut sum = current_claim;
106+
sum -= GKRVerifierHelper::eval_cst(&layer.const_, public_input, sp);
107+
108+
let mut rx = vec![];
109+
let mut r_simd_var = vec![];
110+
let mut r_mpi_var = vec![];
111+
let mut verified = true;
112+
113+
for i_var in 0..var_num {
114+
verified &= verify_sumcheck_step::<C, T>(
115+
&mut proof_reader,
116+
degree,
117+
transcript,
118+
&mut sum,
119+
&mut rx,
120+
sp,
121+
);
122+
log::trace!("x {} var, verified? {}", i_var, verified);
123+
}
124+
GKRVerifierHelper::set_rx(&rx, sp);
125+
126+
for i_var in 0..C::get_field_pack_size().trailing_zeros() {
127+
verified &= verify_sumcheck_step::<C, T>(
128+
&mut proof_reader,
129+
degree,
130+
transcript,
131+
&mut sum,
132+
&mut r_simd_var,
133+
sp,
134+
);
135+
log::trace!("simd {} var, verified? {}", i_var, verified);
136+
}
137+
GKRVerifierHelper::set_r_simd_xy(&r_simd_var, sp);
138+
139+
// TODO: nontrivial MPI support
140+
for _i_var in 0..config.mpi_config.world_size().trailing_zeros() {
141+
verified &= verify_sumcheck_step::<C, T>(
142+
&mut proof_reader,
143+
3,
144+
transcript,
145+
&mut sum,
146+
&mut r_mpi_var,
147+
sp,
148+
);
149+
// println!("{} mpi var, verified? {}", _i_var, verified);
150+
}
151+
GKRVerifierHelper::set_r_mpi_xy(&r_mpi_var, sp);
152+
153+
let v_claim = C::ChallengeField::deserialize_from(&mut proof_reader).unwrap();
154+
155+
sum -= v_claim * GKRVerifierHelper::eval_pow_1(&layer.uni, sp)
156+
+ v_claim.exp(5) * GKRVerifierHelper::eval_pow_5(&layer.uni, sp);
157+
transcript.append_field_element(&v_claim);
158+
159+
verified &= sum == C::ChallengeField::ZERO;
160+
161+
(verified, rx, r_simd_var, r_mpi_var, v_claim)
162+
}

sumcheck/src/scratch_pad.rs

+31
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,9 @@ pub struct VerifierScratchPad<C: GKRConfig> {
9191
pub gf2_deg2_eval_coef: C::ChallengeField, // 1 / x(x - 1)
9292
pub deg3_eval_at: [C::ChallengeField; 4],
9393
pub deg3_lag_denoms_inv: [C::ChallengeField; 4],
94+
// ====== for deg6 eval ======
95+
pub deg6_eval_at: [C::ChallengeField; 7],
96+
pub deg6_lag_denoms_inv: [C::ChallengeField; 7],
9497
}
9598

9699
impl<C: GKRConfig> VerifierScratchPad<C> {
@@ -143,6 +146,32 @@ impl<C: GKRConfig> VerifierScratchPad<C> {
143146
deg3_lag_denoms_inv[i] = denominator.inv().unwrap();
144147
}
145148

149+
let deg6_eval_at = if C::FIELD_TYPE == FieldType::GF2 {
150+
panic!("GF2 not supported yet");
151+
} else {
152+
[
153+
C::ChallengeField::ZERO,
154+
C::ChallengeField::ONE,
155+
C::ChallengeField::from(2),
156+
C::ChallengeField::from(3),
157+
C::ChallengeField::from(4),
158+
C::ChallengeField::from(5),
159+
C::ChallengeField::from(6),
160+
]
161+
};
162+
163+
let mut deg6_lag_denoms_inv = [C::ChallengeField::ZERO; 7];
164+
for i in 0..7 {
165+
let mut denominator = C::ChallengeField::ONE;
166+
for j in 0..7 {
167+
if j == i {
168+
continue;
169+
}
170+
denominator *= deg6_eval_at[i] - deg6_eval_at[j];
171+
}
172+
deg6_lag_denoms_inv[i] = denominator.inv().unwrap();
173+
}
174+
146175
Self {
147176
eq_evals_at_rz0: vec![C::ChallengeField::zero(); max_io_size],
148177
eq_evals_at_r_simd: vec![C::ChallengeField::zero(); simd_size],
@@ -162,6 +191,8 @@ impl<C: GKRConfig> VerifierScratchPad<C> {
162191
gf2_deg2_eval_coef,
163192
deg3_eval_at,
164193
deg3_lag_denoms_inv,
194+
deg6_eval_at,
195+
deg6_lag_denoms_inv,
165196
}
166197
}
167198
}

0 commit comments

Comments
 (0)