diff --git a/src/hamiltonian.jl b/src/hamiltonian.jl index f2bbd230..4047f30e 100644 --- a/src/hamiltonian.jl +++ b/src/hamiltonian.jl @@ -42,8 +42,11 @@ end ∂H∂r(h::Hamiltonian{<:UnitEuclideanMetric,<:GaussianKinetic}, r::AbstractVecOrMat) = copy(r) ∂H∂r(h::Hamiltonian{<:DiagEuclideanMetric,<:GaussianKinetic}, r::AbstractVecOrMat) = h.metric.M⁻¹ .* r -∂H∂r(h::Hamiltonian{<:DenseEuclideanMetric,<:GaussianKinetic}, r::AbstractVecOrMat) = - h.metric.M⁻¹ * r +function ∂H∂r(h::Hamiltonian{<:DenseEuclideanMetric,<:GaussianKinetic}, r::AbstractVecOrMat) + out = similar(r) # Make sure the output of this function is of the same type as r + mul!(out, h.metric.M⁻¹, r) + out +end struct PhasePoint{T<:AbstractVecOrMat{<:AbstractFloat},V<:DualValue} θ::T # Position variables / model parameters. @@ -51,7 +54,7 @@ struct PhasePoint{T<:AbstractVecOrMat{<:AbstractFloat},V<:DualValue} ℓπ::V # Cached neg potential energy for the current θ. ℓκ::V # Cached neg kinect energy for the current r. function PhasePoint(θ::T, r::T, ℓπ::V, ℓκ::V) where {T,V} - @argcheck length(θ) == length(r) == length(ℓπ.gradient) == length(ℓπ.gradient) + @argcheck length(θ) == length(r) == length(ℓπ.gradient) == length(ℓκ.gradient) if any(isfinite.((θ, r, ℓπ, ℓκ)) .== false) # @warn "The current proposal will be rejected due to numerical error(s)." isfinite.((θ, r, ℓπ, ℓκ)) # NOTE eltype has to be inlined to avoid type stability issue; see #267 diff --git a/test/hamiltonian.jl b/test/hamiltonian.jl index 47948325..fd23e784 100644 --- a/test/hamiltonian.jl +++ b/test/hamiltonian.jl @@ -1,6 +1,7 @@ using ReTest, AdvancedHMC using AdvancedHMC: GaussianKinetic, DualValue, PhasePoint using LinearAlgebra: dot, diagm +using ComponentArrays @testset "Hamiltonian" begin f = x -> dot(x, x) @@ -38,7 +39,7 @@ end end end -@testset "Metric" begin +@testset "Metric Base Array" begin n_tests = 10 for T in [Float32, Float64] @@ -49,18 +50,50 @@ end h = Hamiltonian(UnitEuclideanMetric(T, D), ℓπ, ∂ℓπ∂θ) @test -AdvancedHMC.neg_energy(h, r_init, θ_init) == sum(abs2, r_init) / 2 @test AdvancedHMC.∂H∂r(h, r_init) == r_init + @test typeof(AdvancedHMC.∂H∂r(h, r_init)) == typeof(r_init) M⁻¹ = ones(T, D) + abs.(randn(T, D)) h = Hamiltonian(DiagEuclideanMetric(M⁻¹), ℓπ, ∂ℓπ∂θ) @test -AdvancedHMC.neg_energy(h, r_init, θ_init) ≈ r_init' * diagm(0 => M⁻¹) * r_init / 2 @test AdvancedHMC.∂H∂r(h, r_init) == M⁻¹ .* r_init + @test typeof(AdvancedHMC.∂H∂r(h, r_init)) == typeof(r_init) m = randn(T, D, D) M⁻¹ = m' * m h = Hamiltonian(DenseEuclideanMetric(M⁻¹), ℓπ, ∂ℓπ∂θ) @test -AdvancedHMC.neg_energy(h, r_init, θ_init) ≈ r_init' * M⁻¹ * r_init / 2 @test AdvancedHMC.∂H∂r(h, r_init) == M⁻¹ * r_init + @test typeof(AdvancedHMC.∂H∂r(h, r_init)) == typeof(r_init) + end + end +end + +@testset "Metric ComponentArrays" begin + n_tests = 10 + for T in [Float32, Float64] + for _ = 1:n_tests + θ_init = ComponentArray(a = randn(T, D), b = randn(T, D)) + r_init = ComponentArray(a = randn(T, D), b = randn(T, D)) + + h = Hamiltonian(UnitEuclideanMetric(T, 2*D), ℓπ, ∂ℓπ∂θ) + @test -AdvancedHMC.neg_energy(h, r_init, θ_init) == sum(abs2, r_init) / 2 + @test AdvancedHMC.∂H∂r(h, r_init) == r_init + @test typeof(AdvancedHMC.∂H∂r(h, r_init)) == typeof(r_init) + + M⁻¹ = ones(T, 2*D) + abs.(randn(T, 2*D)) + h = Hamiltonian(DiagEuclideanMetric(M⁻¹), ℓπ, ∂ℓπ∂θ) + @test -AdvancedHMC.neg_energy(h, r_init, θ_init) ≈ + r_init' * diagm(0 => M⁻¹) * r_init / 2 + @test AdvancedHMC.∂H∂r(h, r_init) == M⁻¹ .* r_init + @test typeof(AdvancedHMC.∂H∂r(h, r_init)) == typeof(r_init) + + m = randn(T, 2*D, 2*D) + M⁻¹ = m' * m + h = Hamiltonian(DenseEuclideanMetric(M⁻¹), ℓπ, ∂ℓπ∂θ) + @test -AdvancedHMC.neg_energy(h, r_init, θ_init) ≈ r_init' * M⁻¹ * r_init / 2 + @test all(AdvancedHMC.∂H∂r(h, r_init) .== M⁻¹ * r_init) + @test typeof(AdvancedHMC.∂H∂r(h, r_init)) == typeof(r_init) end end end