From c822896a9a5633e61275ff152bda84fb41619ad7 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Sun, 26 Dec 2021 14:16:59 +0000 Subject: [PATCH] Allow AbstractGPs in WrappedGP (#217) * Bump patch * Allow AbstractGP in WrappedGP * Test nested GPPP * Check WrappedGP with fresh AbstractGP * Remove redundant code * Remove redudant code --- Project.toml | 2 +- src/Stheno.jl | 1 - src/gp/gp.jl | 4 +- src/util/block_arrays/diagonal.jl | 231 -------------- ...aussian_process_probabilistic_programme.jl | 19 +- test/gp/gp.jl | 6 + test/runtests.jl | 5 +- test/util/block_arrays/diagonal.jl | 299 ------------------ 8 files changed, 27 insertions(+), 540 deletions(-) delete mode 100644 src/util/block_arrays/diagonal.jl delete mode 100644 test/util/block_arrays/diagonal.jl diff --git a/Project.toml b/Project.toml index b45e74be..83463f68 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Stheno" uuid = "8188c328-b5d6-583d-959b-9690869a5511" -version = "0.7.15" +version = "0.7.16" [deps] AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918" diff --git a/src/Stheno.jl b/src/Stheno.jl index 9096c9c8..bd946297 100644 --- a/src/Stheno.jl +++ b/src/Stheno.jl @@ -37,7 +37,6 @@ module Stheno include(joinpath("util", "zygote_rules.jl")) include(joinpath("util", "covariance_matrices.jl")) include(joinpath("util", "block_arrays", "dense.jl")) - include(joinpath("util", "block_arrays", "diagonal.jl")) include(joinpath("util", "abstract_data_set.jl")) include(joinpath("util", "proper_type_piracy.jl")) diff --git a/src/gp/gp.jl b/src/gp/gp.jl index fc4926c9..847221da 100644 --- a/src/gp/gp.jl +++ b/src/gp/gp.jl @@ -12,14 +12,14 @@ struct WrappedGP{Tgp<:AbstractGP} <: SthenoAbstractGP gp::Tgp n::Int gpc::GPC - function WrappedGP{Tgp}(gp::Tgp, gpc::GPC) where {Tgp<:GP} + function WrappedGP{Tgp}(gp::Tgp, gpc::GPC) where {Tgp<:AbstractGP} wgp = new{Tgp}(gp, next_index(gpc), gpc) gpc.n += 1 return wgp end end -wrap(gp::Tgp, gpc::GPC) where {Tgp<:GP} = WrappedGP{Tgp}(gp, gpc) +wrap(gp::Tgp, gpc::GPC) where {Tgp<:AbstractGP} = WrappedGP{Tgp}(gp, gpc) mean(f::WrappedGP, x::AbstractVector) = mean(f.gp, x) diff --git a/src/util/block_arrays/diagonal.jl b/src/util/block_arrays/diagonal.jl deleted file mode 100644 index 8f70195f..00000000 --- a/src/util/block_arrays/diagonal.jl +++ /dev/null @@ -1,231 +0,0 @@ -const BlockDiagonal{T, TM} = BlockMatrix{T, <:Diagonal{TM}} where {TM <: AbstractMatrix{T}} - - - -# -# Constructors -# - -block_diagonal(vs::AbstractVector{<:AbstractMatrix}) = mortar(Diagonal(vs)) - - - -# -# Accumulation rule for Zygote. -# - -function Zygote.accum(A::BlockDiagonal, B::BlockDiagonal) - return block_diagonal(accum.(A.blocks.diag, B.blocks.diag)) -end - - -# -# adjoint / transpose - ensure we get a BlockDiagonal back -# - -LinearAlgebra.adjoint(A::BlockDiagonal) = block_diagonal(adjoint.(A.blocks.diag)) -LinearAlgebra.transpose(A::BlockDiagonal) = block_diagonal(transpose.(A.blocks.diag)) - - -# -# UpperTriangular - ensure we get a BlockDiagonal back -# - -function LinearAlgebra.UpperTriangular(A::BlockDiagonal) - return block_diagonal(UpperTriangular.(A.blocks.diag)) -end - - - -# -# Symmetric - ensure we get a BlockDiagonal back -# - -LinearAlgebra.Symmetric(A::BlockDiagonal) = block_diagonal(Symmetric.(A.blocks.diag)) - -ZygoteRules.@adjoint function LinearAlgebra.Symmetric(A::BlockDiagonal) - return Zygote.pullback(A->block_diagonal(Symmetric.(A.blocks.diag)), A) -end - - - -# -# Addition -# - -function Base.:+(A::BlockDiagonal, B::BlockDiagonal) - return block_diagonal([a + b for (a, b) in zip(A.blocks.diag, B.blocks.diag)]) -end - -function Base.:+(A::Matrix, B::BlockDiagonal) - @assert size(A) == size(B) - C = copy(A) - cs = cumulsizes(B, 1) - for n in 1:nblocks(B, 1) - idx = cs[n]:cs[n+1]-1 - C[idx, idx] += B[Block(n, n)] - end - return C -end - -ZygoteRules.@adjoint function Base.:+(A::Matrix, B::BlockDiagonal{T, <:Matrix{T}} where {T}) - return A + B, function(Δ) - cs = cumulsizes(B, 1) - blks = [Δ[cs[n]:cs[n+1]-1, cs[n]:cs[n+1]-1] for n in 1:nblocks(B, 1)] - return (Δ, block_diagonal(blks)) - end -end - - -# -# Negation -# - -Base.:-(A::BlockDiagonal) = block_diagonal([-a for a in A.blocks.diag]) - - -# -# BlockDiagonal multiplication -# - -function Base.:*(A::BlockDiagonal{<:Real}, B::BlockDiagonal{<:Real}) - return block_diagonal([a * b for (a, b) in zip(A.blocks.diag, B.blocks.diag)]) -end - -ZygoteRules.@adjoint function Base.:*(A::BlockDiagonal{<:Real}, B::BlockDiagonal{<:Real}) - return A * B, Δ->(Δ * B', A' * Δ) -end - - -function Base.:*(A::BlockDiagonal{<:Real}, B::Matrix{<:Real}) - A_blks, B_blks = A.blocks.diag, BlockArray(B, blocksizes(A, 1), [size(B, 2)]).blocks - return Matrix(mortar(reshape([a * b for (a, b) in zip(A_blks, B_blks)], :, 1))) -end -function Base.:*(A::BlockDiagonal{<:Real}, x::Vector{<:Real}) - A_blks, x_blks = diag(A.blocks), BlockArray(x, blocksizes(A, 1)).blocks - return Vector(mortar([a * x for (a, x) in zip(A_blks, x_blks)])) -end - - -# -# BlockDiagonal ldiv -# - -function Base.:\(A::BlockDiagonal{<:Real}, B::BlockDiagonal{<:Real}) - A_blks, B_blks = diag(A.blocks), diag(B.blocks) - return block_diagonal([a \ b for (a, b) in zip(A_blks, B_blks)]) -end - -ZygoteRules.@adjoint function Base.:\(A::BlockDiagonal{<:Real}, B::BlockDiagonal{<:Real}) - Y = A \ B - return Y, function(Ȳ::BlockDiagonal) - B̄ = A' \ Ȳ - return (-B̄ * Y', B̄) - end -end - -function Base.:\(A::BlockDiagonal{<:Real}, B::AbstractMatrix{<:Real}) - A_blks = diag(A.blocks) - B_blks = BlockArray(collect(B), blocksizes(A, 1), [size(B, 2)]).blocks - - return Matrix(mortar(reshape([a \ b for (a, b) in zip(A_blks, B_blks)], :, 1))) -end -ZygoteRules.@adjoint function Base.:\(A::BlockDiagonal{<:Real}, B::AbstractMatrix{<:Real}) - Y = A \ B - return Y, function(Ȳ::AbstractMatrix{<:Real}) - B̄ = A' \ Ȳ - return (_block_diag_bit(-B̄, Y', A), B̄) - end -end - -function Base.:\(A::BlockDiagonal{<:Real}, x::AbstractVector{<:Real}) - return reshape(A \ reshape(x, :, 1), :) -end - -ZygoteRules.@adjoint function Base.:\(A::BlockDiagonal{<:Real}, x::AbstractVector{<:Real}) - y_mat, back = Zygote.pullback(\, A, reshape(x, :, 1)) - return vec(y_mat), function(Δ::AbstractVector{<:Real}) - Ā, x̄ = back(reshape(Δ, :, 1)) - return Ā, vec(x̄) - end -end - -function _block_diag_bit(A::AbstractMatrix, B::AbstractMatrix, R::BlockDiagonal) - A_blks = vec(BlockArray(A, blocksizes(R, 1), [size(A, 2)]).blocks) - B_blks = vec(BlockArray(B', blocksizes(R, 1), [size(B, 1)]).blocks) - return block_diagonal([a * b' for (a, b) in zip(A_blks, B_blks)]) -end - - -# -# cholesky -# - -function LinearAlgebra.cholesky(A::BlockDiagonal{<:Real}) - Cs = map(A->cholesky(A).U, diag(A.blocks)) - return Cholesky(block_diagonal(Cs), :U, 0) -end -ZygoteRules.@adjoint function LinearAlgebra.cholesky(A::BlockDiagonal{<:Real}) - Cs_backs = map(A->Zygote.pullback(A->cholesky(A).U, A), diag(A.blocks)) - Cs, backs = first.(Cs_backs), last.(Cs_backs) - function back(Ū::BlockDiagonal) - return (block_diagonal(map((Ū, back)->first(back(Ū)), diag(Ū.blocks), backs)),) - end - return Cholesky(block_diagonal(Cs), :U, 0), Δ->back(Δ.factors) -end - -function LinearAlgebra.logdet(C::Cholesky{T, <:BlockDiagonal{T}} where {T<:Real}) - return 2 * sum([logdet(c) for c in C.factors.blocks.diag]) -end - -ZygoteRules.@adjoint function LinearAlgebra.logdet( - C::Cholesky{T, <:BlockDiagonal{T}} where {T<:Real}, -) - return logdet(C), function(Δ::Real) - blks = C.factors.blocks.diag - factors = block_diagonal([diagm(0=>2Δ ./ diag(b)) for b in blks]) - return ((factors=factors,),) - end -end - - -# -# Misc -# - -# -# BlockDiagonal mul! and * -# - -Base.:*(D::BlockDiagonal, x::BlockVector) = mul!(copy(x), D, x) -Base.:*(D::BlockDiagonal, X::BlockMatrix) = mul!(copy(X), D, X) - -function LinearAlgebra.mul!(y::BlockVector, D::BlockDiagonal, x::BlockVector) - @assert are_conformal(D, x) && are_conformal(D, y) - blocks = D.blocks.diag - for r in 1:nblocks(D, 1) - mul!(view(y, Block(r)), blocks[r], view(x, Block(r))) - end - return y -end - -@adjoint function Base.:*(D::BlockDiagonal, x::BlockVector) - y = D * x - return y, function(ȳ::BlockVector) - @assert blocksizes(y, 1) == blocksizes(ȳ, 1) - D̄_blocks = map((x_blk, ȳ_blk) -> x_blk * ȳ_blk', x.blocks, ȳ.blocks) - D̄ = mortar(Diagonal(D̄_blocks)) - return D̄, D' * ȳ - end -end - -function LinearAlgebra.mul!(Y::BlockMatrix, D::BlockDiagonal, X::BlockMatrix) - @assert are_conformal(D, X) && are_conformal(D, Y) - blocks = D.blocks.diag - for r in 1:nblocks(D, 1) - for c in 1:nblocks(X, 2) - mul!(view(Y, Block(r, c)), blocks[r], view(X, Block(r, c))) - end - end - return Y -end diff --git a/test/gaussian_process_probabilistic_programme.jl b/test/gaussian_process_probabilistic_programme.jl index 3948efa1..b03f5d30 100644 --- a/test/gaussian_process_probabilistic_programme.jl +++ b/test/gaussian_process_probabilistic_programme.jl @@ -82,8 +82,7 @@ GPPPInput(:f1, randn(4)), ), ] - rng = MersenneTwister(123456) - AbstractGPs.TestUtils.test_internal_abstractgps_interface(rng, f, x0, x1) + test_internal_abstractgps_interface(MersenneTwister(123456), f, x0, x1) end @timedtestset "gppp macro" begin @@ -103,4 +102,20 @@ y = rand(f(x, s)) Zygote.gradient((x, y, f, s) -> logpdf(f(x, s), y), x, y, f, s) end + + # Check that we can use one GPPP inside another. + @timedtestset "nested gppp" begin + + gpc_outer = GPC() + f1_outer = Stheno.wrap(f, gpc_outer) + f2_outer = 5 * f1_outer + f_outer = Stheno.GPPP((f1=f1_outer, f2=f2_outer), gpc_outer) + + x0 = GPPPInput(:f1, randn(5)) + x1 = GPPPInput(:f2, randn(4)) + x0_outer = GPPPInput(:f1, x0) + x1_outer = GPPPInput(:f2, x1) + rng = MersenneTwister(123456) + test_internal_abstractgps_interface(rng, f_outer, x0_outer, x1_outer) + end end diff --git a/test/gp/gp.jl b/test/gp/gp.jl index 0e9e8c94..91ee06b0 100644 --- a/test/gp/gp.jl +++ b/test/gp/gp.jl @@ -1,3 +1,5 @@ +struct ToyAbstractGP <: AbstractGP end + @timedtestset "gp" begin # Ensure that basic functionality works as expected. @@ -33,4 +35,8 @@ @test cov(f1, f1, x′, x) ≈ cov(f1, f1, x, x′)' end + + @timedtestset "wrapped AbstractGP" begin + wrap(ToyAbstractGP(), GPC()) + end end diff --git a/test/runtests.jl b/test/runtests.jl index e4442894..85d00b49 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -24,7 +24,6 @@ using Stheno: GPC, AV, FiniteGP, - block_diagonal, AbstractGP, BlockData, blocks, @@ -36,10 +35,9 @@ using Stheno: diag_At_B, diag_Xt_invA_X, diag_Xt_invA_Y, - block_diagonal, - BlockDiagonal, blocksizes +using Stheno.AbstractGPs.TestUtils: test_internal_abstractgps_interface using Stheno.AbstractGPs.Distributions: MvNormal using FiniteDifferences: j′vp @@ -60,7 +58,6 @@ include("test_util.jl") @testset "block_arrays" begin include(joinpath("util", "block_arrays", "test_util.jl")) include(joinpath("util", "block_arrays", "dense.jl")) - include(joinpath("util", "block_arrays", "diagonal.jl")) end include(joinpath("util", "abstract_data_set.jl")) end diff --git a/test/util/block_arrays/diagonal.jl b/test/util/block_arrays/diagonal.jl deleted file mode 100644 index fc67d1ed..00000000 --- a/test/util/block_arrays/diagonal.jl +++ /dev/null @@ -1,299 +0,0 @@ -function general_BlockDiagonal_tests(rng, blocks) - d = block_diagonal(blocks) - Ps, Qs = size.(blocks, 1), size.(blocks, 2) - - @testset "general" begin - @test blocksizes(d, 1) == Ps - @test blocksizes(d, 2) == Qs - - @test view(d, Block(1, 1)) == blocks[1] - @test view(d, Block(2, 2)) == blocks[2] - @test view(d, Block(1, 2)) == zeros(Ps[1], Qs[2]) - @test view(d, Block(2, 1)) == zeros(Ps[2], Qs[1]) - - @test d[Block(1, 1)] == view(d, Block(1, 1)) - end -end - -function BlockDiagonal_mul_tests(rng, blocks) - D, Ps = block_diagonal(blocks), size.(blocks, 1) - Dmat = Matrix(D) - - U = UpperTriangular(D) - - xs, ys = [randn(rng, P) for P in Ps], [randn(rng, P) for P in Ps] - y, x = mortar(ys), mortar(xs) - - # Matrix-Vector product - @test mul!(y, D, x) ≈ Dmat * Vector(x) - @test mul!(y, D, x) == D * x - @test mul!(y, U, x) ≈ Matrix(U) * Vector(x) - @test mul!(y, U, x) == U * x - - Qs = [3, 4] - X = mortar([randn(rng, P, Q) for P in Ps, Q in Qs]) - Y = mortar([randn(rng, P, Q) for P in Ps, Q in Qs]) - - # Matrix-Matrix product - @test mul!(Y, D, X) ≈ Dmat * X - @test mul!(Y, U, X) ≈ Matrix(U) * Matrix(X) - @test mul!(Y, D, X) == D * X - @test mul!(Y, U, X) == U * X -end - -function BlockDiagonal_chol_tests(rng, blocks) - - D, Ps = block_diagonal(blocks), size.(blocks, 1) - Dmat = Matrix(D) - - C, Cmat = cholesky(D), cholesky(Dmat) - - @test C.U ≈ Cmat.U - @test logdet(C) ≈ logdet(Cmat) - - Csym = cholesky(Symmetric(D)) - @test C.U ≈ Csym.U - - # Test backprop for accessing `U`. - U_diag, back_diag = Zygote.pullback(D->cholesky(D).U, D) - U_dens, back_dens = Zygote.pullback(D->cholesky(D).U, Matrix(D)) - - @test U_diag ≈ U_dens - - Ū = block_diagonal([randn(rng, P, P) for P in Ps]) - D̄_diag = first(back_diag(Ū)) - D̄_dens = first(back_dens(Matrix(Ū))) - @test Matrix(D̄_diag) ≈ D̄_dens - @test D̄_diag isa BlockDiagonal - - # Test backprop for logdet of a Cholesky. - l_diag, l_back_diag = Zygote.pullback(D->logdet(cholesky(D)), D) - l_dens, l_back_dens = Zygote.pullback(D->logdet(cholesky(D)), Matrix(D)) - - @test l_diag ≈ l_dens - - l̄ = randn(rng) - D̄_diag = first(l_back_diag(l̄)) - D̄_dens = first(l_back_dens(l̄)) - @test Matrix(D̄_diag) ≈ D̄_dens - @test D̄_diag isa BlockDiagonal -end - -function BlockDiagonal_add_tests(rng, blks; grad=true) - - D = block_diagonal(blks) - Dmat = Matrix(D) - A = randn(rng, size(D)) - - A_copy = copy(A) - C = A_copy + D - @test A_copy == A - @test C == A + Dmat - - if grad == true - @assert length(blks) == 2 - adjoint_test( - (A, b1, b2)->A + block_diagonal([b1, b2]), - randn(rng, size(A)), A, blks[1], blks[2], - ) - end -end - -@timedtestset "BlockDiagonal" begin - @timedtestset "Matrix" begin - rng, Ps, Qs = MersenneTwister(123456), [2, 3], [4, 5] - vs = [randn(rng, Ps[1], Qs[1]), randn(rng, Ps[2], Qs[2])] - general_BlockDiagonal_tests(rng, vs) - - As = [randn(rng, Ps[n], Ps[n]) for n in eachindex(Ps)] - blks = [As[n] * As[n]' + I for n in eachindex(As)] - BlockDiagonal_mul_tests(rng, blks) - BlockDiagonal_mul_tests(rng, UpperTriangular.(blks)) - BlockDiagonal_mul_tests(rng, Hermitian.(blks)) - BlockDiagonal_mul_tests(rng, Symmetric.(blks)) - BlockDiagonal_chol_tests(rng, blks) - BlockDiagonal_add_tests(rng, blks; grad=false) - end - @timedtestset "Diagonal{T, <:Vector{T}}" begin - rng, Ps = MersenneTwister(123456), [2, 3] - vs = [Diagonal(randn(rng, Ps[n])) for n in eachindex(Ps)] - general_BlockDiagonal_tests(rng, vs) - - blocks = [Diagonal(ones(P) + exp.(randn(rng, P))) for P in Ps] - BlockDiagonal_add_tests(rng, blocks; grad=false) - @timedtestset "cholesky" begin - x, ȳ = randn(rng, sum(Ps)), randn(rng, sum(Ps)) - adjoint_test((X, blks)->cholesky(block_diagonal(blks)).U \ X, ȳ, x, blocks) - - X, Ȳ = randn(rng, sum(Ps), 7), randn(rng, sum(Ps), 7) - adjoint_test((X, blks)->cholesky(block_diagonal(blks)).U \ X, Ȳ, X, blocks) - adjoint_test(blks->logdet(cholesky(block_diagonal(blks))), randn(rng), blocks) - end - end - @timedtestset "Negation" begin - rng, Ps = MersenneTwister(123456), [4, 5, 6, 7] - A = block_diagonal([randn(rng, P, P) for P in Ps]) - - @test Matrix(-A) == -Matrix(A) - @test -A isa BlockDiagonal - - Y_diag, back_diag = Zygote.pullback(-, A) - Y_dens, back_dens = Zygote.pullback(-, Matrix(A)) - - Ȳ = block_diagonal([randn(rng, P, P) for P in Ps]) - @test Y_diag == -A - @test Matrix(first(back_diag(Ȳ))) == first(back_dens(Matrix(Ȳ))) - @test first(back_diag(Ȳ)) isa BlockDiagonal - end - @timedtestset "adjoint" begin - rng, Ps = MersenneTwister(123456), [4, 5, 6] - A = block_diagonal([randn(rng, P, P) for P in Ps]) - - @test Matrix(A') == Matrix(A)' - @test A' isa BlockDiagonal - - Y, back = Zygote.pullback(adjoint, A) - Y_dens, back_dens = Zygote.pullback(adjoint, Matrix(A)) - Ȳ = block_diagonal([randn(rng, P, P) for P in Ps]) - @test Y == A' - @test Matrix(first(back(Ȳ))) == first(back_dens(Matrix(Ȳ))) - @test first(back(Ȳ)) isa BlockDiagonal - end - @timedtestset "transpose" begin - rng, Ps = MersenneTwister(123456), [4, 5, 6] - A = block_diagonal([randn(rng, P, P) for P in Ps]) - - @test Matrix(transpose(A)) == transpose(Matrix(A)) - @test transpose(A) isa BlockDiagonal - - Y, back = Zygote.pullback(transpose, A) - Y_dens, back_dens = Zygote.pullback(transpose, Matrix(A)) - Ȳ = block_diagonal([randn(rng, P, P) for P in Ps]) - @test Y == transpose(A) - @test Matrix(first(back(Ȳ))) == first(back_dens(Matrix(Ȳ))) - @test first(back(Ȳ)) isa BlockDiagonal - end - @timedtestset "UpperTriangular" begin - rng, Ps = MersenneTwister(123456), [4, 5, 6] - A = block_diagonal([randn(rng, P, P) for P in Ps]) - - @test Matrix(UpperTriangular(A)) == UpperTriangular(Matrix(A)) - @test UpperTriangular(A) isa BlockDiagonal - - B_diag, back_diag = Zygote.pullback(UpperTriangular, A) - B_dens, back_dens = Zygote.pullback(UpperTriangular, Matrix(A)) - @test Matrix(B_diag) == B_dens - - B̄ = block_diagonal([randn(rng, P, P) for P in Ps]) - Ā_diag, Ā_dens = first(back_diag(B̄)), first(back_dens(Matrix(B̄))) - @test Ā_diag == Ā_dens - @test Ā_diag isa BlockDiagonal - end - @timedtestset "Symmetric" begin - rng, Ps = MersenneTwister(123456), [4, 5, 6] - A = block_diagonal([randn(rng, P, P) for P in Ps]) - S = Symmetric(A) - @test S == Symmetric(Matrix(A)) - @test S isa BlockDiagonal - - S_diag, back_diag = Zygote.pullback(Symmetric, A) - S_dens, back_dens = Zygote.pullback(Symmetric, Matrix(A)) - @test S_diag ≈ S_dens - end - @timedtestset "BlockDiagonal * BlockDiagonal" begin - rng, Ps = MersenneTwister(123456), [4, 5, 6] - A = block_diagonal([randn(rng, P, P) for P in Ps]) - B = block_diagonal([randn(rng, P, P) for P in Ps]) - - @test Matrix(A * B) ≈ Matrix(A) * Matrix(B) - @test A * B isa BlockDiagonal - - Y_diag, back_diag = Zygote.pullback(*, A, B) - Y_dens, back_dens = Zygote.pullback(*, Matrix(A), Matrix(B)) - - Ȳ = block_diagonal([randn(rng, P, P) for P in Ps]) - @test Y_diag == A * B - - Ā_diag, B̄_diag = back_diag(Ȳ) - Ā_dens, B̄_dens = back_dens(Matrix(Ȳ)) - - @test Matrix(Ā_diag) ≈ Ā_dens - @test Matrix(B̄_diag) ≈ B̄_dens - - @test Ā_diag isa BlockDiagonal - @test B̄_diag isa BlockDiagonal - end - @timedtestset "BlockDiagonal * Matrix" begin - rng, Ps, Q = MersenneTwister(123456), [4, 5, 6], 11 - A = block_diagonal([randn(rng, P, P) for P in Ps]) - B = randn(rng, sum(Ps), Q) - @test Matrix(A * B) ≈ Matrix(A) * B - @test Matrix(A * collect(B')') ≈ Matrix(A) * B - - Y_diag, back_diag = Zygote.pullback(*, A, B) - Y_dens, back_dens = Zygote.pullback(*, Matrix(A), B) - @test Y_diag ≈ Y_dens - - Ȳ = randn(rng, sum(Ps), Q) - Ā_diag, B̄_diag = back_diag(Ȳ) - Ā_dens, B̄_dens = back_dens(Ȳ) - @test Matrix(Ā_diag) ≈ Ā_dens - @test B̄_diag ≈ B̄_dens - @test_broken Ā_diag isa BlockDiagonal - end - @timedtestset "BlockDiagonal * Vector" begin - rng, Ps = MersenneTwister(123456), [4, 5, 6] - A = block_diagonal([randn(rng, P, P) for P in Ps]) - x = randn(rng, sum(Ps)) - @test Vector(A * x) ≈ Matrix(A) * x - - Y_diag, back_diag = Zygote.pullback(*, A, x) - Y_dens, back_dens = Zygote.pullback(*, Matrix(A), x) - @test Y_diag ≈ Y_dens - - ȳ = randn(rng, sum(Ps)) - Ā_diag, x̄_diag = back_diag(ȳ) - Ā_dens, x̄_dens = back_dens(ȳ) - @test Matrix(Ā_diag) ≈ Ā_dens - @test x̄_diag ≈ x̄_dens - @test_broken Ā_diag isa BlockDiagonal - end - @timedtestset "ldiv(BlockDiagonal, Matrix)" begin - rng, Ps, Q = MersenneTwister(123456), [4, 5, 6], 11 - A = block_diagonal([randn(rng, P, P) for P in Ps]) - B = randn(rng, sum(Ps), Q) - @test Matrix(A \ B) ≈ Matrix(A) \ B - - Y_diag, back_diag = Zygote.pullback(\, A, B) - Y_dens, back_dens = Zygote.pullback(\, Matrix(A), B) - @test Y_diag ≈ Y_dens - - Ȳ = randn(rng, sum(Ps), Q) - Ā_diag, B̄_diag = back_diag(Ȳ) - Ā_dens, B̄_dens = back_dens(Ȳ) - @test_broken Matrix(Ā_diag) ≈ Ā_dens # we're not checking the right bits of the matrix here - @test B̄_diag ≈ B̄_dens - @test Ā_diag isa BlockDiagonal - @test blocksizes(Ā_diag, 1) == blocksizes(A, 1) - @test blocksizes(Ā_diag, 2) == blocksizes(A, 2) - end - @timedtestset "ldiv(BlockDiagonal, Vector)" begin - rng, Ps = MersenneTwister(123456), [4, 5, 6] - A = block_diagonal([randn(rng, P, P) for P in Ps]) - B = randn(rng, sum(Ps)) - @test Vector(A \ B) ≈ Matrix(A) \ B - - Y_diag, back_diag = Zygote.pullback(\, A, B) - Y_dens, back_dens = Zygote.pullback(\, Matrix(A), B) - @test Y_diag ≈ Y_dens - - Ȳ = randn(rng, sum(Ps)) - Ā_diag, B̄_diag = back_diag(Ȳ) - Ā_dens, B̄_dens = back_dens(Ȳ) - @test_broken Matrix(Ā_diag) ≈ Ā_dens # we're not checking the right bits of the matrix here - @test B̄_diag ≈ B̄_dens - @test Ā_diag isa BlockDiagonal - @test blocksizes(Ā_diag, 1) == blocksizes(A, 1) - @test blocksizes(Ā_diag, 2) == blocksizes(A, 2) - end -end