-
Notifications
You must be signed in to change notification settings - Fork 19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add some interface functions to support the new Gibbs sampler in Turing #144
Closed
Closed
Changes from 53 commits
Commits
Show all changes
56 commits
Select commit
Hold shift + click to select a range
dcf1da9
very incomplete draft
sunxd3 cdaa663
update `getparams`
sunxd3 57275f5
Upstream `condition` and `decondition` from `AbstractPPL`
sunxd3 26027ea
remove `condition` and `decondition`
sunxd3 6ebab49
add Compat to make new interface functions public
sunxd3 e1099f9
bump minor version
sunxd3 95d781b
bump minor version instead
sunxd3 f05f293
unfinished gibbs example
sunxd3 590d37f
some updates
sunxd3 3afc232
more progress; still need to deal with w being on simplex
sunxd3 55dbab5
bit of format
sunxd3 67ff8e8
results is wrong
sunxd3 f758a4c
Apply suggestions from code review
sunxd3 7d0ba7c
add hierarchical normal problem
sunxd3 1ab6dd9
some updates; add doc
sunxd3 923c116
move folder into test
sunxd3 63028d3
setup as a test
sunxd3 44de81c
add to doc
sunxd3 be43178
format
sunxd3 1a6e0d5
bump patch version
sunxd3 6b60b72
reverse version bump -- already done
sunxd3 c58b39a
remove dep on `Compat`
sunxd3 ac0ce7a
updates to doc
sunxd3 280eaf1
update gibbs to add to the src folder
sunxd3 b262ea9
update mh code
sunxd3 c47ade4
update code further
sunxd3 8d29ad3
fix test errors
sunxd3 c28a75a
format
sunxd3 1382054
fix doctest error
sunxd3 8962d40
tidy up
sunxd3 dc6001c
updates
sunxd3 e194108
Update test/gibbs_example/mh.jl
sunxd3 64eb0e4
fix error
sunxd3 9361c39
typo fix
sunxd3 39c4d87
Update src/gibbs.jl
sunxd3 7f889cf
rename gibbs test file to prepare for moving
sunxd3 62a2332
move gibbs.jl
sunxd3 6132f0c
update code
sunxd3 af208bc
updates
sunxd3 fd472df
rework the code; still not type stable
sunxd3 4306aee
fix test
sunxd3 b798b2e
update doc -- need proofread
sunxd3 3ed5cb3
fix 1.6 struct field splatting compat issue
sunxd3 6fde198
update code and doc
sunxd3 c7f577d
relax test error
sunxd3 8f11a15
rename gibbs markdown file
sunxd3 48a160d
change title
sunxd3 8d74889
update code and note
sunxd3 bceb510
fix doc example
sunxd3 c177271
try to fix doc example error
sunxd3 bdba893
fix doc deps
sunxd3 e7e2870
fix more doc example error
sunxd3 80df187
minor update
sunxd3 076e431
Apply suggestions from code review
sunxd3 4293868
Update docs/src/state_interface.md
sunxd3 1cee0ab
Update docs/src/state_interface.md
sunxd3 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
# On `AbstractMCMC` Interface Supporting `Gibbs` | ||
|
||
This is written at Oct 1st, 2024. Version of packages described in this passage are: | ||
|
||
* `Turing.jl`: 0.34.1 | ||
|
||
In this passage, `Gibbs` refers to `Experimental.Gibbs`. | ||
|
||
## Current Implementation of `Gibbs` in `Turing` | ||
|
||
Here I describe the current implementation of `Gibbs` in `Turing` and the interface it requires from its sampler states. | ||
|
||
### Interface 1: `getparams` | ||
|
||
From the [definition of `GibbsState`](https://github.com/TuringLang/Turing.jl/blob/3c91eec43176d26048b810aae0f6f2fac0686cfa/src/experimental/gibbs.jl#L244-L248), we can see that a `vi::DynamicPPL.AbstractVarInfo` field is used to keep track of the names and values of parameters and the log density. The `states` field collects the sampler-specific *state*s. | ||
|
||
(The *link*ing of *varinfo*s is omitted in this discussion.) | ||
A local `VarInfo` is initially created with `DynamicPPL.subset(::VarInfo, ::Vector{<:VarName})` to make the conditioned model. After the Gibbs step, an updated `varinfo` is obtained by calling `Turing.Inference.varinfo` on the sampler state. | ||
|
||
For samplers and their states defined in `Turing` (including `DynamicHMC`, as `DynamicNUTSState` is defined by `Turing` in the package extension), we (à la `Turing.jl` package) assume that the *state*s all have a field called `vi`. Then `varinfo(_some_sampler_state_)` is simply `varinfo(state) = state.vi` (defined in [`src/mcmc/gibbs.jl`](https://github.com/TuringLang/Turing.jl/blob/3c91eec43176d26048b810aae0f6f2fac0686cfa/src/mcmc/gibbs.jl#L97)). (`GibbsState` conforms to this assumption.) | ||
|
||
For `ExternalSamplers`, we currently only support `AdvancedHMC` and `AdvancedMH`. The mechanism is as follows: at the end of the `step` call with an external sampler, [`transition_to_turing` and `state_to_turing` are called](https://github.com/TuringLang/Turing.jl/blob/3c91eec43176d26048b810aae0f6f2fac0686cfa/src/mcmc/abstractmcmc.jl#L147). These two functions then call `getparams` on the sampler state of the external samplers. `getparams` for `AdvancedHMC.HMCState` and `AdvancedMH.Transition` (`AdvancedMH` uses `Transition` as state) are defined in `abstractmcmc.jl`. | ||
|
||
Thus, the first interface emerges: `getparams`. As `getparams` is designed to be implemented by a sampler that works with the `LogDensityProblems` interface, it makes sense for `getparams` to return a vector of `Real`s. The `logdensity_problem` should then be responsible for performing the transformation between its underlying representation and the vector of `Real`s. | ||
|
||
It's worth noting that: | ||
|
||
* `getparams` is not a function specific for `Gibbs`. It is required for the current support of external samplers. | ||
* There is another [`getparams`](https://github.com/TuringLang/Turing.jl/blob/3c91eec43176d26048b810aae0f6f2fac0686cfa/src/mcmc/Inference.jl#L328-L351) in `Turing.jl` that takes *model* and *varinfo*, then returns a `NamedTuple`. | ||
|
||
### Interface 2: `recompute_logp!!` | ||
|
||
Consider a model with multiple groups of variables, say $\theta_1, \theta_2, \ldots, \theta_k$. At the beginning of the $t$-th Gibbs step, the model parameters in the `GibbsState` are typically updated and different from the $(t-1)$-th step. The `GibbsState` maintains $k$ sub-states, one for each variable group, denoted as $\text{state}_{t,1}, \text{state}_{t,2}, \ldots, \text{state}_{t,k}$. | ||
|
||
The parameter values in each sub-state, i.e., $\theta_{t,i}$ in $\text{state}_{t,i}$, are always in sync with the corresponding values in the `GibbsState`. At the end of the $t$-th Gibbs step, $\text{state}_{t,i}$ will store the log density of the $i$-th variable group conditioned on all other variable groups at their values from step $t$, denoted as $\log p(\theta_{t,i} \mid \theta_{t,-i})$. This log density is equal to the joint log density of the whole model evaluated at the current parameter values $(\theta_{t,1}, \ldots, \theta_{t,k})$. | ||
|
||
However, the log density stored in each sub-state is in general not equal to the log density needed for the next Gibbs step at $t+1$, i.e., $\log p(\theta_{t,i} \mid \theta_{t+1,-i})$. This is because the values of the other variable groups $\theta_{-i}$ will have been updated in the Gibbs step from $t$ to $t+1$, changing the conditioning set. Therefore, the log density typically needs to be recomputed at each Gibbs step to account for the updated values of the conditioning variables. | ||
|
||
Only in certain special cases, the recomputation can be skipped. For example, in a Metropolis-Hastings step where the proposal is rejected for all other variable groups, i.e., $\theta_{t+1,-i} = \theta_{t,-i}$, the log density $\log p(\theta_{t,i} \mid \theta_{t,-i})$ remains valid and doesn't need to be recomputed. | ||
|
||
The `recompute_logp!!` function in `abstractmcmc.jl` handles this recomputation. It takes an updated conditioned log density function $\log p(\cdot \mid \theta_{t+1,j})$ and the parameter values $\theta_{t,i}$ stored in $\text{state}_{t,i}$ to compute the updated log density $\log p(\theta_{t,i} \mid \theta_{t+1,j})$. | ||
|
||
## Proposed Interface | ||
|
||
The two functions `getparams` and `recompute_logp!!` form a minimal interface to support the `Gibbs` implementation. However, there are concerns about introducing them directly into `AbstractMCMC`. The main reason is that `AbstractMCMC` is a root dependency of the `Turing` packages, so we want to be very careful with new releases. | ||
|
||
Here, I propose some alternative functions that achieve the same functionality as `getparams` and `recompute_logp!!`, but without introducing new interface functions. | ||
sunxd3 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
For `getparams`, I propose we use `Base.vec`. It is a `Base` function, so there's no need to export anything from `AbstractMCMC`. Since `getparams` should return a vector, using `vec` makes sense. The concern is that, officially, `Base.vec` is defined for `AbstractArray`, so it remains a question whether we should only introduce `vec` in the absence of other `AbstractArray` interfaces. | ||
sunxd3 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
For `recompute_logp!!`, I propose we overload `LogDensityProblems.logdensity(logdensity_model::AbstractMCMC.LogDensityModel, state::State; recompute_logp=true)` to compute the log probability. If `recompute_logp` is `true`, it should recompute the log probability of the state. Otherwise, it could use the log probability stored in the state. To allow updating the log probability stored in the state, samplers should define outer constructor for their state types `StateType(state::StateType, logp)` that takes an existing `state` and a log probability value `logp`, and returns a new state of the same type with the updated log probability. | ||
sunxd3 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
While overloading `LogDensityProblems.logdensity` to take a state object instead of a vector for the second argument somewhat deviate from the interface in `LogDensityProblems`, I believe it provides a clean and extensible solution for handling log probability recomputation within the existing interface. | ||
sunxd3 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
An example demonstrating these interfaces is provided in `src/state_interface.md`. | ||
|
||
## A More Standalone `Gibbs` Implementation | ||
|
||
`Gibbs` should not manage a `variable name → sampler` but rather `range → sampler`, i.e. it maintain a vector of parameter values. while `logdensity_problem` should manage both the name and transformations. | ||
sunxd3 marked this conversation as resolved.
Show resolved
Hide resolved
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fair, but if we now make a release where we assume that certain functionality is overloaded, then that seems strictly worse, no?