Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Make nargo::ops::transform_program idempotent #6695

Merged
269 changes: 264 additions & 5 deletions acvm-repo/acvm/src/compiler/transformers/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
use std::collections::BTreeSet;

use acir::{
circuit::{brillig::BrilligOutputs, Circuit, ExpressionWidth, Opcode},
circuit::{
self,
brillig::{BrilligInputs, BrilligOutputs},
opcodes::{BlackBoxFuncCall, FunctionInput, MemOp},
Circuit, ExpressionWidth, Opcode,
},
native_types::{Expression, Witness},
AcirField,
};
Expand Down Expand Up @@ -79,8 +86,6 @@ pub(super) fn transform_internal<F: AcirField>(
&mut next_witness_index,
);

// Update next_witness counter
next_witness_index += (intermediate_variables.len() - len) as u32;
vezenovm marked this conversation as resolved.
Show resolved Hide resolved
let mut new_opcodes = Vec::new();
for (g, (norm, w)) in intermediate_variables.iter().skip(len) {
// de-normalize
Expand Down Expand Up @@ -160,13 +165,267 @@ pub(super) fn transform_internal<F: AcirField>(
let mut merge_optimizer = MergeExpressionsOptimizer::new();
let (opcodes, new_acir_opcode_positions) =
merge_optimizer.eliminate_intermediate_variable(&acir, new_acir_opcode_positions);
// n.b. we do not update current_witness_index after the eliminate_intermediate_variable pass, the real index could be less.
let acir = Circuit {

// n.b. if we do not update current_witness_index after the eliminate_intermediate_variable pass, the real index could be less.
let mut acir = Circuit {
current_witness_index,
expression_width,
opcodes,
// The optimizer does not add new public inputs
..acir
};

// After the elimination of intermediate variables the `current_witness_index` is potentially higher than it needs to be,
// which would cause gaps if we ran the optimization a second time, making it look like new variables were added.
// Here we figure out what is the final state of witnesses by visiting each opcode.
let witnesses = WitnessCollector::collect_from_circuit(&acir);
if let Some(max_witness) = witnesses.last() {
acir.current_witness_index = max_witness.0;
}
aakoshh marked this conversation as resolved.
Show resolved Hide resolved

(acir, new_acir_opcode_positions)
}

/// Collect all witnesses in a circuit.
#[derive(Default, Clone, Debug)]
struct WitnessCollector {
witnesses: BTreeSet<Witness>,
}

impl WitnessCollector {
/// Collect all witnesses in a circuit.
fn collect_from_circuit<F: AcirField>(circuit: &Circuit<F>) -> BTreeSet<Witness> {
let mut collector = Self::default();
collector.extend_from_circuit(circuit);
collector.witnesses
}

fn add(&mut self, witness: Witness) {
self.witnesses.insert(witness);
}

fn add_many(&mut self, witnesses: &[Witness]) {
self.witnesses.extend(witnesses);
}

/// Add all witnesses from the circuit.
fn extend_from_circuit<F: AcirField>(&mut self, circuit: &Circuit<F>) {
self.witnesses.extend(&circuit.private_parameters);
self.witnesses.extend(&circuit.public_parameters.0);
self.witnesses.extend(&circuit.return_values.0);
for opcode in &circuit.opcodes {
self.extend_from_opcode(opcode);
}
}

/// Add witnesses from the opcode.
fn extend_from_opcode<F: AcirField>(&mut self, opcode: &Opcode<F>) {
match opcode {
Opcode::AssertZero(expr) => {
self.extend_from_expr(expr);
}
Opcode::BlackBoxFuncCall(call) => self.extend_from_blackbox(call),
Opcode::MemoryOp { block_id: _, op, predicate } => {
let MemOp { operation, index, value } = op;
self.extend_from_expr(operation);
self.extend_from_expr(index);
self.extend_from_expr(value);
if let Some(pred) = predicate {
self.extend_from_expr(pred);
}
}
Opcode::MemoryInit { block_id: _, init, block_type: _ } => {
for w in init {
self.add(*w);
}
}
// We keep the display for a BrilligCall and circuit Call separate as they
// are distinct in their functionality and we should maintain this separation for debugging.
Opcode::BrilligCall { id: _, inputs, outputs, predicate } => {
if let Some(pred) = predicate {
self.extend_from_expr(pred);
}
self.extend_from_brillig_inputs(inputs);
self.extend_from_brillig_outputs(outputs);
}
Opcode::Call { id: _, inputs, outputs, predicate } => {
if let Some(pred) = predicate {
self.extend_from_expr(pred);
}
self.add_many(inputs);
self.add_many(outputs);
}
}
}

fn extend_from_expr<F: AcirField>(&mut self, expr: &Expression<F>) {
for i in &expr.mul_terms {
self.add(i.1);
self.add(i.2);
}
for i in &expr.linear_combinations {
self.add(i.1);
}
}

fn extend_from_brillig_inputs<F: AcirField>(&mut self, inputs: &[BrilligInputs<F>]) {
for input in inputs {
match input {
BrilligInputs::Single(expr) => {
self.extend_from_expr(expr);
}
BrilligInputs::Array(exprs) => {
for expr in exprs {
self.extend_from_expr(expr);
}
}
BrilligInputs::MemoryArray(_) => {}
}
}
}

fn extend_from_brillig_outputs(&mut self, outputs: &[BrilligOutputs]) {
for output in outputs {
match output {
BrilligOutputs::Simple(w) => {
self.add(*w);
}
BrilligOutputs::Array(ws) => self.add_many(ws),
}
}
}

fn extend_from_blackbox<F: AcirField>(&mut self, call: &BlackBoxFuncCall<F>) {
match call {
BlackBoxFuncCall::AES128Encrypt { inputs, iv, key, outputs } => {
self.extend_from_function_inputs(inputs.as_slice());
self.extend_from_function_inputs(iv.as_slice());
self.extend_from_function_inputs(key.as_slice());
self.add_many(outputs);
}
BlackBoxFuncCall::AND { lhs, rhs, output } => {
self.extend_from_function_input(lhs);
self.extend_from_function_input(rhs);
self.add(*output);
}
BlackBoxFuncCall::XOR { lhs, rhs, output } => {
self.extend_from_function_input(lhs);
self.extend_from_function_input(rhs);
self.add(*output);
}
BlackBoxFuncCall::RANGE { input } => {
self.extend_from_function_input(input);
}
BlackBoxFuncCall::Blake2s { inputs, outputs } => {
self.extend_from_function_inputs(inputs.as_slice());
self.add_many(outputs.as_slice());
}
BlackBoxFuncCall::Blake3 { inputs, outputs } => {
self.extend_from_function_inputs(inputs.as_slice());
self.add_many(outputs.as_slice());
}
BlackBoxFuncCall::SchnorrVerify {
public_key_x,
public_key_y,
signature,
message,
output,
} => {
self.extend_from_function_input(public_key_x);
self.extend_from_function_input(public_key_y);
self.extend_from_function_inputs(signature.as_slice());
self.extend_from_function_inputs(message.as_slice());
self.add(*output);
}
BlackBoxFuncCall::EcdsaSecp256k1 {
public_key_x,
public_key_y,
signature,
hashed_message,
output,
} => {
self.extend_from_function_inputs(public_key_x.as_slice());
self.extend_from_function_inputs(public_key_y.as_slice());
self.extend_from_function_inputs(signature.as_slice());
self.extend_from_function_inputs(hashed_message.as_slice());
self.add(*output);
}
BlackBoxFuncCall::EcdsaSecp256r1 {
public_key_x,
public_key_y,
signature,
hashed_message,
output,
} => {
self.extend_from_function_inputs(public_key_x.as_slice());
self.extend_from_function_inputs(public_key_y.as_slice());
self.extend_from_function_inputs(signature.as_slice());
self.extend_from_function_inputs(hashed_message.as_slice());
self.add(*output);
}
BlackBoxFuncCall::MultiScalarMul { points, scalars, outputs } => {
self.extend_from_function_inputs(points.as_slice());
self.extend_from_function_inputs(scalars.as_slice());
let (x, y, i) = outputs;
self.add(*x);
self.add(*y);
self.add(*i);
}
BlackBoxFuncCall::EmbeddedCurveAdd { input1, input2, outputs } => {
self.extend_from_function_inputs(input1.as_slice());
self.extend_from_function_inputs(input2.as_slice());
let (x, y, i) = outputs;
self.add(*x);
self.add(*y);
self.add(*i);
}
BlackBoxFuncCall::Keccakf1600 { inputs, outputs } => {
self.extend_from_function_inputs(inputs.as_slice());
self.add_many(outputs.as_slice());
}
BlackBoxFuncCall::RecursiveAggregation {
verification_key,
proof,
public_inputs,
key_hash,
proof_type: _,
} => {
self.extend_from_function_inputs(verification_key.as_slice());
self.extend_from_function_inputs(proof.as_slice());
self.extend_from_function_inputs(public_inputs.as_slice());
self.extend_from_function_input(key_hash);
}
BlackBoxFuncCall::BigIntAdd { .. }
| BlackBoxFuncCall::BigIntSub { .. }
| BlackBoxFuncCall::BigIntMul { .. }
| BlackBoxFuncCall::BigIntDiv { .. } => {}
BlackBoxFuncCall::BigIntFromLeBytes { inputs, modulus: _, output: _ } => {
self.extend_from_function_inputs(inputs.as_slice());
}
BlackBoxFuncCall::BigIntToLeBytes { input: _, outputs } => {
self.add_many(outputs.as_slice());
}
BlackBoxFuncCall::Poseidon2Permutation { inputs, outputs, len: _ } => {
self.extend_from_function_inputs(inputs.as_slice());
self.add_many(outputs.as_slice());
}
BlackBoxFuncCall::Sha256Compression { inputs, hash_values, outputs } => {
self.extend_from_function_inputs(inputs.as_slice());
self.extend_from_function_inputs(hash_values.as_slice());
self.add_many(outputs.as_slice());
}
}
}

fn extend_from_function_input<F: AcirField>(&mut self, input: &FunctionInput<F>) {
if let circuit::opcodes::ConstantOrWitnessEnum::Witness(witness) = input.input() {
self.add(witness);
}
}

fn extend_from_function_inputs<F: AcirField>(&mut self, inputs: &[FunctionInput<F>]) {
for input in inputs {
self.extend_from_function_input(input);
}
}
}
Loading