Skip to content

Commit 58fc568

Browse files
areuschAndrewZhaoLuo
authored andcommitted
[usmp] Also remap VarNode to USMP-allocated buffer (#12880)
Before this patch, ConvertPoolAllocationsToOffsets would generate TIR like the following: let dense_let: Pointer(global int32) = @tir.address_of(global_workspace_37_buffer_var[69952], dtype=handle) for (k.outer: int32, 0, 64) { @tir.call_extern("gemm_1x1x1_update_UKVNAEBL", ..., dense, ...) } T_multiply[ax1] = @tir.q_multiply_shift(((dense: Buffer(dense_let, int32, [10], [], align=32)[ax1], ...) This caused CodegenSourceBase to later fail with this error: "src/target/source/codegen_source_base.cc", line 67 Check failed: (it != var_idmap_.end()) is false: Find undefined Variable dense After this patch, "dense" in the call_extern is changed to read "dense_let."
1 parent 5328c6b commit 58fc568

File tree

3 files changed

+114
-9
lines changed

3 files changed

+114
-9
lines changed

src/tir/usmp/analysis/extract_buffer_info.cc

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -429,15 +429,17 @@ void BufferInfoExtractor::VisitExpr_(const VarNode* op) {
429429

430430
Array<Var> static GetMatchedBuffers(const PrimFunc& func) {
431431
Array<Var> buffer_vars;
432-
for (unsigned int i = 0; i < func->params.size() - 1; i++) {
433-
Var param = func->params[i];
434-
buffer_vars.push_back(func->buffer_map[param]->data);
435-
}
436-
Var last_param = func->params.back();
437-
// Checks whether last var is present in the buffer map
438-
// because it could be the resource handle
439-
if (func->buffer_map.find(last_param) != func->buffer_map.end()) {
440-
buffer_vars.push_back(func->buffer_map[last_param]->data);
432+
if (func->params.size() > 0) {
433+
for (unsigned int i = 0; i < func->params.size() - 1; i++) {
434+
Var param = func->params[i];
435+
buffer_vars.push_back(func->buffer_map[param]->data);
436+
}
437+
Var last_param = func->params.back();
438+
// Checks whether last var is present in the buffer map
439+
// because it could be the resource handle
440+
if (func->buffer_map.find(last_param) != func->buffer_map.end()) {
441+
buffer_vars.push_back(func->buffer_map[last_param]->data);
442+
}
441443
}
442444
return buffer_vars;
443445
}

src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ class PoolAllocationToOffsetConverter : public StmtExprMutator {
9696
private:
9797
PrimExpr VisitExpr_(const CallNode* op) override;
9898
Stmt VisitStmt_(const AllocateNode* op) override;
99+
PrimExpr VisitExpr_(const VarNode* op) override;
99100
PrimExpr VisitExpr_(const BufferLoadNode* op) override;
100101
Stmt VisitStmt_(const BufferStoreNode* op) override;
101102

@@ -395,6 +396,15 @@ PrimExpr PoolAllocationToOffsetConverter::VisitExpr_(const BufferLoadNode* op) {
395396
return std::move(load);
396397
}
397398

399+
PrimExpr PoolAllocationToOffsetConverter::VisitExpr_(const VarNode* op) {
400+
auto it = allocate_var_to_let_var_.find(GetRef<Var>(op));
401+
if (it != allocate_var_to_let_var_.end()) {
402+
return (*it).second;
403+
}
404+
405+
return StmtExprMutator::VisitExpr_(op);
406+
}
407+
398408
Buffer PoolAllocationToOffsetConverter::GetRemappedBuffer(Buffer original) {
399409
{
400410
auto it = original_buf_to_let_buf_.find(original);

tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -600,5 +600,98 @@ def test_resnet_subgraph():
600600
tvm.ir.assert_structural_equal(actual_func, ref_func)
601601

602602

603+
@tvm.script.ir_module
604+
class TensorIntrinStructure:
605+
@T.prim_func
606+
def tensor_intrin_primfunc() -> None:
607+
dense_data = T.allocate([10], "int32", "global")
608+
T.evaluate(
609+
T.call_extern(
610+
"intrin_function",
611+
T.tvm_access_ptr(
612+
T.type_annotation(dtype="int32"), dense_data, 0, 1, 2, dtype="handle"
613+
),
614+
dtype="int32",
615+
)
616+
)
617+
618+
dense = T.buffer_decl([10], "int32", data=dense_data)
619+
dense[0] = T.q_multiply_shift(dense[0], 1608879842, 31, -7, dtype="int32")
620+
621+
@T.prim_func
622+
def __tvm_main__(input: T.handle, output: T.handle) -> None:
623+
T.evaluate(T.call_extern("tensor_intrin_primfunc", dtype="int32"))
624+
625+
626+
@tvm.script.ir_module
627+
class TensorIntrinStructurePlanned:
628+
@T.prim_func
629+
def tensor_intrin_primfunc(global_workspace_1_var: T.Ptr[T.uint8]) -> None:
630+
global_workspace_1_buffer_var = T.match_buffer(
631+
global_workspace_1_var, [40], dtype="uint8", strides=[1], elem_offset=0, align=16
632+
)
633+
T.preflattened_buffer(
634+
global_workspace_1_buffer_var, [40], dtype="uint8", strides=[1], elem_offset=0, align=16
635+
)
636+
dense_let = T.buffer_decl([10], "int32")
637+
with T.let(dense_let.data, T.address_of(global_workspace_1_buffer_var[0], dtype="handle")):
638+
T.evaluate(
639+
T.call_extern(
640+
"intrin_function",
641+
T.tvm_access_ptr(
642+
T.type_annotation(dtype="int32"), dense_let.data, 0, 1, 2, dtype="handle"
643+
),
644+
dtype="int32",
645+
)
646+
)
647+
dense_let[0] = T.q_multiply_shift(dense_let[0], 1608879842, 31, -7, dtype="int32")
648+
649+
@T.prim_func
650+
def __tvm_main__(
651+
input: T.handle, global_workspace_1_var: T.Ptr[T.uint8], output: T.handle
652+
) -> None:
653+
global_workspace_1_buffer_var = T.match_buffer(
654+
global_workspace_1_var, [40], dtype="uint8", strides=[1], elem_offset=0, align=16
655+
)
656+
T.evaluate(
657+
T.call_extern(
658+
"tensor_intrin_primfunc", global_workspace_1_buffer_var.data, dtype="int32"
659+
)
660+
)
661+
662+
663+
def test_tensor_intrin():
664+
target = Target("c")
665+
global_workspace_pool = WorkspacePoolInfo(
666+
"global_workspace",
667+
[target],
668+
)
669+
670+
tir_mod = TensorIntrinStructure
671+
tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target)
672+
tir_mod = assign_poolinfos_to_allocates_in_irmodule(tir_mod, [global_workspace_pool])
673+
main_func = tir_mod["__tvm_main__"]
674+
buffer_analysis = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod)
675+
buffer_info_map = buffer_analysis.buffer_info_stmts
676+
677+
fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo")
678+
buffer_info_arr = fcreate_array_bi(buffer_info_map)
679+
fusmp_algo_greedy_by_size = tvm.get_global_func("tir.usmp.algo.greedy_by_size")
680+
buffer_pool_allocations = fusmp_algo_greedy_by_size(
681+
buffer_info_arr, buffer_analysis.memory_pressure
682+
)
683+
fassign_stmt_pool_allocations = tvm.get_global_func("tir.usmp.AssignStmtPoolAllocations")
684+
pool_allocations = fassign_stmt_pool_allocations(buffer_info_map, buffer_pool_allocations)
685+
tir_mod_with_offsets = tvm.tir.usmp.transform.convert_pool_allocations_to_offsets(
686+
pool_allocations, emit_tvmscript_printable=True
687+
)(tir_mod)
688+
689+
expected = TensorIntrinStructurePlanned
690+
691+
for gv, ref_func in expected.functions.items():
692+
actual_func = tir_mod_with_offsets[gv.name_hint]
693+
tvm.ir.assert_structural_equal(actual_func, ref_func)
694+
695+
603696
if __name__ == "__main__":
604697
pytest.main([__file__] + sys.argv[1:])

0 commit comments

Comments
 (0)