Skip to content

Commit

Permalink
fix bugs and add test
Browse files Browse the repository at this point in the history
  • Loading branch information
sunxd3 committed Dec 20, 2024
1 parent f1ffda1 commit cf380a8
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 2 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ uuid = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
keywords = ["probablistic programming"]
license = "MIT"
desc = "Common interfaces for probabilistic programming"
version = "0.10.0"
version = "0.10.1"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
2 changes: 1 addition & 1 deletion src/abstractprobprog.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,5 +128,5 @@ end
Draw a sample from the predictive distribution specified by `model` with its parameters fixed to `params`.
"""
function StatsBase.predict(model::AbstractProbabilisticProgram, params)
return predict(Random.default_rng(), NamedTuple, model, params)
return StatsBase.predict(Random.default_rng(), model, params)
end
11 changes: 11 additions & 0 deletions test/abstractprobprog.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ function Base.rand(rng::Random.AbstractRNG, ::Type{T}, model::RandModel) where {
return nothing
end

struct PredModel <: AbstractProbabilisticProgram end

function AbstractPPL.predict(rng::Random.AbstractRNG, model::PredModel, params)
return rng
end

@testset "AbstractProbabilisticProgram" begin
@testset "rand defaults" begin
model = RandModel(nothing, nothing)
Expand All @@ -34,4 +40,9 @@ end
@test model.T === NamedTuple
end
end

@testset "predict defaults" begin
model = PredModel()
@test AbstractPPL.predict(model, nothing) == Random.default_rng()
end
end

0 comments on commit cf380a8

Please sign in to comment.