Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[autodiff] Handle mutable for loop index #7778

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 80 additions & 2 deletions taichi/transforms/auto_diff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -288,12 +288,32 @@ class PromoteSSA2LocalVar : public BasicStmtVisitor {
execute_once_ = true;
}

bool is_mutable_loop_index(Stmt *stmt) {
if (stmt->is<LoopIndexStmt>()) {
auto loop = stmt->as<LoopIndexStmt>()->loop;
// TODO: we assume that the loop indices of a top-level for are constants
// Note: the StructForStmt and MeshForStmt are top-level fors
if (loop->parent->parent_stmt == nullptr) {
return false;
}
if (loop->is<RangeForStmt>()) {
auto range = loop->as<RangeForStmt>();
if (range->begin->is<ConstStmt>() && range->end->is<ConstStmt>()) {
return false;
}
}
return true;
}
return false;
}

void visit(Stmt *stmt) override {
if (execute_once_)
return;
bool mutable_loop_index = is_mutable_loop_index(stmt);
if (!(stmt->is<UnaryOpStmt>() || stmt->is<BinaryOpStmt>() ||
stmt->is<TernaryOpStmt>() || stmt->is<GlobalLoadStmt>() ||
stmt->is<AllocaStmt>())) {
stmt->is<AllocaStmt>() || mutable_loop_index)) {
// TODO: this list may be incomplete
return;
}
Expand Down Expand Up @@ -331,6 +351,42 @@ class PromoteSSA2LocalVar : public BasicStmtVisitor {
}

void visit(RangeForStmt *stmt) override {
bool mutable_loop_begin_index = is_mutable_loop_index(stmt->begin);
// TODO: To refine this, handle the case that the begin or end of the loop
// is read from argload
if (stmt->begin->is<GlobalLoadStmt>() || mutable_loop_begin_index) {
// Create a alloc
auto alloc = Stmt::make<AllocaStmt>(stmt->begin->ret_type);
auto alloc_ptr = alloc.get();
TI_ASSERT(alloca_block_);
alloca_block_->insert(std::move(alloc), 0);
auto load =
stmt->begin->insert_after_me(Stmt::make<LocalLoadStmt>(alloc_ptr));
Stmt *glb_load_stmt_backup = stmt->begin;
irpass::replace_all_usages_with(stmt->parent, stmt->begin, load);
// Create the load first so that the operand of the store won't get
// replaced
load->insert_before_me(
Stmt::make<LocalStoreStmt>(alloc_ptr, glb_load_stmt_backup));
}

bool mutable_loop_end_index = is_mutable_loop_index(stmt->end);
if (stmt->end->is<GlobalLoadStmt>() || mutable_loop_end_index) {
// Create a alloc
auto alloc = Stmt::make<AllocaStmt>(stmt->end->ret_type);
auto alloc_ptr = alloc.get();
TI_ASSERT(alloca_block_);
alloca_block_->insert(std::move(alloc), 0);
auto load =
stmt->end->insert_after_me(Stmt::make<LocalLoadStmt>(alloc_ptr));
Stmt *glb_load_stmt_backup = stmt->end;
irpass::replace_all_usages_with(stmt->parent, stmt->end, load);
// Create the load first so that the operand of the store won't get
// replaced
load->insert_before_me(
Stmt::make<LocalStoreStmt>(alloc_ptr, glb_load_stmt_backup));
}

auto old_execute_once = execute_once_;
execute_once_ = false; // loop body may be executed many times
stmt->body->accept(this);
Expand Down Expand Up @@ -439,6 +495,17 @@ class AdStackAllocaJudger : public BasicStmtVisitor {
stmt->false_statements->accept(this);
}

// Check whether the targets serves as the begin or end of a for loop
void visit(RangeForStmt *stmt) override {
if (is_stack_needed_)
return;
if (stmt->begin == target_alloca_ || stmt->end == target_alloca_) {
is_stack_needed_ = true;
return;
}
stmt->body->accept(this);
}

static bool run(AllocaStmt *target_alloca) {
AdStackAllocaJudger judger;
judger.target_alloca_ = target_alloca;
Expand Down Expand Up @@ -969,7 +1036,18 @@ class MakeAdjoint : public ADTransform {
auto new_for = for_stmt->clone();
auto new_for_ptr = new_for->as<RangeForStmt>();
new_for_ptr->reversed = !new_for_ptr->reversed;
insert_grad_stmt(std::move(new_for));

auto new_for_stmt = insert_grad_stmt(std::move(new_for));

if (new_for_ptr->begin->is<AdStackLoadTopStmt>()) {
new_for_ptr->begin =
new_for_stmt->insert_before_me(new_for_ptr->begin->clone());
}
if (new_for_ptr->end->is<AdStackLoadTopStmt>()) {
new_for_ptr->end =
new_for_stmt->insert_before_me(new_for_ptr->end->clone());
}

const int len = new_for_ptr->body->size();

for (int i = 0; i < len; i++) {
Expand Down
37 changes: 37 additions & 0 deletions tests/python/test_ad_for.py
Original file line number Diff line number Diff line change
Expand Up @@ -1010,3 +1010,40 @@ def compute():
compute.grad()
for i in range(N):
assert a.grad[i] == i


@test_utils.test(require=ti.extension.adstack)
def test_mutable_loop_index():
NUM = 5
x = ti.field(dtype=float)
y = ti.field(dtype=float)
ti.root.dense(ti.i, NUM).place(x, y)
loss = ti.field(dtype=float, shape=())
ti.root.lazy_grad()

@ti.kernel
def initialize():
for i in x:
x[i] = i
y[i] = 0
loss[None] = 0

@ti.kernel
def compute_loss():
for i in range(NUM):
l = 0.0
for j in range(i):
for k in range(j):
l += x[j] * x[k]
y[i] = l

for i in range(NUM):
loss[None] += y[i]

initialize()
with ti.ad.Tape(loss=loss):
compute_loss()

refs = [10., 7., 5., 3., 0.]
for i in range(NUM):
assert x.grad[i] == refs[i]