diff --git a/Makefile b/Makefile index ffd90a836fb0..38600a22c4c1 100644 --- a/Makefile +++ b/Makefile @@ -294,7 +294,6 @@ SOURCE_FILES = \ Debug.cpp \ DebugArguments.cpp \ DebugToFile.cpp \ - DeepCopy.cpp \ Definition.cpp \ Deinterleave.cpp \ DeviceArgument.cpp \ @@ -424,7 +423,6 @@ HEADER_FILES = \ Debug.h \ DebugArguments.h \ DebugToFile.h \ - DeepCopy.h \ Definition.h \ Deinterleave.h \ DeviceArgument.h \ @@ -441,6 +439,7 @@ HEADER_FILES = \ Float16.h \ Func.h \ Function.h \ + FunctionPtr.h \ FuseGPUThreadLoops.h \ FuzzFloatStores.h \ Generator.h \ diff --git a/src/Associativity.cpp b/src/Associativity.cpp index f52bb06a499e..5e8bde5e2bd5 100644 --- a/src/Associativity.cpp +++ b/src/Associativity.cpp @@ -56,8 +56,6 @@ class ConvertSelfRef : public IRMutator { internal_assert(op); if ((op->call_type == Call::Halide) && (func == op->name)) { - internal_assert(!op->func.defined()) - << "Func should not have been defined for a self-reference\n"; internal_assert(args.size() == op->args.size()) << "Self-reference should have the same number of args as the original\n"; for (size_t i = 0; i < op->args.size(); i++) { @@ -543,7 +541,7 @@ void associativity_test() { Expr x = Variable::make(t, "x"); Expr y = Variable::make(t, "y"); Expr x_idx = Variable::make(Int(32), "x_idx"); - Expr f_call_0 = Call::make(t, "f", {x_idx}, Call::CallType::Halide, nullptr, 0); + Expr f_call_0 = Call::make(t, "f", {x_idx}, Call::CallType::Halide, FunctionPtr(), 0); // f(x) = uint8(uint16(x + y), 255) check_associativity("f", {x_idx}, {Cast::make(UInt(8), min(Cast::make(UInt(16), y + f_call_0), make_const(t, 255)))}, @@ -588,7 +586,7 @@ void associativity_test() { Expr x = Variable::make(t, "x"); Expr y = Variable::make(t, "y"); Expr x_idx = Variable::make(Int(32), "x_idx"); - Expr f_call_0 = Call::make(t, "f", {x_idx}, Call::CallType::Halide, nullptr, 0); + Expr f_call_0 = Call::make(t, "f", {x_idx}, Call::CallType::Halide, FunctionPtr(), 0); // f(x) = y && f(x) check_associativity("f", {x_idx}, {And::make(y, f_call_0)}, @@ -622,11 +620,11 @@ void associativity_test() { zs[i] = Variable::make(t, "z" + std::to_string(i)); } - Expr f_call_0 = Call::make(t, "f", {x}, Call::CallType::Halide, nullptr, 0); - Expr f_call_1 = Call::make(t, "f", {x}, Call::CallType::Halide, nullptr, 1); - Expr f_call_2 = Call::make(t, "f", {x}, Call::CallType::Halide, nullptr, 2); - Expr g_call_0 = Call::make(t, "g", {rx}, Call::CallType::Halide, nullptr, 0); - Expr g_call_1 = Call::make(t, "g", {rx}, Call::CallType::Halide, nullptr, 1); + Expr f_call_0 = Call::make(t, "f", {x}, Call::CallType::Halide, FunctionPtr(), 0); + Expr f_call_1 = Call::make(t, "f", {x}, Call::CallType::Halide, FunctionPtr(), 1); + Expr f_call_2 = Call::make(t, "f", {x}, Call::CallType::Halide, FunctionPtr(), 2); + Expr g_call_0 = Call::make(t, "g", {rx}, Call::CallType::Halide, FunctionPtr(), 0); + Expr g_call_1 = Call::make(t, "g", {rx}, Call::CallType::Halide, FunctionPtr(), 1); // f(x) = f(x) check_associativity("f", {x}, {f_call_0}, @@ -759,11 +757,11 @@ void associativity_test() { { Expr ry = Variable::make(t, "ry"); - Expr f_xy_call_0 = Call::make(t, "f", {x, y}, Call::CallType::Halide, nullptr, 0); - Expr f_xy_call_1 = Call::make(t, "f", {x, y}, Call::CallType::Halide, nullptr, 1); - Expr f_xy_call_2 = Call::make(t, "f", {x, y}, Call::CallType::Halide, nullptr, 2); - Expr f_xy_call_3 = Call::make(t, "f", {x, y}, Call::CallType::Halide, nullptr, 3); - Expr g_xy_call_0 = Call::make(t, "g", {rx, ry}, Call::CallType::Halide, nullptr, 0); + Expr f_xy_call_0 = Call::make(t, "f", {x, y}, Call::CallType::Halide, FunctionPtr(), 0); + Expr f_xy_call_1 = Call::make(t, "f", {x, y}, Call::CallType::Halide, FunctionPtr(), 1); + Expr f_xy_call_2 = Call::make(t, "f", {x, y}, Call::CallType::Halide, FunctionPtr(), 2); + Expr f_xy_call_3 = Call::make(t, "f", {x, y}, Call::CallType::Halide, FunctionPtr(), 3); + Expr g_xy_call_0 = Call::make(t, "g", {rx, ry}, Call::CallType::Halide, FunctionPtr(), 0); // 2D argmin + trivial update: // f(x, y) = Tuple(min(f(x, y)[0], g(r.x, r.y)[0]), diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index b9ef72e9b857..c597e3d26502 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -285,7 +285,6 @@ set(HEADER_FILES Debug.h DebugArguments.h DebugToFile.h - DeepCopy.h Definition.h Deinterleave.h DeviceArgument.h @@ -302,6 +301,7 @@ set(HEADER_FILES Float16.h Func.h Function.h + FunctionPtr.h FuseGPUThreadLoops.h FuzzFloatStores.h Generator.h @@ -438,7 +438,6 @@ add_library(Halide ${HALIDE_LIBRARY_TYPE} Debug.cpp DebugArguments.cpp DebugToFile.cpp - DeepCopy.cpp Definition.cpp Deinterleave.cpp DeviceArgument.cpp diff --git a/src/DeepCopy.cpp b/src/DeepCopy.cpp deleted file mode 100644 index f8762b4865dd..000000000000 --- a/src/DeepCopy.cpp +++ /dev/null @@ -1,53 +0,0 @@ -#include "DeepCopy.h" - -namespace Halide{ -namespace Internal { - -using std::map; -using std::pair; -using std::string; -using std::vector; - -pair, map> deep_copy( - const vector &outputs, const map &env) { - vector copy_outputs; - map copy_env; - - // Create empty deep-copies of all Functions in 'env' - map copied_map; // Original Function -> Deep-copy - for (const auto &iter : env) { - copied_map[iter.second] = Function(iter.second.name()); - } - - // Deep copy all Functions in 'env' into their corresponding empty copies - for (const auto &iter : env) { - iter.second.deep_copy(copied_map[iter.second], copied_map); - } - - // Need to substitute-in all old Function references in all Exprs referenced - // within the Function with the deep-copy versions - for (auto &iter : copied_map) { - iter.second.substitute_calls(copied_map); - } - - // Populate the env with the deep-copy version - for (const auto &iter : copied_map) { - copy_env.emplace(iter.first.name(), iter.second); - } - - for (const auto &func : outputs) { - const auto &iter = copied_map.find(func); - if (iter != copied_map.end()) { - debug(4) << "Adding deep-copied version to outputs: " << func.name() << "\n"; - copy_outputs.push_back(iter->second); - } else { - debug(4) << "Adding original version to outputs: " << func.name() << "\n"; - copy_outputs.push_back(func); - } - } - - return { copy_outputs, copy_env }; -} - -} -} diff --git a/src/DeepCopy.h b/src/DeepCopy.h deleted file mode 100644 index a65fdc5db547..000000000000 --- a/src/DeepCopy.h +++ /dev/null @@ -1,24 +0,0 @@ -#ifndef HALIDE_DEEP_COPY_H -#define HALIDE_DEEP_COPY_H - -/** \file - * - * Defines pass to create deep-copies of all Functions in 'env'. - */ - -#include - -#include "IR.h" - -namespace Halide { -namespace Internal { - -/** Create deep-copies of all Functions in 'env'. This returns a pair of the - * deep-copied versions of 'outputs' and 'env' */ -std::pair, std::map> deep_copy( - const std::vector &outputs, const std::map &env); - -} -} - -#endif diff --git a/src/FindCalls.cpp b/src/FindCalls.cpp index 2703e95a50a2..667569235722 100644 --- a/src/FindCalls.cpp +++ b/src/FindCalls.cpp @@ -9,6 +9,7 @@ using std::vector; using std::string; using std::pair; +namespace { /* Find all the internal halide calls in an expr */ class FindCalls : public IRVisitor { public: @@ -38,7 +39,8 @@ class FindCalls : public IRVisitor { } }; -void populate_environment(Function f, map &env, bool recursive = true) { +void populate_environment_helper(Function f, map &env, + bool recursive = true, bool include_wrappers = false) { map::const_iterator iter = env.find(f.name()); if (iter != env.end()) { user_assert(iter->second.same_as(f)) @@ -58,26 +60,39 @@ void populate_environment(Function f, map &env, bool recursive } } + if (include_wrappers) { + for (auto it : f.schedule().wrappers()) { + Function g(it.second); + calls.calls[g.name()] = g; + } + } + if (!recursive) { env.insert(calls.calls.begin(), calls.calls.end()); } else { env[f.name()] = f; for (pair i : calls.calls) { - populate_environment(i.second, env); + populate_environment_helper(i.second, env, recursive, include_wrappers); } } } +} + +void populate_environment(Function f, map &env) { + populate_environment_helper(f, env, true, true); +} + map find_transitive_calls(Function f) { map res; - populate_environment(f, res, true); + populate_environment_helper(f, res, true, false); return res; } map find_direct_calls(Function f) { map res; - populate_environment(f, res, false); + populate_environment_helper(f, res, false, false); return res; } diff --git a/src/FindCalls.h b/src/FindCalls.h index e1b420dd6be9..9b462116b611 100644 --- a/src/FindCalls.h +++ b/src/FindCalls.h @@ -28,6 +28,10 @@ std::map find_direct_calls(Function f); */ std::map find_transitive_calls(Function f); +/** Find all Functions transitively referenced by f in any way and add + * them to the given map. */ +void populate_environment(Function f, std::map &env); + } } diff --git a/src/Func.cpp b/src/Func.cpp index d08a48bacbaa..793bc7ab9a4b 100644 --- a/src/Func.cpp +++ b/src/Func.cpp @@ -365,8 +365,6 @@ class SubstituteSelfReference : public IRMutator { internal_assert(c); if ((c->call_type == Call::Halide) && (func == c->name)) { - internal_assert(!c->func.defined()) - << "func should not have been defined for a self-reference\n"; debug(4) << "...Replace call to Func \"" << c->name << "\" with " << "\"" << substitute.name() << "\"\n"; vector args; @@ -872,7 +870,7 @@ Func Stage::rfactor(vector> preserved) { if (!prover_result.xs[i].var.empty()) { Expr prev_val = Call::make(intm.output_types()[i], func_name, f_store_args, Call::CallType::Halide, - nullptr, i); + FunctionPtr(), i); replacements.emplace(prover_result.xs[i].var, prev_val); } else { user_warning << "Update definition of " << stage_name << " at index " << i @@ -1742,16 +1740,16 @@ void Func::invalidate_cache() { Func Func::in(const Func &f) { invalidate_cache(); user_assert(name() != f.name()) << "Cannot call 'in()' on itself\n"; - const map> &wrappers = func.wrappers(); + const map &wrappers = func.wrappers(); const auto &iter = wrappers.find(f.name()); if (iter == wrappers.end()) { - Func wrapper(name() + "_in_" + f.name()); + Func wrapper(func.new_function_in_same_group(name() + "_in_" + f.name())); wrapper(args()) = (*this)(args()); func.add_wrapper(f.name(), wrapper.func); return wrapper; } - IntrusivePtr wrapper_contents = iter->second; + FunctionPtr wrapper_contents = iter->second; internal_assert(wrapper_contents.defined()); // Make sure that no other Func shares the same wrapper as 'f' @@ -1776,7 +1774,7 @@ Func Func::in(const vector& fs) { // Either all Funcs have the same wrapper or they don't already have any wrappers. // Otherwise, throw an error. - const map> &wrappers = func.wrappers(); + const map &wrappers = func.wrappers(); const auto &iter = wrappers.find(fs[0].name()); if (iter == wrappers.end()) { @@ -1786,7 +1784,7 @@ Func Func::in(const vector& fs) { << "Cannot define the wrapper since " << fs[i].name() << " already has a wrapper while " << fs[0].name() << " doesn't \n"; } - Func wrapper(name() + "_wrapper"); + Func wrapper(func.new_function_in_same_group(name() + "_wrapper")); wrapper(args()) = (*this)(args()); for (const Func &f : fs) { user_assert(name() != f.name()) << "Cannot call 'in()' on itself\n"; @@ -1795,7 +1793,7 @@ Func Func::in(const vector& fs) { return wrapper; } - IntrusivePtr wrapper_contents = iter->second; + FunctionPtr wrapper_contents = iter->second; internal_assert(wrapper_contents.defined()); // Make sure all the other Funcs in 'fs' share the same wrapper and no other @@ -1825,16 +1823,16 @@ Func Func::in(const vector& fs) { Func Func::in() { invalidate_cache(); - const map> &wrappers = func.wrappers(); + const map &wrappers = func.wrappers(); const auto &iter = wrappers.find(""); if (iter == wrappers.end()) { - Func wrapper(name() + "_global_wrapper"); + Func wrapper(func.new_function_in_same_group(name() + "_global_wrapper")); wrapper(args()) = (*this)(args()); func.add_wrapper("", wrapper.func); return wrapper; } - IntrusivePtr wrapper_contents = iter->second; + FunctionPtr wrapper_contents = iter->second; internal_assert(wrapper_contents.defined()); Function wrapper(wrapper_contents); internal_assert(wrapper.frozen()); diff --git a/src/Function.cpp b/src/Function.cpp index d8cc011540f7..a021b5b3545f 100644 --- a/src/Function.cpp +++ b/src/Function.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include "IR.h" #include "IRMutator.h" @@ -21,18 +22,40 @@ using std::vector; using std::string; using std::set; using std::map; +using std::pair; -typedef map, IntrusivePtr> DeepCopyMap; -ExternFuncArgument deep_copy_extern_func_argument_helper(const ExternFuncArgument &src, - DeepCopyMap &copied_map); -void deep_copy_function_contents_helper(const IntrusivePtr &src, - IntrusivePtr &dst, - DeepCopyMap &copied_map); -IntrusivePtr deep_copy_function_contents_helper( - const IntrusivePtr &src, DeepCopyMap &copied_map); +typedef map DeepCopyMap; + +struct FunctionContents; + +namespace { +// Weaken all the references to a particular Function to break +// reference cycles. Also count the number of references found. +class WeakenFunctionPtrs : public IRMutator { + using IRMutator::visit; + + void visit(const Call *c) { + IRMutator::visit(c); + c = expr.as(); + internal_assert(c); + if (c->func.defined() && + c->func.get() == func) { + FunctionPtr ptr = c->func; + ptr.weaken(); + expr = Call::make(c->type, c->name, c->args, c->call_type, + ptr, c->value_index, + c->image, c->param); + count++; + } + } + FunctionContents *func; +public: + int count = 0; + WeakenFunctionPtrs(FunctionContents *f) : func(f) {} +}; +} struct FunctionContents { - mutable RefCount ref_count; std::string name; std::vector output_types; @@ -49,21 +72,13 @@ struct FunctionContents { std::vector extern_arguments; std::string extern_function_name; - NameMangling extern_mangling; - DeviceAPI extern_function_device_api; - bool extern_uses_old_buffer_t; + NameMangling extern_mangling = NameMangling::Default; + DeviceAPI extern_function_device_api = DeviceAPI::Host; + bool extern_uses_old_buffer_t = false; - bool trace_loads, trace_stores, trace_realizations; + bool trace_loads = false, trace_stores = false, trace_realizations = false; - bool frozen; - - FunctionContents() : extern_mangling(NameMangling::Default), - extern_function_device_api(DeviceAPI::Host), - extern_uses_old_buffer_t(false), - trace_loads(false), - trace_stores(false), - trace_realizations(false), - frozen(false) {} + bool frozen = false; void accept(IRVisitor *visitor) const { func_schedule.accept(visitor); @@ -119,13 +134,22 @@ struct FunctionContents { } }; +struct FunctionGroup { + mutable RefCount ref_count; + vector members; +}; + +FunctionContents *FunctionPtr::get() const { + return &(group()->members[idx]); +} + template<> -EXPORT RefCount &ref_count(const FunctionContents *f) { +EXPORT RefCount &ref_count(const FunctionGroup *f) { return f->ref_count; } template<> -EXPORT void destroy(const FunctionContents *f) { +EXPORT void destroy(const FunctionGroup *f) { delete f; } @@ -205,28 +229,6 @@ struct CheckVars : public IRGraphVisitor { } }; -struct DeleteSelfReferences : public IRMutator { - IntrusivePtr func; - - // Also count the number of self references so we know if a Func - // has a recursive definition. - int count = 0; - - using IRMutator::visit; - - void visit(const Call *c) { - IRMutator::visit(c); - c = expr.as(); - internal_assert(c); - if (c->func.same_as(func)) { - expr = Call::make(c->type, c->name, c->args, c->call_type, - nullptr, c->value_index, - c->image, c->param); - count++; - } - } -}; - // Mark all functions found in an expr as frozen. class FreezeFunctions : public IRGraphVisitor { using IRGraphVisitor::visit; @@ -251,21 +253,24 @@ namespace { static std::atomic rand_counter; } -Function::Function() : contents(new FunctionContents) { +Function::Function() { } -Function::Function(const IntrusivePtr &ptr) : contents(ptr) { +Function::Function(const FunctionPtr &ptr) : contents(ptr) { + contents.strengthen(); internal_assert(ptr.defined()) << "Can't construct Function from undefined FunctionContents ptr\n"; } -Function::Function(const std::string &n) : contents(new FunctionContents) { +Function::Function(const std::string &n) { for (size_t i = 0; i < n.size(); i++) { user_assert(n[i] != '.') << "Func name \"" << n << "\" is invalid. " << "Func names may not contain the character '.', " << "as it is used internally by Halide as a separator\n"; } + contents.strong = new FunctionGroup; + contents.strong->members.resize(1); contents->name = n; } @@ -287,96 +292,52 @@ ExternFuncArgument deep_copy_extern_func_argument_helper( // If the FunctionContents has already been deep-copied previously, i.e. // it's in the 'copied_map', use the deep-copied version from the map instead // of creating a new deep-copy - IntrusivePtr &copied_func = copied_map[src.func]; - if (copied_func.defined()) { - copy.func = copied_func; - } else { - copy.func = deep_copy_function_contents_helper(src.func, copied_map); - copied_map[src.func] = copy.func; - } + FunctionPtr &copied_func = copied_map[src.func]; + internal_assert(copied_func.defined()); + copy.func = copied_func; return copy; } -// Return a deep-copy of FunctionContents 'src' -IntrusivePtr deep_copy_function_contents_helper( - const IntrusivePtr &src, DeepCopyMap &copied_map) { - - IntrusivePtr copy(new FunctionContents); - deep_copy_function_contents_helper(src, copy, copied_map); - return copy; -} +void Function::deep_copy(FunctionPtr copy, DeepCopyMap &copied_map) const { + internal_assert(copy.defined() && contents.defined()) + << "Cannot deep-copy undefined Function\n"; -// Return a deep-copy of FunctionContents 'src' -void deep_copy_function_contents_helper(const IntrusivePtr &src, - IntrusivePtr &dst, - DeepCopyMap &copied_map) { - debug(4) << "Deep-copy function contents: \"" << src->name << "\"\n"; - - internal_assert(dst.defined() && src.defined()) << "Cannot deep-copy undefined FunctionContents\n"; - - dst->name = src->name; - dst->output_types = src->output_types; - dst->debug_file = src->debug_file; - dst->extern_function_name = src->extern_function_name; - dst->extern_mangling = src->extern_mangling; - dst->extern_function_device_api = src->extern_function_device_api; - dst->extern_uses_old_buffer_t = src->extern_uses_old_buffer_t; - dst->trace_loads = src->trace_loads; - dst->trace_stores = src->trace_stores; - dst->trace_realizations = src->trace_realizations; - dst->frozen = src->frozen; - dst->output_buffers = src->output_buffers; - dst->func_schedule = src->func_schedule.deep_copy(copied_map); + // Add reference to this Function's deep-copy to the map in case of + // self-reference, e.g. self-reference in an Definition. + copied_map[contents] = copy; + + debug(4) << "Deep-copy function contents: \"" << contents->name << "\"\n"; + + copy->name = contents->name; + copy->output_types = contents->output_types; + copy->debug_file = contents->debug_file; + copy->extern_function_name = contents->extern_function_name; + copy->extern_mangling = contents->extern_mangling; + copy->extern_function_device_api = contents->extern_function_device_api; + copy->extern_uses_old_buffer_t = contents->extern_uses_old_buffer_t; + copy->trace_loads = contents->trace_loads; + copy->trace_stores = contents->trace_stores; + copy->trace_realizations = contents->trace_realizations; + copy->frozen = contents->frozen; + copy->output_buffers = contents->output_buffers; + copy->func_schedule = contents->func_schedule.deep_copy(copied_map); // Copy the pure definition - dst->init_def = src->init_def.get_copy(); - internal_assert(dst->init_def.is_init()); - internal_assert(dst->init_def.schedule().rvars().empty()) + copy->init_def = contents->init_def.get_copy(); + internal_assert(copy->init_def.is_init()); + internal_assert(copy->init_def.schedule().rvars().empty()) << "Init definition shouldn't have reduction domain\n"; - for (const Definition &def : src->updates) { + for (const Definition &def : contents->updates) { internal_assert(!def.is_init()); Definition def_copy = def.get_copy(); internal_assert(!def_copy.is_init()); - dst->updates.push_back(std::move(def_copy)); + copy->updates.push_back(std::move(def_copy)); } - for (const ExternFuncArgument &e : src->extern_arguments) { + for (const ExternFuncArgument &e : contents->extern_arguments) { ExternFuncArgument e_copy = deep_copy_extern_func_argument_helper(e, copied_map); - dst->extern_arguments.push_back(std::move(e_copy)); - } -} - -void Function::deep_copy(Function ©, - std::map &copied_map) const { - internal_assert(copy.contents.defined() && contents.defined()) - << "Cannot deep-copy undefined Function\n"; - // Need to copy over the contents of Functions in 'copied_map' since - // deep_copy_function_contents_helper() takes a map of - // (DeepCopyMap) - DeepCopyMap copied_funcs_map; - for (const auto &iter : copied_map) { - copied_funcs_map[iter.first.contents] = iter.second.contents; - } - // Add reference to this Function's deep-copy to the map in case of - // self-reference, e.g. self-reference in an Definition. - copied_funcs_map[contents] = copy.contents; - - // Perform the deep-copies - deep_copy_function_contents_helper(contents, copy.contents, copied_funcs_map); - - // Copy over all new deep-copies of FunctionContents into 'copied_map'. - for (const auto &iter : copied_funcs_map) { - Function old_func = Function(iter.first); - if (copied_map.count(old_func)) { - // Need to make sure that deep_copy_function_contents_helper() uses - // the already existing deep-copy of FunctionContents instead of - // creating a new deep-copy - internal_assert(copied_map[old_func].contents.same_as(iter.second)) - << old_func.name() << " is deep-copied twice\n"; - continue; - } - copied_map[old_func] = Function(iter.second); + copy->extern_arguments.push_back(std::move(e_copy)); } } @@ -438,7 +399,8 @@ void Function::define(const vector &args, vector values) { << "Reduction domain referenced in pure function definition.\n"; if (!contents.defined()) { - contents = new FunctionContents; + contents.strong = new FunctionGroup; + contents.strong->members.resize(1); contents->name = unique_name('f'); } @@ -623,18 +585,16 @@ void Function::define_update(const vector &_args, vector values) { // The update value and args probably refer back to the // function itself, introducing circular references and hence // memory leaks. We need to break these cycles. - DeleteSelfReferences deleter; - deleter.func = contents; - deleter.count = 0; + WeakenFunctionPtrs weakener(contents.get()); for (size_t i = 0; i < args.size(); i++) { - args[i] = deleter.mutate(args[i]); + args[i] = weakener.mutate(args[i]); } for (size_t i = 0; i < values.size(); i++) { - values[i] = deleter.mutate(values[i]); + values[i] = weakener.mutate(values[i]); } if (check.reduction_domain.defined()) { check.reduction_domain.set_predicate( - deleter.mutate(check.reduction_domain.predicate())); + weakener.mutate(check.reduction_domain.predicate())); } Definition r(args, values, check.reduction_domain, false); @@ -676,7 +636,7 @@ void Function::define_update(const vector &_args, vector values) { // the args are pure, then this definition completely hides // earlier ones! if (!check.reduction_domain.defined() && - deleter.count == 0 && + weakener.count == 0 && pure) { user_warning << "In update definition " << update_idx << " of Func \"" << name() << "\":\n" @@ -737,6 +697,10 @@ void Function::accept(IRVisitor *visitor) const { contents->accept(visitor); } +void Function::mutate(IRMutator *mutator) { + contents->mutate(mutator); +} + const std::string &Function::name() const { return contents->name; } @@ -895,13 +859,31 @@ bool Function::frozen() const { return contents->frozen; } -const map> &Function::wrappers() const { +const map &Function::wrappers() const { return contents->func_schedule.wrappers(); } +Function Function::new_function_in_same_group(const std::string &f) { + int group_size = (int)(contents.group()->members.size()); + contents.group()->members.resize(group_size+1); + contents.group()->members[group_size].name = f; + FunctionPtr ptr; + ptr.strong = contents.group(); + ptr.idx = group_size; + return Function(ptr); +} + void Function::add_wrapper(const std::string &f, Function &wrapper) { wrapper.freeze(); - contents->func_schedule.add_wrapper(f, wrapper.contents); + FunctionPtr ptr = wrapper.contents; + + // Weaken the pointer from the function to its wrapper + ptr.weaken(); + contents->func_schedule.add_wrapper(f, ptr); + + // Weaken the pointer from the wrapper back to the function. + WeakenFunctionPtrs weakener(contents.get()); + wrapper.mutate(&weakener); } namespace { @@ -910,22 +892,27 @@ namespace { class SubstituteCalls : public IRMutator { using IRMutator::visit; - map substitutions; + map substitutions; void visit(const Call *c) { IRMutator::visit(c); c = expr.as(); internal_assert(c); - if ((c->call_type == Call::Halide) && c->func.defined() && substitutions.count(Function(c->func))) { - const Function &subs = substitutions[Function(c->func)]; + if ((c->call_type == Call::Halide) && + c->func.defined() && + substitutions.count(c->func)) { + FunctionPtr subs = substitutions[c->func]; + internal_assert(subs.defined()) << "Function not in environment: " << subs->name << "\n"; debug(4) << "...Replace call to Func \"" << c->name << "\" with " - << "\"" << subs.name() << "\"\n"; - expr = Call::make(subs, c->args, c->value_index); + << "\"" << subs->name << "\"\n"; + expr = Call::make(c->type, subs->name, c->args, c->call_type, + subs, c->value_index, + c->image, c->param); } } public: - SubstituteCalls(const map &substitutions) + SubstituteCalls(const map &substitutions) : substitutions(substitutions) {} }; @@ -946,9 +933,8 @@ class SubstituteScheduleParamExprs : public IRMutator { } // anonymous namespace -Function &Function::substitute_calls(const map &substitutions) { +Function &Function::substitute_calls(const map &substitutions) { debug(4) << "Substituting calls in " << name() << "\n"; - if (substitutions.empty()) { return *this; } @@ -958,8 +944,8 @@ Function &Function::substitute_calls(const map &sub } Function &Function::substitute_calls(const Function &orig, const Function &substitute) { - map substitutions; - substitutions.emplace(orig, substitute); + map substitutions; + substitutions.emplace(orig.get_contents(), substitute.get_contents()); return substitute_calls(substitutions); } @@ -969,5 +955,59 @@ Function &Function::substitute_schedule_param_exprs() { return *this; } +// Deep copy an entire Function DAG. +pair, map> deep_copy( + const vector &outputs, const map &env) { + vector copy_outputs; + map copy_env; + + // Create empty deep-copies of all Functions in 'env' + DeepCopyMap copied_map; // Original Function -> Deep-copy + IntrusivePtr group(new FunctionGroup); + group->members.resize(env.size()); + int i = 0; + for (const auto &iter : env) { + // Make a weak pointer to the function to use for within-group references. + FunctionPtr ptr; + ptr.weak = group.get(); + ptr.idx = i; + ptr->name = iter.second.name(); + copied_map[iter.second.get_contents()] = ptr; + i++; + } + + // Deep copy all Functions in 'env' into their corresponding empty copies + for (const auto &iter : env) { + iter.second.deep_copy(copied_map[iter.second.get_contents()], copied_map); + } + + // Need to substitute-in all old Function references in all Exprs referenced + // within the Function with the deep-copy versions + for (auto &iter : copied_map) { + Function(iter.second).substitute_calls(copied_map); + } + + // Populate the env with the deep-copy version + for (const auto &iter : copied_map) { + FunctionPtr ptr = iter.second; + copy_env.emplace(iter.first->name, Function(ptr)); + } + + for (const auto &func : outputs) { + const auto &iter = copied_map.find(func.get_contents()); + if (iter != copied_map.end()) { + FunctionPtr ptr = iter->second; + debug(4) << "Adding deep-copied version to outputs: " << func.name() << "\n"; + copy_outputs.push_back(Function(ptr)); + } else { + debug(4) << "Adding original version to outputs: " << func.name() << "\n"; + copy_outputs.push_back(func); + } + } + + return { copy_outputs, copy_env }; +} + + } } diff --git a/src/Function.h b/src/Function.h index 14447d43cfad..6485a9cf6211 100644 --- a/src/Function.h +++ b/src/Function.h @@ -7,6 +7,7 @@ #include "Expr.h" #include "IntrusivePtr.h" +#include "FunctionPtr.h" #include "Parameter.h" #include "Schedule.h" #include "Reduction.h" @@ -17,21 +18,17 @@ namespace Halide { -namespace Internal { -struct FunctionContents; -} - /** An argument to an extern-defined Func. May be a Function, Buffer, * ImageParam or Expr. */ struct ExternFuncArgument { enum ArgType {UndefinedArg = 0, FuncArg, BufferArg, ExprArg, ImageParamArg}; ArgType arg_type; - Internal::IntrusivePtr func; + Internal::FunctionPtr func; Buffer<> buffer; Expr expr; Internal::Parameter image_param; - ExternFuncArgument(Internal::IntrusivePtr f): arg_type(FuncArg), func(f) {} + ExternFuncArgument(Internal::FunctionPtr f): arg_type(FuncArg), func(f) {} template ExternFuncArgument(Buffer b): arg_type(BufferArg), buffer(b) {} @@ -66,7 +63,7 @@ namespace Internal { * syntactic sugar to help with definitions. */ class Function { - IntrusivePtr contents; + FunctionPtr contents; public: /** This lets you use a Function as a key in a map of the form @@ -88,11 +85,11 @@ class Function { EXPORT explicit Function(const std::string &n); /** Construct a Function from an existing FunctionContents pointer. Must be non-null */ - EXPORT explicit Function(const IntrusivePtr &); + EXPORT explicit Function(const FunctionPtr &); /** Get a handle on the halide function contents that this Function * represents. */ - IntrusivePtr get_contents() const { + FunctionPtr get_contents() const { return contents; } @@ -104,7 +101,7 @@ class Function { * creating a new deep-copy to avoid creating deep-copies of the same Function * multiple times. */ - EXPORT void deep_copy(Function ©, std::map &copied_map) const; + EXPORT void deep_copy(FunctionPtr copy, std::map &copied_map) const; /** Add a pure definition to this function. It may not already * have a definition. All the free variables in 'value' must @@ -126,6 +123,10 @@ class Function { * of this function. */ EXPORT void accept(IRVisitor *visitor) const; + /** Accept a mutator to mutator all of the definitions and + * arguments of this function. */ + EXPORT void mutate(IRMutator *visitor); + /** Get the name of the function. */ EXPORT const std::string &name() const; @@ -263,6 +264,12 @@ class Function { * add new definitions. */ EXPORT bool frozen() const; + /** Make a new Function with the same lifetime as this one, and + * return a strong reference to it. Useful to create Functions which + * have circular references to this one - e.g. the wrappers + * produced by Func::in. */ + Function new_function_in_same_group(const std::string &); + /** Mark calls of this function by 'f' to be replaced with its wrapper * during the lowering stage. If the string 'f' is empty, it means replace * all calls to this function by all other functions (excluding itself) in @@ -271,14 +278,14 @@ class Function { * See \ref Func::in for more details. */ // @{ EXPORT void add_wrapper(const std::string &f, Function &wrapper); - EXPORT const std::map> &wrappers() const; + EXPORT const std::map &wrappers() const; // @} /** Replace every call to Functions in 'substitutions' keys by all Exprs * referenced in this Function to call to their substitute Functions (i.e. * the corresponding values in 'substitutions' map). */ // @{ - EXPORT Function &substitute_calls(const std::map &substitutions); + EXPORT Function &substitute_calls(const std::map &substitutions); EXPORT Function &substitute_calls(const Function &orig, const Function &substitute); // @} @@ -287,6 +294,11 @@ class Function { EXPORT Function &substitute_schedule_param_exprs(); }; +/** Deep copy an entire Function DAG. */ +std::pair, std::map> deep_copy( + const std::vector &outputs, + const std::map &env); + }} #endif diff --git a/src/FunctionPtr.h b/src/FunctionPtr.h new file mode 100644 index 000000000000..a20998759207 --- /dev/null +++ b/src/FunctionPtr.h @@ -0,0 +1,93 @@ +#ifndef HALIDE_FUNCTION_PTR_H +#define HALIDE_FUNCTION_PTR_H + +#include "IntrusivePtr.h" + +namespace Halide { +namespace Internal { + +/** Functions are allocated in groups for memory management. Each + * group has a ref count associated with it. All within-group + * references must be weak. If there are any references from outside + * the group, at least one must be strong. Within-group references + * may form cycles, but there may not be reference cycles that span + * multiple groups. These rules are not enforced automatically. */ +struct FunctionGroup; + +/** The opaque struct describing a Halide function. Wrap it in a + * Function object to access it. */ +struct FunctionContents; + +/** A possibly-weak pointer to a Halide function. Take care to follow + * the rules mentioned above. Preserves weakness/strength on copy. + * + * Note that Function objects are always strong pointers to Halide + * functions. + */ +struct FunctionPtr { + /** A strong and weak pointer to the group. Only one of these + * should be non-zero. */ + // @{ + IntrusivePtr strong; + FunctionGroup *weak = nullptr; + // @} + + /** The index of the function within the group. */ + int idx = 0; + + /** Get a pointer to the group this Function belongs to. */ + FunctionGroup *group() const { + return weak ? weak : strong.get(); + } + + /** Get the opaque FunctionContents object this pointer refers + * to. Wrap it in a Function to do anything interesting with it. */ + // @{ + FunctionContents *get() const; + + FunctionContents &operator*() const { + return *get(); + } + + FunctionContents *operator->() const { + return get(); + } + // @} + + /** Convert from a strong reference to a weak reference. Does + * nothing if the pointer is undefined, or if the reference is + * already weak. */ + void weaken() { + weak = group(); + strong = nullptr; + } + + /** Convert from a weak reference to a strong reference. Does + * nothing if the pointer is undefined, or if the reference is + * already strong. */ + void strengthen() { + strong = group(); + weak = nullptr; + } + + /** Check if the reference is defined. */ + bool defined() const { + return weak || strong.defined(); + } + + /** Check if two FunctionPtrs refer to the same Function. */ + bool same_as(const FunctionPtr &other) const { + return idx == other.idx && group() == other.group(); + } + + /** Pointer comparison, for using FunctionPtrs as keys in maps and + * sets. */ + bool operator<(const FunctionPtr &other) const { + return get() < other.get(); + } +}; + +} +} + +#endif diff --git a/src/IR.cpp b/src/IR.cpp index b8be68d409ed..cf9db3a88c9d 100644 --- a/src/IR.cpp +++ b/src/IR.cpp @@ -535,7 +535,7 @@ Expr Call::make(Function func, const std::vector &args, int idx) { } Expr Call::make(Type type, const std::string &name, const std::vector &args, CallType call_type, - IntrusivePtr func, int value_index, + FunctionPtr func, int value_index, Buffer<> image, Parameter param) { if (name == Call::prefetch && call_type == Call::Intrinsic) { internal_assert(args.size() % 2 == 0) diff --git a/src/IR.h b/src/IR.h index 9a79961c0a93..ee79a00a91dc 100644 --- a/src/IR.h +++ b/src/IR.h @@ -540,10 +540,8 @@ struct Call : public ExprNode { trace; // If it's a call to another halide function, this call node holds - // onto a pointer to that function for the purposes of reference - // counting only. Self-references in update definitions do not - // have this set, to avoid cycles. - IntrusivePtr func; + // a possibly-weak reference to that function. + FunctionPtr func; // If that function has multiple values, which value does this // call node refer to? @@ -558,7 +556,7 @@ struct Call : public ExprNode { Parameter param; EXPORT static Expr make(Type type, const std::string &name, const std::vector &args, CallType call_type, - IntrusivePtr func = nullptr, int value_index = 0, + FunctionPtr func = FunctionPtr(), int value_index = 0, Buffer<> image = Buffer<>(), Parameter param = Parameter()); /** Convenience constructor for calls to other halide functions */ @@ -566,12 +564,12 @@ struct Call : public ExprNode { /** Convenience constructor for loads from concrete images */ static Expr make(Buffer<> image, const std::vector &args) { - return make(image.type(), image.name(), args, Image, nullptr, 0, image, Parameter()); + return make(image.type(), image.name(), args, Image, FunctionPtr(), 0, image, Parameter()); } /** Convenience constructor for loads from images parameters */ static Expr make(Parameter param, const std::vector &args) { - return make(param.type(), param.name(), args, Image, nullptr, 0, Buffer<>(), param); + return make(param.type(), param.name(), args, Image, FunctionPtr(), 0, Buffer<>(), param); } /** Check if a call node is pure within a pipeline, meaning that diff --git a/src/InjectImageIntrinsics.cpp b/src/InjectImageIntrinsics.cpp index 309038641439..d4f31e95c872 100644 --- a/src/InjectImageIntrinsics.cpp +++ b/src/InjectImageIntrinsics.cpp @@ -119,7 +119,7 @@ class InjectImageIntrinsics : public IRMutator { Call::image_load, args, Call::PureIntrinsic, - nullptr, + FunctionPtr(), 0, call->image, call->param); diff --git a/src/InjectOpenGLIntrinsics.cpp b/src/InjectOpenGLIntrinsics.cpp index d61b67a4ce32..80cb66d47c69 100644 --- a/src/InjectOpenGLIntrinsics.cpp +++ b/src/InjectOpenGLIntrinsics.cpp @@ -71,7 +71,7 @@ class InjectOpenGLIntrinsics : public IRMutator { expr = Call::make(call->type, Call::glsl_texture_load, vector(&args[0], &args[5]), - Call::Intrinsic, nullptr, 0, + Call::Intrinsic, FunctionPtr(), 0, call->image, call->param); } else if (call->is_intrinsic(Call::image_store)) { user_assert(call->args.size() == 6) diff --git a/src/Lower.cpp b/src/Lower.cpp index 4429bf14adb6..f019b7e6ac96 100644 --- a/src/Lower.cpp +++ b/src/Lower.cpp @@ -15,7 +15,6 @@ #include "Debug.h" #include "DebugArguments.h" #include "DebugToFile.h" -#include "DeepCopy.h" #include "Deinterleave.h" #include "EarlyFree.h" #include "FindCalls.h" @@ -83,8 +82,7 @@ Module lower(const vector &output_funcs, const string &pipeline_name, // Compute an environment map env; for (Function f : output_funcs) { - map more_funcs = find_transitive_calls(f); - env.insert(more_funcs.begin(), more_funcs.end()); + populate_environment(f, env); } // Create a deep-copy of the entire graph of Funcs. @@ -379,6 +377,26 @@ Module lower(const vector &output_funcs, const string &pipeline_name, } } + // We're about to drop the environment and outputs vector, which + // contain the only strong refs to Functions that may still be + // pointed to by the IR. So make those refs strong. + class StrengthenRefs : public IRMutator { + using IRMutator::visit; + void visit(const Call *c) { + IRMutator::visit(c); + c = expr.as(); + internal_assert(c); + if (c->func.defined()) { + FunctionPtr ptr = c->func; + ptr.strengthen(); + expr = Call::make(c->type, c->name, c->args, c->call_type, + ptr, c->value_index, + c->image, c->param); + } + } + }; + s = StrengthenRefs().mutate(s); + LoweredFunc main_func(pipeline_name, public_args, s, linkage_type); // If we're in debug mode, add code that prints the args. diff --git a/src/PrintLoopNest.cpp b/src/PrintLoopNest.cpp index 0ef1c6680785..0299a3cfe4fb 100644 --- a/src/PrintLoopNest.cpp +++ b/src/PrintLoopNest.cpp @@ -1,5 +1,4 @@ #include "PrintLoopNest.h" -#include "DeepCopy.h" #include "FindCalls.h" #include "Function.h" #include "Func.h" @@ -158,8 +157,7 @@ string print_loop_nest(const vector &output_funcs) { // Compute an environment map env; for (Function f : output_funcs) { - map more_funcs = find_transitive_calls(f); - env.insert(more_funcs.begin(), more_funcs.end()); + populate_environment(f, env); } // Create a deep-copy of the entire graph of Funcs. @@ -171,6 +169,11 @@ string print_loop_nest(const vector &output_funcs) { Func(f).compute_root().store_root(); } + // Ensure that all ScheduleParams become well-defined constant Exprs. + for (auto &f : env) { + f.second.substitute_schedule_param_exprs(); + } + // Substitute in wrapper Funcs env = wrap_func_calls(env); diff --git a/src/Schedule.cpp b/src/Schedule.cpp index 68e8162c9de8..0b9a790f7d6a 100644 --- a/src/Schedule.cpp +++ b/src/Schedule.cpp @@ -113,11 +113,7 @@ bool LoopLevel::operator==(const LoopLevel &other) const { namespace Internal { -typedef std::map, IntrusivePtr> DeepCopyMap; - -IntrusivePtr deep_copy_function_contents_helper( - const IntrusivePtr &src, - DeepCopyMap &copied_map); +typedef std::map DeepCopyMap; /** A schedule for a halide function, which defines where, when, and * how it should be evaluated. */ @@ -127,7 +123,7 @@ struct FuncScheduleContents { LoopLevel store_level, compute_level; std::vector storage_dims; std::vector bounds; - std::map> wrappers; + std::map wrappers; bool memoized; FuncScheduleContents() : @@ -215,7 +211,7 @@ EXPORT void destroy(const StageScheduleContents *p) { FuncSchedule::FuncSchedule() : contents(new FuncScheduleContents) {} FuncSchedule FuncSchedule::deep_copy( - std::map, IntrusivePtr> &copied_map) const { + std::map &copied_map) const { internal_assert(contents.defined()) << "Cannot deep-copy undefined FuncSchedule\n"; FuncSchedule copy; @@ -225,17 +221,11 @@ FuncSchedule FuncSchedule::deep_copy( copy.contents->bounds = contents->bounds; copy.contents->memoized = contents->memoized; - // Deep-copy wrapper functions. If function has already been deep-copied before, - // i.e. it's in the 'copied_map', use the deep-copied version from the map instead - // of creating a new deep-copy + // Deep-copy wrapper functions. for (const auto &iter : contents->wrappers) { - IntrusivePtr &copied_func = copied_map[iter.second]; - if (copied_func.defined()) { - copy.contents->wrappers[iter.first] = copied_func; - } else { - copy.contents->wrappers[iter.first] = deep_copy_function_contents_helper(iter.second, copied_map); - copied_map[iter.second] = copy.contents->wrappers[iter.first]; - } + FunctionPtr &copied_func = copied_map[iter.second]; + internal_assert(copied_func.defined()) << Function(iter.second).name() << "\n"; + copy.contents->wrappers[iter.first] = copied_func; } internal_assert(copy.contents->wrappers.size() == contents->wrappers.size()); return copy; @@ -265,16 +255,16 @@ const std::vector &FuncSchedule::bounds() const { return contents->bounds; } -std::map> &FuncSchedule::wrappers() { +std::map &FuncSchedule::wrappers() { return contents->wrappers; } -const std::map> &FuncSchedule::wrappers() const { +const std::map &FuncSchedule::wrappers() const { return contents->wrappers; } void FuncSchedule::add_wrapper(const std::string &f, - const IntrusivePtr &wrapper) { + const Internal::FunctionPtr &wrapper) { if (contents->wrappers.count(f)) { if (f.empty()) { user_warning << "Replacing previous definition of global wrapper in function \"" diff --git a/src/Schedule.h b/src/Schedule.h index c0bee3140b8c..9aed9d998c96 100644 --- a/src/Schedule.h +++ b/src/Schedule.h @@ -7,6 +7,7 @@ #include "Expr.h" #include "Parameter.h" +#include "FunctionPtr.h" #include @@ -258,7 +259,7 @@ class FuncSchedule { * same FunctionContents multiple times. */ EXPORT FuncSchedule deep_copy( - std::map, IntrusivePtr> &copied_map) const; + std::map &copied_map) const; /** This flag is set to true if the schedule is memoized. */ // @{ @@ -288,10 +289,10 @@ class FuncSchedule { * all calls to the function by all other functions (excluding itself) in * the pipeline with the wrapper. See \ref Func::in for more details. */ // @{ - const std::map> &wrappers() const; - std::map> &wrappers(); + const std::map &wrappers() const; + std::map &wrappers(); EXPORT void add_wrapper(const std::string &f, - const IntrusivePtr &wrapper); + const Internal::FunctionPtr &wrapper); // @} /** At what sites should we inject the allocation and the diff --git a/src/Tracing.cpp b/src/Tracing.cpp index bef4c59ed7f8..ed90a7026e1c 100644 --- a/src/Tracing.cpp +++ b/src/Tracing.cpp @@ -73,7 +73,9 @@ class InjectTracing : public IRMutator { bool trace_it = false; Expr trace_parent; if (op->call_type == Call::Halide) { - Function f = env.find(op->name)->second; + auto it = env.find(op->name); + internal_assert(it != env.end()) << op->name << " not in environment\n"; + Function f = it->second; internal_assert(!f.can_be_inlined() || !f.schedule().compute_level().is_inline()); trace_it = f.is_tracing_loads() || trace_all_loads; diff --git a/src/WrapCalls.cpp b/src/WrapCalls.cpp index b7f1661e56b2..9deecb6d49cf 100644 --- a/src/WrapCalls.cpp +++ b/src/WrapCalls.cpp @@ -11,30 +11,36 @@ using std::set; using std::string; using std::vector; -typedef map SubstitutionMap; +typedef map SubstitutionMap; namespace { -void insert_func_wrapper_helper(map &func_wrappers_map, - const Function &in_func, const Function &wrapped_func, - const Function &wrapper) { - internal_assert(in_func.get_contents().defined() && wrapped_func.get_contents().defined() && - wrapper.get_contents().defined()); +void insert_func_wrapper_helper(map &func_wrappers_map, + FunctionPtr in_func, + FunctionPtr wrapped_func, + FunctionPtr wrapper) { + internal_assert(in_func.defined() && + wrapped_func.defined() && + wrapper.defined()); internal_assert(func_wrappers_map[in_func].count(wrapped_func) == 0) << "Should only have one wrapper for each function call in a Func\n"; SubstitutionMap &wrappers_map = func_wrappers_map[in_func]; for (auto iter = wrappers_map.begin(); iter != wrappers_map.end(); ++iter) { if (iter->second.same_as(wrapped_func)) { - debug(4) << "Merging wrapper of " << in_func.name() << " [" << iter->first.name() - << ", " << iter->second.name() << "] with [" << wrapped_func.name() << ", " - << wrapper.name() << "]\n"; + debug(4) << "Merging wrapper of " << Function(in_func).name() + << " [" << Function(iter->first).name() + << ", " << Function(iter->second).name() + << "] with [" << Function(wrapped_func).name() << ", " + << Function(wrapper).name() << "]\n"; iter->second = wrapper; return; } else if (wrapper.same_as(iter->first)) { - debug(4) << "Merging wrapper of " << in_func.name() << " [" << wrapped_func.name() - << ", " << wrapper.name() << "] with [" << iter->first.name() << ", " - << iter->second.name() << "]\n"; + debug(4) << "Merging wrapper of " << Function(in_func).name() + << " [" << Function(wrapped_func).name() + << ", " << Function(wrapper).name() + << "] with [" << Function(iter->first).name() + << ", " << Function(iter->second).name() << "]\n"; wrappers_map.emplace(wrapped_func, iter->second); wrappers_map.erase(iter); return; @@ -48,17 +54,17 @@ void insert_func_wrapper_helper(map wrap_func_calls(const map &env) { map wrapped_env; - map func_wrappers_map; // In Func -> [wrapped Func -> wrapper] + map func_wrappers_map; // In Func -> [wrapped Func -> wrapper] for (const auto &iter : env) { wrapped_env.emplace(iter.first, iter.second); - func_wrappers_map[iter.second]; + func_wrappers_map[iter.second.get_contents()]; } for (const auto &it : env) { string wrapped_fname = it.first; - const Function &wrapped_func = it.second; - const auto &wrappers = wrapped_func.schedule().wrappers(); + FunctionPtr wrapped_func = it.second.get_contents(); + const auto &wrappers = it.second.schedule().wrappers(); // Put the names of all wrappers of this Function into the set for // faster comparison during the substitution. @@ -69,17 +75,19 @@ map wrap_func_calls(const map &env) { for (const auto &iter : wrappers) { string in_func = iter.first; - const Function &wrapper = Function(iter.second); // This is already the deep-copy version + FunctionPtr wrapper = iter.second; if (in_func.empty()) { // Global wrapper for (const auto &wrapped_env_iter : wrapped_env) { in_func = wrapped_env_iter.first; - if ((wrapper.name() == in_func) || (all_func_wrappers.find(in_func) != all_func_wrappers.end())) { + if ((wrapped_fname == in_func) || + (all_func_wrappers.find(in_func) != all_func_wrappers.end())) { // The wrapper should still call the original function, // so we don't want to rewrite the calls done by the // wrapper. We also shouldn't rewrite the original // function itself. - debug(4) << "Skip over replacing \"" << in_func << "\" with \"" << wrapper.name() << "\"\n"; + debug(4) << "Skip over replacing \"" << in_func + << "\" with \"" << Function(wrapper).name() << "\"\n"; continue; } if (wrappers.count(in_func)) { @@ -90,13 +98,15 @@ map wrap_func_calls(const map &env) { } debug(4) << "Global wrapper: replacing reference of \"" << wrapped_fname << "\" in \"" << in_func - << "\" with \"" << wrapper.name() << "\"\n"; - insert_func_wrapper_helper(func_wrappers_map, wrapped_env_iter.second, wrapped_func, wrapper); + << "\" with \"" << Function(wrapper).name() << "\"\n"; + insert_func_wrapper_helper(func_wrappers_map, + wrapped_env_iter.second.get_contents(), + wrapped_func, wrapper); } } else { // Custom wrapper debug(4) << "Custom wrapper: replacing reference of \"" << wrapped_fname << "\" in \"" << in_func << "\" with \"" - << wrapper.name() << "\"\n"; + << Function(wrapper).name() << "\"\n"; const auto &in_func_iter = wrapped_env.find(in_func); if (in_func_iter == wrapped_env.end()) { @@ -111,17 +121,20 @@ map wrap_func_calls(const map &env) { // f.in(g); // f.realize(..); debug(4) << " skip custom wrapper for " << in_func << " [" << wrapped_fname - << " -> " << wrapper.name() << "] since it's not in the pipeline\n"; + << " -> " << Function(wrapper).name() << "] since it's not in the pipeline\n"; continue; } - insert_func_wrapper_helper(func_wrappers_map, wrapped_env[in_func], wrapped_func, wrapper); + insert_func_wrapper_helper(func_wrappers_map, + wrapped_env[in_func].get_contents(), + wrapped_func, + wrapper); } } } // Perform the substitution for (auto &iter : wrapped_env) { - const auto &substitutions = func_wrappers_map[iter.second]; + const auto &substitutions = func_wrappers_map[iter.second.get_contents()]; if (!substitutions.empty()) { iter.second.substitute_calls(substitutions); }