diff --git a/vm/src/hint_processor/cairo_1_hint_processor/circuit.rs b/vm/src/hint_processor/cairo_1_hint_processor/circuit.rs index 08a41d4bf4..dff4e5a42a 100644 --- a/vm/src/hint_processor/cairo_1_hint_processor/circuit.rs +++ b/vm/src/hint_processor/cairo_1_hint_processor/circuit.rs @@ -1,7 +1,7 @@ // Most of the `EvalCircuit` implementation is derived from the `cairo-lang-runner` crate. // https://github.com/starkware-libs/cairo/blob/main/crates/cairo-lang-runner/src/casm_run/circuit.rs -use core::{array, ops::Deref}; +use core::ops::Deref; use ark_ff::{One, Zero}; use num_bigint::{BigInt, BigUint, ToBigInt}; @@ -12,7 +12,10 @@ use starknet_types_core::felt::Felt; use crate::{ stdlib::boxed::Box, types::relocatable::{MaybeRelocatable, Relocatable}, - vm::{errors::hint_errors::HintError, vm_core::VirtualMachine}, + vm::{ + errors::{hint_errors::HintError, memory_errors::MemoryError}, + vm_core::VirtualMachine, + }, }; // A gate is defined by 3 offsets, the first two are the inputs and the third is the output. @@ -31,79 +34,97 @@ struct Circuit<'a> { } impl Circuit<'_> { - fn read_add_mod_value(&mut self, offset: usize) -> Option { - self.read_circuit_value((self.add_mod_offsets + offset).unwrap()) + fn read_add_mod_value(&mut self, offset: usize) -> Result, MemoryError> { + self.read_circuit_value((self.add_mod_offsets + offset)?) } - fn read_mul_mod_value(&mut self, offset: usize) -> Option { - self.read_circuit_value((self.mul_mod_offsets + offset).unwrap()) + fn read_mul_mod_value(&mut self, offset: usize) -> Result, MemoryError> { + self.read_circuit_value((self.mul_mod_offsets + offset)?) } - fn read_circuit_value(&mut self, offset: Relocatable) -> Option { - let value_ptr = self.get_value_ptr(offset); - read_circuit_value(self.vm, value_ptr) + fn read_circuit_value(&mut self, offset: Relocatable) -> Result, MemoryError> { + let value_ptr = self.get_value_ptr(offset)?; + Ok(read_circuit_value(self.vm, value_ptr)?) } - fn write_add_mod_value(&mut self, offset: usize, value: BigUint) { - self.write_circuit_value((self.add_mod_offsets + offset).unwrap(), value); + fn write_add_mod_value(&mut self, offset: usize, value: BigUint) -> Result<(), MemoryError> { + self.write_circuit_value((self.add_mod_offsets + offset)?, value)?; + + Ok(()) } - fn write_mul_mod_value(&mut self, offset: usize, value: BigUint) { - self.write_circuit_value((self.mul_mod_offsets + offset).unwrap(), value); + fn write_mul_mod_value(&mut self, offset: usize, value: BigUint) -> Result<(), MemoryError> { + self.write_circuit_value((self.mul_mod_offsets + offset)?, value)?; + + Ok(()) } - fn write_circuit_value(&mut self, offset: Relocatable, value: BigUint) { - let value_ptr = self.get_value_ptr(offset); - write_circuit_value(self.vm, value_ptr, value); + fn write_circuit_value( + &mut self, + offset: Relocatable, + value: BigUint, + ) -> Result<(), MemoryError> { + let value_ptr = self.get_value_ptr(offset)?; + write_circuit_value(self.vm, value_ptr, value)?; + + Ok(()) } - fn get_value_ptr(&self, address: Relocatable) -> Relocatable { - (self.values_ptr + self.vm.get_integer(address).unwrap().as_ref()).unwrap() + fn get_value_ptr(&self, address: Relocatable) -> Result { + (self.values_ptr + self.vm.get_integer(address)?.as_ref()).map_err(|e| MemoryError::Math(e)) } } -fn read_circuit_value(vm: &mut VirtualMachine, add: Relocatable) -> Option { +fn read_circuit_value( + vm: &mut VirtualMachine, + add: Relocatable, +) -> Result, MemoryError> { let mut res = BigUint::zero(); for l in (0..LIMBS_COUNT).rev() { - let add_l = (add + l).unwrap(); + let add_l = (add + l)?; match vm.get_maybe(&add_l) { Some(MaybeRelocatable::Int(limb)) => res = (res << 96) + limb.to_biguint(), - _ => return None, + _ => return Ok(None), } } - Some(res) + Ok(Some(res)) } -fn write_circuit_value(vm: &mut VirtualMachine, add: Relocatable, mut value: BigUint) { +fn write_circuit_value( + vm: &mut VirtualMachine, + add: Relocatable, + mut value: BigUint, +) -> Result<(), MemoryError> { for l in 0..LIMBS_COUNT { // get the nth limb from a circuit value let (new_value, rem) = value.div_rem(&(BigUint::one() << 96u8)); - vm.insert_value((add + l).unwrap(), Felt::from(rem)) - .unwrap(); + vm.insert_value((add + l)?, Felt::from(rem))?; value = new_value; } + + Ok(()) } // Finds the inverse of a value. // // If the value has no inverse, find a nullifier so that: // value * nullifier = 0 (mod modulus) -fn find_inverse(value: BigUint, modulus: &BigUint) -> (bool, BigUint) { +fn find_inverse(value: BigUint, modulus: &BigUint) -> Result<(bool, BigUint), HintError> { let ex_gcd = value .to_bigint() - .unwrap() - .extended_gcd(&modulus.to_bigint().unwrap()); + .ok_or(HintError::BigUintToBigIntFail)? + .extended_gcd(&modulus.to_bigint().ok_or(HintError::BigUintToBigIntFail)?); let gcd = ex_gcd.gcd.to_biguint().unwrap(); if gcd.is_one() { - return (true, get_modulus(&ex_gcd.x, modulus)); + return Ok((true, get_modulus(&ex_gcd.x, modulus))); } let nullifier = modulus / gcd; - (false, nullifier) + Ok((false, nullifier)) } fn get_modulus(value: &BigInt, modulus: &BigUint) -> BigUint { @@ -124,7 +145,7 @@ fn compute_gates( n_mul_mods: usize, modulus_ptr: Relocatable, ) -> Result { - let modulus = read_circuit_value(vm, modulus_ptr).unwrap(); + let modulus = read_circuit_value(vm, modulus_ptr)?.unwrap(); let mut circuit = Circuit { vm, values_ptr, @@ -141,21 +162,21 @@ fn compute_gates( loop { while addmod_idx < n_add_mods { - let lhs = circuit.read_add_mod_value(3 * addmod_idx); - let rhs = circuit.read_add_mod_value(3 * addmod_idx + 1); + let lhs = circuit.read_add_mod_value(3 * addmod_idx)?; + let rhs = circuit.read_add_mod_value(3 * addmod_idx + 1)?; match (lhs, rhs) { (Some(l), Some(r)) => { let res = (l + r) % &circuit.modulus; - circuit.write_add_mod_value(3 * addmod_idx + 2, res); + circuit.write_add_mod_value(3 * addmod_idx + 2, res)?; } // sub gate: lhs = res - rhs (None, Some(r)) => { - let Some(res) = circuit.read_add_mod_value(3 * addmod_idx + 2) else { + let Some(res) = circuit.read_add_mod_value(3 * addmod_idx + 2)? else { break; }; let value = (res + &circuit.modulus - r) % &circuit.modulus; - circuit.write_add_mod_value(3 * addmod_idx, value); + circuit.write_add_mod_value(3 * addmod_idx, value)?; } _ => break, } @@ -167,18 +188,18 @@ fn compute_gates( break; } - let lhs = circuit.read_mul_mod_value(3 * mulmod_idx); - let rhs = circuit.read_mul_mod_value(3 * mulmod_idx + 1); + let lhs = circuit.read_mul_mod_value(3 * mulmod_idx)?; + let rhs = circuit.read_mul_mod_value(3 * mulmod_idx + 1)?; match (lhs, rhs) { (Some(l), Some(r)) => { let res = (l * r) % &circuit.modulus; - circuit.write_mul_mod_value(3 * mulmod_idx + 2, res); + circuit.write_mul_mod_value(3 * mulmod_idx + 2, res)?; } // inverse gate: lhs = 1 / rhs (None, Some(r)) => { - let (success, res) = find_inverse(r, &circuit.modulus); - circuit.write_mul_mod_value(3 * mulmod_idx, res); + let (success, res) = find_inverse(r, &circuit.modulus)?; + circuit.write_mul_mod_value(3 * mulmod_idx, res)?; if !success { first_failure_idx = mulmod_idx; @@ -209,7 +230,7 @@ fn fill_instances( mut offsets_ptr: Relocatable, ) -> Result<(), HintError> { for i in 0..n_instances { - let instance_ptr = (built_ptr + i * MOD_BUILTIN_INSTACE_SIZE).unwrap(); + let instance_ptr = (built_ptr + i * MOD_BUILTIN_INSTACE_SIZE)?; for (idx, value) in modulus.iter().enumerate() { vm.insert_value((instance_ptr + idx)?, *value)?; @@ -256,8 +277,12 @@ pub fn eval_circuit( modulus_ptr, )?; - let modulus: [Felt; 4] = - array::from_fn(|l| *vm.get_integer((modulus_ptr + l).unwrap()).unwrap().deref()); + let modulus: [Felt; 4] = [ + *vm.get_integer(modulus_ptr)?.deref(), + *vm.get_integer((modulus_ptr + 1)?)?.deref(), + *vm.get_integer((modulus_ptr + 2)?)?.deref(), + *vm.get_integer((modulus_ptr + 3)?)?.deref(), + ]; fill_instances( vm, diff --git a/vm/src/vm/errors/hint_errors.rs b/vm/src/vm/errors/hint_errors.rs index 639ce9aaf2..849d38288d 100644 --- a/vm/src/vm/errors/hint_errors.rs +++ b/vm/src/vm/errors/hint_errors.rs @@ -154,6 +154,8 @@ pub enum HintError { BigintToU32Fail, #[error("BigInt to BigUint failed, BigInt is negative")] BigIntToBigUintFail, + #[error("BigUint to BigInt failed")] + BigUintToBigIntFail, #[error("Assertion failed, 0 <= ids.a % PRIME < range_check_builtin.bound \n a = {0} is out of range")] ValueOutOfRange(Box), #[error("Assertion failed, 0 <= ids.a % PRIME < range_check_builtin.bound \n a = {0} is out of range")]