Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(experimental): try to infer lambda argument types inside calls #7088

Merged
merged 16 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions compiler/noirc_frontend/src/ast/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,14 @@ impl ExpressionKind {
struct_type: None,
}))
}

pub fn is_lambda_without_type_annotations(&self) -> bool {
if let ExpressionKind::Lambda(lambda) = self {
lambda.parameters.iter().any(|(_, typ)| typ.typ.is_unspecified())
} else {
false
}
}
}

impl Recoverable for ExpressionKind {
Expand Down
4 changes: 4 additions & 0 deletions compiler/noirc_frontend/src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,10 @@
| UnresolvedTypeData::Error => false,
}
}

pub fn is_unspecified(&self) -> bool {
matches!(self, UnresolvedTypeData::Unspecified)
}
}

#[derive(Debug, PartialEq, Eq, Copy, Clone, Hash, PartialOrd, Ord)]
Expand Down Expand Up @@ -601,7 +605,7 @@
Self::Public => write!(f, "pub"),
Self::Private => write!(f, "priv"),
Self::CallData(id) => write!(f, "calldata{id}"),
Self::ReturnData => write!(f, "returndata"),

Check warning on line 608 in compiler/noirc_frontend/src/ast/mod.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (returndata)
}
}
}
193 changes: 159 additions & 34 deletions compiler/noirc_frontend/src/elaborator/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ impl<'context> Elaborator<'context> {
ExpressionKind::If(if_) => self.elaborate_if(*if_),
ExpressionKind::Variable(variable) => return self.elaborate_variable(variable),
ExpressionKind::Tuple(tuple) => self.elaborate_tuple(tuple),
ExpressionKind::Lambda(lambda) => self.elaborate_lambda(*lambda),
ExpressionKind::Lambda(lambda) => self.elaborate_lambda(*lambda, None),
ExpressionKind::Parenthesized(expr) => return self.elaborate_expression(*expr),
ExpressionKind::Quote(quote) => self.elaborate_quote(quote, expr.span),
ExpressionKind::Comptime(comptime, _) => {
Expand Down Expand Up @@ -388,19 +388,49 @@ impl<'context> Elaborator<'context> {
fn elaborate_call(&mut self, call: CallExpression, span: Span) -> (HirExpression, Type) {
let (func, func_type) = self.elaborate_expression(*call.func);

let any_argument_has_lambda_without_type_annotations =
call.arguments.iter().any(|arg| arg.kind.is_lambda_without_type_annotations());

let mut arguments = Vec::with_capacity(call.arguments.len());
let args = vecmap(call.arguments, |arg| {
let span = arg.span;
let args: Vec<_> = call
.arguments
.into_iter()
.enumerate()
.map(|(arg_index, arg)| {
asterite marked this conversation as resolved.
Show resolved Hide resolved
let span = arg.span;

let (arg, typ) = if call.is_macro_call {
self.elaborate_in_comptime_context(|this| {
this.elaborate_call_argument_expression(
arg,
arg_index,
&func_type,
any_argument_has_lambda_without_type_annotations,
)
})
} else {
self.elaborate_call_argument_expression(
arg,
arg_index,
&func_type,
any_argument_has_lambda_without_type_annotations,
)
};
asterite marked this conversation as resolved.
Show resolved Hide resolved

let (arg, typ) = if call.is_macro_call {
self.elaborate_in_comptime_context(|this| this.elaborate_expression(arg))
} else {
self.elaborate_expression(arg)
};
if any_argument_has_lambda_without_type_annotations {
// Try to unify this argument type against the function's argument type
// so that a potential lambda following this argument can have more concrete types.
if let Type::Function(func_args, _, _, _) = &func_type {
if let Some(func_arg_type) = func_args.get(arg_index) {
let _ = func_arg_type.unify(&typ);
}
}
}
asterite marked this conversation as resolved.
Show resolved Hide resolved

arguments.push(arg);
(typ, arg, span)
});
arguments.push(arg);
(typ, arg, span)
})
.collect();

// Avoid cloning arguments unless this is a macro call
let mut comptime_args = Vec::new();
Expand Down Expand Up @@ -458,24 +488,69 @@ impl<'context> Elaborator<'context> {
None
};

let call_span = Span::from(object_span.start()..method_name_span.end());
let location = Location::new(call_span, self.file);

let (function_id, function_name) = method_ref.clone().into_function_id_and_name(
object_type.clone(),
generics.clone(),
location,
self.interner,
);

let func_type =
self.type_check_variable(function_name.clone(), function_id, generics.clone());
self.interner.push_expr_type(function_id, func_type.clone());

let any_argument_has_lambda_without_type_annotations = method_call
.arguments
.iter()
.any(|arg| arg.kind.is_lambda_without_type_annotations());

if any_argument_has_lambda_without_type_annotations {
// Try to unify the object type with the first argument of the function.
// The reason to do this is that many methods that take a lambda will yield `self` or part of `self`
// as a parameter. By unifying `self` with the first argument we'll potentially get more
// concrete types in the arguments that are function types, which will later be passed as
// lambda parameter hints.
if let Type::Function(args, _, _, _) = &func_type {
if !args.is_empty() {
let _ = args[0].unify(&object_type);
}
}
}

// These arguments will be given to the desugared function call.
// Compared to the method arguments, they also contain the object.
let mut function_args = Vec::with_capacity(method_call.arguments.len() + 1);
let mut arguments = Vec::with_capacity(method_call.arguments.len());

function_args.push((object_type.clone(), object, object_span));

for arg in method_call.arguments {
for (arg_index, arg) in method_call.arguments.into_iter().enumerate() {
let span = arg.span;
let (arg, typ) = self.elaborate_expression(arg);
let (arg, typ) = self.elaborate_call_argument_expression(
arg,
arg_index + 1,
&func_type,
any_argument_has_lambda_without_type_annotations,
);

if any_argument_has_lambda_without_type_annotations {
// Try to unify this argument type against the function's argument type
// so that a potential lambda following this argument can have more concrete types.
if let Type::Function(func_args, _, _, _) = &func_type {
if let Some(func_arg_type) = func_args.get(arg_index + 1) {
let _ = func_arg_type.unify(&typ);
}
}
}

arguments.push(arg);
function_args.push((typ, arg, span));
}

let call_span = Span::from(object_span.start()..method_name_span.end());
let location = Location::new(call_span, self.file);
let method = method_call.method_name;
let turbofish_generics = generics.clone();
let is_macro_call = method_call.is_macro_call;
let method_call =
HirMethodCallExpression { method, object, arguments, location, generics };
Expand All @@ -485,18 +560,9 @@ impl<'context> Elaborator<'context> {
// Desugar the method call into a normal, resolved function call
// so that the backend doesn't need to worry about methods
// TODO: update object_type here?
let ((function_id, function_name), function_call) = method_call.into_function_call(
method_ref,
object_type,
is_macro_call,
location,
self.interner,
);

let func_type =
self.type_check_variable(function_name, function_id, turbofish_generics);

self.interner.push_expr_type(function_id, func_type.clone());
let function_call =
method_call.into_function_call(function_id, is_macro_call, location);

self.interner
.add_function_reference(func_id, Location::new(method_name_span, self.file));
Expand All @@ -520,6 +586,40 @@ impl<'context> Elaborator<'context> {
}
}

/// Elaborates an expression taking into account that it's a call argument in a function
/// that has the given type, and `arg_index` is the index of that argument in that function type.
fn elaborate_call_argument_expression(
&mut self,
arg: Expression,
arg_index: usize,
func_type: &Type,
any_argument_has_lambda_without_type_annotations: bool,
asterite marked this conversation as resolved.
Show resolved Hide resolved
) -> (ExprId, Type) {
if !any_argument_has_lambda_without_type_annotations {
return self.elaborate_expression(arg);
}

let ExpressionKind::Lambda(lambda) = arg.kind else {
return self.elaborate_expression(arg);
};

let span = arg.span;
let type_hint = if let Type::Function(func_args, _, _, _) = func_type {
if let Some(Type::Function(func_args, _, _, _)) = func_args.get(arg_index) {
Some(func_args)
} else {
None
}
} else {
None
};
let (hir_expr, typ) = self.elaborate_lambda(*lambda, type_hint);
let id = self.interner.push_expr(hir_expr);
self.interner.push_expr_location(id, span, self.file);
self.interner.push_expr_type(id, typ.clone());
(id, typ)
}

fn check_method_call_visibility(&mut self, func_id: FuncId, object_type: &Type, name: &Ident) {
if !method_call_is_visible(
object_type,
Expand Down Expand Up @@ -846,19 +946,44 @@ impl<'context> Elaborator<'context> {
(HirExpression::Tuple(element_ids), Type::Tuple(element_types))
}

fn elaborate_lambda(&mut self, lambda: Lambda) -> (HirExpression, Type) {
/// For elaborating a lambda we might get `parameters_type_hints`. These come from a potential
/// call that has this lambda as the argument.
/// The parameter type hints will be the types of the function type corresponding to the lambda argument.
fn elaborate_lambda(
&mut self,
lambda: Lambda,
parameters_type_hints: Option<&Vec<Type>>,
) -> (HirExpression, Type) {
self.push_scope();
let scope_index = self.scopes.current_scope_index();

self.lambda_stack.push(LambdaContext { captures: Vec::new(), scope_index });

let mut arg_types = Vec::with_capacity(lambda.parameters.len());
let parameters = vecmap(lambda.parameters, |(pattern, typ)| {
let parameter = DefinitionKind::Local(None);
let typ = self.resolve_inferred_type(typ);
arg_types.push(typ.clone());
(self.elaborate_pattern(pattern, typ.clone(), parameter, true), typ)
});
let parameters: Vec<_> = lambda
.parameters
.into_iter()
.enumerate()
.map(|(index, (pattern, typ))| {
let parameter = DefinitionKind::Local(None);
let is_unspecified = matches!(typ.typ, UnresolvedTypeData::Unspecified);
let typ = self.resolve_inferred_type(typ);

if is_unspecified {
asterite marked this conversation as resolved.
Show resolved Hide resolved
// If there's a parameter type hint, use it to unify the argument type
if let Some(parameter_type_hint) =
parameters_type_hints.and_then(|hints| hints.get(index))
{
// We don't error here because eventually the lambda type will be checked against
// the call that contains it, which would then produce an error if this didn't unify.
let _ = typ.unify(parameter_type_hint);
}
}

arg_types.push(typ.clone());
(self.elaborate_pattern(pattern, typ.clone(), parameter, true), typ)
})
.collect();

let return_type = self.resolve_inferred_type(lambda.return_type);
let body_span = lambda.body.span;
Expand Down
37 changes: 20 additions & 17 deletions compiler/noirc_frontend/src/hir_def/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,24 +225,15 @@ impl HirMethodReference {
}
}
}
}

impl HirMethodCallExpression {
/// Converts a method call into a function call
///
/// Returns ((func_var_id, func_var), call_expr)
pub fn into_function_call(
mut self,
method: HirMethodReference,
pub fn into_function_id_and_name(
self,
object_type: Type,
is_macro_call: bool,
generics: Option<Vec<Type>>,
location: Location,
interner: &mut NodeInterner,
) -> ((ExprId, HirIdent), HirCallExpression) {
let mut arguments = vec![self.object];
arguments.append(&mut self.arguments);

let (id, impl_kind) = match method {
) -> (ExprId, HirIdent) {
let (id, impl_kind) = match self {
HirMethodReference::FuncId(func_id) => {
(interner.function_definition_id(func_id), ImplKind::NotATraitMethod)
}
Expand All @@ -261,10 +252,22 @@ impl HirMethodCallExpression {
}
};
let func_var = HirIdent { location, id, impl_kind };
let func = interner.push_expr(HirExpression::Ident(func_var.clone(), self.generics));
let func = interner.push_expr(HirExpression::Ident(func_var.clone(), generics));
interner.push_expr_location(func, location.span, location.file);
let expr = HirCallExpression { func, arguments, location, is_macro_call };
((func, func_var), expr)
(func, func_var)
}
}

impl HirMethodCallExpression {
pub fn into_function_call(
mut self,
func: ExprId,
is_macro_call: bool,
location: Location,
) -> HirCallExpression {
let mut arguments = vec![self.object];
arguments.append(&mut self.arguments);
HirCallExpression { func, arguments, location, is_macro_call }
}
}

Expand Down
Loading
Loading