diff --git a/taichi/transforms/auto_diff.cpp b/taichi/transforms/auto_diff.cpp index 1c057e3055b9f..cde76e75594f7 100644 --- a/taichi/transforms/auto_diff.cpp +++ b/taichi/transforms/auto_diff.cpp @@ -288,12 +288,32 @@ class PromoteSSA2LocalVar : public BasicStmtVisitor { execute_once_ = true; } + bool is_mutable_loop_index(Stmt *stmt) { + if (stmt->is()) { + auto loop = stmt->as()->loop; + // TODO: we assume that the loop indices of a top-level for are constants + // Note: the StructForStmt and MeshForStmt are top-level fors + if (loop->parent->parent_stmt == nullptr) { + return false; + } + if (loop->is()) { + auto range = loop->as(); + if (range->begin->is() && range->end->is()) { + return false; + } + } + return true; + } + return false; + } + void visit(Stmt *stmt) override { if (execute_once_) return; + bool mutable_loop_index = is_mutable_loop_index(stmt); if (!(stmt->is() || stmt->is() || stmt->is() || stmt->is() || - stmt->is())) { + stmt->is() || mutable_loop_index)) { // TODO: this list may be incomplete return; } @@ -331,6 +351,42 @@ class PromoteSSA2LocalVar : public BasicStmtVisitor { } void visit(RangeForStmt *stmt) override { + bool mutable_loop_begin_index = is_mutable_loop_index(stmt->begin); + // TODO: To refine this, handle the case that the begin or end of the loop + // is read from argload + if (stmt->begin->is() || mutable_loop_begin_index) { + // Create a alloc + auto alloc = Stmt::make(stmt->begin->ret_type); + auto alloc_ptr = alloc.get(); + TI_ASSERT(alloca_block_); + alloca_block_->insert(std::move(alloc), 0); + auto load = + stmt->begin->insert_after_me(Stmt::make(alloc_ptr)); + Stmt *glb_load_stmt_backup = stmt->begin; + irpass::replace_all_usages_with(stmt->parent, stmt->begin, load); + // Create the load first so that the operand of the store won't get + // replaced + load->insert_before_me( + Stmt::make(alloc_ptr, glb_load_stmt_backup)); + } + + bool mutable_loop_end_index = is_mutable_loop_index(stmt->end); + if (stmt->end->is() || mutable_loop_end_index) { + // Create a alloc + auto alloc = Stmt::make(stmt->end->ret_type); + auto alloc_ptr = alloc.get(); + TI_ASSERT(alloca_block_); + alloca_block_->insert(std::move(alloc), 0); + auto load = + stmt->end->insert_after_me(Stmt::make(alloc_ptr)); + Stmt *glb_load_stmt_backup = stmt->end; + irpass::replace_all_usages_with(stmt->parent, stmt->end, load); + // Create the load first so that the operand of the store won't get + // replaced + load->insert_before_me( + Stmt::make(alloc_ptr, glb_load_stmt_backup)); + } + auto old_execute_once = execute_once_; execute_once_ = false; // loop body may be executed many times stmt->body->accept(this); @@ -439,6 +495,17 @@ class AdStackAllocaJudger : public BasicStmtVisitor { stmt->false_statements->accept(this); } + // Check whether the targets serves as the begin or end of a for loop + void visit(RangeForStmt *stmt) override { + if (is_stack_needed_) + return; + if (stmt->begin == target_alloca_ || stmt->end == target_alloca_) { + is_stack_needed_ = true; + return; + } + stmt->body->accept(this); + } + static bool run(AllocaStmt *target_alloca) { AdStackAllocaJudger judger; judger.target_alloca_ = target_alloca; @@ -969,7 +1036,18 @@ class MakeAdjoint : public ADTransform { auto new_for = for_stmt->clone(); auto new_for_ptr = new_for->as(); new_for_ptr->reversed = !new_for_ptr->reversed; - insert_grad_stmt(std::move(new_for)); + + auto new_for_stmt = insert_grad_stmt(std::move(new_for)); + + if (new_for_ptr->begin->is()) { + new_for_ptr->begin = + new_for_stmt->insert_before_me(new_for_ptr->begin->clone()); + } + if (new_for_ptr->end->is()) { + new_for_ptr->end = + new_for_stmt->insert_before_me(new_for_ptr->end->clone()); + } + const int len = new_for_ptr->body->size(); for (int i = 0; i < len; i++) { diff --git a/tests/python/test_ad_for.py b/tests/python/test_ad_for.py index d68222d75751f..06a0750d1bce5 100644 --- a/tests/python/test_ad_for.py +++ b/tests/python/test_ad_for.py @@ -1010,3 +1010,40 @@ def compute(): compute.grad() for i in range(N): assert a.grad[i] == i + + +@test_utils.test(require=ti.extension.adstack) +def test_mutable_loop_index(): + NUM = 5 + x = ti.field(dtype=float) + y = ti.field(dtype=float) + ti.root.dense(ti.i, NUM).place(x, y) + loss = ti.field(dtype=float, shape=()) + ti.root.lazy_grad() + + @ti.kernel + def initialize(): + for i in x: + x[i] = i + y[i] = 0 + loss[None] = 0 + + @ti.kernel + def compute_loss(): + for i in range(NUM): + l = 0.0 + for j in range(i): + for k in range(j): + l += x[j] * x[k] + y[i] = l + + for i in range(NUM): + loss[None] += y[i] + + initialize() + with ti.ad.Tape(loss=loss): + compute_loss() + + refs = [10., 7., 5., 3., 0.] + for i in range(NUM): + assert x.grad[i] == refs[i]