@@ -13,11 +13,9 @@ function TuringDenseMvNormal(m::AbstractVector, A::AbstractMatrix)
13
13
return TuringDenseMvNormal (m, cholesky (A))
14
14
end
15
15
Base. length (d:: TuringDenseMvNormal ) = length (d. m)
16
- function Distributions. rand (rng:: Random.AbstractRNG , d:: TuringDenseMvNormal )
17
- return d. m .+ d. C. U' * randn (rng, length (d))
18
- end
19
- function Distributions. rand (rng:: Random.AbstractRNG , d:: TuringDenseMvNormal , n:: Int )
20
- return d. m .+ d. C. U' * randn (rng, length (d), n)
16
+ Distributions. rand (d:: TuringDenseMvNormal , n:: Int... ) = rand (Random. GLOBAL_RNG, d, n... )
17
+ function Distributions. rand (rng:: Random.AbstractRNG , d:: TuringDenseMvNormal , n:: Int... )
18
+ return d. m .+ d. C. U' * randn (rng, length (d), n... )
21
19
end
22
20
23
21
"""
32
30
33
31
Base. length (d:: TuringDiagMvNormal ) = length (d. m)
34
32
Base. size (d:: TuringDiagMvNormal ) = (length (d), length (d))
35
- function Distributions. rand (rng:: Random.AbstractRNG , d:: TuringDiagMvNormal )
36
- return d. m .+ d. σ .* randn (rng, length (d))
37
- end
38
- function Distributions. rand (rng:: Random.AbstractRNG , d:: TuringDiagMvNormal , n:: Int )
39
- return d. m .+ d. σ .* randn (rng, length (d), n)
40
- end
41
-
42
- struct TuringScalMvNormal{Tm<: AbstractVector , Tσ<: Real } <: ContinuousMultivariateDistribution
43
- m:: Tm
44
- σ:: T σ
45
- end
46
-
47
- Base. length (d:: TuringScalMvNormal ) = length (d. m)
48
- Base. size (d:: TuringScalMvNormal ) = (length (d), length (d))
49
- function Distributions. rand (rng:: Random.AbstractRNG , d:: TuringScalMvNormal )
50
- return d. m .+ d. σ .* randn (rng, length (d))
51
- end
52
- function Distributions. rand (rng:: Random.AbstractRNG , d:: TuringScalMvNormal , n:: Int )
53
- return d. m .+ d. σ .* randn (rng, length (d), n)
54
-
55
- Base. length (d:: TuringDiagMvNormal ) = length (d. m)
56
- Base. size (d:: TuringDiagMvNormal ) = (length (d), length (d))
57
- function Distributions. rand (rng:: Random.AbstractRNG , d:: TuringDiagMvNormal )
58
- return d. m .+ d. σ .* randn (rng, length (d))
33
+ Distributions. rand (d:: TuringDiagMvNormal , n:: Int... ) = rand (Random. GLOBAL_RNG, d, n... )
34
+ function Distributions. rand (rng:: Random.AbstractRNG , d:: TuringDiagMvNormal , n:: Int... )
35
+ return d. m .+ d. σ .* randn (rng, length (d), n... )
59
36
end
60
37
61
-
62
38
struct TuringScalMvNormal{Tm<: AbstractVector , Tσ<: Real } <: ContinuousMultivariateDistribution
63
39
m:: Tm
64
40
σ:: T σ
65
41
end
66
42
67
43
Base. length (d:: TuringScalMvNormal ) = length (d. m)
68
44
Base. size (d:: TuringScalMvNormal ) = (length (d), length (d))
69
- function Distributions. rand (rng :: Random.AbstractRNG , d:: TuringScalMvNormal )
70
- return d . m .+ d . σ .* randn (rng, length (d) )
71
- >>>>>> > multiple distributions as one
45
+ Distributions. rand (d :: TuringScalMvNormal , n :: Int... ) = rand ( Random. GLOBAL_RNG , d, n ... )
46
+ function Distributions . rand (rng:: Random.AbstractRNG , d :: TuringScalMvNormal , n :: Int... )
47
+ return d . m .+ d . σ .* randn (rng, length (d), n ... )
72
48
end
73
49
74
50
for T in (:AbstractVector , :AbstractMatrix )
@@ -95,31 +71,6 @@ function _logpdf(d::TuringDenseMvNormal, x::AbstractVector)
95
71
end
96
72
function _logpdf (d:: TuringDenseMvNormal , x:: AbstractMatrix )
97
73
return - ((size (x, 1 ) * log (2 π) + logdet (d. C)) .+ vec (sum (abs2 .(zygote_ldiv (d. C. U' , x .- d. m)), dims= 1 ))) ./ 2
98
- ====== =
99
- for T in (:TrackedVector , :TrackedMatrix )
100
- @eval function Distributions. logpdf (d:: MvNormal{<:Any, <:PDMats.ScalMat} , x:: $T )
101
- logpdf (TuringScalMvNormal (d. μ, d. Σ. value), x)
102
- end
103
- end
104
-
105
- function _logpdf (d:: TuringScalMvNormal , x:: AbstractVector )
106
- return - (length (x) * log (2 π) + 2 * sum (log (d. σ)) + sum (abs2, (x .- d. m) ./ d. σ)) / 2
107
- end
108
- function _logpdf (d:: TuringScalMvNormal , x:: AbstractMatrix )
109
- return - (size (x, 2 ) * log (2 π) .+ 2 * sum (log (d. σ)) .+ sum (abs2, (x .- d. m) ./ d. σ, dims= 1 )' ) ./ 2
110
- end
111
-
112
- function _logpdf (d:: TuringDiagMvNormal , x:: AbstractVector )
113
- return - (length (x) * log (2 π) + 2 * sum (log .(d. σ)) + sum (abs2, (x .- d. m) ./ d. σ)) / 2
114
- end
115
- function _logpdf (d:: TuringDiagMvNormal , x:: AbstractMatrix )
116
- return - (size (x, 2 ) * log (2 π) .+ 2 * sum (log .(d. σ)) .+ sum (abs2, (x .- d. m) ./ d. σ, dims= 1 )' ) ./ 2
117
- end
118
- function _logpdf (d:: TuringDenseMvNormal , x:: AbstractVector )
119
- return - (length (x) * log (2 π) + logdet (d. C) + sum (abs2, zygote_ldiv (d. C. U' , x .- d. m))) / 2
120
- end
121
- function _logpdf (d:: TuringDenseMvNormal , x:: AbstractMatrix )
122
- return - (size (x, 2 ) * log (2 π) .+ logdet (d. C) .+ sum (abs2, zygote_ldiv (d. C. U' , x .- d. m), dims= 1 )' ) ./ 2
123
74
end
124
75
125
76
# zero mean, dense covariance
@@ -129,9 +80,9 @@ MvNormal(A::TrackedMatrix) = TuringMvNormal(A)
129
80
MvNormal (σ:: TrackedVector ) = TuringMvNormal (σ)
130
81
131
82
# dense mean, dense covariance
132
- MvNormal (m:: TrackedVector{<:Real} , A:: TrackedMatrix{<:Real} ) = TuringDenseMvNormal (m, A)
133
- MvNormal (m:: TrackedVector{<:Real} , A:: Matrix{<:Real} ) = TuringDenseMvNormal (m, A)
134
- MvNormal (m:: AbstractVector{<:Real} , A:: TrackedMatrix{<:Real} ) = TuringDenseMvNormal (m, A)
83
+ MvNormal (m:: TrackedVector{<:Real} , A:: TrackedMatrix{<:Real} ) = TuringMvNormal (m, A)
84
+ MvNormal (m:: TrackedVector{<:Real} , A:: Matrix{<:Real} ) = TuringMvNormal (m, A)
85
+ MvNormal (m:: AbstractVector{<:Real} , A:: TrackedMatrix{<:Real} ) = TuringMvNormal (m, A)
135
86
136
87
# dense mean, diagonal covariance
137
88
function MvNormal (
@@ -237,9 +188,9 @@ MvLogNormal(A::TrackedMatrix) = TuringMvLogNormal(TuringMvNormal(A))
237
188
MvLogNormal (σ:: TrackedVector ) = TuringMvLogNormal (TuringMvNormal (σ))
238
189
239
190
# dense mean, dense covariance
240
- MvLogNormal (m:: TrackedVector{<:Real} , A:: TrackedMatrix{<:Real} ) = TuringMvLogNormal (TuringDenseMvNormal (m, A))
241
- MvLogNormal (m:: TrackedVector{<:Real} , A:: Matrix{<:Real} ) = TuringMvLogNormal (TuringDenseMvNormal (m, A))
242
- MvLogNormal (m:: AbstractVector{<:Real} , A:: TrackedMatrix{<:Real} ) = TuringMvLogNormal (TuringDenseMvNormal (m, A))
191
+ MvLogNormal (m:: TrackedVector{<:Real} , A:: TrackedMatrix{<:Real} ) = TuringMvLogNormal (TuringMvNormal (m, A))
192
+ MvLogNormal (m:: TrackedVector{<:Real} , A:: Matrix{<:Real} ) = TuringMvLogNormal (TuringMvNormal (m, A))
193
+ MvLogNormal (m:: AbstractVector{<:Real} , A:: TrackedMatrix{<:Real} ) = TuringMvLogNormal (TuringMvNormal (m, A))
243
194
244
195
# dense mean, diagonal covariance
245
196
function MvLogNormal (
0 commit comments