Skip to content
This repository has been archived by the owner on Aug 22, 2023. It is now read-only.

Commit

Permalink
Preserve the UNSPECIFIED conversion status when not converting by def…
Browse files Browse the repository at this point in the history
…ault. Fix the conversion decorator tot only apply functools.wraps if the target is a function.

PiperOrigin-RevId: 254323896
  • Loading branch information
Dan Moldovan authored and tensorflower-gardener committed Jun 21, 2019
1 parent 9fe47db commit 40b8579
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 8 deletions.
34 changes: 26 additions & 8 deletions tensorflow/python/autograph/impl/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,14 +159,18 @@ def tf_convert(f, ctx, convert_by_default=True, force_conversion=False):
f_wrapper = f
decorators, f = tf_decorator.unwrap(f)

apply_autograph = ((ctx.status == ag_ctx.Status.ENABLED) or
(convert_by_default and
ctx.status == ag_ctx.Status.UNSPECIFIED))
if apply_autograph:
# TODO(mdan): Grab features from context.
# TODO(mdan): Grab features from context.
if ctx.status == ag_ctx.Status.ENABLED:
wrapper = convert(recursive=True, force_conversion=force_conversion)(f)
else:
elif ctx.status == ag_ctx.Status.DISABLED:
wrapper = do_not_convert(f)
elif ctx.status == ag_ctx.Status.UNSPECIFIED:
if convert_by_default:
wrapper = convert(recursive=True, force_conversion=force_conversion)(f)
else:
wrapper = call_with_unspecified_conversion_status(f)
else:
raise ValueError(ctx.status)

if decorators:
wrapper = tf_decorator.rewrap(f_wrapper, f, wrapper)
Expand Down Expand Up @@ -246,6 +250,19 @@ class RunMode(Enum):
PY_FUNC = 2


def call_with_unspecified_conversion_status(func):
"""Decorator that resets the conversion context to the unspecified status."""
def wrapper(*args, **kwargs):
with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.UNSPECIFIED):
return func(*args, **kwargs)

if inspect.isfunction(func) or inspect.ismethod(func):
wrapper = functools.update_wrapper(wrapper, func)

setattr(wrapper, '__ag_compiled', True)
return wrapper


def do_not_convert_internal(f):
"""Decorator that marks internal functions which do not need conversion."""
setattr(f, '__ag_compiled', True)
Expand Down Expand Up @@ -279,12 +296,10 @@ def do_not_convert(func=None, run_as=RunMode.GRAPH, return_dtypes=None):
run_as=run_as,
return_dtypes=return_dtypes)

@functools.wraps(func)
def graph_wrapper(*args, **kwargs):
with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.DISABLED):
return func(*args, **kwargs)

@functools.wraps(func)
def py_func_wrapper(*args, **kwargs):
if kwargs:
raise NotImplementedError('RunMode.PY_FUNC does not yet support kwargs')
Expand All @@ -299,6 +314,9 @@ def py_func_wrapper(*args, **kwargs):
else:
raise ValueError('unknown value for run_as: %s' % run_as)

if inspect.isfunction(func) or inspect.ismethod(func):
wrapper = functools.update_wrapper(wrapper, func)

setattr(wrapper, '__ag_compiled', True)
return wrapper

Expand Down
33 changes: 33 additions & 0 deletions tensorflow/python/autograph/impl/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,16 @@ def test_method(self, x, y):
self.assertEqual((),
tuple(function_utils.fn_args(tc.test_method_whitelisted)))

def test_do_not_convert_callable_object(self):

class TestClass(object):

def __call__(self):
return 1

tc = TestClass()
self.assertEqual(1, api.do_not_convert(tc)())

@test_util.run_deprecated_v1
def test_convert_call_site_decorator(self):

Expand Down Expand Up @@ -729,6 +739,12 @@ def converted_fn():
self.assertEqual(
ag_ctx.control_status_ctx().status, ag_ctx.Status.UNSPECIFIED)

@api.call_with_unspecified_conversion_status
def unspecified_fn():
self.assertEqual(
ag_ctx.control_status_ctx().status, ag_ctx.Status.UNSPECIFIED)
unspecified_fn()

def test_to_graph_basic(self):

def test_fn(x, s):
Expand Down Expand Up @@ -888,6 +904,23 @@ def test_fn(ctx):
# The code in `f` is only valid with AutoGraph.
test_fn(ag_ctx.ControlStatusCtx(status=ag_ctx.Status.DISABLED))

def test_tf_convert_unspecified_not_converted_by_default(self):

def f():
self.assertEqual(
ag_ctx.control_status_ctx().status, ag_ctx.Status.UNSPECIFIED)
if tf.reduce_sum([1, 2]) > 0:
return -1
return 1

@def_function.function
def test_fn(ctx):
return api.tf_convert(f, ctx, convert_by_default=False)()

with self.assertRaisesRegex(TypeError, 'tf.Tensor.*bool'):
# The code in `f` is only valid with AutoGraph.
test_fn(ag_ctx.ControlStatusCtx(status=ag_ctx.Status.UNSPECIFIED))

def test_tf_convert_whitelisted_method(self):

model = sequential.Sequential([
Expand Down

0 comments on commit 40b8579

Please sign in to comment.