Skip to content

Commit

Permalink
Fixes to Bayes SDE notebook (#449)
Browse files Browse the repository at this point in the history
* Updated retcode success check to the current version used by SciMLBase. Introduced new noisy observations that are better suited for the problem, and corrected the model to calculate the likelihood based on multiple trajectories rather than a single trajectory.

* Added more explanation on the likelihood calculation

---------

Co-authored-by: Gabriel Gress <[email protected]>
  • Loading branch information
gjgress and Gabriel Gress authored May 27, 2024
1 parent d901459 commit 5df81f3
Showing 1 changed file with 19 additions and 8 deletions.
27 changes: 19 additions & 8 deletions tutorials/10-bayesian-stochastic-differential-equations/index.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Pkg.instantiate();
```{julia}
using Turing
using DifferentialEquations
using DifferentialEquations.EnsembleAnalysis
# Load StatsPlots for visualizations and diagnostics.
using StatsPlots
Expand Down Expand Up @@ -117,6 +118,10 @@ prob_sde = SDEProblem(lotka_volterra!, multiplicative_noise!, u0, tspan, p)
ensembleprob = EnsembleProblem(prob_sde)
data = solve(ensembleprob, SOSRI(); saveat=0.1, trajectories=1000)
plot(EnsembleSummary(data))
# We generate new noisy observations based on the stochastic model for the parameter estimation tasks in this tutorial.
# We create our observations by adding random normally distributed noise to the mean of the ensemble simulation.
sdedata = reduce(hcat, timeseries_steps_mean(data).u) + 0.8 * randn(size(reduce(hcat, timeseries_steps_mean(data).u)))
```

```{julia}
Expand All @@ -132,17 +137,24 @@ plot(EnsembleSummary(data))
# Simulate stochastic Lotka-Volterra model.
p = [α, β, γ, δ, ϕ1, ϕ2]
predicted = solve(prob, SOSRI(); p=p, saveat=0.1)
remake(prob, p = p)
ensembleprob = EnsembleProblem(prob)
predicted = solve(ensembleprob, SOSRI(); saveat=0.1, trajectories = 1000)
# Early exit if simulation could not be computed successfully.
if predicted.retcode !== :Success
Turing.@addlogprob! -Inf
return nothing
for i in 1:length(predicted)
if !SciMLBase.successful_retcode(predicted[i])
Turing.@addlogprob! -Inf
return nothing
end
end
# Observations.
for i in 1:length(predicted)
data[:, i] ~ MvNormal(predicted[i], σ^2 * I)
# We compute the likelihood for each trajectory of our simulation in order to better approximate the overall likelihood of our choice of parameters
for j in 1:length(predicted)
for i in 1:length(predicted[j])
data[:, i] ~ MvNormal(predicted[j][i], σ^2 * I)
end
end
return nothing
Expand All @@ -154,9 +166,8 @@ Therefore we use NUTS with a low target acceptance rate of `0.25` and specify a
SGHMC might be a more suitable algorithm to be used here.

```{julia}
model_sde = fitlv_sde(odedata, prob_sde)
model_sde = fitlv_sde(sdedata, prob_sde)
setadbackend(:forwarddiff)
chain_sde = sample(
model_sde,
NUTS(0.25),
Expand Down

0 comments on commit 5df81f3

Please sign in to comment.