Skip to content

Commit 9160dc4

Browse files
authored
Add sliding_window operator (#9816)
* Add windows operator * remove TODO * Convert ICHECKs to CHECKs * Report errors using diagnostic context * Use more readable CHECKs * Remove example; move comments to test * Revert "Remove example; move comments to test" This is a partial revert. This reverts commit c810c2db7637ce9537adc49d1016caddd5093d3a. * Add newline to fix Sphinx error * windows -> sliding_window * whitespace * fmt
1 parent 44fe7ef commit 9160dc4

File tree

9 files changed

+327
-0
lines changed

9 files changed

+327
-0
lines changed

include/tvm/relay/attrs/transform.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,29 @@
3333
namespace tvm {
3434
namespace relay {
3535

36+
/*! \brief Attributes used for the sliding_window operator */
37+
struct SlidingWindowAttrs : public tvm::AttrsNode<SlidingWindowAttrs> {
38+
int axis;
39+
Array<Integer> window_shape;
40+
Array<Integer> strides;
41+
TVM_DECLARE_ATTRS(SlidingWindowAttrs, "relay.attrs.SlidingWindowAttrs") {
42+
TVM_ATTR_FIELD(axis).describe(
43+
"What axis the sliding window begin forming over."
44+
"Window will be slid over this axis and all following axes."
45+
"The axis value determines the window shape (and thus, the"
46+
"number of strides):"
47+
"window shape and strides must both be of length"
48+
"`data.ndim-axis`.");
49+
TVM_ATTR_FIELD(window_shape)
50+
.describe(
51+
"The window shape to form over the input."
52+
"Window shape must be of length `data.ndim-axis`.");
53+
TVM_ATTR_FIELD(strides).describe(
54+
"How to stride the window along each dimension."
55+
"Strides must be of length `data.ndim-axis`.");
56+
}
57+
}; // struct SlidingWindowAttrs
58+
3659
/*! \brief data type cast */
3760
struct CastAttrs : public tvm::AttrsNode<CastAttrs> {
3861
DataType dtype;

include/tvm/topi/transform.h

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,90 @@ namespace topi {
4747
using namespace tvm::te;
4848
using namespace topi::detail;
4949

50+
/*!
51+
* \brief Creates an operation to slide a window over the input x.
52+
*
53+
* \param x The input tensor.
54+
* \param axis What axis the window begins sliding over. Window will be slid
55+
* over this axis and all following axes. The axis value determines the window
56+
* shape (and thus, the number of strides): window shape and strides must both
57+
* be of length `data.ndim-axis`.
58+
* \param window_shape The window shape to form over the input. Window shape
59+
* must be of length `data.ndim-axis`.
60+
* \param strides How to stride the window along each dimension. Strides must be
61+
* of length `data.ndim-axis`.
62+
* \param name The name of the operation
63+
* \param tag The tag to mark the operation
64+
*
65+
* \return A Tensor whose op member is the sliding_window operation
66+
*/
67+
inline Tensor sliding_window(const Tensor& x, int axis, Array<Integer> window_shape,
68+
Array<Integer> strides, std::string name = "T_sliding_window",
69+
std::string tag = "") {
70+
CHECK_GE(axis, 0);
71+
auto _axis = size_t(axis);
72+
CHECK_LT(_axis, x->shape.size()) << "axis must be a valid dimension index of x.";
73+
CHECK_EQ(x->shape.size() - _axis, window_shape.size())
74+
<< "There must be a window shape for every dimension of x "
75+
<< "over which we are sliding the window.";
76+
CHECK_EQ(strides.size(), window_shape.size()) << "Windows and strides should be the same length.";
77+
78+
// Compute the new shape.
79+
Array<PrimExpr> new_shape;
80+
// Dimensions up until `axis` remain the same.
81+
for (size_t i = 0; i < _axis; ++i) {
82+
new_shape.push_back(x->shape[i]);
83+
}
84+
85+
// New dimensions which result from sliding the window in each dimension. One new dimension per
86+
// window dimension.
87+
for (size_t i = 0; i < window_shape.size(); ++i) {
88+
// Length of the shape along this dimension.
89+
auto dim_len = x->shape[_axis + i];
90+
// Length of the window along this dimension.
91+
auto window_len = window_shape[i];
92+
// Strides along this dimension.
93+
auto stride = strides[i];
94+
95+
new_shape.push_back(floordiv(dim_len - (window_len - 1) + stride - 1, stride));
96+
}
97+
98+
// Dimensions comprising the window.
99+
for (size_t i = 0; i < window_shape.size(); ++i) {
100+
new_shape.push_back(window_shape[i]);
101+
}
102+
103+
ICHECK(new_shape.size() == _axis + 2 * window_shape.size());
104+
105+
return compute(
106+
new_shape,
107+
[&](const Array<Var>& indices) {
108+
// The index at which to index the old tensor x.
109+
Array<PrimExpr> idx;
110+
111+
// Dimensions up until `axis` remain the same.
112+
for (size_t i = 0; i < _axis; ++i) {
113+
idx.push_back(indices[i]);
114+
}
115+
116+
for (size_t i = 0; i < window_shape.size(); ++i) {
117+
// Which window in this dimension we are indexing.
118+
auto window_idx = indices[_axis + i];
119+
// Which index within the window we are indexing.
120+
auto idx_within_window = indices[_axis + window_shape.size() + i];
121+
// Stride value for this dimension.
122+
auto stride = strides[i];
123+
124+
idx.push_back(window_idx * stride + idx_within_window);
125+
}
126+
127+
ICHECK(idx.size() == x->shape.size());
128+
129+
return x(idx);
130+
},
131+
name, tag);
132+
}
133+
50134
/*!
51135
* \brief Creates an operation to insert new dimensions of length 1
52136
*

python/tvm/relay/op/_transform.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,15 @@
7070
# concatenate
7171
_reg.register_schedule("concatenate", strategy.schedule_concatenate)
7272

73+
# sliding_window
74+
@_reg.register_compute("sliding_window")
75+
def compute_sliding_window(attrs, inputs, output_type):
76+
"""Compute definition of sliding_window"""
77+
return [topi.sliding_window(inputs[0], attrs.axis, attrs.window_shape, attrs.strides)]
78+
79+
80+
_reg.register_strategy("sliding_window", strategy.sliding_window_strategy)
81+
7382
# strided_set
7483
@_reg.register_compute("strided_set")
7584
def compute_strided_set(attrs, inputs, output_type):

python/tvm/relay/op/strategy/generic.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1705,6 +1705,28 @@ def uniform_strategy(attrs, inputs, out_type, target):
17051705
return strategy
17061706

17071707

1708+
# sliding_window
1709+
def wrap_compute_sliding_window():
1710+
"""Wrap sliding_window topi compute"""
1711+
1712+
def _compute_sliding_window(attrs, inputs, _):
1713+
return [topi.sliding_window(inputs[0], attrs.axis, attrs.window_shape, attrs.strides)]
1714+
1715+
return _compute_sliding_window
1716+
1717+
1718+
@override_native_generic_func("sliding_window_strategy")
1719+
def sliding_window_strategy(attrs, inputs, out_type, target):
1720+
"""sliding_window generic strategy"""
1721+
strategy = _op.OpStrategy()
1722+
strategy.add_implementation(
1723+
wrap_compute_sliding_window(),
1724+
wrap_topi_schedule(topi.generic.schedule_extern),
1725+
name="sliding_window.generic",
1726+
)
1727+
return strategy
1728+
1729+
17081730
@override_native_generic_func("normal_strategy")
17091731
def normal_strategy(attrs, inputs, out_type, target):
17101732
"""normal generic strategy"""

python/tvm/relay/op/transform.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,64 @@
2525
from .tensor import shape_of
2626

2727

28+
def sliding_window(data, axis, window_shape, strides):
29+
"""Slide a window over the data tensor.
30+
31+
Parameters
32+
----------
33+
data : relay.Expr
34+
The input data to the operator.
35+
36+
axis : int
37+
What axis the window begins sliding over. Window will be slid over
38+
this axis and all following axes. The axis value determines the window
39+
shape (and thus, the number of strides): window shape and strides must
40+
both be of length `data.ndim-axis`.
41+
42+
window_shape : List[int]
43+
The window shape to form over the input. Window shape must be of length
44+
`data.ndim-axis`.
45+
46+
strides : List[int]
47+
How to stride the window along each dimension. Strides must be of length
48+
`data.ndim-axis`.
49+
50+
Returns
51+
-------
52+
result : relay.Expr
53+
The resulting tensor.
54+
55+
Examples
56+
--------
57+
.. code-block:: python
58+
59+
# Slide a window of shape (3, 4, 5) over the x tensor, beginning with
60+
# dimension 1, which slides the window over the two subtensors of
61+
# shape (3, 32, 32).
62+
x = relay.var("x", relay.TensorType((2, 3, 32, 32), "float32"))
63+
y = relay.sliding_window(x, 1, [3, 4, 5], [1, 2, 3])
64+
65+
data = np.random.rand(2, 3, 32, 32).astype("float32")
66+
result = create_executor().evaluate(y, {x: relay.const(data)}).numpy()
67+
68+
# The resulting shape still has batch size 2. Each dimension in
69+
# (1, 15, 10) represents the locations where we were able to
70+
# form a window; that is, we were able to place the window
71+
# in one place along the dimension of length 3, 15 places along
72+
# the dimension of length 32 (when striding by 2), and 10 places
73+
# along the second dimension of length 32 (when striding by 3).
74+
# The remaining dimension (3, 4, 5) represent the formed windows.
75+
assert result.shape == (2, 1, 15, 10, 3, 4, 5)
76+
77+
assert np.array_equal(result[0, 0, 0, 0, :, :, :], data[0, :, 0:4, 0:5])
78+
assert np.array_equal(result[1, 0, 7, 3, :, :, :], data[1, :, 14:18, 9:14])
79+
assert np.array_equal(result[1, 0, 14, 9, :, :, :], data[1, :, 28:32, 27:32])
80+
"""
81+
from .. import _ffi_api as _relay_make
82+
83+
return _relay_make.sliding_window(data, axis, window_shape, strides)
84+
85+
2886
def cast(data, dtype):
2987
"""Cast input tensor to data type.
3088

python/tvm/topi/transform.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -971,3 +971,33 @@ def invert_permutation(data):
971971
r_ind = data[ind]
972972
result[r_ind] = ind
973973
return result
974+
975+
976+
def sliding_window(data, axis, window_shape, strides):
977+
"""Slide a window over the data tensor.
978+
979+
Parameters
980+
----------
981+
data : relay.Expr
982+
The input data to the operator.
983+
984+
axis : int
985+
What axis the window begins sliding over. Window will be slid over
986+
this axis and all following axes. The axis value determines the window
987+
shape (and thus, the number of strides): window shape and strides must
988+
both be of length `data.ndim-axis`.
989+
990+
window_shape : List[int]
991+
The window shape to form over the input. Window shape must be of length
992+
`data.ndim-axis`.
993+
994+
strides : List[int]
995+
How to stride the window along each dimension. Strides must be of length
996+
`data.ndim-axis`.
997+
998+
Returns
999+
-------
1000+
result : relay.Expr
1001+
The resulting tensor.
1002+
"""
1003+
return cpp.sliding_window(data, axis, window_shape, strides)

src/relay/op/tensor/transform.cc

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,77 @@ namespace tvm {
5252
namespace relay {
5353
using tir::IntImmNode;
5454

55+
TVM_REGISTER_NODE_TYPE(SlidingWindowAttrs);
56+
57+
bool SlidingWindowRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
58+
const TypeReporter& reporter) {
59+
// `types` contains: [data, result]
60+
ICHECK_EQ(types.size(), 2);
61+
const auto* data = types[0].as<TensorTypeNode>();
62+
if (data == nullptr) {
63+
reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
64+
<< "SlidingWindow operator expects input to be of TensorType "
65+
<< "but got " << PrettyPrint(types[0]));
66+
return false;
67+
}
68+
const auto* param = attrs.as<SlidingWindowAttrs>();
69+
const int axis = param->axis;
70+
71+
std::vector<IndexExpr> oshape;
72+
73+
// Dimensions up until `axis` remain the same.
74+
for (int i = 0; i < axis; ++i) {
75+
oshape.emplace_back(data->shape[i]);
76+
}
77+
78+
// New dimensions which result from sliding the window in each dimension. One new dimension per
79+
// window dimension.
80+
for (size_t i = 0; i < param->window_shape.size(); ++i) {
81+
// Length of the shape along this dimension.
82+
auto dim_len = data->shape[axis + i];
83+
// Length of the window along this dimension.
84+
auto window_len = param->window_shape[i];
85+
// Strides along this dimension.
86+
auto stride = param->strides[i];
87+
88+
oshape.push_back(floordiv(dim_len - (window_len - 1) + stride - 1, stride));
89+
}
90+
91+
// Dimensions comprising the window.
92+
for (size_t i = 0; i < param->window_shape.size(); ++i) {
93+
oshape.push_back(param->window_shape[i]);
94+
}
95+
96+
reporter->Assign(types[1], TensorType(oshape, data->dtype));
97+
return true;
98+
}
99+
100+
Array<te::Tensor> SlidingWindowCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
101+
const Type& out_type) {
102+
const SlidingWindowAttrs* param = attrs.as<SlidingWindowAttrs>();
103+
ICHECK(param != nullptr);
104+
return {topi::sliding_window(inputs[0], param->axis, param->window_shape, param->strides)};
105+
}
106+
107+
Expr MakeSlidingWindow(Expr data, int axis, Array<Integer> window_shape, Array<Integer> strides) {
108+
auto attrs = make_object<SlidingWindowAttrs>();
109+
attrs->axis = axis;
110+
attrs->window_shape = window_shape;
111+
attrs->strides = strides;
112+
static const Op& op = Op::Get("sliding_window");
113+
return Call(op, {data}, Attrs(attrs), {});
114+
}
115+
116+
TVM_REGISTER_GLOBAL("relay.ir.sliding_window").set_body_typed(MakeSlidingWindow);
117+
118+
RELAY_REGISTER_OP("sliding_window")
119+
.describe(R"code(Slide window over a tensor.)code" TVM_ADD_FILELINE)
120+
.set_num_inputs(1)
121+
.set_attrs_type<SlidingWindowAttrs>()
122+
.add_argument("data", "Tensor", "The input tensor.")
123+
.add_type_rel("SlidingWindow", SlidingWindowRel)
124+
.set_attr<TOpPattern>("TOpPattern", kOpaque);
125+
55126
// relay.cast
56127
TVM_REGISTER_NODE_TYPE(CastAttrs);
57128

src/topi/transform.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ TVM_REGISTER_GLOBAL("topi.reshape").set_body([](TVMArgs args, TVMRetValue* rv) {
5454
*rv = reshape(args[0], args[1]);
5555
});
5656

57+
TVM_REGISTER_GLOBAL("topi.sliding_window").set_body([](TVMArgs args, TVMRetValue* rv) {
58+
*rv = sliding_window(args[0], args[1], args[2], args[3]);
59+
});
60+
5761
TVM_REGISTER_GLOBAL("topi.squeeze").set_body([](TVMArgs args, TVMRetValue* rv) {
5862
*rv = squeeze(args[0], ArrayOrInt(args[1]));
5963
});

tests/python/relay/test_op_level3.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,32 @@ def test_cast():
9191
assert yy.checked_type == relay.TensorType((8, 9, 4), "int32")
9292

9393

94+
def test_sliding_window():
95+
# Slide a window of shape (3, 4, 5) over the x tensor, beginning with
96+
# dimension 1, which slides the window over the two subtensors of shape (3,
97+
# 32, 32).
98+
x = relay.var("x", relay.TensorType((2, 3, 32, 32), "float32"))
99+
y = relay.sliding_window(x, 1, [3, 4, 5], [1, 2, 3])
100+
101+
# The resulting shape still has batch size 2. Each dimension in (1, 15, 10)
102+
# represents the locations where we were able to form a window; that is, we
103+
# were able to place the window in one place along the dimension of length
104+
# 3, 15 places along the dimension of length 32 (when striding by 2), and 10
105+
# places along the second dimension of length 32 (when striding by 3). The
106+
# remaining dimensions (3, 4, 5) represent the formed windows.
107+
yy = run_infer_type(y)
108+
assert yy.checked_type == relay.TensorType((2, 1, 15, 10, 3, 4, 5), "float32")
109+
110+
data = np.random.rand(2, 3, 32, 32).astype("float32")
111+
intrp = create_executor()
112+
result = intrp.evaluate(y, {x: relay.const(data)})
113+
result_np = result.numpy()
114+
assert result_np.shape == (2, 1, 15, 10, 3, 4, 5)
115+
assert np.array_equal(result_np[0, 0, 0, 0, :, :, :], data[0, :, 0:4, 0:5])
116+
assert np.array_equal(result_np[1, 0, 7, 3, :, :, :], data[1, :, 14:18, 9:14])
117+
assert np.array_equal(result_np[1, 0, 14, 9, :, :, :], data[1, :, 28:32, 27:32])
118+
119+
94120
def test_clip():
95121
a = relay.var("a", relay.TensorType((10, 4), "float32"))
96122
y = relay.clip(a, 1.0, 4.0)

0 commit comments

Comments
 (0)