-
Notifications
You must be signed in to change notification settings - Fork 8
/
abstractprobprog.jl
132 lines (97 loc) · 3.81 KB
/
abstractprobprog.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
using AbstractMCMC
using DensityInterface
using Random
using StatsBase
"""
AbstractProbabilisticProgram
Common base type for models expressed as probabilistic programs.
"""
abstract type AbstractProbabilisticProgram <: AbstractMCMC.AbstractModel end
DensityInterface.DensityKind(::AbstractProbabilisticProgram) = HasDensity()
"""
logdensityof(model, trace)
Evaluate the (possibly unnormalized) density of the model specified by the probabilistic program
in `model`, at specific values for the random variables given through `trace`.
`trace` can be of any supported internal trace type, or a fixed probability expression.
`logdensityof` should interact with conditioning and deconditioning in the way required by
probability theory.
"""
DensityInterface.logdensityof(::AbstractProbabilisticProgram, ::AbstractModelTrace)
"""
decondition(conditioned_model)
Remove the conditioning (i.e., observation data) from `conditioned_model`, turning it into a
generative model over prior and observed variables.
The invariant
```
m == condition(decondition(m), obs)
```
should hold for models `m` with conditioned variables `obs`.
"""
function decondition end
"""
condition(model, observations)
Condition the generative model `model` on some observed data, creating a new model of the (possibly
unnormalized) posterior distribution over them.
`observations` can be of any supported internal trace type, or a fixed probability expression.
The invariant
```
m = decondition(condition(m, obs))
```
should hold for generative models `m` and arbitrary `obs`.
"""
function condition end
"""
fix(model, params)
Fix the values of parameters specified in `params` within the probabilistic model `model`.
This operation is equivalent to treating the fixed parameters as being drawn from a point mass
distribution centered at the values specified in `params`. Thus these parameters no longer contribute
to the accumulated log density.
Conceptually, this is similar to Pearl's do-operator in causal inference, where we intervene
on variables by setting them to specific values, effectively cutting off their dependencies
on their usual causes in the model.
The invariant
```
m == unfix(fix(m, params))
```
should hold for any model `m` and parameters `params`.
"""
function fix end
"""
unfix(model)
Remove any fixed parameters from the model `model`, returning a new model without the fixed parameters.
This function reverses the effect of `fix` by removing parameter constraints that were previously set.
It returns a new model where all previously fixed parameters are allowed to vary according to their
original distributions in the model.
The invariant
```
m == unfix(fix(m, params))
```
should hold for any model `m` and parameters `params`.
"""
function unfix end
"""
rand([rng=Random.default_rng()], [T=NamedTuple], model::AbstractProbabilisticProgram) -> T
Draw a sample from the joint distribution of the model specified by the probabilistic program.
The sample will be returned as format specified by `T`.
"""
Base.rand(rng::Random.AbstractRNG, ::Type, model::AbstractProbabilisticProgram)
function Base.rand(rng::Random.AbstractRNG, model::AbstractProbabilisticProgram)
return rand(rng, NamedTuple, model)
end
function Base.rand(::Type{T}, model::AbstractProbabilisticProgram) where {T}
return rand(Random.default_rng(), T, model)
end
function Base.rand(model::AbstractProbabilisticProgram)
return rand(Random.default_rng(), NamedTuple, model)
end
"""
predict(
[rng::AbstractRNG=Random.default_rng(),]
model::AbstractProbabilisticProgram,
params,
)
Draw a sample from the predictive distribution specified by `model` with its parameters fixed to `params`.
"""
function StatsBase.predict(model::AbstractProbabilisticProgram, params)
return StatsBase.predict(Random.default_rng(), model, params)
end