diff --git a/src/circuit_writer/mod.rs b/src/circuit_writer/mod.rs index efd4b0d1f..47b93f0d7 100644 --- a/src/circuit_writer/mod.rs +++ b/src/circuit_writer/mod.rs @@ -152,7 +152,7 @@ impl CircuitWriter { let main_fn_info = circuit_writer.main_info()?; let function = match &main_fn_info.kind { - crate::imports::FnKind::BuiltIn(_, _) => unreachable!(), + crate::imports::FnKind::BuiltIn(_, _, _) => unreachable!(), crate::imports::FnKind::Native(fn_sig) => fn_sig.clone(), }; diff --git a/src/circuit_writer/writer.rs b/src/circuit_writer/writer.rs index a2a56ca9b..8f72a699a 100644 --- a/src/circuit_writer/writer.rs +++ b/src/circuit_writer/writer.rs @@ -424,7 +424,7 @@ impl CircuitWriter { match &fn_info.kind { // assert() <-- for example - FnKind::BuiltIn(sig, handle) => { + FnKind::BuiltIn(sig, handle, _) => { let res = handle(self, &sig.generics, &vars, expr.span); res.map(|r| r.map(VarOrRef::Var)) } diff --git a/src/imports.rs b/src/imports.rs index 6dcf94daf..323d0ce72 100644 --- a/src/imports.rs +++ b/src/imports.rs @@ -74,7 +74,8 @@ where B: Backend, { /// A built-in is just a handle to a function written in Rust. - BuiltIn(FnSig, FnHandle), + /// The boolean flag indicates whether to skip argument type checking for builtins + BuiltIn(FnSig, FnHandle, bool), /// A native function is represented as an AST. Native(FunctionDef), @@ -96,7 +97,7 @@ mod fn_kind_serde { S: Serializer, { let surrogate = match self { - FnKind::BuiltIn(sig, _handle) => { + FnKind::BuiltIn(sig, _handle, _) => { FnKindSurrogate::BuiltIn(sig.clone(), "native function".to_string()) } FnKind::Native(def) => FnKindSurrogate::Native(def.clone()), diff --git a/src/mast/mod.rs b/src/mast/mod.rs index d955bc485..2bbedfd5a 100644 --- a/src/mast/mod.rs +++ b/src/mast/mod.rs @@ -188,7 +188,7 @@ impl FnInfo { ctx: &mut MastCtx, ) -> Result { match self.kind { - FnKind::BuiltIn(ref mut sig, _) => { + FnKind::BuiltIn(ref mut sig, _, _) => { sig.resolve_generic_values(observed_args, ctx)?; } FnKind::Native(ref mut func) => { @@ -1363,9 +1363,9 @@ pub fn instantiate_fn_call( // construct the monomorphized function AST let (func_def, mono_info) = match fn_info.kind { - FnKind::BuiltIn(_, handle) => ( + FnKind::BuiltIn(_, handle, ignore_arg_types) => ( FnInfo { - kind: FnKind::BuiltIn(sig_typed, handle), + kind: FnKind::BuiltIn(sig_typed, handle, ignore_arg_types), ..fn_info }, // todo: we will need to propagate the constant value from builtin function as well diff --git a/src/negative_tests.rs b/src/negative_tests.rs index 1796fe0c8..dd557cf71 100644 --- a/src/negative_tests.rs +++ b/src/negative_tests.rs @@ -623,7 +623,7 @@ fn test_hint_builtin_fn(qualified: &FullyQualified, code: &str) -> Result } let fn_info = FnInfo { - kind: FnKind::BuiltIn(sig, mocked_builtin_fn::), + kind: FnKind::BuiltIn(sig, mocked_builtin_fn::, false), is_hint: true, span: Span::default(), }; diff --git a/src/stdlib/bits.rs b/src/stdlib/bits.rs index 6460547e9..c9f55c20d 100644 --- a/src/stdlib/bits.rs +++ b/src/stdlib/bits.rs @@ -22,10 +22,10 @@ pub struct BitsLib {} impl Module for BitsLib { const MODULE: &'static str = "bits"; - fn get_fns() -> Vec<(&'static str, FnInfoType)> { + fn get_fns() -> Vec<(&'static str, FnInfoType, bool)> { vec![ - (NTH_BIT_FN, nth_bit), - (CHECK_FIELD_SIZE_FN, check_field_size), + (NTH_BIT_FN, nth_bit, false), + (CHECK_FIELD_SIZE_FN, check_field_size, false), ] } } diff --git a/src/stdlib/builtins.rs b/src/stdlib/builtins.rs index d0599fdda..65e607cdd 100644 --- a/src/stdlib/builtins.rs +++ b/src/stdlib/builtins.rs @@ -23,8 +23,6 @@ pub const BUILTIN_FN_NAMES: [&str; 3] = ["assert", "assert_eq", "log"]; const ASSERT_FN: &str = "assert(condition: Bool)"; const ASSERT_EQ_FN: &str = "assert_eq(lhs: Field, rhs: Field)"; -// todo: currently only supports a single field var -// to support all the types, we can bypass the type check for this log function for now const LOG_FN: &str = "log(var: Field)"; pub struct BuiltinsLib {} @@ -32,11 +30,12 @@ pub struct BuiltinsLib {} impl Module for BuiltinsLib { const MODULE: &'static str = "builtins"; - fn get_fns() -> Vec<(&'static str, FnInfoType)> { + fn get_fns() -> Vec<(&'static str, FnInfoType, bool)> { vec![ - (ASSERT_FN, assert_fn), - (ASSERT_EQ_FN, assert_eq_fn), - (LOG_FN, log_fn), + (ASSERT_FN, assert_fn, false), + (ASSERT_EQ_FN, assert_eq_fn, false), + // true -> skip argument type checking for log + (LOG_FN, log_fn, true), ] } } diff --git a/src/stdlib/crypto.rs b/src/stdlib/crypto.rs index 857c8ec20..66113cddd 100644 --- a/src/stdlib/crypto.rs +++ b/src/stdlib/crypto.rs @@ -8,7 +8,7 @@ pub struct CryptoLib {} impl Module for CryptoLib { const MODULE: &'static str = "crypto"; - fn get_fns() -> Vec<(&'static str, FnInfoType)> { - vec![(POSEIDON_FN, B::poseidon())] + fn get_fns() -> Vec<(&'static str, FnInfoType, bool)> { + vec![(POSEIDON_FN, B::poseidon(), false)] } } diff --git a/src/stdlib/int.rs b/src/stdlib/int.rs index 7c5b8bfc2..03c574890 100644 --- a/src/stdlib/int.rs +++ b/src/stdlib/int.rs @@ -20,8 +20,8 @@ pub struct IntLib {} impl Module for IntLib { const MODULE: &'static str = "int"; - fn get_fns() -> Vec<(&'static str, FnInfoType)> { - vec![(DIVMOD_FN, divmod_fn)] + fn get_fns() -> Vec<(&'static str, FnInfoType, bool)> { + vec![(DIVMOD_FN, divmod_fn, false)] } } diff --git a/src/stdlib/mod.rs b/src/stdlib/mod.rs index 32b437f3f..1c8dde950 100644 --- a/src/stdlib/mod.rs +++ b/src/stdlib/mod.rs @@ -71,18 +71,18 @@ trait Module { /// e.g. "crypto" const MODULE: &'static str; - fn get_fns() -> Vec<(&'static str, FnInfoType)>; + fn get_fns() -> Vec<(&'static str, FnInfoType, bool)>; fn get_parsed_fns() -> Vec> { let fns = Self::get_fns(); let mut res = Vec::with_capacity(fns.len()); - for (code, fn_handle) in fns { + for (code, fn_handle, ignore_arg_types) in fns { let ctx = &mut ParserCtx::default(); // TODO: we should try to point to real noname files here (not 0) let mut tokens = Token::parse(0, code).unwrap(); let sig = FnSig::parse(ctx, &mut tokens).unwrap(); res.push(FnInfo { - kind: FnKind::BuiltIn(sig, fn_handle), + kind: FnKind::BuiltIn(sig, fn_handle, ignore_arg_types), is_hint: false, span: Span::default(), }); diff --git a/src/type_checker/checker.rs b/src/type_checker/checker.rs index cf8dda009..150072bdd 100644 --- a/src/type_checker/checker.rs +++ b/src/type_checker/checker.rs @@ -5,15 +5,18 @@ use serde::{Deserialize, Serialize}; use crate::{ backends::Backend, + cli::packages::UserRepo, constants::Span, error::{ErrorKind, Result}, imports::FnKind, parser::{ types::{ - is_numeric, FnSig, ForLoopArgument, FunctionDef, Stmt, StmtKind, Symbolic, Ty, TyKind, + is_numeric, FnSig, ForLoopArgument, FunctionDef, ModulePath, Stmt, StmtKind, Symbolic, + Ty, TyKind, }, CustomType, Expr, ExprKind, Op2, }, + stdlib::builtins::QUALIFIED_BUILTINS, syntax::is_type, }; @@ -38,7 +41,7 @@ where impl FnInfo { pub fn sig(&self) -> &FnSig { match &self.kind { - FnKind::BuiltIn(sig, _) => sig, + FnKind::BuiltIn(sig, _, _) => sig, FnKind::Native(func) => &func.sig, } } @@ -833,6 +836,19 @@ impl TypeChecker { None => (), }; + // get the ignore_arg_types flag from the function info if it's a builtin + let ignore_arg_types = match self + .fn_info(&FullyQualified::new( + &ModulePath::Absolute(UserRepo::new(QUALIFIED_BUILTINS)), + &fn_sig.name.value, + )) + .map(|info| &info.kind) + { + // check builtin + Some(FnKind::BuiltIn(_, _, ignore)) => *ignore, + _ => false, + }; + // canonicalize the arguments depending on method call or not let expected: Vec<_> = if method_call { fn_sig @@ -862,21 +878,24 @@ impl TypeChecker { )); } - // compare argument types with the function signature - for (sig_arg, (typ, span)) in expected.iter().zip(observed) { - // when const attribute presented, the argument must be a constant - if sig_arg.is_constant() && !matches!(typ, TyKind::Field { constant: true }) { - return Err(self.error( - ErrorKind::ArgumentTypeMismatch(sig_arg.typ.kind.clone(), typ), - span, - )); - } + // skip argument type checking if ignore_arg_types is true + if !ignore_arg_types { + // compare argument types with the function signature + for (sig_arg, (typ, span)) in expected.iter().zip(observed) { + // when const attribute presented, the argument must be a constant + if sig_arg.is_constant() && !matches!(typ, TyKind::Field { constant: true }) { + return Err(self.error( + ErrorKind::ArgumentTypeMismatch(sig_arg.typ.kind.clone(), typ), + span, + )); + } - if !typ.match_expected(&sig_arg.typ.kind, false) { - return Err(self.error( - ErrorKind::ArgumentTypeMismatch(sig_arg.typ.kind.clone(), typ), - span, - )); + if !typ.match_expected(&sig_arg.typ.kind, false) { + return Err(self.error( + ErrorKind::ArgumentTypeMismatch(sig_arg.typ.kind.clone(), typ), + span, + )); + } } } diff --git a/src/type_checker/mod.rs b/src/type_checker/mod.rs index b904feb5b..baa5ee9fa 100644 --- a/src/type_checker/mod.rs +++ b/src/type_checker/mod.rs @@ -377,7 +377,7 @@ impl TypeChecker { // check it is a builtin function let fn_handle = match builtin_fn { - FnKind::BuiltIn(_, fn_handle) => fn_handle, + FnKind::BuiltIn(_, fn_handle, _) => fn_handle, _ => { return Err(Error::new( "type-checker", @@ -392,7 +392,8 @@ impl TypeChecker { qualified, FnInfo { is_hint: true, - kind: FnKind::BuiltIn(function.sig.clone(), fn_handle), + // todo: is there a case where we want to ignore argument types for hint functions? + kind: FnKind::BuiltIn(function.sig.clone(), fn_handle, false), span: function.span, }, ); diff --git a/src/witness.rs b/src/witness.rs index 396f79fd5..94eefa117 100644 --- a/src/witness.rs +++ b/src/witness.rs @@ -66,7 +66,7 @@ impl CompiledCircuit { // get info on main let main_info = self.main_info(); let main_sig = match &main_info.kind { - crate::imports::FnKind::BuiltIn(_, _) => unreachable!(), + crate::imports::FnKind::BuiltIn(_, _, _) => unreachable!(), crate::imports::FnKind::Native(fn_sig) => &fn_sig.sig, };