Skip to content

Commit fa5045b

Browse files
authored
[Metaschedule] MultiLevelTiling for wide vector architectures (#12845)
* [Metaschedule] Introduce MultiLevelTiling for wide vector architecture * update test * format * cpplint
1 parent 52dbf10 commit fa5045b

File tree

7 files changed

+307
-12
lines changed

7 files changed

+307
-12
lines changed

include/tvm/meta_schedule/schedule_rule.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,21 @@ class ScheduleRule : public runtime::ObjectRef {
187187
Optional<Array<Integer>> vector_load_lens, Optional<Map<String, ObjectRef>> reuse_read,
188188
Optional<Map<String, ObjectRef>> reuse_write, bool use_software_pipeline);
189189

190+
/*!
191+
* \brief Extension of MultiLevelTiling for backends with wide vectors.
192+
* The loop over the innermost spatial axis of the output buffer is always vectorized with the
193+
* maximum vector length.
194+
* \param structure The tiling structure. 'SSRSRS' is recommended.
195+
* \param vector_length_in_bits The length of a vector register in bits.
196+
* \param max_innermost_factor The maximum size of the innermost factor. NullOpt means no limit
197+
* \param reuse_read Data reuse configuration for reading. NullOpt means no reuse.
198+
* \param reuse_write Data reuse configuration for writing. NullOpt means no reuse.
199+
* \return The schedule rule created
200+
*/
201+
TVM_DLL static ScheduleRule MultiLevelTilingWideVector(
202+
String structure, Integer vector_length_in_bits, Optional<Integer> max_innermost_factor,
203+
Optional<Map<String, ObjectRef>> reuse_read, Optional<Map<String, ObjectRef>> reuse_write);
204+
190205
/*!
191206
* \brief Create a rule: add-rfactor to some blocks if needed
192207
* \param max_jobs_per_core The maximum number of jobs to be launched per CPU core. It sets the

python/tvm/meta_schedule/schedule_rule/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
MultiLevelTilingWithIntrin,
2929
ReuseType,
3030
MultiLevelTilingTensorCore,
31+
MultiLevelTilingWideVector,
3132
)
3233
from .parallel_vectorize_unroll import ParallelizeVectorizeUnroll
3334
from .random_compute_location import RandomComputeLocation

python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,3 +187,40 @@ def __init__(
187187
reuse_write.as_dict() if reuse_write is not None else None,
188188
use_software_pipeline,
189189
)
190+
191+
192+
@register_object("meta_schedule.MultiLevelTilingWideVector")
193+
class MultiLevelTilingWideVector(ScheduleRule):
194+
"""Extension of MultiLevelTiling for backends with wide vectors. The loop over the innermost
195+
spatial axis of the output buffer is always vectorized with the maximum vector length.
196+
197+
Parameters
198+
----------
199+
structure : str
200+
The tiling structure. 'SSRSRS' is recommended.
201+
vector_length_in_bits: int
202+
The length of a vector register in bits.
203+
max_innermost_factor : Optional[int]
204+
The maximum size of the innermost factor. None means no limit
205+
reuse_read : Optional[ReuseType]
206+
Data reuse configuration for reading. None means no reuse.
207+
reuse_write : Optional[ReuseType]
208+
Data reuse configuration for writing. None means no reuse.
209+
"""
210+
211+
def __init__(
212+
self,
213+
structure: str,
214+
vector_length_in_bits: int,
215+
max_innermost_factor: Optional[int] = None,
216+
reuse_read: Optional[ReuseType] = None,
217+
reuse_write: Optional[ReuseType] = None,
218+
) -> None:
219+
self.__init_handle_by_constructor__(
220+
_ffi_api.ScheduleRuleMultiLevelTilingWideVector, # type: ignore # pylint: disable=no-member
221+
structure,
222+
vector_length_in_bits,
223+
max_innermost_factor,
224+
reuse_read.as_dict() if reuse_read is not None else None,
225+
reuse_write.as_dict() if reuse_write is not None else None,
226+
)

src/meta_schedule/schedule_rule/multi_level_tiling.cc

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,17 @@ std::vector<State> MultiLevelTilingNode::AddWriteReuse(State state) const {
166166
return results;
167167
}
168168

169+
Array<tir::LoopRV> MultiLevelTilingNode::SplitLoop(const Schedule& sch, BlockRV block, LoopRV loop,
170+
int n_tiles) const {
171+
Array<tir::ExprRV> factors = sch->SamplePerfectTile(
172+
/*loop=*/loop,
173+
/*n=*/n_tiles,
174+
/*max_innermost_factor=*/max_innermost_factor);
175+
Array<tir::LoopRV> splits = sch->Split(/*loop=*/loop,
176+
/*factors=*/{factors.begin(), factors.end()});
177+
return splits;
178+
}
179+
169180
std::vector<State> MultiLevelTilingNode::TileLoopNest(State state) const {
170181
Schedule& sch = state->sch;
171182
const BlockRV& block_rv = state->block_rv;
@@ -179,6 +190,7 @@ std::vector<State> MultiLevelTilingNode::TileLoopNest(State state) const {
179190
for (int i = 0, n = loops.size(); i < n; ++i) {
180191
LoopRV loop = loops[i];
181192
const std::vector<int>* idx = nullptr;
193+
182194
if (iter_types[i] == IterVarType::kDataPar) {
183195
idx = &s_indices_;
184196
if (spatial_loop_product != -1) {
@@ -193,17 +205,18 @@ std::vector<State> MultiLevelTilingNode::TileLoopNest(State state) const {
193205
} else {
194206
continue;
195207
}
196-
// Do the split
197-
int n_tiles = idx->size();
198-
Array<tir::ExprRV> factors = sch->SamplePerfectTile(
199-
/*loop=*/loop,
200-
/*n=*/n_tiles,
201-
/*max_innermost_factor=*/max_innermost_factor);
202-
Array<tir::LoopRV> splits = sch->Split(/*loop=*/loop,
203-
/*factors=*/{factors.begin(), factors.end()});
204-
// Put every tile to its slot
205-
for (int j = 0; j < n_tiles; ++j) {
206-
tiles[idx->at(j)].push_back(splits[j]);
208+
209+
const int n_tiles = idx->size();
210+
211+
if (n_tiles == 1) {
212+
tiles[idx->at(0)].push_back(loop);
213+
} else {
214+
auto splits = SplitLoop(sch, block_rv, loop, n_tiles);
215+
216+
// Put every tile to its slot
217+
for (int j = 0; j < n_tiles; ++j) {
218+
tiles[idx->at(j)].push_back(splits[j]);
219+
}
207220
}
208221
}
209222
// Step 3. Reorder to organize the tiles

src/meta_schedule/schedule_rule/multi_level_tiling.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,9 @@ class MultiLevelTilingNode : public ScheduleRuleNode {
161161
protected:
162162
virtual std::vector<State> ApplySubRules(std::vector<State> states);
163163

164+
virtual Array<tir::LoopRV> SplitLoop(const tir::Schedule& sch, tir::BlockRV block,
165+
tir::LoopRV loop, int n_tiles) const;
166+
164167
// Annotate a block to use cooperative fetching
165168
void AnnotateCooperativeFetching(tir::Schedule* sch, const tir::BlockRV& block) const;
166169

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
#include "../../tir/schedule/analysis.h"
21+
#include "../../tir/schedule/transform.h"
22+
#include "../utils.h"
23+
#include "multi_level_tiling.h"
24+
25+
namespace tvm {
26+
namespace meta_schedule {
27+
28+
using tir::BlockRV;
29+
using tir::LoopRV;
30+
using tir::Schedule;
31+
32+
/*!
33+
* \brief Extension of MultiLevelTiling for backends with wide vectors.
34+
* The loop over the innermost spatial axis of the output buffer is always vectorized with the
35+
* maximum vector length.
36+
*/
37+
class MultiLevelTilingWideVectorNode : public MultiLevelTilingNode {
38+
public:
39+
size_t vector_length_in_bits;
40+
41+
static constexpr const char* _type_key = "meta_schedule.MultiLevelTilingWideVector";
42+
TVM_DECLARE_FINAL_OBJECT_INFO(MultiLevelTilingWideVectorNode, MultiLevelTilingNode);
43+
44+
protected:
45+
Array<tir::LoopRV> SplitLoop(const Schedule& sch, BlockRV block, LoopRV loop, int n_tiles) const;
46+
};
47+
48+
Array<tir::LoopRV> MultiLevelTilingWideVectorNode::SplitLoop(const Schedule& sch, BlockRV block_rv,
49+
LoopRV loop_rv, int n_tiles) const {
50+
const tir::ForNode* loop = TVM_SREF_TO_FOR(sch->GetSRef(loop_rv));
51+
const tir::StmtSRef block_sref = sch->GetSRef(block_rv);
52+
const tir::BlockNode* block_node = block_sref->StmtAs<tir::BlockNode>();
53+
const tir::BlockRealize block_realize = tir::GetBlockRealize(sch->state(), block_sref);
54+
ICHECK(block_node && block_node->writes.size() == 1);
55+
56+
const auto out_dtype = block_node->writes[0]->buffer->dtype;
57+
const int vec_len = vector_length_in_bits / out_dtype.bits();
58+
59+
// Determine if this loop is over the innermost axis of the output buffer.
60+
// In the example below, we look for a loop whose loop var is bound to the axis co.
61+
62+
// for (i0, 0, 1) {
63+
// for (i1, 0, 56) {
64+
// for (i2, 0, 56) {
65+
// for (i3, 0, 64) {
66+
// for (i4, 0, 3) {
67+
// for (i5, 0, 3) {
68+
// for (i6, 0, 64) {
69+
// block conv2d_nhwc(...) {
70+
// ...
71+
// bind(co, i3)
72+
// ...
73+
// writes([conv2d_nhwc[n, h, w, co]])
74+
// ...
75+
// conv2d_nhwc[n, h, w, co] = ...
76+
// }
77+
const size_t innermost_axis = block_node->writes[0]->region.size() - 1;
78+
const PrimExpr innermost_iter_value = block_realize->iter_values[innermost_axis];
79+
80+
if (!arith::Analyzer().CanProve(loop->loop_var == innermost_iter_value)) {
81+
// If this is not the innermost spatial loop, split the loop in the normal way.
82+
return MultiLevelTilingNode::SplitLoop(sch, block_rv, loop_rv, n_tiles);
83+
} else {
84+
// We split the innermost spatial loop in a way that always uses the maximum vector length.
85+
const int64_t* extent_int = tir::GetLoopIntExtent(loop);
86+
if (extent_int && *extent_int > vec_len) {
87+
Array<tir::LoopRV> inner_splits = sch->Split(/*loop=*/loop_rv,
88+
/*factors=*/{NullOpt, PrimExpr(vec_len)});
89+
Array<tir::ExprRV> outer_factors = sch->SamplePerfectTile(
90+
/*loop=*/inner_splits[0],
91+
/*n=*/n_tiles - 1,
92+
/*max_innermost_factor=*/max_innermost_factor);
93+
Array<tir::LoopRV> outer_splits = sch->Split(
94+
/*loop=*/inner_splits[0], /*factors=*/{outer_factors.begin(), outer_factors.end()});
95+
outer_splits.push_back(inner_splits[1]);
96+
return outer_splits;
97+
} else {
98+
Array<tir::ExprRV> factors(n_tiles - 1, PrimExpr(1));
99+
factors.push_back(loop->extent);
100+
return sch->Split(/*loop=*/loop_rv,
101+
/*factors=*/{factors.begin(), factors.end()});
102+
}
103+
}
104+
}
105+
106+
ScheduleRule ScheduleRule::MultiLevelTilingWideVector(
107+
String structure, Integer vector_length_in_bits, Optional<Integer> max_innermost_factor,
108+
Optional<Map<String, ObjectRef>> reuse_read, Optional<Map<String, ObjectRef>> reuse_write) {
109+
auto node = MultiLevelTilingInitCommon<MultiLevelTilingWideVectorNode>(
110+
structure, NullOpt, max_innermost_factor, NullOpt, reuse_read, reuse_write);
111+
node->vector_length_in_bits = vector_length_in_bits->value;
112+
return ScheduleRule(node);
113+
}
114+
115+
TVM_REGISTER_NODE_TYPE(MultiLevelTilingWideVectorNode);
116+
TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleMultiLevelTilingWideVector")
117+
.set_body_typed(ScheduleRule::MultiLevelTilingWideVector);
118+
119+
} // namespace meta_schedule
120+
} // namespace tvm

tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py

Lines changed: 107 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# under the License.
1717
# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
1818
from tvm import meta_schedule as ms
19-
from tvm import te
19+
from tvm import te, target
2020
from tvm.meta_schedule.testing import te_workload
2121
from tvm.meta_schedule.testing.schedule_rule import get_rules
2222
from tvm.meta_schedule.testing.space_generation import check_sketches
@@ -521,9 +521,115 @@ def sum_with_trivial_block_iter(
521521
assert not sch.trace.simplified(remove_postproc=True).insts
522522

523523

524+
def test_multi_level_tiling_hexagon():
525+
@T.prim_func
526+
def cpu_conv2d_nhwc(
527+
inputs: T.Buffer[(1, 56, 56, 64), "float16"],
528+
weight: T.Buffer[(3, 3, 64, 64), "float16"],
529+
conv2d_nhwc: T.Buffer[(1, 56, 56, 64), "float16"],
530+
) -> None:
531+
# function attr dict
532+
T.func_attr({"global_symbol": "main", "tir.noalias": True})
533+
# body
534+
# with T.block("root")
535+
PadInput = T.alloc_buffer([1, 58, 58, 64], dtype="float16")
536+
for i0, i1, i2, i3 in T.grid(1, 58, 58, 64):
537+
with T.block("PadInput"):
538+
i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3])
539+
T.reads(inputs[i0_1, i1_1 - 1, i2_1 - 1, i3_1])
540+
T.writes(PadInput[i0_1, i1_1, i2_1, i3_1])
541+
PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(
542+
1 <= i1_1 and i1_1 < 57 and 1 <= i2_1 and i2_1 < 57,
543+
inputs[i0_1, i1_1 - 1, i2_1 - 1, i3_1],
544+
T.float16(0),
545+
dtype="float16",
546+
)
547+
for (
548+
i0_0,
549+
i1_0,
550+
i2_0,
551+
i3_0,
552+
i4_0,
553+
i5_0,
554+
i6_0,
555+
i0_1_1,
556+
i1_1_1,
557+
i2_1_1,
558+
i3_1_1,
559+
i4_1,
560+
i5_1,
561+
i6_1,
562+
i0_2,
563+
i1_2,
564+
i2_2,
565+
i3_2,
566+
) in T.grid(1, 1, 2, 1, 3, 3, 16, 1, 14, 2, 1, 1, 1, 4, 1, 4, 14, 64):
567+
with T.block("conv2d_nhwc"):
568+
n = T.axis.spatial(1, i0_1_1 + i0_2 + i0_0)
569+
h = T.axis.spatial(56, i1_0 * 56 + i1_1_1 * 4 + i1_2)
570+
w = T.axis.spatial(56, i2_0 * 28 + i2_1_1 * 14 + i2_2)
571+
co = T.axis.spatial(64, i3_0 * 64 + i3_1_1 * 64 + i3_2)
572+
rh = T.axis.reduce(3, i4_1 + i4_0)
573+
rw = T.axis.reduce(3, i5_0 + i5_1)
574+
rc = T.axis.reduce(64, i6_0 * 4 + i6_1)
575+
T.reads(PadInput[n, h + rh, w + rw, co // 64 * 64 + rc], weight[rh, rw, rc, co])
576+
T.writes(conv2d_nhwc[n, h, w, co])
577+
T.block_attr({"meta_schedule.tiling_structure": "SRSRS"})
578+
with T.init():
579+
conv2d_nhwc[n, h, w, co] = T.float16(0)
580+
conv2d_nhwc[n, h, w, co] = (
581+
conv2d_nhwc[n, h, w, co]
582+
+ PadInput[n, h + rh, w + rw, co // 64 * 64 + rc] * weight[rh, rw, rc, co]
583+
)
584+
585+
target_hexagon = target.hexagon("v69", num_cores=4)
586+
587+
I = 64
588+
O = 64
589+
H = 56
590+
W = 56
591+
592+
mod = te.create_prim_func(
593+
te_workload.conv2d_nhwc(1, H, W, I, O, 3, 1, 1, 1, in_dtype="float16", out_dtype="float16")
594+
)
595+
596+
actual = ms.TuneContext(
597+
mod=mod,
598+
target=Target(target_hexagon, host=target_hexagon),
599+
space_generator=ms.space_generator.PostOrderApply(),
600+
sch_rules=[
601+
ms.schedule_rule.MultiLevelTilingWideVector(
602+
structure="SRSRS",
603+
vector_length_in_bits=1024,
604+
max_innermost_factor=64,
605+
reuse_read=None,
606+
reuse_write=None,
607+
)
608+
],
609+
task_name="test",
610+
).generate_design_space()
611+
612+
decision_0 = [
613+
("SamplePerfectTile", [1, 1, 1]),
614+
("SamplePerfectTile", [1, 14, 4]),
615+
("SamplePerfectTile", [2, 2, 14]),
616+
("SamplePerfectTile", [3, 1]),
617+
("SamplePerfectTile", [3, 1]),
618+
("SamplePerfectTile", [16, 4]),
619+
]
620+
621+
check_sketches(
622+
mod,
623+
sketches=actual,
624+
expected_mods=[cpu_conv2d_nhwc],
625+
expected_decisions=[decision_0],
626+
)
627+
628+
524629
if __name__ == "__main__":
525630
test_cpu_matmul()
526631
test_cpu_matmul_relu()
527632
test_cuda_matmul()
528633
test_cuda_matmul_relu()
529634
test_cuda_sum_with_trivial_block_iter()
635+
test_multi_level_tiling_hexagon()

0 commit comments

Comments
 (0)