Skip to content

Commit

Permalink
feat: loop keyword in runtime and comptime code (#7096)
Browse files Browse the repository at this point in the history
  • Loading branch information
asterite authored Jan 17, 2025
1 parent df71bde commit c4f183c
Show file tree
Hide file tree
Showing 18 changed files with 231 additions and 28 deletions.
20 changes: 14 additions & 6 deletions compiler/noirc_evaluator/src/ssa/opt/unrolling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -310,11 +310,13 @@ impl Loop {
// simplified to a simple jump.
return None;
}
assert_eq!(
instructions.len(),
1,
"The header should just compare the induction variable and jump"
);

if instructions.len() != 1 {
// The header should just compare the induction variable and jump.
// If that's not the case, this might be a `loop` and not a `for` loop.
return None;
}

match &function.dfg[instructions[0]] {
Instruction::Binary(Binary { lhs: _, operator: BinaryOp::Lt, rhs }) => {
function.dfg.get_numeric_constant(*rhs)
Expand Down Expand Up @@ -750,7 +752,13 @@ fn get_induction_variable(function: &Function, block: BasicBlockId) -> Result<Va
// block parameters. If that becomes the case we'll need to figure out which variable
// is generally constant and increasing to guess which parameter is the induction
// variable.
assert_eq!(arguments.len(), 1, "It is expected that a loop's induction variable is the only block parameter of the loop header");
if arguments.len() != 1 {
// It is expected that a loop's induction variable is the only block parameter of the loop header.
// If there's no variable this might be a `loop`.
let call_stack = function.dfg.get_call_stack(*location);
return Err(call_stack);
}

let value = arguments[0];
if function.dfg.get_numeric_constant(value).is_some() {
Ok(value)
Expand Down
12 changes: 4 additions & 8 deletions compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ pub(super) struct SharedContext {
#[derive(Copy, Clone)]
pub(super) struct Loop {
pub(super) loop_entry: BasicBlockId,
pub(super) loop_index: ValueId,
/// The loop index will be `Some` for a `for` and `None` for a `loop`
pub(super) loop_index: Option<ValueId>,
pub(super) loop_end: BasicBlockId,
}

Expand Down Expand Up @@ -1010,13 +1011,8 @@ impl<'a> FunctionContext<'a> {
}
}

pub(crate) fn enter_loop(
&mut self,
loop_entry: BasicBlockId,
loop_index: ValueId,
loop_end: BasicBlockId,
) {
self.loops.push(Loop { loop_entry, loop_index, loop_end });
pub(crate) fn enter_loop(&mut self, loop_: Loop) {
self.loops.push(loop_);
}

pub(crate) fn exit_loop(&mut self) {
Expand Down
45 changes: 41 additions & 4 deletions compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use acvm::AcirField;
use noirc_frontend::token::FmtStrFragment;
pub(crate) use program::Ssa;

use context::SharedContext;
use context::{Loop, SharedContext};
use iter_extended::{try_vecmap, vecmap};
use noirc_errors::Location;
use noirc_frontend::ast::{UnaryOp, Visibility};
Expand Down Expand Up @@ -152,6 +152,7 @@ impl<'a> FunctionContext<'a> {
Expression::Index(index) => self.codegen_index(index),
Expression::Cast(cast) => self.codegen_cast(cast),
Expression::For(for_expr) => self.codegen_for(for_expr),
Expression::Loop(block) => self.codegen_loop(block),
Expression::If(if_expr) => self.codegen_if(if_expr),
Expression::Tuple(tuple) => self.codegen_tuple(tuple),
Expression::ExtractTupleField(tuple, index) => {
Expand Down Expand Up @@ -557,7 +558,7 @@ impl<'a> FunctionContext<'a> {

// Remember the blocks and variable used in case there are break/continue instructions
// within the loop which need to jump to them.
self.enter_loop(loop_entry, loop_index, loop_end);
self.enter_loop(Loop { loop_entry, loop_index: Some(loop_index), loop_end });

// Set the location of the initial jmp instruction to the start range. This is the location
// used to issue an error if the start range cannot be determined at compile-time.
Expand Down Expand Up @@ -587,6 +588,38 @@ impl<'a> FunctionContext<'a> {
Ok(Self::unit_value())
}

/// Codegens a loop, creating three new blocks in the process.
/// The return value of a loop is always a unit literal.
///
/// For example, the loop `loop { body }` is codegen'd as:
///
/// ```text
/// br loop_body()
/// loop_body():
/// v3 = ... codegen body ...
/// br loop_body()
/// loop_end():
/// ... This is the current insert point after codegen_for finishes ...
/// ```
fn codegen_loop(&mut self, block: &Expression) -> Result<Values, RuntimeError> {
let loop_body = self.builder.insert_block();
let loop_end = self.builder.insert_block();

self.enter_loop(Loop { loop_entry: loop_body, loop_index: None, loop_end });

self.builder.terminate_with_jmp(loop_body, vec![]);

// Compile the loop body
self.builder.switch_to_block(loop_body);
self.codegen_expression(block)?;
self.builder.terminate_with_jmp(loop_body, vec![]);

// Finish by switching to the end of the loop
self.builder.switch_to_block(loop_end);
self.exit_loop();
Ok(Self::unit_value())
}

/// Codegens an if expression, handling the case of what to do if there is no 'else'.
///
/// For example, the expression `if cond { a } else { b }` is codegen'd as:
Expand Down Expand Up @@ -852,8 +885,12 @@ impl<'a> FunctionContext<'a> {
let loop_ = self.current_loop();

// Must remember to increment i before jumping
let new_loop_index = self.make_offset(loop_.loop_index, 1);
self.builder.terminate_with_jmp(loop_.loop_entry, vec![new_loop_index]);
if let Some(loop_index) = loop_.loop_index {
let new_loop_index = self.make_offset(loop_index, 1);
self.builder.terminate_with_jmp(loop_.loop_entry, vec![new_loop_index]);
} else {
self.builder.terminate_with_jmp(loop_.loop_entry, vec![]);
}
Self::unit_value()
}
}
1 change: 1 addition & 0 deletions compiler/noirc_frontend/src/elaborator/lints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ fn can_return_without_recursing(interner: &NodeInterner, func_id: FuncId, expr_i
HirStatement::Semi(e) => check(e),
// Rust doesn't seem to check the for loop body (it's bounds might mean it's never called).
HirStatement::For(e) => check(e.start_range) && check(e.end_range),
HirStatement::Loop(e) => check(e),
HirStatement::Constrain(_)
| HirStatement::Comptime(_)
| HirStatement::Break
Expand Down
20 changes: 17 additions & 3 deletions compiler/noirc_frontend/src/elaborator/statements.rs
Original file line number Diff line number Diff line change
Expand Up @@ -271,11 +271,25 @@ impl<'context> Elaborator<'context> {

pub(super) fn elaborate_loop(
&mut self,
_block: Expression,
block: Expression,
span: noirc_errors::Span,
) -> (HirStatement, Type) {
self.push_err(ResolverError::LoopNotYetSupported { span });
(HirStatement::Error, Type::Unit)
let in_constrained_function = self.in_constrained_function();
if in_constrained_function {
self.push_err(ResolverError::LoopInConstrainedFn { span });
}

self.nested_loops += 1;
self.push_scope();

let (block, _block_type) = self.elaborate_expression(block);

self.pop_scope();
self.nested_loops -= 1;

let statement = HirStatement::Loop(block);

(statement, Type::Unit)
}

fn elaborate_jump(&mut self, is_break: bool, span: noirc_errors::Span) -> (HirStatement, Type) {
Expand Down
13 changes: 12 additions & 1 deletion compiler/noirc_frontend/src/hir/comptime/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,9 @@ pub enum InterpreterError {
GlobalsDependencyCycle {
location: Location,
},
LoopHaltedForUiResponsiveness {
location: Location,
},

// These cases are not errors, they are just used to prevent us from running more code
// until the loop can be resumed properly. These cases will never be displayed to users.
Expand Down Expand Up @@ -323,7 +326,8 @@ impl InterpreterError {
| InterpreterError::CannotSetFunctionBody { location, .. }
| InterpreterError::UnknownArrayLength { location, .. }
| InterpreterError::CannotInterpretFormatStringWithErrors { location }
| InterpreterError::GlobalsDependencyCycle { location } => *location,
| InterpreterError::GlobalsDependencyCycle { location }
| InterpreterError::LoopHaltedForUiResponsiveness { location } => *location,

InterpreterError::FailedToParseMacro { error, file, .. } => {
Location::new(error.span(), *file)
Expand Down Expand Up @@ -683,6 +687,13 @@ impl<'a> From<&'a InterpreterError> for CustomDiagnostic {
let secondary = String::new();
CustomDiagnostic::simple_error(msg, secondary, location.span)
}
InterpreterError::LoopHaltedForUiResponsiveness { location } => {
let msg = "This loop took too much time to execute so it was halted for UI responsiveness"
.to_string();
let secondary =
"This error doesn't happen in normal executions of `nargo`".to_string();
CustomDiagnostic::simple_warning(msg, secondary, location.span)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ impl HirStatement {
block: for_stmt.block.to_display_ast(interner),
span,
}),
HirStatement::Loop(block) => StatementKind::Loop(block.to_display_ast(interner)),
HirStatement::Break => StatementKind::Break,
HirStatement::Continue => StatementKind::Continue,
HirStatement::Expression(expr) => {
Expand Down
29 changes: 29 additions & 0 deletions compiler/noirc_frontend/src/hir/comptime/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1550,6 +1550,7 @@ impl<'local, 'interner> Interpreter<'local, 'interner> {
HirStatement::Constrain(constrain) => self.evaluate_constrain(constrain),
HirStatement::Assign(assign) => self.evaluate_assign(assign),
HirStatement::For(for_) => self.evaluate_for(for_),
HirStatement::Loop(expression) => self.evaluate_loop(expression),
HirStatement::Break => self.evaluate_break(statement),
HirStatement::Continue => self.evaluate_continue(statement),
HirStatement::Expression(expression) => self.evaluate(expression),
Expand Down Expand Up @@ -1741,6 +1742,34 @@ impl<'local, 'interner> Interpreter<'local, 'interner> {
Ok(Value::Unit)
}

fn evaluate_loop(&mut self, expr: ExprId) -> IResult<Value> {
let was_in_loop = std::mem::replace(&mut self.in_loop, true);
let in_lsp = self.elaborator.interner.is_in_lsp_mode();
let mut counter = 0;

loop {
self.push_scope();

match self.evaluate(expr) {
Ok(_) => (),
Err(InterpreterError::Break) => break,
Err(InterpreterError::Continue) => continue,
Err(other) => return Err(other),
}

self.pop_scope();

counter += 1;
if in_lsp && counter == 10_000 {
let location = self.elaborator.interner.expr_location(&expr);
return Err(InterpreterError::LoopHaltedForUiResponsiveness { location });
}
}

self.in_loop = was_in_loop;
Ok(Value::Unit)
}

fn evaluate_break(&mut self, id: StmtId) -> IResult<Value> {
if self.in_loop {
Err(InterpreterError::Break)
Expand Down
9 changes: 9 additions & 0 deletions compiler/noirc_frontend/src/hir/resolution/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ pub enum ResolverError {
DependencyCycle { span: Span, item: String, cycle: String },
#[error("break/continue are only allowed in unconstrained functions")]
JumpInConstrainedFn { is_break: bool, span: Span },
#[error("loop is only allowed in unconstrained functions")]
LoopInConstrainedFn { span: Span },
#[error("break/continue are only allowed within loops")]
JumpOutsideLoop { is_break: bool, span: Span },
#[error("Only `comptime` globals can be mutable")]
Expand Down Expand Up @@ -434,6 +436,13 @@ impl<'a> From<&'a ResolverError> for Diagnostic {
*span,
)
},
ResolverError::LoopInConstrainedFn { span } => {
Diagnostic::simple_error(
"loop is only allowed in unconstrained functions".into(),
"Constrained code must always have a known number of loop iterations".into(),
*span,
)
},
ResolverError::JumpOutsideLoop { is_break, span } => {
let item = if *is_break { "break" } else { "continue" };
Diagnostic::simple_error(
Expand Down
1 change: 1 addition & 0 deletions compiler/noirc_frontend/src/hir_def/stmt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ pub enum HirStatement {
Constrain(HirConstrainStatement),
Assign(HirAssignStatement),
For(HirForStatement),
Loop(ExprId),
Break,
Continue,
Expression(ExprId),
Expand Down
1 change: 1 addition & 0 deletions compiler/noirc_frontend/src/monomorphization/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ pub enum Expression {
Index(Index),
Cast(Cast),
For(For),
Loop(Box<Expression>),
If(If),
Tuple(Vec<Expression>),
ExtractTupleField(Box<Expression>, usize),
Expand Down
4 changes: 4 additions & 0 deletions compiler/noirc_frontend/src/monomorphization/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -695,6 +695,10 @@ impl<'interner> Monomorphizer<'interner> {
block,
}))
}
HirStatement::Loop(block) => {
let block = Box::new(self.expr(block)?);
Ok(ast::Expression::Loop(block))
}
HirStatement::Expression(expr) => self.expr(expr),
HirStatement::Semi(expr) => {
self.expr(expr).map(|expr| ast::Expression::Semi(Box::new(expr)))
Expand Down
10 changes: 10 additions & 0 deletions compiler/noirc_frontend/src/monomorphization/printer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ impl AstPrinter {
write!(f, " as {})", cast.r#type)
}
Expression::For(for_expr) => self.print_for(for_expr, f),
Expression::Loop(block) => self.print_loop(block, f),
Expression::If(if_expr) => self.print_if(if_expr, f),
Expression::Tuple(tuple) => self.print_tuple(tuple, f),
Expression::ExtractTupleField(expr, index) => {
Expand Down Expand Up @@ -209,6 +210,15 @@ impl AstPrinter {
write!(f, "}}")
}

fn print_loop(&mut self, block: &Expression, f: &mut Formatter) -> Result<(), std::fmt::Error> {
write!(f, "loop {{")?;
self.indent_level += 1;
self.print_expr_expect_block(block, f)?;
self.indent_level -= 1;
self.next_line(f)?;
write!(f, "}}")
}

fn print_if(
&mut self,
if_expr: &super::ast::If,
Expand Down
9 changes: 7 additions & 2 deletions compiler/noirc_frontend/src/parser/parser/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -294,10 +294,13 @@ impl<'a> Parser<'a> {

/// LoopStatement = 'loop' Block
fn parse_loop(&mut self) -> Option<Expression> {
let start_span = self.current_token_span;
if !self.eat_keyword(Keyword::Loop) {
return None;
}

self.push_error(ParserErrorReason::ExperimentalFeature("loops"), start_span);

let block_start_span = self.current_token_span;
let block = if let Some(block) = self.parse_block() {
Expression {
Expand Down Expand Up @@ -819,7 +822,8 @@ mod tests {
#[test]
fn parses_empty_loop() {
let src = "loop { }";
let statement = parse_statement_no_errors(src);
let mut parser = Parser::for_str(src);
let statement = parser.parse_statement_or_error();
let StatementKind::Loop(block) = statement.kind else {
panic!("Expected loop");
};
Expand All @@ -832,7 +836,8 @@ mod tests {
#[test]
fn parses_loop_with_statements() {
let src = "loop { 1; 2 }";
let statement = parse_statement_no_errors(src);
let mut parser = Parser::for_str(src);
let statement = parser.parse_statement_or_error();
let StatementKind::Loop(block) = statement.kind else {
panic!("Expected loop");
};
Expand Down
1 change: 1 addition & 0 deletions compiler/noirc_frontend/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -903,6 +903,7 @@ fn find_lambda_captures(stmts: &[StmtId], interner: &NodeInterner, result: &mut
HirStatement::Constrain(constr_stmt) => constr_stmt.0,
HirStatement::Semi(semi_expr) => semi_expr,
HirStatement::For(for_loop) => for_loop.block,
HirStatement::Loop(block) => block,
HirStatement::Error => panic!("Invalid HirStatement!"),
HirStatement::Break => panic!("Unexpected break"),
HirStatement::Continue => panic!("Unexpected continue"),
Expand Down
Loading

0 comments on commit c4f183c

Please sign in to comment.