Skip to content

Commit

Permalink
Remove default_updates from local variables of a Scan
Browse files Browse the repository at this point in the history
This adds `SharedVariable` construction tracking that allows one to determine
which variables were created within a specific scope (e.g. within a Python
function).  With this ability, we're able to determine which shared variable
update should and shouldn't be performed within the iterations of a `Scan` node.
  • Loading branch information
brandonwillard committed Apr 18, 2022
1 parent 110e345 commit 00e0d80
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 87 deletions.
24 changes: 21 additions & 3 deletions aesara/compile/sharedvalue.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
"""

import copy
import logging
from contextlib import contextmanager
from typing import List, Optional

import numpy as np

Expand All @@ -14,8 +15,20 @@
from aesara.link.c.type import generic


_logger = logging.getLogger("aesara.compile.sharedvalue")
__docformat__ = "restructuredtext en"
__SHARED_CONTEXT__: Optional[List[Variable]] = None


@contextmanager
def collect_new_shareds():
r"""Return all the `SharedVariable`\s created within this context manager."""
global __SHARED_CONTEXT__
old_context = __SHARED_CONTEXT__
context = []
try:
__SHARED_CONTEXT__ = context
yield context
finally:
__SHARED_CONTEXT__ = old_context


class SharedVariable(Variable):
Expand Down Expand Up @@ -85,6 +98,11 @@ def __init__(self, name, type, value, strict, allow_downcast=None, container=Non
allow_downcast=allow_downcast,
)

global __SHARED_CONTEXT__

if isinstance(__SHARED_CONTEXT__, list):
__SHARED_CONTEXT__.append(self)

def get_value(self, borrow=False, return_internal_type=False):
"""
Get the non-symbolic value associated with this SharedVariable.
Expand Down
64 changes: 39 additions & 25 deletions aesara/scan/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import numpy as np

import aesara.tensor as at
from aesara.compile import SharedVariable
from aesara.compile.function.pfunc import construct_pfunc_ins_and_outs
from aesara.compile.sharedvalue import SharedVariable, collect_new_shareds
from aesara.configdefaults import config
from aesara.graph.basic import Constant, Variable, clone_replace, graph_inputs
from aesara.graph.op import get_test_value
Expand Down Expand Up @@ -861,7 +861,10 @@ def wrap_into_list(x):
# when we apply the lambda expression we get a mixture of update rules
# and outputs that needs to be separated

condition, outputs, updates = get_updates_and_outputs(fn(*args))
with collect_new_shareds() as new_shareds:
raw_inner_outputs = fn(*args)

condition, outputs, updates = get_updates_and_outputs(raw_inner_outputs)
if condition is not None:
as_while = True
else:
Expand Down Expand Up @@ -974,13 +977,36 @@ def wrap_into_list(x):
shared_inner_inputs = []
shared_inner_outputs = []
sit_sot_shared = []
no_update_shared_inputs = []
for input in dummy_inputs:
if isinstance(input.variable, SharedVariable) and input.update:
if not isinstance(input.variable, SharedVariable):
continue

is_local = input.variable in new_shareds

# We only want to add shared variable updates that were either
# user-specified within the inner-function (e.g. by returning an update
# `dict`) or the `SharedVariable.default_update`s of a shared variable
# created in the inner-function.
if input.update and (is_local or input.variable in updates):
# We need to remove the `default_update`s on the shared
# variables created within the context of the loop function
# (e.g. via use of `RandomStream`); otherwise, they'll get
# picked up during compilation and produce errors when the
# updates include inner-graph variables.
# We also don't want to remove a default update that applies to
# the scope/context containing this `Scan`, so we only remove
# default updates on "local" variables.
if is_local and hasattr(input.variable, "default_update"):
del input.variable.default_update

new_var = safe_new(input.variable)

if getattr(input.variable, "name", None) is not None:
new_var.name = input.variable.name + "_copy"

inner_replacements[input.variable] = new_var

if isinstance(new_var.type, TensorType):
sit_sot_inner_inputs.append(new_var)
sit_sot_scan_inputs.append(
Expand All @@ -989,6 +1015,7 @@ def wrap_into_list(x):
actual_n_steps,
)
)

tensor_update = at.as_tensor_variable(input.update)
sit_sot_inner_outputs.append(tensor_update)
# Note that `pos` is not a negative index. The sign of `pos` is used
Expand All @@ -1000,14 +1027,14 @@ def wrap_into_list(x):
# refers to the update rule with index `-1 - pos`.
sit_sot_rightOrder.append(-1 - len(sit_sot_shared))
sit_sot_shared.append(input.variable)
inner_replacements[input.variable] = new_var

else:
shared_inner_inputs.append(new_var)
shared_scan_inputs.append(input.variable)
shared_inner_outputs.append(input.update)
inner_replacements[input.variable] = new_var
n_shared_outs += 1
else:
no_update_shared_inputs.append(input)

n_sit_sot = len(sit_sot_inner_inputs)

Expand Down Expand Up @@ -1048,33 +1075,20 @@ def wrap_into_list(x):

other_shared_scan_args = [
arg.variable
for arg in dummy_inputs
if (
isinstance(arg.variable, SharedVariable)
and not arg.update
and arg.variable in non_seqs_set
)
for arg in no_update_shared_inputs
if arg.variable in non_seqs_set
]
other_shared_inner_args = [
safe_new(arg.variable, "_copy")
for arg in dummy_inputs
if (
isinstance(arg.variable, SharedVariable)
and not arg.update
and arg.variable in non_seqs_set
)
for arg in no_update_shared_inputs
if arg.variable in non_seqs_set
]
else:
other_shared_scan_args = [
arg.variable
for arg in dummy_inputs
if (isinstance(arg.variable, SharedVariable) and not arg.update)
]
other_shared_scan_args = [arg.variable for arg in no_update_shared_inputs]
other_shared_inner_args = [
safe_new(arg.variable, "_copy")
for arg in dummy_inputs
if (isinstance(arg.variable, SharedVariable) and not arg.update)
safe_new(arg.variable, "_copy") for arg in no_update_shared_inputs
]

inner_replacements.update(
dict(zip(other_shared_scan_args, other_shared_inner_args))
)
Expand Down
49 changes: 49 additions & 0 deletions tests/scan/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from aesara.tensor.math import sum as at_sum
from aesara.tensor.math import tanh
from aesara.tensor.nnet import categorical_crossentropy
from aesara.tensor.random import normal
from aesara.tensor.random.utils import RandomStream
from aesara.tensor.shape import Shape_i, reshape, specify_shape
from aesara.tensor.sharedvar import SharedVariable
Expand Down Expand Up @@ -240,6 +241,54 @@ def scan_nodes_from_fct(fct):


class TestScan:
@pytest.mark.parametrize(
"rng_type",
[
np.random.default_rng,
np.random.RandomState,
],
)
def test_inner_graph_cloning(self, rng_type):
r"""Scan should remove the updates-providing special properties on `RandomType`\s."""

inner_inner_rng = shared(rng_type(), name="inner_inner_rng")

y = shared(np.array(1.0, dtype=config.floatX), name="y")
y.default_update = y + 1

z_rng = shared(rng_type(), name="z_rng")
z = normal(0, 1, rng=z_rng, name="z")

z_rng_update = z.owner.outputs[0]
z_rng_update.name = "z_rng_update"
z_rng.default_update = z_rng_update

inner_rng = None

def inner_fn(x):
inner_rng = shared(rng_type(), name="inner_rng")
inner_rng.default_update = inner_inner_rng
inner_inner_rng.default_update = inner_rng

r = normal(x, rng=inner_rng)
return r + y + z, z

out, out_updates = scan(
inner_fn,
outputs_info=[at.as_tensor(0.0, dtype=config.floatX), None],
n_steps=4,
)

assert not hasattr(inner_rng, "default_update")
assert hasattr(inner_inner_rng, "default_update")
assert hasattr(y, "default_update")
assert hasattr(z_rng, "default_update")

out_fn = function([], out, mode=Mode(optimizer=None))
res, z_res = out_fn()
assert len(set(res)) == 4
assert len(set(z_res)) == 1

@pytest.mark.skipif(
isinstance(get_default_mode(), DebugMode),
reason="This test fails in DebugMode, because it is not yet picklable.",
Expand Down
Loading

0 comments on commit 00e0d80

Please sign in to comment.