Skip to content

Commit

Permalink
Make local_rv_size_lift a local optimization and simplify tests
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed May 25, 2021
1 parent 5db98be commit ea52882
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 142 deletions.
20 changes: 12 additions & 8 deletions aesara/tensor/random/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,21 @@ def random_make_inplace(fgraph, node):
)


def lift_rv_shapes(node):
"""Lift `RandomVariable`'s shape-related parameters.
@local_optimizer(tracks=None)
def local_rv_size_lift(fgraph, node):
"""Lift the ``size`` parameter in a ``RandomVariable``.
In other words, this will broadcast the distribution parameters and
extra dimensions added by the `size` parameter.
In other words, this will broadcast the distribution parameters by adding
the extra dimensions implied by the ``size`` parameter, and remove the
``size`` parameter in the process.
For example, ``normal([0.0, 1.0], 5.0, size=(3, 2))`` becomes
``normal([[0., 1.], [0., 1.], [0., 1.]], [[5., 5.], [5., 5.], [5., 5.]])``.
For example, ``normal(0, 1, size=(1, 2))`` becomes
``normal([[0, 0]], [[1, 1]], size=())``.
"""

if not isinstance(node.op, RandomVariable):
return False
return

rng, size, dtype, *dist_params = node.inputs

Expand All @@ -65,13 +67,15 @@ def lift_rv_shapes(node):
)
for p in dist_params
]
else:
return

new_node = node.op.make_node(rng, None, dtype, *dist_params)

if config.compute_test_value != "off":
compute_test_value(new_node)

return new_node
return new_node.outputs


@local_optimizer([DimShuffle])
Expand Down
238 changes: 104 additions & 134 deletions tests/tensor/random/test_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,26 +19,59 @@
)
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.random.opt import (
lift_rv_shapes,
local_dimshuffle_rv_lift,
local_rv_size_lift,
local_subtensor_rv_lift,
)
from aesara.tensor.subtensor import AdvancedSubtensor, AdvancedSubtensor1, Subtensor
from aesara.tensor.type import iscalar, vector


inplace_mode = Mode(
"py", OptimizationQuery(include=["random_make_inplace"], exclude=[])
)
no_mode = Mode("py", OptimizationQuery(include=[], exclude=[]))


def apply_local_opt_to_rv(opt, op_fn, dist_op, dist_params, size, rng):
dist_params_aet = []
for p in dist_params:
p_aet = aet.as_tensor(p).type()
p_aet.tag.test_value = p
dist_params_aet.append(p_aet)

size_aet = []
for s in size:
s_aet = iscalar()
s_aet.tag.test_value = s
size_aet.append(s_aet)

dist_st = op_fn(dist_op(*dist_params_aet, size=size_aet, rng=rng))

f_inputs = [
p for p in dist_params_aet + size_aet if not isinstance(p, (slice, Constant))
]

mode = Mode("py", EquilibriumOptimizer([opt], max_use_ratio=100))

f_opt = function(
f_inputs,
dist_st,
mode=mode,
)

(new_out,) = f_opt.maker.fgraph.outputs

return new_out, f_inputs, dist_st, f_opt


def test_inplace_optimization():

out = normal(0, 1)

assert out.owner.op.inplace is False

inplace_mode = Mode(
"py", OptimizationQuery(include=["random_make_inplace"], exclude=[])
)

f = function(
[],
out,
Expand All @@ -55,80 +88,62 @@ def test_inplace_optimization():
)


def check_shape_lifted_rv(rv, params, size, rng):
aet_params = []
for p in params:
p_aet = aet.as_tensor(p)
p_aet = p_aet.type()
p_aet.tag.test_value = p
aet_params.append(p_aet)

aet_size = []
for s in size:
s_aet = aet.as_tensor(s)
s_aet = s_aet.type()
s_aet.tag.test_value = s
aet_size.append(s_aet)

rv = rv(*aet_params, size=aet_size, rng=rng)
rv_lifted = lift_rv_shapes(rv.owner)

# Make sure the size input is empty
assert np.array_equal(rv_lifted.inputs[1].data, [])

f_ref = function(
aet_params + aet_size,
rv,
mode=no_mode,
)
f_lifted = function(
aet_params + aet_size,
rv_lifted.outputs[1],
mode=no_mode,
)
f_ref_val = f_ref(*(params + size))
f_lifted_val = f_lifted(*(params + size))
assert np.array_equal(f_ref_val, f_lifted_val)


@config.change_flags(compute_test_value="raise")
def test_lift_rv_shapes():

@pytest.mark.parametrize(
"dist_op, dist_params, size",
[
(
normal,
[
np.array(1.0, dtype=config.floatX),
np.array(5.0, dtype=config.floatX),
],
[],
),
(
normal,
[
np.array([0.0, 1.0], dtype=config.floatX),
np.array(5.0, dtype=config.floatX),
],
[],
),
(
normal,
[
np.array([0.0, 1.0], dtype=config.floatX),
np.array(5.0, dtype=config.floatX),
],
[3, 2],
),
(
multivariate_normal,
[
np.array([[0], [10], [100]], dtype=config.floatX),
np.diag(np.array([1e-6], dtype=config.floatX)),
],
[2, 3],
),
(
dirichlet,
[np.array([[100, 1, 1], [1, 100, 1], [1, 1, 100]], dtype=config.floatX)],
[2, 3],
),
],
)
def test_local_rv_size_lift(dist_op, dist_params, size):
rng = shared(np.random.RandomState(1233532), borrow=False)

test_params = [
np.array(1.0, dtype=config.floatX),
np.array(5.0, dtype=config.floatX),
]
test_size = []
check_shape_lifted_rv(normal, test_params, test_size, rng)

test_params = [
np.array([0.0, 1.0], dtype=config.floatX),
np.array(5.0, dtype=config.floatX),
]
test_size = []
check_shape_lifted_rv(normal, test_params, test_size, rng)

test_params = [
np.array([0.0, 1.0], dtype=config.floatX),
np.array(5.0, dtype=config.floatX),
]
test_size = [3, 2]
check_shape_lifted_rv(normal, test_params, test_size, rng)

test_params = [
np.array([[0], [10], [100]], dtype=config.floatX),
np.diag(np.array([1e-6], dtype=config.floatX)),
]
test_size = [2, 3]
check_shape_lifted_rv(multivariate_normal, test_params, test_size, rng)
new_out, f_inputs, dist_st, f_opt = apply_local_opt_to_rv(
local_rv_size_lift,
lambda rv: rv,
dist_op,
dist_params,
size,
rng,
)

test_params = [
np.array([[100, 1, 1], [1, 100, 1], [1, 1, 100]], dtype=config.floatX)
]
test_size = [2, 3]
check_shape_lifted_rv(dirichlet, test_params, test_size, rng)
assert aet.get_vector_length(new_out.owner.inputs[1]) == 0


@pytest.mark.parametrize(
Expand Down Expand Up @@ -274,36 +289,15 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol):

rng = shared(np.random.RandomState(1233532), borrow=False)

dist_params_aet = []
for p in dist_params:
p_aet = aet.as_tensor(p).type()
p_aet.tag.test_value = p
dist_params_aet.append(p_aet)

size_aet = []
for s in size:
s_aet = iscalar()
s_aet.tag.test_value = s
size_aet.append(s_aet)

dist_st = dist_op(*dist_params_aet, size=size_aet, rng=rng).dimshuffle(ds_order)

f_inputs = [
p for p in dist_params_aet + size_aet if not isinstance(p, (slice, Constant))
]

mode = Mode(
"py", EquilibriumOptimizer([local_dimshuffle_rv_lift], max_use_ratio=100)
)

f_opt = function(
f_inputs,
dist_st,
mode=mode,
new_out, f_inputs, dist_st, f_opt = apply_local_opt_to_rv(
local_dimshuffle_rv_lift,
lambda rv: rv.dimshuffle(ds_order),
dist_op,
dist_params,
size,
rng,
)

(new_out,) = f_opt.maker.fgraph.outputs

if lifted:
assert new_out.owner.op == dist_op
assert all(
Expand Down Expand Up @@ -407,50 +401,26 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol):
)
@config.change_flags(compute_test_value_opt="raise", compute_test_value="raise")
def test_Subtensor_lift(indices, lifted, dist_op, dist_params, size):
from aesara.tensor.subtensor import as_index_constant

rng = shared(np.random.RandomState(1233532), borrow=False)

dist_params_aet = []
for p in dist_params:
p_aet = aet.as_tensor(p).type()
p_aet.tag.test_value = p
dist_params_aet.append(p_aet)

size_aet = []
for s in size:
s_aet = iscalar()
s_aet.tag.test_value = s
size_aet.append(s_aet)

from aesara.tensor.subtensor import as_index_constant

indices_aet = ()
for i in indices:
i_aet = as_index_constant(i)
if not isinstance(i_aet, slice):
i_aet.tag.test_value = i
indices_aet += (i_aet,)

dist_st = dist_op(*dist_params_aet, size=size_aet, rng=rng)[indices_aet]

f_inputs = [
p
for p in dist_params_aet + size_aet + list(indices_aet)
if not isinstance(p, (slice, Constant))
]

mode = Mode(
"py", EquilibriumOptimizer([local_subtensor_rv_lift], max_use_ratio=100)
new_out, f_inputs, dist_st, f_opt = apply_local_opt_to_rv(
local_subtensor_rv_lift,
lambda rv: rv[indices_aet],
dist_op,
dist_params,
size,
rng,
)

f_opt = function(
f_inputs,
dist_st,
mode=mode,
)

(new_out,) = f_opt.maker.fgraph.outputs

if lifted:
assert isinstance(new_out.owner.op, RandomVariable)
assert all(
Expand Down

0 comments on commit ea52882

Please sign in to comment.