From f9dcde61acc9ab8d96f520e4fda70c345c0631fc Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Thu, 6 Jul 2017 15:45:16 -0700 Subject: [PATCH 1/3] Allow self-references in update definitions This makes the leaking worse. --- src/Associativity.cpp | 2 -- src/Func.cpp | 2 -- src/Function.cpp | 53 +++++++++++++++++-------------------------- src/Function.h | 4 ++++ src/Tracing.cpp | 4 +++- src/WrapCalls.cpp | 3 ++- 6 files changed, 30 insertions(+), 38 deletions(-) diff --git a/src/Associativity.cpp b/src/Associativity.cpp index f52bb06a499e..98fbbde6a224 100644 --- a/src/Associativity.cpp +++ b/src/Associativity.cpp @@ -56,8 +56,6 @@ class ConvertSelfRef : public IRMutator { internal_assert(op); if ((op->call_type == Call::Halide) && (func == op->name)) { - internal_assert(!op->func.defined()) - << "Func should not have been defined for a self-reference\n"; internal_assert(args.size() == op->args.size()) << "Self-reference should have the same number of args as the original\n"; for (size_t i = 0; i < op->args.size(); i++) { diff --git a/src/Func.cpp b/src/Func.cpp index d08a48bacbaa..1f84adf966d9 100644 --- a/src/Func.cpp +++ b/src/Func.cpp @@ -365,8 +365,6 @@ class SubstituteSelfReference : public IRMutator { internal_assert(c); if ((c->call_type == Call::Halide) && (func == c->name)) { - internal_assert(!c->func.defined()) - << "func should not have been defined for a self-reference\n"; debug(4) << "...Replace call to Func \"" << c->name << "\" with " << "\"" << substitute.name() << "\"\n"; vector args; diff --git a/src/Function.cpp b/src/Function.cpp index d8cc011540f7..6b03b61e3266 100644 --- a/src/Function.cpp +++ b/src/Function.cpp @@ -49,21 +49,14 @@ struct FunctionContents { std::vector extern_arguments; std::string extern_function_name; - NameMangling extern_mangling; - DeviceAPI extern_function_device_api; - bool extern_uses_old_buffer_t; + NameMangling extern_mangling = NameMangling::Default; + DeviceAPI extern_function_device_api = DeviceAPI::Host; + bool extern_uses_old_buffer_t = false; - bool trace_loads, trace_stores, trace_realizations; + bool trace_loads = false, trace_stores = false, trace_realizations = false; - bool frozen; + bool frozen = false; - FunctionContents() : extern_mangling(NameMangling::Default), - extern_function_device_api(DeviceAPI::Host), - extern_uses_old_buffer_t(false), - trace_loads(false), - trace_stores(false), - trace_realizations(false), - frozen(false) {} void accept(IRVisitor *visitor) const { func_schedule.accept(visitor); @@ -205,25 +198,20 @@ struct CheckVars : public IRGraphVisitor { } }; -struct DeleteSelfReferences : public IRMutator { +struct CountSelfReferences : public IRVisitor { IntrusivePtr func; - // Also count the number of self references so we know if a Func + // Count the number of self references so we know if a Func // has a recursive definition. int count = 0; - using IRMutator::visit; + using IRVisitor::visit; void visit(const Call *c) { - IRMutator::visit(c); - c = expr.as(); - internal_assert(c); if (c->func.same_as(func)) { - expr = Call::make(c->type, c->name, c->args, c->call_type, - nullptr, c->value_index, - c->image, c->param); count++; } + IRVisitor::visit(c); } }; @@ -620,21 +608,18 @@ void Function::define_update(const vector &_args, vector values) { check.reduction_domain.set_predicate(lower_random(check.reduction_domain.predicate(), free_vars, tag)); } - // The update value and args probably refer back to the - // function itself, introducing circular references and hence - // memory leaks. We need to break these cycles. - DeleteSelfReferences deleter; - deleter.func = contents; - deleter.count = 0; + // Check for self-references, so that we can emit a warning for + // shadowed definitions. + CountSelfReferences counter; + counter.func = contents; for (size_t i = 0; i < args.size(); i++) { - args[i] = deleter.mutate(args[i]); + args[i].accept(&counter); } for (size_t i = 0; i < values.size(); i++) { - values[i] = deleter.mutate(values[i]); + values[i].accept(&counter); } if (check.reduction_domain.defined()) { - check.reduction_domain.set_predicate( - deleter.mutate(check.reduction_domain.predicate())); + check.reduction_domain.predicate().accept(&counter); } Definition r(args, values, check.reduction_domain, false); @@ -676,7 +661,7 @@ void Function::define_update(const vector &_args, vector values) { // the args are pure, then this definition completely hides // earlier ones! if (!check.reduction_domain.defined() && - deleter.count == 0 && + counter.count == 0 && pure) { user_warning << "In update definition " << update_idx << " of Func \"" << name() << "\":\n" @@ -737,6 +722,10 @@ void Function::accept(IRVisitor *visitor) const { contents->accept(visitor); } +void Function::mutate(IRMutator *mutator) { + contents->mutate(mutator); +} + const std::string &Function::name() const { return contents->name; } diff --git a/src/Function.h b/src/Function.h index 14447d43cfad..e1204a231f83 100644 --- a/src/Function.h +++ b/src/Function.h @@ -126,6 +126,10 @@ class Function { * of this function. */ EXPORT void accept(IRVisitor *visitor) const; + /** Accept a mutator to mutator all of the definitions and + * arguments of this function. */ + EXPORT void mutate(IRMutator *visitor); + /** Get the name of the function. */ EXPORT const std::string &name() const; diff --git a/src/Tracing.cpp b/src/Tracing.cpp index bef4c59ed7f8..ed90a7026e1c 100644 --- a/src/Tracing.cpp +++ b/src/Tracing.cpp @@ -73,7 +73,9 @@ class InjectTracing : public IRMutator { bool trace_it = false; Expr trace_parent; if (op->call_type == Call::Halide) { - Function f = env.find(op->name)->second; + auto it = env.find(op->name); + internal_assert(it != env.end()) << op->name << " not in environment\n"; + Function f = it->second; internal_assert(!f.can_be_inlined() || !f.schedule().compute_level().is_inline()); trace_it = f.is_tracing_loads() || trace_all_loads; diff --git a/src/WrapCalls.cpp b/src/WrapCalls.cpp index b7f1661e56b2..f7f05e25277f 100644 --- a/src/WrapCalls.cpp +++ b/src/WrapCalls.cpp @@ -74,7 +74,8 @@ map wrap_func_calls(const map &env) { if (in_func.empty()) { // Global wrapper for (const auto &wrapped_env_iter : wrapped_env) { in_func = wrapped_env_iter.first; - if ((wrapper.name() == in_func) || (all_func_wrappers.find(in_func) != all_func_wrappers.end())) { + if ((wrapped_fname == in_func) || + (all_func_wrappers.find(in_func) != all_func_wrappers.end())) { // The wrapper should still call the original function, // so we don't want to rewrite the calls done by the // wrapper. We also shouldn't rewrite the original From 8639c1934c0c3b79f3ade3d8944cd96e7447652c Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Mon, 10 Jul 2017 18:07:25 -0700 Subject: [PATCH 2/3] Allocate Functions in groups that share a refcount This avoids reference cycles that cause memory leaks in master. I introduce a FunctionPtr, which may be a weak or a strong reference to a particular Function in a group (using the refcount of the entire group). The programmer must obey several rules, which are not currently mechanically checked. They are: 1) All references within a group must be weak. 2) If there are any references to the group from outside, at least one such reference must be strong. 3) There may be no reference cycles that span multiple groups. Func::in() and recursive update definitions take care to stay within one group and use weak references, to enforce 1 and 3. Calls from one Func to another form a DAG, and are strong to enforce 2. When the deep copy happens I put everything in one big group, because it was the simplest way to implement the deep copy. Deep copy therefore changes all references to weak to enforce rule 1. During lowering, the environment holds the only strong references to any Functions. At the end of lowering, before the environment is discarded, the resulting stmt is made to have strong references, enforcing rule 2. Transfer of Exprs from the Function dag to the lowered stmt is one-way, so rule 3 is also respected. --- Makefile | 3 +- src/Associativity.cpp | 24 +-- src/CMakeLists.txt | 3 +- src/DeepCopy.cpp | 53 ------ src/DeepCopy.h | 24 --- src/FindCalls.cpp | 23 ++- src/FindCalls.h | 4 + src/Func.cpp | 20 +-- src/Function.cpp | 309 +++++++++++++++++++-------------- src/Function.h | 32 ++-- src/FunctionPtr.h | 93 ++++++++++ src/IR.cpp | 2 +- src/IR.h | 12 +- src/InjectImageIntrinsics.cpp | 2 +- src/InjectOpenGLIntrinsics.cpp | 2 +- src/Lower.cpp | 24 ++- src/PrintLoopNest.cpp | 1 - src/Schedule.cpp | 30 ++-- src/Schedule.h | 9 +- src/WrapCalls.cpp | 60 ++++--- 20 files changed, 420 insertions(+), 310 deletions(-) delete mode 100644 src/DeepCopy.cpp delete mode 100644 src/DeepCopy.h create mode 100644 src/FunctionPtr.h diff --git a/Makefile b/Makefile index e3d3e876af36..9b58531d1bcf 100644 --- a/Makefile +++ b/Makefile @@ -292,7 +292,6 @@ SOURCE_FILES = \ Debug.cpp \ DebugArguments.cpp \ DebugToFile.cpp \ - DeepCopy.cpp \ Definition.cpp \ Deinterleave.cpp \ DeviceArgument.cpp \ @@ -422,7 +421,6 @@ HEADER_FILES = \ Debug.h \ DebugArguments.h \ DebugToFile.h \ - DeepCopy.h \ Definition.h \ Deinterleave.h \ DeviceArgument.h \ @@ -439,6 +437,7 @@ HEADER_FILES = \ Float16.h \ Func.h \ Function.h \ + FunctionPtr.h \ FuseGPUThreadLoops.h \ FuzzFloatStores.h \ Generator.h \ diff --git a/src/Associativity.cpp b/src/Associativity.cpp index 98fbbde6a224..5e8bde5e2bd5 100644 --- a/src/Associativity.cpp +++ b/src/Associativity.cpp @@ -541,7 +541,7 @@ void associativity_test() { Expr x = Variable::make(t, "x"); Expr y = Variable::make(t, "y"); Expr x_idx = Variable::make(Int(32), "x_idx"); - Expr f_call_0 = Call::make(t, "f", {x_idx}, Call::CallType::Halide, nullptr, 0); + Expr f_call_0 = Call::make(t, "f", {x_idx}, Call::CallType::Halide, FunctionPtr(), 0); // f(x) = uint8(uint16(x + y), 255) check_associativity("f", {x_idx}, {Cast::make(UInt(8), min(Cast::make(UInt(16), y + f_call_0), make_const(t, 255)))}, @@ -586,7 +586,7 @@ void associativity_test() { Expr x = Variable::make(t, "x"); Expr y = Variable::make(t, "y"); Expr x_idx = Variable::make(Int(32), "x_idx"); - Expr f_call_0 = Call::make(t, "f", {x_idx}, Call::CallType::Halide, nullptr, 0); + Expr f_call_0 = Call::make(t, "f", {x_idx}, Call::CallType::Halide, FunctionPtr(), 0); // f(x) = y && f(x) check_associativity("f", {x_idx}, {And::make(y, f_call_0)}, @@ -620,11 +620,11 @@ void associativity_test() { zs[i] = Variable::make(t, "z" + std::to_string(i)); } - Expr f_call_0 = Call::make(t, "f", {x}, Call::CallType::Halide, nullptr, 0); - Expr f_call_1 = Call::make(t, "f", {x}, Call::CallType::Halide, nullptr, 1); - Expr f_call_2 = Call::make(t, "f", {x}, Call::CallType::Halide, nullptr, 2); - Expr g_call_0 = Call::make(t, "g", {rx}, Call::CallType::Halide, nullptr, 0); - Expr g_call_1 = Call::make(t, "g", {rx}, Call::CallType::Halide, nullptr, 1); + Expr f_call_0 = Call::make(t, "f", {x}, Call::CallType::Halide, FunctionPtr(), 0); + Expr f_call_1 = Call::make(t, "f", {x}, Call::CallType::Halide, FunctionPtr(), 1); + Expr f_call_2 = Call::make(t, "f", {x}, Call::CallType::Halide, FunctionPtr(), 2); + Expr g_call_0 = Call::make(t, "g", {rx}, Call::CallType::Halide, FunctionPtr(), 0); + Expr g_call_1 = Call::make(t, "g", {rx}, Call::CallType::Halide, FunctionPtr(), 1); // f(x) = f(x) check_associativity("f", {x}, {f_call_0}, @@ -757,11 +757,11 @@ void associativity_test() { { Expr ry = Variable::make(t, "ry"); - Expr f_xy_call_0 = Call::make(t, "f", {x, y}, Call::CallType::Halide, nullptr, 0); - Expr f_xy_call_1 = Call::make(t, "f", {x, y}, Call::CallType::Halide, nullptr, 1); - Expr f_xy_call_2 = Call::make(t, "f", {x, y}, Call::CallType::Halide, nullptr, 2); - Expr f_xy_call_3 = Call::make(t, "f", {x, y}, Call::CallType::Halide, nullptr, 3); - Expr g_xy_call_0 = Call::make(t, "g", {rx, ry}, Call::CallType::Halide, nullptr, 0); + Expr f_xy_call_0 = Call::make(t, "f", {x, y}, Call::CallType::Halide, FunctionPtr(), 0); + Expr f_xy_call_1 = Call::make(t, "f", {x, y}, Call::CallType::Halide, FunctionPtr(), 1); + Expr f_xy_call_2 = Call::make(t, "f", {x, y}, Call::CallType::Halide, FunctionPtr(), 2); + Expr f_xy_call_3 = Call::make(t, "f", {x, y}, Call::CallType::Halide, FunctionPtr(), 3); + Expr g_xy_call_0 = Call::make(t, "g", {rx, ry}, Call::CallType::Halide, FunctionPtr(), 0); // 2D argmin + trivial update: // f(x, y) = Tuple(min(f(x, y)[0], g(r.x, r.y)[0]), diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index d62011144980..f5939c60832f 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -285,7 +285,6 @@ set(HEADER_FILES Debug.h DebugArguments.h DebugToFile.h - DeepCopy.h Definition.h Deinterleave.h DeviceArgument.h @@ -302,6 +301,7 @@ set(HEADER_FILES Float16.h Func.h Function.h + FunctionPtr.h FuseGPUThreadLoops.h FuzzFloatStores.h Generator.h @@ -438,7 +438,6 @@ add_library(Halide ${HALIDE_LIBRARY_TYPE} Debug.cpp DebugArguments.cpp DebugToFile.cpp - DeepCopy.cpp Definition.cpp Deinterleave.cpp DeviceArgument.cpp diff --git a/src/DeepCopy.cpp b/src/DeepCopy.cpp deleted file mode 100644 index f8762b4865dd..000000000000 --- a/src/DeepCopy.cpp +++ /dev/null @@ -1,53 +0,0 @@ -#include "DeepCopy.h" - -namespace Halide{ -namespace Internal { - -using std::map; -using std::pair; -using std::string; -using std::vector; - -pair, map> deep_copy( - const vector &outputs, const map &env) { - vector copy_outputs; - map copy_env; - - // Create empty deep-copies of all Functions in 'env' - map copied_map; // Original Function -> Deep-copy - for (const auto &iter : env) { - copied_map[iter.second] = Function(iter.second.name()); - } - - // Deep copy all Functions in 'env' into their corresponding empty copies - for (const auto &iter : env) { - iter.second.deep_copy(copied_map[iter.second], copied_map); - } - - // Need to substitute-in all old Function references in all Exprs referenced - // within the Function with the deep-copy versions - for (auto &iter : copied_map) { - iter.second.substitute_calls(copied_map); - } - - // Populate the env with the deep-copy version - for (const auto &iter : copied_map) { - copy_env.emplace(iter.first.name(), iter.second); - } - - for (const auto &func : outputs) { - const auto &iter = copied_map.find(func); - if (iter != copied_map.end()) { - debug(4) << "Adding deep-copied version to outputs: " << func.name() << "\n"; - copy_outputs.push_back(iter->second); - } else { - debug(4) << "Adding original version to outputs: " << func.name() << "\n"; - copy_outputs.push_back(func); - } - } - - return { copy_outputs, copy_env }; -} - -} -} diff --git a/src/DeepCopy.h b/src/DeepCopy.h deleted file mode 100644 index a65fdc5db547..000000000000 --- a/src/DeepCopy.h +++ /dev/null @@ -1,24 +0,0 @@ -#ifndef HALIDE_DEEP_COPY_H -#define HALIDE_DEEP_COPY_H - -/** \file - * - * Defines pass to create deep-copies of all Functions in 'env'. - */ - -#include - -#include "IR.h" - -namespace Halide { -namespace Internal { - -/** Create deep-copies of all Functions in 'env'. This returns a pair of the - * deep-copied versions of 'outputs' and 'env' */ -std::pair, std::map> deep_copy( - const std::vector &outputs, const std::map &env); - -} -} - -#endif diff --git a/src/FindCalls.cpp b/src/FindCalls.cpp index 2703e95a50a2..667569235722 100644 --- a/src/FindCalls.cpp +++ b/src/FindCalls.cpp @@ -9,6 +9,7 @@ using std::vector; using std::string; using std::pair; +namespace { /* Find all the internal halide calls in an expr */ class FindCalls : public IRVisitor { public: @@ -38,7 +39,8 @@ class FindCalls : public IRVisitor { } }; -void populate_environment(Function f, map &env, bool recursive = true) { +void populate_environment_helper(Function f, map &env, + bool recursive = true, bool include_wrappers = false) { map::const_iterator iter = env.find(f.name()); if (iter != env.end()) { user_assert(iter->second.same_as(f)) @@ -58,26 +60,39 @@ void populate_environment(Function f, map &env, bool recursive } } + if (include_wrappers) { + for (auto it : f.schedule().wrappers()) { + Function g(it.second); + calls.calls[g.name()] = g; + } + } + if (!recursive) { env.insert(calls.calls.begin(), calls.calls.end()); } else { env[f.name()] = f; for (pair i : calls.calls) { - populate_environment(i.second, env); + populate_environment_helper(i.second, env, recursive, include_wrappers); } } } +} + +void populate_environment(Function f, map &env) { + populate_environment_helper(f, env, true, true); +} + map find_transitive_calls(Function f) { map res; - populate_environment(f, res, true); + populate_environment_helper(f, res, true, false); return res; } map find_direct_calls(Function f) { map res; - populate_environment(f, res, false); + populate_environment_helper(f, res, false, false); return res; } diff --git a/src/FindCalls.h b/src/FindCalls.h index e1b420dd6be9..9b462116b611 100644 --- a/src/FindCalls.h +++ b/src/FindCalls.h @@ -28,6 +28,10 @@ std::map find_direct_calls(Function f); */ std::map find_transitive_calls(Function f); +/** Find all Functions transitively referenced by f in any way and add + * them to the given map. */ +void populate_environment(Function f, std::map &env); + } } diff --git a/src/Func.cpp b/src/Func.cpp index 1f84adf966d9..793bc7ab9a4b 100644 --- a/src/Func.cpp +++ b/src/Func.cpp @@ -870,7 +870,7 @@ Func Stage::rfactor(vector> preserved) { if (!prover_result.xs[i].var.empty()) { Expr prev_val = Call::make(intm.output_types()[i], func_name, f_store_args, Call::CallType::Halide, - nullptr, i); + FunctionPtr(), i); replacements.emplace(prover_result.xs[i].var, prev_val); } else { user_warning << "Update definition of " << stage_name << " at index " << i @@ -1740,16 +1740,16 @@ void Func::invalidate_cache() { Func Func::in(const Func &f) { invalidate_cache(); user_assert(name() != f.name()) << "Cannot call 'in()' on itself\n"; - const map> &wrappers = func.wrappers(); + const map &wrappers = func.wrappers(); const auto &iter = wrappers.find(f.name()); if (iter == wrappers.end()) { - Func wrapper(name() + "_in_" + f.name()); + Func wrapper(func.new_function_in_same_group(name() + "_in_" + f.name())); wrapper(args()) = (*this)(args()); func.add_wrapper(f.name(), wrapper.func); return wrapper; } - IntrusivePtr wrapper_contents = iter->second; + FunctionPtr wrapper_contents = iter->second; internal_assert(wrapper_contents.defined()); // Make sure that no other Func shares the same wrapper as 'f' @@ -1774,7 +1774,7 @@ Func Func::in(const vector& fs) { // Either all Funcs have the same wrapper or they don't already have any wrappers. // Otherwise, throw an error. - const map> &wrappers = func.wrappers(); + const map &wrappers = func.wrappers(); const auto &iter = wrappers.find(fs[0].name()); if (iter == wrappers.end()) { @@ -1784,7 +1784,7 @@ Func Func::in(const vector& fs) { << "Cannot define the wrapper since " << fs[i].name() << " already has a wrapper while " << fs[0].name() << " doesn't \n"; } - Func wrapper(name() + "_wrapper"); + Func wrapper(func.new_function_in_same_group(name() + "_wrapper")); wrapper(args()) = (*this)(args()); for (const Func &f : fs) { user_assert(name() != f.name()) << "Cannot call 'in()' on itself\n"; @@ -1793,7 +1793,7 @@ Func Func::in(const vector& fs) { return wrapper; } - IntrusivePtr wrapper_contents = iter->second; + FunctionPtr wrapper_contents = iter->second; internal_assert(wrapper_contents.defined()); // Make sure all the other Funcs in 'fs' share the same wrapper and no other @@ -1823,16 +1823,16 @@ Func Func::in(const vector& fs) { Func Func::in() { invalidate_cache(); - const map> &wrappers = func.wrappers(); + const map &wrappers = func.wrappers(); const auto &iter = wrappers.find(""); if (iter == wrappers.end()) { - Func wrapper(name() + "_global_wrapper"); + Func wrapper(func.new_function_in_same_group(name() + "_global_wrapper")); wrapper(args()) = (*this)(args()); func.add_wrapper("", wrapper.func); return wrapper; } - IntrusivePtr wrapper_contents = iter->second; + FunctionPtr wrapper_contents = iter->second; internal_assert(wrapper_contents.defined()); Function wrapper(wrapper_contents); internal_assert(wrapper.frozen()); diff --git a/src/Function.cpp b/src/Function.cpp index 6b03b61e3266..a021b5b3545f 100644 --- a/src/Function.cpp +++ b/src/Function.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include "IR.h" #include "IRMutator.h" @@ -21,18 +22,40 @@ using std::vector; using std::string; using std::set; using std::map; +using std::pair; -typedef map, IntrusivePtr> DeepCopyMap; -ExternFuncArgument deep_copy_extern_func_argument_helper(const ExternFuncArgument &src, - DeepCopyMap &copied_map); -void deep_copy_function_contents_helper(const IntrusivePtr &src, - IntrusivePtr &dst, - DeepCopyMap &copied_map); -IntrusivePtr deep_copy_function_contents_helper( - const IntrusivePtr &src, DeepCopyMap &copied_map); +typedef map DeepCopyMap; + +struct FunctionContents; + +namespace { +// Weaken all the references to a particular Function to break +// reference cycles. Also count the number of references found. +class WeakenFunctionPtrs : public IRMutator { + using IRMutator::visit; + + void visit(const Call *c) { + IRMutator::visit(c); + c = expr.as(); + internal_assert(c); + if (c->func.defined() && + c->func.get() == func) { + FunctionPtr ptr = c->func; + ptr.weaken(); + expr = Call::make(c->type, c->name, c->args, c->call_type, + ptr, c->value_index, + c->image, c->param); + count++; + } + } + FunctionContents *func; +public: + int count = 0; + WeakenFunctionPtrs(FunctionContents *f) : func(f) {} +}; +} struct FunctionContents { - mutable RefCount ref_count; std::string name; std::vector output_types; @@ -57,7 +80,6 @@ struct FunctionContents { bool frozen = false; - void accept(IRVisitor *visitor) const { func_schedule.accept(visitor); @@ -112,13 +134,22 @@ struct FunctionContents { } }; +struct FunctionGroup { + mutable RefCount ref_count; + vector members; +}; + +FunctionContents *FunctionPtr::get() const { + return &(group()->members[idx]); +} + template<> -EXPORT RefCount &ref_count(const FunctionContents *f) { +EXPORT RefCount &ref_count(const FunctionGroup *f) { return f->ref_count; } template<> -EXPORT void destroy(const FunctionContents *f) { +EXPORT void destroy(const FunctionGroup *f) { delete f; } @@ -198,23 +229,6 @@ struct CheckVars : public IRGraphVisitor { } }; -struct CountSelfReferences : public IRVisitor { - IntrusivePtr func; - - // Count the number of self references so we know if a Func - // has a recursive definition. - int count = 0; - - using IRVisitor::visit; - - void visit(const Call *c) { - if (c->func.same_as(func)) { - count++; - } - IRVisitor::visit(c); - } -}; - // Mark all functions found in an expr as frozen. class FreezeFunctions : public IRGraphVisitor { using IRGraphVisitor::visit; @@ -239,21 +253,24 @@ namespace { static std::atomic rand_counter; } -Function::Function() : contents(new FunctionContents) { +Function::Function() { } -Function::Function(const IntrusivePtr &ptr) : contents(ptr) { +Function::Function(const FunctionPtr &ptr) : contents(ptr) { + contents.strengthen(); internal_assert(ptr.defined()) << "Can't construct Function from undefined FunctionContents ptr\n"; } -Function::Function(const std::string &n) : contents(new FunctionContents) { +Function::Function(const std::string &n) { for (size_t i = 0; i < n.size(); i++) { user_assert(n[i] != '.') << "Func name \"" << n << "\" is invalid. " << "Func names may not contain the character '.', " << "as it is used internally by Halide as a separator\n"; } + contents.strong = new FunctionGroup; + contents.strong->members.resize(1); contents->name = n; } @@ -275,96 +292,52 @@ ExternFuncArgument deep_copy_extern_func_argument_helper( // If the FunctionContents has already been deep-copied previously, i.e. // it's in the 'copied_map', use the deep-copied version from the map instead // of creating a new deep-copy - IntrusivePtr &copied_func = copied_map[src.func]; - if (copied_func.defined()) { - copy.func = copied_func; - } else { - copy.func = deep_copy_function_contents_helper(src.func, copied_map); - copied_map[src.func] = copy.func; - } + FunctionPtr &copied_func = copied_map[src.func]; + internal_assert(copied_func.defined()); + copy.func = copied_func; return copy; } -// Return a deep-copy of FunctionContents 'src' -IntrusivePtr deep_copy_function_contents_helper( - const IntrusivePtr &src, DeepCopyMap &copied_map) { - - IntrusivePtr copy(new FunctionContents); - deep_copy_function_contents_helper(src, copy, copied_map); - return copy; -} +void Function::deep_copy(FunctionPtr copy, DeepCopyMap &copied_map) const { + internal_assert(copy.defined() && contents.defined()) + << "Cannot deep-copy undefined Function\n"; -// Return a deep-copy of FunctionContents 'src' -void deep_copy_function_contents_helper(const IntrusivePtr &src, - IntrusivePtr &dst, - DeepCopyMap &copied_map) { - debug(4) << "Deep-copy function contents: \"" << src->name << "\"\n"; - - internal_assert(dst.defined() && src.defined()) << "Cannot deep-copy undefined FunctionContents\n"; - - dst->name = src->name; - dst->output_types = src->output_types; - dst->debug_file = src->debug_file; - dst->extern_function_name = src->extern_function_name; - dst->extern_mangling = src->extern_mangling; - dst->extern_function_device_api = src->extern_function_device_api; - dst->extern_uses_old_buffer_t = src->extern_uses_old_buffer_t; - dst->trace_loads = src->trace_loads; - dst->trace_stores = src->trace_stores; - dst->trace_realizations = src->trace_realizations; - dst->frozen = src->frozen; - dst->output_buffers = src->output_buffers; - dst->func_schedule = src->func_schedule.deep_copy(copied_map); + // Add reference to this Function's deep-copy to the map in case of + // self-reference, e.g. self-reference in an Definition. + copied_map[contents] = copy; + + debug(4) << "Deep-copy function contents: \"" << contents->name << "\"\n"; + + copy->name = contents->name; + copy->output_types = contents->output_types; + copy->debug_file = contents->debug_file; + copy->extern_function_name = contents->extern_function_name; + copy->extern_mangling = contents->extern_mangling; + copy->extern_function_device_api = contents->extern_function_device_api; + copy->extern_uses_old_buffer_t = contents->extern_uses_old_buffer_t; + copy->trace_loads = contents->trace_loads; + copy->trace_stores = contents->trace_stores; + copy->trace_realizations = contents->trace_realizations; + copy->frozen = contents->frozen; + copy->output_buffers = contents->output_buffers; + copy->func_schedule = contents->func_schedule.deep_copy(copied_map); // Copy the pure definition - dst->init_def = src->init_def.get_copy(); - internal_assert(dst->init_def.is_init()); - internal_assert(dst->init_def.schedule().rvars().empty()) + copy->init_def = contents->init_def.get_copy(); + internal_assert(copy->init_def.is_init()); + internal_assert(copy->init_def.schedule().rvars().empty()) << "Init definition shouldn't have reduction domain\n"; - for (const Definition &def : src->updates) { + for (const Definition &def : contents->updates) { internal_assert(!def.is_init()); Definition def_copy = def.get_copy(); internal_assert(!def_copy.is_init()); - dst->updates.push_back(std::move(def_copy)); + copy->updates.push_back(std::move(def_copy)); } - for (const ExternFuncArgument &e : src->extern_arguments) { + for (const ExternFuncArgument &e : contents->extern_arguments) { ExternFuncArgument e_copy = deep_copy_extern_func_argument_helper(e, copied_map); - dst->extern_arguments.push_back(std::move(e_copy)); - } -} - -void Function::deep_copy(Function ©, - std::map &copied_map) const { - internal_assert(copy.contents.defined() && contents.defined()) - << "Cannot deep-copy undefined Function\n"; - // Need to copy over the contents of Functions in 'copied_map' since - // deep_copy_function_contents_helper() takes a map of - // (DeepCopyMap) - DeepCopyMap copied_funcs_map; - for (const auto &iter : copied_map) { - copied_funcs_map[iter.first.contents] = iter.second.contents; - } - // Add reference to this Function's deep-copy to the map in case of - // self-reference, e.g. self-reference in an Definition. - copied_funcs_map[contents] = copy.contents; - - // Perform the deep-copies - deep_copy_function_contents_helper(contents, copy.contents, copied_funcs_map); - - // Copy over all new deep-copies of FunctionContents into 'copied_map'. - for (const auto &iter : copied_funcs_map) { - Function old_func = Function(iter.first); - if (copied_map.count(old_func)) { - // Need to make sure that deep_copy_function_contents_helper() uses - // the already existing deep-copy of FunctionContents instead of - // creating a new deep-copy - internal_assert(copied_map[old_func].contents.same_as(iter.second)) - << old_func.name() << " is deep-copied twice\n"; - continue; - } - copied_map[old_func] = Function(iter.second); + copy->extern_arguments.push_back(std::move(e_copy)); } } @@ -426,7 +399,8 @@ void Function::define(const vector &args, vector values) { << "Reduction domain referenced in pure function definition.\n"; if (!contents.defined()) { - contents = new FunctionContents; + contents.strong = new FunctionGroup; + contents.strong->members.resize(1); contents->name = unique_name('f'); } @@ -608,18 +582,19 @@ void Function::define_update(const vector &_args, vector values) { check.reduction_domain.set_predicate(lower_random(check.reduction_domain.predicate(), free_vars, tag)); } - // Check for self-references, so that we can emit a warning for - // shadowed definitions. - CountSelfReferences counter; - counter.func = contents; + // The update value and args probably refer back to the + // function itself, introducing circular references and hence + // memory leaks. We need to break these cycles. + WeakenFunctionPtrs weakener(contents.get()); for (size_t i = 0; i < args.size(); i++) { - args[i].accept(&counter); + args[i] = weakener.mutate(args[i]); } for (size_t i = 0; i < values.size(); i++) { - values[i].accept(&counter); + values[i] = weakener.mutate(values[i]); } if (check.reduction_domain.defined()) { - check.reduction_domain.predicate().accept(&counter); + check.reduction_domain.set_predicate( + weakener.mutate(check.reduction_domain.predicate())); } Definition r(args, values, check.reduction_domain, false); @@ -661,7 +636,7 @@ void Function::define_update(const vector &_args, vector values) { // the args are pure, then this definition completely hides // earlier ones! if (!check.reduction_domain.defined() && - counter.count == 0 && + weakener.count == 0 && pure) { user_warning << "In update definition " << update_idx << " of Func \"" << name() << "\":\n" @@ -884,13 +859,31 @@ bool Function::frozen() const { return contents->frozen; } -const map> &Function::wrappers() const { +const map &Function::wrappers() const { return contents->func_schedule.wrappers(); } +Function Function::new_function_in_same_group(const std::string &f) { + int group_size = (int)(contents.group()->members.size()); + contents.group()->members.resize(group_size+1); + contents.group()->members[group_size].name = f; + FunctionPtr ptr; + ptr.strong = contents.group(); + ptr.idx = group_size; + return Function(ptr); +} + void Function::add_wrapper(const std::string &f, Function &wrapper) { wrapper.freeze(); - contents->func_schedule.add_wrapper(f, wrapper.contents); + FunctionPtr ptr = wrapper.contents; + + // Weaken the pointer from the function to its wrapper + ptr.weaken(); + contents->func_schedule.add_wrapper(f, ptr); + + // Weaken the pointer from the wrapper back to the function. + WeakenFunctionPtrs weakener(contents.get()); + wrapper.mutate(&weakener); } namespace { @@ -899,22 +892,27 @@ namespace { class SubstituteCalls : public IRMutator { using IRMutator::visit; - map substitutions; + map substitutions; void visit(const Call *c) { IRMutator::visit(c); c = expr.as(); internal_assert(c); - if ((c->call_type == Call::Halide) && c->func.defined() && substitutions.count(Function(c->func))) { - const Function &subs = substitutions[Function(c->func)]; + if ((c->call_type == Call::Halide) && + c->func.defined() && + substitutions.count(c->func)) { + FunctionPtr subs = substitutions[c->func]; + internal_assert(subs.defined()) << "Function not in environment: " << subs->name << "\n"; debug(4) << "...Replace call to Func \"" << c->name << "\" with " - << "\"" << subs.name() << "\"\n"; - expr = Call::make(subs, c->args, c->value_index); + << "\"" << subs->name << "\"\n"; + expr = Call::make(c->type, subs->name, c->args, c->call_type, + subs, c->value_index, + c->image, c->param); } } public: - SubstituteCalls(const map &substitutions) + SubstituteCalls(const map &substitutions) : substitutions(substitutions) {} }; @@ -935,9 +933,8 @@ class SubstituteScheduleParamExprs : public IRMutator { } // anonymous namespace -Function &Function::substitute_calls(const map &substitutions) { +Function &Function::substitute_calls(const map &substitutions) { debug(4) << "Substituting calls in " << name() << "\n"; - if (substitutions.empty()) { return *this; } @@ -947,8 +944,8 @@ Function &Function::substitute_calls(const map &sub } Function &Function::substitute_calls(const Function &orig, const Function &substitute) { - map substitutions; - substitutions.emplace(orig, substitute); + map substitutions; + substitutions.emplace(orig.get_contents(), substitute.get_contents()); return substitute_calls(substitutions); } @@ -958,5 +955,59 @@ Function &Function::substitute_schedule_param_exprs() { return *this; } +// Deep copy an entire Function DAG. +pair, map> deep_copy( + const vector &outputs, const map &env) { + vector copy_outputs; + map copy_env; + + // Create empty deep-copies of all Functions in 'env' + DeepCopyMap copied_map; // Original Function -> Deep-copy + IntrusivePtr group(new FunctionGroup); + group->members.resize(env.size()); + int i = 0; + for (const auto &iter : env) { + // Make a weak pointer to the function to use for within-group references. + FunctionPtr ptr; + ptr.weak = group.get(); + ptr.idx = i; + ptr->name = iter.second.name(); + copied_map[iter.second.get_contents()] = ptr; + i++; + } + + // Deep copy all Functions in 'env' into their corresponding empty copies + for (const auto &iter : env) { + iter.second.deep_copy(copied_map[iter.second.get_contents()], copied_map); + } + + // Need to substitute-in all old Function references in all Exprs referenced + // within the Function with the deep-copy versions + for (auto &iter : copied_map) { + Function(iter.second).substitute_calls(copied_map); + } + + // Populate the env with the deep-copy version + for (const auto &iter : copied_map) { + FunctionPtr ptr = iter.second; + copy_env.emplace(iter.first->name, Function(ptr)); + } + + for (const auto &func : outputs) { + const auto &iter = copied_map.find(func.get_contents()); + if (iter != copied_map.end()) { + FunctionPtr ptr = iter->second; + debug(4) << "Adding deep-copied version to outputs: " << func.name() << "\n"; + copy_outputs.push_back(Function(ptr)); + } else { + debug(4) << "Adding original version to outputs: " << func.name() << "\n"; + copy_outputs.push_back(func); + } + } + + return { copy_outputs, copy_env }; +} + + } } diff --git a/src/Function.h b/src/Function.h index e1204a231f83..6485a9cf6211 100644 --- a/src/Function.h +++ b/src/Function.h @@ -7,6 +7,7 @@ #include "Expr.h" #include "IntrusivePtr.h" +#include "FunctionPtr.h" #include "Parameter.h" #include "Schedule.h" #include "Reduction.h" @@ -17,21 +18,17 @@ namespace Halide { -namespace Internal { -struct FunctionContents; -} - /** An argument to an extern-defined Func. May be a Function, Buffer, * ImageParam or Expr. */ struct ExternFuncArgument { enum ArgType {UndefinedArg = 0, FuncArg, BufferArg, ExprArg, ImageParamArg}; ArgType arg_type; - Internal::IntrusivePtr func; + Internal::FunctionPtr func; Buffer<> buffer; Expr expr; Internal::Parameter image_param; - ExternFuncArgument(Internal::IntrusivePtr f): arg_type(FuncArg), func(f) {} + ExternFuncArgument(Internal::FunctionPtr f): arg_type(FuncArg), func(f) {} template ExternFuncArgument(Buffer b): arg_type(BufferArg), buffer(b) {} @@ -66,7 +63,7 @@ namespace Internal { * syntactic sugar to help with definitions. */ class Function { - IntrusivePtr contents; + FunctionPtr contents; public: /** This lets you use a Function as a key in a map of the form @@ -88,11 +85,11 @@ class Function { EXPORT explicit Function(const std::string &n); /** Construct a Function from an existing FunctionContents pointer. Must be non-null */ - EXPORT explicit Function(const IntrusivePtr &); + EXPORT explicit Function(const FunctionPtr &); /** Get a handle on the halide function contents that this Function * represents. */ - IntrusivePtr get_contents() const { + FunctionPtr get_contents() const { return contents; } @@ -104,7 +101,7 @@ class Function { * creating a new deep-copy to avoid creating deep-copies of the same Function * multiple times. */ - EXPORT void deep_copy(Function ©, std::map &copied_map) const; + EXPORT void deep_copy(FunctionPtr copy, std::map &copied_map) const; /** Add a pure definition to this function. It may not already * have a definition. All the free variables in 'value' must @@ -267,6 +264,12 @@ class Function { * add new definitions. */ EXPORT bool frozen() const; + /** Make a new Function with the same lifetime as this one, and + * return a strong reference to it. Useful to create Functions which + * have circular references to this one - e.g. the wrappers + * produced by Func::in. */ + Function new_function_in_same_group(const std::string &); + /** Mark calls of this function by 'f' to be replaced with its wrapper * during the lowering stage. If the string 'f' is empty, it means replace * all calls to this function by all other functions (excluding itself) in @@ -275,14 +278,14 @@ class Function { * See \ref Func::in for more details. */ // @{ EXPORT void add_wrapper(const std::string &f, Function &wrapper); - EXPORT const std::map> &wrappers() const; + EXPORT const std::map &wrappers() const; // @} /** Replace every call to Functions in 'substitutions' keys by all Exprs * referenced in this Function to call to their substitute Functions (i.e. * the corresponding values in 'substitutions' map). */ // @{ - EXPORT Function &substitute_calls(const std::map &substitutions); + EXPORT Function &substitute_calls(const std::map &substitutions); EXPORT Function &substitute_calls(const Function &orig, const Function &substitute); // @} @@ -291,6 +294,11 @@ class Function { EXPORT Function &substitute_schedule_param_exprs(); }; +/** Deep copy an entire Function DAG. */ +std::pair, std::map> deep_copy( + const std::vector &outputs, + const std::map &env); + }} #endif diff --git a/src/FunctionPtr.h b/src/FunctionPtr.h new file mode 100644 index 000000000000..a20998759207 --- /dev/null +++ b/src/FunctionPtr.h @@ -0,0 +1,93 @@ +#ifndef HALIDE_FUNCTION_PTR_H +#define HALIDE_FUNCTION_PTR_H + +#include "IntrusivePtr.h" + +namespace Halide { +namespace Internal { + +/** Functions are allocated in groups for memory management. Each + * group has a ref count associated with it. All within-group + * references must be weak. If there are any references from outside + * the group, at least one must be strong. Within-group references + * may form cycles, but there may not be reference cycles that span + * multiple groups. These rules are not enforced automatically. */ +struct FunctionGroup; + +/** The opaque struct describing a Halide function. Wrap it in a + * Function object to access it. */ +struct FunctionContents; + +/** A possibly-weak pointer to a Halide function. Take care to follow + * the rules mentioned above. Preserves weakness/strength on copy. + * + * Note that Function objects are always strong pointers to Halide + * functions. + */ +struct FunctionPtr { + /** A strong and weak pointer to the group. Only one of these + * should be non-zero. */ + // @{ + IntrusivePtr strong; + FunctionGroup *weak = nullptr; + // @} + + /** The index of the function within the group. */ + int idx = 0; + + /** Get a pointer to the group this Function belongs to. */ + FunctionGroup *group() const { + return weak ? weak : strong.get(); + } + + /** Get the opaque FunctionContents object this pointer refers + * to. Wrap it in a Function to do anything interesting with it. */ + // @{ + FunctionContents *get() const; + + FunctionContents &operator*() const { + return *get(); + } + + FunctionContents *operator->() const { + return get(); + } + // @} + + /** Convert from a strong reference to a weak reference. Does + * nothing if the pointer is undefined, or if the reference is + * already weak. */ + void weaken() { + weak = group(); + strong = nullptr; + } + + /** Convert from a weak reference to a strong reference. Does + * nothing if the pointer is undefined, or if the reference is + * already strong. */ + void strengthen() { + strong = group(); + weak = nullptr; + } + + /** Check if the reference is defined. */ + bool defined() const { + return weak || strong.defined(); + } + + /** Check if two FunctionPtrs refer to the same Function. */ + bool same_as(const FunctionPtr &other) const { + return idx == other.idx && group() == other.group(); + } + + /** Pointer comparison, for using FunctionPtrs as keys in maps and + * sets. */ + bool operator<(const FunctionPtr &other) const { + return get() < other.get(); + } +}; + +} +} + +#endif diff --git a/src/IR.cpp b/src/IR.cpp index b8be68d409ed..cf9db3a88c9d 100644 --- a/src/IR.cpp +++ b/src/IR.cpp @@ -535,7 +535,7 @@ Expr Call::make(Function func, const std::vector &args, int idx) { } Expr Call::make(Type type, const std::string &name, const std::vector &args, CallType call_type, - IntrusivePtr func, int value_index, + FunctionPtr func, int value_index, Buffer<> image, Parameter param) { if (name == Call::prefetch && call_type == Call::Intrinsic) { internal_assert(args.size() % 2 == 0) diff --git a/src/IR.h b/src/IR.h index 9a79961c0a93..ee79a00a91dc 100644 --- a/src/IR.h +++ b/src/IR.h @@ -540,10 +540,8 @@ struct Call : public ExprNode { trace; // If it's a call to another halide function, this call node holds - // onto a pointer to that function for the purposes of reference - // counting only. Self-references in update definitions do not - // have this set, to avoid cycles. - IntrusivePtr func; + // a possibly-weak reference to that function. + FunctionPtr func; // If that function has multiple values, which value does this // call node refer to? @@ -558,7 +556,7 @@ struct Call : public ExprNode { Parameter param; EXPORT static Expr make(Type type, const std::string &name, const std::vector &args, CallType call_type, - IntrusivePtr func = nullptr, int value_index = 0, + FunctionPtr func = FunctionPtr(), int value_index = 0, Buffer<> image = Buffer<>(), Parameter param = Parameter()); /** Convenience constructor for calls to other halide functions */ @@ -566,12 +564,12 @@ struct Call : public ExprNode { /** Convenience constructor for loads from concrete images */ static Expr make(Buffer<> image, const std::vector &args) { - return make(image.type(), image.name(), args, Image, nullptr, 0, image, Parameter()); + return make(image.type(), image.name(), args, Image, FunctionPtr(), 0, image, Parameter()); } /** Convenience constructor for loads from images parameters */ static Expr make(Parameter param, const std::vector &args) { - return make(param.type(), param.name(), args, Image, nullptr, 0, Buffer<>(), param); + return make(param.type(), param.name(), args, Image, FunctionPtr(), 0, Buffer<>(), param); } /** Check if a call node is pure within a pipeline, meaning that diff --git a/src/InjectImageIntrinsics.cpp b/src/InjectImageIntrinsics.cpp index 309038641439..d4f31e95c872 100644 --- a/src/InjectImageIntrinsics.cpp +++ b/src/InjectImageIntrinsics.cpp @@ -119,7 +119,7 @@ class InjectImageIntrinsics : public IRMutator { Call::image_load, args, Call::PureIntrinsic, - nullptr, + FunctionPtr(), 0, call->image, call->param); diff --git a/src/InjectOpenGLIntrinsics.cpp b/src/InjectOpenGLIntrinsics.cpp index d61b67a4ce32..80cb66d47c69 100644 --- a/src/InjectOpenGLIntrinsics.cpp +++ b/src/InjectOpenGLIntrinsics.cpp @@ -71,7 +71,7 @@ class InjectOpenGLIntrinsics : public IRMutator { expr = Call::make(call->type, Call::glsl_texture_load, vector(&args[0], &args[5]), - Call::Intrinsic, nullptr, 0, + Call::Intrinsic, FunctionPtr(), 0, call->image, call->param); } else if (call->is_intrinsic(Call::image_store)) { user_assert(call->args.size() == 6) diff --git a/src/Lower.cpp b/src/Lower.cpp index 4429bf14adb6..f019b7e6ac96 100644 --- a/src/Lower.cpp +++ b/src/Lower.cpp @@ -15,7 +15,6 @@ #include "Debug.h" #include "DebugArguments.h" #include "DebugToFile.h" -#include "DeepCopy.h" #include "Deinterleave.h" #include "EarlyFree.h" #include "FindCalls.h" @@ -83,8 +82,7 @@ Module lower(const vector &output_funcs, const string &pipeline_name, // Compute an environment map env; for (Function f : output_funcs) { - map more_funcs = find_transitive_calls(f); - env.insert(more_funcs.begin(), more_funcs.end()); + populate_environment(f, env); } // Create a deep-copy of the entire graph of Funcs. @@ -379,6 +377,26 @@ Module lower(const vector &output_funcs, const string &pipeline_name, } } + // We're about to drop the environment and outputs vector, which + // contain the only strong refs to Functions that may still be + // pointed to by the IR. So make those refs strong. + class StrengthenRefs : public IRMutator { + using IRMutator::visit; + void visit(const Call *c) { + IRMutator::visit(c); + c = expr.as(); + internal_assert(c); + if (c->func.defined()) { + FunctionPtr ptr = c->func; + ptr.strengthen(); + expr = Call::make(c->type, c->name, c->args, c->call_type, + ptr, c->value_index, + c->image, c->param); + } + } + }; + s = StrengthenRefs().mutate(s); + LoweredFunc main_func(pipeline_name, public_args, s, linkage_type); // If we're in debug mode, add code that prints the args. diff --git a/src/PrintLoopNest.cpp b/src/PrintLoopNest.cpp index 0ef1c6680785..bb4fa211efc5 100644 --- a/src/PrintLoopNest.cpp +++ b/src/PrintLoopNest.cpp @@ -1,5 +1,4 @@ #include "PrintLoopNest.h" -#include "DeepCopy.h" #include "FindCalls.h" #include "Function.h" #include "Func.h" diff --git a/src/Schedule.cpp b/src/Schedule.cpp index 68e8162c9de8..0b9a790f7d6a 100644 --- a/src/Schedule.cpp +++ b/src/Schedule.cpp @@ -113,11 +113,7 @@ bool LoopLevel::operator==(const LoopLevel &other) const { namespace Internal { -typedef std::map, IntrusivePtr> DeepCopyMap; - -IntrusivePtr deep_copy_function_contents_helper( - const IntrusivePtr &src, - DeepCopyMap &copied_map); +typedef std::map DeepCopyMap; /** A schedule for a halide function, which defines where, when, and * how it should be evaluated. */ @@ -127,7 +123,7 @@ struct FuncScheduleContents { LoopLevel store_level, compute_level; std::vector storage_dims; std::vector bounds; - std::map> wrappers; + std::map wrappers; bool memoized; FuncScheduleContents() : @@ -215,7 +211,7 @@ EXPORT void destroy(const StageScheduleContents *p) { FuncSchedule::FuncSchedule() : contents(new FuncScheduleContents) {} FuncSchedule FuncSchedule::deep_copy( - std::map, IntrusivePtr> &copied_map) const { + std::map &copied_map) const { internal_assert(contents.defined()) << "Cannot deep-copy undefined FuncSchedule\n"; FuncSchedule copy; @@ -225,17 +221,11 @@ FuncSchedule FuncSchedule::deep_copy( copy.contents->bounds = contents->bounds; copy.contents->memoized = contents->memoized; - // Deep-copy wrapper functions. If function has already been deep-copied before, - // i.e. it's in the 'copied_map', use the deep-copied version from the map instead - // of creating a new deep-copy + // Deep-copy wrapper functions. for (const auto &iter : contents->wrappers) { - IntrusivePtr &copied_func = copied_map[iter.second]; - if (copied_func.defined()) { - copy.contents->wrappers[iter.first] = copied_func; - } else { - copy.contents->wrappers[iter.first] = deep_copy_function_contents_helper(iter.second, copied_map); - copied_map[iter.second] = copy.contents->wrappers[iter.first]; - } + FunctionPtr &copied_func = copied_map[iter.second]; + internal_assert(copied_func.defined()) << Function(iter.second).name() << "\n"; + copy.contents->wrappers[iter.first] = copied_func; } internal_assert(copy.contents->wrappers.size() == contents->wrappers.size()); return copy; @@ -265,16 +255,16 @@ const std::vector &FuncSchedule::bounds() const { return contents->bounds; } -std::map> &FuncSchedule::wrappers() { +std::map &FuncSchedule::wrappers() { return contents->wrappers; } -const std::map> &FuncSchedule::wrappers() const { +const std::map &FuncSchedule::wrappers() const { return contents->wrappers; } void FuncSchedule::add_wrapper(const std::string &f, - const IntrusivePtr &wrapper) { + const Internal::FunctionPtr &wrapper) { if (contents->wrappers.count(f)) { if (f.empty()) { user_warning << "Replacing previous definition of global wrapper in function \"" diff --git a/src/Schedule.h b/src/Schedule.h index c0bee3140b8c..9aed9d998c96 100644 --- a/src/Schedule.h +++ b/src/Schedule.h @@ -7,6 +7,7 @@ #include "Expr.h" #include "Parameter.h" +#include "FunctionPtr.h" #include @@ -258,7 +259,7 @@ class FuncSchedule { * same FunctionContents multiple times. */ EXPORT FuncSchedule deep_copy( - std::map, IntrusivePtr> &copied_map) const; + std::map &copied_map) const; /** This flag is set to true if the schedule is memoized. */ // @{ @@ -288,10 +289,10 @@ class FuncSchedule { * all calls to the function by all other functions (excluding itself) in * the pipeline with the wrapper. See \ref Func::in for more details. */ // @{ - const std::map> &wrappers() const; - std::map> &wrappers(); + const std::map &wrappers() const; + std::map &wrappers(); EXPORT void add_wrapper(const std::string &f, - const IntrusivePtr &wrapper); + const Internal::FunctionPtr &wrapper); // @} /** At what sites should we inject the allocation and the diff --git a/src/WrapCalls.cpp b/src/WrapCalls.cpp index f7f05e25277f..9deecb6d49cf 100644 --- a/src/WrapCalls.cpp +++ b/src/WrapCalls.cpp @@ -11,30 +11,36 @@ using std::set; using std::string; using std::vector; -typedef map SubstitutionMap; +typedef map SubstitutionMap; namespace { -void insert_func_wrapper_helper(map &func_wrappers_map, - const Function &in_func, const Function &wrapped_func, - const Function &wrapper) { - internal_assert(in_func.get_contents().defined() && wrapped_func.get_contents().defined() && - wrapper.get_contents().defined()); +void insert_func_wrapper_helper(map &func_wrappers_map, + FunctionPtr in_func, + FunctionPtr wrapped_func, + FunctionPtr wrapper) { + internal_assert(in_func.defined() && + wrapped_func.defined() && + wrapper.defined()); internal_assert(func_wrappers_map[in_func].count(wrapped_func) == 0) << "Should only have one wrapper for each function call in a Func\n"; SubstitutionMap &wrappers_map = func_wrappers_map[in_func]; for (auto iter = wrappers_map.begin(); iter != wrappers_map.end(); ++iter) { if (iter->second.same_as(wrapped_func)) { - debug(4) << "Merging wrapper of " << in_func.name() << " [" << iter->first.name() - << ", " << iter->second.name() << "] with [" << wrapped_func.name() << ", " - << wrapper.name() << "]\n"; + debug(4) << "Merging wrapper of " << Function(in_func).name() + << " [" << Function(iter->first).name() + << ", " << Function(iter->second).name() + << "] with [" << Function(wrapped_func).name() << ", " + << Function(wrapper).name() << "]\n"; iter->second = wrapper; return; } else if (wrapper.same_as(iter->first)) { - debug(4) << "Merging wrapper of " << in_func.name() << " [" << wrapped_func.name() - << ", " << wrapper.name() << "] with [" << iter->first.name() << ", " - << iter->second.name() << "]\n"; + debug(4) << "Merging wrapper of " << Function(in_func).name() + << " [" << Function(wrapped_func).name() + << ", " << Function(wrapper).name() + << "] with [" << Function(iter->first).name() + << ", " << Function(iter->second).name() << "]\n"; wrappers_map.emplace(wrapped_func, iter->second); wrappers_map.erase(iter); return; @@ -48,17 +54,17 @@ void insert_func_wrapper_helper(map wrap_func_calls(const map &env) { map wrapped_env; - map func_wrappers_map; // In Func -> [wrapped Func -> wrapper] + map func_wrappers_map; // In Func -> [wrapped Func -> wrapper] for (const auto &iter : env) { wrapped_env.emplace(iter.first, iter.second); - func_wrappers_map[iter.second]; + func_wrappers_map[iter.second.get_contents()]; } for (const auto &it : env) { string wrapped_fname = it.first; - const Function &wrapped_func = it.second; - const auto &wrappers = wrapped_func.schedule().wrappers(); + FunctionPtr wrapped_func = it.second.get_contents(); + const auto &wrappers = it.second.schedule().wrappers(); // Put the names of all wrappers of this Function into the set for // faster comparison during the substitution. @@ -69,7 +75,7 @@ map wrap_func_calls(const map &env) { for (const auto &iter : wrappers) { string in_func = iter.first; - const Function &wrapper = Function(iter.second); // This is already the deep-copy version + FunctionPtr wrapper = iter.second; if (in_func.empty()) { // Global wrapper for (const auto &wrapped_env_iter : wrapped_env) { @@ -80,7 +86,8 @@ map wrap_func_calls(const map &env) { // so we don't want to rewrite the calls done by the // wrapper. We also shouldn't rewrite the original // function itself. - debug(4) << "Skip over replacing \"" << in_func << "\" with \"" << wrapper.name() << "\"\n"; + debug(4) << "Skip over replacing \"" << in_func + << "\" with \"" << Function(wrapper).name() << "\"\n"; continue; } if (wrappers.count(in_func)) { @@ -91,13 +98,15 @@ map wrap_func_calls(const map &env) { } debug(4) << "Global wrapper: replacing reference of \"" << wrapped_fname << "\" in \"" << in_func - << "\" with \"" << wrapper.name() << "\"\n"; - insert_func_wrapper_helper(func_wrappers_map, wrapped_env_iter.second, wrapped_func, wrapper); + << "\" with \"" << Function(wrapper).name() << "\"\n"; + insert_func_wrapper_helper(func_wrappers_map, + wrapped_env_iter.second.get_contents(), + wrapped_func, wrapper); } } else { // Custom wrapper debug(4) << "Custom wrapper: replacing reference of \"" << wrapped_fname << "\" in \"" << in_func << "\" with \"" - << wrapper.name() << "\"\n"; + << Function(wrapper).name() << "\"\n"; const auto &in_func_iter = wrapped_env.find(in_func); if (in_func_iter == wrapped_env.end()) { @@ -112,17 +121,20 @@ map wrap_func_calls(const map &env) { // f.in(g); // f.realize(..); debug(4) << " skip custom wrapper for " << in_func << " [" << wrapped_fname - << " -> " << wrapper.name() << "] since it's not in the pipeline\n"; + << " -> " << Function(wrapper).name() << "] since it's not in the pipeline\n"; continue; } - insert_func_wrapper_helper(func_wrappers_map, wrapped_env[in_func], wrapped_func, wrapper); + insert_func_wrapper_helper(func_wrappers_map, + wrapped_env[in_func].get_contents(), + wrapped_func, + wrapper); } } } // Perform the substitution for (auto &iter : wrapped_env) { - const auto &substitutions = func_wrappers_map[iter.second]; + const auto &substitutions = func_wrappers_map[iter.second.get_contents()]; if (!substitutions.empty()) { iter.second.substitute_calls(substitutions); } From 2aa759d18bc4e08e687e8c9564e27aa50140dae7 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Thu, 13 Jul 2017 10:52:40 -0700 Subject: [PATCH 3/3] Fix print_loop_nest --- src/PrintLoopNest.cpp | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/PrintLoopNest.cpp b/src/PrintLoopNest.cpp index bb4fa211efc5..0299a3cfe4fb 100644 --- a/src/PrintLoopNest.cpp +++ b/src/PrintLoopNest.cpp @@ -157,8 +157,7 @@ string print_loop_nest(const vector &output_funcs) { // Compute an environment map env; for (Function f : output_funcs) { - map more_funcs = find_transitive_calls(f); - env.insert(more_funcs.begin(), more_funcs.end()); + populate_environment(f, env); } // Create a deep-copy of the entire graph of Funcs. @@ -170,6 +169,11 @@ string print_loop_nest(const vector &output_funcs) { Func(f).compute_root().store_root(); } + // Ensure that all ScheduleParams become well-defined constant Exprs. + for (auto &f : env) { + f.second.substitute_schedule_param_exprs(); + } + // Substitute in wrapper Funcs env = wrap_func_calls(env);