Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support hint function declaration to bind with builtin function #199

Merged
merged 8 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/circuit_writer/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ impl<B: Backend> CircuitWriter<B> {
module,
fn_name,
args,
..
} => {
// sanity check
if fn_name.value == "main" {
Expand Down
9 changes: 9 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,9 @@ pub enum ErrorKind {
#[error("function `{0}` not present in scope (did you misspell it?)")]
UndefinedFunction(String),

#[error("hint function `{0}` signature is missing its corresponding builtin function")]
MissingHintMapping(String),

#[error("function name `{0}` is already in use by a variable present in the scope")]
FunctionNameInUsebyVariable(String),

Expand All @@ -246,6 +249,12 @@ pub enum ErrorKind {
#[error("attribute not recognized: `{0:?}`")]
InvalidAttribute(AttributeKind),

#[error("unsafe attribute is needed to call a hint function. eg: `unsafe fn foo()`")]
ExpectedUnsafeAttribute,

#[error("unsafe attribute should only be applied to hint function calls")]
UnexpectedUnsafeAttribute,

#[error("A return value is not used")]
UnusedReturnValue,

Expand Down
8 changes: 8 additions & 0 deletions src/lexer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ pub enum Keyword {
Use,
/// A function
Fn,
/// A hint function
Hint,
/// Attribute required for hint functions
Unsafe,
/// New variable
Let,
/// Public input
Expand Down Expand Up @@ -75,6 +79,8 @@ impl Keyword {
match s {
"use" => Some(Self::Use),
"fn" => Some(Self::Fn),
"hint" => Some(Self::Hint),
"unsafe" => Some(Self::Unsafe),
"let" => Some(Self::Let),
"pub" => Some(Self::Pub),
"return" => Some(Self::Return),
Expand All @@ -97,6 +103,8 @@ impl Display for Keyword {
let desc = match self {
Self::Use => "use",
Self::Fn => "fn",
Self::Hint => "hint",
Self::Unsafe => "unsafe",
Self::Let => "let",
Self::Pub => "pub",
Self::Return => "return",
Expand Down
7 changes: 7 additions & 0 deletions src/mast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,7 @@ fn monomorphize_expr<B: Backend>(
module,
fn_name,
args,
unsafe_attr,
} => {
// compute the observed arguments types
let mut observed = Vec::with_capacity(args.len());
Expand Down Expand Up @@ -548,6 +549,7 @@ fn monomorphize_expr<B: Backend>(
module: module.clone(),
fn_name: resolved_sig.name,
args: args_mono,
unsafe_attr: *unsafe_attr,
},
);
let resolved_sig = &fn_info.sig().generics.resolved_sig;
Expand All @@ -568,6 +570,7 @@ fn monomorphize_expr<B: Backend>(
module: module.clone(),
fn_name: fn_name_mono.clone(),
args: args_mono,
unsafe_attr: *unsafe_attr,
},
);

Expand Down Expand Up @@ -611,6 +614,7 @@ fn monomorphize_expr<B: Backend>(
let fn_kind = FnKind::Native(method_type.clone());
let mut fn_info = FnInfo {
kind: fn_kind,
is_hint: false,
span: method_type.span,
};

Expand Down Expand Up @@ -1215,6 +1219,7 @@ pub fn instantiate_fn_call<B: Backend>(
let func_def = match fn_info.kind {
FnKind::BuiltIn(_, handle) => FnInfo {
kind: FnKind::BuiltIn(sig_typed, handle),
is_hint: fn_info.is_hint,
span: fn_info.span,
},
FnKind::Native(fn_def) => {
Expand All @@ -1226,7 +1231,9 @@ pub fn instantiate_fn_call<B: Backend>(
sig: sig_typed,
body: stmts_typed,
span: fn_def.span,
is_hint: fn_def.is_hint,
}),
is_hint: fn_info.is_hint,
span: fn_info.span,
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/name_resolution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ impl NameResCtx {
}

pub(crate) fn resolve_fn_def(&self, fn_def: &mut FunctionDef) -> Result<()> {
let FunctionDef { sig, body, span: _ } = fn_def;
let FunctionDef { sig, body, .. } = fn_def;

//
// signature
Expand Down
1 change: 1 addition & 0 deletions src/name_resolution/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ impl NameResCtx {
module,
fn_name,
args,
unsafe_attr: _,
} => {
if matches!(module, ModulePath::Local)
&& BUILTIN_FN_NAMES.contains(&fn_name.value.as_str())
Expand Down
125 changes: 112 additions & 13 deletions src/negative_tests.rs
Original file line number Diff line number Diff line change
@@ -1,30 +1,41 @@
use crate::{
backends::r1cs::{R1csBn254Field, R1CS},
circuit_writer::CircuitWriter,
backends::{
r1cs::{R1csBn254Field, R1CS},
Backend,
},
circuit_writer::{CircuitWriter, VarInfo},
compiler::{get_nast, typecheck_next_file_inner, Sources},
constants::Span,
error::{ErrorKind, Result},
imports::FnKind,
lexer::Token,
mast::Mast,
name_resolution::NAST,
type_checker::TypeChecker,
parser::{
types::{FnSig, GenericParameters},
ParserCtx,
},
type_checker::{FnInfo, FullyQualified, TypeChecker},
var::Var,
witness::CompiledCircuit,
};

fn nast_pass(code: &str) -> Result<(NAST<R1CS<R1csBn254Field>>, usize)> {
type R1csBackend = R1CS<R1csBn254Field>;

fn nast_pass(code: &str) -> Result<(NAST<R1csBackend>, usize)> {
let mut source = Sources::new();
let res = get_nast(
get_nast(
None,
&mut source,
"example.no".to_string(),
code.to_string(),
0,
);

res
)
}

fn tast_pass(code: &str) -> (Result<usize>, TypeChecker<R1CS<R1csBn254Field>>, Sources) {
fn tast_pass(code: &str) -> (Result<usize>, TypeChecker<R1csBackend>, Sources) {
let mut source = Sources::new();
let mut tast = TypeChecker::<R1CS<R1csBn254Field>>::new();
let mut tast = TypeChecker::<R1csBackend>::new();
let res = typecheck_next_file_inner(
&mut tast,
None,
Expand All @@ -37,12 +48,12 @@ fn tast_pass(code: &str) -> (Result<usize>, TypeChecker<R1CS<R1csBn254Field>>, S
(res, tast, source)
}

fn mast_pass(code: &str) -> Result<Mast<R1CS<R1csBn254Field>>> {
fn mast_pass(code: &str) -> Result<Mast<R1csBackend>> {
let (_, tast, _) = tast_pass(code);
crate::mast::monomorphize(tast)
}

fn synthesizer_pass(code: &str) -> Result<CompiledCircuit<R1CS<R1csBn254Field>>> {
fn synthesizer_pass(code: &str) -> Result<CompiledCircuit<R1csBackend>> {
let mast = mast_pass(code);
CircuitWriter::generate_circuit(mast?, R1CS::new())
}
Expand Down Expand Up @@ -396,6 +407,94 @@ fn test_generic_missing_parenthesis() {
"#;

let res = nast_pass(code).err();
println!("{:?}", res);
assert!(matches!(res.unwrap().kind, ErrorKind::MissingParenthesis));
}
fn test_hint_builtin_fn(qualified: &FullyQualified, code: &str) -> Result<usize> {
let mut source = Sources::new();
let mut tast = TypeChecker::<R1csBackend>::new();
// mock a builtin function
let ctx = &mut ParserCtx::default();
let mut tokens = Token::parse(0, "calc(val: Field) -> Field;").unwrap();
let sig = FnSig::parse(ctx, &mut tokens).unwrap();

fn mocked_builtin_fn<B: Backend>(
_: &mut CircuitWriter<B>,
_: &GenericParameters,
_: &[VarInfo<B::Field, B::Var>],
_: Span,
) -> Result<Option<Var<B::Field, B::Var>>> {
Ok(None)
}

let fn_info = FnInfo {
kind: FnKind::BuiltIn(sig, mocked_builtin_fn::<R1csBackend>),
is_hint: true,
span: Span::default(),
};

// add the mocked builtin function
// note that this should happen in the tast phase, instead of mast phase.
// currently this function is the only way to mock a builtin function.
tast.add_monomorphized_fn(qualified.clone(), fn_info);

typecheck_next_file_inner(
&mut tast,
None,
&mut source,
"example.no".to_string(),
code.to_string(),
0,
)
}

#[test]
fn test_hint_call_missing_unsafe() {
let qualified = FullyQualified {
module: None,
name: "calc".to_string(),
};

let valid_code = r#"
hint fn calc(val: Field) -> Field;

fn main(pub xx: Field) {
let yy = unsafe calc(xx);
}
"#;

let res = test_hint_builtin_fn(&qualified, valid_code);
assert!(res.is_ok());

let invalid_code = r#"
hint fn calc(val: Field) -> Field;

fn main(pub xx: Field) {
let yy = calc(xx);
}
"#;

let res = test_hint_builtin_fn(&qualified, invalid_code);
assert!(matches!(
res.unwrap_err().kind,
ErrorKind::ExpectedUnsafeAttribute
));
}

#[test]
fn test_nonhint_call_with_unsafe() {
let code = r#"
fn calc(val: Field) -> Field {
return val + 1;
}

fn main(pub xx: Field) {
let yy = unsafe calc(xx);
}
"#;

let res = tast_pass(code).0;
assert!(matches!(
res.unwrap_err().kind,
ErrorKind::UnexpectedUnsafeAttribute
));
}
17 changes: 17 additions & 0 deletions src/parser/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ pub enum ExprKind {
module: ModulePath,
fn_name: Ident,
args: Vec<Expr>,
unsafe_attr: bool,
},

/// `lhs.method_name(args)`
Expand Down Expand Up @@ -379,6 +380,21 @@ impl Expr {
}
}

TokenKind::Keyword(Keyword::Unsafe) => {
let mut fn_call = Expr::parse(ctx, tokens)?;
// should be FnCall
match &mut fn_call.kind {
ExprKind::FnCall { unsafe_attr, .. } => {
*unsafe_attr = true;
}
_ => {
return Err(ctx.error(ErrorKind::InvalidExpression, fn_call.span));
}
};

fn_call
}

// unrecognized pattern
_ => {
return Err(ctx.error(ErrorKind::InvalidExpression, token.span));
Expand Down Expand Up @@ -576,6 +592,7 @@ impl Expr {
module,
fn_name,
args,
unsafe_attr: false,
},
span,
)
Expand Down
20 changes: 20 additions & 0 deletions src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,26 @@ impl<B: Backend> AST<B> {
});
}

// `hint fn calc() { }`
TokenKind::Keyword(Keyword::Hint) => {
// expect fn token
tokens.bump_expected(ctx, TokenKind::Keyword(Keyword::Fn))?;

function_observed = true;

let func = FunctionDef::parse_hint(ctx, &mut tokens)?;

// expect ;, as the hint function is an empty function wired with a builtin.
// todo: later these hint functions will be migrated from builtins to native functions
// then it will expect a function block instead of ;
tokens.bump_expected(ctx, TokenKind::SemiColon)?;

ast.push(Root {
kind: RootKind::FunctionDef(func),
span: token.span,
});
}

// `struct Foo { a: Field, b: Field }`
TokenKind::Keyword(Keyword::Struct) => {
let s = StructDef::parse(ctx, &mut tokens)?;
Expand Down
31 changes: 30 additions & 1 deletion src/parser/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,7 @@ impl Attribute {

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionDef {
pub is_hint: bool,
pub sig: FnSig,
pub body: Vec<Stmt>,
pub span: Span,
Expand Down Expand Up @@ -1149,7 +1150,35 @@ impl FunctionDef {
));
}

let func = Self { sig, body, span };
let func = Self {
sig,
body,
span,
is_hint: false,
};

Ok(func)
}

/// Parse a hint function signature
pub fn parse_hint(ctx: &mut ParserCtx, tokens: &mut Tokens) -> Result<Self> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you think it'd make sense to merge this with the parse function above?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This new parser function is for hint function without a body.
So I think once we support the native hints, we can just remove this function, while the native hint function can get parsed using parser function above without changes.

// parse signature
let sig = FnSig::parse(ctx, tokens)?;
let span = sig.name.span;

// make sure that it doesn't shadow a builtin
if BUILTIN_FN_NAMES.contains(&sig.name.value.as_ref()) {
return Err(ctx.error(ErrorKind::ShadowingBuiltIn(sig.name.value.clone()), span));
}

// for now the body is empty.
// this will be changed once the native hint is implemented.
let func = Self {
sig,
body: vec![],
span,
is_hint: true,
};

Ok(func)
}
Expand Down
Loading
Loading