Skip to content

Commit

Permalink
Relax type constraints for Stacked (#177)
Browse files Browse the repository at this point in the history
* relax type constraints for Stacked

* additional changes to make new Stacked work

* added default constructor for Stacked

* fixed typo

* fixed tests for stacked

* aight finally fixed tests for Stacked

* added suggestion from review

* fixed typo

* fixed impl for Stacked using array

* updated tests

* bump patch version
  • Loading branch information
torfjelde authored Jun 5, 2021
1 parent 60dc43c commit edbb560
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 33 deletions.
2 changes: 1 addition & 1 deletion src/Bijectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ end
function mapvcat(f, args...)
out = map(f, args...)
init = vcat(out[1])
return reshape(reduce(vcat, drop(out, 1); init = init), size(out))
return reduce(vcat, drop(out, 1); init = init)
end

function maphcat(f, args...)
Expand Down
57 changes: 31 additions & 26 deletions src/bijectors/stacked.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,28 +21,12 @@ b = stack(b1, b2)
b([0.0, 1.0]) == [b1(0.0), 1.0] # => true
```
"""
struct Stacked{Bs, N} <: Bijector{1}
struct Stacked{Bs, Rs} <: Bijector{1}
bs::Bs
ranges::NTuple{N, UnitRange{Int}}

function Stacked(
bs::C,
ranges::NTuple{N, UnitRange{Int}}
) where {N, C<:Tuple{Vararg{<:ZeroOrOneDimBijector, N}}}
return new{C, N}(bs, ranges)
end

function Stacked(
bs::A,
ranges::NTuple{N, UnitRange{Int}}
) where {N, A<:AbstractArray{<:Bijector}}
@assert length(bs) == N "number of bijectors is not same as number of ranges"
@assert all(b -> isa(b, ZeroOrOneDimBijector), bs)
return new{A, N}(bs, ranges)
end
ranges::Rs
end
Stacked(bs, ranges::AbstractArray) = Stacked(bs, tuple(ranges...))
Stacked(bs) = Stacked(bs, tuple([i:i for i = 1:length(bs)]...))
Stacked(bs::Tuple) = Stacked(bs, ntuple(i -> i:i, length(bs)))
Stacked(bs::AbstractArray) = Stacked(bs, [i:i for i in 1:length(bs)])

# define nested numerical parameters
# TODO: replace with `Functors.@functor Stacked (bs,)` when
Expand Down Expand Up @@ -95,7 +79,10 @@ function (sb::Stacked{<:Tuple})(x::AbstractVector{<:Real})
return y
end
# The Stacked{<:AbstractArray} version is not TrackedArray friendly
function (sb::Stacked{<:AbstractArray, N})(x::AbstractVector{<:Real}) where {N}
function (sb::Stacked{<:AbstractArray})(x::AbstractVector{<:Real})
N = length(sb.bs)
N == 1 && return sb.bs[1](x[sb.ranges[1]])

y = mapvcat(1:N) do i
sb.bs[i](x[sb.ranges[i]])
end
Expand All @@ -105,7 +92,23 @@ end

(sb::Stacked)(x::AbstractMatrix{<:Real}) = eachcolmaphcat(sb, x)
function logabsdetjac(
b::Stacked{<:Any, N},
b::Stacked,
x::AbstractVector{<:Real}
)
N = length(b.bs)
init = sum(logabsdetjac(b.bs[1], x[b.ranges[1]]))

return if N > 1
init + sum(2:N) do i
sum(logabsdetjac(b.bs[i], x[b.ranges[i]]))
end
else
init
end
end

function logabsdetjac(
b::Stacked{<:Tuple{Vararg{<:Any, N}}},
x::AbstractVector{<:Real}
) where {N}
init = sum(logabsdetjac(b.bs[1], x[b.ranges[1]]))
Expand All @@ -114,7 +117,8 @@ function logabsdetjac(
end
end

function logabsdetjac(b::Stacked{<:Any, 1}, x::AbstractVector{<:Real})
# Handle the case of just one bijector
function logabsdetjac(b::Stacked{<:Tuple{<:Bijector}}, x::AbstractVector{<:Real})
return sum(logabsdetjac(b.bs[1], x[b.ranges[1]]))
end

Expand All @@ -133,15 +137,15 @@ end
# logjac += sum(_logjac)
# return (rv = vcat(y_1, y_2), logabsdetjac = logjac)
# end
@generated function forward(b::Stacked{T, N}, x::AbstractVector) where {N, T<:Tuple}
@generated function forward(b::Stacked{<:Tuple{Vararg{<:Any, N}}}, x::AbstractVector) where {N}
expr = Expr(:block)
y_names = []

push!(expr.args, :((y_1, _logjac) = forward(b.bs[1], x[b.ranges[1]])))
# TODO: drop the `sum` when we have dimensionality
push!(expr.args, :(logjac = sum(_logjac)))
push!(y_names, :y_1)
for i = 2:length(T.parameters)
for i = 2:N
y_name = Symbol("y_$i")
push!(expr.args, :(($y_name, _logjac) = forward(b.bs[$i], x[b.ranges[$i]])))

Expand All @@ -155,7 +159,8 @@ end
return expr
end

function forward(sb::Stacked{<:AbstractArray, N}, x::AbstractVector) where {N}
function forward(sb::Stacked{<:AbstractArray}, x::AbstractVector)
N = length(sb.bs)
yinit, linit = forward(sb.bs[1], x[sb.ranges[1]])
logjac = sum(linit)
ys = mapvcat(drop(sb.bs, 1), drop(sb.ranges, 1)) do b, r
Expand Down
23 changes: 17 additions & 6 deletions test/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ end
(Exp{1}() PlanarLayer(2) RadialLayer(2), randn(2, 3)),
(SimplexBijector(), mapslices(z -> normalize(z, 1), rand(2, 3); dims = 1)),
(stack(Exp{0}(), Scale(2.0)), randn(2, 3)),
(Stacked((Exp{1}(), SimplexBijector()), [1:1, 2:3]),
(Stacked((Exp{1}(), SimplexBijector()), (1:1, 2:3)),
mapslices(z -> normalize(z, 1), rand(3, 2); dims = 1)),
(RationalQuadraticSpline(randn(3), randn(3), randn(3 - 1), 2.), [-0.5, 0.5]),
(LeakyReLU(0.1), randn(3)),
Expand Down Expand Up @@ -611,7 +611,7 @@ end


# TODO: change when we have dimensionality in the type
sb = @inferred Stacked((Bijectors.Exp(), Bijectors.SimplexBijector()), [1:1, 2:3])
sb = @inferred Stacked((Bijectors.Exp(), Bijectors.SimplexBijector()), (1:1, 2:3))
x = ones(3) ./ 3.0
res = @inferred forward(sb, x)
@test sb(param(x)) isa TrackedArray
Expand All @@ -623,8 +623,18 @@ end
x = ones(4) ./ 4.0
@test_throws AssertionError sb(x)

@test_throws AssertionError Stacked([Bijectors.Exp(), ], (1:1, 2:3))
@test_throws MethodError Stacked((Bijectors.Exp(), ), (1:1, 2:3))
# Array-version
sb = Stacked([Bijectors.Exp(), Bijectors.SimplexBijector()], [1:1, 2:3])
x = ones(3) ./ 3.0
res = forward(sb, x)
@test sb(param(x)) isa TrackedArray
@test sb(x) == [exp(x[1]), sb.bs[2](x[2:3])...]
@test res.rv == [exp(x[1]), sb.bs[2](x[2:3])...]
@test logabsdetjac(sb, x) == sum([sum(logabsdetjac(sb.bs[i], x[sb.ranges[i]])) for i = 1:2])
@test res.logabsdetjac == logabsdetjac(sb, x)

x = ones(4) ./ 4.0
@test_throws AssertionError sb(x)

@testset "Stacked: ADVI with MvNormal" begin
# MvNormal test
Expand All @@ -649,6 +659,7 @@ end
push!(ranges, idx:idx + length(d) - 1)
idx += length(d)
end
ranges = tuple(ranges...)

num_params = ranges[end][end]
d = MvNormal(zeros(num_params), ones(num_params))
Expand All @@ -675,7 +686,7 @@ end
ibs = inv.(bs)
sb = @inferred Stacked(ibs, ranges)
isb = @inferred inv(sb)
@test sb isa Stacked{<: Tuple}
@test sb isa Stacked{<:Tuple}

# inverse
td = @inferred transformed(d, sb)
Expand All @@ -697,7 +708,7 @@ end

# Ensure `Stacked` works for a single bijector
d = (MvNormal(2, 1.0),)
sb = Stacked(bijector.(d), [1:2])
sb = Stacked(bijector.(d), (1:2, ))
x = [.5, 1.]
@test sb(x) == x
@test logabsdetjac(sb, x) == 0
Expand Down

0 comments on commit edbb560

Please sign in to comment.