Skip to content

Commit

Permalink
Merge pull request halide#2184 from halide/fix_wrapper_leak
Browse files Browse the repository at this point in the history
Fix wrapper leak
  • Loading branch information
abadams authored Jul 17, 2017
2 parents 5a6b3ec + 2aa759d commit ec778f9
Show file tree
Hide file tree
Showing 21 changed files with 440 additions and 334 deletions.
3 changes: 1 addition & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,6 @@ SOURCE_FILES = \
Debug.cpp \
DebugArguments.cpp \
DebugToFile.cpp \
DeepCopy.cpp \
Definition.cpp \
Deinterleave.cpp \
DeviceArgument.cpp \
Expand Down Expand Up @@ -424,7 +423,6 @@ HEADER_FILES = \
Debug.h \
DebugArguments.h \
DebugToFile.h \
DeepCopy.h \
Definition.h \
Deinterleave.h \
DeviceArgument.h \
Expand All @@ -441,6 +439,7 @@ HEADER_FILES = \
Float16.h \
Func.h \
Function.h \
FunctionPtr.h \
FuseGPUThreadLoops.h \
FuzzFloatStores.h \
Generator.h \
Expand Down
26 changes: 12 additions & 14 deletions src/Associativity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,6 @@ class ConvertSelfRef : public IRMutator {
internal_assert(op);

if ((op->call_type == Call::Halide) && (func == op->name)) {
internal_assert(!op->func.defined())
<< "Func should not have been defined for a self-reference\n";
internal_assert(args.size() == op->args.size())
<< "Self-reference should have the same number of args as the original\n";
for (size_t i = 0; i < op->args.size(); i++) {
Expand Down Expand Up @@ -543,7 +541,7 @@ void associativity_test() {
Expr x = Variable::make(t, "x");
Expr y = Variable::make(t, "y");
Expr x_idx = Variable::make(Int(32), "x_idx");
Expr f_call_0 = Call::make(t, "f", {x_idx}, Call::CallType::Halide, nullptr, 0);
Expr f_call_0 = Call::make(t, "f", {x_idx}, Call::CallType::Halide, FunctionPtr(), 0);

// f(x) = uint8(uint16(x + y), 255)
check_associativity("f", {x_idx}, {Cast::make(UInt(8), min(Cast::make(UInt(16), y + f_call_0), make_const(t, 255)))},
Expand Down Expand Up @@ -588,7 +586,7 @@ void associativity_test() {
Expr x = Variable::make(t, "x");
Expr y = Variable::make(t, "y");
Expr x_idx = Variable::make(Int(32), "x_idx");
Expr f_call_0 = Call::make(t, "f", {x_idx}, Call::CallType::Halide, nullptr, 0);
Expr f_call_0 = Call::make(t, "f", {x_idx}, Call::CallType::Halide, FunctionPtr(), 0);

// f(x) = y && f(x)
check_associativity("f", {x_idx}, {And::make(y, f_call_0)},
Expand Down Expand Up @@ -622,11 +620,11 @@ void associativity_test() {
zs[i] = Variable::make(t, "z" + std::to_string(i));
}

Expr f_call_0 = Call::make(t, "f", {x}, Call::CallType::Halide, nullptr, 0);
Expr f_call_1 = Call::make(t, "f", {x}, Call::CallType::Halide, nullptr, 1);
Expr f_call_2 = Call::make(t, "f", {x}, Call::CallType::Halide, nullptr, 2);
Expr g_call_0 = Call::make(t, "g", {rx}, Call::CallType::Halide, nullptr, 0);
Expr g_call_1 = Call::make(t, "g", {rx}, Call::CallType::Halide, nullptr, 1);
Expr f_call_0 = Call::make(t, "f", {x}, Call::CallType::Halide, FunctionPtr(), 0);
Expr f_call_1 = Call::make(t, "f", {x}, Call::CallType::Halide, FunctionPtr(), 1);
Expr f_call_2 = Call::make(t, "f", {x}, Call::CallType::Halide, FunctionPtr(), 2);
Expr g_call_0 = Call::make(t, "g", {rx}, Call::CallType::Halide, FunctionPtr(), 0);
Expr g_call_1 = Call::make(t, "g", {rx}, Call::CallType::Halide, FunctionPtr(), 1);

// f(x) = f(x)
check_associativity("f", {x}, {f_call_0},
Expand Down Expand Up @@ -759,11 +757,11 @@ void associativity_test() {

{
Expr ry = Variable::make(t, "ry");
Expr f_xy_call_0 = Call::make(t, "f", {x, y}, Call::CallType::Halide, nullptr, 0);
Expr f_xy_call_1 = Call::make(t, "f", {x, y}, Call::CallType::Halide, nullptr, 1);
Expr f_xy_call_2 = Call::make(t, "f", {x, y}, Call::CallType::Halide, nullptr, 2);
Expr f_xy_call_3 = Call::make(t, "f", {x, y}, Call::CallType::Halide, nullptr, 3);
Expr g_xy_call_0 = Call::make(t, "g", {rx, ry}, Call::CallType::Halide, nullptr, 0);
Expr f_xy_call_0 = Call::make(t, "f", {x, y}, Call::CallType::Halide, FunctionPtr(), 0);
Expr f_xy_call_1 = Call::make(t, "f", {x, y}, Call::CallType::Halide, FunctionPtr(), 1);
Expr f_xy_call_2 = Call::make(t, "f", {x, y}, Call::CallType::Halide, FunctionPtr(), 2);
Expr f_xy_call_3 = Call::make(t, "f", {x, y}, Call::CallType::Halide, FunctionPtr(), 3);
Expr g_xy_call_0 = Call::make(t, "g", {rx, ry}, Call::CallType::Halide, FunctionPtr(), 0);

// 2D argmin + trivial update:
// f(x, y) = Tuple(min(f(x, y)[0], g(r.x, r.y)[0]),
Expand Down
3 changes: 1 addition & 2 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,6 @@ set(HEADER_FILES
Debug.h
DebugArguments.h
DebugToFile.h
DeepCopy.h
Definition.h
Deinterleave.h
DeviceArgument.h
Expand All @@ -302,6 +301,7 @@ set(HEADER_FILES
Float16.h
Func.h
Function.h
FunctionPtr.h
FuseGPUThreadLoops.h
FuzzFloatStores.h
Generator.h
Expand Down Expand Up @@ -438,7 +438,6 @@ add_library(Halide ${HALIDE_LIBRARY_TYPE}
Debug.cpp
DebugArguments.cpp
DebugToFile.cpp
DeepCopy.cpp
Definition.cpp
Deinterleave.cpp
DeviceArgument.cpp
Expand Down
53 changes: 0 additions & 53 deletions src/DeepCopy.cpp

This file was deleted.

24 changes: 0 additions & 24 deletions src/DeepCopy.h

This file was deleted.

23 changes: 19 additions & 4 deletions src/FindCalls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ using std::vector;
using std::string;
using std::pair;

namespace {
/* Find all the internal halide calls in an expr */
class FindCalls : public IRVisitor {
public:
Expand Down Expand Up @@ -38,7 +39,8 @@ class FindCalls : public IRVisitor {
}
};

void populate_environment(Function f, map<string, Function> &env, bool recursive = true) {
void populate_environment_helper(Function f, map<string, Function> &env,
bool recursive = true, bool include_wrappers = false) {
map<string, Function>::const_iterator iter = env.find(f.name());
if (iter != env.end()) {
user_assert(iter->second.same_as(f))
Expand All @@ -58,26 +60,39 @@ void populate_environment(Function f, map<string, Function> &env, bool recursive
}
}

if (include_wrappers) {
for (auto it : f.schedule().wrappers()) {
Function g(it.second);
calls.calls[g.name()] = g;
}
}

if (!recursive) {
env.insert(calls.calls.begin(), calls.calls.end());
} else {
env[f.name()] = f;

for (pair<string, Function> i : calls.calls) {
populate_environment(i.second, env);
populate_environment_helper(i.second, env, recursive, include_wrappers);
}
}
}

}

void populate_environment(Function f, map<string, Function> &env) {
populate_environment_helper(f, env, true, true);
}

map<string, Function> find_transitive_calls(Function f) {
map<string, Function> res;
populate_environment(f, res, true);
populate_environment_helper(f, res, true, false);
return res;
}

map<string, Function> find_direct_calls(Function f) {
map<string, Function> res;
populate_environment(f, res, false);
populate_environment_helper(f, res, false, false);
return res;
}

Expand Down
4 changes: 4 additions & 0 deletions src/FindCalls.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ std::map<std::string, Function> find_direct_calls(Function f);
*/
std::map<std::string, Function> find_transitive_calls(Function f);

/** Find all Functions transitively referenced by f in any way and add
* them to the given map. */
void populate_environment(Function f, std::map<std::string, Function> &env);

}
}

Expand Down
22 changes: 10 additions & 12 deletions src/Func.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -365,8 +365,6 @@ class SubstituteSelfReference : public IRMutator {
internal_assert(c);

if ((c->call_type == Call::Halide) && (func == c->name)) {
internal_assert(!c->func.defined())
<< "func should not have been defined for a self-reference\n";
debug(4) << "...Replace call to Func \"" << c->name << "\" with "
<< "\"" << substitute.name() << "\"\n";
vector<Expr> args;
Expand Down Expand Up @@ -872,7 +870,7 @@ Func Stage::rfactor(vector<pair<RVar, Var>> preserved) {
if (!prover_result.xs[i].var.empty()) {
Expr prev_val = Call::make(intm.output_types()[i], func_name,
f_store_args, Call::CallType::Halide,
nullptr, i);
FunctionPtr(), i);
replacements.emplace(prover_result.xs[i].var, prev_val);
} else {
user_warning << "Update definition of " << stage_name << " at index " << i
Expand Down Expand Up @@ -1742,16 +1740,16 @@ void Func::invalidate_cache() {
Func Func::in(const Func &f) {
invalidate_cache();
user_assert(name() != f.name()) << "Cannot call 'in()' on itself\n";
const map<string, IntrusivePtr<FunctionContents>> &wrappers = func.wrappers();
const map<string, FunctionPtr> &wrappers = func.wrappers();
const auto &iter = wrappers.find(f.name());
if (iter == wrappers.end()) {
Func wrapper(name() + "_in_" + f.name());
Func wrapper(func.new_function_in_same_group(name() + "_in_" + f.name()));
wrapper(args()) = (*this)(args());
func.add_wrapper(f.name(), wrapper.func);
return wrapper;
}

IntrusivePtr<FunctionContents> wrapper_contents = iter->second;
FunctionPtr wrapper_contents = iter->second;
internal_assert(wrapper_contents.defined());

// Make sure that no other Func shares the same wrapper as 'f'
Expand All @@ -1776,7 +1774,7 @@ Func Func::in(const vector<Func>& fs) {

// Either all Funcs have the same wrapper or they don't already have any wrappers.
// Otherwise, throw an error.
const map<string, IntrusivePtr<FunctionContents>> &wrappers = func.wrappers();
const map<string, FunctionPtr> &wrappers = func.wrappers();

const auto &iter = wrappers.find(fs[0].name());
if (iter == wrappers.end()) {
Expand All @@ -1786,7 +1784,7 @@ Func Func::in(const vector<Func>& fs) {
<< "Cannot define the wrapper since " << fs[i].name()
<< " already has a wrapper while " << fs[0].name() << " doesn't \n";
}
Func wrapper(name() + "_wrapper");
Func wrapper(func.new_function_in_same_group(name() + "_wrapper"));
wrapper(args()) = (*this)(args());
for (const Func &f : fs) {
user_assert(name() != f.name()) << "Cannot call 'in()' on itself\n";
Expand All @@ -1795,7 +1793,7 @@ Func Func::in(const vector<Func>& fs) {
return wrapper;
}

IntrusivePtr<FunctionContents> wrapper_contents = iter->second;
FunctionPtr wrapper_contents = iter->second;
internal_assert(wrapper_contents.defined());

// Make sure all the other Funcs in 'fs' share the same wrapper and no other
Expand Down Expand Up @@ -1825,16 +1823,16 @@ Func Func::in(const vector<Func>& fs) {

Func Func::in() {
invalidate_cache();
const map<string, IntrusivePtr<FunctionContents>> &wrappers = func.wrappers();
const map<string, FunctionPtr> &wrappers = func.wrappers();
const auto &iter = wrappers.find("");
if (iter == wrappers.end()) {
Func wrapper(name() + "_global_wrapper");
Func wrapper(func.new_function_in_same_group(name() + "_global_wrapper"));
wrapper(args()) = (*this)(args());
func.add_wrapper("", wrapper.func);
return wrapper;
}

IntrusivePtr<FunctionContents> wrapper_contents = iter->second;
FunctionPtr wrapper_contents = iter->second;
internal_assert(wrapper_contents.defined());
Function wrapper(wrapper_contents);
internal_assert(wrapper.frozen());
Expand Down
Loading

0 comments on commit ec778f9

Please sign in to comment.