From cf380a82420f068cf09851298fb950071290f7cf Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Fri, 20 Dec 2024 03:49:29 +0000 Subject: [PATCH] fix bugs and add test --- Project.toml | 2 +- src/abstractprobprog.jl | 2 +- test/abstractprobprog.jl | 11 +++++++++++ 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index df46559..c3723c8 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" keywords = ["probablistic programming"] license = "MIT" desc = "Common interfaces for probabilistic programming" -version = "0.10.0" +version = "0.10.1" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/src/abstractprobprog.jl b/src/abstractprobprog.jl index 32e125d..d6da0d8 100644 --- a/src/abstractprobprog.jl +++ b/src/abstractprobprog.jl @@ -128,5 +128,5 @@ end Draw a sample from the predictive distribution specified by `model` with its parameters fixed to `params`. """ function StatsBase.predict(model::AbstractProbabilisticProgram, params) - return predict(Random.default_rng(), NamedTuple, model, params) + return StatsBase.predict(Random.default_rng(), model, params) end diff --git a/test/abstractprobprog.jl b/test/abstractprobprog.jl index 00230be..083bed3 100644 --- a/test/abstractprobprog.jl +++ b/test/abstractprobprog.jl @@ -13,6 +13,12 @@ function Base.rand(rng::Random.AbstractRNG, ::Type{T}, model::RandModel) where { return nothing end +struct PredModel <: AbstractProbabilisticProgram end + +function AbstractPPL.predict(rng::Random.AbstractRNG, model::PredModel, params) + return rng +end + @testset "AbstractProbabilisticProgram" begin @testset "rand defaults" begin model = RandModel(nothing, nothing) @@ -34,4 +40,9 @@ end @test model.T === NamedTuple end end + + @testset "predict defaults" begin + model = PredModel() + @test AbstractPPL.predict(model, nothing) == Random.default_rng() + end end