Skip to content

Commit

Permalink
Add a non-one shape constraint to TensorType
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Nov 1, 2022
1 parent 454f8ae commit 92b3309
Show file tree
Hide file tree
Showing 8 changed files with 196 additions and 81 deletions.
2 changes: 1 addition & 1 deletion aesara/sandbox/rng_mrg.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def new(cls, rstate, ndim, dtype, size):
v_size = as_tensor_variable(size)
if ndim is None:
ndim = get_vector_length(v_size)
op = cls(TensorType(dtype, (False,) * ndim))
op = cls(TensorType(dtype, shape=(None,) * ndim))
return op(rstate, v_size)

def perform(self, node, inp, out, params):
Expand Down
6 changes: 2 additions & 4 deletions aesara/sparse/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,8 @@ def __init__(
dtype: Union[str, np.dtype],
shape: Optional[Iterable[Optional[Union[bool, int]]]] = None,
name: Optional[str] = None,
broadcastable: Optional[Iterable[bool]] = None,
):
if shape is None and broadcastable is None:
if shape is None:
shape = (None, None)

if format not in self.format_cls:
Expand All @@ -82,13 +81,12 @@ def __init__(

self.format = format

super().__init__(dtype, shape=shape, name=name, broadcastable=broadcastable)
super().__init__(dtype, shape=shape, name=name)

def clone(
self,
dtype=None,
shape=None,
broadcastable=None,
**kwargs,
):
format: Optional[SparsityTypes] = kwargs.pop("format", self.format)
Expand Down
73 changes: 58 additions & 15 deletions aesara/tensor/shape.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import warnings
from numbers import Number
from textwrap import dedent
from typing import Dict, List, Tuple, Union
from typing import Dict, List, Sequence, Tuple, Union

import numpy as np

Expand All @@ -16,11 +16,65 @@
from aesara.tensor import basic as at
from aesara.tensor import get_vector_length
from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.type import DenseTensorType, TensorType, int_dtypes, tensor
from aesara.tensor.type import (
DenseTensorType,
TensorType,
int_dtypes,
shape_key,
tensor,
)
from aesara.tensor.type_other import NoneConst
from aesara.tensor.var import TensorConstant, TensorVariable


def filter_shape_vars(
ref_shape: Tuple[int, ...], shape: Sequence[Variable], shape_is_encoded: bool = True
) -> Tuple[int, ...]:
r"""Compute the most \"informative\" shape based on a static reference.
Parameters
----------
ref_shape
A static shape reference using static shape constraint encoding.
shape
A symbolic shape.
shape_is_encoded
If ``True``, `shape` is assumed to be static shape constraint encoded.
Returns
-------
The most specific, and compatible (with `ref_shape`), static shape
constraint encoded values.
"""
shape_bottom = shape_key(None)
type_shape = ()
for i, (xts, s) in enumerate(zip(ref_shape, shape)):

try:
# TODO FIXME: We shouldn't need to do this; let a rewrite
# do constant folding and update the `TensorType`s.
s_val = at.get_scalar_constant_value(s)

if isinstance(s_val, np.ndarray):
s_val = s_val.item()

if shape_is_encoded or s_val is not None and s_val > 0:
type_s = shape_key(s_val)
else:
type_s = shape_bottom
except NotScalarConstantError:
type_s = shape_bottom

if not (xts <= -1 or type_s <= -1 or type_s == xts):
raise AssertionError(
f"SpecifyShape: Got shape {xts} at index {i}, expected {type_s}."
)

type_shape += (max(type_s, xts),)

return type_shape


def register_shape_c_code(type, code, version=()):
"""
Tell Shape Op how to generate C code for an Aesara Type.
Expand Down Expand Up @@ -383,7 +437,6 @@ class SpecifyShape(COp):
_f16_ok = True

def make_node(self, x, *shape):
from aesara.tensor.basic import get_scalar_constant_value

x = at.as_tensor_variable(x)

Expand All @@ -406,18 +459,7 @@ def make_node(self, x, *shape):
f"Input `x` is {x.type.ndim}-dimensional and will never match a shape of length {len(shape)}."
)

type_shape = [None] * x.ndim
for i, (xts, s) in enumerate(zip(x.type.shape, shape)):
if xts is not None:
type_shape[i] = xts
else:
try:
type_s = get_scalar_constant_value(s)
if type_s is not None:
type_shape[i] = int(type_s)
except NotScalarConstantError:
pass

type_shape = filter_shape_vars(x.type.shape_encoded, shape)
out_var = x.type.clone(shape=type_shape)()

return Apply(self, [x, *shape], [out_var])
Expand Down Expand Up @@ -601,6 +643,7 @@ def make_node(self, x, shp):
x = at.as_tensor_variable(x)
shp_orig = shp
shp = at.as_tensor_variable(shp, ndim=1)

if not (
shp.dtype in int_dtypes
or (isinstance(shp, TensorConstant) and shp.data.size == 0)
Expand Down
10 changes: 8 additions & 2 deletions aesara/tensor/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1549,7 +1549,10 @@ def make_node(self, x, y, *inputs):
f"Wrong type for Subtensor template. Expected {input.type}, got {expected_type}."
)

return Apply(self, (x, y) + inputs, [x.type()])
out_var = x.type.clone(
shape=tuple(1 if s == 1 else None for s in x.type.shape)
)()
return Apply(self, (x, y) + inputs, [out_var])

def decl_view(self):
return "PyArrayObject * zview = NULL;"
Expand Down Expand Up @@ -2180,7 +2183,10 @@ def make_node(self, x, y, ilist):
% (opname, x_.type.ndim, y_.type.ndim)
)

return Apply(self, [x_, y_, ilist_], [x_.type()])
out_var = x_.type.clone(
shape=tuple(1 if s == 1 else None for s in x_.type.shape)
)()
return Apply(self, [x_, y_, ilist_], [out_var])

def copy_of_x(self, x):
"""
Expand Down
134 changes: 92 additions & 42 deletions aesara/tensor/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,26 @@
}


def parse_bcast_and_shape(s):
if s is None:
return (None, False)
elif s == 1:
return (1, False)
elif s >= 0:
return (s, True)
elif s < 0:
# The second flag states that this dimension's size cannot be
# equal to 1
return (None, True)


def shape_key(s):
if s is None:
return -2

return s


class TensorType(CType[np.ndarray], HasDataType, HasShape):
r"""Symbolic `Type` representing `numpy.ndarray`\s."""

Expand All @@ -72,7 +92,6 @@ def __init__(
dtype: Union[str, np.dtype],
shape: Optional[Iterable[Optional[Union[bool, int]]]] = None,
name: Optional[str] = None,
broadcastable: Optional[Iterable[bool]] = None,
):
r"""
Expand All @@ -82,7 +101,8 @@ def __init__(
A NumPy dtype (e.g. ``"int64"``).
shape
The static shape information. ``None``\s are used to indicate
unknown shape values for their respective dimensions.
unknown shape values for their respective dimensions and ``-1`` to
indicate the constraint ``shape != 1``.
If `shape` is a list of ``bool``\s, the ``True`` elements of are
converted to ``1``\s and the ``False`` values are converted to
``None``\s.
Expand All @@ -91,12 +111,7 @@ def __init__(
"""

if broadcastable is not None:
warnings.warn(
"The `broadcastable` keyword is deprecated; use `shape`.",
DeprecationWarning,
)
shape = broadcastable
self.name = name

if str(dtype) == "floatX":
self.dtype = config.floatX
Expand All @@ -106,30 +121,21 @@ def __init__(

self.dtype = np.dtype(dtype).name

def parse_bcast_and_shape(s):
if isinstance(s, (bool, np.bool_)):
return 1 if s else None
else:
return s

self.shape = tuple(parse_bcast_and_shape(s) for s in shape)
self.dtype_specs() # error checking is done there
self.name = name
self.numpy_dtype = np.dtype(self.dtype)

def clone(
self, dtype=None, shape=None, broadcastable=None, **kwargs
) -> "TensorType":
if broadcastable is not None:
warnings.warn(
"The `broadcastable` keyword is deprecated; use `shape`.",
DeprecationWarning,
)
shape = broadcastable
self.shape_encoded = tuple(shape_key(s) for s in shape)

assert isinstance(self.shape_encoded, tuple)
assert all(
isinstance(s, int) and not isinstance(s, bool) for s in self.shape_encoded
)

def clone(self, dtype=None, shape=None, **kwargs) -> "TensorType":
if dtype is None:
dtype = self.dtype
if shape is None:
shape = self.shape
shape = self.shape_encoded
return type(self)(dtype, shape, name=self.name)

def filter(self, data, strict=False, allow_downcast=None):
Expand Down Expand Up @@ -243,16 +249,24 @@ def filter(self, data, strict=False, allow_downcast=None):
" Aesara C code does not support that.",
)

if not all(
ds == ts if ts is not None else True
for ds, ts in zip(data.shape, self.shape)
):
raise TypeError(
f"The type's shape ({self.shape}) is not compatible with the data's ({data.shape})"
)
def check_shape_info(i, s_val, s_info):
if s_info == -1 and s_val == 1:
raise ValueError(
f"Value's shape in dimension {i} is not compatible "
f"with the constraint: {s_val} != 1"
)
if s_info > -1 and s_val != s_info:
raise ValueError(
f"Value's shape in dimension {i} is not compatible "
f"with the constraint: {s_val} == {s_info}"
)

for i, (s_val, s_info) in enumerate(zip(np.shape(data), self.shape_encoded)):
check_shape_info(i, s_val, s_info)

if self.filter_checks_isfinite and not np.all(np.isfinite(data)):
raise ValueError("Non-finite elements not allowed")

return data

def filter_variable(self, other, allow_convert=True):
Expand Down Expand Up @@ -308,7 +322,10 @@ def in_same_class(self, otype):
if (
isinstance(otype, TensorType)
and otype.dtype == self.dtype
and otype.broadcastable == self.broadcastable
and all(
s == o_s if s == 1 or o_s == 1 else True
for s, o_s in zip(self.shape, otype.shape)
)
):
return True
return False
Expand All @@ -320,7 +337,11 @@ def is_super(self, otype):
and otype.ndim == self.ndim
# `otype` is allowed to be as or more shape-specific than `self`,
# but not less
and all(sb == ob or sb is None for sb, ob in zip(self.shape, otype.shape))
and all(
s == o_s if s > -1 and o_s > -1 else s <= o_s
# not (s is not None and s >= 0 and max(s, o_s, key=shape_key) != s)
for s, o_s in zip(self.shape_encoded, otype.shape_encoded)
)
):
return True

Expand All @@ -334,13 +355,22 @@ def convert_variable(self, var):
if (self.ndim == var.type.ndim) and (self.dtype == var.type.dtype):
# `var.type` only differs from `self` in that its shape is (at least partially)
# less specific than `self`, so we convert `var` to `self`'s `Type`.
# `specify_shape` will combine the more precise shapes of the two types
return aesara.tensor.specify_shape(var, self.shape)
# `specify_shape` will combine the more precise shapes of the two types.

new_shape_encoded = ()
for s, o_s in zip(self.shape_encoded, var.type.shape_encoded):

if s > -1 and o_s > -1 and s != o_s:
raise ValueError(
f"Incompatible shapes: {self.shape_encoded}, {var.type.shape_encoded}"
)

new_shape_encoded += (max(s, o_s),)

return aesara.tensor.specify_shape(var, new_shape_encoded)

@staticmethod
def values_eq(a, b, force_same_dtype=True):
# TODO: check to see if the shapes must match; for now, we err on safe
# side...
if a.shape != b.shape:
return False
if force_same_dtype and a.dtype != b.dtype:
Expand All @@ -367,14 +397,23 @@ def __eq__(self, other):
if type(self) != type(other):
return NotImplemented

return other.dtype == self.dtype and other.shape == self.shape
return other.dtype == self.dtype and other.shape_encoded == self.shape_encoded

def __hash__(self):
return hash((type(self), self.dtype, self.shape))
return hash((type(self), self.dtype, self.shape_encoded))

@property
def shape(self) -> Tuple[Optional[Union[int]]]:
"""Return a static shape tuple with unknown values equal to ``None``."""
return tuple(s if s > -1 else None for s in self.shape_encoded)

@property
def broadcastable(self):
"""A boolean tuple indicating which dimensions have a shape equal to one."""
warnings.warn(
"TensorType.broadcastable is deprecated; use TensorType.shape",
DeprecationWarning,
)
return tuple(s == 1 for s in self.shape)

@property
Expand All @@ -386,7 +425,18 @@ def __str__(self):
if self.name:
return self.name
else:
return f"TensorType({self.dtype}, {self.shape})"

def shape_str(s):
if s == -1:
return ">1"
elif s < -1:
return "?"
else:
return str(s)

formatted_shape = ", ".join([shape_str(s) for s in self.shape_encoded])

return f"TensorType({self.dtype}, ({formatted_shape}))"

def __repr__(self):
return str(self)
Expand Down
Loading

0 comments on commit 92b3309

Please sign in to comment.