Skip to content

Commit 9049d66

Browse files
apivovarovkevinthesun
authored andcommitted
[Relay][Legalize] Legalize conv2d_transpose for NHWC (#4399)
1 parent 87bd799 commit 9049d66

File tree

6 files changed

+165
-5
lines changed

6 files changed

+165
-5
lines changed

python/tvm/relay/op/nn/_nn.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,26 @@ def schedule_conv2d_transpose(attrs, outs, target):
278278
return topi.generic.schedule_conv2d_transpose_nchw(outs)
279279

280280

281+
@reg.register_legalize("nn.conv2d_transpose")
282+
def legalize_conv2d_transpose(attrs, inputs, types):
283+
"""Legalize conv2d_transpose op.
284+
285+
Parameters
286+
----------
287+
attrs : tvm.attrs.Attrs
288+
Attributes of current Transposed convolution
289+
inputs : list of tvm.relay.Expr
290+
The args of the Relay expr to be legalized
291+
types : list of types
292+
List of input and output types
293+
294+
Returns
295+
-------
296+
result : tvm.relay.Expr
297+
The legalized expr
298+
"""
299+
return topi.nn.conv2d_transpose_legalize(attrs, inputs, types)
300+
281301
reg.register_pattern("nn.conv2d_transpose", OpPattern.OUT_ELEMWISE_FUSABLE)
282302

283303
# bias_add

python/tvm/relay/op/op_attrs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,3 +284,8 @@ class BinaryConv2DAttrs(Attrs):
284284
@register_relay_attr_node
285285
class BinaryDenseAttrs(Attrs):
286286
"""Attributes used in bitserial dense operators"""
287+
288+
289+
@register_relay_attr_node
290+
class Conv2DTransposeAttrs(Attrs):
291+
"""Attributes used in Transposed Conv2D operators"""

tests/python/relay/test_op_level2.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -311,8 +311,8 @@ def test_conv2d_transpose_infer_type():
311311
(10, 15, 3, 3), "float32")
312312

313313
# infer by shape of w, mixed precision
314-
n, c, h, w = tvm.var("n"), 10, 10, 12
315-
x = relay.var("x", relay.TensorType((n, c, h, w), "float32"))
314+
n, h, w, c = tvm.var("n"), 10, 10, 12
315+
x = relay.var("x", relay.TensorType((n, h, w, c), "float32"))
316316
w = relay.var("w", relay.TensorType((12, 11, 5, 5), "float32"))
317317
y = relay.nn.conv2d_transpose(x, w,
318318
output_padding=(1, 1),
@@ -323,7 +323,7 @@ def test_conv2d_transpose_infer_type():
323323
(n, 15, 15, 11), "float32")
324324

325325

326-
def test_conv2d_transpose_run():
326+
def test_conv2d_transpose_nchw_run():
327327
dshape = (1, 3, 18, 18)
328328
kshape = (3, 10, 3, 3)
329329
oshape = (1, 10, 37, 37)
@@ -348,6 +348,33 @@ def test_conv2d_transpose_run():
348348
tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5)
349349

350350

351+
def test_conv2d_transpose_nhwc_run():
352+
dshape_nhwc = (1, 18, 18, 3)
353+
kshape_hwoi = (3, 3, 10, 3)
354+
oshape_nhwc = (1, 37, 37, 10)
355+
x = relay.var("x", shape=dshape_nhwc)
356+
w = relay.var("w")
357+
# kshape and kernel_layout should have swapped IO.
358+
# kshape is HWOI and kernel_layout is HWIO
359+
y = relay.nn.conv2d_transpose(x, w,
360+
channels=10, kernel_size=(3, 3), strides=(2, 2),
361+
padding=(1, 1), output_padding=(2, 2),
362+
data_layout="NHWC", kernel_layout="HWIO")
363+
func = relay.Function([x, w], y)
364+
dtype = "float32"
365+
data = np.random.uniform(size=dshape_nhwc).astype(dtype)
366+
kernel = np.random.uniform(size=kshape_hwoi).astype(dtype)
367+
# use true kshape layout here - HWOI
368+
c_np = topi.testing.conv2d_transpose_nhwc_python(data, kernel, 'HWOI', 2, 1)
369+
d_np = np.zeros(shape=oshape_nhwc)
370+
d_np[:,0:c_np.shape[1],0:c_np.shape[2],:] = c_np
371+
ref_res = d_np
372+
373+
for target, ctx in ctx_list():
374+
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
375+
op_res1 = intrp1.evaluate(func)(data, kernel)
376+
tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5)
377+
351378

352379
def test_upsampling_infer_type():
353380
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
@@ -819,7 +846,8 @@ def test_bitpack_infer_type():
819846
test_pad_infer_type()
820847
test_pad_run()
821848
test_conv2d_transpose_infer_type()
822-
test_conv2d_transpose_run()
849+
test_conv2d_transpose_nchw_run()
850+
test_conv2d_transpose_nhwc_run()
823851
test_conv2d_run()
824852
test_conv2d_winograd()
825853
test_bitserial_conv2d_infer_type()

topi/python/topi/nn/conv2d_transpose.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
"""Transposed 2D convolution operators (sometimes called Deconvolution)."""
1919
from __future__ import absolute_import as _abs
2020
import tvm
21+
from tvm import relay
2122
from .dilate import dilate
2223
from .pad import pad
2324
from .util import get_pad_tuple
@@ -102,3 +103,62 @@ def declaration_conv2d_transpose_impl(data, kernel, strides, padding, out_dtype)
102103
axis=[dc, dh, dw]), tag="conv2d_transpose_nchw")
103104

104105
return Output
106+
107+
108+
@tvm.target.generic_func
109+
def conv2d_transpose_legalize(attrs, inputs, types):
110+
"""Legalizes Transposed 2D convolution op.
111+
112+
Parameters
113+
----------
114+
attrs : tvm.attrs.Attrs
115+
Attributes of current Transposed 2D convolution
116+
inputs : list of tvm.relay.Expr
117+
The args of the Relay expr to be legalized
118+
types : list of types
119+
List of input and output types
120+
121+
Returns
122+
-------
123+
result : tvm.relay.Expr
124+
The legalized expr
125+
"""
126+
if attrs['data_layout'] == 'NHWC':
127+
data, kernel = inputs
128+
kernel_layout = attrs['kernel_layout']
129+
# Convert Kernel layout to IOHW
130+
# kernel_layout is different from input kernel layout - IO is swapped
131+
if kernel_layout == 'HWIO':
132+
# input kernel layout is swapped to HWOI
133+
# output kernel layout will be IOHW
134+
kernel = relay.transpose(kernel, axes=(3, 2, 0, 1))
135+
elif kernel_layout == 'HWOI':
136+
# input kernel layout is swapped to HWIO
137+
# output kernel layout will be IOHW
138+
kernel = relay.transpose(kernel, axes=(2, 3, 0, 1))
139+
elif kernel_layout == 'IOHW':
140+
# input kernel layout is swapped to OIHW
141+
# output kernel layout will be IOHW
142+
kernel = relay.transpose(kernel, axes=(1, 0, 2, 3))
143+
elif kernel_layout == 'OIHW':
144+
# input kernel layout is swapped to IOHW
145+
# output kernel layout will be IOHW
146+
pass
147+
else:
148+
# Skip legalize. Let relay.nn.conv2d_transpose to handle the case
149+
return None
150+
151+
# Set new attrs for conv2d_transpose.
152+
new_attrs = {k: attrs[k] for k in attrs.keys()}
153+
new_attrs['data_layout'] = 'NCHW'
154+
# layout of kernel should be IOHW, but kernel_layout should be swapped - OIHW
155+
new_attrs['kernel_layout'] = 'OIHW'
156+
157+
# Convert data to NCHW.
158+
data = relay.transpose(data, axes=(0, 3, 1, 2))
159+
deconv = relay.nn.conv2d_transpose(data, kernel, **new_attrs)
160+
# Convert back to original NHWC layout.
161+
out = relay.transpose(deconv, axes=(0, 2, 3, 1))
162+
return out
163+
164+
return None

topi/python/topi/testing/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from .conv2d_hwcn_python import conv2d_hwcn_python
2525
from .conv2d_nchw_python import conv2d_nchw_python
2626
from .conv2d_nhwc_python import conv2d_nhwc_python
27-
from .conv2d_transpose_nchw_python import conv2d_transpose_nchw_python
27+
from .conv2d_transpose_python import conv2d_transpose_nchw_python, conv2d_transpose_nhwc_python
2828
from .deformable_conv2d_nchw_python import deformable_conv2d_nchw_python
2929
from .depthwise_conv2d_python import depthwise_conv2d_python_nchw, depthwise_conv2d_python_nhwc
3030
from .dilate_python import dilate_python

topi/python/topi/testing/conv2d_transpose_nchw_python.py renamed to topi/python/topi/testing/conv2d_transpose_python.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,50 @@ def conv2d_transpose_nchw_python(a_np, w_np, stride, padding):
7373
padded_a_np[n, c], w_np[c, f], mode='valid')
7474
b_np[n, f] += out
7575
return b_np
76+
77+
78+
def conv2d_transpose_nhwc_python(a_nhwc, weight, weight_format, stride, padding):
79+
"""Transposed convolution operator in NHWC layout.
80+
81+
Parameters
82+
----------
83+
a_nhwc : numpy.ndarray
84+
4-D with shape [batch, in_height, in_width, in_channel]
85+
86+
weight : numpy.ndarray
87+
4-D in formats HWIO, HWOI, OIHW or IOHW
88+
89+
weight_format : str
90+
['HWIO', 'HWOI', 'OIHW', 'IOHW']
91+
92+
stride : int or a list/tuple of two ints
93+
Stride size, or [stride_height, stride_width]
94+
95+
padding : int or str
96+
Padding size, or ['VALID', 'SAME']
97+
98+
Returns
99+
-------
100+
b_np : np.ndarray
101+
4-D with shape [batch, out_channel, out_height, out_width]
102+
"""
103+
assert a_nhwc.ndim == 4, "a_nhwc number of dimensions should be 4"
104+
assert weight.ndim == 4, "weight number of dimensions should be 4"
105+
106+
a_nchw = np.transpose(a_nhwc, (0, 3, 1, 2))
107+
108+
# conv2d_transpose_nchw_python needs kernel layout to be IOHW
109+
if weight_format == 'HWIO':
110+
w_iohw = np.transpose(weight, (2, 3, 0, 1))
111+
elif weight_format == 'HWOI':
112+
w_iohw = np.transpose(weight, (3, 2, 0, 1))
113+
elif weight_format == 'OIHW':
114+
w_iohw = np.transpose(weight, (1, 0, 2, 3))
115+
elif weight_format == 'IOHW':
116+
w_iohw = weight
117+
else:
118+
raise ValueError('Valid weight_formats are HWIO, HWOI, OIHW or IOHW')
119+
120+
res_nchw = conv2d_transpose_nchw_python(a_nchw, w_iohw, stride, padding)
121+
res_nhwc = np.transpose(res_nchw, (0, 2, 3, 1))
122+
return res_nhwc

0 commit comments

Comments
 (0)