Skip to content

Commit

Permalink
Merge pull request #186 from gio54321/for-refactor
Browse files Browse the repository at this point in the history
Refactor for loops and iterators in the same statement structure
  • Loading branch information
katat authored Sep 23, 2024
2 parents b9bd50d + eef4f16 commit 1711beb
Show file tree
Hide file tree
Showing 5 changed files with 180 additions and 204 deletions.
143 changes: 76 additions & 67 deletions src/circuit_writer/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use crate::{
error::{ErrorKind, Result},
imports::FnKind,
parser::{
types::{FunctionDef, Stmt, StmtKind, TyKind},
types::{ForLoopArgument, FunctionDef, Stmt, StmtKind, TyKind},
Expr, ExprKind, Op2,
},
syntax::is_type,
Expand Down Expand Up @@ -136,83 +136,92 @@ impl<B: Backend> CircuitWriter<B> {
self.add_local_var(fn_env, lhs.value.clone(), var_info);
}

StmtKind::ForLoop { var, range, body } => {
// compute the start and end of the range
let start_bg: BigUint = self
.compute_expr(fn_env, &range.start)?
.ok_or_else(|| {
self.error(ErrorKind::CannotComputeExpression, range.start.span)
})?
.constant()
.expect("expected constant")
.into();
let start: u32 = start_bg
.try_into()
.map_err(|_| self.error(ErrorKind::InvalidRangeSize, range.start.span))?;

let end_bg: BigUint = self
.compute_expr(fn_env, &range.end)?
.ok_or_else(|| self.error(ErrorKind::CannotComputeExpression, range.end.span))?
.constant()
.expect("expected constant")
.into();
let end: u32 = end_bg
.try_into()
.map_err(|_| self.error(ErrorKind::InvalidRangeSize, range.end.span))?;

// compute for the for loop block
for ii in start..end {
fn_env.nest();

let cst_var = Var::new_constant(ii.into(), var.span);
let var_info =
VarInfo::new(cst_var, false, Some(TyKind::Field { constant: true }));
self.add_local_var(fn_env, var.value.clone(), var_info);

self.compile_block(fn_env, body)?;

fn_env.pop();
}
}
StmtKind::IteratorLoop {
StmtKind::ForLoop {
var,
iterator,
argument,
body,
} => {
let iterator_var = self
.compute_expr(fn_env, iterator)?
.expect("array access on non-array");
match argument {
ForLoopArgument::Range(range) => {
// compute the start and end of the range
let start_bg: BigUint = self
.compute_expr(fn_env, &range.start)?
.ok_or_else(|| {
self.error(ErrorKind::CannotComputeExpression, range.start.span)
})?
.constant()
.expect("expected constant")
.into();
let start: u32 = start_bg.try_into().map_err(|_| {
self.error(ErrorKind::InvalidRangeSize, range.start.span)
})?;

let array_typ = self
.expr_type(iterator)
.cloned()
.expect("cannot find type of array");
let end_bg: BigUint = self
.compute_expr(fn_env, &range.end)?
.ok_or_else(|| {
self.error(ErrorKind::CannotComputeExpression, range.end.span)
})?
.constant()
.expect("expected constant")
.into();
let end: u32 = end_bg
.try_into()
.map_err(|_| self.error(ErrorKind::InvalidRangeSize, range.end.span))?;

// compute for the for loop block
for ii in start..end {
fn_env.nest();

let cst_var = Var::new_constant(ii.into(), var.span);
let var_info = VarInfo::new(
cst_var,
false,
Some(TyKind::Field { constant: true }),
);
self.add_local_var(fn_env, var.value.clone(), var_info);

self.compile_block(fn_env, body)?;

fn_env.pop();
}
}
ForLoopArgument::Iterator(iterator) => {
let iterator_var = self
.compute_expr(fn_env, iterator)?
.expect("array access on non-array");

let (elem_type, array_len) = match array_typ {
TyKind::Array(ty, array_len) => (ty, array_len),
_ => panic!("expected array"),
};
let array_typ = self
.expr_type(iterator)
.cloned()
.expect("cannot find type of array");

// compute the size of each element in the array
let len = self.size_of(&elem_type);
let (elem_type, array_len) = match array_typ {
TyKind::Array(ty, array_len) => (ty, array_len),
_ => panic!("expected array"),
};

// compute the size of each element in the array
let len = self.size_of(&elem_type);

for idx in 0..array_len {
// compute the real index
let idx = idx as usize;
let start = idx * len;
for idx in 0..array_len {
// compute the real index
let idx = idx as usize;
let start = idx * len;

fn_env.nest();
fn_env.nest();

// add the variable to the inner enviroment corresponding
// to iterator[idx]
let indexed_var = iterator_var.narrow(start, len).value(self, fn_env);
let var_info =
VarInfo::new(indexed_var.clone(), false, Some(*elem_type.clone()));
self.add_local_var(fn_env, var.value.clone(), var_info);
// add the variable to the inner enviroment corresponding
// to iterator[idx]
let indexed_var = iterator_var.narrow(start, len).value(self, fn_env);
let var_info =
VarInfo::new(indexed_var.clone(), false, Some(*elem_type.clone()));
self.add_local_var(fn_env, var.value.clone(), var_info);

self.compile_block(fn_env, body)?;
self.compile_block(fn_env, body)?;

fn_env.pop();
fn_env.pop();
}
}
}
}
StmtKind::Expr(expr) => {
Expand Down
100 changes: 44 additions & 56 deletions src/mast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::{
error::{Error, ErrorKind, Result},
imports::FnKind,
parser::{
types::{FnArg, FnSig, Range, Stmt, StmtKind, Symbolic, Ty, TyKind},
types::{FnArg, FnSig, ForLoopArgument, Range, Stmt, StmtKind, Symbolic, Ty, TyKind},
CustomType, Expr, ExprKind, FunctionDef, Op2,
},
syntax::{is_generic_parameter, is_type},
Expand Down Expand Up @@ -911,74 +911,62 @@ pub fn monomorphize_stmt<B: Backend>(

Some((stmt_mono, None))
}
StmtKind::ForLoop { var, range, body } => {
StmtKind::ForLoop {
var,
argument,
body,
} => {
// enter a new scope
mono_fn_env.nest();

mono_fn_env.store_type(
&var.value,
// because we don't unroll the loop in the monomorphized AST,
// there is no constant value to propagate.
&MTypeInfo::new(&TyKind::Field { constant: true }, var.span, None),
)?;
let argument_mono = match argument {
ForLoopArgument::Range(range) => {
mono_fn_env.store_type(
&var.value,
// because we don't unroll the loop in the monomorphized AST,
// there is no constant value to propagate.
&MTypeInfo::new(&TyKind::Field { constant: true }, var.span, None),
)?;

let start_mono = monomorphize_expr(ctx, &range.start, mono_fn_env)?;
let end_mono = monomorphize_expr(ctx, &range.end, mono_fn_env)?;
let start_mono = monomorphize_expr(ctx, &range.start, mono_fn_env)?;
let end_mono = monomorphize_expr(ctx, &range.end, mono_fn_env)?;

if start_mono.constant.is_none() || end_mono.constant.is_none() {
return Err(error(ErrorKind::InvalidRangeSize, stmt.span));
}
if start_mono.constant.is_none() || end_mono.constant.is_none() {
return Err(error(ErrorKind::InvalidRangeSize, stmt.span));
}

if start_mono.constant.unwrap() > end_mono.constant.unwrap() {
return Err(error(ErrorKind::InvalidRangeSize, stmt.span));
}
if start_mono.constant.unwrap() > end_mono.constant.unwrap() {
return Err(error(ErrorKind::InvalidRangeSize, stmt.span));
}

let range_mono = Range {
start: start_mono.expr,
end: end_mono.expr,
span: range.span,
let range_mono = Range {
start: start_mono.expr,
end: end_mono.expr,
span: range.span,
};
ForLoopArgument::Range(range_mono)
}
ForLoopArgument::Iterator(iterator) => {
let iterator_mono = monomorphize_expr(ctx, iterator, mono_fn_env)?;
let typ = iterator_mono.typ.as_ref().expect("expected a type");
let array_element_type = match typ {
TyKind::Array(t, _) => t,
_ => panic!("expected an array"),
};

mono_fn_env.store_type(
&var.value,
&MTypeInfo::new(array_element_type, var.span, None),
)?;
ForLoopArgument::Iterator(Box::new(iterator_mono.expr))
}
};

let (stmts_mono, _) = monomorphize_block(ctx, mono_fn_env, body, None)?;
let loop_stmt_mono = Stmt {
kind: StmtKind::ForLoop {
var: var.clone(),
range: range_mono,
body: stmts_mono,
},
span: stmt.span,
};

// exit the scope
mono_fn_env.pop();

Some((loop_stmt_mono, None))
}
StmtKind::IteratorLoop {
var,
iterator,
body,
} => {
// enter a new scope
mono_fn_env.nest();

let iterator_mono = monomorphize_expr(ctx, iterator, mono_fn_env)?;
let typ = iterator_mono.typ.as_ref().expect("expected a type");
let array_element_type = match typ {
TyKind::Array(t, _) => t,
_ => panic!("expected an array"),
};

mono_fn_env.store_type(
&var.value,
&MTypeInfo::new(array_element_type, var.span, None),
)?;

let (stmts_mono, _) = monomorphize_block(ctx, mono_fn_env, body, None)?;
let loop_stmt_mono = Stmt {
kind: StmtKind::IteratorLoop {
var: var.clone(),
iterator: Box::new(iterator_mono.expr),
argument: argument_mono,
body: stmts_mono,
},
span: stmt.span,
Expand Down
18 changes: 7 additions & 11 deletions src/name_resolution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::{
constants::Span,
error::{Error, ErrorKind, Result},
parser::{
types::{FnArg, FnSig, FuncOrMethod, ModulePath, Stmt, StmtKind, TyKind},
types::{FnArg, FnSig, ForLoopArgument, FuncOrMethod, ModulePath, Stmt, StmtKind, TyKind},
ConstDef, CustomType, FunctionDef, StructDef, UsePath,
},
};
Expand Down Expand Up @@ -186,19 +186,15 @@ impl NameResCtx {
StmtKind::Comment(_) => (),
StmtKind::ForLoop {
var: _,
range: _,
argument,
body,
} => {
for stmt in body {
self.resolve_stmt(stmt)?;
// if the argument of the for loop is an iterator, resolve it
if let ForLoopArgument::Iterator(iterator) = argument {
self.resolve_expr(iterator)?;
}
}
StmtKind::IteratorLoop {
var: _,
iterator,
body,
} => {
self.resolve_expr(iterator)?;

// resolve the body of the for loop
for stmt in body {
self.resolve_stmt(stmt)?;
}
Expand Down
35 changes: 10 additions & 25 deletions src/parser/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1243,17 +1243,10 @@ pub enum StmtKind {
Return(Box<Expr>),
Comment(String),

// `for var in 0..10 { <body> }`
// `for var in 0..10 { <body> }` or `for var in vec { <body> }`
ForLoop {
var: Ident,
range: Range,
body: Vec<Stmt>,
},

// `for item in vec { <body> }`
IteratorLoop {
var: Ident,
iterator: Box<Expr>,
argument: ForLoopArgument,
body: Vec<Stmt>,
},
}
Expand Down Expand Up @@ -1394,22 +1387,14 @@ impl Stmt {
let statement = Stmt::parse(ctx, tokens)?;
body.push(statement);
}

//
match argument {
ForLoopArgument::Range(range) => Ok(Stmt {
kind: StmtKind::ForLoop { var, range, body },
span,
}),
ForLoopArgument::Iterator(iterator) => Ok(Stmt {
kind: StmtKind::IteratorLoop {
var,
iterator,
body,
},
span,
}),
}
Ok(Stmt {
kind: StmtKind::ForLoop {
var,
argument,
body,
},
span,
})
}

// if/else
Expand Down
Loading

0 comments on commit 1711beb

Please sign in to comment.