Skip to content

Commit

Permalink
Setup formatter (#17)
Browse files Browse the repository at this point in the history
* apply formatter, add formatting rule
* add Formatter action

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
Red-Portal and github-actions[bot] authored Nov 30, 2024
1 parent 7b261ae commit de41cd6
Show file tree
Hide file tree
Showing 18 changed files with 345 additions and 371 deletions.
7 changes: 7 additions & 0 deletions .JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@

style = "blue"
align_assignment = true
align_struct_field = true
align_pair_arrow = true
align_matrix = true
align_conditional = true
26 changes: 26 additions & 0 deletions .github/workflows/Format.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
name: Format suggestions

on:
pull_request:

concurrency:
# Skip intermediate builds: always.
# Cancel intermediate builds: only if it is a pull request build.
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }}

jobs:
format:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
with:
version: 1
- run: |
julia -e 'using Pkg; Pkg.add("JuliaFormatter")'
julia -e 'using JuliaFormatter; format("."; verbose=true)'
- uses: reviewdog/action-suggester@v1
with:
tool_name: JuliaFormatter
fail_on_error: true
7 changes: 2 additions & 5 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,8 @@ makedocs(;
"Univariate Slice Sampling" => "univariate_slice.md",
"Meta Multivariate Samplers" => "meta_multivariate.md",
"Latent Slice Sampling" => "latent_slice.md",
"Gibbsian Polar Slice Sampling" => "gibbs_polar.md"
"Gibbsian Polar Slice Sampling" => "gibbs_polar.md",
],
)

deploydocs(;
repo="github.com/TuringLang/SliceSampling.jl",
push_preview=true
)
deploydocs(; repo="github.com/TuringLang/SliceSampling.jl", push_preview=true)
59 changes: 30 additions & 29 deletions ext/SliceSamplingTuringExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ if isdefined(Base, :get_extension)
using Random
using SliceSampling
using Turing
# using Turing: Turing, Experimental
# using Turing: Turing, Experimental
else
using ..LogDensityProblemsAD
using ..Random
Expand All @@ -17,46 +17,47 @@ end

# Required for using the slice samplers as `externalsampler`s in Turing
# begin
Turing.Inference.getparams(
::Turing.DynamicPPL.Model,
sample::SliceSampling.Transition
) = sample.params
function Turing.Inference.getparams(
::Turing.DynamicPPL.Model, sample::SliceSampling.Transition
)
return sample.params
end
# end

# Required for using the slice samplers as `Experimental.Gibbs` samplers in Turing
# begin
Turing.Inference.getparams(
::Turing.DynamicPPL.Model,
state::SliceSampling.UnivariateSliceState
) = state.transition.params
function Turing.Inference.getparams(
::Turing.DynamicPPL.Model, state::SliceSampling.UnivariateSliceState
)
return state.transition.params
end

Turing.Inference.getparams(
::Turing.DynamicPPL.Model,
state::SliceSampling.GibbsState
) = state.transition.params
function Turing.Inference.getparams(
::Turing.DynamicPPL.Model, state::SliceSampling.GibbsState
)
return state.transition.params
end

Turing.Inference.getparams(
::Turing.DynamicPPL.Model,
state::SliceSampling.HitAndRunState
) = state.transition.params
function Turing.Inference.getparams(
::Turing.DynamicPPL.Model, state::SliceSampling.HitAndRunState
)
return state.transition.params
end

Turing.Experimental.gibbs_requires_recompute_logprob(
function Turing.Experimental.gibbs_requires_recompute_logprob(
model_dst,
::Turing.DynamicPPL.Sampler{
<: Turing.Inference.ExternalSampler{
<: SliceSampling.AbstractSliceSampling, A, U
}
<:Turing.Inference.ExternalSampler{<:SliceSampling.AbstractSliceSampling,A,U}
},
sampler_src,
state_dst,
state_src
) where {A,U} = false
state_src,
) where {A,U}
return false
end
# end

function SliceSampling.initial_sample(
rng::Random.AbstractRNG,
::Turing.LogDensityFunction
)
function SliceSampling.initial_sample(rng::Random.AbstractRNG, ℓ::Turing.LogDensityFunction)
model =.model
spl = Turing.SampleFromUniform()
vi = Turing.VarInfo(rng, model, spl)
Expand All @@ -67,14 +68,14 @@ function SliceSampling.initial_sample(
if init_attempt_count == 10
@warn "failed to find valid initial parameters in $(init_attempt_count) tries; consider providing explicit initial parameters using the `initial_params` keyword"
end

# NOTE: This will sample in the unconstrained space.
vi = last(DynamicPPL.evaluate!!(model, rng, vi, SampleFromUniform()))
θ = vi[spl]

init_attempt_count += 1
end
θ
return θ
end

end
55 changes: 26 additions & 29 deletions src/SliceSampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ Struct containing the results of the transition.
- `lp::Real`: Log-target density of the samples.
- `info::NamedTuple`: Named tuple containing information about the transition.
"""
struct Transition{P, L <: Real, I <: NamedTuple}
struct Transition{P,L<:Real,I<:NamedTuple}
"current state of the slice sampling chain"
params::P

Expand All @@ -53,47 +53,44 @@ Return the initial sample for the `model` using the random number generator `rng
- `model`: The target `LogDensityProblem`.
"""
function initial_sample(::Random.AbstractRNG, ::Any)
error(
return error(
"`initial_sample` is not implemented but an initialization wasn't provided. ",
"Consider supplying an initialization to `initial_params`."
"Consider supplying an initialization to `initial_params`.",
)
end

# If target is from `LogDensityProblemsAD`, unwrap target before calling `initial_sample`.
# This is necessary since Turing wraps `DynamicPPL.Model`s when passed to an `externalsampler`.
initial_sample(
rng::Random.AbstractRNG,
wrap::LogDensityProblemsAD.ADGradientWrapper
) = initial_sample(rng, parent(wrap))
function initial_sample(
rng::Random.AbstractRNG, wrap::LogDensityProblemsAD.ADGradientWrapper
)
return initial_sample(rng, parent(wrap))
end

function exceeded_max_prop(max_prop::Int)
error("Exceeded maximum number of proposal $(max_prop), ",
"which indicates an acceptance rate less than $(1/max_prop*100)%. ",
"A quick fix is to increase `max_prop`, ",
"but an acceptance rate that is too low often indicates that there is a problem. ",
"Here are some possible causes:\n",
"- The model might be broken or degenerate (most likely cause).\n",
"- The tunable parameters of the sampler are suboptimal.\n",
"- The initialization is pathologic. (try supplying a (different) `initial_params`)\n",
"- There might be a bug in the sampler. (if this is suspected, file an issue to `SliceSampling`)\n"
)
return error(
"Exceeded maximum number of proposal $(max_prop), ",
"which indicates an acceptance rate less than $(1/max_prop*100)%. ",
"A quick fix is to increase `max_prop`, ",
"but an acceptance rate that is too low often indicates that there is a problem. ",
"Here are some possible causes:\n",
"- The model might be broken or degenerate (most likely cause).\n",
"- The tunable parameters of the sampler are suboptimal.\n",
"- The initialization is pathologic. (try supplying a (different) `initial_params`)\n",
"- There might be a bug in the sampler. (if this is suspected, file an issue to `SliceSampling`)\n",
)
end

## Univariate Slice Sampling Algorithms
export Slice, SliceSteppingOut, SliceDoublingOut

abstract type AbstractUnivariateSliceSampling <: AbstractSliceSampling end
abstract type AbstractUnivariateSliceSampling <: AbstractSliceSampling end

accept_slice_proposal(
::AbstractSliceSampling,
::Any,
::Real,
::Real,
::Real,
::Real,
::Real,
::Real,
) = true
function accept_slice_proposal(
::AbstractSliceSampling, ::Any, ::Real, ::Real, ::Real, ::Real, ::Real, ::Real
)
return true
end

function find_interval end

Expand All @@ -103,7 +100,7 @@ include("univariate/steppingout.jl")
include("univariate/doublingout.jl")

## Multivariate slice sampling algorithms
abstract type AbstractMultivariateSliceSampling <: AbstractSliceSampling end
abstract type AbstractMultivariateSliceSampling <: AbstractSliceSampling end

# Meta Multivariate Samplers
export RandPermGibbs, HitAndRun
Expand Down
Loading

0 comments on commit de41cd6

Please sign in to comment.