Skip to content

Commit

Permalink
PhasePoint constructor bug when using GPU (#267)
Browse files Browse the repository at this point in the history
* Fix PhasePoint constructor GPU bug

* Fix typo

* Update hamiltonian.jl

Co-authored-by: Kai Xu <[email protected]>
  • Loading branch information
treigerm and xukai92 authored Jun 24, 2021
1 parent c8feff3 commit 6cec3da
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 4 deletions.
12 changes: 9 additions & 3 deletions src/hamiltonian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,15 @@ struct PhasePoint{T<:AbstractVecOrMat{<:AbstractFloat}, V<:DualValue}
@argcheck length(θ) == length(r) == length(ℓπ.gradient) == length(ℓπ.gradient)
if any(isfinite.((θ, r, ℓπ, ℓκ)) .== false)
@warn "The current proposal will be rejected due to numerical error(s)." isfinite.((θ, r, ℓπ, ℓκ))
E = eltype(T)
ℓπ = DualValue(map(v -> isfinite(v) ? v : -E(Inf), ℓπ.value), ℓπ.gradient)
ℓκ = DualValue(map(v -> isfinite(v) ? v : -E(Inf), ℓκ.value), ℓκ.gradient)
# NOTE eltype has to be inlined to avoid type stability issue; see #267
ℓπ = DualValue(
map(v -> isfinite(v) ? v : -eltype(T)(Inf), ℓπ.value),
ℓπ.gradient
)
ℓκ = DualValue(
map(v -> isfinite(v) ? v : -eltype(T)(Inf), ℓκ.value),
ℓκ.gradient
)
end
new{T,V}(θ, r, ℓπ, ℓκ)
end
Expand Down
27 changes: 27 additions & 0 deletions test/cuda.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Test
using AdvancedHMC
using AdvancedHMC: DualValue, PhasePoint
using CUDA

@testset "AdvancedHMC GPU" begin
Expand All @@ -21,4 +22,30 @@ using CUDA
proposal = HMCKernel(Trajectory{EndPointTS}(integrator, FixedNSteps(5)))

samples, stats = sample(hamiltonian, proposal, θ₀, n_samples)
end

@testset "PhasePoint GPU" begin
for T in [Float32, Float64]
init_z1() = PhasePoint(
CuArray([T(NaN) T(NaN)]),
CuArray([T(NaN) T(NaN)]),
DualValue(CuArray(zeros(T, 2)), CuArray(zeros(T, 1, 2))),
DualValue(CuArray(zeros(T, 2)), CuArray(zeros(T, 1, 2)))
)
init_z2() = PhasePoint(
CuArray([T(Inf) T(Inf)]),
CuArray([T(Inf) T(Inf)]),
DualValue(CuArray(zeros(T, 2)), CuArray(zeros(T, 1, 2))),
DualValue(CuArray(zeros(T, 2)), CuArray(zeros(T, 1, 2)))
)

@test_logs (:warn, "The current proposal will be rejected due to numerical error(s).") init_z1()
@test_logs (:warn, "The current proposal will be rejected due to numerical error(s).") init_z2()

z1 = init_z1()
z2 = init_z2()

@test z1.ℓπ.value == z2.ℓπ.value
@test z1.ℓκ.value == z2.ℓκ.value
end
end
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using Distributed, Test, CUDA

println("Envronment variables for testing")
println("Environment variables for testing")
println(ENV)

@testset "AdvancedHMC" begin
Expand Down

0 comments on commit 6cec3da

Please sign in to comment.