diff --git a/examples/fixture/asm/kimchi/hint.asm b/examples/fixture/asm/kimchi/hint.asm index f514b61de..b3d4f0579 100644 --- a/examples/fixture/asm/kimchi/hint.asm +++ b/examples/fixture/asm/kimchi/hint.asm @@ -10,6 +10,8 @@ DoubleGeneric<1,-1> DoubleGeneric<1,-1> DoubleGeneric<1,-1> DoubleGeneric<1,0,0,0,-16> +DoubleGeneric<1,0,0,0,-32> +DoubleGeneric<1,0,0,0,-4> DoubleGeneric<1,-1> DoubleGeneric<1,0,0,0,-3> DoubleGeneric<1,0,-1,0,1> @@ -20,11 +22,11 @@ DoubleGeneric<1,0,-1,0,1> DoubleGeneric<1,0,0,0,-1> DoubleGeneric<1,0,0,0,-1> DoubleGeneric<1,-1> -(0,0) -> (18,0) -(1,0) -> (2,0) -> (9,1) +(0,0) -> (20,0) +(1,0) -> (2,0) -> (11,1) (4,0) -> (6,1) (4,2) -> (5,1) -(5,0) -> (7,1) -> (18,1) -(9,0) -> (11,0) -(14,1) -> (15,0) -(15,2) -> (16,0) +(5,0) -> (7,1) -> (20,1) +(11,0) -> (13,0) +(16,1) -> (17,0) +(17,2) -> (18,0) diff --git a/examples/fixture/asm/r1cs/hint.asm b/examples/fixture/asm/r1cs/hint.asm index fdcdbd324..8a40c1068 100644 --- a/examples/fixture/asm/r1cs/hint.asm +++ b/examples/fixture/asm/r1cs/hint.asm @@ -7,10 +7,12 @@ v_5 == (v_6) * (1) v_4 == (v_7) * (1) 16 == (v_8) * (1) -v_2 == (v_9) * (1) -3 == (v_10) * (1) -1 == (v_11) * (1) -1 == (v_12) * (1) -1 == (-1 * v_13 + 1) * (1) +32 == (v_9) * (1) +4 == (v_10) * (1) +v_2 == (v_11) * (1) +3 == (v_12) * (1) +1 == (v_13) * (1) 1 == (v_14) * (1) +1 == (-1 * v_15 + 1) * (1) +1 == (v_16) * (1) v_4 == (v_1) * (1) diff --git a/examples/hint.no b/examples/hint.no index bea559acf..9ee209696 100644 --- a/examples/hint.no +++ b/examples/hint.no @@ -3,6 +3,14 @@ struct Thing { pub yy: Field, } +fn init_arr(const LEN: Field) -> [Field; LEN] { + return [0; LEN]; +} + +hint fn cst_div(const lhs: Field, const rhs: Field) -> Field { + return lhs / rhs; +} + hint fn mul(lhs: Field, rhs: Field) -> Field { return lhs * rhs; } @@ -20,16 +28,20 @@ hint fn ite(lhs: Field, rhs: Field) -> Field { return if lhs != rhs { lhs } else { rhs }; } -hint fn exp(const EXP: Field, val: Field) -> Field { - let mut res = val; - - for num in 1..EXP { - res = res * val; - } +hint fn exp(const EXP: Field, base: Field) -> Field { + let res = base ** EXP; return res; } +hint fn lshift(val: Field, shift: Field) -> Field { + return val << shift; +} + +hint fn rem(lhs: Field, rhs: Field) -> Field { + return lhs % rhs; +} + hint fn sub(lhs: Field, rhs: Field) -> Field { return lhs - rhs; } @@ -68,6 +80,12 @@ fn main(pub public_input: Field, private_input: Field) -> Field { let kk = unsafe exp(4, public_input); assert_eq(kk, 16); + let k2 = unsafe lshift(public_input, 4); + assert_eq(k2, 32); + + let ll = unsafe rem(kk, 12); + assert_eq(ll, 4); + let thing = unsafe multiple_inputs_outputs([public_input, 3]); // have to include all the outputs from hint function, otherwise it throws vars not in circuit error. // this is because each individual element in the hint output maps to a separate cell var in noname. @@ -82,5 +100,9 @@ fn main(pub public_input: Field, private_input: Field) -> Field { assert(!oo[1]); assert(oo[2]); + // mast phase can fold the constant value using hint functions + let one = unsafe cst_div(2, 2); + let arr = init_arr(one); + assert_eq(arr[0], 0); return xx; } \ No newline at end of file diff --git a/src/backends/kimchi/mod.rs b/src/backends/kimchi/mod.rs index 74a998a64..b448b4949 100644 --- a/src/backends/kimchi/mod.rs +++ b/src/backends/kimchi/mod.rs @@ -89,7 +89,7 @@ pub struct KimchiVesta { /// This is how you compute the value of each variable during witness generation. /// It is created during circuit generation. - pub(crate) vars_to_value: HashMap>, + pub(crate) vars_to_value: HashMap, Span)>, /// The execution trace table with vars as placeholders. /// It is created during circuit generation, @@ -307,7 +307,7 @@ impl Backend for KimchiVesta { self.next_variable += 1; // store it in the circuit_writer - self.vars_to_value.insert(var.index, val); + self.vars_to_value.insert(var.index, (val, span)); var } @@ -365,6 +365,16 @@ impl Backend for KimchiVesta { for var in 0..self.next_variable { if !written_vars.contains(&var) && !disable_safety_check { + let (val, span) = self + .vars_to_value + .get(&var) + .expect("a var should be in vars_to_value"); + + if matches!(val, Value::HintIR(..)) { + println!("a HintIR value not used in the circuit: {:?}", span); + continue; + } + if let Some(private_cell_var) = self .private_input_cell_vars .iter() @@ -378,7 +388,7 @@ impl Backend for KimchiVesta { ); Err(err)?; } else { - Err(Error::new("contraint-finalization", ErrorKind::UnexpectedError("there's a bug in the circuit_writer, some cellvar does not end up being a cellvar in the circuit!"), Span::default()))?; + Err(Error::new("contraint-finalization", ErrorKind::UnexpectedError("there's a bug in the circuit_writer, some cellvar does not end up being a cellvar in the circuit!"), *span))?; } } } @@ -403,7 +413,7 @@ impl Backend for KimchiVesta { let var_idx = pub_var.cvar().unwrap().index; let prev = self .vars_to_value - .insert(var_idx, Value::PublicOutput(Some(ret_var))); + .insert(var_idx, (Value::PublicOutput(Some(ret_var)), ret_var.span)); assert!(prev.is_some()); } } @@ -419,7 +429,7 @@ impl Backend for KimchiVesta { var: &Self::Var, ) -> crate::error::Result { let val = self.vars_to_value.get(&var.index).unwrap(); - self.compute_val(env, val, var.index) + self.compute_val(env, &val.0, var.index) } fn generate_witness( @@ -446,7 +456,7 @@ impl Backend for KimchiVesta { // if it's a public output, defer it's computation if matches!( self.vars_to_value.get(&var.index), - Some(Value::PublicOutput(_)) + Some((Value::PublicOutput(_), ..)) ) { public_outputs_vars .entry(*var) diff --git a/src/backends/mod.rs b/src/backends/mod.rs index 0854c2b83..14d5df27e 100644 --- a/src/backends/mod.rs +++ b/src/backends/mod.rs @@ -1,13 +1,12 @@ use std::{fmt::Debug, str::FromStr}; use ::kimchi::o1_utils::FieldHelpers; -use ark_ff::{Field, One, PrimeField, Zero}; -use circ::ir::term::precomp::PreComp; +use ark_ff::{Field, One, Zero}; use fxhash::FxHashMap; use num_bigint::BigUint; use crate::{ - circuit_writer::VarInfo, + circuit_writer::{ir::IRWriter, VarInfo}, compiler::Sources, constants::Span, error::{Error, ErrorKind, Result}, @@ -206,14 +205,9 @@ pub trait Backend: Clone { Ok(res) } - Value::HintIR(t, named_vars) => { - let mut precomp = PreComp::new(); - // For hint evaluation purpose, precomp only has only one output and no connections with other parts, - // so just use a dummy output var name. - precomp.add_output("x".to_string(), t.clone()); - + Value::HintIR(t, named_vars, logs) => { // map the named vars to env - let env = named_vars + let ir_env = named_vars .iter() .map(|(name, var)| { let val = match var { @@ -227,24 +221,21 @@ pub trait Backend: Clone { }) .collect::>(); - // evaluate and get the only one output - let eval_map = precomp.eval(&env); - let value = eval_map.get("x").unwrap(); - // convert to field - let res = match value { - circ::ir::term::Value::Field(f) => { - let bytes = f.i().to_digits::(rug::integer::Order::Lsf); - Self::Field::from_le_bytes_mod_order(&bytes) - } - circ::ir::term::Value::Bool(b) => { - if *b { - Self::Field::one() - } else { - Self::Field::zero() - } - } - _ => panic!("unexpected output type"), - }; + // evaluate logs + for log in logs { + // check the cache on env and log, and only evaluate if not in cache + let res: Vec = IRWriter::::eval_ir(&ir_env, log); + // format and print out array + println!( + "log: {:#?}", + res.iter().map(|f| f.pretty()).collect::>() + ); + } + + // evaluate the term + let res = IRWriter::::eval_ir(&ir_env, t)[0]; + + env.cached_values.insert(cache_key, res); // cache Ok(res) } diff --git a/src/backends/r1cs/mod.rs b/src/backends/r1cs/mod.rs index 57fcaf170..5a2461f1e 100644 --- a/src/backends/r1cs/mod.rs +++ b/src/backends/r1cs/mod.rs @@ -267,7 +267,7 @@ where { /// Constraints in the r1cs. constraints: Vec>, - witness_vector: Vec>, + witness_vector: Vec<(Value, Span)>, /// Debug information for each constraint. debug_info: Vec, /// Debug information for var info. @@ -384,7 +384,7 @@ where span, }; - self.witness_vector.insert(var.index, val); + self.witness_vector.insert(var.index, (val, span)); LinearCombination::from(var) } @@ -419,8 +419,9 @@ where // replace the computation of the public output vars with the actual variables being returned here let var_idx = pub_var.cvar().unwrap().to_cell_var().index; let prev = &self.witness_vector[var_idx]; - assert!(matches!(prev, Value::PublicOutput(None))); - self.witness_vector[var_idx] = Value::PublicOutput(Some(ret_var)); + assert!(matches!(prev.0, Value::PublicOutput(None))); + self.witness_vector[var_idx] = + (Value::PublicOutput(Some(ret_var.clone())), ret_var.span); } } @@ -435,7 +436,7 @@ where } // check if every cell vars end up being a cell var in the circuit or public output - for (index, _) in self.witness_vector.iter().enumerate() { + for (index, (val, span)) in self.witness_vector.iter().enumerate() { // Skip the first var which is always 1 // - In a linear combination, each of the vars can be paired with a coefficient. // - The first var is assumed to be the factor of the constant of a linear combination. @@ -444,6 +445,11 @@ where } if !written_vars.contains(&index) && !disable_safety_check { + // ignore HintIR val + if let Value::HintIR(..) = val { + println!("a HintIR value not used in the circuit: {:?}", span); + continue; + } if let Some(private_cell_var) = self .private_input_cell_vars .iter() @@ -458,7 +464,7 @@ where Err(Error::new( "constraint-finalization", ErrorKind::UnexpectedError("there's a bug in the circuit_writer, some cellvar does not end up being a cellvar in the circuit!"), - Span::default(), + *span, ))? } } @@ -478,7 +484,7 @@ where for (var, factor) in &lc.terms { let var_val = self.witness_vector.get(var.index).unwrap(); - let calc = self.compute_val(env, var_val, var.index)? * factor; + let calc = self.compute_val(env, &var_val.0, var.index)? * factor; val += calc; } @@ -501,11 +507,11 @@ where .iter() .enumerate() .map(|(index, val)| { - match val { + match val.0 { // Defer calculation for output vars. // The reasoning behind this is to avoid deep recursion potentially triggered by the public output var at the beginning. Value::PublicOutput(_) => Ok(F::zero()), - _ => self.compute_val(witness_env, val, index), + _ => self.compute_val(witness_env, &val.0, index), } }) .collect::>>()?; @@ -735,7 +741,7 @@ mod tests { // first var should be initialized as 1 assert_eq!(r1cs.witness_vector.len(), 1); - match &r1cs.witness_vector[0] { + match &r1cs.witness_vector[0].0 { crate::var::Value::Constant(cst) => { assert_eq!(*cst, R1csBls12381Field::one()); } diff --git a/src/circuit_writer/ir.rs b/src/circuit_writer/ir.rs index bceb39ca3..837e3f0d1 100644 --- a/src/circuit_writer/ir.rs +++ b/src/circuit_writer/ir.rs @@ -1,8 +1,13 @@ -use ark_ff::Zero; +use ark_ff::{One, PrimeField, Zero}; use circ::{ - ir::term::{leaf_term, term, BoolNaryOp, Op, PfNaryOp, PfUnOp, Sort, Term, Value}, + ir::term::{ + leaf_term, precomp::PreComp, term, BoolNaryOp, BvBinOp, IntBinOp, IntBinPred, IntNaryOp, + Op, PfNaryOp, PfUnOp, Sort, Term, Value, + }, term, }; +use fxhash::FxHashMap; +use kimchi::o1_utils::FieldHelpers; use num_bigint::BigUint; use serde::{Deserialize, Serialize}; use std::collections::HashMap; @@ -12,13 +17,13 @@ use crate::{ constants::Span, error::{Error, ErrorKind, Result}, imports::FnKind, - mast::Mast, + mast::PropagatedConstant, parser::{ types::{ForLoopArgument, FunctionDef, Stmt, StmtKind, TyKind}, Expr, ExprKind, Op2, }, syntax::is_type, - type_checker::{ConstInfo, FnInfo, FullyQualified, StructInfo}, + type_checker::{ConstInfo, FnInfo, FullyQualified, StructInfo, TypeChecker}, }; /// Same as [crate::var::Var], but with Term instead of ConstOrCell. @@ -93,11 +98,10 @@ impl Var { } } - pub fn constant(&self) -> Option { + pub fn constant(&self) -> Option { if self.cvars.len() == 1 { - self.cvars[0] - .as_value_opt() - .map(|v| BigUint::from(v.as_pf().i().to_u64_wrapping())) + let env = fxhash::FxHashMap::default(); + Some(IRWriter::::eval_ir(&env, &self.cvars[0])[0].to_biguint()) } else { None } @@ -124,9 +128,9 @@ pub enum VarOrRef { } impl VarOrRef { - pub(crate) fn constant(&self) -> Option { + pub(crate) fn constant(&self) -> Option { match self { - VarOrRef::Var(var) => var.constant(), + VarOrRef::Var(var) => var.constant::(), VarOrRef::Ref { .. } => None, } } @@ -369,8 +373,9 @@ impl FnEnv { #[derive(Debug)] /// This converts the MAST to circ IR. /// Currently it is only for hint functions. -pub(crate) struct IRWriter { - pub(crate) typed: Mast, +pub struct IRWriter { + pub typed: TypeChecker, + pub logs: Vec, } impl IRWriter { @@ -406,9 +411,8 @@ impl IRWriter { .ok_or_else(|| { self.error(ErrorKind::CannotComputeExpression, range.start.span) })? - .constant() - .expect("expected constant") - .into(); + .constant::() + .expect("expected constant"); let start: u32 = start_bg.try_into().map_err(|_| { self.error(ErrorKind::InvalidRangeSize, range.start.span) })?; @@ -418,9 +422,8 @@ impl IRWriter { .ok_or_else(|| { self.error(ErrorKind::CannotComputeExpression, range.end.span) })? - .constant() - .expect("expected constant") - .into(); + .constant::() + .expect("expected constant"); let end: u32 = end_bg .try_into() .map_err(|_| self.error(ErrorKind::InvalidRangeSize, range.end.span))?; @@ -572,16 +575,151 @@ impl IRWriter { return Ok(vec![]); } - let res = ir.unwrap().cvars.into_iter().map(|v| { + let logs = self.logs.clone(); + let logs_terms: Vec = logs + .into_iter() + .map(|v| term(Op::Tuple, v.var.cvars)) + .collect(); + + self.logs.clear(); + + let res = ir.unwrap().cvars.into_iter().enumerate().map(|(i, v)| { // With the current setup to calculate symbolic values, the [compute_val] can only compute for one symbolic variable, - // thus it has to evaluate each symbolic variable separately from a hint function. + // it has to evaluate each symbolic variable separately from a hint function. // Thus, this could introduce some performance overhead if the hint returns multiple symbolic variables. - crate::var::Value::HintIR(v, named_args.clone()) + // Maybe this can be batched and cached in the [compute_val] function. + + // Each compiled IR can contain multiple terms, as hint function output could be array or struct. + // Each term needs to be evaluated separately. + // For logs, there could be multiple logs for a compiled hint function. + // To avoid redundant logs, here we only evaluate log terms once with the first term. + if i == 0 { + crate::var::Value::HintIR(v, named_args.clone(), logs_terms.clone()) + } else { + crate::var::Value::HintIR(v, named_args.clone(), Vec::new()) + } }); Ok(res.collect()) } + /// Evaluate a single IR term. + pub fn eval_ir( + env: &FxHashMap, + t: &circ::ir::term::Term, + ) -> Vec { + let mut precomp = PreComp::new(); + // For hint evaluation purpose, precomp only has only one output and no connections with other parts, + // so just use a dummy output var name. + precomp.add_output("x".to_string(), t.clone()); + // evaluate and get the only one output + let eval_map = precomp.eval(env); + let value = eval_map.get("x").unwrap(); + // convert to field + match value { + circ::ir::term::Value::Field(f) => { + let bytes = f.i().to_digits::(rug::integer::Order::Lsf); + // todo: should we allow field overflow in hint evaluation? + vec![B::Field::from_le_bytes_mod_order(&bytes)] + } + circ::ir::term::Value::Bool(b) => { + if *b { + vec![B::Field::one()] + } else { + vec![B::Field::zero()] + } + } + circ::ir::term::Value::BitVector(bv) => { + let bytes = bv.uint().to_digits::(rug::integer::Order::Lsf); + // todo: should we allow field overflow in hint evaluation? + vec![B::Field::from_le_bytes_mod_order(&bytes)] + } + circ::ir::term::Value::Int(int) => { + let bytes = int.to_digits::(rug::integer::Order::Lsf); + vec![B::Field::from_le_bytes_mod_order(&bytes)] + } + circ::ir::term::Value::Tuple(v) => { + let mut res = Vec::new(); + for v in v { + match v { + circ::ir::term::Value::Field(f) => { + let bytes = f.i().to_digits::(rug::integer::Order::Lsf); + res.push(B::Field::from_le_bytes_mod_order(&bytes)); + } + circ::ir::term::Value::Bool(b) => { + if *b { + res.push(B::Field::one()); + } else { + res.push(B::Field::zero()); + } + } + circ::ir::term::Value::BitVector(bv) => { + let bytes = bv.uint().to_digits::(rug::integer::Order::Lsf); + res.push(B::Field::from_le_bytes_mod_order(&bytes)); + } + circ::ir::term::Value::Int(int) => { + let bytes = int.to_digits::(rug::integer::Order::Lsf); + res.push(B::Field::from_le_bytes_mod_order(&bytes)); + } + circ::ir::term::Value::Tuple(_) => { + panic!("nested tuple is not supported"); + } + _ => panic!("unexpected output type"), + } + } + res + } + _ => panic!("unexpected output type"), + } + } + + /// This is used in MAST phase to fold constant values. + pub fn evaluate( + &mut self, + function: &FunctionDef, + args: Vec, + ) -> Result { + assert!(!function.is_main()); + + // create new fn_env + let fn_env = &mut FnEnv::new(); + + // set arguments + assert_eq!(function.sig.arguments.len(), args.len()); + + // create circ var terms for the arguments + for (arg, observed) in function.sig.arguments.iter().zip(args) { + let name = &arg.name.value; + match observed { + PropagatedConstant::Single(cst) => { + let f = B::Field::from(cst); + let cvar = leaf_term(Op::new_const(Value::Field(f.to_circ_field()))); + let var = Var::new(vec![cvar], arg.name.span); + let var_info = VarInfo::new(var, false, Some(arg.typ.kind.clone())); + self.add_local_var(fn_env, name.clone(), var_info).unwrap(); + } + _ => unimplemented!(), + } + } + + // compile it and potentially return a return value + let ir = self.compile_block(fn_env, &function.body)?; + + let res: Vec<_> = ir + .unwrap() + .cvars + .into_iter() + .flat_map(|f| { + // because all the arguments are assumed to be constants, + // so no need to pass the arguments in env + let env = fxhash::FxHashMap::default(); + Self::eval_ir(&env, &f) + }) + .collect(); + + Ok(PropagatedConstant::from(res[0].to_biguint())) + } + fn compile_native_function_call( &mut self, function: &FunctionDef, @@ -651,12 +789,20 @@ impl IRWriter { match &fn_info.kind { // assert() <-- for example - FnKind::BuiltIn(..) => Err(self.error( - ErrorKind::InvalidFnCall( - "builtin functions not allowed in hint functions.", - ), - expr.span, - )), + FnKind::BuiltIn(sig, ..) => { + if sig.name.value == "log" { + self.logs.push(vars[0].clone()); + + Ok(None) + } else { + Err(self.error( + ErrorKind::InvalidFnCall( + "builtin functions not allowed in hint functions.", + ), + expr.span, + )) + } + } // fn_name(args) // ^^^^^^^ FnKind::Native(func) => { @@ -896,15 +1042,67 @@ impl IRWriter { Var::new_cvar(t, expr.span) } Op2::Division => { - let t: Term = term![ - Op::PfNaryOp(PfNaryOp::Mul); lhs.cvars[0].clone(), - term![Op::PfUnOp(PfUnOp::Recip); rhs.cvars[0].clone()] - ]; + // convert to int + let a_int = term![Op::PfToInt; lhs.cvars[0].clone()]; + let b_int = term![Op::PfToInt; rhs.cvars[0].clone()]; + // division + let t = term![Op::IntBinOp(IntBinOp::Div); a_int, b_int]; + + // convert back to field + let t = term![Op::IntToPf(B::Field::to_circ_type()); t]; + Var::new_cvar(t, expr.span) } - _ => todo!(), - }; + Op2::Rem => { + let bit_len = B::Field::MODULUS_BIT_SIZE as usize; + let a_bv = term![Op::PfToBv(bit_len); lhs.cvars[0].clone()]; + let b_bv = term![Op::PfToBv(bit_len); rhs.cvars[0].clone()]; + let t = term![Op::BvBinOp(BvBinOp::Urem); a_bv.clone(), b_bv.clone()]; + let t = term![Op::UbvToPf(Box::new(B::Field::to_circ_type())); t]; + + Var::new_var(t, expr.span) + } + Op2::LShift => { + let bit_len = B::Field::MODULUS_BIT_SIZE as usize; + let a_bv = term![Op::PfToBv(bit_len); lhs.cvars[0].clone()]; + let b_bv = term![Op::PfToBv(bit_len); rhs.cvars[0].clone()]; + // if the shift result is larger than the bit length, it will be truncated: + // https://github.com/circify/circ/blob/4aa36e479fe15fb444cc9190e0cb5a1a493ee221/src/ir/term/bv.rs#L96 + // todo: should we allow field overflow in the hint calculation? + let t = term![Op::BvBinOp(BvBinOp::Shl); a_bv, b_bv]; + // convert back to field + let t = term![Op::UbvToPf(Box::new(B::Field::to_circ_type())); t]; + Var::new_var(t, expr.span) + } + Op2::LessThan => { + let a_int = term![Op::PfToInt; lhs.cvars[0].clone()]; + let b_int = term![Op::PfToInt; rhs.cvars[0].clone()]; + let t = term![Op::IntBinPred(IntBinPred::Lt); a_int, b_int]; + Var::new_var(t, expr.span) + } + Op2::Pow => { + let base_int = term![Op::PfToInt; lhs.cvars[0].clone()]; + let folded = circ::ir::opt::cfold::fold(&rhs.cvars[0].clone(), &[]); + let exp = match &folded.as_value_opt().unwrap() { + v => (**v).as_pf().i().to_u32().unwrap(), + _ => unreachable!(), + }; + let result = if exp == 0 { + let var = Var::new_constant(B::Field::from(1u32), expr.span); + term![Op::PfToInt; var.cvars[0].clone()] + } else { + let mut acc = base_int.clone(); + for _ in 1..exp { + acc = term![Op::IntNaryOp(IntNaryOp::Mul); acc, base_int.clone()]; + } + acc + }; + // convert back to field + let converted = term![Op::IntToPf(B::Field::to_circ_type()); result]; + Var::new_var(converted, expr.span) + } + }; Ok(Some(VarOrRef::Var(res))) } @@ -983,7 +1181,7 @@ impl IRWriter { .compute_expr(fn_env, idx)? .ok_or_else(|| self.error(ErrorKind::CannotComputeExpression, expr.span))?; let idx = idx_var - .constant() + .constant::() .ok_or_else(|| self.error(ErrorKind::ExpectedConstant, expr.span))?; let idx: usize = idx.try_into().unwrap(); @@ -1072,7 +1270,7 @@ impl IRWriter { .compute_expr(fn_env, size)? .ok_or_else(|| self.error(ErrorKind::CannotComputeExpression, expr.span))?; let size = size - .constant() + .constant::() .ok_or_else(|| self.error(ErrorKind::ExpectedConstant, expr.span))?; let size: usize = size.try_into().unwrap(); diff --git a/src/circuit_writer/mod.rs b/src/circuit_writer/mod.rs index 0f15db907..0a09481ae 100644 --- a/src/circuit_writer/mod.rs +++ b/src/circuit_writer/mod.rs @@ -134,7 +134,8 @@ impl CircuitWriter { backend, public_output: None, ir_writer: ir::IRWriter { - typed: typed.clone(), + typed: typed.0.clone(), + logs: vec![], }, } } diff --git a/src/circuit_writer/writer.rs b/src/circuit_writer/writer.rs index f5b8e2c61..10b5f3528 100644 --- a/src/circuit_writer/writer.rs +++ b/src/circuit_writer/writer.rs @@ -664,7 +664,11 @@ impl CircuitWriter { Op2::Inequality => field::not_equal(self, &lhs, &rhs, expr.span), Op2::BoolAnd => boolean::and(self, &lhs[0], &rhs[0], expr.span), Op2::BoolOr => boolean::or(self, &lhs[0], &rhs[0], expr.span), - Op2::Division => todo!(), + Op2::Division => unreachable!("/ is only supported in hint functions"), + Op2::Rem => unreachable!("% is only supported in hint functions"), + Op2::LShift => unreachable!("<< is only supported in hint functions"), + Op2::LessThan => unreachable!("< is only supported in hint functions"), + Op2::Pow => unreachable!("** is only supported in hint functions"), }; Ok(Some(VarOrRef::Var(res))) diff --git a/src/lexer/mod.rs b/src/lexer/mod.rs index 509fd4b07..a25a64e8e 100644 --- a/src/lexer/mod.rs +++ b/src/lexer/mod.rs @@ -142,6 +142,8 @@ pub enum TokenKind { RightCurlyBracket, // } SemiColon, // ; Slash, // / + Percent, // % + LeftDoubleArrow, // << Comment(String), // // comment Greater, // > Less, // < @@ -152,6 +154,7 @@ pub enum TokenKind { Minus, // - RightArrow, // -> Star, // * + DoubleStar, // ** Ampersand, // & DoubleAmpersand, // && Pipe, // | @@ -186,6 +189,8 @@ impl Display for TokenKind { RightCurlyBracket => "`}`", SemiColon => "`;`", Slash => "`/`", + Percent => "`%`", + LeftDoubleArrow => "`<<`", Comment(_) => "`//`", Greater => "`>`", Less => "`<`", @@ -196,6 +201,7 @@ impl Display for TokenKind { Minus => "`-`", RightArrow => "`->`", Star => "`*`", + DoubleStar => "`**`", Ampersand => "`&`", DoubleAmpersand => "`&&`", Pipe => "`|`", @@ -377,6 +383,18 @@ impl Token { tokens.push(TokenKind::Slash.new_token(ctx, 1)); } } + '%' => { + tokens.push(TokenKind::Percent.new_token(ctx, 1)); + } + '<' => { + let next_c = chars.peek(); + if matches!(next_c, Some(&'<')) { + tokens.push(TokenKind::LeftDoubleArrow.new_token(ctx, 2)); + chars.next(); + } else { + tokens.push(TokenKind::Less.new_token(ctx, 1)); + } + } '>' => { tokens.push(TokenKind::Greater.new_token(ctx, 1)); } @@ -422,6 +440,9 @@ impl Token { if matches!(next_c, Some(&'=')) { tokens.push(TokenKind::StarEqual.new_token(ctx, 2)); chars.next(); + } else if matches!(next_c, Some(&'*')) { + tokens.push(TokenKind::DoubleStar.new_token(ctx, 2)); + chars.next(); } else { tokens.push(TokenKind::Star.new_token(ctx, 1)); } diff --git a/src/mast/mod.rs b/src/mast/mod.rs index 4e7120840..1c124a14a 100644 --- a/src/mast/mod.rs +++ b/src/mast/mod.rs @@ -269,13 +269,19 @@ impl FnSig { } // const NN: Field _ => { - let cst = observed_arg.constant.clone(); - if is_generic_parameter(sig_arg.name.value.as_str()) && cst.is_some() { - self.generics.assign( - &sig_arg.name.value, - cst.unwrap().as_single(), - observed_arg.expr.span, - )?; + if is_generic_parameter(sig_arg.name.value.as_str()) { + if let Some(cst) = &observed_arg.constant { + self.generics.assign( + &sig_arg.name.value, + cst.as_single(), + observed_arg.expr.span, + )?; + } else { + return Err(error( + ErrorKind::GenericValueExpected(sig_arg.name.value.clone()), + observed_arg.expr.span, + )); + } } } } @@ -358,7 +364,7 @@ impl MastCtx { &mut self, old_qualified: FullyQualified, new_qualified: FullyQualified, - fn_info: FnInfo, + fn_info: &FnInfo, ) { self.tast .add_monomorphized_fn(new_qualified.clone(), fn_info); @@ -450,33 +456,8 @@ impl Mast { self.0.fn_info(qualified) } - // TODO: might want to memoize that at some point - /// Returns the number of field elements contained in the given type. - pub(crate) fn size_of(&self, typ: &TyKind) -> usize { - match typ { - TyKind::Field { .. } => 1, - TyKind::Custom { module, name } => { - let qualified = FullyQualified::new(module, name); - let struct_info = self - .struct_info(&qualified) - .expect("bug in the mast: cannot find struct info"); - - let mut sum = 0; - - for (_, t, _) in &struct_info.fields { - sum += self.size_of(t); - } - - sum - } - TyKind::Array(typ, len) => (*len as usize) * self.size_of(typ), - TyKind::GenericSizedArray(_, _) => { - unreachable!("generic arrays should have been resolved") - } - TyKind::Bool => 1, - TyKind::String(s) => s.len(), - TyKind::Tuple(typs) => typs.iter().map(|ty| self.size_of(ty)).sum(), - } + pub fn size_of(&self, typ: &TyKind) -> usize { + self.0.size_of(typ) } } /// Monomorphize the main function. @@ -526,7 +507,7 @@ pub fn monomorphize(tast: TypeChecker) -> Result> { func_def.body = stmts; main_fn.kind = FnKind::Native(func_def); - ctx.tast.add_monomorphized_fn(qualified, main_fn.clone()); + ctx.tast.add_monomorphized_fn(qualified, &main_fn); ctx.clear_generic_fns(); @@ -616,7 +597,7 @@ fn monomorphize_expr( let mono_qualified = FullyQualified::new(module, &resolved_sig.name.value); // check if this function is already monomorphized - if ctx.functions_instantiated.contains_key(&mono_qualified) { + let res = if ctx.functions_instantiated.contains_key(&mono_qualified) { let mexpr = expr.to_mast( ctx, &ExprKind::FnCall { @@ -635,7 +616,7 @@ fn monomorphize_expr( // retrieve the constant value from the cache let cst = ctx.cst_fn_cache.get(&mono_qualified).cloned(); - ExprMonoInfo::new(mexpr, typ, cst) + (ExprMonoInfo::new(mexpr, typ, cst), fn_info) } else { // monomorphize the function call let (fn_info_mono, typ, cst) = @@ -658,9 +639,28 @@ fn monomorphize_expr( ); let new_qualified = FullyQualified::new(module, &fn_name_mono.value); - ctx.add_monomorphized_fn(old_qualified, new_qualified, fn_info_mono); + ctx.add_monomorphized_fn(old_qualified, new_qualified, &fn_info_mono); - ExprMonoInfo::new(mexpr, typ, cst) + (ExprMonoInfo::new(mexpr, typ, cst), fn_info_mono) + }; + + // check if all observed args are constants + let all_cst_args = observed.iter().all(|f| f.constant.is_some()); + // IR writer to evaluate hints that only require constant inputs + if res.1.is_hint && all_cst_args { + let mut ir = crate::circuit_writer::ir::IRWriter:: { + typed: ctx.tast.clone(), + logs: vec![], + }; + let cst_args = observed.into_iter().map(|f| f.constant.unwrap()).collect(); + let fn_def = res.1.native().unwrap(); + let val = ir.evaluate(fn_def, cst_args)?; + let mut mono_info: ExprMonoInfo = res.0; + + mono_info.constant = Some(val); + mono_info + } else { + res.0 } } @@ -815,10 +815,14 @@ fn monomorphize_expr( let typ = match op { Op2::Equality => Some(TyKind::Bool), Op2::Inequality => Some(TyKind::Bool), + Op2::LessThan => Some(TyKind::Bool), Op2::Addition | Op2::Subtraction | Op2::Multiplication | Op2::Division + | Op2::Rem + | Op2::LShift + | Op2::Pow | Op2::BoolAnd | Op2::BoolOr => lhs_mono.typ, }; diff --git a/src/negative_tests.rs b/src/negative_tests.rs index 50fb49257..66c2d745d 100644 --- a/src/negative_tests.rs +++ b/src/negative_tests.rs @@ -654,7 +654,7 @@ fn test_hint_builtin_fn(qualified: &FullyQualified, code: &str) -> Result // add the mocked builtin function // note that this should happen in the tast phase, instead of mast phase. // currently this function is the only way to mock a builtin function. - tast.add_monomorphized_fn(qualified.clone(), fn_info); + tast.add_monomorphized_fn(qualified.clone(), &fn_info); typecheck_next_file_inner( &mut tast, diff --git a/src/parser/expr.rs b/src/parser/expr.rs index 4cb1ce62e..42cf2709b 100644 --- a/src/parser/expr.rs +++ b/src/parser/expr.rs @@ -140,6 +140,10 @@ pub enum Op2 { Subtraction, Multiplication, Division, + Rem, + LShift, + LessThan, + Pow, Equality, Inequality, BoolAnd, @@ -554,7 +558,11 @@ impl Expr { TokenKind::Plus | TokenKind::Minus | TokenKind::Star + | TokenKind::DoubleStar | TokenKind::Slash + | TokenKind::Percent + | TokenKind::LeftDoubleArrow + | TokenKind::Less | TokenKind::DoubleEqual | TokenKind::NotEqual | TokenKind::DoubleAmpersand @@ -568,7 +576,11 @@ impl Expr { TokenKind::Plus => Op2::Addition, TokenKind::Minus => Op2::Subtraction, TokenKind::Star => Op2::Multiplication, + TokenKind::DoubleStar => Op2::Pow, TokenKind::Slash => Op2::Division, + TokenKind::Percent => Op2::Rem, + TokenKind::LeftDoubleArrow => Op2::LShift, + TokenKind::Less => Op2::LessThan, TokenKind::DoubleEqual => Op2::Equality, TokenKind::NotEqual => Op2::Inequality, TokenKind::DoubleAmpersand => Op2::BoolAnd, diff --git a/src/type_checker/checker.rs b/src/type_checker/checker.rs index 45f543e31..3168c6e3c 100644 --- a/src/type_checker/checker.rs +++ b/src/type_checker/checker.rs @@ -172,10 +172,13 @@ impl TypeChecker { fn_name.span, ) })?; + let is_hint = fn_info.is_hint; let fn_sig = fn_info.sig().clone(); + let all_constants = fn_sig.arguments.iter().all(|arg| arg.is_constant()); - // check if the function is a hint - if fn_info.is_hint && !unsafe_attr { + // check if the function is a hint. + // ignore the unsafe attribute if we are in a hint function. + if !typed_fn_env.is_in_hint_fn() && fn_info.is_hint && !unsafe_attr { return Err(self.error(ErrorKind::ExpectedUnsafeAttribute, expr.span)); } @@ -204,7 +207,16 @@ impl TypeChecker { // type check the function call let method_call = false; - let res = self.check_fn_call(typed_fn_env, method_call, fn_sig, args, expr.span)?; + let mut res = + self.check_fn_call(typed_fn_env, method_call, fn_sig, args, expr.span)?; + + // if it is a hint function and only accept constant arguments, then its return can be assumed to be constants + if is_hint && all_constants { + res = res.map(|ty| match ty { + TyKind::Field { constant: _ } => TyKind::Field { constant: true }, + _ => unimplemented!("only field return type is supported for now"), + }); + } res.map(ExprTyInfo::new_anon) } @@ -412,10 +424,14 @@ impl TypeChecker { let typ = match op { Op2::Equality => TyKind::Bool, Op2::Inequality => TyKind::Bool, + Op2::LessThan => TyKind::Bool, Op2::Addition | Op2::Subtraction | Op2::Multiplication | Op2::Division + | Op2::Rem + | Op2::LShift + | Op2::Pow | Op2::BoolAnd | Op2::BoolOr => lhs_node.typ, }; diff --git a/src/type_checker/fn_env.rs b/src/type_checker/fn_env.rs index 0e0643318..7bbff7e90 100644 --- a/src/type_checker/fn_env.rs +++ b/src/type_checker/fn_env.rs @@ -56,18 +56,22 @@ pub struct TypedFnEnv { /// Determines if forloop variables are allowed to be accessed. forbid_forloop_scope: bool, + /// Indicates if the function is a hint function. + in_hint_fn: bool, + /// The kind of function we're currently type checking current_fn_kind: FuncOrMethod, } impl TypedFnEnv { /// Creates a new TypeEnv with the given function kind - pub fn new(fn_kind: &FuncOrMethod) -> Self { + pub fn new(fn_kind: &FuncOrMethod, is_hint: bool) -> Self { Self { current_scope: 0, vars: HashMap::new(), forloop_scopes: Vec::new(), forbid_forloop_scope: false, + in_hint_fn: is_hint, current_fn_kind: fn_kind.clone(), } } @@ -114,6 +118,11 @@ impl TypedFnEnv { self.forloop_scopes.pop(); } + /// Returns whether it is in a hint function. + pub fn is_in_hint_fn(&self) -> bool { + self.in_hint_fn + } + /// Returns true if a scope is a prefix of our scope. pub fn is_in_scope(&self, prefix_scope: usize) -> bool { self.current_scope >= prefix_scope diff --git a/src/type_checker/mod.rs b/src/type_checker/mod.rs index d5ea9f0ad..9c54bc4ad 100644 --- a/src/type_checker/mod.rs +++ b/src/type_checker/mod.rs @@ -124,6 +124,34 @@ impl TypeChecker { self.constants.get(&qualified) } + /// Returns the number of field elements contained in the given type. + pub(crate) fn size_of(&self, typ: &TyKind) -> usize { + match typ { + TyKind::Field { .. } => 1, + TyKind::Custom { module, name } => { + let qualified = FullyQualified::new(module, name); + let struct_info = self + .struct_info(&qualified) + .expect("bug in the mast: cannot find struct info"); + + let mut sum = 0; + + for (_, t, _) in &struct_info.fields { + sum += self.size_of(t); + } + + sum + } + TyKind::Array(typ, len) => (*len as usize) * self.size_of(typ), + TyKind::GenericSizedArray(_, _) => { + unreachable!("generic arrays should have been resolved") + } + TyKind::Bool => 1, + TyKind::String(s) => s.len(), + TyKind::Tuple(typs) => typs.iter().map(|ty| self.size_of(ty)).sum(), + } + } + pub fn last_node_id(&self) -> usize { self.node_id } @@ -132,8 +160,8 @@ impl TypeChecker { self.node_id = node_id; } - pub fn add_monomorphized_fn(&mut self, qualified: FullyQualified, fn_info: FnInfo) { - self.functions.insert(qualified, fn_info); + pub fn add_monomorphized_fn(&mut self, qualified: FullyQualified, fn_info: &FnInfo) { + self.functions.insert(qualified, fn_info.clone()); } pub fn add_monomorphized_type(&mut self, node_id: usize, typ: TyKind) { @@ -329,7 +357,7 @@ impl TypeChecker { // `fn main() { ... }` RootKind::FunctionDef(function) => { // create a new typed fn environment to type check the function - let mut typed_fn_env = TypedFnEnv::new(&function.sig.kind); + let mut typed_fn_env = TypedFnEnv::new(&function.sig.kind, function.is_hint); // if we're expecting a library, this should not be the main function let is_main = function.is_main(); diff --git a/src/var.rs b/src/var.rs index b508d69cd..eb676d476 100644 --- a/src/var.rs +++ b/src/var.rs @@ -71,7 +71,11 @@ where PublicOutput(Option), /// Resulted IR term and the (name, variable) for arugments of the hint function - HintIR(Term, Vec<(String, ConstOrCell)>), + HintIR( + Term, + Vec<(String, ConstOrCell)>, + Vec, + ), } impl std::fmt::Debug for Value { @@ -88,7 +92,7 @@ impl std::fmt::Debug for Value { Value::PublicOutput(..) => write!(f, "PublicOutput"), Value::Scale(..) => write!(f, "Scaling"), Value::NthBit(_, _) => write!(f, "NthBit"), - Value::HintIR(..) => write!(f, "HintIR"), + Value::HintIR(t, args, logs) => write!(f, "HintIR: {}, {:?}, {:?}", t, args, logs), } } }