From cce3de0bb719de2321cc2ea785a81ba844d0e2fe Mon Sep 17 00:00:00 2001 From: Alex Kuzmin Date: Wed, 28 Aug 2024 13:05:30 +0800 Subject: [PATCH] Implement call as expression --- src/compiler/abepi.rs | 7 +-- src/compiler/compiler.rs | 17 -------- src/compiler/semantic/analyser.rs | 16 +++---- src/compiler/semantic/rules.rs | 3 ++ src/compiler/setup_inter.rs | 3 +- src/interpreter/expr.rs | 1 + src/interpreter/mod.rs | 3 +- src/parser/ast/expression.rs | 11 +++++ src/parser/ast/statement.rs | 71 +++++++++++++++++-------------- src/parser/build.rs | 7 +-- src/parser/chiquito.lalrpop | 8 ++-- 11 files changed, 75 insertions(+), 72 deletions(-) diff --git a/src/compiler/abepi.rs b/src/compiler/abepi.rs index b58685ae..dfefab56 100644 --- a/src/compiler/abepi.rs +++ b/src/compiler/abepi.rs @@ -65,8 +65,8 @@ impl + TryInto + Clone + Debug, V: Clone + Debug> CompilationU Statement::Transition(dsym, id, stmt) => { self.compiler_statement_transition(dsym, id, *stmt) } - Statement::HyperTransition(dsym, state, call) => { - self.compiler_statement_hyper_transition(dsym, state, *call) + Statement::HyperTransition(dsym, ids, call, state) => { + self.compiler_statement_hyper_transition(dsym, ids, call, state) } _ => vec![], } @@ -427,8 +427,9 @@ impl + TryInto + Clone + Debug, V: Clone + Debug> CompilationU fn compiler_statement_hyper_transition( &self, _dsym: DebugSymRef, + _ids: Vec, + _call: Expression, _state: V, - _call: Statement, ) -> Vec> { todo!("Compile expressions?") } diff --git a/src/compiler/compiler.rs b/src/compiler/compiler.rs index 09e708d5..8e9d51be 100644 --- a/src/compiler/compiler.rs +++ b/src/compiler/compiler.rs @@ -842,22 +842,5 @@ mod test { let result = TLDeclsParser::new().parse(&debug_sym_ref_factory, circuit); assert!(result.is_ok()); - - // Wrong transition operator - let circuit = " - machine caller (signal n) (signal b: field) { - a', b, c' <== fibo(d, e, f + g) --> final; - } - "; - - let debug_sym_ref_factory = DebugSymRefFactory::new("", circuit); - let result = TLDeclsParser::new().parse(&debug_sym_ref_factory, circuit); - - assert!(result.is_err()); - let err = result.err().unwrap(); - assert_eq!( - err.to_string(), - "Unrecognized token `-` found at 99:100\nExpected one of \"->\" or \";\"" - ); } } diff --git a/src/compiler/semantic/analyser.rs b/src/compiler/semantic/analyser.rs index 28011d1a..e5291ea4 100644 --- a/src/compiler/semantic/analyser.rs +++ b/src/compiler/semantic/analyser.rs @@ -272,16 +272,10 @@ impl Analyser { Statement::SignalDecl(_, _) => {} Statement::WGVarDecl(_, _) => {} - Statement::HyperTransition(_, state, call) => { - self.analyse_statement(*call); + Statement::HyperTransition(_, ids, call, state) => { + self.analyse_expression(call); self.collect_id_usages(&[state]); - } - Statement::Call(_, ids, _machine, exprs) => { - // TODO analyze machine? self.collect_id_usages(&ids); - exprs - .into_iter() - .for_each(|expr| self.analyse_expression(expr)) } } } @@ -319,6 +313,12 @@ impl Analyser { } => { self.extract_usages_expression(&sub); } + Expression::Call(_, fun, exprs) => { + self.collect_id_usages(&[fun]); + exprs + .into_iter() + .for_each(|expr| self.extract_usages_expression(&expr)); + } _ => {} } } diff --git a/src/compiler/semantic/rules.rs b/src/compiler/semantic/rules.rs index 5bf846e8..818162c1 100644 --- a/src/compiler/semantic/rules.rs +++ b/src/compiler/semantic/rules.rs @@ -43,6 +43,9 @@ fn undeclared_rule(analyser: &mut Analyser, expr: &Expression {} + Expression::Call(_, _, args) => { + args.iter().for_each(|arg| undeclared_rule(analyser, arg)); + } } } diff --git a/src/compiler/setup_inter.rs b/src/compiler/setup_inter.rs index 2b2d4f18..79598772 100644 --- a/src/compiler/setup_inter.rs +++ b/src/compiler/setup_inter.rs @@ -250,8 +250,7 @@ impl SetupInterpreter { SignalAssignment(_, _, _) | WGAssignment(_, _, _) => vec![], SignalDecl(_, _) | WGVarDecl(_, _) => vec![], - Call(_, _, _, _) => todo!(), - HyperTransition(_, _, _) => todo!(), + HyperTransition(_, _, _, _) => todo!(), }; self.add_poly_constraints(result.into_iter().map(|cr| cr.anti_booly).collect()); diff --git a/src/interpreter/expr.rs b/src/interpreter/expr.rs index afc85830..608ee813 100644 --- a/src/interpreter/expr.rs +++ b/src/interpreter/expr.rs @@ -82,6 +82,7 @@ pub(crate) fn eval_expr( Const(_, v) => Ok(Value::Field(F::from_big_int(v))), True(_) => Ok(Value::Bool(true)), False(_) => Ok(Value::Bool(false)), + Call(_, _, _) => todo!(), } .map_err(|msg| Message::RuntimeErr { msg, diff --git a/src/interpreter/mod.rs b/src/interpreter/mod.rs index 8dfd421d..a2b7938d 100644 --- a/src/interpreter/mod.rs +++ b/src/interpreter/mod.rs @@ -206,8 +206,7 @@ impl<'a, F: Field + Hash> Interpreter<'a, F> { Block(_, stmts) => self.exec_step_block(stmts), Assert(_, _) => Ok(None), StateDecl(_, _, _) => Ok(None), - Call(_, _, _, _) => todo!("execute call?"), - HyperTransition(_, _, _) => todo!("execute hypertransition?"), + HyperTransition(_, _, _, _) => todo!("execute hypertransition?"), } } diff --git a/src/parser/ast/expression.rs b/src/parser/ast/expression.rs index 61556d11..d18c2aff 100644 --- a/src/parser/ast/expression.rs +++ b/src/parser/ast/expression.rs @@ -193,6 +193,12 @@ pub enum Expression { Const(DebugSymRef, F), True(DebugSymRef), False(DebugSymRef), + /// Function or machine call. + /// Tuple values: + /// - debug symbol reference; + /// - function/machine ID; + /// - call argument expressions vector. + Call(DebugSymRef, V, Vec>), } // Shorthand for BigInt expression @@ -217,6 +223,7 @@ impl Expression { Const(_, _) => true, True(_) => false, False(_) => false, + Call(_, _, _) => todo!(), } } @@ -234,6 +241,7 @@ impl Expression { when_true.is_logic() } + Expression::Call { .. } => todo!(), _ => false, } } @@ -247,6 +255,7 @@ impl Expression { Expression::Const(dsym, _) => dsym, Expression::True(dsym) => dsym, Expression::False(dsym) => dsym, + Expression::Call(dsym, _, _) => dsym, } } @@ -260,6 +269,7 @@ impl Expression { Expression::Query(_, _) => false, Expression::True(_) => false, Expression::False(_) => false, + Expression::Call(_, _, _) => false, } } } @@ -315,6 +325,7 @@ impl Debug for Expression { Expression::True(_) => write!(f, "true"), Expression::False(_) => write!(f, "false"), + Expression::Call(_, fun, exprs) => write!(f, "{:?}({:?})", fun, exprs), } } } diff --git a/src/parser/ast/statement.rs b/src/parser/ast/statement.rs index 82398485..896cd50e 100644 --- a/src/parser/ast/statement.rs +++ b/src/parser/ast/statement.rs @@ -13,42 +13,42 @@ pub struct TypedIdDecl { #[derive(Clone)] pub enum Statement { - Assert(DebugSymRef, Expression), // assert x; - - SignalAssignment(DebugSymRef, Vec, Vec>), // x <-- y; - SignalAssignmentAssert(DebugSymRef, Vec, Vec>), // x <== y; - WGAssignment(DebugSymRef, Vec, Vec>), // x = y; - - IfThen(DebugSymRef, Box>, Box>), // if x { y } + /// assert x; + Assert(DebugSymRef, Expression), + /// x <-- y; + SignalAssignment(DebugSymRef, Vec, Vec>), + /// x <== y; + SignalAssignmentAssert(DebugSymRef, Vec, Vec>), + /// x = y; + WGAssignment(DebugSymRef, Vec, Vec>), + /// if x { y } + IfThen(DebugSymRef, Box>, Box>), + /// if x { y } else { z } IfThenElse( DebugSymRef, Box>, Box>, Box>, - ), // if x { y } else { z } - - SignalDecl(DebugSymRef, Vec>), // signal x; - WGVarDecl(DebugSymRef, Vec>), // var x; - - StateDecl(DebugSymRef, V, Box>), // state x { y } + ), + /// signal x; + SignalDecl(DebugSymRef, Vec>), + /// var x; + WGVarDecl(DebugSymRef, Vec>), + /// state x { y } + StateDecl(DebugSymRef, V, Box>), /// Transition to another state. - Transition(DebugSymRef, V, Box>), // -> x { y } - - Block(DebugSymRef, Vec>), // { x } - /// Call into another machine with assertion. - /// Tuple values: - /// - debug symbol reference; - /// - assigned/asserted ids vector; - /// - machine ID; - /// - call argument expressions vector. - Call(DebugSymRef, Vec, V, Vec>), + /// -> x { y } + Transition(DebugSymRef, V, Box>), + /// { x } + Block(DebugSymRef, Vec>), /// Call into another machine with assertion and subsequent transition to another /// state. /// Tuple values: /// - debug symbol reference; - /// - next state ID. - /// - machine call; - HyperTransition(DebugSymRef, V, Box>), + /// - assigned signal IDs; + /// - call expression; + /// - next state ID; + HyperTransition(DebugSymRef, Vec, Expression, V), } impl Debug for Statement { @@ -98,11 +98,17 @@ impl Debug for Statement { .join(" ") ) } - Statement::Call(_, ids, machine, exprs) => { - write!(f, "{:?} <== {} ({:?});", ids, machine.name(), exprs) - } - Statement::HyperTransition(_, state, call) => { - write!(f, "{:?} -> {:?};", call, state) + Statement::HyperTransition(_, ids, call, state) => { + write!( + f, + "{:?} <== {:?} -> {:?};", + ids.iter() + .map(|id| id.name()) + .collect::>() + .join(", "), + call, + state + ) } } } @@ -122,8 +128,7 @@ impl Statement { Statement::StateDecl(dsym, _, _) => dsym.clone(), Statement::Transition(dsym, _, _) => dsym.clone(), Statement::Block(dsym, _) => dsym.clone(), - Statement::Call(dsym, _, _, _) => dsym.clone(), - Statement::HyperTransition(dsym, _, _) => dsym.clone(), + Statement::HyperTransition(dsym, _, _, _) => dsym.clone(), } } } diff --git a/src/parser/build.rs b/src/parser/build.rs index b5cee63a..e179f875 100644 --- a/src/parser/build.rs +++ b/src/parser/build.rs @@ -62,13 +62,14 @@ pub fn build_transition( Statement::Transition(dsym, id, Box::new(block)) } -pub fn build_hyper_transition( +pub fn build_hyper_transition( dsym: DebugSymRef, + ids: Vec, + call: Expression, state: Identifier, - call: Statement, ) -> Statement { match call { - Statement::Call(_, _, _, _) => Statement::HyperTransition(dsym, state, Box::new(call)), + Expression::Call(_, _, _) => Statement::HyperTransition(dsym, ids, call, state), _ => unreachable!("Hyper transition must include a call statement"), } } diff --git a/src/parser/chiquito.lalrpop b/src/parser/chiquito.lalrpop index d1ee42f2..e05882f2 100644 --- a/src/parser/chiquito.lalrpop +++ b/src/parser/chiquito.lalrpop @@ -54,7 +54,6 @@ StatementType: Statement = { ParseTransitionSimple, ParseTransition, HyperTransition, - Call, } AssertEq: Statement = { @@ -114,11 +113,11 @@ ParseTransition: Statement = { } HyperTransition: Statement = { - "->" => build_hyper_transition(dsym_factory.create(l,r), st, call), + "<==" "->" => build_hyper_transition(dsym_factory.create(l,r), ids, call, st), } -Call: Statement = { - "<==" "(" ")" => Statement::Call(dsym_factory.create(l,r), ids, machine, es), +Call: Expression = { + "(" ")" => Expression::Call(dsym_factory.create(l,r), fun, es), } ParseSignalDecl: Statement = { @@ -212,6 +211,7 @@ ExpressionTerm: Expr = { "true" => Expression::True(dsym_factory.create(l,r)), "false" => Expression::False(dsym_factory.create(l,r)), "(" ")" => e, + Call } ParseBinOp: Expr = {