Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
Implement call as expression
Browse files Browse the repository at this point in the history
  • Loading branch information
alxkzmn committed Aug 28, 2024
1 parent 49fb609 commit cce3de0
Show file tree
Hide file tree
Showing 11 changed files with 75 additions and 72 deletions.
7 changes: 4 additions & 3 deletions src/compiler/abepi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ impl<F: From<u64> + TryInto<u32> + Clone + Debug, V: Clone + Debug> CompilationU
Statement::Transition(dsym, id, stmt) => {
self.compiler_statement_transition(dsym, id, *stmt)
}
Statement::HyperTransition(dsym, state, call) => {
self.compiler_statement_hyper_transition(dsym, state, *call)
Statement::HyperTransition(dsym, ids, call, state) => {
self.compiler_statement_hyper_transition(dsym, ids, call, state)
}
_ => vec![],
}
Expand Down Expand Up @@ -427,8 +427,9 @@ impl<F: From<u64> + TryInto<u32> + Clone + Debug, V: Clone + Debug> CompilationU
fn compiler_statement_hyper_transition(
&self,
_dsym: DebugSymRef,
_ids: Vec<V>,
_call: Expression<F, V>,
_state: V,
_call: Statement<F, V>,
) -> Vec<CompilationResult<F, V>> {
todo!("Compile expressions?")
}
Expand Down
17 changes: 0 additions & 17 deletions src/compiler/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -842,22 +842,5 @@ mod test {
let result = TLDeclsParser::new().parse(&debug_sym_ref_factory, circuit);

assert!(result.is_ok());

// Wrong transition operator
let circuit = "
machine caller (signal n) (signal b: field) {
a', b, c' <== fibo(d, e, f + g) --> final;
}
";

let debug_sym_ref_factory = DebugSymRefFactory::new("", circuit);
let result = TLDeclsParser::new().parse(&debug_sym_ref_factory, circuit);

assert!(result.is_err());
let err = result.err().unwrap();
assert_eq!(
err.to_string(),
"Unrecognized token `-` found at 99:100\nExpected one of \"->\" or \";\""
);
}
}
16 changes: 8 additions & 8 deletions src/compiler/semantic/analyser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -272,16 +272,10 @@ impl Analyser {

Statement::SignalDecl(_, _) => {}
Statement::WGVarDecl(_, _) => {}
Statement::HyperTransition(_, state, call) => {
self.analyse_statement(*call);
Statement::HyperTransition(_, ids, call, state) => {
self.analyse_expression(call);
self.collect_id_usages(&[state]);
}
Statement::Call(_, ids, _machine, exprs) => {
// TODO analyze machine?
self.collect_id_usages(&ids);
exprs
.into_iter()
.for_each(|expr| self.analyse_expression(expr))
}
}
}
Expand Down Expand Up @@ -319,6 +313,12 @@ impl Analyser {
} => {
self.extract_usages_expression(&sub);
}
Expression::Call(_, fun, exprs) => {
self.collect_id_usages(&[fun]);
exprs
.into_iter()
.for_each(|expr| self.extract_usages_expression(&expr));
}
_ => {}
}
}
Expand Down
3 changes: 3 additions & 0 deletions src/compiler/semantic/rules.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ fn undeclared_rule(analyser: &mut Analyser, expr: &Expression<BigInt, Identifier
undeclared_rule(analyser, when_false);
}
Expression::Const(_, _) | Expression::True(_) | Expression::False(_) => {}
Expression::Call(_, _, args) => {
args.iter().for_each(|arg| undeclared_rule(analyser, arg));
}
}
}

Expand Down
3 changes: 1 addition & 2 deletions src/compiler/setup_inter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,8 +250,7 @@ impl SetupInterpreter {

SignalAssignment(_, _, _) | WGAssignment(_, _, _) => vec![],
SignalDecl(_, _) | WGVarDecl(_, _) => vec![],
Call(_, _, _, _) => todo!(),
HyperTransition(_, _, _) => todo!(),
HyperTransition(_, _, _, _) => todo!(),
};

self.add_poly_constraints(result.into_iter().map(|cr| cr.anti_booly).collect());
Expand Down
1 change: 1 addition & 0 deletions src/interpreter/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ pub(crate) fn eval_expr<F: Field + Hash, V: Identifiable>(
Const(_, v) => Ok(Value::Field(F::from_big_int(v))),
True(_) => Ok(Value::Bool(true)),
False(_) => Ok(Value::Bool(false)),
Call(_, _, _) => todo!(),
}
.map_err(|msg| Message::RuntimeErr {
msg,
Expand Down
3 changes: 1 addition & 2 deletions src/interpreter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,7 @@ impl<'a, F: Field + Hash> Interpreter<'a, F> {
Block(_, stmts) => self.exec_step_block(stmts),
Assert(_, _) => Ok(None),
StateDecl(_, _, _) => Ok(None),
Call(_, _, _, _) => todo!("execute call?"),
HyperTransition(_, _, _) => todo!("execute hypertransition?"),
HyperTransition(_, _, _, _) => todo!("execute hypertransition?"),
}
}

Expand Down
11 changes: 11 additions & 0 deletions src/parser/ast/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,12 @@ pub enum Expression<F, V> {
Const(DebugSymRef, F),
True(DebugSymRef),
False(DebugSymRef),
/// Function or machine call.
/// Tuple values:
/// - debug symbol reference;
/// - function/machine ID;
/// - call argument expressions vector.
Call(DebugSymRef, V, Vec<Expression<F, V>>),
}

// Shorthand for BigInt expression
Expand All @@ -217,6 +223,7 @@ impl<F, V> Expression<F, V> {
Const(_, _) => true,
True(_) => false,
False(_) => false,
Call(_, _, _) => todo!(),
}
}

Expand All @@ -234,6 +241,7 @@ impl<F, V> Expression<F, V> {

when_true.is_logic()
}
Expression::Call { .. } => todo!(),
_ => false,
}
}
Expand All @@ -247,6 +255,7 @@ impl<F, V> Expression<F, V> {
Expression::Const(dsym, _) => dsym,
Expression::True(dsym) => dsym,
Expression::False(dsym) => dsym,
Expression::Call(dsym, _, _) => dsym,
}
}

Expand All @@ -260,6 +269,7 @@ impl<F, V> Expression<F, V> {
Expression::Query(_, _) => false,
Expression::True(_) => false,
Expression::False(_) => false,
Expression::Call(_, _, _) => false,
}
}
}
Expand Down Expand Up @@ -315,6 +325,7 @@ impl<F: Debug, V: Debug> Debug for Expression<F, V> {

Expression::True(_) => write!(f, "true"),
Expression::False(_) => write!(f, "false"),
Expression::Call(_, fun, exprs) => write!(f, "{:?}({:?})", fun, exprs),
}
}
}
71 changes: 38 additions & 33 deletions src/parser/ast/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,42 +13,42 @@ pub struct TypedIdDecl<V> {

#[derive(Clone)]
pub enum Statement<F, V> {
Assert(DebugSymRef, Expression<F, V>), // assert x;

SignalAssignment(DebugSymRef, Vec<V>, Vec<Expression<F, V>>), // x <-- y;
SignalAssignmentAssert(DebugSymRef, Vec<V>, Vec<Expression<F, V>>), // x <== y;
WGAssignment(DebugSymRef, Vec<V>, Vec<Expression<F, V>>), // x = y;

IfThen(DebugSymRef, Box<Expression<F, V>>, Box<Statement<F, V>>), // if x { y }
/// assert x;
Assert(DebugSymRef, Expression<F, V>),
/// x <-- y;
SignalAssignment(DebugSymRef, Vec<V>, Vec<Expression<F, V>>),
/// x <== y;
SignalAssignmentAssert(DebugSymRef, Vec<V>, Vec<Expression<F, V>>),
/// x = y;
WGAssignment(DebugSymRef, Vec<V>, Vec<Expression<F, V>>),
/// if x { y }
IfThen(DebugSymRef, Box<Expression<F, V>>, Box<Statement<F, V>>),
/// if x { y } else { z }
IfThenElse(
DebugSymRef,
Box<Expression<F, V>>,
Box<Statement<F, V>>,
Box<Statement<F, V>>,
), // if x { y } else { z }

SignalDecl(DebugSymRef, Vec<TypedIdDecl<V>>), // signal x;
WGVarDecl(DebugSymRef, Vec<TypedIdDecl<V>>), // var x;

StateDecl(DebugSymRef, V, Box<Statement<F, V>>), // state x { y }
),
/// signal x;
SignalDecl(DebugSymRef, Vec<TypedIdDecl<V>>),
/// var x;
WGVarDecl(DebugSymRef, Vec<TypedIdDecl<V>>),
/// state x { y }
StateDecl(DebugSymRef, V, Box<Statement<F, V>>),
/// Transition to another state.
Transition(DebugSymRef, V, Box<Statement<F, V>>), // -> x { y }

Block(DebugSymRef, Vec<Statement<F, V>>), // { x }
/// Call into another machine with assertion.
/// Tuple values:
/// - debug symbol reference;
/// - assigned/asserted ids vector;
/// - machine ID;
/// - call argument expressions vector.
Call(DebugSymRef, Vec<V>, V, Vec<Expression<F, V>>),
/// -> x { y }
Transition(DebugSymRef, V, Box<Statement<F, V>>),
/// { x }
Block(DebugSymRef, Vec<Statement<F, V>>),
/// Call into another machine with assertion and subsequent transition to another
/// state.
/// Tuple values:
/// - debug symbol reference;
/// - next state ID.
/// - machine call;
HyperTransition(DebugSymRef, V, Box<Statement<F, V>>),
/// - assigned signal IDs;
/// - call expression;
/// - next state ID;
HyperTransition(DebugSymRef, Vec<V>, Expression<F, V>, V),
}

impl<F: Debug> Debug for Statement<F, Identifier> {
Expand Down Expand Up @@ -98,11 +98,17 @@ impl<F: Debug> Debug for Statement<F, Identifier> {
.join(" ")
)
}
Statement::Call(_, ids, machine, exprs) => {
write!(f, "{:?} <== {} ({:?});", ids, machine.name(), exprs)
}
Statement::HyperTransition(_, state, call) => {
write!(f, "{:?} -> {:?};", call, state)
Statement::HyperTransition(_, ids, call, state) => {
write!(
f,
"{:?} <== {:?} -> {:?};",
ids.iter()
.map(|id| id.name())
.collect::<Vec<_>>()
.join(", "),
call,
state
)
}
}
}
Expand All @@ -122,8 +128,7 @@ impl<F, V> Statement<F, V> {
Statement::StateDecl(dsym, _, _) => dsym.clone(),
Statement::Transition(dsym, _, _) => dsym.clone(),
Statement::Block(dsym, _) => dsym.clone(),
Statement::Call(dsym, _, _, _) => dsym.clone(),
Statement::HyperTransition(dsym, _, _) => dsym.clone(),
Statement::HyperTransition(dsym, _, _, _) => dsym.clone(),
}
}
}
7 changes: 4 additions & 3 deletions src/parser/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,14 @@ pub fn build_transition<F>(
Statement::Transition(dsym, id, Box::new(block))
}

pub fn build_hyper_transition<F>(
pub fn build_hyper_transition<F: Clone>(
dsym: DebugSymRef,
ids: Vec<Identifier>,
call: Expression<F, Identifier>,
state: Identifier,
call: Statement<F, Identifier>,
) -> Statement<F, Identifier> {
match call {
Statement::Call(_, _, _, _) => Statement::HyperTransition(dsym, state, Box::new(call)),
Expression::Call(_, _, _) => Statement::HyperTransition(dsym, ids, call, state),
_ => unreachable!("Hyper transition must include a call statement"),
}
}
8 changes: 4 additions & 4 deletions src/parser/chiquito.lalrpop
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ StatementType: Statement<BigInt, Identifier> = {
ParseTransitionSimple,
ParseTransition,
HyperTransition,
Call,
}

AssertEq: Statement<BigInt, Identifier> = {
Expand Down Expand Up @@ -114,11 +113,11 @@ ParseTransition: Statement<BigInt, Identifier> = {
}

HyperTransition: Statement<BigInt, Identifier> = {
<l: @L> <call: Call> "->" <st: Identifier> <r: @R> => build_hyper_transition(dsym_factory.create(l,r), st, call),
<l: @L> <ids: ParseIdsList> "<==" <call:Expression> "->" <st: Identifier> <r: @R> => build_hyper_transition(dsym_factory.create(l,r), ids, call, st),
}

Call: Statement<BigInt, Identifier> = {
<l: @L> <ids: ParseIdsList> "<==" <machine: Identifier> "(" <es:ParseExpressionList> ")" <r: @R> => Statement::Call(dsym_factory.create(l,r), ids, machine, es),
Call: Expression<BigInt, Identifier> = {
<l: @L> <fun: Identifier> "(" <es:ParseExpressionList> ")" <r: @R> => Expression::Call(dsym_factory.create(l,r), fun, es),
}

ParseSignalDecl: Statement<BigInt, Identifier> = {
Expand Down Expand Up @@ -212,6 +211,7 @@ ExpressionTerm: Expr = {
<l: @L> "true" <r: @R> => Expression::True(dsym_factory.create(l,r)),
<l: @L> "false" <r: @R> => Expression::False(dsym_factory.create(l,r)),
<l: @L> "(" <e: Expression> ")" <r: @R> => e,
Call
}

ParseBinOp<Op, Next>: Expr = {
Expand Down

0 comments on commit cce3de0

Please sign in to comment.