From f1ffda114137b196559b17655d257d35b4cb6f71 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 16 Dec 2024 19:36:41 +0100 Subject: [PATCH] Add StatsBase.predict to the interface (#81) * Add StatsBase as a dependency * Implement StatsBase.predict * use `fix` and fix some errors * Format Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Bump StatsBase compat * slim down implementations --------- Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> Co-authored-by: Xianda Sun <5433119+sunxd3@users.noreply.github.com> Co-authored-by: Xianda Sun Co-authored-by: Penelope Yong Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- Project.toml | 2 ++ src/abstractprobprog.jl | 14 ++++++++++++++ 2 files changed, 16 insertions(+) diff --git a/Project.toml b/Project.toml index db0655f..df46559 100644 --- a/Project.toml +++ b/Project.toml @@ -11,6 +11,7 @@ Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d" JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" [compat] AbstractMCMC = "2, 3, 4, 5" @@ -18,4 +19,5 @@ Accessors = "0.1" DensityInterface = "0.4" JSON = "0.19 - 0.21" Random = "1.6" +StatsBase = "0.32, 0.33, 0.34" julia = "~1.6.6, 1.7.3" diff --git a/src/abstractprobprog.jl b/src/abstractprobprog.jl index 07e5546..32e125d 100644 --- a/src/abstractprobprog.jl +++ b/src/abstractprobprog.jl @@ -1,6 +1,7 @@ using AbstractMCMC using DensityInterface using Random +using StatsBase """ AbstractProbabilisticProgram @@ -116,3 +117,16 @@ end function Base.rand(model::AbstractProbabilisticProgram) return rand(Random.default_rng(), NamedTuple, model) end + +""" + predict( + [rng::AbstractRNG=Random.default_rng(),] + model::AbstractProbabilisticProgram, + params, + ) + +Draw a sample from the predictive distribution specified by `model` with its parameters fixed to `params`. +""" +function StatsBase.predict(model::AbstractProbabilisticProgram, params) + return predict(Random.default_rng(), NamedTuple, model, params) +end