diff --git a/src/core.py b/src/core.py index 91baa4b..9557efd 100644 --- a/src/core.py +++ b/src/core.py @@ -90,7 +90,7 @@ def assert_img_range(img): def decategorize(mask): return iu.decategorize(mask, iu.rgb2wk_map) - with tf.Session() as sess: + with tf.compat.v1.Session() as sess: snet_in = consts.snet_in('0.1.0', sess) snet_out = consts.snet_out('0.1.0', sess) def snet(img): @@ -169,7 +169,7 @@ def inpainted(image, segmap): ''' assert (255 >= image).all(), image.max() assert (image >= 0).all(), image.min() - with tf.Session() as sess: + with tf.compat.v1.Session() as sess: cnet_in = consts.cnet_in('0.1.0',sess) cnet_out = consts.cnet_out('0.1.0',sess) return inpaint(