Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add varname tests from DPPL + format repo #111

Merged
merged 3 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 99 additions & 111 deletions README.md

Large diffs are not rendered by default.

12 changes: 6 additions & 6 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@ using Documenter
using AbstractPPL

# Doctest setup
DocMeta.setdocmeta!(AbstractPPL, :DocTestSetup, :(using AbstractPPL); recursive = true)
DocMeta.setdocmeta!(AbstractPPL, :DocTestSetup, :(using AbstractPPL); recursive=true)

makedocs(;
sitename = "AbstractPPL",
modules = [AbstractPPL],
pages = ["Home" => "index.md", "API" => "api.md"],
checkdocs = :exports,
doctest = false,
sitename="AbstractPPL",
modules=[AbstractPPL],
pages=["Home" => "index.md", "API" => "api.md"],
checkdocs=:exports,
doctest=false,
)
5 changes: 2 additions & 3 deletions src/AbstractPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,13 @@ export VarName,
varname_to_string,
string_to_varname


# Abstract model functions
export AbstractProbabilisticProgram, condition, decondition, fix, unfix, logdensityof, densityof, AbstractContext, evaluate!!
export AbstractProbabilisticProgram,
condition, decondition, fix, unfix, logdensityof, densityof, AbstractContext, evaluate!!

# Abstract traces
export AbstractModelTrace


include("varname.jl")
include("abstractmodeltrace.jl")
include("abstractprobprog.jl")
Expand Down
5 changes: 0 additions & 5 deletions src/abstractprobprog.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ using AbstractMCMC
using DensityInterface
using Random


"""
AbstractProbabilisticProgram

Expand All @@ -12,7 +11,6 @@ abstract type AbstractProbabilisticProgram <: AbstractMCMC.AbstractModel end

DensityInterface.DensityKind(::AbstractProbabilisticProgram) = HasDensity()


"""
logdensityof(model, trace)

Expand All @@ -26,7 +24,6 @@ probability theory.
"""
DensityInterface.logdensityof(::AbstractProbabilisticProgram, ::AbstractModelTrace)


"""
decondition(conditioned_model)

Expand All @@ -43,7 +40,6 @@ should hold for models `m` with conditioned variables `obs`.
"""
function decondition end


"""
condition(model, observations)

Expand Down Expand Up @@ -84,7 +80,6 @@ should hold for any model `m` and parameters `params`.
"""
function fix end


"""
unfix(model)

Expand Down
115 changes: 77 additions & 38 deletions src/varname.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,11 @@ struct VarName{sym,T}

function VarName{sym}(optic=identity) where {sym}
if !is_static_optic(typeof(optic))
throw(ArgumentError("attempted to construct `VarName` with unsupported optic of type $(nameof(typeof(optic)))"))
throw(
ArgumentError(
"attempted to construct `VarName` with unsupported optic of type $(nameof(typeof(optic)))",
),
)
end
return new{sym,typeof(optic)}(optic)
end
Expand Down Expand Up @@ -168,7 +172,7 @@ end

function Base.show(io::IO, vn::VarName{sym,T}) where {sym,T}
print(io, getsym(vn))
_show_optic(io, getoptic(vn))
return _show_optic(io, getoptic(vn))
end

# modified from https://github.com/JuliaObjects/Accessors.jl/blob/01528a81fdf17c07436e1f3d99119d3f635e4c26/src/sugar.jl#L502
Expand All @@ -181,7 +185,7 @@ function _show_optic(io::IO, optic)
print(io, " ∘ ")
end
shortstr = reduce(_shortstring, inner; init="")
print(io, shortstr)
return print(io, shortstr)
end

_shortstring(prev, o::IndexLens) = "$prev[$(join(map(prettify_index, o.indices), ", "))]"
Expand All @@ -207,7 +211,6 @@ Symbol("x[1][:]")
"""
Base.Symbol(vn::VarName) = Symbol(string(vn)) # simplified symbol


"""
inspace(vn::Union{VarName, Symbol}, space::Tuple)

Expand Down Expand Up @@ -244,7 +247,6 @@ inspace(vn::VarName, space::Tuple) = any(_in(vn, s) for s in space)
_in(vn::VarName, s::Symbol) = getsym(vn) == s
_in(vn::VarName, s::VarName) = subsumes(s, vn)


"""
subsumes(u::VarName, v::VarName)

Expand Down Expand Up @@ -297,8 +299,9 @@ subsumes(::typeof(identity), ::typeof(identity)) = true
subsumes(::typeof(identity), ::ALLOWED_OPTICS) = true
subsumes(::ALLOWED_OPTICS, ::typeof(identity)) = false

subsumes(t::ComposedOptic, u::ComposedOptic) =
subsumes(t.outer, u.outer) && subsumes(t.inner, u.inner)
function subsumes(t::ComposedOptic, u::ComposedOptic)
return subsumes(t.outer, u.outer) && subsumes(t.inner, u.inner)
end

# If `t` is still a composed lens, then there is no way it can subsume `u` since `u` is a
# leaf of the "lens-tree".
Expand All @@ -317,11 +320,12 @@ subsumes(t::PropertyLens, u::PropertyLens) = false
# FIXME: Does not support `DynamicIndexLens`.
# FIXME: Does not correctly handle cases such as `subsumes(x, x[:])`
# (but neither did old implementation).
subsumes(
function subsumes(
t::Union{IndexLens,ComposedOptic{<:ALLOWED_OPTICS,<:IndexLens}},
u::Union{IndexLens,ComposedOptic{<:ALLOWED_OPTICS,<:IndexLens}}
) = subsumes_indices(t, u)

u::Union{IndexLens,ComposedOptic{<:ALLOWED_OPTICS,<:IndexLens}},
)
return subsumes_indices(t, u)
end

"""
subsumedby(t, u)
Expand Down Expand Up @@ -444,7 +448,6 @@ subsumes_index(i::Colon, j) = true
subsumes_index(i::AbstractVector, j) = issubset(j, i)
subsumes_index(i, j) = i == j


"""
ConcretizedSlice(::Base.Slice)

Expand All @@ -455,10 +458,13 @@ struct ConcretizedSlice{T,R} <: AbstractVector{T}
range::R
end

ConcretizedSlice(s::Base.Slice{R}) where {R} = ConcretizedSlice{eltype(s.indices),R}(s.indices)
function ConcretizedSlice(s::Base.Slice{R}) where {R}
return ConcretizedSlice{eltype(s.indices),R}(s.indices)
end
Base.show(io::IO, s::ConcretizedSlice) = print(io, ":")
Base.show(io::IO, ::MIME"text/plain", s::ConcretizedSlice) =
print(io, "ConcretizedSlice(", s.range, ")")
function Base.show(io::IO, ::MIME"text/plain", s::ConcretizedSlice)
return print(io, "ConcretizedSlice(", s.range, ")")
end
Base.size(s::ConcretizedSlice) = size(s.range)
Base.iterate(s::ConcretizedSlice, state...) = Base.iterate(s.range, state...)
Base.collect(s::ConcretizedSlice) = collect(s.range)
Expand All @@ -480,8 +486,9 @@ The only purpose of this are special cases like `:`, which we want to avoid beco
`ConcretizedSlice` based on the `lowered_index`, just what you'd get with an explicit `begin:end`
"""
reconcretize_index(original_index, lowered_index) = lowered_index
reconcretize_index(original_index::Colon, lowered_index::Base.Slice) =
ConcretizedSlice(lowered_index)
function reconcretize_index(original_index::Colon, lowered_index::Base.Slice)
return ConcretizedSlice(lowered_index)
end

"""
concretize(l, x)
Expand All @@ -495,7 +502,9 @@ the result close to the original indexing.
"""
concretize(I::ALLOWED_OPTICS, x) = I
concretize(I::DynamicIndexLens, x) = concretize(IndexLens(I.f(x)), x)
concretize(I::IndexLens, x) = IndexLens(reconcretize_index.(I.indices, to_indices(x, I.indices)))
function concretize(I::IndexLens, x)
return IndexLens(reconcretize_index.(I.indices, to_indices(x, I.indices)))
end
function concretize(I::ComposedOptic, x)
x_inner = I.inner(x) # TODO: get view here
return ComposedOptic(concretize(I.outer, x_inner), concretize(I.inner, x))
Expand Down Expand Up @@ -646,11 +655,9 @@ function varname(expr::Expr, concretize=Accessors.need_dynamic_optic(expr))
end

if concretize
return :(
$(AbstractPPL.VarName){$sym}(
return :($(AbstractPPL.VarName){$sym}(
$(AbstractPPL.concretize)($optics, $sym_escaped)
)
)
))
elseif Accessors.need_dynamic_optic(expr)
error("Variable name `$(expr)` is dynamic and requires concretization!")
else
Expand All @@ -672,7 +679,7 @@ end
function _parse_obj_optic(ex)
obj, optics = _parse_obj_optics(ex)
optic = Expr(:call, Accessors.opticcompose, optics...)
obj, optic
return obj, optic
end

# Accessors doesn't have the same support for interpolation
Expand All @@ -688,7 +695,8 @@ function _parse_obj_optics(ex)
indices = Accessors.replace_underscore.(indices, collection)
dims = length(indices) == 1 ? nothing : 1:length(indices)
lindices = esc.(Accessors.lower_index.(collection, indices, dims))
optics = :($(Accessors.DynamicIndexLens)($(esc(collection)) -> ($(lindices...),)))
optics =
:($(Accessors.DynamicIndexLens)($(esc(collection)) -> ($(lindices...),)))
else
index = esc(Expr(:tuple, indices...))
optics = :($(Accessors.IndexLens)($index))
Expand All @@ -702,16 +710,20 @@ function _parse_obj_optics(ex)
elseif Meta.isexpr(property, :$, 1)
optics = :($(Accessors.PropertyLens){$(esc(property.args[1]))}())
else
throw(ArgumentError(
string("Error while parsing :($ex). Second argument to `getproperty` can only be",
"a `Symbol` or `String` literal, received `$property` instead.")
))
throw(
ArgumentError(
string(
"Error while parsing :($ex). Second argument to `getproperty` can only be",
"a `Symbol` or `String` literal, received `$property` instead.",
),
),
)
end
else
obj = esc(ex)
return obj, ()
end
obj, tuple(frontoptics..., optics)
return obj, tuple(frontoptics..., optics)
end

"""
Expand Down Expand Up @@ -778,12 +790,27 @@ Convert an index `i` to a dictionary representation.
"""
index_to_dict(i::Integer) = Dict("type" => _BASE_INTEGER_TYPE, "value" => i)
index_to_dict(v::Vector{Int}) = Dict("type" => _BASE_VECTOR_TYPE, "values" => v)
index_to_dict(r::UnitRange) = Dict("type" => _BASE_UNITRANGE_TYPE, "start" => r.start, "stop" => r.stop)
index_to_dict(r::StepRange) = Dict("type" => _BASE_STEPRANGE_TYPE, "start" => r.start, "stop" => r.stop, "step" => r.step)
index_to_dict(r::Base.OneTo{I}) where {I} = Dict("type" => _BASE_ONETO_TYPE, "stop" => r.stop)
function index_to_dict(r::UnitRange)
return Dict("type" => _BASE_UNITRANGE_TYPE, "start" => r.start, "stop" => r.stop)
end
function index_to_dict(r::StepRange)
return Dict(
"type" => _BASE_STEPRANGE_TYPE,
"start" => r.start,
"stop" => r.stop,
"step" => r.step,
)
end
function index_to_dict(r::Base.OneTo{I}) where {I}
return Dict("type" => _BASE_ONETO_TYPE, "stop" => r.stop)
end
index_to_dict(::Colon) = Dict("type" => _BASE_COLON_TYPE)
index_to_dict(s::ConcretizedSlice{T,R}) where {T,R} = Dict("type" => _CONCRETIZED_SLICE_TYPE, "range" => index_to_dict(s.range))
index_to_dict(t::Tuple) = Dict("type" => _BASE_TUPLE_TYPE, "values" => map(index_to_dict, t))
function index_to_dict(s::ConcretizedSlice{T,R}) where {T,R}
return Dict("type" => _CONCRETIZED_SLICE_TYPE, "range" => index_to_dict(s.range))
end
function index_to_dict(t::Tuple)
return Dict("type" => _BASE_TUPLE_TYPE, "values" => map(index_to_dict, t))
end

"""
dict_to_index(dict)
Expand Down Expand Up @@ -839,9 +866,17 @@ function dict_to_index(dict)
end

optic_to_dict(::typeof(identity)) = Dict("type" => "identity")
optic_to_dict(::PropertyLens{sym}) where {sym} = Dict("type" => "property", "field" => String(sym))
function optic_to_dict(::PropertyLens{sym}) where {sym}
return Dict("type" => "property", "field" => String(sym))
end
optic_to_dict(i::IndexLens) = Dict("type" => "index", "indices" => index_to_dict(i.indices))
optic_to_dict(c::ComposedOptic) = Dict("type" => "composed", "outer" => optic_to_dict(c.outer), "inner" => optic_to_dict(c.inner))
function optic_to_dict(c::ComposedOptic)
return Dict(
"type" => "composed",
"outer" => optic_to_dict(c.outer),
"inner" => optic_to_dict(c.inner),
)
end

function dict_to_optic(dict)
if dict["type"] == "identity"
Expand All @@ -857,9 +892,13 @@ function dict_to_optic(dict)
end
end

varname_to_dict(vn::VarName) = Dict("sym" => getsym(vn), "optic" => optic_to_dict(getoptic(vn)))
function varname_to_dict(vn::VarName)
return Dict("sym" => getsym(vn), "optic" => optic_to_dict(getoptic(vn)))
end

dict_to_varname(dict::Dict{<:AbstractString, Any}) = VarName{Symbol(dict["sym"])}(dict_to_optic(dict["optic"]))
function dict_to_varname(dict::Dict{<:AbstractString,Any})
return VarName{Symbol(dict["sym"])}(dict_to_optic(dict["optic"]))
end

"""
varname_to_string(vn::VarName)
Expand Down
5 changes: 1 addition & 4 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,7 @@ using Test
include("abstractprobprog.jl")
@testset "doctests" begin
DocMeta.setdocmeta!(
AbstractPPL,
:DocTestSetup,
:(using AbstractPPL);
recursive=true,
AbstractPPL, :DocTestSetup, :(using AbstractPPL); recursive=true
)
doctest(AbstractPPL; manual=false)
end
Expand Down
Loading
Loading