Skip to content

Commit

Permalink
Cast output of local_func_inv and local_exp_log to float when needed
Browse files Browse the repository at this point in the history
  • Loading branch information
Ricardo committed Aug 19, 2021
1 parent 43ed901 commit d0a9488
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 2 deletions.
14 changes: 12 additions & 2 deletions aesara/tensor/math_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,12 @@ def local_func_inv(fgraph, node):
if is_inverse_pair(node_op, prev_op, inv_pair):
# We don't need to copy stack trace, because the optimization
# is trivial and maintains the earlier stack trace
return x.owner.inputs
ottype = node.out.dtype
inp = x.owner.inputs[0]
# Functions may have casted integer input to float
if inp.dtype != ottype:
inp = cast(inp, ottype)
return [inp]

return

Expand All @@ -246,7 +251,12 @@ def local_exp_log(fgraph, node):

# Case for log(exp(x))
if isinstance(prev_op, aes.Exp) and isinstance(node_op, aes.Log):
return x.owner.inputs
new_out = x.owner.inputs[0]
old_out = node.outputs[0]
# Exp may have casted integer input to float
if new_out.dtype != old_out.dtype:
new_out = cast(new_out, old_out.dtype)
return [new_out]

# Case for exp(softplus(x)) aka exp(log1pexp)
if isinstance(prev_op, aes_math.Softplus) and isinstance(node_op, aes.Exp):
Expand Down
21 changes: 21 additions & 0 deletions tests/tensor/test_math_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2488,6 +2488,16 @@ def test(self):
self.assert_func_pair_optimized(rad2deg, rad2deg, dx, should_copy=False)
self.assert_func_pair_optimized(rad2deg, cosh, dx, should_copy=False)

def test_integer_upcast(self):
"""
All invertible methods (except for `Neg`) can upgrade their input to float.
Here we test that the rewrite works with just one pair of methods
"""
x = ivector("x")
f = function([x], deg2rad(rad2deg(x)), mode=self.mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1


class TestExpLog:
def setup_method(self):
Expand All @@ -2512,6 +2522,17 @@ def test_log_exp(self):
assert len(ops_graph) == 0
np.testing.assert_array_equal(f(data), data)

def test_log_exp_integer_upcast(self):
x = ivector("x")
f = function([x], log(exp(x)), mode=self.mode)
ops_graph = [
node
for node in f.maker.fgraph.toposort()
if isinstance(node.op, Elemwise)
and isinstance(node.op.scalar_op, (aes.Log, aes.Exp))
]
assert len(ops_graph) == 0

def test_exp_log(self):
# exp(log(x)) -> switch(x >= 0, x, nan)
data_valid = np.random.random((4, 3)).astype("float32")
Expand Down

0 comments on commit d0a9488

Please sign in to comment.