From ea5288208b009bb7aaf93bd3cf0f4603716a3198 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Mon, 24 May 2021 20:10:07 -0500 Subject: [PATCH] Make local_rv_size_lift a local optimization and simplify tests --- aesara/tensor/random/opt.py | 20 +-- tests/tensor/random/test_opt.py | 238 ++++++++++++++------------------ 2 files changed, 116 insertions(+), 142 deletions(-) diff --git a/aesara/tensor/random/opt.py b/aesara/tensor/random/opt.py index f1aa462817..7bbef84761 100644 --- a/aesara/tensor/random/opt.py +++ b/aesara/tensor/random/opt.py @@ -40,19 +40,21 @@ def random_make_inplace(fgraph, node): ) -def lift_rv_shapes(node): - """Lift `RandomVariable`'s shape-related parameters. +@local_optimizer(tracks=None) +def local_rv_size_lift(fgraph, node): + """Lift the ``size`` parameter in a ``RandomVariable``. - In other words, this will broadcast the distribution parameters and - extra dimensions added by the `size` parameter. + In other words, this will broadcast the distribution parameters by adding + the extra dimensions implied by the ``size`` parameter, and remove the + ``size`` parameter in the process. - For example, ``normal([0.0, 1.0], 5.0, size=(3, 2))`` becomes - ``normal([[0., 1.], [0., 1.], [0., 1.]], [[5., 5.], [5., 5.], [5., 5.]])``. + For example, ``normal(0, 1, size=(1, 2))`` becomes + ``normal([[0, 0]], [[1, 1]], size=())``. """ if not isinstance(node.op, RandomVariable): - return False + return rng, size, dtype, *dist_params = node.inputs @@ -65,13 +67,15 @@ def lift_rv_shapes(node): ) for p in dist_params ] + else: + return new_node = node.op.make_node(rng, None, dtype, *dist_params) if config.compute_test_value != "off": compute_test_value(new_node) - return new_node + return new_node.outputs @local_optimizer([DimShuffle]) diff --git a/tests/tensor/random/test_opt.py b/tests/tensor/random/test_opt.py index c31271a1e8..2e8bb7ca38 100644 --- a/tests/tensor/random/test_opt.py +++ b/tests/tensor/random/test_opt.py @@ -19,26 +19,59 @@ ) from aesara.tensor.random.op import RandomVariable from aesara.tensor.random.opt import ( - lift_rv_shapes, local_dimshuffle_rv_lift, + local_rv_size_lift, local_subtensor_rv_lift, ) from aesara.tensor.subtensor import AdvancedSubtensor, AdvancedSubtensor1, Subtensor from aesara.tensor.type import iscalar, vector -inplace_mode = Mode( - "py", OptimizationQuery(include=["random_make_inplace"], exclude=[]) -) no_mode = Mode("py", OptimizationQuery(include=[], exclude=[])) +def apply_local_opt_to_rv(opt, op_fn, dist_op, dist_params, size, rng): + dist_params_aet = [] + for p in dist_params: + p_aet = aet.as_tensor(p).type() + p_aet.tag.test_value = p + dist_params_aet.append(p_aet) + + size_aet = [] + for s in size: + s_aet = iscalar() + s_aet.tag.test_value = s + size_aet.append(s_aet) + + dist_st = op_fn(dist_op(*dist_params_aet, size=size_aet, rng=rng)) + + f_inputs = [ + p for p in dist_params_aet + size_aet if not isinstance(p, (slice, Constant)) + ] + + mode = Mode("py", EquilibriumOptimizer([opt], max_use_ratio=100)) + + f_opt = function( + f_inputs, + dist_st, + mode=mode, + ) + + (new_out,) = f_opt.maker.fgraph.outputs + + return new_out, f_inputs, dist_st, f_opt + + def test_inplace_optimization(): out = normal(0, 1) assert out.owner.op.inplace is False + inplace_mode = Mode( + "py", OptimizationQuery(include=["random_make_inplace"], exclude=[]) + ) + f = function( [], out, @@ -55,80 +88,62 @@ def test_inplace_optimization(): ) -def check_shape_lifted_rv(rv, params, size, rng): - aet_params = [] - for p in params: - p_aet = aet.as_tensor(p) - p_aet = p_aet.type() - p_aet.tag.test_value = p - aet_params.append(p_aet) - - aet_size = [] - for s in size: - s_aet = aet.as_tensor(s) - s_aet = s_aet.type() - s_aet.tag.test_value = s - aet_size.append(s_aet) - - rv = rv(*aet_params, size=aet_size, rng=rng) - rv_lifted = lift_rv_shapes(rv.owner) - - # Make sure the size input is empty - assert np.array_equal(rv_lifted.inputs[1].data, []) - - f_ref = function( - aet_params + aet_size, - rv, - mode=no_mode, - ) - f_lifted = function( - aet_params + aet_size, - rv_lifted.outputs[1], - mode=no_mode, - ) - f_ref_val = f_ref(*(params + size)) - f_lifted_val = f_lifted(*(params + size)) - assert np.array_equal(f_ref_val, f_lifted_val) - - @config.change_flags(compute_test_value="raise") -def test_lift_rv_shapes(): - +@pytest.mark.parametrize( + "dist_op, dist_params, size", + [ + ( + normal, + [ + np.array(1.0, dtype=config.floatX), + np.array(5.0, dtype=config.floatX), + ], + [], + ), + ( + normal, + [ + np.array([0.0, 1.0], dtype=config.floatX), + np.array(5.0, dtype=config.floatX), + ], + [], + ), + ( + normal, + [ + np.array([0.0, 1.0], dtype=config.floatX), + np.array(5.0, dtype=config.floatX), + ], + [3, 2], + ), + ( + multivariate_normal, + [ + np.array([[0], [10], [100]], dtype=config.floatX), + np.diag(np.array([1e-6], dtype=config.floatX)), + ], + [2, 3], + ), + ( + dirichlet, + [np.array([[100, 1, 1], [1, 100, 1], [1, 1, 100]], dtype=config.floatX)], + [2, 3], + ), + ], +) +def test_local_rv_size_lift(dist_op, dist_params, size): rng = shared(np.random.RandomState(1233532), borrow=False) - test_params = [ - np.array(1.0, dtype=config.floatX), - np.array(5.0, dtype=config.floatX), - ] - test_size = [] - check_shape_lifted_rv(normal, test_params, test_size, rng) - - test_params = [ - np.array([0.0, 1.0], dtype=config.floatX), - np.array(5.0, dtype=config.floatX), - ] - test_size = [] - check_shape_lifted_rv(normal, test_params, test_size, rng) - - test_params = [ - np.array([0.0, 1.0], dtype=config.floatX), - np.array(5.0, dtype=config.floatX), - ] - test_size = [3, 2] - check_shape_lifted_rv(normal, test_params, test_size, rng) - - test_params = [ - np.array([[0], [10], [100]], dtype=config.floatX), - np.diag(np.array([1e-6], dtype=config.floatX)), - ] - test_size = [2, 3] - check_shape_lifted_rv(multivariate_normal, test_params, test_size, rng) + new_out, f_inputs, dist_st, f_opt = apply_local_opt_to_rv( + local_rv_size_lift, + lambda rv: rv, + dist_op, + dist_params, + size, + rng, + ) - test_params = [ - np.array([[100, 1, 1], [1, 100, 1], [1, 1, 100]], dtype=config.floatX) - ] - test_size = [2, 3] - check_shape_lifted_rv(dirichlet, test_params, test_size, rng) + assert aet.get_vector_length(new_out.owner.inputs[1]) == 0 @pytest.mark.parametrize( @@ -274,36 +289,15 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol): rng = shared(np.random.RandomState(1233532), borrow=False) - dist_params_aet = [] - for p in dist_params: - p_aet = aet.as_tensor(p).type() - p_aet.tag.test_value = p - dist_params_aet.append(p_aet) - - size_aet = [] - for s in size: - s_aet = iscalar() - s_aet.tag.test_value = s - size_aet.append(s_aet) - - dist_st = dist_op(*dist_params_aet, size=size_aet, rng=rng).dimshuffle(ds_order) - - f_inputs = [ - p for p in dist_params_aet + size_aet if not isinstance(p, (slice, Constant)) - ] - - mode = Mode( - "py", EquilibriumOptimizer([local_dimshuffle_rv_lift], max_use_ratio=100) - ) - - f_opt = function( - f_inputs, - dist_st, - mode=mode, + new_out, f_inputs, dist_st, f_opt = apply_local_opt_to_rv( + local_dimshuffle_rv_lift, + lambda rv: rv.dimshuffle(ds_order), + dist_op, + dist_params, + size, + rng, ) - (new_out,) = f_opt.maker.fgraph.outputs - if lifted: assert new_out.owner.op == dist_op assert all( @@ -407,23 +401,10 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol): ) @config.change_flags(compute_test_value_opt="raise", compute_test_value="raise") def test_Subtensor_lift(indices, lifted, dist_op, dist_params, size): + from aesara.tensor.subtensor import as_index_constant rng = shared(np.random.RandomState(1233532), borrow=False) - dist_params_aet = [] - for p in dist_params: - p_aet = aet.as_tensor(p).type() - p_aet.tag.test_value = p - dist_params_aet.append(p_aet) - - size_aet = [] - for s in size: - s_aet = iscalar() - s_aet.tag.test_value = s - size_aet.append(s_aet) - - from aesara.tensor.subtensor import as_index_constant - indices_aet = () for i in indices: i_aet = as_index_constant(i) @@ -431,26 +412,15 @@ def test_Subtensor_lift(indices, lifted, dist_op, dist_params, size): i_aet.tag.test_value = i indices_aet += (i_aet,) - dist_st = dist_op(*dist_params_aet, size=size_aet, rng=rng)[indices_aet] - - f_inputs = [ - p - for p in dist_params_aet + size_aet + list(indices_aet) - if not isinstance(p, (slice, Constant)) - ] - - mode = Mode( - "py", EquilibriumOptimizer([local_subtensor_rv_lift], max_use_ratio=100) + new_out, f_inputs, dist_st, f_opt = apply_local_opt_to_rv( + local_subtensor_rv_lift, + lambda rv: rv[indices_aet], + dist_op, + dist_params, + size, + rng, ) - f_opt = function( - f_inputs, - dist_st, - mode=mode, - ) - - (new_out,) = f_opt.maker.fgraph.outputs - if lifted: assert isinstance(new_out.owner.op, RandomVariable) assert all(