diff --git a/examples/fixture/asm/kimchi/generic_assert_eq.asm b/examples/fixture/asm/kimchi/generic_assert_eq.asm new file mode 100644 index 000000000..00a1e9781 --- /dev/null +++ b/examples/fixture/asm/kimchi/generic_assert_eq.asm @@ -0,0 +1,33 @@ +@ noname.0.7.0 +@ public inputs: 3 + +DoubleGeneric<1> +DoubleGeneric<1> +DoubleGeneric<1> +DoubleGeneric<1,0,0,0,-3> +DoubleGeneric<1,0,0,0,-3> +DoubleGeneric<1,0,0,0,-1> +DoubleGeneric<1,0,0,0,-2> +DoubleGeneric<1,0,-1,0,2> +DoubleGeneric<1,0,0,0,-5> +DoubleGeneric<1,0,0,0,-1> +DoubleGeneric<1,0,0,0,-2> +DoubleGeneric<1,0,-1,0,2> +DoubleGeneric<0,0,-1,1> +DoubleGeneric<1,1,-1> +DoubleGeneric<1,1,-1> +DoubleGeneric<1,0,0,0,-5> +DoubleGeneric<1,0,0,0,-1> +DoubleGeneric<1,0,0,0,-1> +DoubleGeneric<1,0,0,0,-2> +DoubleGeneric<1,0,0,0,-3> +DoubleGeneric<1,0,0,0,-4> +DoubleGeneric<1,0,0,0,-5> +(0,0) -> (5,0) -> (9,0) -> (12,1) -> (13,0) -> (16,0) -> (17,0) +(1,0) -> (6,0) -> (10,0) -> (14,0) -> (18,0) +(2,0) -> (3,0) -> (4,0) -> (7,0) -> (11,0) -> (12,0) -> (14,1) -> (19,0) +(7,2) -> (8,0) +(11,2) -> (15,0) +(12,2) -> (13,1) +(13,2) -> (20,0) +(14,2) -> (21,0) diff --git a/examples/fixture/asm/r1cs/generic_assert_eq.asm b/examples/fixture/asm/r1cs/generic_assert_eq.asm new file mode 100644 index 000000000..a58c2bdb4 --- /dev/null +++ b/examples/fixture/asm/r1cs/generic_assert_eq.asm @@ -0,0 +1,18 @@ +@ noname.0.7.0 +@ public inputs: 3 + +3 == (v_3) * (1) +3 == (v_3) * (1) +1 == (v_1) * (1) +2 == (v_2) * (1) +5 == (v_3 + 2) * (1) +1 == (v_1) * (1) +2 == (v_2) * (1) +v_4 == (v_3) * (v_1) +5 == (v_3 + 2) * (1) +1 == (v_1) * (1) +1 == (v_1) * (1) +2 == (v_2) * (1) +3 == (v_3) * (1) +4 == (v_1 + v_4) * (1) +5 == (v_2 + v_3) * (1) diff --git a/examples/functions.no b/examples/functions.no index d69bf196a..2a66c2cc0 100644 --- a/examples/functions.no +++ b/examples/functions.no @@ -10,6 +10,8 @@ fn main(pub one: Field) { let four = add(one, 3); assert_eq(four, 4); + // double() should not be folded to return 8 + // the asm test will catch the missing constraint if it is folded let eight = double(4); assert_eq(eight, double(four)); } diff --git a/examples/generic_assert_eq.no b/examples/generic_assert_eq.no new file mode 100644 index 000000000..b1ac68e4c --- /dev/null +++ b/examples/generic_assert_eq.no @@ -0,0 +1,64 @@ +const size = 2; +struct Thing { + xx: Field, + yy: [Field; 2], +} + +struct Nestedthing { + xx: Field, + another: [Another; 2], +} + +struct Another { + aa: Field, + bb: [Field; 2], +} + +fn init_arr(element: Field, const LEN: Field) -> [Field; LEN] { + let arr = [element; LEN]; + return arr; +} + +fn main(pub public_arr: [Field; 2], pub public_input: Field) { + let generic_arr = init_arr(public_input, size); + let arr = [3, 3]; + + assert_eq(generic_arr, arr); + let mut concrete_arr = [1, 2]; + + // instead of the following: + // assert_eq(public_arr[0], concrete_arr[0]); + // assert_eq(public_arr[1], concrete_arr[1]); + // we can write: + assert_eq(public_arr, concrete_arr); + + let thing = Thing { xx: 5, yy: [1, 2] }; + let other_thing = Thing { xx: generic_arr[0] + 2, yy: public_arr }; + + // instead of the following: + // assert_eq(thing.xx, other_thing.xx); + // assert_eq(thing.yy[0], other_thing.yy[0]); + // assert_eq(thing.yy[1], other_thing.yy[1]); + // we can write: + assert_eq(thing, other_thing); + + let nested_thing = Nestedthing { xx: 5, another: [ + Another { aa: public_arr[0], bb: [1, 2] }, + Another { aa: generic_arr[1], bb: [4, 5] } + ] }; + let other_nested_thing = Nestedthing { xx: generic_arr[0] + 2, another: [ + Another { aa: 1, bb: public_arr }, + Another { aa: 3, bb: [public_arr[0] + (public_input * public_arr[0]), public_arr[1] + public_input] } + ] }; + + // instead of the following: + // assert_eq(nested_thing.xx, other_nested_thing.xx); + // assert_eq(nested_thing.another[0].aa, other_nested_thing.another[0].aa); + // assert_eq(nested_thing.another[0].bb[0], other_nested_thing.another[0].bb[0]); + // assert_eq(nested_thing.another[0].bb[1], other_nested_thing.another[0].bb[1]); + // assert_eq(nested_thing.another[1].aa, other_nested_thing.another[1].aa); + // assert_eq(nested_thing.another[1].bb[0], other_nested_thing.another[1].bb[0]); + // assert_eq(nested_thing.another[1].bb[1], other_nested_thing.another[1].bb[1]); + // we can write: + assert_eq(nested_thing, other_nested_thing); +} diff --git a/examples/log.no b/examples/log.no new file mode 100644 index 000000000..c9072c268 --- /dev/null +++ b/examples/log.no @@ -0,0 +1,23 @@ + +struct Thing { + xx: Field, + yy: Field +} + +fn main(pub public_input: Field) -> Field { + + log(1234); + log(true); + + let arr = [1,2,3]; + log(arr); + + let thing = Thing { xx : public_input , yy: public_input + 1}; + + log(thing); + + let tup = (1 , true , thing); + log("formatted string with a number {} boolean {} arr {} tuple {} struct {}" , 1234 , true, arr, tup, thing); + + return public_input + 1; +} \ No newline at end of file diff --git a/examples/tuple.no b/examples/tuple.no new file mode 100644 index 000000000..867c91998 --- /dev/null +++ b/examples/tuple.no @@ -0,0 +1,47 @@ +struct Thing { + xx: Field, + tuple_field: (Field,Bool) +} + +// return tuples from functions +fn Thing.new(xx: Field , tup: (Field,Bool)) -> (Thing , (Field,Bool)) { + return ( + Thing { + xx: xx, + tuple_field:tup + }, + tup + ); +} + +fn generic_array_tuple_test(var : ([[Field;NN];LEN],Bool)) -> (Field , [Field;NN]) { + let zero = 0; + let result = if var[1] {var[0][LEN - 1][NN - 1]} else { var[0][LEN - 2][NN - 2] }; + return (result , var[0][LEN - 1]); +} + +// xx should be 0 +fn main(pub xx: [Field; 2]) -> Field { + // creation of new tuple with different types + let tup = (1, true); + + // create nested tuples + let nested_tup = ((false, [1,2,3]), 1); + log(nested_tup); // (1, (true , [1,2,3])) + + let incr = nested_tup[1]; // 1 + + // tuples can be input to function + let mut thing = Thing.new(xx[1] , (xx[0] , xx[0] == 0)); + + // you can access a tuple type just like you access a array + thing[0].tuple_field[0] += incr; + log(thing[0].tuple_field[0]); + let new_allocation = [xx,xx]; + let ret = generic_array_tuple_test((new_allocation, true)); + + assert_eq(thing[0].tuple_field[0] , 1); + log(ret[1]); // logs xx i.e [0,123] + + return ret[0]; +} \ No newline at end of file diff --git a/src/backends/kimchi/mod.rs b/src/backends/kimchi/mod.rs index 9605eff39..b448b4949 100644 --- a/src/backends/kimchi/mod.rs +++ b/src/backends/kimchi/mod.rs @@ -22,7 +22,7 @@ use crate::{ backends::kimchi::asm::parse_coeffs, circuit_writer::{ writer::{AnnotatedCell, Cell, PendingGate}, - DebugInfo, Gate, GateKind, Wiring, + DebugInfo, Gate, GateKind, VarInfo, Wiring, }, compiler::Sources, constants::Span, @@ -128,6 +128,9 @@ pub struct KimchiVesta { /// Indexes used by the private inputs /// (this is useful to check that they appear in the circuit) pub(crate) private_input_cell_vars: Vec, + + /// Log information + pub(crate) log_info: Vec<(Span, VarInfo)>, } impl Witness { @@ -174,6 +177,7 @@ impl KimchiVesta { finalized: false, public_input_size: 0, private_input_cell_vars: vec![], + log_info: vec![], } } @@ -428,11 +432,11 @@ impl Backend for KimchiVesta { self.compute_val(env, &val.0, var.index) } - fn generate_witness( + fn generate_witness( &self, witness_env: &mut WitnessEnv, sources: &Sources, - _typed: &Mast, + typed: &Mast, ) -> Result { if !self.finalized { unreachable!("the circuit must be finalized before generating a witness"); @@ -481,6 +485,7 @@ impl Backend for KimchiVesta { } public_outputs.push(val); } + self.print_log(witness_env, &self.log_info, sources, typed)?; // sanity check the witness for (row, (gate, witness_row, debug_info)) in @@ -809,9 +814,8 @@ impl Backend for KimchiVesta { fn log_var( &mut self, var: &crate::circuit_writer::VarInfo, - msg: String, span: Span, ) { - println!("todo: implement log_var for kimchi backend"); + self.log_info.push((span, var.clone())); } } diff --git a/src/backends/mod.rs b/src/backends/mod.rs index 100e82c49..14d5df27e 100644 --- a/src/backends/mod.rs +++ b/src/backends/mod.rs @@ -12,6 +12,8 @@ use crate::{ error::{Error, ErrorKind, Result}, helpers::PrettyField, imports::FnHandle, + parser::types::TyKind, + utils::{log_array_or_tuple_type, log_custom_type, log_string_type}, var::{ConstOrCell, Value, Var}, witness::WitnessEnv, }; @@ -404,15 +406,133 @@ pub trait Backend: Clone { ) -> Result<()>; /// Generate the witness for a backend. - fn generate_witness( + fn generate_witness( &self, witness_env: &mut WitnessEnv, sources: &Sources, - typed: &Mast, + typed: &Mast, ) -> Result; /// Generate the asm for a backend. fn generate_asm(&self, sources: &Sources, debug: bool) -> String; - fn log_var(&mut self, var: &VarInfo, msg: String, span: Span); + fn log_var(&mut self, var: &VarInfo, span: Span); + + /// print the log given the log_info + fn print_log( + &self, + witness_env: &mut WitnessEnv, + logs: &[(Span, VarInfo)], + sources: &Sources, + typed: &Mast, + ) -> Result<()> { + let mut logs_iter = logs.into_iter(); + while let Some((span, var_info)) = logs_iter.next() { + let (filename, source) = sources.get(&span.filename_id).unwrap(); + let (line, _, _) = crate::utils::find_exact_line(source, *span); + let dbg_msg = format!("[{filename}:{line}] -> "); + + match &var_info.typ { + // Field + Some(TyKind::Field { .. }) => match &var_info.var[0] { + ConstOrCell::Const(cst) => { + println!("{dbg_msg}{}", cst.pretty()); + } + ConstOrCell::Cell(cell) => { + let val = self.compute_var(witness_env, cell)?; + println!("{dbg_msg}{}", val.pretty()); + } + }, + + // Bool + Some(TyKind::Bool) => match &var_info.var[0] { + ConstOrCell::Const(cst) => { + let val = *cst == Self::Field::one(); + println!("{dbg_msg}{}", val); + } + ConstOrCell::Cell(cell) => { + let val = self.compute_var(witness_env, cell)? == Self::Field::one(); + println!("{dbg_msg}{}", val); + } + }, + + // Array + Some(TyKind::Array(b, s)) => { + let mut typs = Vec::with_capacity(*s as usize); + for _ in 0..(*s) { + typs.push((**b).clone()); + } + let (output, remaining) = log_array_or_tuple_type( + self, + &var_info.var.cvars, + &typs, + *s, + witness_env, + typed, + span, + false, + )?; + assert!(remaining.is_empty()); + println!("{dbg_msg}{}", output); + } + + // Custom types + Some(TyKind::Custom { + module, + name: struct_name, + }) => { + let mut string_vec = Vec::new(); + let (output, remaining) = log_custom_type( + self, + module, + struct_name, + typed, + &var_info.var.cvars, + witness_env, + span, + &mut string_vec, + )?; + assert!(remaining.is_empty()); + println!("{dbg_msg}{}{}", struct_name, output); + } + + // GenericSizedArray + Some(TyKind::GenericSizedArray(_, _)) => { + unreachable!("GenericSizedArray should be monomorphized") + } + + Some(TyKind::String(s)) => { + let output = + log_string_type(self, &mut logs_iter, s, witness_env, typed, span)?; + println!("{dbg_msg}{}", output); + } + + Some(TyKind::Tuple(typs)) => { + let len = typs.len(); + let (output, remaining) = log_array_or_tuple_type( + self, + &var_info.var.cvars, + &typs, + len as u32, + witness_env, + typed, + span, + true, + ) + .unwrap(); + assert!(remaining.is_empty()); + println!("{dbg_msg}{}", output); + } + None => { + return Err(Error::new( + "log", + ErrorKind::UnexpectedError("No type info for logging"), + *span, + )) + } + } + } + + Ok(()) + } } diff --git a/src/backends/r1cs/mod.rs b/src/backends/r1cs/mod.rs index 0fa9952d5..5a2461f1e 100644 --- a/src/backends/r1cs/mod.rs +++ b/src/backends/r1cs/mod.rs @@ -1,16 +1,16 @@ pub mod arkworks; pub mod builtin; pub mod snarkjs; - -use std::collections::{HashMap, HashSet}; - +use crate::helpers::PrettyField; use circ::cfg::{CircCfg, CircOpt}; use circ_fields::FieldV; use itertools::{izip, Itertools as _}; use kimchi::o1_utils::FieldHelpers; use num_bigint::BigUint; +use num_traits::One; use rug::Integer; use serde::{Deserialize, Serialize}; +use std::collections::{HashMap, HashSet}; use crate::circuit_writer::VarInfo; use crate::compiler::Sources; @@ -271,7 +271,7 @@ where /// Debug information for each constraint. debug_info: Vec, /// Debug information for var info. - log_info: Vec<(String, Span, VarInfo>)>, + log_info: Vec<(Span, VarInfo>)>, /// Record the public inputs for reordering the witness vector public_inputs: Vec, /// Record the private inputs for checking @@ -493,11 +493,11 @@ where /// Generate the witnesses /// This process should check if the constraints are satisfied. - fn generate_witness( + fn generate_witness( &self, witness_env: &mut crate::witness::WitnessEnv, sources: &Sources, - typed: &Mast, + typed: &Mast, ) -> crate::error::Result { assert!(self.finalized, "the circuit is not finalized yet!"); @@ -523,78 +523,7 @@ where witness[var.index] = val; } - // print out the log info - for (_, span, var_info) in &self.log_info { - let (filename, source) = sources.get(&span.filename_id).unwrap(); - let (line, _, line_str) = crate::utils::find_exact_line(source, *span); - let line_str = line_str.trim_start(); - let dbg_msg = format!("[{filename}:{line}] `{line_str}` -> "); - - match &var_info.typ { - // Field - Some(TyKind::Field { .. }) => match &var_info.var[0] { - ConstOrCell::Const(cst) => { - println!("{dbg_msg}{}", cst.pretty()); - } - ConstOrCell::Cell(cell) => { - let val = cell.evaluate(&witness); - println!("{dbg_msg}{}", val.pretty()); - } - }, - - // Bool - Some(TyKind::Bool) => match &var_info.var[0] { - ConstOrCell::Const(cst) => { - let val = *cst == F::one(); - println!("{dbg_msg}{}", val); - } - ConstOrCell::Cell(cell) => { - let val = cell.evaluate(&witness) == F::one(); - println!("{dbg_msg}{}", val); - } - }, - - // Array - Some(TyKind::Array(b, s)) => { - let (output, remaining) = - log_array_type(&var_info.var.cvars, b, *s, &witness, typed, span); - assert!(remaining.is_empty()); - println!("{dbg_msg}{}", output); - } - - // Custom types - Some(TyKind::Custom { - module, - name: struct_name, - }) => { - let mut string_vec = Vec::new(); - let (output, remaining) = log_custom_type( - module, - struct_name, - typed, - &var_info.var.cvars, - &witness, - span, - &mut string_vec, - ); - assert!(remaining.is_empty()); - println!("{dbg_msg}{}{}", struct_name, output); - } - - // GenericSizedArray - Some(TyKind::GenericSizedArray(_, _)) => { - unreachable!("GenericSizedArray should be monomorphized") - } - - None => { - return Err(Error::new( - "log", - ErrorKind::UnexpectedError("No type info for logging"), - *span, - )) - } - } - } + self.print_log(witness_env, &self.log_info, sources, typed)?; for (index, (constraint, debug_info)) in izip!(&self.constraints, &self.debug_info).enumerate() @@ -762,175 +691,8 @@ where var } - fn log_var(&mut self, var: &VarInfo, msg: String, span: Span) { - self.log_info.push((msg, span, var.clone())); - } -} - -fn log_custom_type( - module: &ModulePath, - struct_name: &String, - typed: &Mast, - var_info_var: &[ConstOrCell>], - witness: &[F], - span: &Span, - string_vec: &mut Vec, -) -> (String, Vec>>) { - let qualified = FullyQualified::new(module, struct_name); - let struct_info = typed - .struct_info(&qualified) - .ok_or( - typed - .0 - .error(ErrorKind::UnexpectedError("struct not found"), *span), - ) - .unwrap(); - - let mut remaining = var_info_var.to_vec(); - - for (field_name, field_typ) in &struct_info.fields { - let len = typed.size_of(field_typ); - match field_typ { - TyKind::Field { .. } => match &remaining[0] { - ConstOrCell::Const(cst) => { - string_vec.push(format!("{field_name}: {}", cst.pretty())); - remaining = remaining[len..].to_vec(); - } - ConstOrCell::Cell(cell) => { - let val = cell.evaluate(witness); - string_vec.push(format!("{field_name}: {}", val.pretty())); - remaining = remaining[len..].to_vec(); - } - }, - - TyKind::Bool => match &remaining[0] { - ConstOrCell::Const(cst) => { - let val = *cst == F::one(); - string_vec.push(format!("{field_name}: {}", val)); - remaining = remaining[len..].to_vec(); - } - ConstOrCell::Cell(cell) => { - let val = cell.evaluate(witness) == F::one(); - string_vec.push(format!("{field_name}: {}", val)); - remaining = remaining[len..].to_vec(); - } - }, - - TyKind::Array(b, s) => { - let (output, new_remaining) = - log_array_type(&remaining, b, *s, witness, typed, span); - string_vec.push(format!("{field_name}: {}", output)); - remaining = new_remaining; - } - - TyKind::Custom { - module, - name: struct_name, - } => { - let mut custom_string_vec = Vec::new(); - let (output, new_remaining) = log_custom_type( - module, - struct_name, - typed, - &remaining, - witness, - span, - &mut custom_string_vec, - ); - string_vec.push(format!("{}: {}{}", field_name, struct_name, output)); - remaining = new_remaining; - } - - TyKind::GenericSizedArray(_, _) => { - unreachable!("GenericSizedArray should be monomorphized") - } - } - } - - (format!("{{ {} }}", string_vec.join(", ")), remaining) -} - -fn log_array_type( - var_info_var: &[ConstOrCell>], - base_type: &TyKind, - size: u32, - witness: &[F], - typed: &Mast, - span: &Span, -) -> (String, Vec>>) { - match base_type { - TyKind::Field { .. } => { - let values: Vec = var_info_var - .iter() - .take(size as usize) - .map(|cvar| match cvar { - ConstOrCell::Const(cst) => cst.pretty(), - ConstOrCell::Cell(cell) => cell.evaluate(witness).pretty(), - }) - .collect(); - - let remaining = var_info_var[size as usize..].to_vec(); - (format!("[{}]", values.join(", ")), remaining) - } - - TyKind::Bool => { - let values: Vec = var_info_var - .iter() - .take(size as usize) - .map(|cvar| match cvar { - ConstOrCell::Const(cst) => { - let val = *cst == F::one(); - val.to_string() - } - ConstOrCell::Cell(cell) => { - let val = cell.evaluate(witness) == F::one(); - val.to_string() - } - }) - .collect(); - - let remaining = var_info_var[size as usize..].to_vec(); - (format!("[{}]", values.join(", ")), remaining) - } - - TyKind::Array(inner_type, inner_size) => { - let mut nested_result = Vec::new(); - let mut remaining = var_info_var.to_vec(); - for _ in 0..size { - let (chunk_result, new_remaining) = - log_array_type(&remaining, inner_type, *inner_size, witness, typed, span); - nested_result.push(chunk_result); - remaining = new_remaining; - } - (format!("[{}]", nested_result.join(", ")), remaining) - } - - TyKind::Custom { - module, - name: struct_name, - } => { - let mut nested_result = Vec::new(); - let mut remaining = var_info_var.to_vec(); - for _ in 0..size { - let mut string_vec = Vec::new(); - let (output, new_remaining) = log_custom_type( - module, - struct_name, - typed, - &remaining, - witness, - span, - &mut string_vec, - ); - nested_result.push(format!("{}{}", struct_name, output)); - remaining = new_remaining; - } - (format!("[{}]", nested_result.join(", ")), remaining) - } - - TyKind::GenericSizedArray(_, _) => { - unreachable!("GenericSizedArray should be monomorphized") - } + fn log_var(&mut self, var: &VarInfo, span: Span) { + self.log_info.push((span, var.clone())); } } diff --git a/src/circuit_writer/ir.rs b/src/circuit_writer/ir.rs index c0d836cec..b079e0f2d 100644 --- a/src/circuit_writer/ir.rs +++ b/src/circuit_writer/ir.rs @@ -1137,6 +1137,17 @@ impl IRWriter { Ok(Some(VarOrRef::Var(v))) } + ExprKind::StringLiteral(s) => { + // chars as field elements from asci;; + let fr: Vec = s.chars().map(|char| B::Field::from(char as u8)).collect(); + let cvars = fr + .iter() + .map(|&f| leaf_term(Op::new_const(Value::Field(f.to_circ_field())))) + .collect(); + + Ok(Some(VarOrRef::Var(Var::new(cvars, expr.span)))) + } + ExprKind::Variable { module, name } => { // if it's a type we return nothing // (most likely what follows is a static method call) @@ -1159,11 +1170,11 @@ impl IRWriter { Ok(Some(res)) } - ExprKind::ArrayAccess { array, idx } => { - // retrieve var of array + ExprKind::ArrayOrTupleAccess { container, idx } => { + // retrieve var of container let var = self - .compute_expr(fn_env, array)? - .expect("array access on non-array"); + .compute_expr(fn_env, container)? + .expect("container access on non-container"); // compute the index let idx_var = self @@ -1174,10 +1185,15 @@ impl IRWriter { .ok_or_else(|| self.error(ErrorKind::ExpectedConstant, expr.span))?; let idx: usize = idx.try_into().unwrap(); - // retrieve the type of the elements in the array - let array_typ = self.expr_type(array).expect("cannot find type of array"); + // retrieve the type of the elements in the container + let container_typ = self + .expr_type(container) + .expect("cannot find type of container"); - let elem_type = match array_typ { + // actual starting index for narrowing the var depends on the cotainer + // for arrays it is just idx * elem_size as all elements are of same size + // while for tuples we have to sum the sizes of all types up to that index + let (start, len) = match container_typ { TyKind::Array(ty, array_len) => { if idx >= (*array_len as usize) { return Err(self.error( @@ -1185,18 +1201,25 @@ impl IRWriter { expr.span, )); } - ty + let len = self.size_of(ty); + let start = idx * self.size_of(ty); + (start, len) + } + + TyKind::Tuple(typs) => { + let mut starting_idx = 0; + for i in 0..idx { + starting_idx += self.size_of(&typs[i]); + } + (starting_idx, self.size_of(&typs[idx])) } _ => Err(Error::new( "compute-expr", - ErrorKind::UnexpectedError("expected array"), + ErrorKind::UnexpectedError("expected container"), expr.span, ))?, }; - // compute the size of each element in the array - let len = self.size_of(elem_type); - // compute the real index let start = idx * len; @@ -1261,6 +1284,20 @@ impl IRWriter { let var = VarOrRef::Var(Var::new(cvars, expr.span)); Ok(Some(var)) } + + ExprKind::TupleDeclaration(items) => { + let mut cvars = vec![]; + + for item in items { + let var = self.compute_expr(fn_env, item)?.unwrap(); + let to_extend = var.value(self, fn_env).cvars.clone(); + cvars.extend(to_extend); + } + + let var = VarOrRef::Var(Var::new(cvars, expr.span)); + + Ok(Some(var)) + } } } diff --git a/src/circuit_writer/writer.rs b/src/circuit_writer/writer.rs index 21ccde7a7..a79aa60e3 100644 --- a/src/circuit_writer/writer.rs +++ b/src/circuit_writer/writer.rs @@ -331,6 +331,16 @@ impl CircuitWriter { TyKind::GenericSizedArray(_, _) => { unreachable!("generic array should have been resolved") } + TyKind::String(_) => todo!("String type is not supported for constraints"), + TyKind::Tuple(types) => { + let mut offset = 0; + for ty in types { + let size = self.size_of(ty); + let slice = &input[offset..(offset + size)]; + self.constrain_inputs_to_main(slice, input_typ, span)?; + offset += size; + } + } }; Ok(()) } @@ -700,6 +710,13 @@ impl CircuitWriter { Ok(Some(res)) } + ExprKind::StringLiteral(s) => { + let chars_in_ff: Vec = + s.chars().map(|char| B::Field::from(char as u8)).collect(); + let cvars = chars_in_ff.iter().map(|&f| ConstOrCell::Const(f)).collect(); + Ok(Some(VarOrRef::Var(Var::new(cvars, expr.span)))) + } + ExprKind::Variable { module, name } => { // if it's a type we return nothing // (most likely what follows is a static method call) @@ -722,11 +739,11 @@ impl CircuitWriter { Ok(Some(res)) } - ExprKind::ArrayAccess { array, idx } => { - // retrieve var of array + ExprKind::ArrayOrTupleAccess { container, idx } => { + // retrieve var of container let var = self - .compute_expr(fn_env, array)? - .expect("array access on non-array"); + .compute_expr(fn_env, container)? + .expect("container access on non-container"); // compute the index let idx_var = self @@ -738,10 +755,15 @@ impl CircuitWriter { let idx: BigUint = idx.into(); let idx: usize = idx.try_into().unwrap(); - // retrieve the type of the elements in the array - let array_typ = self.expr_type(array).expect("cannot find type of array"); + // retrieve the type of the elements in the container + let container_typ = self + .expr_type(container) + .expect("cannot find type of container"); - let elem_type = match array_typ { + // actual starting index for narrowing the var depends on the cotainer + // for arrays it is just idx * elem_size as all elements are of same size + // while for tuples we have to sum the sizes of all types up to that index + let (start, len) = match container_typ { TyKind::Array(ty, array_len) => { if idx >= (*array_len as usize) { return Err(self.error( @@ -749,21 +771,25 @@ impl CircuitWriter { expr.span, )); } - ty + let len = self.size_of(ty); + let start = idx * self.size_of(ty); + (start, len) + } + + TyKind::Tuple(typs) => { + let mut start = 0; + for i in 0..idx { + start += self.size_of(&typs[i]); + } + (start, self.size_of(&typs[idx])) } _ => Err(Error::new( "compute-expr", - ErrorKind::UnexpectedError("expected array"), + ErrorKind::UnexpectedError("expected container"), expr.span, ))?, }; - // compute the size of each element in the array - let len = self.size_of(elem_type); - - // compute the real index - let start = idx * len; - // out-of-bound checks if start >= var.len() || start + len > var.len() { return Err(self.error( @@ -826,6 +852,21 @@ impl CircuitWriter { let var = VarOrRef::Var(Var::new(cvars, expr.span)); Ok(Some(var)) } + // exact copy of Array Declaration there is nothing really different at when looking it from a expression level + // as both of them are just `Vec` + ExprKind::TupleDeclaration(items) => { + let mut cvars = vec![]; + + for item in items { + let var = self.compute_expr(fn_env, item)?.unwrap(); + let to_extend = var.value(self, fn_env).cvars.clone(); + cvars.extend(to_extend); + } + + let var = VarOrRef::Var(Var::new(cvars, expr.span)); + + Ok(Some(var)) + } } } diff --git a/src/error.rs b/src/error.rs index a4542e4a2..9d086fc2e 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,3 +1,4 @@ +use clap::error; use miette::Diagnostic; use thiserror::Error; @@ -44,6 +45,8 @@ pub enum ErrorKind { AssignmentToImmutableVariable, #[error("the {0} of assert_eq must be of type Field or BigInt. It was of type {1}")] AssertTypeMismatch(&'static str, TyKind), + #[error("the types in assert_eq don't match: expected {0} but got {1}")] + AssertEqTypeMismatch(TyKind, TyKind), #[error( "the dependency `{0}` does not appear to be listed in your manifest file `Noname.toml`" )] @@ -265,6 +268,9 @@ pub enum ErrorKind { #[error("array accessed at index {0} is out of bounds (max allowed index is {1})")] ArrayIndexOutOfBounds(usize, usize), + #[error("tuple accessed at index {0} is out of bounds (max allowed index is {1})")] + TupleIndexOutofBounds(usize, usize), + #[error( "one-letter variables or types are not allowed. Best practice is to use descriptive names" )] @@ -324,8 +330,8 @@ pub enum ErrorKind { #[error("field access can only be applied on custom structs")] FieldAccessOnNonCustomStruct, - #[error("array access can only be performed on arrays")] - ArrayAccessOnNonArray, + #[error("array like access can only be performed on arrays or tuples")] + AccessOnNonCollection, #[error("struct `{0}` does not exist (are you sure it is defined?)")] UndefinedStruct(String), @@ -365,4 +371,10 @@ pub enum ErrorKind { #[error("division by zero")] DivisionByZero, + + #[error("lhs `{0}` is less than rhs `{1}`")] + NegativeLhsLessThanRhs(String, String), + + #[error("Not enough variables provided to fill placeholders in the formatted string")] + InsufficientVariables, } diff --git a/src/inputs.rs b/src/inputs.rs index d4dbaee5f..9994fa44f 100644 --- a/src/inputs.rs +++ b/src/inputs.rs @@ -151,6 +151,22 @@ impl CompiledCircuit { Ok(res) } + // parsing for tuple function inputs from json + (TyKind::Tuple(types), Value::Array(values)) => { + if values.len() != types.len() { + Err(ParsingError::ArraySizeMismatch( + values.len(), + types.len() as usize, + ))? + } + // making a vec with capacity allows for less number of reallocations + let mut res = Vec::with_capacity(values.len()); + for (ty, val) in types.iter().zip(values) { + let el = self.parse_single_input(val, ty)?; + res.extend(el); + } + Ok(res) + } (expected, observed) => { return Err(ParsingError::MismatchJsonArgument( expected.clone(), diff --git a/src/lexer/mod.rs b/src/lexer/mod.rs index 765fb5864..a25a64e8e 100644 --- a/src/lexer/mod.rs +++ b/src/lexer/mod.rs @@ -126,45 +126,45 @@ impl Display for Keyword { #[derive(Debug, Clone, PartialEq, Eq, Serialize)] pub enum TokenKind { - Keyword(Keyword), // reserved keywords - Identifier(String), // [a-zA-Z](A-Za-z0-9_)* - BigUInt(BigUint), // (0-9)* - Dot, // . - DoubleDot, // .. - Comma, // , - Colon, // : - DoubleColon, // :: - LeftParen, // ( - RightParen, // ) - LeftBracket, // [ - RightBracket, // ] - LeftCurlyBracket, // { - RightCurlyBracket, // } - SemiColon, // ; - Slash, // / - Percent, // % - LeftDoubleArrow, // << - Comment(String), // // comment - Greater, // > - Less, // < - Equal, // = - DoubleEqual, // == - NotEqual, // != - Plus, // + - Minus, // - - RightArrow, // -> - Star, // * - DoubleStar, // ** - Ampersand, // & - DoubleAmpersand, // && - Pipe, // | - DoublePipe, // || - Exclamation, // ! - Question, // ? - PlusEqual, // += - MinusEqual, // -= - StarEqual, // *= - // Literal, // "thing" + Keyword(Keyword), // reserved keywords + Identifier(String), // [a-zA-Z](A-Za-z0-9_)* + BigUInt(BigUint), // (0-9)* + Dot, // . + DoubleDot, // .. + Comma, // , + Colon, // : + DoubleColon, // :: + LeftParen, // ( + RightParen, // ) + LeftBracket, // [ + RightBracket, // ] + LeftCurlyBracket, // { + RightCurlyBracket, // } + SemiColon, // ; + Slash, // / + Percent, // % + LeftDoubleArrow, // << + Comment(String), // // comment + Greater, // > + Less, // < + Equal, // = + DoubleEqual, // == + NotEqual, // != + Plus, // + + Minus, // - + RightArrow, // -> + Star, // * + DoubleStar, // ** + Ampersand, // & + DoubleAmpersand, // && + Pipe, // | + DoublePipe, // || + Exclamation, // ! + Question, // ? + PlusEqual, // += + MinusEqual, // -= + StarEqual, // *= + StringLiteral(String), // "thing" } impl Display for TokenKind { @@ -211,7 +211,7 @@ impl Display for TokenKind { PlusEqual => "`+=`", MinusEqual => "`-=`", StarEqual => "`*=`", - // TokenType::Literal => "`\"something\"", + StringLiteral(_) => "`\"something\"", }; write!(f, "{}", desc) @@ -478,6 +478,12 @@ impl Token { '?' => { tokens.push(TokenKind::Question.new_token(ctx, 1)); } + '"' => { + //TODO: Add error handling if qoute not closed + let literal: String = chars.by_ref().take_while(|&char| char != '"').collect(); + let len = literal.len(); + tokens.push(TokenKind::StringLiteral(literal).new_token(ctx, len)); + } ' ' => ctx.offset += 1, _ => { return Err(ctx.error( diff --git a/src/mast/mod.rs b/src/mast/mod.rs index d07d9bf2c..ca8e34e38 100644 --- a/src/mast/mod.rs +++ b/src/mast/mod.rs @@ -230,6 +230,10 @@ impl FnSig { val.to_u32().expect("array size exceeded u32"), ) } + TyKind::Tuple(typs) => { + let typs: Vec = typs.iter().map(|ty| self.resolve_type(ty, ctx)).collect(); + TyKind::Tuple(typs) + } _ => typ.clone(), } } @@ -251,6 +255,18 @@ impl FnSig { observed_arg.expr.span, )?; } + // if generics in tuple + (TyKind::Tuple(sig_arg_typs), TyKind::Tuple(observed_arg_typs)) => { + for (sig_arg_typ, observed_arg_typ) in + sig_arg_typs.iter().zip(observed_arg_typs) + { + self.resolve_generic_array( + &sig_arg_typ, + &observed_arg_typ, + observed_arg.expr.span, + )?; + } + } // const NN: Field _ => { if is_generic_parameter(sig_arg.name.value.as_str()) { @@ -404,6 +420,7 @@ impl Symbolic { } Symbolic::Generic(g) => gens.get(&g.value), Symbolic::Add(a, b) => a.eval(gens, tast) + b.eval(gens, tast), + Symbolic::Sub(a, b) => a.eval(gens, tast) - b.eval(gens, tast), Symbolic::Mul(a, b) => a.eval(gens, tast) * b.eval(gens, tast), } } @@ -814,20 +831,42 @@ fn monomorphize_expr( let ExprMonoInfo { expr: rhs_expr, .. } = rhs_mono; // fold constants - let cst = match (&lhs_expr.kind, &rhs_expr.kind) { - (ExprKind::BigUInt(lhs), ExprKind::BigUInt(rhs)) => match op { - Op2::Addition => Some(lhs + rhs), - Op2::Subtraction => Some(lhs - rhs), - Op2::Multiplication => Some(lhs * rhs), - Op2::Division => Some(lhs / rhs), - _ => None, - }, + let cst = match (&lhs_mono.constant, &rhs_mono.constant) { + (Some(PropagatedConstant::Single(lhs)), Some(PropagatedConstant::Single(rhs))) => { + match op { + Op2::Addition => Some(lhs + rhs), + Op2::Subtraction => { + if lhs < rhs { + // throw error + return Err(error( + ErrorKind::NegativeLhsLessThanRhs( + lhs.to_string(), + rhs.to_string(), + ), + expr.span, + )); + } + Some(lhs - rhs) + } + Op2::Multiplication => Some(lhs * rhs), + Op2::Division => Some(lhs / rhs), + _ => None, + } + } _ => None, }; match cst { Some(v) => { - let mexpr = expr.to_mast(ctx, &ExprKind::BigUInt(v.clone())); + let mexpr = expr.to_mast( + ctx, + &ExprKind::BinaryOp { + op: op.clone(), + protected: *protected, + lhs: Box::new(lhs_expr), + rhs: Box::new(rhs_expr), + }, + ); ExprMonoInfo::new(mexpr, typ, Some(PropagatedConstant::from(v))) } @@ -881,6 +920,19 @@ fn monomorphize_expr( ExprMonoInfo::new(mexpr, Some(TyKind::Bool), None) } + ExprKind::StringLiteral(inner) => { + let mexpr = expr.to_mast(ctx, &ExprKind::StringLiteral(inner.clone())); + let string_literal_val: Vec = inner + .chars() + .map(|char| PropagatedConstant::Single(BigUint::from(char as u8))) + .collect(); + ExprMonoInfo::new( + mexpr, + Some(TyKind::String(inner.clone())), + Some(PropagatedConstant::Array(string_literal_val)), + ) + } + // mod::path.of.var // it could be also a generic variable ExprKind::Variable { module, name } => { @@ -935,18 +987,26 @@ fn monomorphize_expr( res } - ExprKind::ArrayAccess { array, idx } => { + ExprKind::ArrayOrTupleAccess { container, idx } => { // get type of lhs - let array_mono = monomorphize_expr(ctx, array, mono_fn_env)?; + let array_mono = monomorphize_expr(ctx, container, mono_fn_env)?; let id_mono = monomorphize_expr(ctx, idx, mono_fn_env)?; // get type of element let el_typ = match array_mono.typ { Some(TyKind::Array(typkind, _)) => Some(*typkind), + Some(TyKind::Tuple(typs)) => match &idx.kind { + ExprKind::BigUInt(index) => Some(typs[index.to_usize().unwrap()].clone()), + _ => Err(Error::new( + "Non constant container access", + ErrorKind::ExpectedConstant, + expr.span, + ))?, + }, _ => Err(Error::new( - "Array Access", + "Container Access", ErrorKind::UnexpectedError( - "Attempting to access array when type is not an array", + "Attempting to access container when type is not an container", ), expr.span, ))?, @@ -954,8 +1014,8 @@ fn monomorphize_expr( let mexpr = expr.to_mast( ctx, - &ExprKind::ArrayAccess { - array: Box::new(array_mono.expr), + &ExprKind::ArrayOrTupleAccess { + container: Box::new(array_mono.expr), idx: Box::new(id_mono.expr), }, ); @@ -1129,6 +1189,30 @@ fn monomorphize_expr( return Err(error(ErrorKind::InvalidArraySize, expr.span)); } } + + ExprKind::TupleDeclaration(items) => { + // checking the size of the tuple + let _: u32 = items.len().try_into().expect("tuple too large"); + + let items_mono: Vec = items + .iter() + .map(|item| monomorphize_expr(ctx, item, mono_fn_env).unwrap()) + .collect(); + + let typs: Vec = items_mono + .iter() + .cloned() + .map(|item_mono| item_mono.typ.unwrap()) + .collect(); + + let mexpr = expr.to_mast( + ctx, + &ExprKind::ArrayDeclaration(items_mono.into_iter().map(|e| e.expr).collect()), + ); + + let typ = TyKind::Tuple(typs); + ExprMonoInfo::new(mexpr, Some(typ), None) + } }; if let Some(typ) = &expr_mono.typ { @@ -1324,12 +1408,19 @@ pub fn instantiate_fn_call( // canonicalize the arguments depending on method call or not let expected: Vec<_> = fn_sig.arguments.iter().collect(); + let ignore_arg_types = match fn_info.kind { + FnKind::BuiltIn(_, _, ignore) => ignore, + FnKind::Native(_) => false, + }; + // check argument length if expected.len() != args.len() { - return Err(error( - ErrorKind::MismatchFunctionArguments(args.len(), expected.len()), - span, - )); + if !ignore_arg_types { + return Err(error( + ErrorKind::MismatchFunctionArguments(args.len(), expected.len()), + span, + )); + } } // create a context for the function call diff --git a/src/name_resolution/context.rs b/src/name_resolution/context.rs index d4d196eea..23314adc1 100644 --- a/src/name_resolution/context.rs +++ b/src/name_resolution/context.rs @@ -133,6 +133,12 @@ impl NameResCtx { TyKind::Array(typ_kind, _) => self.resolve_typ_kind(typ_kind)?, TyKind::GenericSizedArray(typ_kind, _) => self.resolve_typ_kind(typ_kind)?, TyKind::Bool => (), + TyKind::String { .. } => (), + TyKind::Tuple(typs) => { + typs.iter_mut() + .for_each(|typ| self.resolve_typ_kind(typ).unwrap()); + () + } }; Ok(()) diff --git a/src/name_resolution/expr.rs b/src/name_resolution/expr.rs index d2630fe00..b48b644b6 100644 --- a/src/name_resolution/expr.rs +++ b/src/name_resolution/expr.rs @@ -71,8 +71,8 @@ impl NameResCtx { ExprKind::Variable { module, name: _ } => { self.resolve(module, false)?; } - ExprKind::ArrayAccess { array, idx } => { - self.resolve_expr(array)?; + ExprKind::ArrayOrTupleAccess { container, idx } => { + self.resolve_expr(container)?; self.resolve_expr(idx)?; } ExprKind::ArrayDeclaration(items) => { @@ -99,11 +99,17 @@ impl NameResCtx { } } ExprKind::Bool(_) => {} + ExprKind::StringLiteral(_) => {} ExprKind::IfElse { cond, then_, else_ } => { self.resolve_expr(cond)?; self.resolve_expr(then_)?; self.resolve_expr(else_)?; } + ExprKind::TupleDeclaration(items) => { + for expr in items { + self.resolve_expr(expr)?; + } + } }; Ok(()) diff --git a/src/parser/expr.rs b/src/parser/expr.rs index b6e630300..42cf2709b 100644 --- a/src/parser/expr.rs +++ b/src/parser/expr.rs @@ -97,9 +97,14 @@ pub enum ExprKind { // TODO: change to `identifier` or `path`? Variable { module: ModulePath, name: Ident }, - /// An array access, for example: + /// An array or tuple access, for example: /// `lhs[idx]` - ArrayAccess { array: Box, idx: Box }, + /// As both almost work identical to each other expersion level we handle the cases for each container in the + /// circuit writers and typecheckers + ArrayOrTupleAccess { + container: Box, + idx: Box, + }, /// `[ ... ]` ArrayDeclaration(Vec), @@ -122,6 +127,11 @@ pub enum ExprKind { then_: Box, else_: Box, }, + /// Any string literal + StringLiteral(String), + + ///Tuple Declaration + TupleDeclaration(Vec), } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -210,6 +220,40 @@ impl Expr { // parenthesis TokenKind::LeftParen => { let mut expr = Expr::parse(ctx, tokens)?; + + // check if it is a tuple declaration + let second_token = tokens.peek(); + + match second_token { + // this means a tuple declaration + Some(Token { + kind: TokenKind::Comma, + span: _, + }) => { + let mut items = vec![expr]; + let last_span: Span; + loop { + let token = tokens.bump_err(ctx, ErrorKind::InvalidEndOfLine)?; + match token.kind { + TokenKind::RightParen => { + last_span = token.span; + break; + } + TokenKind::Comma => (), + _ => return Err(ctx.error(ErrorKind::InvalidEndOfLine, token.span)), + } + let item = Expr::parse(ctx, tokens)?; + items.push(item); + } + return Ok(Expr::new( + ctx, + ExprKind::TupleDeclaration(items), + span.merge_with(last_span), + )); + } + _ => (), + } + tokens.bump_expected(ctx, TokenKind::RightParen)?; if let ExprKind::BinaryOp { protected, .. } = &mut expr.kind { @@ -246,7 +290,8 @@ impl Expr { | ExprKind::Bool { .. } | ExprKind::BigUInt { .. } | ExprKind::FieldAccess { .. } - | ExprKind::ArrayAccess { .. } + | ExprKind::ArrayOrTupleAccess { .. } + | ExprKind::StringLiteral { .. } ) { Err(Error::new( "parse - if keyword", @@ -279,7 +324,8 @@ impl Expr { | ExprKind::Bool { .. } | ExprKind::BigUInt { .. } | ExprKind::FieldAccess { .. } - | ExprKind::ArrayAccess { .. } + | ExprKind::ArrayOrTupleAccess { .. } + | ExprKind::StringLiteral { .. } ) { Err(Error::new( "parse - if keyword", @@ -414,6 +460,7 @@ impl Expr { fn_call } + TokenKind::StringLiteral(s) => Expr::new(ctx, ExprKind::StringLiteral(s), span), // unrecognized pattern _ => { @@ -448,7 +495,7 @@ impl Expr { if !matches!( &self.kind, ExprKind::Variable { .. } - | ExprKind::ArrayAccess { .. } + | ExprKind::ArrayOrTupleAccess { .. } | ExprKind::FieldAccess { .. }, ) { return Err(ctx.error( @@ -603,7 +650,7 @@ impl Expr { parse_type_declaration(ctx, tokens, ident)? } - // array access + // array or tuple access Some(Token { kind: TokenKind::LeftBracket, .. @@ -615,7 +662,7 @@ impl Expr { self.kind, ExprKind::Variable { .. } | ExprKind::FieldAccess { .. } - | ExprKind::ArrayAccess { .. } + | ExprKind::ArrayOrTupleAccess { .. } ) { Err(Error::new( "parse_rhs - left bracket", @@ -635,8 +682,8 @@ impl Expr { Expr::new( ctx, - ExprKind::ArrayAccess { - array: Box::new(self), + ExprKind::ArrayOrTupleAccess { + container: Box::new(self), idx: Box::new(idx), }, span, @@ -689,7 +736,7 @@ impl Expr { &self.kind, ExprKind::FieldAccess { .. } | ExprKind::Variable { .. } - | ExprKind::ArrayAccess { .. } + | ExprKind::ArrayOrTupleAccess { .. } ) { let span = self.span.merge_with(period.span); return Err(ctx.error(ErrorKind::InvalidFieldAccessExpression, span)); diff --git a/src/parser/types.rs b/src/parser/types.rs index 4b79db677..9a3ee98ef 100644 --- a/src/parser/types.rs +++ b/src/parser/types.rs @@ -188,6 +188,7 @@ pub enum Symbolic { /// Generic parameter Generic(Ident), Add(Box, Box), + Sub(Box, Box), Mul(Box, Box), } @@ -198,6 +199,7 @@ impl Display for Symbolic { Symbolic::Constant(ident) => write!(f, "{}", ident.value), Symbolic::Generic(ident) => write!(f, "{}", ident.value), Symbolic::Add(lhs, rhs) => write!(f, "{} + {}", lhs, rhs), + Symbolic::Sub(lhs, rhs) => write!(f, "{} - {}", lhs, rhs), Symbolic::Mul(lhs, rhs) => write!(f, "{} * {}", lhs, rhs), } } @@ -219,7 +221,7 @@ impl Symbolic { Symbolic::Generic(ident) => { generics.insert(ident.value.clone()); } - Symbolic::Add(lhs, rhs) | Symbolic::Mul(lhs, rhs) => { + Symbolic::Add(lhs, rhs) | Symbolic::Mul(lhs, rhs) | Symbolic::Sub(lhs, rhs) => { generics.extend(lhs.extract_generics()); generics.extend(rhs.extract_generics()); } @@ -251,6 +253,7 @@ impl Symbolic { // no protected flags are needed, as this is based on expression nodes which already ordered the operations match op { Op2::Addition => Ok(Symbolic::Add(Box::new(lhs), Box::new(rhs?))), + Op2::Subtraction => Ok(Symbolic::Sub(Box::new(lhs), Box::new(rhs?))), Op2::Multiplication => Ok(Symbolic::Mul(Box::new(lhs), Box::new(rhs?))), _ => Err(Error::new( "mast", @@ -281,7 +284,9 @@ pub enum TyKind { /// A boolean (`true` or `false`). Bool, - // Tuple(Vec), + + /// A tuple is a data structure which contains many types + Tuple(Vec), // Bool, // U8, // U16, @@ -291,6 +296,9 @@ pub enum TyKind { /// This is an intermediate type. /// After monomorphized, it will be converted to `Array`. GenericSizedArray(Box, Symbolic), + + /// A string type current purpose it to pass around for logging + String(String), } impl TyKind { @@ -311,6 +319,7 @@ impl TyKind { /// - If `no_generic_allowed` is `true`, the function returns `false`. /// - If `no_generic_allowed` is `false`, the function compares the element types. /// - For `Custom` types, it compares the `module` and `name` fields for equality. + /// - For tuples, it matches type of each element i.e `self[i] == expected[i]` for every i /// - For other types, it uses a basic equality check. pub fn match_expected(&self, expected: &TyKind, no_generic_allowed: bool) -> bool { match (self, expected) { @@ -334,7 +343,21 @@ impl TyKind { name: expected_name, }, ) => module == expected_module && name == expected_name, + (TyKind::String { .. }, TyKind::String { .. }) => true, (x, y) if x == y => true, + (TyKind::Tuple(lhs), TyKind::Tuple(rhs)) => { + // if length does not match then they are of different type + if lhs.len() == rhs.len() { + let match_items = lhs + .iter() + .zip(rhs) + .filter(|&(l, r)| l.match_expected(r, no_generic_allowed)) + .count(); + lhs.len() == match_items + } else { + false + } + } _ => false, } } @@ -356,6 +379,13 @@ impl TyKind { generics.extend(ty.extract_generics()); generics.extend(sym.extract_generics()); } + TyKind::String { .. } => (), + // for the time when (([Field;N])) + TyKind::Tuple(typs) => { + for ty in typs { + generics.extend(ty.extract_generics()); + } + } } generics @@ -395,6 +425,18 @@ impl Display for TyKind { TyKind::Array(ty, size) => write!(f, "[{}; {}]", ty, size), TyKind::Bool => write!(f, "Bool"), TyKind::GenericSizedArray(ty, size) => write!(f, "[{}; {}]", ty, size), + TyKind::String(s) => write!(f, "String({})", s), + TyKind::Tuple(types) => { + write!( + f, + "({})", + types + .iter() + .map(|ty| ty.to_string()) + .collect::>() + .join(",") + ) + } } } } @@ -499,6 +541,34 @@ impl Ty { } } + // tuple type return + TokenKind::LeftParen => { + let mut typs = vec![]; + loop { + // parse the type + let ty = Ty::parse(ctx, tokens)?; + typs.push(ty.kind); + + // if next token is RightParen then return the type + let token = tokens.peek(); + match token { + Some(token) => match token.kind { + TokenKind::RightParen => { + tokens.bump(ctx); + return Ok(Ty { + kind: TyKind::Tuple(typs), + span: token.span, + }); + } + // if there is a comma just bump the tokens so we are on the type + TokenKind::Comma => _ = tokens.bump(ctx), + _ => return Err(ctx.error(ErrorKind::InvalidEndOfLine, token.span)), + }, + _ => (), + } + } + } + // unrecognized _ => Err(ctx.error(ErrorKind::InvalidType, token.span)), } @@ -547,6 +617,15 @@ impl FnSig { generics.add(name); } } + // extracts generics from interior of tuple + TyKind::Tuple(typs) => { + for ty in typs { + let extracted = ty.extract_generics(); + for name in extracted { + generics.add(name); + } + } + } _ => (), } } @@ -602,6 +681,7 @@ impl FnSig { /// Either: /// - `const NN: Field` or `[[Field; NN]; MM]` /// - `[Field; cst]`, where cst is a constant variable. We also monomorphize generic array with a constant var as its size. + /// - `([Field; cst])` when tuple type returns a generic pub fn require_monomorphization(&self) -> bool { let has_arg_cst = self .arguments @@ -626,6 +706,7 @@ impl FnSig { self.has_constant(ty) } TyKind::Array(ty, _) => self.has_constant(ty), + TyKind::Tuple(typs) => typs.iter().any(|ty| self.has_constant(ty)), _ => false, } } @@ -635,6 +716,12 @@ impl FnSig { pub fn monomorphized_name(&self) -> Ident { let mut name = self.name.clone(); + // check if it contains # in the name + if name.value.contains('#') { + // if so, then it is already monomorphized + return name; + } + if self.require_monomorphization() { let mut generics = self.generics.parameters.iter().collect::>(); generics.sort_by(|a, b| a.0.cmp(b.0)); @@ -900,6 +987,15 @@ impl FnArg { generics.insert(name); } } + // extract generics for inner type + TyKind::Tuple(typs) => { + for ty in typs { + let extracted = self.typ.kind.extract_generics(); + for name in extracted { + generics.insert(name); + } + } + } _ => (), } diff --git a/src/stdlib/bits.rs b/src/stdlib/bits.rs index ac5ca3d01..8d87fe537 100644 --- a/src/stdlib/bits.rs +++ b/src/stdlib/bits.rs @@ -12,10 +12,7 @@ use crate::{ var::{ConstOrCell, Value, Var}, }; -use super::{FnInfoType, Module}; - -const NTH_BIT_FN: &str = "nth_bit(val: Field, const nth: Field) -> Field"; -const CHECK_FIELD_SIZE_FN: &str = "check_field_size(cmp: Field)"; +use super::{builtins::Builtin, FnInfoType, Module}; pub struct BitsLib {} @@ -24,81 +21,95 @@ impl Module for BitsLib { fn get_fns() -> Vec<(&'static str, FnInfoType, bool)> { vec![ - (NTH_BIT_FN, nth_bit, false), - (CHECK_FIELD_SIZE_FN, check_field_size, false), + (NthBitFn::SIGNATURE, NthBitFn::builtin, false), + ( + CheckFieldSizeFn::SIGNATURE, + CheckFieldSizeFn::builtin, + false, + ), ] } } -fn nth_bit( - compiler: &mut CircuitWriter, - _generics: &GenericParameters, - vars: &[VarInfo], - span: Span, -) -> Result>> { - // should be two input vars - assert_eq!(vars.len(), 2); - - // these should be type checked already, unless it is called by other low level functions - // eg. builtins - let var_info = &vars[0]; - let val = &var_info.var; - assert_eq!(val.len(), 1); - - let var_info = &vars[1]; - let nth = &var_info.var; - assert_eq!(nth.len(), 1); - - let nth: usize = match &nth[0] { - ConstOrCell::Cell(_) => unreachable!("nth should be a constant"), - ConstOrCell::Const(cst) => cst.to_u64() as usize, - }; - - let val = match &val[0] { - ConstOrCell::Cell(cvar) => cvar.clone(), - ConstOrCell::Const(cst) => { - // directly return the nth bit without adding symbolic value as it doesn't depend on a cell var - let bit = cst.to_bits(); - return Ok(Some(Var::new_cvar( - ConstOrCell::Const(B::Field::from(bit[nth])), - span, - ))); - } - }; +struct NthBitFn {} +struct CheckFieldSizeFn {} + +impl Builtin for NthBitFn { + const SIGNATURE: &'static str = "nth_bit(val: Field, const nth: Field) -> Field"; + + fn builtin( + compiler: &mut CircuitWriter, + _generics: &GenericParameters, + vars: &[VarInfo], + span: Span, + ) -> Result>> { + // should be two input vars + assert_eq!(vars.len(), 2); + + // these should be type checked already, unless it is called by other low level functions + // eg. builtins + let var_info = &vars[0]; + let val = &var_info.var; + assert_eq!(val.len(), 1); + + let var_info = &vars[1]; + let nth = &var_info.var; + assert_eq!(nth.len(), 1); + + let nth: usize = match &nth[0] { + ConstOrCell::Cell(_) => unreachable!("nth should be a constant"), + ConstOrCell::Const(cst) => cst.to_u64() as usize, + }; + + let val = match &val[0] { + ConstOrCell::Cell(cvar) => cvar.clone(), + ConstOrCell::Const(cst) => { + // directly return the nth bit without adding symbolic value as it doesn't depend on a cell var + let bit = cst.to_bits(); + return Ok(Some(Var::new_cvar( + ConstOrCell::Const(B::Field::from(bit[nth])), + span, + ))); + } + }; - let bit = compiler - .backend - .new_internal_var(Value::NthBit(val.clone(), nth), span); + let bit = compiler + .backend + .new_internal_var(Value::NthBit(val.clone(), nth), span); - Ok(Some(Var::new(vec![ConstOrCell::Cell(bit)], span))) + Ok(Some(Var::new(vec![ConstOrCell::Cell(bit)], span))) + } } -// Ensure that the field size is not exceeded -fn check_field_size( - _compiler: &mut CircuitWriter, - _generics: &GenericParameters, - vars: &[VarInfo], - span: Span, -) -> Result>> { - let var = &vars[0].var[0]; - let bit_len = B::Field::MODULUS_BIT_SIZE as u64; - - match var { - ConstOrCell::Const(cst) => { - let to_cmp = cst.to_u64(); - if to_cmp >= bit_len { - return Err(Error::new( - "constraint-generation", - ErrorKind::AssertionFailed, - span, - )); +impl Builtin for CheckFieldSizeFn { + const SIGNATURE: &'static str = "check_field_size(cmp: Field)"; + + fn builtin( + _compiler: &mut CircuitWriter, + _generics: &GenericParameters, + vars: &[VarInfo], + span: Span, + ) -> Result>> { + let var = &vars[0].var[0]; + let bit_len = B::Field::MODULUS_BIT_SIZE as u64; + + match var { + ConstOrCell::Const(cst) => { + let to_cmp = cst.to_u64(); + if to_cmp >= bit_len { + return Err(Error::new( + "constraint-generation", + ErrorKind::AssertionFailed, + span, + )); + } + Ok(None) } - Ok(None) + ConstOrCell::Cell(_) => Err(Error::new( + "constraint-generation", + ErrorKind::ExpectedConstant, + span, + )), } - ConstOrCell::Cell(_) => Err(Error::new( - "constraint-generation", - ErrorKind::ExpectedConstant, - span, - )), } } diff --git a/src/stdlib/builtins.rs b/src/stdlib/builtins.rs index 65e607cdd..2c16e62fe 100644 --- a/src/stdlib/builtins.rs +++ b/src/stdlib/builtins.rs @@ -5,6 +5,7 @@ use std::sync::Arc; use ark_ff::{One, Zero}; use kimchi::o1_utils::FieldHelpers; use num_bigint::BigUint; +use regex::Regex; use crate::{ backends::Backend, @@ -12,7 +13,8 @@ use crate::{ constants::Span, error::{Error, ErrorKind, Result}, helpers::PrettyField, - parser::types::{GenericParameters, TyKind}, + parser::types::{GenericParameters, ModulePath, TyKind}, + type_checker::FullyQualified, var::{ConstOrCell, Value, Var}, }; @@ -21,10 +23,6 @@ use super::{FnInfoType, Module}; pub const QUALIFIED_BUILTINS: &str = "std/builtins"; pub const BUILTIN_FN_NAMES: [&str; 3] = ["assert", "assert_eq", "log"]; -const ASSERT_FN: &str = "assert(condition: Bool)"; -const ASSERT_EQ_FN: &str = "assert_eq(lhs: Field, rhs: Field)"; -const LOG_FN: &str = "log(var: Field)"; - pub struct BuiltinsLib {} impl Module for BuiltinsLib { @@ -32,29 +30,177 @@ impl Module for BuiltinsLib { fn get_fns() -> Vec<(&'static str, FnInfoType, bool)> { vec![ - (ASSERT_FN, assert_fn, false), - (ASSERT_EQ_FN, assert_eq_fn, false), + (AssertFn::SIGNATURE, AssertFn::builtin, false), + (AssertEqFn::SIGNATURE, AssertEqFn::builtin, true), // true -> skip argument type checking for log - (LOG_FN, log_fn, true), + (LogFn::SIGNATURE, LogFn::builtin, true), ] } } -/// Asserts that two vars are equal. -fn assert_eq_fn( - compiler: &mut CircuitWriter, - _generics: &GenericParameters, - vars: &[VarInfo], +/// Represents a comparison that needs to be made +enum Comparison { + /// Compare two variables + Vars(B::Var, B::Var), + /// Compare a variable with a constant + VarConst(B::Var, B::Field), + /// Compare two constants + Constants(B::Field, B::Field), +} + +/// Helper function to generate all comparisons +fn assert_eq_values( + compiler: &CircuitWriter, + lhs_info: &VarInfo, + rhs_info: &VarInfo, + typ: &TyKind, span: Span, -) -> Result>> { - // we get two vars - assert_eq!(vars.len(), 2); - let lhs_info = &vars[0]; - let rhs_info = &vars[1]; - - // they are both of type field - if !matches!(lhs_info.typ, Some(TyKind::Field { .. })) { - let lhs = lhs_info.typ.clone().ok_or_else(|| { +) -> Vec> { + let mut comparisons = Vec::new(); + + match typ { + // Field and Bool has the same logic + TyKind::Field { .. } | TyKind::Bool | TyKind::String(..) => { + let lhs_var = &lhs_info.var[0]; + let rhs_var = &rhs_info.var[0]; + match (lhs_var, rhs_var) { + (ConstOrCell::Const(a), ConstOrCell::Const(b)) => { + comparisons.push(Comparison::Constants(a.clone(), b.clone())); + } + (ConstOrCell::Const(cst), ConstOrCell::Cell(cvar)) + | (ConstOrCell::Cell(cvar), ConstOrCell::Const(cst)) => { + comparisons.push(Comparison::VarConst(cvar.clone(), cst.clone())); + } + (ConstOrCell::Cell(lhs), ConstOrCell::Cell(rhs)) => { + comparisons.push(Comparison::Vars(lhs.clone(), rhs.clone())); + } + } + } + + // Arrays (fixed size) + TyKind::Array(element_type, size) => { + let size = *size as usize; + let element_size = compiler.size_of(element_type); + + // compare each element recursively + for i in 0..size { + let start = i * element_size; + let mut element_comparisons = assert_eq_values( + compiler, + &VarInfo::new( + Var::new(lhs_info.var.range(start, element_size).to_vec(), span), + false, + Some(*element_type.clone()), + ), + &VarInfo::new( + Var::new(rhs_info.var.range(start, element_size).to_vec(), span), + false, + Some(*element_type.clone()), + ), + element_type, + span, + ); + comparisons.append(&mut element_comparisons); + } + } + + // Custom types (structs) + TyKind::Custom { module, name } => { + let qualified = FullyQualified::new(module, name); + let struct_info = compiler.struct_info(&qualified).expect("struct not found"); + + // compare each field recursively + let mut offset = 0; + for (_, field_type) in &struct_info.fields { + let field_size = compiler.size_of(field_type); + let mut field_comparisons = assert_eq_values( + compiler, + &VarInfo::new( + Var::new(lhs_info.var.range(offset, field_size).to_vec(), span), + false, + Some(field_type.clone()), + ), + &VarInfo::new( + Var::new(rhs_info.var.range(offset, field_size).to_vec(), span), + false, + Some(field_type.clone()), + ), + field_type, + span, + ); + comparisons.append(&mut field_comparisons); + offset += field_size; + } + } + + // GenericSizedArray should be monomorphized to Array before reaching here + // no need to handle it seperately + TyKind::GenericSizedArray(_, _) => { + unreachable!("GenericSizedArray should be monomorphized") + } + + TyKind::String(_) => todo!("String is not implemented yet"), + + TyKind::Tuple(typs) => { + let mut offset = 0; + for ty in typs { + let element_size = compiler.size_of(ty); + let mut element_comparisions = assert_eq_values( + compiler, + &VarInfo::new( + Var::new(lhs_info.var.range(offset, element_size).to_vec(), span), + false, + Some(ty.clone()), + ), + &VarInfo::new( + Var::new(rhs_info.var.range(offset, element_size).to_vec(), span), + false, + Some(ty.clone()), + ), + ty, + span, + ); + comparisons.append(&mut element_comparisions); + offset += element_size; + } + } + } + + comparisons +} + +pub trait Builtin { + const SIGNATURE: &'static str; + + fn builtin( + compiler: &mut CircuitWriter, + generics: &GenericParameters, + vars: &[VarInfo], + span: Span, + ) -> Result>>; +} + +struct AssertEqFn {} +struct AssertFn {} +struct LogFn {} + +impl Builtin for AssertEqFn { + const SIGNATURE: &'static str = "assert_eq(lhs: Field, rhs: Field)"; + + /// Asserts that two vars are equal. + fn builtin( + compiler: &mut CircuitWriter, + _generics: &GenericParameters, + vars: &[VarInfo], + span: Span, + ) -> Result>> { + // we get two vars + assert_eq!(vars.len(), 2); + let lhs_info = &vars[0]; + let rhs_info = &vars[1]; + + // get types of both arguments + let lhs_type = lhs_info.typ.as_ref().ok_or_else(|| { Error::new( "constraint-generation", ErrorKind::UnexpectedError("No type info for lhs of assertion"), @@ -62,15 +208,7 @@ fn assert_eq_fn( ) })?; - Err(Error::new( - "constraint-generation", - ErrorKind::AssertTypeMismatch("rhs", lhs), - span, - ))? - } - - if !matches!(rhs_info.typ, Some(TyKind::Field { .. })) { - let rhs = rhs_info.typ.clone().ok_or_else(|| { + let rhs_type = rhs_info.typ.as_ref().ok_or_else(|| { Error::new( "constraint-generation", ErrorKind::UnexpectedError("No type info for rhs of assertion"), @@ -78,90 +216,96 @@ fn assert_eq_fn( ) })?; - Err(Error::new( - "constraint-generation", - ErrorKind::AssertTypeMismatch("rhs", rhs), - span, - ))? - } + // they have the same type + if !lhs_type.match_expected(rhs_type, false) { + return Err(Error::new( + "constraint-generation", + ErrorKind::AssertEqTypeMismatch(lhs_type.clone(), rhs_type.clone()), + span, + )); + } - // retrieve the values - let lhs_var = &lhs_info.var; - assert_eq!(lhs_var.len(), 1); - let lhs_cvar = &lhs_var[0]; - - let rhs_var = &rhs_info.var; - assert_eq!(rhs_var.len(), 1); - let rhs_cvar = &rhs_var[0]; - - match (lhs_cvar, rhs_cvar) { - // two constants - (ConstOrCell::Const(a), ConstOrCell::Const(b)) => { - if a != b { - Err(Error::new( - "constraint-generation", - ErrorKind::AssertionFailed, - span, - ))? + // first collect all comparisons needed + let comparisons = assert_eq_values(compiler, lhs_info, rhs_info, lhs_type, span); + + // then add all the constraints + for comparison in comparisons { + match comparison { + Comparison::Vars(lhs, rhs) => { + compiler.backend.assert_eq_var(&lhs, &rhs, span); + } + Comparison::VarConst(var, constant) => { + compiler.backend.assert_eq_const(&var, constant, span); + } + Comparison::Constants(a, b) => { + if a != b { + return Err(Error::new( + "constraint-generation", + ErrorKind::AssertionFailed, + span, + )); + } + } } } - // a const and a var - (ConstOrCell::Const(cst), ConstOrCell::Cell(cvar)) - | (ConstOrCell::Cell(cvar), ConstOrCell::Const(cst)) => { - compiler.backend.assert_eq_const(cvar, *cst, span) - } - (ConstOrCell::Cell(lhs), ConstOrCell::Cell(rhs)) => { - compiler.backend.assert_eq_var(lhs, rhs, span) - } + Ok(None) } - - Ok(None) } -/// Asserts that a condition is true. -fn assert_fn( - compiler: &mut CircuitWriter, - _generics: &GenericParameters, - vars: &[VarInfo], - span: Span, -) -> Result>> { - // we get a single var - assert_eq!(vars.len(), 1); - - // of type bool - let var_info = &vars[0]; - assert!(matches!(var_info.typ, Some(TyKind::Bool))); - - // of only one field element - let var = &var_info.var; - assert_eq!(var.len(), 1); - let cond = &var[0]; - - match cond { - ConstOrCell::Const(cst) => { - assert!(cst.is_one()); - } - ConstOrCell::Cell(cvar) => { - let one = B::Field::one(); - compiler.backend.assert_eq_const(cvar, one, span); +impl Builtin for AssertFn { + const SIGNATURE: &'static str = "assert(condition: Bool)"; + + /// Asserts that a condition is true. + fn builtin( + compiler: &mut CircuitWriter, + _generics: &GenericParameters, + vars: &[VarInfo], + span: Span, + ) -> Result::Field, ::Var>>> { + // we get a single var + assert_eq!(vars.len(), 1); + + // of type bool + let var_info = &vars[0]; + assert!(matches!(var_info.typ, Some(TyKind::Bool))); + + // of only one field element + let var = &var_info.var; + assert_eq!(var.len(), 1); + let cond = &var[0]; + + match cond { + ConstOrCell::Const(cst) => { + assert!(cst.is_one()); + } + ConstOrCell::Cell(cvar) => { + let one = B::Field::one(); + compiler.backend.assert_eq_const(cvar, one, span); + } } - } - Ok(None) + Ok(None) + } } -/// Logging -fn log_fn( - compiler: &mut CircuitWriter, - generics: &GenericParameters, - vars: &[VarInfo], - span: Span, -) -> Result>> { - for var in vars { - // todo: will need to support string argument in order to customize msg - compiler.backend.log_var(var, "log".to_owned(), span); - } +impl Builtin for LogFn { + // todo: currently only supports a single field var + // to support all the types, we can bypass the type check for this log function for now + const SIGNATURE: &'static str = "log(var: Field)"; - Ok(None) + /// Logging + fn builtin( + compiler: &mut CircuitWriter, + _generics: &GenericParameters, + vars: &[VarInfo], + span: Span, + ) -> Result>> { + for var in vars { + // todo: will need to support string argument in order to customize msg + compiler.backend.log_var(var, span); + } + + Ok(None) + } } diff --git a/src/stdlib/crypto.rs b/src/stdlib/crypto.rs index 66113cddd..13ff91a86 100644 --- a/src/stdlib/crypto.rs +++ b/src/stdlib/crypto.rs @@ -1,14 +1,27 @@ -use super::{FnInfoType, Module}; +use super::{builtins::Builtin, FnInfoType, Module}; use crate::backends::Backend; -const POSEIDON_FN: &str = "poseidon(input: [Field; 2]) -> [Field; 3]"; - pub struct CryptoLib {} impl Module for CryptoLib { const MODULE: &'static str = "crypto"; fn get_fns() -> Vec<(&'static str, FnInfoType, bool)> { - vec![(POSEIDON_FN, B::poseidon(), false)] + vec![(PoseidonFn::SIGNATURE, PoseidonFn::builtin, false)] + } +} + +struct PoseidonFn {} + +impl Builtin for PoseidonFn { + const SIGNATURE: &'static str = "poseidon(input: [Field; 2]) -> [Field; 3]"; + + fn builtin( + compiler: &mut crate::circuit_writer::CircuitWriter, + generics: &crate::parser::types::GenericParameters, + vars: &[crate::circuit_writer::VarInfo], + span: crate::constants::Span, + ) -> crate::error::Result>> { + B::poseidon()(compiler, generics, vars, span) } } diff --git a/src/stdlib/int.rs b/src/stdlib/int.rs index 03c574890..94f40ff89 100644 --- a/src/stdlib/int.rs +++ b/src/stdlib/int.rs @@ -11,9 +11,7 @@ use crate::{ var::{ConstOrCell, Value, Var}, }; -use super::{FnInfoType, Module}; - -const DIVMOD_FN: &str = "divmod(dividend: Field, divisor: Field) -> [Field; 2]"; +use super::{builtins::Builtin, FnInfoType, Module}; pub struct IntLib {} @@ -21,57 +19,63 @@ impl Module for IntLib { const MODULE: &'static str = "int"; fn get_fns() -> Vec<(&'static str, FnInfoType, bool)> { - vec![(DIVMOD_FN, divmod_fn, false)] + vec![(DivmodFn::SIGNATURE, DivmodFn::builtin, false)] } } /// Divides two field elements and returns the quotient and remainder. -fn divmod_fn( - compiler: &mut CircuitWriter, - _generics: &GenericParameters, - vars: &[VarInfo], - span: Span, -) -> Result>> { - // we get two vars - let dividend_info = &vars[0]; - let divisor_info = &vars[1]; - - // retrieve the values - let dividend_var = ÷nd_info.var[0]; - let divisor_var = &divisor_info.var[0]; - - match (dividend_var, divisor_var) { - // two constants - (ConstOrCell::Const(a), ConstOrCell::Const(b)) => { - // convert to bigints - let a = a.to_biguint(); - let b = b.to_biguint(); - - let quotient = a.clone() / b.clone(); - let remainder = a % b; - - // convert back to fields - let quotient = B::Field::from_biguint("ient).unwrap(); - let remainder = B::Field::from_biguint(&remainder).unwrap(); - - Ok(Some(Var::new( - vec![ConstOrCell::Const(quotient), ConstOrCell::Const(remainder)], - span, - ))) - } +struct DivmodFn {} + +impl Builtin for DivmodFn { + const SIGNATURE: &'static str = "divmod(dividend: Field, divisor: Field) -> [Field; 2]"; + + fn builtin( + compiler: &mut CircuitWriter, + _generics: &GenericParameters, + vars: &[VarInfo], + span: Span, + ) -> Result>> { + // we get two vars + let dividend_info = &vars[0]; + let divisor_info = &vars[1]; + + // retrieve the values + let dividend_var = ÷nd_info.var[0]; + let divisor_var = &divisor_info.var[0]; + + match (dividend_var, divisor_var) { + // two constants + (ConstOrCell::Const(a), ConstOrCell::Const(b)) => { + // convert to bigints + let a = a.to_biguint(); + let b = b.to_biguint(); + + let quotient = a.clone() / b.clone(); + let remainder = a % b; + + // convert back to fields + let quotient = B::Field::from_biguint("ient).unwrap(); + let remainder = B::Field::from_biguint(&remainder).unwrap(); + + Ok(Some(Var::new( + vec![ConstOrCell::Const(quotient), ConstOrCell::Const(remainder)], + span, + ))) + } + + _ => { + let quotient = compiler + .backend + .new_internal_var(Value::Div(dividend_var.clone(), divisor_var.clone()), span); + let remainder = compiler + .backend + .new_internal_var(Value::Mod(dividend_var.clone(), divisor_var.clone()), span); - _ => { - let quotient = compiler - .backend - .new_internal_var(Value::Div(dividend_var.clone(), divisor_var.clone()), span); - let remainder = compiler - .backend - .new_internal_var(Value::Mod(dividend_var.clone(), divisor_var.clone()), span); - - Ok(Some(Var::new( - vec![ConstOrCell::Cell(quotient), ConstOrCell::Cell(remainder)], - span, - ))) + Ok(Some(Var::new( + vec![ConstOrCell::Cell(quotient), ConstOrCell::Cell(remainder)], + span, + ))) + } } } } diff --git a/src/tests/examples.rs b/src/tests/examples.rs index e86d85874..99003152c 100644 --- a/src/tests/examples.rs +++ b/src/tests/examples.rs @@ -839,6 +839,25 @@ fn test_generic_array_nested(#[case] backend: BackendKind) -> miette::Result<()> Ok(()) } +#[rstest] +#[case::kimchi_vesta(BackendKind::KimchiVesta(KimchiVesta::new(false)))] +#[case::r1cs(BackendKind::R1csBls12_381(R1CS::new()))] +fn test_generic_assert_eq(#[case] backend: BackendKind) -> miette::Result<()> { + let public_inputs = r#"{"public_arr": ["1", "2"], "public_input": "3"}"#; + let private_inputs = r#"{}"#; + + test_file( + "generic_assert_eq", + public_inputs, + private_inputs, + vec![], + backend, + DEFAULT_OPTIONS, + )?; + + Ok(()) +} + #[rstest] #[case::kimchi_vesta(BackendKind::KimchiVesta(KimchiVesta::new(false)))] #[case::r1cs(BackendKind::R1csBls12_381(R1CS::new()))] diff --git a/src/type_checker/checker.rs b/src/type_checker/checker.rs index 0f12178a3..9db30594a 100644 --- a/src/type_checker/checker.rs +++ b/src/type_checker/checker.rs @@ -293,16 +293,16 @@ impl TypeChecker { name.value.clone() } - // `array[idx] = ` - ExprKind::ArrayAccess { array, idx } => { - // get variable behind array - let array_node = self - .compute_type(array, typed_fn_env)? - .expect("type-checker bug: array access on an empty var"); - - array_node + // `array[idx] = ` or `tuple[idx] = rhs` + ExprKind::ArrayOrTupleAccess { container, idx } => { + // get variable behind container + let cotainer_node = self + .compute_type(container, typed_fn_env)? + .expect("type-checker bug: array or tuple access on an empty var"); + + cotainer_node .var_name - .expect("anonymous array access cannot be mutated") + .expect("anonymous array or tuple access cannot be mutated") } // `struct.field = ` @@ -443,6 +443,8 @@ impl TypeChecker { ExprKind::Bool(_) => Some(ExprTyInfo::new_anon(TyKind::Bool)), + ExprKind::StringLiteral(s) => Some(ExprTyInfo::new_anon(TyKind::String(s.clone()))), + // mod::path.of.var ExprKind::Variable { module, name } => { let qualified = FullyQualified::new(module, &name.value); @@ -478,13 +480,16 @@ impl TypeChecker { } } - ExprKind::ArrayAccess { array, idx } => { + ExprKind::ArrayOrTupleAccess { container, idx } => { // get type of lhs - let typ = self.compute_type(array, typed_fn_env)?.unwrap(); + let typ = self.compute_type(container, typed_fn_env)?.unwrap(); - // check that it is an array - if !matches!(typ.typ, TyKind::Array(..) | TyKind::GenericSizedArray(..)) { - Err(self.error(ErrorKind::ArrayAccessOnNonArray, expr.span))? + // check that it is an array or tuple + if !matches!( + typ.typ, + TyKind::Array(..) | TyKind::GenericSizedArray(..) | TyKind::Tuple(..) + ) { + Err(self.error(ErrorKind::AccessOnNonCollection, expr.span))? } // check that expression is a bigint @@ -498,6 +503,19 @@ impl TypeChecker { let el_typ = match typ.typ { TyKind::Array(typkind, _) => *typkind, TyKind::GenericSizedArray(typkind, _) => *typkind, + TyKind::Tuple(typs) => match &idx.kind { + ExprKind::BigUInt(index) => { + let idx = index.to_usize().unwrap(); + if idx >= typs.len() { + return Err(self.error( + ErrorKind::TupleIndexOutofBounds(idx, typs.len()), + expr.span, + )); + } + typs[idx].clone() + } + _ => return Err(self.error(ErrorKind::ExpectedConstant, expr.span)), + }, _ => Err(self.error(ErrorKind::UnexpectedError("not an array"), expr.span))?, }; @@ -532,6 +550,20 @@ impl TypeChecker { let res = ExprTyInfo::new_anon(TyKind::Array(Box::new(tykind), len)); Some(res) } + ExprKind::TupleDeclaration(items) => { + // restricting tuple len as array len + let _: u32 = items.len().try_into().expect("tuple too large"); + let typs: Vec = items + .iter() + .map(|item| { + self.compute_type(item, typed_fn_env) + .unwrap() + .expect("expected some val") + .typ + }) + .collect(); + Some(ExprTyInfo::new_anon(TyKind::Tuple(typs))) + } ExprKind::IfElse { cond, then_, else_ } => { // cond can only be a boolean @@ -550,7 +582,7 @@ impl TypeChecker { &then_.kind, ExprKind::Variable { .. } | ExprKind::FieldAccess { .. } - | ExprKind::ArrayAccess { .. } + | ExprKind::ArrayOrTupleAccess { .. } ) { return Err(self.error(ErrorKind::IfElseInvalidIfBranch(), then_.span)); } @@ -559,7 +591,7 @@ impl TypeChecker { &else_.kind, ExprKind::Variable { .. } | ExprKind::FieldAccess { .. } - | ExprKind::ArrayAccess { .. } + | ExprKind::ArrayOrTupleAccess { .. } ) { return Err(self.error(ErrorKind::IfElseInvalidElseBranch(), else_.span)); } @@ -888,10 +920,12 @@ impl TypeChecker { // check argument length if expected.len() != observed.len() { - return Err(self.error( - ErrorKind::MismatchFunctionArguments(observed.len(), expected.len()), - span, - )); + if !ignore_arg_types { + return Err(self.error( + ErrorKind::MismatchFunctionArguments(observed.len(), expected.len()), + span, + )); + } } // skip argument type checking if ignore_arg_types is true diff --git a/src/type_checker/mod.rs b/src/type_checker/mod.rs index bbe14c8ad..47dd62009 100644 --- a/src/type_checker/mod.rs +++ b/src/type_checker/mod.rs @@ -147,6 +147,8 @@ impl TypeChecker { unreachable!("generic arrays should have been resolved") } TyKind::Bool => 1, + TyKind::String(s) => s.len(), + TyKind::Tuple(typs) => typs.iter().map(|ty| self.size_of(ty)).sum(), } } @@ -514,6 +516,7 @@ impl TypeChecker { TyKind::Field { constant: false } | TyKind::Custom { .. } | TyKind::Array(_, _) + | TyKind::Tuple(_) | TyKind::Bool => { typed_fn_env.store_type( "public_output".to_string(), @@ -522,6 +525,9 @@ impl TypeChecker { } TyKind::Field { constant: true } => unreachable!(), TyKind::GenericSizedArray(_, _) => unreachable!(), + TyKind::String(_) => { + todo!("String Type for circuits is not implemented") + } } } } diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 6f212f307..e59f4bae3 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -1,4 +1,20 @@ +use crate::{ + backends::Backend, + circuit_writer::VarInfo, + constants::Span, + error::{Error, ErrorKind, Result}, + helpers::PrettyField, + mast::Mast, + parser::types::{ModulePath, TyKind}, + type_checker::FullyQualified, + var::ConstOrCell, + witness::WitnessEnv, +}; use std::fmt::Write; +use std::slice::Iter; + +use num_traits::One; +use regex::{Captures, Regex}; pub fn noname_version() -> String { format!("@ noname.{}\n", env!("CARGO_PKG_VERSION")) @@ -86,3 +102,353 @@ yz ); } } + +// for failable replacer this is the recommended approach by the author of regex lib https://github.com/rust-lang/regex/issues/648#issuecomment-590072186 +// I have made Fn -> FnMut because our replace mutates the iterator by moving it forward +fn replace_all( + re: &Regex, + haystack: &str, + mut replacement: impl FnMut(&Captures) -> Result, +) -> Result { + let mut new = String::with_capacity(haystack.len()); + let mut last_match = 0; + for caps in re.captures_iter(haystack) { + let m = caps.get(0).unwrap(); + new.push_str(&haystack[last_match..m.start()]); + new.push_str(&replacement(&caps)?); + last_match = m.end(); + } + new.push_str(&haystack[last_match..]); + Ok(new) +} + +pub fn log_string_type( + backend: &B, + logs_iter: &mut Iter<'_, (Span, VarInfo)>, + str: &str, + witness: &mut WitnessEnv, + typed: &Mast, + span: &Span, +) -> Result { + let re = Regex::new(r"\{\s*\}").unwrap(); + let replacer = |_: &Captures| { + let (span, var) = match logs_iter.next() { + Some((span, var)) => (span, var), + None => return Err(Error::new("log", ErrorKind::InsufficientVariables, *span)), + }; + let replacement = match &var.typ { + Some(TyKind::Field { .. }) => match &var.var[0] { + ConstOrCell::Const(cst) => Ok(cst.pretty()), + ConstOrCell::Cell(cell) => { + let val = backend.compute_var(witness, cell).unwrap(); + Ok(val.pretty()) + } + }, + // Bool + Some(TyKind::Bool) => match &var.var[0] { + ConstOrCell::Const(cst) => { + let val = *cst == B::Field::one(); + Ok(val.to_string()) + } + ConstOrCell::Cell(cell) => { + let val = backend.compute_var(witness, cell)? == B::Field::one(); + Ok(val.to_string()) + } + }, + + // Array + Some(TyKind::Array(b, s)) => { + let mut typs = Vec::with_capacity(*s as usize); + for _ in 0..(*s) { + typs.push((**b).clone()); + } + let (output, remaining) = log_array_or_tuple_type( + backend, + &var.var.cvars, + &typs[..], + *s, + witness, + typed, + span, + false, + ) + .unwrap(); + assert!(remaining.is_empty()); + Ok(output) + } + + // Custom types + Some(TyKind::Custom { + module, + name: struct_name, + }) => { + let mut string_vec = Vec::new(); + let (output, remaining) = log_custom_type( + backend, + module, + struct_name, + typed, + &var.var.cvars, + witness, + span, + &mut string_vec, + ) + .unwrap(); + assert!(remaining.is_empty()); + Ok(output) + } + + // GenericSizedArray + Some(TyKind::GenericSizedArray(_, _)) => { + unreachable!("GenericSizedArray should be monomorphized") + } + Some(TyKind::String(_)) => todo!("String cannot be in circuit yet"), + + Some(TyKind::Tuple(typs)) => { + println!("{:?}", typs); + let len = typs.len(); + let (output, remaining) = log_array_or_tuple_type( + backend, + &var.var.cvars, + &typs, + len as u32, + witness, + typed, + span, + true, + ) + .unwrap(); + assert!(remaining.is_empty()); + Ok(output) + } + None => { + return Err(Error::new( + "log", + ErrorKind::UnexpectedError("No type info for logging"), + *span, + )) + } + }; + replacement + }; + replace_all(&re, str, replacer) +} + +pub fn log_array_or_tuple_type( + backend: &B, + var_info_var: &[ConstOrCell], + typs: &[TyKind], + size: u32, + witness: &mut WitnessEnv, + typed: &Mast, + span: &Span, + is_tuple: bool, +) -> Result<(String, Vec>)> { + let mut remaining = var_info_var.to_vec(); + let mut nested_result = Vec::new(); + + for i in 0..size { + let base_type = &typs[i as usize]; + let (chunk_result, new_remaining) = match base_type { + TyKind::Field { .. } => { + let value = match &remaining[0] { + ConstOrCell::Const(cst) => cst.pretty(), + ConstOrCell::Cell(cell) => { + let val = backend.compute_var(witness, cell).unwrap(); + val.pretty() + } + }; + (value, remaining[1..].to_vec()) + } + // Bool + TyKind::Bool => { + let value = match &remaining[0] { + ConstOrCell::Const(cst) => { + let val = *cst == B::Field::one(); + val.to_string() + } + ConstOrCell::Cell(cell) => { + let val = backend.compute_var(witness, cell)? == B::Field::one(); + val.to_string() + } + }; + (value, remaining[1..].to_vec()) + } + TyKind::Array(inner_type, inner_size) => { + let mut vec_inner_type = Vec::with_capacity(remaining.len()); + for _ in 0..remaining.len() { + vec_inner_type.push((**inner_type).clone()); + } + let is_tuple = match **inner_type { + TyKind::Tuple(_) => true, + _ => false, + }; + log_array_or_tuple_type( + backend, + &remaining, + &vec_inner_type[..], + *inner_size, + witness, + typed, + span, + is_tuple, + )? + } + + // Custom types + TyKind::Custom { + module, + name: struct_name, + } => { + let mut string_vec = Vec::new(); + let (output, new_remaining) = log_custom_type( + backend, + module, + struct_name, + typed, + &remaining, + witness, + span, + &mut string_vec, + )?; + (format!("{}{}", struct_name, output), new_remaining) + } + + // GenericSizedArray + TyKind::GenericSizedArray(_, _) => { + unreachable!("GenericSizedArray should be monomorphized") + } + TyKind::String(_) => todo!("String cannot be in circuit yet"), + + TyKind::Tuple(inner_typs) => { + let inner_size = inner_typs.len(); + log_array_or_tuple_type( + backend, + &remaining, + &inner_typs, + inner_size as u32, + witness, + typed, + span, + true, + )? + } + }; + nested_result.push(chunk_result); + remaining = new_remaining; + } + + if is_tuple { + Ok((format!("({})", nested_result.join(",")), remaining)) + } else { + Ok((format!("[{}]", nested_result.join(",")), remaining)) + } +} +pub fn log_custom_type( + backend: &B, + module: &ModulePath, + struct_name: &String, + typed: &Mast, + var_info_var: &[ConstOrCell], + witness: &mut WitnessEnv, + span: &Span, + string_vec: &mut Vec, +) -> Result<(String, Vec>)> { + let qualified = FullyQualified::new(module, struct_name); + let struct_info = typed + .struct_info(&qualified) + .ok_or( + typed + .0 + .error(ErrorKind::UnexpectedError("struct not found"), *span), + ) + .unwrap(); + + let mut remaining = var_info_var.to_vec(); + + for (field_name, field_typ) in &struct_info.fields { + let len = typed.size_of(field_typ); + match field_typ { + TyKind::Field { .. } => match &remaining[0] { + ConstOrCell::Const(cst) => { + string_vec.push(format!("{field_name}: {}", cst.pretty())); + remaining = remaining[len..].to_vec(); + } + ConstOrCell::Cell(cell) => { + let val = backend.compute_var(witness, cell)?; + string_vec.push(format!("{field_name}: {}", val.pretty())); + remaining = remaining[len..].to_vec(); + } + }, + + TyKind::Bool => match &remaining[0] { + ConstOrCell::Const(cst) => { + let val = *cst == B::Field::one(); + string_vec.push(format!("{field_name}: {}", val)); + remaining = remaining[len..].to_vec(); + } + ConstOrCell::Cell(cell) => { + let val = backend.compute_var(witness, cell)? == B::Field::one(); + string_vec.push(format!("{field_name}: {}", val)); + remaining = remaining[len..].to_vec(); + } + }, + + TyKind::Array(b, s) => { + let len = remaining.len(); + let mut typs: Vec = Vec::with_capacity(len); + typs.push((**b).clone()); + + let (output, new_remaining) = log_array_or_tuple_type( + backend, + &remaining, + &typs[..], + *s, + witness, + typed, + span, + false, + )?; + string_vec.push(format!("{field_name}: {}", output)); + remaining = new_remaining; + } + + TyKind::Custom { + module, + name: struct_name, + } => { + let mut custom_string_vec = Vec::new(); + let (output, new_remaining) = log_custom_type( + backend, + module, + struct_name, + typed, + &remaining, + witness, + span, + &mut custom_string_vec, + )?; + string_vec.push(format!("{}: {}{}", field_name, struct_name, output)); + remaining = new_remaining; + } + + TyKind::GenericSizedArray(_, _) => { + unreachable!("GenericSizedArray should be monomorphized") + } + TyKind::String(s) => { + todo!("String cannot be a type for customs it is only for logging") + } + TyKind::Tuple(typs) => { + let len = typs.len(); + let (output, new_remaining) = log_array_or_tuple_type( + backend, &remaining, &typs, len as u32, witness, typed, span, true, + ) + .unwrap(); + string_vec.push(format!("{field_name}: {}", output)); + remaining = new_remaining; + } + } + } + + Ok((format!("{{ {} }}", string_vec.join(", ")), remaining)) +}