Skip to content

Commit

Permalink
update documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
Red-Portal committed Dec 24, 2024
1 parent f40df75 commit 97f64e1
Show file tree
Hide file tree
Showing 7 changed files with 33 additions and 17 deletions.
9 changes: 6 additions & 3 deletions docs/src/elbo/repgradelbo.md
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,8 @@ _, _, stats_cfe, _ = AdvancedVI.optimize(
max_iter;
show_progress = false,
adtype = AutoForwardDiff(),
optimizer = ProjectScale(Optimisers.Adam(3e-3)),
optimizer = Optimisers.Adam(3e-3),
operator = ClipScale(),
callback = callback,
);
Expand All @@ -229,7 +230,8 @@ _, _, stats_stl, _ = AdvancedVI.optimize(
max_iter;
show_progress = false,
adtype = AutoForwardDiff(),
optimizer = ProjectScale(Optimisers.Adam(3e-3)),
optimizer = Optimisers.Adam(3e-3),
operator = ClipScale(),
callback = callback,
);
Expand Down Expand Up @@ -316,7 +318,8 @@ _, _, stats_qmc, _ = AdvancedVI.optimize(
max_iter;
show_progress = false,
adtype = AutoForwardDiff(),
optimizer = ProjectScale(Optimisers.Adam(3e-3)),
optimizer = Optimisers.Adam(3e-3),
operator = ClipScale(),
callback = callback,
);
Expand Down
7 changes: 4 additions & 3 deletions docs/src/examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,14 @@ q_avg_trans, q_trans, stats, _ = AdvancedVI.optimize(
n_max_iter;
show_progress=false,
adtype=AutoForwardDiff(),
optimizer=ProjectScale(Optimisers.Adam(1e-3)),
optimizer=Optimisers.Adam(1e-3),
operator=ClipScale(),
);
nothing
```

`ProjectScale` is a wrapper around an optimization rule such that the variational approximation stays within a stable region of the variational family.
For more information see [this section](@ref projectscale).
`ClipScale` is a projection operator, which ensures that the variational approximation stays within a stable region of the variational family.
For more information see [this section](@ref clipscale).

`q_avg_trans` is the final output of the optimization procedure.
If a parameter averaging strategy is used through the keyword argument `averager`, `q_avg_trans` is be the output of the averaging strategy, while `q_trans` is the last iterate.
Expand Down
10 changes: 0 additions & 10 deletions docs/src/families.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,6 @@ FullRankGaussian
MeanFieldGaussian
```

### [Scale Projection Operator](@id projectscale)

For the location scale, it is often the case that optimization is stable only when the smallest eigenvalue of the scale matrix is strictly positive[^D2020].
To ensure this, we provide the following wrapper around optimization rule:

```@docs
ProjectScale
```

[^D2020]: Domke, J. (2020). Provable smoothness guarantees for black-box variational inference. In *International Conference on Machine Learning*.
### Gaussian Variational Families

```julia
Expand Down
16 changes: 16 additions & 0 deletions docs/src/optimization.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,19 @@ PolynomialAveraging
[^DCAMHV2020]: Dhaka, A. K., Catalina, A., Andersen, M. R., Magnusson, M., Huggins, J., & Vehtari, A. (2020). Robust, accurate stochastic optimization for variational inference. Advances in Neural Information Processing Systems, 33, 10961-10973.
[^KMJ2024]: Khaled, A., Mishchenko, K., & Jin, C. (2023). Dowg unleashed: An efficient universal parameter-free gradient descent method. Advances in Neural Information Processing Systems, 36, 6748-6769.
[^IHC2023]: Ivgi, M., Hinder, O., & Carmon, Y. (2023). Dog is sgd's best friend: A parameter-free dynamic step size schedule. In International Conference on Machine Learning (pp. 14465-14499). PMLR.

## Operators

Depending on the variational family, variational objective, and optimization strategy, it might be necessary to modify the variational parameters after performing a gradient-based update.
For this, an operator acting on the parameters can be supplied via the `operator` keyword argument of `optimize`.

### `ClipScale` (@id clipscale)

For the location scale, it is often the case that optimization is stable only when the smallest eigenvalue of the scale matrix is strictly positive[^D2020].
To ensure this, we provide the following projection operator:

```@docs
ClipScale
```

[^D2020]: Domke, J. (2020). Provable smoothness guarantees for black-box variational inference. In *International Conference on Machine Learning*.
5 changes: 5 additions & 0 deletions src/AdvancedVI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,11 @@ Apply operator `op` on the variational parameters `params`. For instance, `op` c
"""
function operate end

"""
IdentityOperator()
Identity operator.
"""
struct IdentityOperator <: AbstractOperator end

operate(::IdentityOperator, family, params, restructure) = params
Expand Down
2 changes: 1 addition & 1 deletion src/optimization/clip_scale.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"""
ClipScale(ϵ = 1e-5)
Apply a projection ensuring that an `MvLocationScale` or `MvLocationScaleLowRank` has a scale with eigenvalues larger than `ϵ`.
Projection operator ensuring that an `MvLocationScale` or `MvLocationScaleLowRank` has a scale with eigenvalues larger than `ϵ`.
`ClipScale` also supports by operating on `MvLocationScale` and `MvLocationScaleLowRank` wrapped by a `Bijectors.TransformedDistribution` object.
"""
Optimisers.@def struct ClipScale <: AbstractOperator
Expand Down
1 change: 1 addition & 0 deletions src/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ This requires the variational approximation to be marked as a functor through `F
- `adtype::ADtypes.AbstractADType`: Automatic differentiation backend.
- `optimizer::Optimisers.AbstractRule`: Optimizer used for inference. (Default: `Adam`.)
- `averager::AbstractAverager` : Parameter averaging strategy. (Default: `NoAveraging()`)
- `operator::AbstractOperator` : Operator applied to the parameters after each optimization step. (Default: `IdentityOperator()`)
- `rng::AbstractRNG`: Random number generator. (Default: `Random.default_rng()`.)
- `show_progress::Bool`: Whether to show the progress bar. (Default: `true`.)
- `callback`: Callback function called after every iteration. See further information below. (Default: `nothing`.)
Expand Down

0 comments on commit 97f64e1

Please sign in to comment.