From edbb5603004e9567f59f6a659b18923fc66b18ec Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 5 Jun 2021 08:46:54 +0100 Subject: [PATCH] Relax type constraints for `Stacked` (#177) * relax type constraints for Stacked * additional changes to make new Stacked work * added default constructor for Stacked * fixed typo * fixed tests for stacked * aight finally fixed tests for Stacked * added suggestion from review * fixed typo * fixed impl for Stacked using array * updated tests * bump patch version --- src/Bijectors.jl | 2 +- src/bijectors/stacked.jl | 57 ++++++++++++++++++++++------------------ test/interface.jl | 23 +++++++++++----- 3 files changed, 49 insertions(+), 33 deletions(-) diff --git a/src/Bijectors.jl b/src/Bijectors.jl index 955d0be7..fb9e680a 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -93,7 +93,7 @@ end function mapvcat(f, args...) out = map(f, args...) init = vcat(out[1]) - return reshape(reduce(vcat, drop(out, 1); init = init), size(out)) + return reduce(vcat, drop(out, 1); init = init) end function maphcat(f, args...) diff --git a/src/bijectors/stacked.jl b/src/bijectors/stacked.jl index 0131a76c..3fbc591c 100644 --- a/src/bijectors/stacked.jl +++ b/src/bijectors/stacked.jl @@ -21,28 +21,12 @@ b = stack(b1, b2) b([0.0, 1.0]) == [b1(0.0), 1.0] # => true ``` """ -struct Stacked{Bs, N} <: Bijector{1} +struct Stacked{Bs, Rs} <: Bijector{1} bs::Bs - ranges::NTuple{N, UnitRange{Int}} - - function Stacked( - bs::C, - ranges::NTuple{N, UnitRange{Int}} - ) where {N, C<:Tuple{Vararg{<:ZeroOrOneDimBijector, N}}} - return new{C, N}(bs, ranges) - end - - function Stacked( - bs::A, - ranges::NTuple{N, UnitRange{Int}} - ) where {N, A<:AbstractArray{<:Bijector}} - @assert length(bs) == N "number of bijectors is not same as number of ranges" - @assert all(b -> isa(b, ZeroOrOneDimBijector), bs) - return new{A, N}(bs, ranges) - end + ranges::Rs end -Stacked(bs, ranges::AbstractArray) = Stacked(bs, tuple(ranges...)) -Stacked(bs) = Stacked(bs, tuple([i:i for i = 1:length(bs)]...)) +Stacked(bs::Tuple) = Stacked(bs, ntuple(i -> i:i, length(bs))) +Stacked(bs::AbstractArray) = Stacked(bs, [i:i for i in 1:length(bs)]) # define nested numerical parameters # TODO: replace with `Functors.@functor Stacked (bs,)` when @@ -95,7 +79,10 @@ function (sb::Stacked{<:Tuple})(x::AbstractVector{<:Real}) return y end # The Stacked{<:AbstractArray} version is not TrackedArray friendly -function (sb::Stacked{<:AbstractArray, N})(x::AbstractVector{<:Real}) where {N} +function (sb::Stacked{<:AbstractArray})(x::AbstractVector{<:Real}) + N = length(sb.bs) + N == 1 && return sb.bs[1](x[sb.ranges[1]]) + y = mapvcat(1:N) do i sb.bs[i](x[sb.ranges[i]]) end @@ -105,7 +92,23 @@ end (sb::Stacked)(x::AbstractMatrix{<:Real}) = eachcolmaphcat(sb, x) function logabsdetjac( - b::Stacked{<:Any, N}, + b::Stacked, + x::AbstractVector{<:Real} +) + N = length(b.bs) + init = sum(logabsdetjac(b.bs[1], x[b.ranges[1]])) + + return if N > 1 + init + sum(2:N) do i + sum(logabsdetjac(b.bs[i], x[b.ranges[i]])) + end + else + init + end +end + +function logabsdetjac( + b::Stacked{<:Tuple{Vararg{<:Any, N}}}, x::AbstractVector{<:Real} ) where {N} init = sum(logabsdetjac(b.bs[1], x[b.ranges[1]])) @@ -114,7 +117,8 @@ function logabsdetjac( end end -function logabsdetjac(b::Stacked{<:Any, 1}, x::AbstractVector{<:Real}) +# Handle the case of just one bijector +function logabsdetjac(b::Stacked{<:Tuple{<:Bijector}}, x::AbstractVector{<:Real}) return sum(logabsdetjac(b.bs[1], x[b.ranges[1]])) end @@ -133,7 +137,7 @@ end # logjac += sum(_logjac) # return (rv = vcat(y_1, y_2), logabsdetjac = logjac) # end -@generated function forward(b::Stacked{T, N}, x::AbstractVector) where {N, T<:Tuple} +@generated function forward(b::Stacked{<:Tuple{Vararg{<:Any, N}}}, x::AbstractVector) where {N} expr = Expr(:block) y_names = [] @@ -141,7 +145,7 @@ end # TODO: drop the `sum` when we have dimensionality push!(expr.args, :(logjac = sum(_logjac))) push!(y_names, :y_1) - for i = 2:length(T.parameters) + for i = 2:N y_name = Symbol("y_$i") push!(expr.args, :(($y_name, _logjac) = forward(b.bs[$i], x[b.ranges[$i]]))) @@ -155,7 +159,8 @@ end return expr end -function forward(sb::Stacked{<:AbstractArray, N}, x::AbstractVector) where {N} +function forward(sb::Stacked{<:AbstractArray}, x::AbstractVector) + N = length(sb.bs) yinit, linit = forward(sb.bs[1], x[sb.ranges[1]]) logjac = sum(linit) ys = mapvcat(drop(sb.bs, 1), drop(sb.ranges, 1)) do b, r diff --git a/test/interface.jl b/test/interface.jl index 406a7ceb..9e594813 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -158,7 +158,7 @@ end (Exp{1}() ∘ PlanarLayer(2) ∘ RadialLayer(2), randn(2, 3)), (SimplexBijector(), mapslices(z -> normalize(z, 1), rand(2, 3); dims = 1)), (stack(Exp{0}(), Scale(2.0)), randn(2, 3)), - (Stacked((Exp{1}(), SimplexBijector()), [1:1, 2:3]), + (Stacked((Exp{1}(), SimplexBijector()), (1:1, 2:3)), mapslices(z -> normalize(z, 1), rand(3, 2); dims = 1)), (RationalQuadraticSpline(randn(3), randn(3), randn(3 - 1), 2.), [-0.5, 0.5]), (LeakyReLU(0.1), randn(3)), @@ -611,7 +611,7 @@ end # TODO: change when we have dimensionality in the type - sb = @inferred Stacked((Bijectors.Exp(), Bijectors.SimplexBijector()), [1:1, 2:3]) + sb = @inferred Stacked((Bijectors.Exp(), Bijectors.SimplexBijector()), (1:1, 2:3)) x = ones(3) ./ 3.0 res = @inferred forward(sb, x) @test sb(param(x)) isa TrackedArray @@ -623,8 +623,18 @@ end x = ones(4) ./ 4.0 @test_throws AssertionError sb(x) - @test_throws AssertionError Stacked([Bijectors.Exp(), ], (1:1, 2:3)) - @test_throws MethodError Stacked((Bijectors.Exp(), ), (1:1, 2:3)) + # Array-version + sb = Stacked([Bijectors.Exp(), Bijectors.SimplexBijector()], [1:1, 2:3]) + x = ones(3) ./ 3.0 + res = forward(sb, x) + @test sb(param(x)) isa TrackedArray + @test sb(x) == [exp(x[1]), sb.bs[2](x[2:3])...] + @test res.rv == [exp(x[1]), sb.bs[2](x[2:3])...] + @test logabsdetjac(sb, x) == sum([sum(logabsdetjac(sb.bs[i], x[sb.ranges[i]])) for i = 1:2]) + @test res.logabsdetjac == logabsdetjac(sb, x) + + x = ones(4) ./ 4.0 + @test_throws AssertionError sb(x) @testset "Stacked: ADVI with MvNormal" begin # MvNormal test @@ -649,6 +659,7 @@ end push!(ranges, idx:idx + length(d) - 1) idx += length(d) end + ranges = tuple(ranges...) num_params = ranges[end][end] d = MvNormal(zeros(num_params), ones(num_params)) @@ -675,7 +686,7 @@ end ibs = inv.(bs) sb = @inferred Stacked(ibs, ranges) isb = @inferred inv(sb) - @test sb isa Stacked{<: Tuple} + @test sb isa Stacked{<:Tuple} # inverse td = @inferred transformed(d, sb) @@ -697,7 +708,7 @@ end # Ensure `Stacked` works for a single bijector d = (MvNormal(2, 1.0),) - sb = Stacked(bijector.(d), [1:2]) + sb = Stacked(bijector.(d), (1:2, )) x = [.5, 1.] @test sb(x) == x @test logabsdetjac(sb, x) == 0