Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: inline simple functions #7160

Merged
merged 15 commits into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions compiler/noirc_evaluator/src/ssa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ fn optimize_all(builder: SsaBuilder, options: &SsaEvaluatorOptions) -> Result<Ss
Ok(builder
.run_pass(Ssa::remove_unreachable_functions, "Removing Unreachable Functions (1st)")
.run_pass(Ssa::defunctionalize, "Defunctionalization")
.run_pass(Ssa::inline_simple_functions, "Inlining simple functions")
.run_pass(Ssa::remove_paired_rc, "Removing Paired rc_inc & rc_decs")
.run_pass(
|ssa| ssa.preprocess_functions(options.inliner_aggressiveness),
Expand Down
177 changes: 114 additions & 63 deletions compiler/noirc_evaluator/src/ssa/opt/inlining.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,63 +64,61 @@ impl Ssa {
let inline_targets =
inline_infos.iter().filter_map(|(id, info)| info.is_inline_target().then_some(*id));

let should_inline_call = |callee: &Function| -> bool {
match callee.runtime() {
RuntimeType::Acir(_) => {
// If we have not already finished the flattening pass, functions marked
// to not have predicates should be preserved.
let preserve_function =
!inline_no_predicates_functions && callee.is_no_predicates();
!preserve_function
}
RuntimeType::Brillig(_) => {
// We inline inline if the function called wasn't ruled out as too costly or recursive.
InlineInfo::should_inline(inline_infos, callee.id())
}
}
};

// NOTE: Functions are processed independently of each other, with the final mapping replacing the original,
// instead of inlining the "leaf" functions, moving up towards the entry point.
self.functions = btree_map(inline_targets, |entry_point| {
let function = &self.functions[&entry_point];
let new_function =
function.inlined(&self, inline_no_predicates_functions, inline_infos);
let new_function = function.inlined(&self, &should_inline_call);
(entry_point, new_function)
});
self
}

pub(crate) fn inline_simple_functions(mut self: Ssa) -> Ssa {
let should_inline_call = |function: &Function| {
let entry_block_id = function.entry_block();
let entry_block = &function.dfg[entry_block_id];

// Only inline functions with a single block
if entry_block.successors().next().is_some() {
return false;
}

// Only inline functions with 0 or 1 instructions
entry_block.instructions().len() <= 1
};

self.functions = btree_map(self.functions.iter(), |(id, function)| {
(*id, function.inlined(&self, &should_inline_call))
});

self
}
}

impl Function {
/// Create a new function which has the functions called by this one inlined into its body.
pub(super) fn inlined(
&self,
ssa: &Ssa,
inline_no_predicates_functions: bool,
inline_infos: &InlineInfos,
should_inline_call: &impl Fn(&Function) -> bool,
) -> Function {
let caller_runtime = self.runtime();

let should_inline_call =
|_context: &PerFunctionContext, ssa: &Ssa, called_func_id: FunctionId| -> bool {
// Do not inline self-recursive functions on the top level.
// Inlining a self-recursive function works when there is something to inline into
// by importing all the recursive blocks, but for the entry function there is no wrapper.
if called_func_id == self.id() {
return false;
}
let callee = &ssa.functions[&called_func_id];

match callee.runtime() {
RuntimeType::Acir(inline_type) => {
// If the called function is acir, we inline if it's not an entry point

// If we have not already finished the flattening pass, functions marked
// to not have predicates should be preserved.
let preserve_function =
!inline_no_predicates_functions && callee.is_no_predicates();

!inline_type.is_entry_point() && !preserve_function
}
RuntimeType::Brillig(_) => {
if caller_runtime.is_acir() {
// We never inline a brillig function into an ACIR function.
return false;
}
// We inline inline if the function called wasn't ruled out as too costly or recursive.
inline_infos
.get(&called_func_id)
.map(|info| info.should_inline)
.unwrap_or_default()
}
}
};

InlineContext::new(ssa, self.id()).inline_all(ssa, &should_inline_call)
}
}
Expand All @@ -146,6 +144,9 @@ struct InlineContext {
/// inline into. The same goes for ValueIds, InstructionIds, and for storing other data like
/// parameter to argument mappings.
struct PerFunctionContext<'function> {
/// The function that we are inlining calls into.
entry_function: &'function Function,

/// The source function is the function we're currently inlining into the function being built.
source_function: &'function Function,

Expand Down Expand Up @@ -205,7 +206,7 @@ pub(super) struct InlineInfo {
is_brillig_entry_point: bool,
is_acir_entry_point: bool,
is_recursive: bool,
should_inline: bool,
pub(super) should_inline: bool,
weight: i64,
cost: i64,
}
Expand All @@ -218,6 +219,10 @@ impl InlineInfo {
|| self.is_recursive
|| !self.should_inline
}

pub(super) fn should_inline(inline_infos: &InlineInfos, called_func_id: FunctionId) -> bool {
inline_infos.get(&called_func_id).map(|info| info.should_inline).unwrap_or_default()
}
}

type InlineInfos = BTreeMap<FunctionId, InlineInfo>;
Expand Down Expand Up @@ -519,7 +524,7 @@ fn mark_brillig_functions_to_retain(
inline_no_predicates_functions: bool,
aggressiveness: i64,
times_called: &HashMap<FunctionId, usize>,
inline_infos: &mut BTreeMap<FunctionId, InlineInfo>,
inline_infos: &mut InlineInfos,
) {
let brillig_entry_points = inline_infos
.iter()
Expand Down Expand Up @@ -574,11 +579,12 @@ impl InlineContext {
fn inline_all(
mut self,
ssa: &Ssa,
should_inline_call: &impl Fn(&PerFunctionContext, &Ssa, FunctionId) -> bool,
should_inline_call: &impl Fn(&Function) -> bool,
) -> Function {
let entry_point = &ssa.functions[&self.entry_point];

let mut context = PerFunctionContext::new(&mut self, entry_point, &ssa.globals);
let mut context =
PerFunctionContext::new(&mut self, entry_point, entry_point, &ssa.globals);
context.inlining_entry = true;

for (_, value) in entry_point.dfg.globals.values_iter() {
Expand Down Expand Up @@ -617,7 +623,7 @@ impl InlineContext {
ssa: &Ssa,
id: FunctionId,
arguments: &[ValueId],
should_inline_call: &impl Fn(&PerFunctionContext, &Ssa, FunctionId) -> bool,
should_inline_call: &impl Fn(&Function) -> bool,
) -> Vec<ValueId> {
self.recursion_level += 1;

Expand All @@ -629,7 +635,8 @@ impl InlineContext {
);
}

let mut context = PerFunctionContext::new(self, source_function, &ssa.globals);
let entry_point = &ssa.functions[&self.entry_point];
let mut context = PerFunctionContext::new(self, entry_point, source_function, &ssa.globals);

let parameters = source_function.parameters();
assert_eq!(parameters.len(), arguments.len());
Expand All @@ -651,11 +658,13 @@ impl<'function> PerFunctionContext<'function> {
/// the arguments of the destination function.
fn new(
context: &'function mut InlineContext,
entry_function: &'function Function,
source_function: &'function Function,
globals: &'function Function,
) -> Self {
Self {
context,
entry_function,
source_function,
blocks: HashMap::default(),
values: HashMap::default(),
Expand Down Expand Up @@ -777,7 +786,7 @@ impl<'function> PerFunctionContext<'function> {
fn inline_blocks(
&mut self,
ssa: &Ssa,
should_inline_call: &impl Fn(&PerFunctionContext, &Ssa, FunctionId) -> bool,
should_inline_call: &impl Fn(&Function) -> bool,
) -> Vec<ValueId> {
let mut seen_blocks = HashSet::new();
let mut block_queue = VecDeque::new();
Expand Down Expand Up @@ -844,7 +853,7 @@ impl<'function> PerFunctionContext<'function> {
&mut self,
ssa: &Ssa,
block_id: BasicBlockId,
should_inline_call: &impl Fn(&PerFunctionContext, &Ssa, FunctionId) -> bool,
should_inline_call: &impl Fn(&Function) -> bool,
) {
let mut side_effects_enabled: Option<ValueId> = None;

Expand All @@ -853,19 +862,29 @@ impl<'function> PerFunctionContext<'function> {
match &self.source_function.dfg[*id] {
Instruction::Call { func, arguments } => match self.get_function(*func) {
Some(func_id) => {
if should_inline_call(self, ssa, func_id) {
self.inline_function(ssa, *id, func_id, arguments, should_inline_call);

// This is only relevant during handling functions with `InlineType::NoPredicates` as these
// can pollute the function they're being inlined into with `Instruction::EnabledSideEffects`,
// resulting in predicates not being applied properly.
//
// Note that this doesn't cover the case in which there exists an `Instruction::EnabledSideEffects`
// within the function being inlined whilst the source function has not encountered one yet.
// In practice this isn't an issue as the last `Instruction::EnabledSideEffects` in the
// function being inlined will be to turn off predicates rather than to create one.
if let Some(condition) = side_effects_enabled {
self.context.builder.insert_enable_side_effects_if(condition);
if let Some(callee) = self.should_inline_call(ssa, func_id) {
if should_inline_call(callee) {
self.inline_function(
ssa,
*id,
func_id,
arguments,
should_inline_call,
);

// This is only relevant during handling functions with `InlineType::NoPredicates` as these
// can pollute the function they're being inlined into with `Instruction::EnabledSideEffects`,
// resulting in predicates not being applied properly.
//
// Note that this doesn't cover the case in which there exists an `Instruction::EnabledSideEffects`
// within the function being inlined whilst the source function has not encountered one yet.
// In practice this isn't an issue as the last `Instruction::EnabledSideEffects` in the
// function being inlined will be to turn off predicates rather than to create one.
if let Some(condition) = side_effects_enabled {
self.context.builder.insert_enable_side_effects_if(condition);
}
} else {
self.push_instruction(*id);
}
} else {
self.push_instruction(*id);
Expand All @@ -882,14 +901,46 @@ impl<'function> PerFunctionContext<'function> {
}
}

fn should_inline_call<'a>(
&self,
ssa: &'a Ssa,
called_func_id: FunctionId,
) -> Option<&'a Function> {
// Do not inline self-recursive functions on the top level.
// Inlining a self-recursive function works when there is something to inline into
// by importing all the recursive blocks, but for the entry function there is no wrapper.
if self.entry_function.id() == called_func_id {
return None;
}

let callee = &ssa.functions[&called_func_id];

match callee.runtime() {
RuntimeType::Acir(inline_type) => {
// If the called function is acir, we inline if it's not an entry point
if inline_type.is_entry_point() {
return None;
}
}
RuntimeType::Brillig(_) => {
if self.entry_function.runtime().is_acir() {
// We never inline a brillig function into an ACIR function.
return None;
}
}
}

Some(callee)
}

/// Inline a function call and remember the inlined return values in the values map
fn inline_function(
&mut self,
ssa: &Ssa,
call_id: InstructionId,
function: FunctionId,
arguments: &[ValueId],
should_inline_call: &impl Fn(&PerFunctionContext, &Ssa, FunctionId) -> bool,
should_inline_call: &impl Fn(&Function) -> bool,
) {
let old_results = self.source_function.dfg.instruction_results(call_id);
let arguments = vecmap(arguments, |arg| self.translate_value(*arg));
Expand Down
22 changes: 19 additions & 3 deletions compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
//! Pre-process functions before inlining them into others.

use crate::ssa::Ssa;
use crate::ssa::{
ir::function::{Function, RuntimeType},
Ssa,
};

use super::inlining;
use super::inlining::{self, InlineInfo};

impl Ssa {
/// Run pre-processing steps on functions in isolation.
Expand All @@ -19,6 +22,19 @@ impl Ssa {
// Preliminary inlining decisions.
let inline_infos = inlining::compute_inline_infos(&self, false, aggressiveness);

let should_inline_call = |callee: &Function| -> bool {
match callee.runtime() {
RuntimeType::Acir(_) => {
// Functions marked to not have predicates should be preserved.
!callee.is_no_predicates()
}
RuntimeType::Brillig(_) => {
// We inline inline if the function called wasn't ruled out as too costly or recursive.
InlineInfo::should_inline(&inline_infos, callee.id())
}
}
};

for (id, (own_weight, transitive_weight)) in bottom_up {
// Skip preprocessing heavy functions that gained most of their weight from transitive accumulation.
// These can be processed later by the regular SSA passes.
Expand All @@ -34,7 +50,7 @@ impl Ssa {
}
let function = &self.functions[&id];
// Start with an inline pass.
let mut function = function.inlined(&self, false, &inline_infos);
let mut function = function.inlined(&self, &should_inline_call);
// Help unrolling determine bounds.
function.as_slice_optimization();
// Prepare for unrolling
Expand Down
Loading