From 40b85798db48d8885ec84736d448e922c369acb5 Mon Sep 17 00:00:00 2001 From: Dan Moldovan Date: Thu, 20 Jun 2019 20:15:29 -0700 Subject: [PATCH] Preserve the UNSPECIFIED conversion status when not converting by default. Fix the conversion decorator tot only apply functools.wraps if the target is a function. PiperOrigin-RevId: 254323896 --- tensorflow/python/autograph/impl/api.py | 34 +++++++++++++++----- tensorflow/python/autograph/impl/api_test.py | 33 +++++++++++++++++++ 2 files changed, 59 insertions(+), 8 deletions(-) diff --git a/tensorflow/python/autograph/impl/api.py b/tensorflow/python/autograph/impl/api.py index 90c261a5ce9a53..689a3a10bd211c 100644 --- a/tensorflow/python/autograph/impl/api.py +++ b/tensorflow/python/autograph/impl/api.py @@ -159,14 +159,18 @@ def tf_convert(f, ctx, convert_by_default=True, force_conversion=False): f_wrapper = f decorators, f = tf_decorator.unwrap(f) - apply_autograph = ((ctx.status == ag_ctx.Status.ENABLED) or - (convert_by_default and - ctx.status == ag_ctx.Status.UNSPECIFIED)) - if apply_autograph: - # TODO(mdan): Grab features from context. + # TODO(mdan): Grab features from context. + if ctx.status == ag_ctx.Status.ENABLED: wrapper = convert(recursive=True, force_conversion=force_conversion)(f) - else: + elif ctx.status == ag_ctx.Status.DISABLED: wrapper = do_not_convert(f) + elif ctx.status == ag_ctx.Status.UNSPECIFIED: + if convert_by_default: + wrapper = convert(recursive=True, force_conversion=force_conversion)(f) + else: + wrapper = call_with_unspecified_conversion_status(f) + else: + raise ValueError(ctx.status) if decorators: wrapper = tf_decorator.rewrap(f_wrapper, f, wrapper) @@ -246,6 +250,19 @@ class RunMode(Enum): PY_FUNC = 2 +def call_with_unspecified_conversion_status(func): + """Decorator that resets the conversion context to the unspecified status.""" + def wrapper(*args, **kwargs): + with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.UNSPECIFIED): + return func(*args, **kwargs) + + if inspect.isfunction(func) or inspect.ismethod(func): + wrapper = functools.update_wrapper(wrapper, func) + + setattr(wrapper, '__ag_compiled', True) + return wrapper + + def do_not_convert_internal(f): """Decorator that marks internal functions which do not need conversion.""" setattr(f, '__ag_compiled', True) @@ -279,12 +296,10 @@ def do_not_convert(func=None, run_as=RunMode.GRAPH, return_dtypes=None): run_as=run_as, return_dtypes=return_dtypes) - @functools.wraps(func) def graph_wrapper(*args, **kwargs): with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.DISABLED): return func(*args, **kwargs) - @functools.wraps(func) def py_func_wrapper(*args, **kwargs): if kwargs: raise NotImplementedError('RunMode.PY_FUNC does not yet support kwargs') @@ -299,6 +314,9 @@ def py_func_wrapper(*args, **kwargs): else: raise ValueError('unknown value for run_as: %s' % run_as) + if inspect.isfunction(func) or inspect.ismethod(func): + wrapper = functools.update_wrapper(wrapper, func) + setattr(wrapper, '__ag_compiled', True) return wrapper diff --git a/tensorflow/python/autograph/impl/api_test.py b/tensorflow/python/autograph/impl/api_test.py index ea38988f7fc882..0ade2c3094eba1 100644 --- a/tensorflow/python/autograph/impl/api_test.py +++ b/tensorflow/python/autograph/impl/api_test.py @@ -212,6 +212,16 @@ def test_method(self, x, y): self.assertEqual((), tuple(function_utils.fn_args(tc.test_method_whitelisted))) + def test_do_not_convert_callable_object(self): + + class TestClass(object): + + def __call__(self): + return 1 + + tc = TestClass() + self.assertEqual(1, api.do_not_convert(tc)()) + @test_util.run_deprecated_v1 def test_convert_call_site_decorator(self): @@ -729,6 +739,12 @@ def converted_fn(): self.assertEqual( ag_ctx.control_status_ctx().status, ag_ctx.Status.UNSPECIFIED) + @api.call_with_unspecified_conversion_status + def unspecified_fn(): + self.assertEqual( + ag_ctx.control_status_ctx().status, ag_ctx.Status.UNSPECIFIED) + unspecified_fn() + def test_to_graph_basic(self): def test_fn(x, s): @@ -888,6 +904,23 @@ def test_fn(ctx): # The code in `f` is only valid with AutoGraph. test_fn(ag_ctx.ControlStatusCtx(status=ag_ctx.Status.DISABLED)) + def test_tf_convert_unspecified_not_converted_by_default(self): + + def f(): + self.assertEqual( + ag_ctx.control_status_ctx().status, ag_ctx.Status.UNSPECIFIED) + if tf.reduce_sum([1, 2]) > 0: + return -1 + return 1 + + @def_function.function + def test_fn(ctx): + return api.tf_convert(f, ctx, convert_by_default=False)() + + with self.assertRaisesRegex(TypeError, 'tf.Tensor.*bool'): + # The code in `f` is only valid with AutoGraph. + test_fn(ag_ctx.ControlStatusCtx(status=ag_ctx.Status.UNSPECIFIED)) + def test_tf_convert_whitelisted_method(self): model = sequential.Sequential([