Skip to content

Commit

Permalink
Merge pull request #231 from bufferhe4d/log-type-skip-228
Browse files Browse the repository at this point in the history
Add boolean flag to skip argument type checking of builtin functions.
  • Loading branch information
katat authored Dec 10, 2024
2 parents 321b047 + a2d31f7 commit c21739d
Show file tree
Hide file tree
Showing 13 changed files with 63 additions and 43 deletions.
2 changes: 1 addition & 1 deletion src/circuit_writer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ impl<B: Backend> CircuitWriter<B> {
let main_fn_info = circuit_writer.main_info()?;

let function = match &main_fn_info.kind {
crate::imports::FnKind::BuiltIn(_, _) => unreachable!(),
crate::imports::FnKind::BuiltIn(_, _, _) => unreachable!(),
crate::imports::FnKind::Native(fn_sig) => fn_sig.clone(),
};

Expand Down
2 changes: 1 addition & 1 deletion src/circuit_writer/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ impl<B: Backend> CircuitWriter<B> {

match &fn_info.kind {
// assert() <-- for example
FnKind::BuiltIn(sig, handle) => {
FnKind::BuiltIn(sig, handle, _) => {
let res = handle(self, &sig.generics, &vars, expr.span);
res.map(|r| r.map(VarOrRef::Var))
}
Expand Down
5 changes: 3 additions & 2 deletions src/imports.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ where
B: Backend,
{
/// A built-in is just a handle to a function written in Rust.
BuiltIn(FnSig, FnHandle<B>),
/// The boolean flag indicates whether to skip argument type checking for builtins
BuiltIn(FnSig, FnHandle<B>, bool),

/// A native function is represented as an AST.
Native(FunctionDef),
Expand All @@ -96,7 +97,7 @@ mod fn_kind_serde {
S: Serializer,
{
let surrogate = match self {
FnKind::BuiltIn(sig, _handle) => {
FnKind::BuiltIn(sig, _handle, _) => {
FnKindSurrogate::BuiltIn(sig.clone(), "native function".to_string())
}
FnKind::Native(def) => FnKindSurrogate::Native(def.clone()),
Expand Down
6 changes: 3 additions & 3 deletions src/mast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ impl<B: Backend> FnInfo<B> {
ctx: &mut MastCtx<B>,
) -> Result<FnSig> {
match self.kind {
FnKind::BuiltIn(ref mut sig, _) => {
FnKind::BuiltIn(ref mut sig, _, _) => {
sig.resolve_generic_values(observed_args, ctx)?;
}
FnKind::Native(ref mut func) => {
Expand Down Expand Up @@ -1363,9 +1363,9 @@ pub fn instantiate_fn_call<B: Backend>(

// construct the monomorphized function AST
let (func_def, mono_info) = match fn_info.kind {
FnKind::BuiltIn(_, handle) => (
FnKind::BuiltIn(_, handle, ignore_arg_types) => (
FnInfo {
kind: FnKind::BuiltIn(sig_typed, handle),
kind: FnKind::BuiltIn(sig_typed, handle, ignore_arg_types),
..fn_info
},
// todo: we will need to propagate the constant value from builtin function as well
Expand Down
2 changes: 1 addition & 1 deletion src/negative_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -623,7 +623,7 @@ fn test_hint_builtin_fn(qualified: &FullyQualified, code: &str) -> Result<usize>
}

let fn_info = FnInfo {
kind: FnKind::BuiltIn(sig, mocked_builtin_fn::<R1csBackend>),
kind: FnKind::BuiltIn(sig, mocked_builtin_fn::<R1csBackend>, false),
is_hint: true,
span: Span::default(),
};
Expand Down
6 changes: 3 additions & 3 deletions src/stdlib/bits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ pub struct BitsLib {}
impl Module for BitsLib {
const MODULE: &'static str = "bits";

fn get_fns<B: Backend>() -> Vec<(&'static str, FnInfoType<B>)> {
fn get_fns<B: Backend>() -> Vec<(&'static str, FnInfoType<B>, bool)> {
vec![
(NTH_BIT_FN, nth_bit),
(CHECK_FIELD_SIZE_FN, check_field_size),
(NTH_BIT_FN, nth_bit, false),
(CHECK_FIELD_SIZE_FN, check_field_size, false),
]
}
}
Expand Down
11 changes: 5 additions & 6 deletions src/stdlib/builtins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,19 @@ pub const BUILTIN_FN_NAMES: [&str; 3] = ["assert", "assert_eq", "log"];

const ASSERT_FN: &str = "assert(condition: Bool)";
const ASSERT_EQ_FN: &str = "assert_eq(lhs: Field, rhs: Field)";
// todo: currently only supports a single field var
// to support all the types, we can bypass the type check for this log function for now
const LOG_FN: &str = "log(var: Field)";

pub struct BuiltinsLib {}

impl Module for BuiltinsLib {
const MODULE: &'static str = "builtins";

fn get_fns<B: Backend>() -> Vec<(&'static str, FnInfoType<B>)> {
fn get_fns<B: Backend>() -> Vec<(&'static str, FnInfoType<B>, bool)> {
vec![
(ASSERT_FN, assert_fn),
(ASSERT_EQ_FN, assert_eq_fn),
(LOG_FN, log_fn),
(ASSERT_FN, assert_fn, false),
(ASSERT_EQ_FN, assert_eq_fn, false),
// true -> skip argument type checking for log
(LOG_FN, log_fn, true),
]
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/stdlib/crypto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ pub struct CryptoLib {}
impl Module for CryptoLib {
const MODULE: &'static str = "crypto";

fn get_fns<B: Backend>() -> Vec<(&'static str, FnInfoType<B>)> {
vec![(POSEIDON_FN, B::poseidon())]
fn get_fns<B: Backend>() -> Vec<(&'static str, FnInfoType<B>, bool)> {
vec![(POSEIDON_FN, B::poseidon(), false)]
}
}
4 changes: 2 additions & 2 deletions src/stdlib/int.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ pub struct IntLib {}
impl Module for IntLib {
const MODULE: &'static str = "int";

fn get_fns<B: Backend>() -> Vec<(&'static str, FnInfoType<B>)> {
vec![(DIVMOD_FN, divmod_fn)]
fn get_fns<B: Backend>() -> Vec<(&'static str, FnInfoType<B>, bool)> {
vec![(DIVMOD_FN, divmod_fn, false)]
}
}

Expand Down
6 changes: 3 additions & 3 deletions src/stdlib/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,18 +71,18 @@ trait Module {
/// e.g. "crypto"
const MODULE: &'static str;

fn get_fns<B: Backend>() -> Vec<(&'static str, FnInfoType<B>)>;
fn get_fns<B: Backend>() -> Vec<(&'static str, FnInfoType<B>, bool)>;

fn get_parsed_fns<B: Backend>() -> Vec<FnInfo<B>> {
let fns = Self::get_fns();
let mut res = Vec::with_capacity(fns.len());
for (code, fn_handle) in fns {
for (code, fn_handle, ignore_arg_types) in fns {
let ctx = &mut ParserCtx::default();
// TODO: we should try to point to real noname files here (not 0)
let mut tokens = Token::parse(0, code).unwrap();
let sig = FnSig::parse(ctx, &mut tokens).unwrap();
res.push(FnInfo {
kind: FnKind::BuiltIn(sig, fn_handle),
kind: FnKind::BuiltIn(sig, fn_handle, ignore_arg_types),
is_hint: false,
span: Span::default(),
});
Expand Down
51 changes: 35 additions & 16 deletions src/type_checker/checker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,18 @@ use serde::{Deserialize, Serialize};

use crate::{
backends::Backend,
cli::packages::UserRepo,
constants::Span,
error::{ErrorKind, Result},
imports::FnKind,
parser::{
types::{
is_numeric, FnSig, ForLoopArgument, FunctionDef, Stmt, StmtKind, Symbolic, Ty, TyKind,
is_numeric, FnSig, ForLoopArgument, FunctionDef, ModulePath, Stmt, StmtKind, Symbolic,
Ty, TyKind,
},
CustomType, Expr, ExprKind, Op2,
},
stdlib::builtins::QUALIFIED_BUILTINS,
syntax::is_type,
};

Expand All @@ -38,7 +41,7 @@ where
impl<B: Backend> FnInfo<B> {
pub fn sig(&self) -> &FnSig {
match &self.kind {
FnKind::BuiltIn(sig, _) => sig,
FnKind::BuiltIn(sig, _, _) => sig,
FnKind::Native(func) => &func.sig,
}
}
Expand Down Expand Up @@ -833,6 +836,19 @@ impl<B: Backend> TypeChecker<B> {
None => (),
};

// get the ignore_arg_types flag from the function info if it's a builtin
let ignore_arg_types = match self
.fn_info(&FullyQualified::new(
&ModulePath::Absolute(UserRepo::new(QUALIFIED_BUILTINS)),
&fn_sig.name.value,
))
.map(|info| &info.kind)
{
// check builtin
Some(FnKind::BuiltIn(_, _, ignore)) => *ignore,
_ => false,
};

// canonicalize the arguments depending on method call or not
let expected: Vec<_> = if method_call {
fn_sig
Expand Down Expand Up @@ -862,21 +878,24 @@ impl<B: Backend> TypeChecker<B> {
));
}

// compare argument types with the function signature
for (sig_arg, (typ, span)) in expected.iter().zip(observed) {
// when const attribute presented, the argument must be a constant
if sig_arg.is_constant() && !matches!(typ, TyKind::Field { constant: true }) {
return Err(self.error(
ErrorKind::ArgumentTypeMismatch(sig_arg.typ.kind.clone(), typ),
span,
));
}
// skip argument type checking if ignore_arg_types is true
if !ignore_arg_types {
// compare argument types with the function signature
for (sig_arg, (typ, span)) in expected.iter().zip(observed) {
// when const attribute presented, the argument must be a constant
if sig_arg.is_constant() && !matches!(typ, TyKind::Field { constant: true }) {
return Err(self.error(
ErrorKind::ArgumentTypeMismatch(sig_arg.typ.kind.clone(), typ),
span,
));
}

if !typ.match_expected(&sig_arg.typ.kind, false) {
return Err(self.error(
ErrorKind::ArgumentTypeMismatch(sig_arg.typ.kind.clone(), typ),
span,
));
if !typ.match_expected(&sig_arg.typ.kind, false) {
return Err(self.error(
ErrorKind::ArgumentTypeMismatch(sig_arg.typ.kind.clone(), typ),
span,
));
}
}
}

Expand Down
5 changes: 3 additions & 2 deletions src/type_checker/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ impl<B: Backend> TypeChecker<B> {

// check it is a builtin function
let fn_handle = match builtin_fn {
FnKind::BuiltIn(_, fn_handle) => fn_handle,
FnKind::BuiltIn(_, fn_handle, _) => fn_handle,
_ => {
return Err(Error::new(
"type-checker",
Expand All @@ -392,7 +392,8 @@ impl<B: Backend> TypeChecker<B> {
qualified,
FnInfo {
is_hint: true,
kind: FnKind::BuiltIn(function.sig.clone(), fn_handle),
// todo: is there a case where we want to ignore argument types for hint functions?
kind: FnKind::BuiltIn(function.sig.clone(), fn_handle, false),
span: function.span,
},
);
Expand Down
2 changes: 1 addition & 1 deletion src/witness.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ impl<B: Backend> CompiledCircuit<B> {
// get info on main
let main_info = self.main_info();
let main_sig = match &main_info.kind {
crate::imports::FnKind::BuiltIn(_, _) => unreachable!(),
crate::imports::FnKind::BuiltIn(_, _, _) => unreachable!(),
crate::imports::FnKind::Native(fn_sig) => &fn_sig.sig,
};

Expand Down

0 comments on commit c21739d

Please sign in to comment.