From 7d78781ebcd3dbe6dbab22c52657bf773da4fc43 Mon Sep 17 00:00:00 2001 From: Xianda Sun <5433119+sunxd3@users.noreply.github.com> Date: Sat, 25 Feb 2023 09:47:49 -0800 Subject: [PATCH] Backport 3-arg `rand` to AbstractPPL 0.5.x (#80) * Back-port 3-arg `rand` interface. * Bump version for interface change. * Update CI.yml * Fix test errors. --------- Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> --- .github/workflows/CI.yml | 4 ++++ Project.toml | 2 +- src/abstractprobprog.jl | 20 ++++++++++++++++++++ src/graphinfo.jl | 10 +++++++--- src/varname.jl | 6 ++++-- test/abstractprobprog.jl | 37 +++++++++++++++++++++++++++++++++++++ test/runtests.jl | 3 ++- 7 files changed, 75 insertions(+), 7 deletions(-) create mode 100644 test/abstractprobprog.jl diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 261aec4..e01ec47 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -9,6 +9,10 @@ on: - trying # Build the main branch. - main + pull_request: + branches: + - main + - releases-0.5.x jobs: test: diff --git a/Project.toml b/Project.toml index 3da73be..b532662 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" keywords = ["probablistic programming"] license = "MIT" desc = "Common interfaces for probabilistic programming" -version = "0.5.3" +version = "0.5.4" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/src/abstractprobprog.jl b/src/abstractprobprog.jl index a6773fb..30b8b35 100644 --- a/src/abstractprobprog.jl +++ b/src/abstractprobprog.jl @@ -1,5 +1,6 @@ using AbstractMCMC using DensityInterface +using Random """ @@ -60,3 +61,22 @@ m = decondition(condition(m, obs)) should hold for generative models `m` and arbitrary `obs`. """ function condition end + + +""" + rand([rng=Random.default_rng()], [T=NamedTuple], model::AbstractProbabilisticProgram) -> T + +Draw a sample from the joint distribution of the model specified by the probabilistic program. + +The sample will be returned as format specified by `T`. +""" +Base.rand(rng::Random.AbstractRNG, ::Type, model::AbstractProbabilisticProgram) +function Base.rand(rng::Random.AbstractRNG, model::AbstractProbabilisticProgram) + return rand(rng, NamedTuple, model) +end +function Base.rand(::Type{T}, model::AbstractProbabilisticProgram) where {T} + return rand(Random.default_rng(), T, model) +end +function Base.rand(model::AbstractProbabilisticProgram) + return rand(Random.default_rng(), NamedTuple, model) +end diff --git a/src/graphinfo.jl b/src/graphinfo.jl index 4c415db..3bd7fcc 100644 --- a/src/graphinfo.jl +++ b/src/graphinfo.jl @@ -444,9 +444,9 @@ function Random.rand!(m::AbstractPPL.GraphPPL.Model{T}) where T end """ - rand!(rng::AbstractRNG, m::Model) + rand(m::Model) -Draw random samples from the model and mutate the node values. +Draw random samples from the model and return the samples as NamedTuple. # Examples @@ -470,11 +470,15 @@ julia> rand(m) (μ = 1.0, s2 = 1.0907695400401212, y = 0.05821954440386368) ``` """ -function Random.rand(rng::AbstractRNG, sm::Random.SamplerTrivial{Model{Tnames, Tinput, Tvalue, Teval, Tkind}}) where {Tnames, Tinput, Tvalue, Teval, Tkind} +function Base.rand(rng::AbstractRNG, sm::Random.SamplerTrivial{Model{Tnames, Tinput, Tvalue, Teval, Tkind}}) where {Tnames, Tinput, Tvalue, Teval, Tkind} m = deepcopy(sm[]) get_model_values(rand!(rng, m)) end +function Base.rand(rng::AbstractRNG, ::Type{NamedTuple}, m::Model) + rand(rng, Random.SamplerTrivial(m)) +end + """ logdensityof(m::Model) diff --git a/src/varname.jl b/src/varname.jl index 116ad94..04fb604 100644 --- a/src/varname.jl +++ b/src/varname.jl @@ -460,8 +460,10 @@ resolved as `VarName` only supports non-dynamic indexing as determined by ```jldoctest julia> # Dynamic indexing is not allowed in `VarName` @varname(x[end]) -ERROR: UndefVarError: x not defined -[...] +ERROR: UndefVarError: `x` not defined +Stacktrace: + [1] top-level scope + @ none:1 julia> # To be able to resolve `end` we need `x` to be available. x = randn(2); @varname(x[end]) diff --git a/test/abstractprobprog.jl b/test/abstractprobprog.jl new file mode 100644 index 0000000..00230be --- /dev/null +++ b/test/abstractprobprog.jl @@ -0,0 +1,37 @@ +using AbstractPPL +using Random +using Test + +mutable struct RandModel <: AbstractProbabilisticProgram + rng + T +end + +function Base.rand(rng::Random.AbstractRNG, ::Type{T}, model::RandModel) where {T} + model.rng = rng + model.T = T + return nothing +end + +@testset "AbstractProbabilisticProgram" begin + @testset "rand defaults" begin + model = RandModel(nothing, nothing) + rand(model) + @test model.rng == Random.default_rng() + @test model.T === NamedTuple + rngs = [Random.default_rng(), Random.MersenneTwister(42)] + Ts = [NamedTuple, Dict] + @testset for T in Ts + model = RandModel(nothing, nothing) + rand(T, model) + @test model.rng == Random.default_rng() + @test model.T === T + end + @testset for rng in rngs + model = RandModel(nothing, nothing) + rand(rng, model) + @test model.rng === rng + @test model.T === NamedTuple + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index ea1a9b9..33aa19a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -12,6 +12,7 @@ using Test @testset "AbstractPPL.jl" begin include("deprecations.jl") + include("abstractprobprog.jl") include("graphinfo/graphinfo.jl") @testset "doctests" begin DocMeta.setdocmeta!( @@ -20,6 +21,6 @@ using Test :(using AbstractPPL); recursive=true, ) - doctest(AbstractPPL; manual=false) + doctest(AbstractPPL; manual=false, fix=true) end end \ No newline at end of file