-
Notifications
You must be signed in to change notification settings - Fork 19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
refactor interface for projections/proximal operators #147
Conversation
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Benchmark Results
Benchmark suite | Current: 635ea4e | Previous: 9887bb4 | Ratio |
---|---|---|---|
normal/RepGradELBO + STL/meanfield/Zygote |
15002944353 ns |
14589178867 ns |
1.03 |
normal/RepGradELBO + STL/meanfield/ForwardDiff |
3213209140 ns |
3168929241 ns |
1.01 |
normal/RepGradELBO + STL/meanfield/ReverseDiff |
3201755852 ns |
3202361222 ns |
1.00 |
normal/RepGradELBO + STL/fullrank/Zygote |
14902278767 ns |
14433568159 ns |
1.03 |
normal/RepGradELBO + STL/fullrank/ForwardDiff |
3604116961 ns |
3452371390 ns |
1.04 |
normal/RepGradELBO + STL/fullrank/ReverseDiff |
5831854767 ns |
5762008942 ns |
1.01 |
normal/RepGradELBO/meanfield/Zygote |
7098681152 ns |
6877924349 ns |
1.03 |
normal/RepGradELBO/meanfield/ForwardDiff |
2361672495 ns |
2305614010.5 ns |
1.02 |
normal/RepGradELBO/meanfield/ReverseDiff |
1459158489 ns |
1433074275 ns |
1.02 |
normal/RepGradELBO/fullrank/Zygote |
7131701632 ns |
6826219707 ns |
1.04 |
normal/RepGradELBO/fullrank/ForwardDiff |
2542847767 ns |
2553813983 ns |
1.00 |
normal/RepGradELBO/fullrank/ReverseDiff |
2679946317 ns |
2543813427 ns |
1.05 |
normal + bijector/RepGradELBO + STL/meanfield/Zygote |
23446836790 ns |
22908413254 ns |
1.02 |
normal + bijector/RepGradELBO + STL/meanfield/ForwardDiff |
10436881038 ns |
10230701628 ns |
1.02 |
normal + bijector/RepGradELBO + STL/meanfield/ReverseDiff |
5149189755 ns |
5103157400 ns |
1.01 |
normal + bijector/RepGradELBO + STL/fullrank/Zygote |
23488381369 ns |
22668345453 ns |
1.04 |
normal + bijector/RepGradELBO + STL/fullrank/ForwardDiff |
10953526681 ns |
10781493177 ns |
1.02 |
normal + bijector/RepGradELBO + STL/fullrank/ReverseDiff |
8405261867 ns |
8288136356 ns |
1.01 |
normal + bijector/RepGradELBO/meanfield/Zygote |
14855248592 ns |
14406529461 ns |
1.03 |
normal + bijector/RepGradELBO/meanfield/ForwardDiff |
9322626292 ns |
9026101677 ns |
1.03 |
normal + bijector/RepGradELBO/meanfield/ReverseDiff |
3143892806 ns |
3137765457 ns |
1.00 |
normal + bijector/RepGradELBO/fullrank/Zygote |
14925584523 ns |
14347124133 ns |
1.04 |
normal + bijector/RepGradELBO/fullrank/ForwardDiff |
9458708256 ns |
9935024364 ns |
0.95 |
normal + bijector/RepGradELBO/fullrank/ReverseDiff |
4589456910 ns |
4576538483 ns |
1.00 |
This comment was automatically generated by workflow using github-action-benchmark.
@yebai I'll mark the v0.3 release (at last!) after this PR |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we have a test where the eigenvalues drift too low and we can test both that it fails when nothing using ProjectScale
and that it then succeeds when using ProjectScale
? Just to very concretely see the effect, and see that the first case fails in the expected (rather than some other, unexpected) way.
Except for the above request, I'm happy with the software engineering. I would prefer it though if someone else who has views on the design choices here gave a second, approving opinion. I have little idea of what users need and want from their interfaces here, e.g. if the name ProjectScale
is intuitive for users, or if there should be a warning if someone tries to optimise a LocationScale
without using ProjectScale
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wrapping an optimiser inside ProjectScale(...)
feels slightly strange to me. While using ProjectScale
might be appropriate for a specific paper, but the terminology is not (yet) widely accepted. I think we could introduce an additional keyword argument to pass this information instead of overloading the optimiser argument for too many purposes.
Thank you both for chiming in! @yebai I was thinking this to be similar in functionality to operations like gradient clipping. How about I change the name to |
@Red-Portal Your proposal looks good! |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #147 +/- ##
==========================================
- Coverage 93.54% 91.76% -1.79%
==========================================
Files 12 13 +1
Lines 372 352 -20
==========================================
- Hits 348 323 -25
- Misses 24 29 +5 ☔ View full report in Codecov by Sentry. |
@yebai Needed to change the API a little bit (see the summary at the top comment), do you agree with it? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @Red-Portal -- looks very good. I left two minor comments. Otherwise, this is ready to go!
Co-authored-by: Hong Ge <[email protected]>
…Lang/AdvancedVI.jl into projected_proximal_location_scale
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @Red-Portal!
This PR refactors how post-hoc modifications are applied to the iterates after performing a gradient descent step. For instance, before, updating the parameters of
LocationScale
always silently applied a projection step. Now, everything needs to be made into its ownOptimisationRule
to make it more modular and explicit.More concretely, this PR changes the following:
LocationScale
distribution is no longer projected by default.operator
, is added tooptimize
.operator
object is applied to the parameters after each gradient descent step.ClipScale
, which clips the diagonal of the scale matrix to be strictly positive.For example:
For gradient descent, the operator is applied as:
where$g_t$ is the gradient estimator and $\gamma_t$ is the stepsize.