-
Notifications
You must be signed in to change notification settings - Fork 65
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Segfault when calling rand
with Bijectors
#2074
Comments
hm so sadly this does not err on my computer |
@wsmoses Hmm let me check a few things on my system. As a last resort, would it be useful for you if I can reproduce this in a Docker container? |
Starting from a fresh docker pull kyrkim/enzymeissue2074
docker run -it kyrkim/enzymeissue2074 bash And then copy-paste the snippet above on the pre-installed julia REPL. |
@wsmoses Would this be enough to reproduce on your end? This bug is breaking all the Enzyme tests in |
I managed to repro and slightly reduce with the following, but I'll need more help reducing: using Bijectors
using LinearAlgebra
using Functors
using Optimisers
using Enzyme
using Random, StableRNGs
struct TestProb1 end
logdensity(::TestProb1, θ) = sum(θ)
function Bijectors.bijector(::TestProb1)
return Bijectors.Stacked(
[Base.Fix1(broadcast, log), identity],
[1:1, 2:3],
)
end
struct TestProb2 end
logdensity(::TestProb2, θ) = sum(θ)
struct MvLocationScale{S,D<:ContinuousDistribution,L,E<:Real} <:
ContinuousMultivariateDistribution
location::L
scale::S
dist::D
scale_eps::E
end
Base.length(q::MvLocationScale) = length(q.location)
Functors.@functor MvLocationScale (location, scale)
# This specialization improves AD performance of the sampling path
function Distributions.rand(
rng::AbstractRNG, q::MvLocationScale{<:Diagonal,L}, num_samples::Int
) where {L}
(; location, scale) = q
n_dims = length(location)
scale_diag = diag(scale)
return randn(rng, n_dims, num_samples)
end
function estimate_repgradelbo_ad_forward(params, aux)
(; rng, problem, restructure) = aux
q = restructure(params)
zs = rand(rng, q, 10)
return mean(Base.Fix1(logdensity, problem), eachcol(zs))
end
function main()
d = 5
seed = (0x38bef07cf9cc549d)
rng = StableRNG(seed)
for prob in [TestProb1(), TestProb2()]
q = if prob isa TestProb1
MvLocationScale(zeros(d), Diagonal(ones(d)), Normal(), 1e-5)
else
Bijectors.TransformedDistribution(
MvLocationScale(zeros(d), Diagonal(ones(d)), Normal(), 1e-5),
inverse(
Bijectors.Stacked(
[Base.Fix1(broadcast, log), identity],
[1:1, 2:d],
)
)
)
end
params, re = Optimisers.destructure(q)
buf = zero(params)
aux = (rng=rng, problem=prob, restructure=re)
Enzyme.autodiff(
set_runtime_activity(Enzyme.ReverseWithPrimal, true),
estimate_repgradelbo_ad_forward,
Enzyme.Active,
Enzyme.Duplicated(params, buf),
Enzyme.Const(aux),
)
end
end
main() |
@wsmoses Unfortunately, your version does not fail on my system 😓 |
Naturally |
Oh adding back the |
Hope the following works on your end: using Bijectors
using LinearAlgebra
using Functors
using Optimisers
using Enzyme
using Random
struct TestProb1 end
logdensity(::TestProb1, θ) = sum(θ)
function Bijectors.bijector(::TestProb1)
return Bijectors.Stacked(
[Base.Fix1(broadcast, log), identity],
[1:1, 2:3],
)
end
struct TestProb2 end
logdensity(::TestProb2, θ) = sum(θ)
struct MvLocationScale{L} <: ContinuousMultivariateDistribution
location::L
end
Base.length(q::MvLocationScale) = length(q.location)
Functors.@functor MvLocationScale (location,)
# This specialization improves AD performance of the sampling path
function Distributions.rand(
rng::AbstractRNG, q::MvLocationScale, num_samples::Int
)
(; location,) = q
n_dims = length(location)
return randn(rng, n_dims, num_samples)
end
function estimate_repgradelbo_ad_forward(params, aux)
(; rng, problem, restructure) = aux
q = restructure(params)
zs = rand(rng, q, 10)
return mean(Base.Fix1(logdensity, problem), eachcol(zs))
end
function main()
d = 5
rng = Random.default_rng()
for prob in [TestProb1(), TestProb2()]
q = if prob isa TestProb1
MvLocationScale(zeros(d))
else
Bijectors.TransformedDistribution(
MvLocationScale(zeros(d)),
inverse(
Bijectors.Stacked(
[Base.Fix1(broadcast, log), identity],
[1:1, 2:d],
)
)
)
end
params, re = Optimisers.destructure(q)
buf = zero(params)
aux = (rng=rng, problem=prob, restructure=re)
Enzyme.autodiff(
set_runtime_activity(Enzyme.ReverseWithPrimal, true),
estimate_repgradelbo_ad_forward,
Enzyme.Active,
Enzyme.Duplicated(params, buf),
Enzyme.Const(aux),
)
end
end
main() |
Still segfaults for me! Any chance you can reduce further (and also ideally get rid of bijectors) |
@wsmoses I strongly suspect Bijectors is the offender here; the tests not involving Bijectors never failed. I could try opening up Bijectors, though that might take some time. |
Bingo. I got it distilled. using Statistics
using Base.Iterators
using LinearAlgebra
using Functors
using Optimisers
using Enzyme
using Random
struct TransformedDistribution{D,B}
dist::D
transform::B
end
Functors.@functor TransformedDistribution
function rand(rng::AbstractRNG, td::TransformedDistribution, num_samples::Int)
samples = rand(rng, td.dist, num_samples)
res = reduce(
hcat,
map(axes(samples, 2)) do i
return td.transform(view(samples, :, i))
end,
)
return res
end
struct Stacked{Bs,Rs<:Union{Tuple,AbstractArray}}
bs::Bs
ranges_in::Rs
ranges_out::Rs
length_in::Int
length_out::Int
end
function mapvcat(f, args...)
out = map(f, args...)
init = vcat(out[1])
return reduce(vcat, drop(out, 1); init=init)
end
@generated function _transform_stacked_recursive(
x, rs::NTuple{N,UnitRange{Int}}, bs...
) where {N}
exprs = []
for i in 1:N
push!(exprs, :(bs[$i](x[rs[$i]])))
end
return :(vcat($(exprs...)))
end
function _transform_stacked_recursive(x, rs::NTuple{1,UnitRange{Int}}, b)
rs[1] == 1:length(x) || error("range must be 1:length(x)")
return b(x)
end
function _transform_stacked(sb::Stacked{<:Tuple,<:Tuple}, x::AbstractVector{<:Real})
y = _transform_stacked_recursive(x, sb.ranges_in, sb.bs...)
return y
end
function _transform_stacked(sb::Stacked{<:AbstractArray}, x::AbstractVector{<:Real})
N = length(sb.bs)
N == 1 && return sb.bs[1](x[sb.ranges_in[1]])
y = mapvcat(1:N) do i
sb.bs[i](x[sb.ranges_in[i]])
end
return y
end
function (sb::Stacked)(x::AbstractVector{<:Real})
y = _transform_stacked(sb, x)
return y
end
struct TestProb1 end
logdensity(::TestProb1, θ) = sum(θ)
struct TestProb2 end
logdensity(::TestProb2, θ) = sum(θ)
struct MvLocationScale{L} <: ContinuousMultivariateDistribution
location::L
end
Base.length(q::MvLocationScale) = length(q.location)
Functors.@functor MvLocationScale (location,)
# This specialization improves AD performance of the sampling path
function rand(
rng::AbstractRNG, q::MvLocationScale, num_samples::Int
)
(; location,) = q
n_dims = length(location)
return randn(rng, n_dims, num_samples)
end
function estimate_repgradelbo_ad_forward(params, aux)
(; rng, problem, restructure) = aux
q = restructure(params)
zs = rand(rng, q, 10)
return mean(Base.Fix1(logdensity, problem), eachcol(zs))
end
function main()
d = 5
rng = Random.default_rng()
for prob in [TestProb1(), TestProb2()]
q = if prob isa TestProb1
MvLocationScale(zeros(d))
else
TransformedDistribution(
MvLocationScale(zeros(d)),
Stacked(
[Base.Fix1(broadcast, exp), identity],
[1:1, 2:d],
[1:1, 2:d],
d, d,
)
)
end
params, re = Optimisers.destructure(q)
buf = zero(params)
aux = (rng=rng, problem=prob, restructure=re)
Enzyme.autodiff(
set_runtime_activity(Enzyme.ReverseWithPrimal, true),
estimate_repgradelbo_ad_forward,
Enzyme.Active,
Enzyme.Duplicated(params, buf),
Enzyme.Const(aux),
)
end
end
main() The |
Statistics is fine. Would it be possible to also get rid of functors and optimizers? |
That... I am afraid it is going to be too much of a pain. |
Let me see if we can debug the segfault as is (this will likely generate hundreds of thousands of instructions, so any simplification here will be immensely helpful, and also reducing the dependencies will make sure we can add this as a test) |
I could try to do it, but restructure/destructure take up a big portion of what |
sounds good! and yeah sorry this is so much of a pain (sadly this is usually what segfaults are now, and often end up as bugs in Julia itself =/) |
Huge pain incoming: using Statistics
using Base.Iterators
using LinearAlgebra
using Enzyme
using Random
struct FunctionConstructor{F} end
_isgensym(s::Symbol) = occursin("#", string(s))
@generated function (fc::FunctionConstructor{F})(args...) where F
isempty(args) && return Expr(:new, F)
T = getfield(parentmodule(F), nameof(F))
# We assume all gensym names are anonymous functions
_isgensym(nameof(F)) || return :($T(args...))
# Define `new` for rebuilt function type that matches args
exp = Expr(:new, Expr(:curly, T, args...))
for i in 1:length(args)
push!(exp.args, :(args[$i]))
end
return exp
end
const NoChildren = Tuple{}
function constructorof(f::Type{F}) where F <: Function
FunctionConstructor{F}()
end
_vec(x::Number) = LinRange(x,x,1)
_vec(x::AbstractArray) = vec(x)
struct ExcludeWalk{T, F, G}
walk::T
fn::F
exclude::G
end
_map(f, x::Dict, ys...) = Dict(k => f(v, (y[k] for y in ys)...) for (k, v) in x)
_map(f, x::D, ys...) where {D<:AbstractDict} =
constructorof(D)([k => f(v, (y[k] for y in ys)...) for (k, v) in x]...)
struct DefaultWalk end
function (::DefaultWalk)(recurse, x, ys...)
func, re = functor(x)
yfuncs = map(y -> functor(typeof(x), y)[1], ys)
re(_map(recurse, func, yfuncs...))
end
(walk::ExcludeWalk)(recurse, x, ys...) =
walk.exclude(x) ? walk.fn(x, ys...) : walk.walk(recurse, x, ys...)
struct NoKeyword end
struct CachedWalk{T, S, C <: AbstractDict}
walk::T
prune::S
cache::C
end
CachedWalk(walk; prune = NoKeyword(), cache = IdDict()) =
CachedWalk(walk, prune, cache)
function (walk::CachedWalk)(recurse, x, ys...)
should_cache = usecache(walk.cache, x)
if should_cache && haskey(walk.cache, x)
return walk.prune isa NoKeyword ? cacheget(walk.cache, x, recurse, x, ys...) : walk.prune
else
ret = walk.walk(recurse, x, ys...)
if should_cache
walk.cache[x] = ret
end
return ret
end
end
@generated function anymutable(x::T) where {T}
ismutabletype(T) && return true
fns = QuoteNode.(filter(n -> fieldtype(T, n) != T, fieldnames(T)))
subs = [:(anymutable(getfield(x, $f))) for f in fns]
return Expr(:(||), subs...)
end
usecache(::Union{AbstractDict, AbstractSet}, x) =
isleaf(x) ? anymutable(x) : ismutable(x)
usecache(::Nothing, x) = false
struct WalkCache{K, V, W, C <: AbstractDict{K, V}} <: AbstractDict{K, V}
walk::W
cache::C
WalkCache(walk, cache::AbstractDict{K, V} = IdDict()) where {K, V} = new{K, V, typeof(walk), typeof(cache)}(walk, cache)
end
Base.length(cache::WalkCache) = length(cache.cache)
Base.empty!(cache::WalkCache) = empty!(cache.cache)
Base.haskey(cache::WalkCache, x) = haskey(cache.cache, x)
Base.get(cache::WalkCache, x, default) = haskey(cache.cache, x) ? cache[x] : default
Base.iterate(cache::WalkCache, state...) = iterate(cache.cache, state...)
Base.setindex!(cache::WalkCache, value, key) = setindex!(cache.cache, value, key)
Base.getindex(cache::WalkCache, x) = cache.cache[x]
function functor(T, x)
names = fieldnames(T)
if isempty(names)
return NoChildren(), _ -> x
end
S = constructorof(T) # remove parameters from parametric types and support anonymous functions
vals = ntuple(i -> getfield(x, names[i]), length(names))
return NamedTuple{names}(vals), y -> S(y...)
end
functor(::Type{<:Tuple}, x) = x, identity
functor(::Type{<:NamedTuple{L}}, x) where L = NamedTuple{L}(map(s -> getproperty(x, s), L)), identity
functor(::Type{<:Dict}, x) = Dict(k => x[k] for k in keys(x)), identity
functor(::Type{<:AbstractArray}, x) = x, identity
macro leaf(T)
:(functor(::Type{<:$(esc(T))}, x) = (NoChildren(), _ -> x))
end
@leaf Type
@leaf Number
@leaf AbstractArray{<:Number}
function execute(walk, x, ys...)
recurse(xs...) = walk(var"#self#", xs...)
walk(recurse, x, ys...)
end
function fmap(f, x, ys...; exclude = isleaf,
walk = DefaultWalk(),
cache = IdDict(),
prune = NoKeyword())
_walk = ExcludeWalk(walk, f, exclude)
if !isnothing(cache)
_walk = CachedWalk(_walk, prune, WalkCache(_walk, cache))
end
execute(_walk, x, ys...)
end
isnumeric(x::AbstractArray{<:Number}) = isleaf(x)
isnumeric(x::AbstractArray{<:Integer}) = false
isnumeric(x) = false
children(x) = functor(x)[1]
isleaf(@nospecialize(x)) = children(x) === NoChildren()
struct TrainableStructWalk end
mapvalue(f, x...) = map(f, x...)
mapvalue(f, x::Dict, ys...) = Dict(k => f(v, (get(y, k, nothing) for y in ys)...) for (k,v) in x)
trainable(x) = functor(x)[1]
_trainable(x) = _trainable(functor(x)[1], trainable(x))
_trainable(ch::NamedTuple, tr::NamedTuple) = merge(map(_ -> nothing, ch), tr)
_trainable(ch::Tuple{Vararg{Any,N}}, tr::Tuple{Vararg{Any,N}}) where N = tr
_trainable(ch::AbstractArray, tr::AbstractArray) = tr
_trainable(ch::Dict, tr::Dict) = merge(mapvalue(_ -> nothing, ch), tr)
(::TrainableStructWalk)(recurse, x) = mapvalue(recurse, _trainable(x))
function _flatten(x)
isnumeric(x) && return vcat(_vec(x)), 0, length(x) # trivial case
arrays = AbstractVector[]
len = Ref(0)
off = fmap(x; exclude = isnumeric, walk = TrainableStructWalk()) do y
push!(arrays, _vec(y))
o = len[]
len[] = o + length(y)
o
end
isempty(arrays) && return Bool[], off, 0
return reduce(vcat, arrays), off, len[]
end
function destructure(x)
flat, off, len = _flatten(x)
flat, Restructure(x, off, len)
end
struct Restructure{T,S}
model::T
offsets::S
length::Int
end
struct _Trainable_biwalk end
_getat(y::Number, o::Int, flat::AbstractVector) = flat[o + 1]
_getat(y::AbstractArray, o::Int, flat::AbstractVector) = reshape(flat[o .+ (1:length(y))], axes(y))
function _trainmap(f, ch, tr, aux)
map(ch, tr, aux) do c, t, a # isnothing(t) indicates non-trainable field, safe given isnumeric(c)
isnothing(t) ? c : f(t, a)
end
end
function (::_Trainable_biwalk)(f, x, aux)
ch, re = functor(typeof(x), x)
au, _ = functor(typeof(x), aux)
_trainmap(f, ch, _trainable(x), au) |> re
end
function _rebuild(x, off, flat::AbstractVector, len = length(flat); walk = _Trainable_biwalk(), kw...)
fmap(x, off; exclude = isnumeric, walk, kw...) do y, o
_getat(y, o, flat)
end
end
(re::Restructure)(flat::AbstractVector) = _rebuild(re.model, re.offsets, flat, re.length)
struct TransformedDistribution{D,B}
dist::D
transform::B
end
function makefunctor(T, fs = fieldnames(T))
fidx = Ref(0)
escargs = map(fieldnames(T)) do f
f in fs ? :(y[$(fidx[] += 1)]) : :(x.$f)
end
escargs_nt = map(fieldnames(T)) do f
f in fs ? :(y[$(Meta.quot(f))]) : :(x.$f)
end
escfs = [:($f=x.$f) for f in fs]
@eval begin
function functor(::Type{<:$T}, x)
reconstruct(y) = $T($(escargs...))
reconstruct(y::NamedTuple) = $T($(escargs_nt...))
return (;$(escfs...)), reconstruct
end
end
end
functor(x) = functor(typeof(x), x)
makefunctor(TransformedDistribution)
function rand(rng::AbstractRNG, td::TransformedDistribution, num_samples::Int)
samples = rand(rng, td.dist, num_samples)
res = reduce(
hcat,
map(axes(samples, 2)) do i
return td.transform(view(samples, :, i))
end,
)
return res
end
struct Stacked{Bs,Rs<:Union{Tuple,AbstractArray}}
bs::Bs
ranges_in::Rs
ranges_out::Rs
length_in::Int
length_out::Int
end
makefunctor(Stacked, (:bs,))
function mapvcat(f, args...)
out = map(f, args...)
init = vcat(out[1])
return reduce(vcat, drop(out, 1); init=init)
end
@generated function _transform_stacked_recursive(
x, rs::NTuple{N,UnitRange{Int}}, bs...
) where {N}
exprs = []
for i in 1:N
push!(exprs, :(bs[$i](x[rs[$i]])))
end
return :(vcat($(exprs...)))
end
function _transform_stacked_recursive(x, rs::NTuple{1,UnitRange{Int}}, b)
rs[1] == 1:length(x) || error("range must be 1:length(x)")
return b(x)
end
function _transform_stacked(sb::Stacked{<:Tuple,<:Tuple}, x::AbstractVector{<:Real})
y = _transform_stacked_recursive(x, sb.ranges_in, sb.bs...)
return y
end
function _transform_stacked(sb::Stacked{<:AbstractArray}, x::AbstractVector{<:Real})
N = length(sb.bs)
N == 1 && return sb.bs[1](x[sb.ranges_in[1]])
y = mapvcat(1:N) do i
sb.bs[i](x[sb.ranges_in[i]])
end
return y
end
function (sb::Stacked)(x::AbstractVector{<:Real})
y = _transform_stacked(sb, x)
return y
end
struct TestProb1 end
logdensity(::TestProb1, θ) = sum(θ)
struct TestProb2 end
logdensity(::TestProb2, θ) = sum(θ)
struct MvLocationScale{L}
location::L
end
Base.length(q::MvLocationScale) = length(q.location)
makefunctor(MvLocationScale)
# This specialization improves AD performance of the sampling path
function rand(
rng::AbstractRNG, q::MvLocationScale, num_samples::Int
)
(; location,) = q
n_dims = length(location)
return randn(rng, n_dims, num_samples)
end
function estimate_repgradelbo_ad_forward(params, aux)
(; rng, problem, restructure) = aux
q = restructure(params)
zs = rand(rng, q, 10)
return mean(Base.Fix1(logdensity, problem), eachcol(zs))
end
function main()
d = 5
rng = Random.default_rng()
for prob in [TestProb1(), TestProb2()]
q = if prob isa TestProb1
MvLocationScale(zeros(d))
else
TransformedDistribution(
MvLocationScale(zeros(d)),
Stacked(
[Base.Fix1(broadcast, exp), identity],
[1:1, 2:d],
[1:1, 2:d],
d, d,
)
)
end
params, re = destructure(q)
buf = zero(params)
aux = (rng=rng, problem=prob, restructure=re)
Enzyme.autodiff(
set_runtime_activity(Enzyme.ReverseWithPrimal, true),
estimate_repgradelbo_ad_forward,
Enzyme.Active,
Enzyme.Duplicated(params, buf),
Enzyme.Const(aux),
)
end
end
main() |
darn, sadly this no longer segfaults |
@wsmoses Even with |
@wsmoses Can you confirm whether you still can't reproduce? |
got it and significantly reduced to this: using Statistics
using Base.Iterators
using LinearAlgebra
using Enzyme
using Random
Enzyme.Compiler.DumpPostOpt[] = true
Enzyme.API.printall!(true)
struct Stacked
end
@inline function myrand(rng::AbstractRNG, td::Stacked, num_samples::Int)
return Base.inferencebarrier(ones(1))
end
struct TestProb1 end
logdensity(::TestProb1, θ) = sum(θ)
struct TestProb2 end
logdensity(::TestProb2, θ) = sum(θ)
struct MvLocationScale
end
# This specialization improves AD performance of the sampling path
@inline function myrand(
rng::AbstractRNG, q::MvLocationScale, num_samples::Int
)
return ones(5, num_samples)
end
function mymean(problem, A::AbstractArray)
isempty(A) && return sum(Base.Fix1(logdensity, problem), A)
x1 = sum(@inbounds first(A))
return 1.0
end
function estimate_repgradelbo_ad_forward(rng, problem, model)
zs = myrand(rng, model, 10)
return mymean(problem, eachcol(zs))
end
function main()
d = 5
rng = Random.default_rng()
for prob in [TestProb1(), TestProb2()]
q = if prob isa TestProb1
MvLocationScale()
else
Stacked()
end
Enzyme.autodiff(
set_runtime_activity(Enzyme.ReverseWithPrimal, true),
estimate_repgradelbo_ad_forward,
Enzyme.Active,
Enzyme.Const(rng),
Enzyme.Const(prob),
Enzyme.Const(q),
)
end
end
main()
# main0()
# main1() |
No way... that's magical |
Wow it also seems like it was quite a fix |
Hi!
The following code segfaults on 1.10:
This bug is very sensitive, and very seemingly minor changes (like changing the order of
TestProb1
andTestProb2
) immediately make it go away. As such it was pretty hard to contain, but the above seems to do. Below is the segfault error message.The text was updated successfully, but these errors were encountered: