@@ -47,6 +47,90 @@ namespace topi {
4747using namespace tvm ::te;
4848using namespace topi ::detail;
4949
50+ /* !
51+ * \brief Creates an operation to slide a window over the input x.
52+ *
53+ * \param x The input tensor.
54+ * \param axis What axis the window begins sliding over. Window will be slid
55+ * over this axis and all following axes. The axis value determines the window
56+ * shape (and thus, the number of strides): window shape and strides must both
57+ * be of length `data.ndim-axis`.
58+ * \param window_shape The window shape to form over the input. Window shape
59+ * must be of length `data.ndim-axis`.
60+ * \param strides How to stride the window along each dimension. Strides must be
61+ * of length `data.ndim-axis`.
62+ * \param name The name of the operation
63+ * \param tag The tag to mark the operation
64+ *
65+ * \return A Tensor whose op member is the sliding_window operation
66+ */
67+ inline Tensor sliding_window (const Tensor& x, int axis, Array<Integer> window_shape,
68+ Array<Integer> strides, std::string name = " T_sliding_window" ,
69+ std::string tag = " " ) {
70+ CHECK_GE (axis, 0 );
71+ auto _axis = size_t (axis);
72+ CHECK_LT (_axis, x->shape .size ()) << " axis must be a valid dimension index of x." ;
73+ CHECK_EQ (x->shape .size () - _axis, window_shape.size ())
74+ << " There must be a window shape for every dimension of x "
75+ << " over which we are sliding the window." ;
76+ CHECK_EQ (strides.size (), window_shape.size ()) << " Windows and strides should be the same length." ;
77+
78+ // Compute the new shape.
79+ Array<PrimExpr> new_shape;
80+ // Dimensions up until `axis` remain the same.
81+ for (size_t i = 0 ; i < _axis; ++i) {
82+ new_shape.push_back (x->shape [i]);
83+ }
84+
85+ // New dimensions which result from sliding the window in each dimension. One new dimension per
86+ // window dimension.
87+ for (size_t i = 0 ; i < window_shape.size (); ++i) {
88+ // Length of the shape along this dimension.
89+ auto dim_len = x->shape [_axis + i];
90+ // Length of the window along this dimension.
91+ auto window_len = window_shape[i];
92+ // Strides along this dimension.
93+ auto stride = strides[i];
94+
95+ new_shape.push_back (floordiv (dim_len - (window_len - 1 ) + stride - 1 , stride));
96+ }
97+
98+ // Dimensions comprising the window.
99+ for (size_t i = 0 ; i < window_shape.size (); ++i) {
100+ new_shape.push_back (window_shape[i]);
101+ }
102+
103+ ICHECK (new_shape.size () == _axis + 2 * window_shape.size ());
104+
105+ return compute (
106+ new_shape,
107+ [&](const Array<Var>& indices) {
108+ // The index at which to index the old tensor x.
109+ Array<PrimExpr> idx;
110+
111+ // Dimensions up until `axis` remain the same.
112+ for (size_t i = 0 ; i < _axis; ++i) {
113+ idx.push_back (indices[i]);
114+ }
115+
116+ for (size_t i = 0 ; i < window_shape.size (); ++i) {
117+ // Which window in this dimension we are indexing.
118+ auto window_idx = indices[_axis + i];
119+ // Which index within the window we are indexing.
120+ auto idx_within_window = indices[_axis + window_shape.size () + i];
121+ // Stride value for this dimension.
122+ auto stride = strides[i];
123+
124+ idx.push_back (window_idx * stride + idx_within_window);
125+ }
126+
127+ ICHECK (idx.size () == x->shape .size ());
128+
129+ return x (idx);
130+ },
131+ name, tag);
132+ }
133+
50134/* !
51135 * \brief Creates an operation to insert new dimensions of length 1
52136 *
0 commit comments