Skip to content

Commit

Permalink
[Lang] Support TensorType for irpass::alg_simp() (#8225)
Browse files Browse the repository at this point in the history
Issue: #

### Brief Summary

<!--
copilot:summary
-->
### <samp>🤖 Generated by Copilot at 411b19e</samp>

Add a template function `get_const_stmt_with_value` to
`taichi/ir/statements.h` that creates IR statements for constants of
different types. This is part of improving matrix support in the IR.

### Walkthrough

<!--
copilot:walkthrough
-->
### <samp>🤖 Generated by Copilot at 411b19e</samp>

* Add a new template function `get_const_stmt_with_value` to create
constant statements of different types
([link](https://github.com/taichi-dev/taichi/pull/8225/files?diff=unified&w=0#diff-917d9436dcaafa0f1e41ae9bad90273a303f036f00da94e417788a7fa1dc5260R2021-R2050))
  • Loading branch information
jim19930609 authored Jun 29, 2023
1 parent 56b5bc9 commit 3da2a41
Show file tree
Hide file tree
Showing 3 changed files with 411 additions and 132 deletions.
30 changes: 30 additions & 0 deletions taichi/ir/statements.h
Original file line number Diff line number Diff line change
Expand Up @@ -2033,4 +2033,34 @@ class MatrixInitStmt : public Stmt {
TI_DEFINE_ACCEPT_AND_CLONE
};

template <typename T>
std::vector<std::unique_ptr<Stmt>> get_const_stmt_with_value(DataType dt,
T value) {
if (dt->is<PrimitiveType>()) {
TypedConstant constant(dt, value);
auto const_stmt = std::make_unique<ConstStmt>(constant);

std::vector<std::unique_ptr<Stmt>> ret;
ret.push_back(std::move(const_stmt));
return ret;

} else if (dt->is<TensorType>()) {
DataType element_dt = dt.get_element_type();
std::vector<std::unique_ptr<Stmt>> stmts =
get_const_stmt_with_value(element_dt, value);

Stmt *elem_stmt = stmts.back().get();
std::vector<Stmt *> elem_stmts(dt->as<TensorType>()->get_num_elements(),
elem_stmt);

auto matrix_init_stmt = std::make_unique<MatrixInitStmt>(elem_stmts);
matrix_init_stmt->ret_type = dt;

stmts.push_back(std::move(matrix_init_stmt));
return stmts;
} else {
TI_NOT_IMPLEMENTED
}
}

} // namespace taichi::lang
Loading

0 comments on commit 3da2a41

Please sign in to comment.