diff --git a/.github/workflows/IntegrationTest.yml b/.github/workflows/IntegrationTest.yml index cd3b3658..2e9d6bcf 100644 --- a/.github/workflows/IntegrationTest.yml +++ b/.github/workflows/IntegrationTest.yml @@ -14,6 +14,7 @@ jobs: fail-fast: false matrix: package: + - {user: TuringLang, repo: AdvancedHMC.jl} - {user: TuringLang, repo: AdvancedMH.jl} - {user: TuringLang, repo: EllipticalSliceSampling.jl} - {user: TuringLang, repo: MCMCChains.jl} diff --git a/Project.toml b/Project.toml index 4053aa83..99ef46a1 100644 --- a/Project.toml +++ b/Project.toml @@ -3,12 +3,13 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001" keywords = ["markov chain monte carlo", "probablistic programming"] license = "MIT" desc = "A lightweight interface for common MCMC methods." -version = "4.2" +version = "4.2.1" [deps] BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" ConsoleProgressMonitor = "88cd18e8-d9cc-4ea6-8889-5259c0d15c8b" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" +LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36" ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c" @@ -20,6 +21,7 @@ Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999" [compat] BangBang = "0.3.19" ConsoleProgressMonitor = "0.1" +LogDensityProblems = "2" LoggingExtras = "0.4, 0.5" ProgressLogging = "0.1" StatsBase = "0.32, 0.33" diff --git a/src/AbstractMCMC.jl b/src/AbstractMCMC.jl index 44f56a9c..64f20f97 100644 --- a/src/AbstractMCMC.jl +++ b/src/AbstractMCMC.jl @@ -2,6 +2,7 @@ module AbstractMCMC using BangBang: BangBang using ConsoleProgressMonitor: ConsoleProgressMonitor +using LogDensityProblems: LogDensityProblems using LoggingExtras: LoggingExtras using ProgressLogging: ProgressLogging using StatsBase: StatsBase diff --git a/src/logdensityproblems.jl b/src/logdensityproblems.jl index 98615cde..54db36bb 100644 --- a/src/logdensityproblems.jl +++ b/src/logdensityproblems.jl @@ -12,4 +12,16 @@ that the wrapped object implements the LogDensityProblems.jl interface. """ struct LogDensityModel{L} <: AbstractModel logdensity::L + function LogDensityModel{L}(logdensity::L) where {L} + if LogDensityProblems.capabilities(logdensity) === nothing + throw( + ArgumentError( + "The log density function does not support the LogDensityProblems.jl interface", + ), + ) + end + return new{L}(logdensity) + end end + +LogDensityModel(logdensity::L) where {L} = LogDensityModel{L}(logdensity)