From ac1a8df30099d94862714c20fb0249d7f864c59a Mon Sep 17 00:00:00 2001 From: Daniel Wennberg Date: Wed, 2 Oct 2024 22:04:58 -0700 Subject: [PATCH 1/6] Fix make_zero(!) bugs --- src/make_zero.jl | 60 ++++++++++++++++++++++++++++++++---------------- 1 file changed, 40 insertions(+), 20 deletions(-) diff --git a/src/make_zero.jl b/src/make_zero.jl index f2fd055c61..d744956672 100644 --- a/src/make_zero.jl +++ b/src/make_zero.jl @@ -104,7 +104,7 @@ end prev::Complex{RT}, ::Val{copy_if_inactive} = Val(false), )::Complex{RT} where {copy_if_inactive,RT<:AbstractFloat} - return RT(0) + return Complex{RT}(0) end @inline function EnzymeCore.make_zero( @@ -191,11 +191,10 @@ end return seen[prev] end prev2 = prev.contents - res = Core.Box() - seen[prev] = res - res.contents = Base.Ref( + res = Core.Box( EnzymeCore.make_zero(Core.Typeof(prev2), seen, prev2, Val(copy_if_inactive)), ) + seen[prev] = res return res end @@ -234,7 +233,9 @@ end end if nf == 0 - return prev + # Unclear what types might end up here rather than in specialized methods or + # guaranteed_const_nongen, but as a last-ditch attempt try falling back to Base.zero + return Base.zero(prev)::RT end flds = Vector{Any}(undef, nf) @@ -261,27 +262,41 @@ function make_zero_immutable!( prev::Complex{T}, seen::S, )::Complex{T} where {T<:AbstractFloat,S} - zero(T) + zero(Complex{T}) end function make_zero_immutable!(prev::T, seen::S)::T where {T<:Tuple,S} ntuple(Val(length(T.parameters))) do i Base.@_inline_meta - make_zero_immutable!(prev[i], seen) + p = prev[i] + SBT = Core.Typeof(p) + if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=# + make_zero_immutable!(p, seen) + else + EnzymeCore.make_zero!(p, seen) + p + end end end function make_zero_immutable!(prev::NamedTuple{a,b}, seen::S)::NamedTuple{a,b} where {a,b,S} - NamedTuple{a,b}(ntuple(Val(length(T.parameters))) do i + NamedTuple{a,b}(ntuple(Val(length(b.parameters))) do i Base.@_inline_meta - make_zero_immutable!(prev[a[i]], seen) + p = prev[a[i]] + SBT = Core.Typeof(p) + if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=# + make_zero_immutable!(p, seen) + else + EnzymeCore.make_zero!(p, seen) + p + end end) end function make_zero_immutable!(prev::T, seen::S)::T where {T,S} if guaranteed_const_nongen(T, nothing) - return prev + return prev # Note: unreachable from make_zero! end @assert !ismutable(prev) @@ -313,7 +328,7 @@ end prev::Base.RefValue{T}, seen::ST, )::Nothing where {T<:AbstractFloat,ST} - T[] = zero(T) + prev[] = zero(T) nothing end @@ -321,7 +336,7 @@ end prev::Base.RefValue{Complex{T}}, seen::ST, )::Nothing where {T<:AbstractFloat,ST} - T[] = zero(Complex{T}) + prev[] = zero(Complex{T}) nothing end @@ -390,7 +405,7 @@ end if guaranteed_const_nongen(T, nothing) return end - if in(seen, prev) + if prev in seen return end push!(seen, prev) @@ -429,7 +444,7 @@ end if guaranteed_const_nongen(T, nothing) return end - if in(seen, prev) + if prev in seen return end push!(seen, prev) @@ -459,7 +474,7 @@ end if guaranteed_const_nongen(T, nothing) return end - if in(seen, prev) + if prev in seen return end push!(seen, prev) @@ -482,13 +497,13 @@ end if guaranteed_const_nongen(T, nothing) return end - if in(seen, prev) + if prev in seen return end push!(seen, prev) SBT = Core.Typeof(pv) if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=# - prev.contents = EnzymeCore.make_zero_immutable!(pv, seen) + prev.contents = make_zero_immutable!(pv, seen) nothing else EnzymeCore.make_zero!(pv, seen) @@ -504,7 +519,7 @@ end if guaranteed_const_nongen(T, nothing) return end - if in(prev, seen) + if prev in seen return end @assert !Base.isabstracttype(T) @@ -513,7 +528,7 @@ end if nf == 0 - return + error("cannot zero $T in-place: it is apparently differentiable but has no fields") end push!(seen, prev) @@ -526,7 +541,12 @@ end continue end if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=# - setfield!(prev, i, make_zero_immutable!(xi, seen)) + yi = make_zero_immutable!(xi, seen) + if Base.isconst(T, i) + ccall(:jl_set_nth_field, Cvoid, (Any, Csize_t, Any), prev, i-1, yi) + else + setfield!(prev, i, yi) + end nothing else EnzymeCore.make_zero!(xi, seen) From 7838c290c035213e92d445d9add80f824fd49404 Mon Sep 17 00:00:00 2001 From: Daniel Wennberg Date: Thu, 3 Oct 2024 12:12:20 -0700 Subject: [PATCH 2/6] Add make_zero(!) tests Aiming for full coverage of both new and old implementations of make_zero(!) --- test/abi.jl | 32 ----- test/make_zero.jl | 355 ++++++++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 1 + 3 files changed, 356 insertions(+), 32 deletions(-) create mode 100644 test/make_zero.jl diff --git a/test/abi.jl b/test/abi.jl index 20747f2aaa..b6898ac1ba 100644 --- a/test/abi.jl +++ b/test/abi.jl @@ -489,38 +489,6 @@ mulsin(x) = sin(x[1] * x[2]) @test Enzyme.autodiff(ForwardWithPrimal, () -> Enzyme.within_autodiff())[1] end -mutable struct ConstVal - x::Float64 - const y::Float64 -end - -struct WithIO{F} - v::Vector{Float64} - callback::F - function WithIO(v, io) - callback() = println(io, "hello") - return new{typeof(callback)}(v, callback) - end -end - -@testset "Make Zero" begin - v = ConstVal(2.0, 3.0) - dv = make_zero(v) - @test dv isa ConstVal - @test dv.x ≈ 0.0 - @test dv.y ≈ 0.0 - - f = WithIO([1.0, 2.0], stdout) - df = @test_nowarn try - # catch errors to get failed test instead of "exception outside of a @test" - make_zero(f) - catch e - showerror(stderr, e) - end - @test df.v == [0.0, 0.0] - @test df.callback === f.callback -end - @testset "Type inference" begin x = ones(10) @inferred autodiff(Enzyme.Reverse, abssum, Duplicated(x,x)) diff --git a/test/make_zero.jl b/test/make_zero.jl new file mode 100644 index 0000000000..0f04db2fdb --- /dev/null +++ b/test/make_zero.jl @@ -0,0 +1,355 @@ +using Enzyme +using Test + +mutable struct MutableWrapper{T} + x::T +end + +Base.:(==)(a::MutableWrapper, b::MutableWrapper) = (a === b) || isequal(a.x, b.x) + +struct Incomplete{T} + s::String + x::Float64 + w::T + z # not initialized + Incomplete(s, x, w) = new{typeof(w)}(s, x, w) +end + +function Base.:(==)(a::Incomplete, b::Incomplete) + (a === b) && return true + (isequal(a.s, b.s) && isequal(a.x, b.x) && isequal(a.w, b.w)) || return false + if isdefined(a, :z) && isdefined(b, :z) + isequal(a.z, b.z) || return false + elseif isdefined(a, :z) || isdefined(b, :z) + return false + end + return true +end + +mutable struct MutableIncomplete{T} + s::String + const x::Float64 + y::Float64 + z # not initialized + w::T + function MutableIncomplete(s, x, y, w) + ret = new{typeof(w)}(s, x, y) + ret.w = w + return ret + end +end + +function Base.:(==)(a::MutableIncomplete, b::MutableIncomplete) + (a === b) && return true + if !isequal(a.s, b.s) || !isequal(a.x, b.x) || !isequal(a.y, b.y) || !isequal(a.w, b.w) + return false + end + if isdefined(a, :z) && isdefined(b, :z) + isequal(a.z, b.z) || return false + elseif isdefined(a, :z) || isdefined(b, :z) + return false + end + return true +end + +struct WithIO{F} # issue 2091 + v::Vector{Float64} + callback::F + function WithIO(v, io) + callback() = println(io, "hello") + return new{typeof(callback)}(v, callback) + end +end + +macro test_noerr(expr) + return quote + @test_nowarn try + # catch errors to get failed test instead of "exception outside of a @test" + $(esc(expr)) + catch e + showerror(stderr, e) + end + end +end + +@testset "make_zero" begin + # floats + @test make_zero(1.0) == 0.0 + @test make_zero(1.0im) == 0.0im + + # float arrays + multiple references + rearr = [1.0] + imarr = [1.0im] + rearr0 = make_zero(rearr) + imarr0 = make_zero(imarr) + @test typeof(rearr0) === typeof(rearr) + @test typeof(imarr0) === typeof(imarr) + @test rearr == [1.0] # no mutation + @test imarr == [1.0im] # no mutation + @test rearr0 == [0.0] + @test imarr0 == [0.0im] + rearrs0 = make_zero((rearr, rearr)) + imarrs0 = make_zero((imarr, imarr)) + @test typeof(rearrs0) === typeof((rearr, rearr)) + @test typeof(imarrs0) === typeof((imarr, imarr)) + @test rearr == [1.0] # no mutation + @test imarr == [1.0im] # no mutation + @test rearrs0[1] === rearrs0[2] + @test imarrs0[1] === imarrs0[2] + @test rearrs0[1] == [0.0] + @test imarrs0[1] == [0.0im] + + # floats in structs + rewrapped = MutableWrapper(1.0) + imwrapped = MutableWrapper(1.0im) + rewrapped0 = make_zero(rewrapped) + imwrapped0 = make_zero(imwrapped) + @test typeof(rewrapped0) === typeof(rewrapped) + @test typeof(imwrapped0) === typeof(imwrapped) + @test rewrapped == MutableWrapper(1.0) # no mutation + @test imwrapped == MutableWrapper(1.0im) # no mutation + @test rewrapped0 == MutableWrapper(0.0) + @test imwrapped0 == MutableWrapper(0.0im) + + # generic array + multiple references + wrapped = MutableWrapper(1.0) + mixarr = ["a", 1.0, wrapped] + mixarr0 = make_zero(mixarr) + @test typeof(mixarr0) === typeof(mixarr) + @test view(mixarr, 1:2) == ["a", 1.0] # no mutation + @test mixarr[3] === wrapped # no mutation + @test mixarr0 == ["a", 0.0, MutableWrapper(0.0)] + mixarrs0 = make_zero((mixarr, mixarr)) + @test typeof(mixarrs0) === typeof((mixarr, mixarr)) + @test view(mixarr, 1:2) == ["a", 1.0] # no mutation + @test mixarr[3] === wrapped # no mutation + @test mixarrs0[1] === mixarrs0[2] + @test mixarrs0[1] == ["a", 0.0, MutableWrapper(0.0)] + + # non-differentiable array + copy_if_inactive + constarr = ["a"] + constarr0 = make_zero(constarr) + @test typeof(constarr0) === typeof(constarr) + @test constarr == ["a"] # no mutation + @test constarr0 === constarr + constarr0copy = make_zero(constarr, #=copy_if_inactive=#Val(true)) + @test typeof(constarr0copy) === typeof(constarr0) + @test constarr == ["a"] # no mutation + @test constarr0copy !== constarr + @test constarr0copy == constarr + + # Tuple + tup = ("a", 1.0, MutableWrapper(1.0)) + tup0 = make_zero(tup) + @test typeof(tup0) === typeof(tup) + @test tup == ("a", 1.0, MutableWrapper(1.0)) # no mutation + @test tup0 == ("a", 0.0, MutableWrapper(0.0)) + + # NamedTuple + ntup = (a="a", b=1.0, c=MutableWrapper(1.0)) + ntup0 = make_zero(ntup) + @test typeof(ntup0) === typeof(ntup) + @test ntup == (a="a", b=1.0, c=MutableWrapper(1.0)) # no mutation + @test ntup0 == (a="a", b=0.0, c=MutableWrapper(0.0)) + + # Box + multiple references + box = Core.Box(1.0) + box0 = make_zero(box) + @test typeof(box0) === typeof(box) + @test box.contents == 1.0 # no mutation + @test box0.contents == 0.0 + boxes0 = make_zero((box, box)) + @test typeof(boxes0) === typeof((box, box)) + @test box.contents == 1.0 # no mutation + @test boxes0[1] === boxes0[2] + @test boxes0[1].contents == 0.0 + + # differentiable custom type + multiple references + wrapped = MutableWrapper(1.0) + wrapped0 = make_zero(wrapped) + @test typeof(wrapped0) === typeof(wrapped) + @test wrapped == MutableWrapper(1.0) # no mutation + @test wrapped0 == MutableWrapper(0.0) + wrappeds0 = make_zero((wrapped, wrapped)) + @test typeof(wrappeds0) === typeof((wrapped, wrapped)) + @test wrapped == MutableWrapper(1.0) # no mutation + @test wrappeds0[1] === wrappeds0[2] + @test wrappeds0[1] == MutableWrapper(0.0) + + # non-differentiable custom type + copy_if_inactive + constwrapped = MutableWrapper("a") + constwrapped0 = make_zero(constwrapped) + @test typeof(constwrapped0) === typeof(constwrapped) + @test constwrapped == MutableWrapper("a") # no mutation + @test constwrapped0 === constwrapped + constwrapped0copy = make_zero(constwrapped, #=copy_if_inactive=#Val(true)) + @test typeof(constwrapped0copy) === typeof(constwrapped0) + @test constwrapped == MutableWrapper("a") # no mutation + @test constwrapped0copy !== constwrapped + @test constwrapped0copy == constwrapped + + # immutable struct with active, mutable, inactive and undefined fields + incomplete = Incomplete("a", 1.0, MutableWrapper(1.0)) + incomplete0 = make_zero(incomplete) + @test typeof(incomplete0) === typeof(incomplete) + @test incomplete == Incomplete("a", 1.0, MutableWrapper(1.0)) # no mutation + @test incomplete0 == Incomplete("a", 0.0, MutableWrapper(0.0)) + + # mutable struct with inactive, active, undefined, and mutable fields + # + multiple references + incompletemut = MutableIncomplete("a", 1.0, 1.0, MutableWrapper(1.0)) + incompletemut0 = make_zero(incompletemut) + @test typeof(incompletemut0) === typeof(incompletemut) + @test incompletemut == MutableIncomplete("a", 1.0, 1.0, MutableWrapper(1.0)) # no mutation + @test incompletemut0 == MutableIncomplete("a", 0.0, 0.0, MutableWrapper(0.0)) + incompletemuts0 = make_zero((incompletemut, incompletemut)) + @test typeof(incompletemuts0) === typeof((incompletemut, incompletemut)) + @test incompletemut == MutableIncomplete("a", 1.0, 1.0, MutableWrapper(1.0)) # no mutation + @test incompletemuts0[1] === incompletemuts0[2] + @test incompletemuts0[1] == MutableIncomplete("a", 0.0, 0.0, MutableWrapper(0.0)) + + # containing IO (issue #2091) + f = WithIO([1.0, 2.0], stdout) + df = @test_noerr make_zero(f) + @test df.v == [0.0, 0.0] + @test df.callback === f.callback +end + +@testset "make_zero!" begin + # floats in mutable struct + rewrapped, imwrapped = MutableWrapper(1.0), MutableWrapper(1.0im) + make_zero!(rewrapped) + make_zero!(imwrapped) + @test rewrapped == MutableWrapper(0.0) + @test imwrapped == MutableWrapper(0.0im) + + # mixed tuple in mutable container + wrapped = MutableWrapper(1.0) + tuparr = [(1.0, wrapped)] + make_zero!(tuparr) + @test tuparr[1] === (0.0, wrapped) + @test wrapped == MutableWrapper(0.0) + + # mixed namedtuple in mutable container + wrapped = MutableWrapper(1.0) + ntuparr = [(a=1.0, b=wrapped)] + make_zero!(ntuparr) + @test ntuparr[1] === (a=0.0, b=wrapped) + @test wrapped == MutableWrapper(0.0) + + # immutable struct with active, mutable, inactive and undefined fields in mutable container + wrapped = MutableWrapper(1.0) + incompletearr = [Incomplete("a", 1.0, wrapped)] + make_zero!(incompletearr) + @test incompletearr[1] == Incomplete("a", 0.0, wrapped) + @test wrapped == MutableWrapper(0.0) + + # floats in Ref + reref, imref = Ref(1.0), Ref(1.0im) + make_zero!(reref) + make_zero!(imref) + @test reref[] == 0.0 + @test imref[] == 0.0im + + # float arrays + rearr, imarr = [1.0], [1.0im] + make_zero!(rearr) + make_zero!(imarr) + @test rearr[1] == 0.0 + @test imarr[1] == 0.0im + + # non-differentiable array + constarr = ["a"] + make_zero!(constarr) + @test constarr[1] == "a" + + # array with active, mutable, inactive and unassigned elements + multiple references + wrapped = MutableWrapper(1.0) + genericarr = Vector(undef, 4) + genericarr[1:3] .= ("a", 1.0, wrapped) + genericarrs = [genericarr, genericarr] + make_zero!(genericarrs) + @test genericarrs[1] === genericarrs[2] + @test genericarrs[1] === genericarr + @test view(genericarr, 1:2) == ["a", 0.0] + @test genericarr[3] === wrapped + @test wrapped == MutableWrapper(0.0) + @test !isassigned(genericarr, 4) + + # Ref with multiple references + genericref = Ref((1.0,)) + genericrefs = [genericref, genericref] + make_zero!(genericrefs) + @test genericrefs[1] === genericrefs[2] + @test genericrefs[1] === genericref + @test genericref[] == (0.0,) + + # Ref with mutable value + wrapped = MutableWrapper(1.0) + mutref = Ref(wrapped) + make_zero!(mutref) + @test mutref[] === wrapped + @test wrapped == MutableWrapper(0.0) + + # Ref with non-differentiable value + constref = Ref("a") + make_zero!(constref) + @test constref[] == "a" + + # Box with multiple references + box = Core.Box(1.0) + boxes = [box, box] + make_zero!(boxes) + @test boxes[1] === boxes[2] + @test boxes[1] === box + @test box.contents == 0.0 + + # Box with mutable value + wrapped = MutableWrapper(1.0) + mutbox = Core.Box(wrapped) + make_zero!(mutbox) + @test mutbox.contents === wrapped + @test wrapped == MutableWrapper(0.0) + + # Box with non-differentiable value + constbox = Core.Box("a") + make_zero!(constbox) + @test constbox.contents == "a" + + # mutable struct with inactive, active, const active, undefined, and mutable fields + # + multiple references + wrapped = MutableWrapper(1.0) + incompletemut = MutableIncomplete("a", #=const=#1.0, 1.0, wrapped) + incompletemuts = [incompletemut, incompletemut] + make_zero!(incompletemuts) + @test incompletemuts[1] === incompletemuts[2] + @test incompletemuts[1] === incompletemut + @test incompletemut == MutableIncomplete("a", #=const=#0.0, 0.0, MutableWrapper(0.0)) + @test incompletemut.w === wrapped + + # wrapped differentiable array + arr = [1.0] + arrwrapped = MutableWrapper(arr) + make_zero!(arrwrapped) + @test arrwrapped.x === arr + @test arr == [0.0] + + # early error on active/mixed type + @test_throws ErrorException make_zero!(1.0) + @test_throws ErrorException make_zero!((1.0, MutableWrapper(1.0))) + + # immutable struct with both active and undefined fields in immutable container + # (currently fails due to #1935) + wrapped = MutableWrapper(1.0) + incompletetuparr = [(Incomplete("a", 1.0, wrapped),)] + make_zero!(incompletetuparr) + @test incompletetuparr[1][1] == Incomplete("a", 0.0, MutableWrapper(0.0)) + @test incompletetuparr[1][1].w === wrapped + + # containing IO (issue #2091) + f = WithIO([1.0, 2.0], stdout) + fwrapped = [f] + @test_noerr make_zero!(fwrapped) + @test fwrapped[1] === f + @test fwrapped[1].v == [0.0, 0.0] +end diff --git a/test/runtests.jl b/test/runtests.jl index 26587c3892..e331645378 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -74,6 +74,7 @@ end include("abi.jl") include("typetree.jl") include("optimize.jl") +include("make_zero.jl") include("rules.jl") include("rrules.jl") From 48a6e6f59b43e50cd9bad0adee259a9eb944e1df Mon Sep 17 00:00:00 2001 From: Daniel Wennberg Date: Tue, 22 Oct 2024 13:39:17 -0700 Subject: [PATCH 3/6] Fix more make_zero(!) bugs and add more tests --- ext/EnzymeStaticArraysExt.jl | 34 +- src/make_zero.jl | 201 ++++---- test/make_zero.jl | 936 ++++++++++++++++++++++++----------- 3 files changed, 798 insertions(+), 373 deletions(-) diff --git a/ext/EnzymeStaticArraysExt.jl b/ext/EnzymeStaticArraysExt.jl index c2639a4c99..7e1d131116 100644 --- a/ext/EnzymeStaticArraysExt.jl +++ b/ext/EnzymeStaticArraysExt.jl @@ -32,11 +32,37 @@ end end end -@inline function Enzyme.EnzymeCore.make_zero(x::FT)::FT where {FT<:SArray} - return Base.zero(x) +@inline function Enzyme.EnzymeCore.make_zero( + ::Type{FT}, seen::IdDict, prev::FT, ::Val{copy_if_inactive} +) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:SArray{S,T},copy_if_inactive} + return Base.zero(prev)::FT end -@inline function Enzyme.EnzymeCore.make_zero(x::FT)::FT where {FT<:MArray} - return Base.zero(x) +@inline function Enzyme.EnzymeCore.make_zero( + ::Type{FT}, seen::IdDict, prev::FT, ::Val{copy_if_inactive} +) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:MArray{S,T},copy_if_inactive} + if haskey(seen, prev) + return seen[prev] + end + new = Base.zero(prev)::FT + seen[prev] = new + return new +end +@inline function Enzyme.EnzymeCore.make_zero!( + prev::FT, seen +) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:MArray{S,T}} + if !isnothing(seen) + if prev in seen + return nothing + end + push!(seen, prev) + end + fill!(prev, zero(T)) + return nothing +end +@inline function Enzyme.EnzymeCore.make_zero!( + prev::FT +) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:MArray{S,T}} + return Enzyme.EnzymeCore.make_zero!(prev, nothing) end end diff --git a/src/make_zero.jl b/src/make_zero.jl index d744956672..562a8afe7d 100644 --- a/src/make_zero.jl +++ b/src/make_zero.jl @@ -1,4 +1,3 @@ - @inline function EnzymeCore.make_zero(x::FT)::FT where {FT<:AbstractFloat} return Base.zero(x) end @@ -178,7 +177,9 @@ end prev::NamedTuple{A,RT}, ::Val{copy_if_inactive} = Val(false), )::NamedTuple{A,RT} where {copy_if_inactive,A,RT} - return NamedTuple{A,RT}(EnzymeCore.make_zero(RT, seen, RT(prev), Val(copy_if_inactive))) + prevtup = RT(prev) + TT = Core.Typeof(prevtup) # RT can be abstract + return NamedTuple{A,RT}(EnzymeCore.make_zero(TT, seen, prevtup, Val(copy_if_inactive))) end @inline function EnzymeCore.make_zero( @@ -191,10 +192,9 @@ end return seen[prev] end prev2 = prev.contents - res = Core.Box( - EnzymeCore.make_zero(Core.Typeof(prev2), seen, prev2, Val(copy_if_inactive)), - ) + res = Core.Box() seen[prev] = res + res.contents = EnzymeCore.make_zero(Core.Typeof(prev2), seen, prev2, Val(copy_if_inactive)) return res end @@ -213,7 +213,6 @@ end @assert !Base.isabstracttype(RT) @assert Base.isconcretetype(RT) nf = fieldcount(RT) - if ismutable(prev) y = ccall(:jl_new_struct_uninit, Any, (Any,), RT)::RT seen[prev] = y @@ -231,13 +230,10 @@ end end return y end - if nf == 0 - # Unclear what types might end up here rather than in specialized methods or - # guaranteed_const_nongen, but as a last-ditch attempt try falling back to Base.zero - return Base.zero(prev)::RT + # nothing to do, assume inactive + return copy_if_inactive ? Base.deepcopy_internal(prev, seen) : prev end - flds = Vector{Any}(undef, nf) for i = 1:nf if isdefined(prev, i) @@ -255,22 +251,27 @@ end end function make_zero_immutable!(prev::T, seen::S)::T where {T<:AbstractFloat,S} - zero(T) + return zero(T) end function make_zero_immutable!( prev::Complex{T}, seen::S, )::Complex{T} where {T<:AbstractFloat,S} - zero(Complex{T}) + return zero(Complex{T}) end function make_zero_immutable!(prev::T, seen::S)::T where {T<:Tuple,S} + if guaranteed_const_nongen(T, nothing) + return prev # unreachable from make_zero! + end ntuple(Val(length(T.parameters))) do i Base.@_inline_meta p = prev[i] SBT = Core.Typeof(p) - if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=# + if guaranteed_const_nongen(SBT, nothing) + p # covered by several tests even if not shown in coverage + elseif !ismutabletype(SBT) make_zero_immutable!(p, seen) else EnzymeCore.make_zero!(p, seen) @@ -280,11 +281,16 @@ function make_zero_immutable!(prev::T, seen::S)::T where {T<:Tuple,S} end function make_zero_immutable!(prev::NamedTuple{a,b}, seen::S)::NamedTuple{a,b} where {a,b,S} + if guaranteed_const_nongen(NamedTuple{a,b}, nothing) + return prev # unreachable from make_zero! + end NamedTuple{a,b}(ntuple(Val(length(b.parameters))) do i Base.@_inline_meta p = prev[a[i]] SBT = Core.Typeof(p) - if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=# + if guaranteed_const_nongen(SBT, nothing) + p # covered by several tests even if not shown in coverage + elseif !ismutabletype(SBT) make_zero_immutable!(p, seen) else EnzymeCore.make_zero!(p, seen) @@ -296,21 +302,20 @@ end function make_zero_immutable!(prev::T, seen::S)::T where {T,S} if guaranteed_const_nongen(T, nothing) - return prev # Note: unreachable from make_zero! + return prev # unreachable from make_zero! end - @assert !ismutable(prev) - - RT = Core.Typeof(prev) - @assert !Base.isabstracttype(RT) - @assert Base.isconcretetype(RT) - nf = fieldcount(RT) - + @assert !ismutabletype(T) + @assert !Base.isabstracttype(T) + @assert Base.isconcretetype(T) + nf = fieldcount(T) flds = Vector{Any}(undef, nf) for i = 1:nf if isdefined(prev, i) xi = getfield(prev, i) ST = Core.Typeof(xi) - flds[i] = if active_reg_inner(ST, (), nothing, Val(true)) == ActiveState #=justActive=# + flds[i] = if guaranteed_const_nongen(ST, nothing) + xi + elseif !ismutabletype(ST) make_zero_immutable!(xi, seen) else EnzymeCore.make_zero!(xi, seen) @@ -321,39 +326,63 @@ function make_zero_immutable!(prev::T, seen::S)::T where {T,S} break end end - ccall(:jl_new_structv, Any, (Any, Ptr{Any}, UInt32), RT, flds, nf)::T + return ccall(:jl_new_structv, Any, (Any, Ptr{Any}, UInt32), T, flds, nf)::T end @inline function EnzymeCore.make_zero!( prev::Base.RefValue{T}, seen::ST, )::Nothing where {T<:AbstractFloat,ST} + if !isnothing(seen) + if prev in seen + return nothing + end + push!(seen, prev) + end prev[] = zero(T) - nothing + return nothing end @inline function EnzymeCore.make_zero!( prev::Base.RefValue{Complex{T}}, seen::ST, )::Nothing where {T<:AbstractFloat,ST} + if !isnothing(seen) + if prev in seen + return nothing + end + push!(seen, prev) + end prev[] = zero(Complex{T}) - nothing + return nothing end @inline function EnzymeCore.make_zero!( prev::Array{T,N}, seen::ST, )::Nothing where {T<:AbstractFloat,N,ST} + if !isnothing(seen) + if prev in seen + return nothing + end + push!(seen, prev) + end fill!(prev, zero(T)) - nothing + return nothing end @inline function EnzymeCore.make_zero!( prev::Array{Complex{T},N}, seen::ST, )::Nothing where {T<:AbstractFloat,N,ST} + if !isnothing(seen) + if prev in seen + return nothing + end + push!(seen, prev) + end fill!(prev, zero(Complex{T})) - nothing + return nothing end @static if VERSION < v"1.11-" @@ -362,16 +391,28 @@ else prev::GenericMemory{kind, T}, seen::ST, )::Nothing where {T<:AbstractFloat,kind,ST} + if !isnothing(seen) + if prev in seen + return nothing + end + push!(seen, prev) + end fill!(prev, zero(T)) - nothing + return nothing end @inline function EnzymeCore.make_zero!( prev::GenericMemory{kind, Complex{T}}, seen::ST, )::Nothing where {T<:AbstractFloat,kind,ST} + if !isnothing(seen) + if prev in seen + return nothing + end + push!(seen, prev) + end fill!(prev, zero(Complex{T})) - nothing + return nothing end end @@ -379,90 +420,88 @@ end prev::Base.RefValue{T}, )::Nothing where {T<:AbstractFloat} EnzymeCore.make_zero!(prev, nothing) - nothing + return nothing end @inline function EnzymeCore.make_zero!( prev::Base.RefValue{Complex{T}}, )::Nothing where {T<:AbstractFloat} EnzymeCore.make_zero!(prev, nothing) - nothing + return nothing end @inline function EnzymeCore.make_zero!(prev::Array{T,N})::Nothing where {T<:AbstractFloat,N} EnzymeCore.make_zero!(prev, nothing) - nothing + return nothing end @inline function EnzymeCore.make_zero!( prev::Array{Complex{T},N}, )::Nothing where {T<:AbstractFloat,N} EnzymeCore.make_zero!(prev, nothing) - nothing + return nothing end @inline function EnzymeCore.make_zero!(prev::Array{T,N}, seen::ST)::Nothing where {T,N,ST} if guaranteed_const_nongen(T, nothing) - return + return nothing end if prev in seen - return + return nothing end push!(seen, prev) - for I in eachindex(prev) if isassigned(prev, I) pv = prev[I] SBT = Core.Typeof(pv) - if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=# + if guaranteed_const_nongen(SBT, nothing) + continue + elseif !ismutabletype(SBT) @inbounds prev[I] = make_zero_immutable!(pv, seen) - nothing else EnzymeCore.make_zero!(pv, seen) - nothing end end end - nothing + return nothing end @static if VERSION < v"1.11-" else @inline function EnzymeCore.make_zero!(prev::GenericMemory{kind, T})::Nothing where {T<:AbstractFloat,kind} EnzymeCore.make_zero!(prev, nothing) - nothing + return nothing end @inline function EnzymeCore.make_zero!( prev::GenericMemory{kind, Complex{T}}, )::Nothing where {T<:AbstractFloat, kind} EnzymeCore.make_zero!(prev, nothing) - nothing + return nothing end @inline function EnzymeCore.make_zero!(prev::GenericMemory{kind, T}, seen::ST)::Nothing where {T,kind,ST} if guaranteed_const_nongen(T, nothing) - return + return nothing end if prev in seen - return + return nothing end push!(seen, prev) - for I in eachindex(prev) if isassigned(prev, I) pv = prev[I] SBT = Core.Typeof(pv) - if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=# + if guaranteed_const_nongen(SBT, nothing) + continue + elseif !ismutabletype(SBT) @inbounds prev[I] = make_zero_immutable!(pv, seen) - nothing else EnzymeCore.make_zero!(pv, seen) - nothing end end end - nothing + return nothing end end @@ -472,87 +511,77 @@ end seen::ST, )::Nothing where {T,ST} if guaranteed_const_nongen(T, nothing) - return + return nothing end if prev in seen - return + return nothing end push!(seen, prev) - pv = prev[] SBT = Core.Typeof(pv) - if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=# + if guaranteed_const_nongen(SBT, nothing) + return nothing + elseif !ismutabletype(SBT) prev[] = make_zero_immutable!(pv, seen) - nothing else EnzymeCore.make_zero!(pv, seen) - nothing end - nothing + return nothing end @inline function EnzymeCore.make_zero!(prev::Core.Box, seen::ST)::Nothing where {ST} - pv = prev.contents - T = Core.Typeof(pv) - if guaranteed_const_nongen(T, nothing) - return - end if prev in seen - return + return nothing end push!(seen, prev) + pv = prev.contents SBT = Core.Typeof(pv) - if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=# + if guaranteed_const_nongen(SBT, nothing) + return nothing + elseif !ismutabletype(SBT) prev.contents = make_zero_immutable!(pv, seen) - nothing else EnzymeCore.make_zero!(pv, seen) - nothing end - nothing + return nothing end -@inline function EnzymeCore.make_zero!( - prev::T, - seen::S = Base.IdSet{Any}(), -)::Nothing where {T,S} +@inline function EnzymeCore.make_zero!(prev::T, seen::S)::Nothing where {T,S} if guaranteed_const_nongen(T, nothing) - return + return nothing end if prev in seen - return + return nothing end @assert !Base.isabstracttype(T) @assert Base.isconcretetype(T) nf = fieldcount(T) - - if nf == 0 - error("cannot zero $T in-place: it is apparently differentiable but has no fields") + return nothing end - push!(seen, prev) - for i = 1:nf if isdefined(prev, i) xi = getfield(prev, i) SBT = Core.Typeof(xi) - if guaranteed_const_nongen(SBT, nothing) + activitystate = active_reg_inner(SBT, (), nothing) + if activitystate == AnyState # guaranteed_const continue - end - if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=# + elseif ismutabletype(T) && !ismutabletype(SBT) yi = make_zero_immutable!(xi, seen) if Base.isconst(T, i) ccall(:jl_set_nth_field, Cvoid, (Any, Csize_t, Any), prev, i-1, yi) else setfield!(prev, i, yi) end - nothing - else + elseif activitystate == DupState EnzymeCore.make_zero!(xi, seen) - nothing + else + throw(ArgumentError("$xi of type $SBT cannot be zeroed in-place")) end end end - return + return nothing end + +@inline EnzymeCore.make_zero!(prev) = EnzymeCore.make_zero!(prev, Base.IdSet()) diff --git a/test/make_zero.jl b/test/make_zero.jl index 0f04db2fdb..cbe2f2159f 100644 --- a/test/make_zero.jl +++ b/test/make_zero.jl @@ -1,11 +1,91 @@ +module MakeZeroTests + using Enzyme +using StaticArrays using Test +# Universal getters/setters for built-in and custom containers/wrappers +getx(w::Base.RefValue) = w[] +getx(w::Core.Box) = w.contents +getx(w) = first(w) +gety(w) = last(w) + +setx!(w::Base.RefValue, x) = (w[] = x) +setx!(w::Core.Box, x) = (w.contents = x) +setx!(w, x) = (w[begin] = x) +sety!(w, y) = (w[end] = y) + +# non-isbits MArray doesn't support setindex!, so requires a little hack +function setx!(w::MArray{S,T}, x) where {S,T} + if isbitstype(T) + w[begin] = x + else + w.data = (x, Base.tail(w.data)...) + end + return x +end + +function sety!(w::MArray{S,T}, y) where {S,T} + if isbitstype(T) + w[end] = y + else + w.data = (Base.front(w.data)..., y) + end + return y +end + +struct Empty end + +mutable struct MutableEmpty end + +Base.:(==)(::MutableEmpty, ::MutableEmpty) = true + +struct Wrapper{T} + x::T +end + +Base.:(==)(a::Wrapper, b::Wrapper) = (a === b) || (a.x == b.x) +getx(a::Wrapper) = a.x + mutable struct MutableWrapper{T} x::T end -Base.:(==)(a::MutableWrapper, b::MutableWrapper) = (a === b) || isequal(a.x, b.x) +Base.:(==)(a::MutableWrapper, b::MutableWrapper) = (a === b) || (a.x == b.x) + +getx(a::MutableWrapper) = a.x +setx!(a::MutableWrapper, x) = (a.x = x) + +struct DualWrapper{Tx,Ty} + x::Tx + y::Ty +end + +DualWrapper{T}(x::T, y) where {T} = DualWrapper{T,typeof(y)}(x, y) + +function Base.:(==)(a::DualWrapper, b::DualWrapper) + return (a === b) || ((a.x == b.x) && (a.y == b.y)) +end + +getx(a::DualWrapper) = a.x +gety(a::DualWrapper) = a.y + +mutable struct MutableDualWrapper{Tx,Ty} + x::Tx + y::Ty +end + +MutableDualWrapper{T}(x::T, y) where {T} = MutableDualWrapper{T,typeof(y)}(x, y) + +function Base.:(==)(a::MutableDualWrapper, b::MutableDualWrapper) + return (a === b) || ((a.x == b.x) && (a.y == b.y)) +end + +getx(a::MutableDualWrapper) = a.x +gety(a::MutableDualWrapper) = a.y + +setx!(a::MutableDualWrapper, x) = (a.x = x) +sety!(a::MutableDualWrapper, y) = (a.y = y) struct Incomplete{T} s::String @@ -17,9 +97,9 @@ end function Base.:(==)(a::Incomplete, b::Incomplete) (a === b) && return true - (isequal(a.s, b.s) && isequal(a.x, b.x) && isequal(a.w, b.w)) || return false + ((a.s == b.s) && (a.x == b.x) && (a.w == b.w)) || return false if isdefined(a, :z) && isdefined(b, :z) - isequal(a.z, b.z) || return false + (a.z == b.z) || return false elseif isdefined(a, :z) || isdefined(b, :z) return false end @@ -41,17 +121,51 @@ end function Base.:(==)(a::MutableIncomplete, b::MutableIncomplete) (a === b) && return true - if !isequal(a.s, b.s) || !isequal(a.x, b.x) || !isequal(a.y, b.y) || !isequal(a.w, b.w) + if (a.s != b.s) || (a.x != b.x) || (a.y != b.y) || (a.w != b.w) return false end if isdefined(a, :z) && isdefined(b, :z) - isequal(a.z, b.z) || return false + (a.z == b.z) || return false elseif isdefined(a, :z) || isdefined(b, :z) return false end return true end +mutable struct CustomVector{T} <: AbstractVector{T} + data::Vector{T} +end + +Base.:(==)(a::CustomVector, b::CustomVector) = (a === b) || (a.data == b.data) + +function Enzyme.EnzymeCore.make_zero( + ::Type{CV}, seen::IdDict, prev::CV, ::Val{copy_if_inactive} +) where {CV<:CustomVector{<:AbstractFloat},copy_if_inactive} + @info "make_zero(::CustomVector)" + if haskey(seen, prev) + return seen[prev] + end + new = CustomVector(zero(prev.data))::CV + seen[prev] = new + return new +end + +function Enzyme.EnzymeCore.make_zero!(prev::CustomVector{<:AbstractFloat}, seen)::Nothing + @info "make_zero!(::CustomVector)" + if !isnothing(seen) + if prev in seen + return nothing + end + push!(seen, prev) + end + fill!(prev.data, false) + return nothing +end + +function Enzyme.EnzymeCore.make_zero!(prev::CustomVector{<:AbstractFloat}) + return Enzyme.EnzymeCore.make_zero!(prev, nothing) +end + struct WithIO{F} # issue 2091 v::Vector{Float64} callback::F @@ -72,284 +186,540 @@ macro test_noerr(expr) end end -@testset "make_zero" begin - # floats - @test make_zero(1.0) == 0.0 - @test make_zero(1.0im) == 0.0im - - # float arrays + multiple references - rearr = [1.0] - imarr = [1.0im] - rearr0 = make_zero(rearr) - imarr0 = make_zero(imarr) - @test typeof(rearr0) === typeof(rearr) - @test typeof(imarr0) === typeof(imarr) - @test rearr == [1.0] # no mutation - @test imarr == [1.0im] # no mutation - @test rearr0 == [0.0] - @test imarr0 == [0.0im] - rearrs0 = make_zero((rearr, rearr)) - imarrs0 = make_zero((imarr, imarr)) - @test typeof(rearrs0) === typeof((rearr, rearr)) - @test typeof(imarrs0) === typeof((imarr, imarr)) - @test rearr == [1.0] # no mutation - @test imarr == [1.0im] # no mutation - @test rearrs0[1] === rearrs0[2] - @test imarrs0[1] === imarrs0[2] - @test rearrs0[1] == [0.0] - @test imarrs0[1] == [0.0im] - - # floats in structs - rewrapped = MutableWrapper(1.0) - imwrapped = MutableWrapper(1.0im) - rewrapped0 = make_zero(rewrapped) - imwrapped0 = make_zero(imwrapped) - @test typeof(rewrapped0) === typeof(rewrapped) - @test typeof(imwrapped0) === typeof(imwrapped) - @test rewrapped == MutableWrapper(1.0) # no mutation - @test imwrapped == MutableWrapper(1.0im) # no mutation - @test rewrapped0 == MutableWrapper(0.0) - @test imwrapped0 == MutableWrapper(0.0im) - - # generic array + multiple references - wrapped = MutableWrapper(1.0) - mixarr = ["a", 1.0, wrapped] - mixarr0 = make_zero(mixarr) - @test typeof(mixarr0) === typeof(mixarr) - @test view(mixarr, 1:2) == ["a", 1.0] # no mutation - @test mixarr[3] === wrapped # no mutation - @test mixarr0 == ["a", 0.0, MutableWrapper(0.0)] - mixarrs0 = make_zero((mixarr, mixarr)) - @test typeof(mixarrs0) === typeof((mixarr, mixarr)) - @test view(mixarr, 1:2) == ["a", 1.0] # no mutation - @test mixarr[3] === wrapped # no mutation - @test mixarrs0[1] === mixarrs0[2] - @test mixarrs0[1] == ["a", 0.0, MutableWrapper(0.0)] - - # non-differentiable array + copy_if_inactive - constarr = ["a"] - constarr0 = make_zero(constarr) - @test typeof(constarr0) === typeof(constarr) - @test constarr == ["a"] # no mutation - @test constarr0 === constarr - constarr0copy = make_zero(constarr, #=copy_if_inactive=#Val(true)) - @test typeof(constarr0copy) === typeof(constarr0) - @test constarr == ["a"] # no mutation - @test constarr0copy !== constarr - @test constarr0copy == constarr - - # Tuple - tup = ("a", 1.0, MutableWrapper(1.0)) - tup0 = make_zero(tup) - @test typeof(tup0) === typeof(tup) - @test tup == ("a", 1.0, MutableWrapper(1.0)) # no mutation - @test tup0 == ("a", 0.0, MutableWrapper(0.0)) - - # NamedTuple - ntup = (a="a", b=1.0, c=MutableWrapper(1.0)) - ntup0 = make_zero(ntup) - @test typeof(ntup0) === typeof(ntup) - @test ntup == (a="a", b=1.0, c=MutableWrapper(1.0)) # no mutation - @test ntup0 == (a="a", b=0.0, c=MutableWrapper(0.0)) - - # Box + multiple references - box = Core.Box(1.0) - box0 = make_zero(box) - @test typeof(box0) === typeof(box) - @test box.contents == 1.0 # no mutation - @test box0.contents == 0.0 - boxes0 = make_zero((box, box)) - @test typeof(boxes0) === typeof((box, box)) - @test box.contents == 1.0 # no mutation - @test boxes0[1] === boxes0[2] - @test boxes0[1].contents == 0.0 - - # differentiable custom type + multiple references - wrapped = MutableWrapper(1.0) - wrapped0 = make_zero(wrapped) - @test typeof(wrapped0) === typeof(wrapped) - @test wrapped == MutableWrapper(1.0) # no mutation - @test wrapped0 == MutableWrapper(0.0) - wrappeds0 = make_zero((wrapped, wrapped)) - @test typeof(wrappeds0) === typeof((wrapped, wrapped)) - @test wrapped == MutableWrapper(1.0) # no mutation - @test wrappeds0[1] === wrappeds0[2] - @test wrappeds0[1] == MutableWrapper(0.0) - - # non-differentiable custom type + copy_if_inactive - constwrapped = MutableWrapper("a") - constwrapped0 = make_zero(constwrapped) - @test typeof(constwrapped0) === typeof(constwrapped) - @test constwrapped == MutableWrapper("a") # no mutation - @test constwrapped0 === constwrapped - constwrapped0copy = make_zero(constwrapped, #=copy_if_inactive=#Val(true)) - @test typeof(constwrapped0copy) === typeof(constwrapped0) - @test constwrapped == MutableWrapper("a") # no mutation - @test constwrapped0copy !== constwrapped - @test constwrapped0copy == constwrapped - - # immutable struct with active, mutable, inactive and undefined fields - incomplete = Incomplete("a", 1.0, MutableWrapper(1.0)) - incomplete0 = make_zero(incomplete) - @test typeof(incomplete0) === typeof(incomplete) - @test incomplete == Incomplete("a", 1.0, MutableWrapper(1.0)) # no mutation - @test incomplete0 == Incomplete("a", 0.0, MutableWrapper(0.0)) - - # mutable struct with inactive, active, undefined, and mutable fields - # + multiple references - incompletemut = MutableIncomplete("a", 1.0, 1.0, MutableWrapper(1.0)) - incompletemut0 = make_zero(incompletemut) - @test typeof(incompletemut0) === typeof(incompletemut) - @test incompletemut == MutableIncomplete("a", 1.0, 1.0, MutableWrapper(1.0)) # no mutation - @test incompletemut0 == MutableIncomplete("a", 0.0, 0.0, MutableWrapper(0.0)) - incompletemuts0 = make_zero((incompletemut, incompletemut)) - @test typeof(incompletemuts0) === typeof((incompletemut, incompletemut)) - @test incompletemut == MutableIncomplete("a", 1.0, 1.0, MutableWrapper(1.0)) # no mutation - @test incompletemuts0[1] === incompletemuts0[2] - @test incompletemuts0[1] == MutableIncomplete("a", 0.0, 0.0, MutableWrapper(0.0)) - - # containing IO (issue #2091) - f = WithIO([1.0, 2.0], stdout) - df = @test_noerr make_zero(f) - @test df.v == [0.0, 0.0] - @test df.callback === f.callback +const scalartypes = [Float32, ComplexF32, Float64, ComplexF64] + +const inactivetup = ("a", Empty(), MutableEmpty()) +const inactivearr = [inactivetup] + +const wrappers = [ + (name="Tuple{X}", f=tuple, N=1, mutable=false, typed=true), + (name="@NamedTuple{x::X}", f=(NamedTuple{(:x,)} ∘ tuple), N=1, mutable=false, typed=true), + (name="struct{X}", f=Wrapper, N=1, mutable=false, typed=true), + + (name="@NamedTuple{x}", f=(@NamedTuple{x} ∘ tuple), N=1, mutable=false, typed=false), + (name="struct{Any}", f=Wrapper{Any}, N=1, mutable=false, typed=false), + + (name="Array{X}", f=(x -> [x]), N=1, mutable=true, typed=true), + (name="Base.RefValue{X}", f=Ref, N=1, mutable=true, typed=true), + (name="mutable struct{X}", f=MutableWrapper, N=1, mutable=true, typed=true), + + (name="Array{Any}", f=(x -> Any[x]), N=1, mutable=true, typed=false), + (name="Base.RefValue{Any}", f=Ref{Any}, N=1, mutable=true, typed=false), + (name="Core.Box", f=Core.Box, N=1, mutable=true, typed=false), + (name="mutable struct{Any}", f=MutableWrapper{Any}, N=1, mutable=true, typed=false), + + (name="Tuple{X,Y}", f=tuple, N=2, mutable=false, typed=true), + (name="@NamedTuple{x::X,y::Y}", f=(NamedTuple{(:x, :y)} ∘ tuple), N=2, mutable=false, typed=true), + (name="struct{X,Y}", f=DualWrapper, N=2, mutable=false, typed=true), + + (name="@NamedTuple{x,y::Y}", f=((x, y) -> @NamedTuple{x,y::typeof(y)}((x, y))), N=2, mutable=false, typed=:partial), + (name="struct{Any,Y}", f=DualWrapper{Any}, N=2, mutable=false, typed=:partial), + + (name="@NamedTuple{x,y}", f=@NamedTuple{x,y} ∘ tuple, N=2, mutable=false, typed=false), + (name="struct{Any}", f=DualWrapper{Any,Any}, N=2, mutable=false, typed=false), + + (name="mutable struct{X,Y}", f=MutableDualWrapper, N=2, mutable=true, typed=true), + + (name="Array{promote_type(X,Y)}", f=((x, y) -> [x, y]), N=2, mutable=true, typed=:promoted), + (name="mutable struct{Any,Y}", f=MutableDualWrapper{Any}, N=2, mutable=true, typed=:partial), + + (name="Array{Any}", f=((x, y) -> Any[x, y]), N=2, mutable=true, typed=false), + (name="mutable struct{Any,Any}", f=MutableDualWrapper{Any,Any}, N=2, mutable=true, typed=false), + + # StaticArrays extension + (name="SVector{1,X}", f=SVector{1} ∘ tuple, N=1, mutable=false, typed=true), + (name="SVector{1,Any}", f=SVector{1,Any} ∘ tuple, N=1, mutable=false, typed=false), + (name="MVector{1,X}", f=MVector{1} ∘ tuple, N=1, mutable=true, typed=true), + (name="MVector{1,Any}", f=MVector{1,Any} ∘ tuple, N=1, mutable=true, typed=false), + (name="SVector{2,promote_type(X,Y)}", f=SVector{2} ∘ tuple, N=2, mutable=false, typed=:promoted), + (name="SVector{2,Any}", f=SVector{2,Any} ∘ tuple, N=2, mutable=false, typed=false), + (name="MVector{2,promote_type(X,Y)}", f=MVector{2} ∘ tuple, N=2, mutable=true, typed=:promoted), + (name="MVector{2,Any}", f=MVector{2,Any} ∘ tuple, N=2, mutable=true, typed=false), +] + +@static if VERSION < v"1.11-" +else +_memory(x::Vector) = Memory{eltype(x)}(x) +push!( + wrappers, + (name="Memory{X}", f=(x -> _memory([x])), N=1, mutable=true, typed=true), + (name="Memory{Any}", f=(x -> _memory(Any[x])), N=1, mutable=true, typed=false), + (name="Memory{promote_type(X,Y)}", f=((x, y) -> _memory([x, y])), N=2, mutable=true, typed=:promoted), + (name="Memory{Any}", f=((x, y) -> _memory(Any[x, y])), N=2, mutable=true, typed=false), +) end -@testset "make_zero!" begin - # floats in mutable struct - rewrapped, imwrapped = MutableWrapper(1.0), MutableWrapper(1.0im) - make_zero!(rewrapped) - make_zero!(imwrapped) - @test rewrapped == MutableWrapper(0.0) - @test imwrapped == MutableWrapper(0.0im) - - # mixed tuple in mutable container - wrapped = MutableWrapper(1.0) - tuparr = [(1.0, wrapped)] - make_zero!(tuparr) - @test tuparr[1] === (0.0, wrapped) - @test wrapped == MutableWrapper(0.0) - - # mixed namedtuple in mutable container - wrapped = MutableWrapper(1.0) - ntuparr = [(a=1.0, b=wrapped)] - make_zero!(ntuparr) - @test ntuparr[1] === (a=0.0, b=wrapped) - @test wrapped == MutableWrapper(0.0) - - # immutable struct with active, mutable, inactive and undefined fields in mutable container - wrapped = MutableWrapper(1.0) - incompletearr = [Incomplete("a", 1.0, wrapped)] - make_zero!(incompletearr) - @test incompletearr[1] == Incomplete("a", 0.0, wrapped) - @test wrapped == MutableWrapper(0.0) - - # floats in Ref - reref, imref = Ref(1.0), Ref(1.0im) - make_zero!(reref) - make_zero!(imref) - @test reref[] == 0.0 - @test imref[] == 0.0im - - # float arrays - rearr, imarr = [1.0], [1.0im] - make_zero!(rearr) - make_zero!(imarr) - @test rearr[1] == 0.0 - @test imarr[1] == 0.0im - - # non-differentiable array - constarr = ["a"] - make_zero!(constarr) - @test constarr[1] == "a" - - # array with active, mutable, inactive and unassigned elements + multiple references - wrapped = MutableWrapper(1.0) - genericarr = Vector(undef, 4) - genericarr[1:3] .= ("a", 1.0, wrapped) - genericarrs = [genericarr, genericarr] - make_zero!(genericarrs) - @test genericarrs[1] === genericarrs[2] - @test genericarrs[1] === genericarr - @test view(genericarr, 1:2) == ["a", 0.0] - @test genericarr[3] === wrapped - @test wrapped == MutableWrapper(0.0) - @test !isassigned(genericarr, 4) - - # Ref with multiple references - genericref = Ref((1.0,)) - genericrefs = [genericref, genericref] - make_zero!(genericrefs) - @test genericrefs[1] === genericrefs[2] - @test genericrefs[1] === genericref - @test genericref[] == (0.0,) - - # Ref with mutable value - wrapped = MutableWrapper(1.0) - mutref = Ref(wrapped) - make_zero!(mutref) - @test mutref[] === wrapped - @test wrapped == MutableWrapper(0.0) - - # Ref with non-differentiable value - constref = Ref("a") - make_zero!(constref) - @test constref[] == "a" - - # Box with multiple references - box = Core.Box(1.0) - boxes = [box, box] - make_zero!(boxes) - @test boxes[1] === boxes[2] - @test boxes[1] === box - @test box.contents == 0.0 - - # Box with mutable value - wrapped = MutableWrapper(1.0) - mutbox = Core.Box(wrapped) - make_zero!(mutbox) - @test mutbox.contents === wrapped - @test wrapped == MutableWrapper(0.0) - - # Box with non-differentiable value - constbox = Core.Box("a") - make_zero!(constbox) - @test constbox.contents == "a" - - # mutable struct with inactive, active, const active, undefined, and mutable fields - # + multiple references - wrapped = MutableWrapper(1.0) - incompletemut = MutableIncomplete("a", #=const=#1.0, 1.0, wrapped) - incompletemuts = [incompletemut, incompletemut] - make_zero!(incompletemuts) - @test incompletemuts[1] === incompletemuts[2] - @test incompletemuts[1] === incompletemut - @test incompletemut == MutableIncomplete("a", #=const=#0.0, 0.0, MutableWrapper(0.0)) - @test incompletemut.w === wrapped - - # wrapped differentiable array - arr = [1.0] - arrwrapped = MutableWrapper(arr) - make_zero!(arrwrapped) - @test arrwrapped.x === arr - @test arr == [0.0] - - # early error on active/mixed type - @test_throws ErrorException make_zero!(1.0) - @test_throws ErrorException make_zero!((1.0, MutableWrapper(1.0))) - - # immutable struct with both active and undefined fields in immutable container - # (currently fails due to #1935) - wrapped = MutableWrapper(1.0) - incompletetuparr = [(Incomplete("a", 1.0, wrapped),)] - make_zero!(incompletetuparr) - @test incompletetuparr[1][1] == Incomplete("a", 0.0, MutableWrapper(0.0)) - @test incompletetuparr[1][1].w === wrapped - - # containing IO (issue #2091) - f = WithIO([1.0, 2.0], stdout) - fwrapped = [f] - @test_noerr make_zero!(fwrapped) - @test fwrapped[1] === f - @test fwrapped[1].v == [0.0, 0.0] +function test_make_zero() + @testset "scalars" begin + @testset "$T" for T in scalartypes + x = oneunit(T) + x_makez = make_zero(x) + @test typeof(x_makez) === T # correct type + @test x_makez == zero(T) # correct value + @test x == oneunit(T) # no mutation of original (relevant for BigFloat) + end + end + @testset "nested types" begin + @testset "$T in $(wrapper.name)" for + T in scalartypes, wrapper in filter(w -> (w.N == 1), wrappers) + x = oneunit(T) + w = wrapper.f(x) + w_makez = make_zero(w) + @test typeof(w_makez) === typeof(w) # correct type + @test typeof(getx(w_makez)) === T # correct type + @test getx(w_makez) == zero(T) # correct value + @test getx(w) === x # no mutation of original + @test x == oneunit(T) # no mutation of original (relevant for BigFloat) + @testset "doubly included in $(dualwrapper.name)" for + dualwrapper in filter(w -> (w.N == 2), wrappers) + w_inner = wrapper.f(x) + d_outer = dualwrapper.f(w_inner, w_inner) + d_outer_makez = make_zero(d_outer) + @test typeof(d_outer_makez) === typeof(d_outer) # correct type + @test typeof(getx(d_outer_makez)) === typeof(w_inner) # correct type + @test typeof(getx(getx(d_outer_makez))) === T # correct type + @test getx(d_outer_makez) === gety(d_outer_makez) # correct topology + @test getx(getx(d_outer_makez)) == zero(T) # correct value + @test getx(d_outer) === gety(d_outer) # no mutation of original + @test getx(d_outer) === w_inner # no mutation of original + @test getx(w_inner) === x # no mutation of original + @test x == oneunit(T) # no mutation of original (relevant for BigFloat) + d_inner = dualwrapper.f(x, x) + w_outer = wrapper.f(d_inner) + w_outer_makez = make_zero(w_outer) + @test typeof(w_outer_makez) === typeof(w_outer) # correct type + @test typeof(getx(w_outer_makez)) === typeof(d_inner) # correct type + @test typeof(getx(getx(w_outer_makez))) === T # correct type + @test getx(getx(w_outer_makez)) == gety(getx(w_outer_makez)) # correct topology + @test getx(getx(w_outer_makez)) == zero(T) # correct value + @test getx(w_outer) === d_inner # no mutation of original + @test getx(d_inner) === gety(d_inner) # no mutation of original + @test getx(d_inner) === x # no mutation of original + @test x == oneunit(T) # no mutation of original (relevant for BigFloat) + if wrapper.mutable && !dualwrapper.mutable + # some code paths can only be hit with three layers of wrapping: + # mutable(immutable(mutable(scalar))) + @testset "all wrapped in $(outerwrapper.name)" for + outerwrapper in filter(w -> ((w.N == 1) && w.mutable), wrappers) + w_inner = wrapper.f(x) + d_middle = dualwrapper.f(w_inner, w_inner) + w_outer = outerwrapper.f(d_middle) + w_outer_makez = make_zero(w_outer) + @test typeof(w_outer_makez) === typeof(w_outer) # correct type + @test typeof(getx(w_outer_makez)) === typeof(d_middle) # correct type + @test typeof(getx(getx(w_outer_makez))) === typeof(w_inner) # correct type + @test typeof(getx(getx(getx(w_outer_makez)))) === T # correct type + @test getx(getx(w_outer_makez)) === gety(getx(w_outer_makez)) # correct topology + @test getx(getx(getx(w_outer_makez))) == zero(T) # correct value + @test getx(w_outer) === d_middle # no mutation of original + @test getx(d_middle) === gety(d_middle) # no mutation of original + @test getx(d_middle) === w_inner # no mutation of original + @test getx(w_inner) === x # no mutation of original + @test x == oneunit(T) # no mutation of original (relevant for BigFloat) + end + end + end + end + end + @testset "inactive" begin + @testset "in $(wrapper.name)" for wrapper in wrappers + if wrapper.N == 1 + w = wrapper.f(inactivearr) + w_makez = make_zero(w) + if wrapper.typed == true + @test w_makez === w # preserved wrapper identity if guaranteed const + end + @test typeof(w_makez) === typeof(w) # correct type + @test getx(w_makez) === inactivearr # preserved identity + @test inactivearr[1] === inactivetup # preserved value + @test getx(w) === inactivearr # no mutation of original + else # wrapper.N == 2 + @testset "multiple references" begin + w = wrapper.f(inactivearr, inactivearr) + w_makez = make_zero(w) + if wrapper.typed == true + @test w_makez === w # preserved wrapper identity if guaranteed const + end + @test typeof(w_makez) === typeof(w) # correct type + @test getx(w_makez) === gety(w_makez) # preserved topology + @test getx(w_makez) === inactivearr # preserved identity + @test inactivearr[1] === inactivetup # preserved value + @test getx(w) === gety(w) # no mutation of original + @test getx(w) === inactivearr # no mutation of original + end + @testset "alongside active" begin + a = [1.0] + w = wrapper.f(a, inactivearr) + w_makez = make_zero(w) + @test typeof(w_makez) === typeof(w) # correct type + @test typeof(getx(w_makez)) === typeof(a) # correct type + @test getx(w_makez) == [0.0] # correct value + @test gety(w_makez) === inactivearr # preserved inactive identity + @test inactivearr[1] === inactivetup # preserved inactive value + @test getx(w) === a # no mutation of original + @test a[1] === 1.0 # no mutation of original + @test gety(w) === inactivearr # no mutation of original + if wrapper.typed == :partial + # above: untyped active / typed inactive + # below: untyped inactive / typed active + w = wrapper.f(inactivearr, a) + w_makez = make_zero(w) + @test typeof(w_makez) === typeof(w) # correct type + @test getx(w_makez) === inactivearr # preserved inactive identity + @test inactivearr[1] === inactivetup # preserved inactive value + @test typeof(gety(w_makez)) === typeof(a) # correct type + @test gety(w_makez) == [0.0] # correct value + @test getx(w) === inactivearr # no mutation of original + @test gety(w) === a # no mutation of original + @test a[1] === 1.0 # no mutation of original + end + end + end + end + @testset "copy_if_inactive $value" for (value, args) in [ + ("unspecified", ()), + ("= false", (Val(false),)), + ("= true", (Val(true),)), + ] + a = [1.0] + w = Any[a, inactivearr, inactivearr] + w_makez = make_zero(w, args...) + @test typeof(w_makez) === typeof(w) # correct type + @test typeof(w_makez[1]) === typeof(a) # correct type + @test w_makez[1] == [0.0] # correct value + @test w_makez[2] === w_makez[3] # correct topology (topology should propagate even when copy_if_inactive = Val(true)) + @test w[1] === a # no mutation of original + @test a[1] === 1.0 # no mutation of original + @test w[2] === w[3] # no mutation of original + @test w[2] === inactivearr # no mutation of original + @test inactivearr[1] === inactivetup # no mutation of original + if args == (Val(true),) + @test typeof(w_makez[2]) === typeof(inactivearr) # correct type + @test w_makez[2] == inactivearr # correct value + @test w_makez[2][1] !== inactivetup # correct identity + else + @test w_makez[2] === inactivearr # correct value/type/identity + end + end + end + @testset "heterogeneous containers" begin + scalars, scalarsz = oneunit.(scalartypes), zero.(scalartypes) + wraps, wrapsz = Wrapper.(scalars), Wrapper.(scalarsz) + mwraps, mwrapsz = MutableWrapper.(scalars), MutableWrapper.(scalarsz) + items = (inactivetup..., scalars..., wraps..., mwraps...) + itemsz = (inactivetup..., scalarsz..., wrapsz..., mwrapsz...) + labels = Symbol.("i" .* string.(1:length(items))) + @testset "$name" for (name, c, cz) in [ + ("Tuple", Tuple(items), Tuple(itemsz)), + ("NamedTuple", NamedTuple(labels .=> items), NamedTuple(labels .=> itemsz)), + ("Array", collect(items), collect(itemsz)), + ] + c_makez = make_zero(c) + @test typeof(c_makez) === typeof(c) # correct type + @test all(typeof(czj) === typeof(cj) for (czj, cj) in zip(c_makez, c)) # correct type + @test c_makez == cz # correct value + @test all(czj === inj for (czj, inj) in zip(c_makez, inactivetup)) # preserved inactive identities + @test all(cj === itj for (cj, itj) in zip(c, items)) # no mutation of original + @test all(m.x == oneunit(m.x) for m in mwraps) # no mutation of original + end + end + @testset "circular references" begin + @testset "$(wrapper.name)" for wrapper in ( + filter(w -> (w.mutable && (w.typed in (:partial, false))), wrappers) + ) + a = [1.0] + if wrapper.N == 1 + w = wrapper.f(nothing) + setx!(w, (w, a)) + else + w = wrapper.f(nothing, a) + setx!(w, w) + end + w_makez = @test_noerr make_zero(w) + if wrapper.N == 1 + xz, yz = getx(w_makez) + x, y = getx(w) + else + xz, yz = getx(w_makez), gety(w_makez) + x, y = getx(w), gety(w) + end + @test typeof(w_makez) === typeof(w) # correct type + @test typeof(xz) === typeof(w) # correct type + @test typeof(yz) === typeof(a) # correct type + @test xz === w_makez # correct self-reference + @test yz == [0.0] # correct value + @test x === w # no mutation of original + @test y === a # no mutation of original + @test a[1] === 1.0 # no mutation of original + end + end + @testset "bring your own IdDict" begin + a = [1.0] + seen = IdDict() + a_makez = make_zero(typeof(a), seen, a) + @test typeof(a_makez) === typeof(a) # correct type + @test a_makez == [0.0] # correct value + @test a[1] === 1.0 # no mutation of original + @test haskey(seen, a) # original added to IdDict + @test seen[a] === a_makez # original points to zeroed value + end + @testset "custom leaf type" begin + a = [1.0] + v = CustomVector(a) + # include optional arg Val(false) to avoid calling the custom method directly; + # it should still be invoked + v_makez = @test_logs (:info, "make_zero(::CustomVector)") make_zero(v, Val(false)) + @test typeof(v_makez) === typeof(v) # correct type + @test typeof(v_makez.data) === typeof(a) # correct type + @test v_makez == CustomVector([0.0]) # correct value + @test v.data === a # no mutation of original + @test a[1] === 1.0 # no mutation of original + end + @testset "undefined fields/unassigned elements" begin + @testset "array w inactive/active/mutable/unassigned" begin + a = [1.0] + values = ("a", 1.0, a) + arr = Vector{Any}(undef, 4) + arr[1:3] .= values + arr_makez = make_zero(arr) + @views begin + @test typeof(arr_makez) === typeof(arr) # correct type + @test all(typeof.(arr_makez[1:3]) .=== typeof.(values)) # correct type + @test arr_makez[1:3] == ["a", 0.0, [0.0]] # correct value + @test !isassigned(arr_makez, 4) # propagated undefined + @test all(arr[1:3] .=== values) # no mutation of original + @test !isassigned(arr, 4) # no mutation of original + @test a[1] === 1.0 # no mutation of original + end + end + @testset "struct w inactive/active/mutable/undefined" begin + a = [1.0] + incomplete = Incomplete("a", 1.0, a) + incomplete_makez = make_zero(incomplete) + @test typeof(incomplete_makez) === typeof(incomplete) # correct type + @test typeof(incomplete_makez.w) === typeof(a) # correct type + @test incomplete_makez == Incomplete("a", 0.0, [0.0]) # correct value, propagated undefined + @test a[1] === 1.0 # no mutation of original + end + @testset "mutable struct w inactive/const active/active/mutable/undefined" begin + a = [1.0] + incomplete = MutableIncomplete("a", #=const=#1.0, 1.0, a) + incomplete_makez = make_zero(incomplete) + @test typeof(incomplete_makez) === typeof(incomplete) # correct type + @test typeof(incomplete_makez.w) === typeof(a) # correct type + @test incomplete_makez == MutableIncomplete("a", 0.0, 0.0, [0.0]) # correct value, propagated undefined + @test incomplete == MutableIncomplete("a", 1.0, 1.0, a) # no mutation of original + @test incomplete.w === a # no mutation of original + @test a[1] === 1.0 # no mutation of original + end + end + @testset "containing IO" begin # issue #2091 + f = WithIO([1.0, 2.0], stdout) + df = @test_noerr make_zero(f) + @test df.v == [0.0, 0.0] + @test df.callback === f.callback + end + return nothing end + +function test_make_zero!() + @testset "nested types" begin + @testset "$T in $(wrapper.name)" for + T in scalartypes, wrapper in filter(w -> (w.N == 1), wrappers) + x = oneunit(T) + if wrapper.mutable + w = wrapper.f(x) + make_zero!(w) + @test typeof(getx(w)) === T # preserved type + @test getx(w) == zero(T) # correct value + @test x == oneunit(T) # no mutation of scalar (relevant for BigFloat) + end + @testset "doubly included in $(dualwrapper.name)" for dualwrapper in ( + filter(w -> ((w.N == 2) && (w.mutable || wrapper.mutable)), wrappers) + ) + w_inner = wrapper.f(x) + d_outer = dualwrapper.f(w_inner, w_inner) + make_zero!(d_outer) + @test typeof(getx(d_outer)) === typeof(w_inner) # preserved type + @test typeof(getx(getx(d_outer))) === T # preserved type + @test getx(getx(d_outer)) == zero(T) # correct value + @test getx(d_outer) === gety(d_outer) # preserved topology + @test x == oneunit(T) # no mutation of scalar (relevant for BigFloat) + if wrapper.mutable + @test getx(d_outer) === w_inner # preserved identity + end + d_inner = dualwrapper.f(x, x) + w_outer = wrapper.f(d_inner) + make_zero!(w_outer) + @test typeof(getx(w_outer)) === typeof(d_inner) # preserved type + @test typeof(getx(getx(w_outer))) === T # preserved type + @test getx(getx(w_outer)) == zero(T) # correct value + @test getx(getx(w_outer)) === gety(getx(w_outer)) # preserved topology + @test x == oneunit(T) # no mutation of scalar (relevant for BigFloat) + if dualwrapper.mutable + @test getx(w_outer) === d_inner # preserved identity + end + if wrapper.mutable && !dualwrapper.mutable + # some code paths can only be hit with three layers of wrapping: + # mutable(immutable(mutable(scalar))) + @assert !dualwrapper.mutable # sanity check + @testset "all wrapped in $(outerwrapper.name)" for + outerwrapper in filter(w -> ((w.N == 1) && w.mutable), wrappers) + w_inner = wrapper.f(x) + d_middle = dualwrapper.f(w_inner, w_inner) + w_outer = outerwrapper.f(d_middle) + make_zero!(w_outer) + @test typeof(getx(w_outer)) === typeof(d_middle) # preserved type + @test typeof(getx(getx(w_outer))) === typeof(w_inner) # preserved type + @test typeof(getx(getx(getx(w_outer)))) === T # preserved type + @test getx(getx(getx(w_outer))) == zero(T) # correct value + @test getx(getx(w_outer)) === gety(getx(w_outer)) # preserved topology + @test getx(getx(w_outer)) === w_inner # preserved identity + @test x == oneunit(T) # no mutation of scalar (relevant for BigFloat) + end + end + end + end + end + @testset "inactive" begin + @testset "in $(wrapper.name)" for + wrapper in filter(w -> (w.mutable || (w.typed == true)), wrappers) + if wrapper.N == 1 + w = wrapper.f(inactivearr) + make_zero!(w) + @test getx(w) === inactivearr # preserved identity + @test inactivearr[1] === inactivetup # preserved value + else # wrapper.N == 2 + @testset "multiple references" begin + w = wrapper.f(inactivearr, inactivearr) + make_zero!(w) + @test getx(w) === gety(w) # preserved topology + @test getx(w) === inactivearr # preserved identity + @test inactivearr[1] === inactivetup # preserved value + end + @testset "alongside active" begin + a = [1.0] + w = wrapper.f(a, inactivearr) + make_zero!(w) + @test getx(w) === a # preserved identity + @test a[1] === 0.0 # correct value + @test gety(w) === inactivearr # preserved inactive identity + @test inactivearr[1] === inactivetup # preserved inactive value + end + end + end + end + @testset "heterogeneous containers" begin + mwraps = MutableWrapper.(oneunit.(scalartypes)) + mwrapsz = MutableWrapper.(zero.(scalartypes)) + items = (inactivetup..., mwraps...) + itemsz = (inactivetup..., mwrapsz...) + labels = Symbol.("i" .* string.(1:length(items))) + @testset "$name" for (name, c, cz) in [ + ("Tuple", Tuple(items), Tuple(itemsz)), + ("NamedTuple", NamedTuple(labels .=> items), NamedTuple(labels .=> itemsz)), + ("Array", collect(items), collect(itemsz)), + ] + make_zero!(c) + @test all(cj === itj for (cj, itj) in zip(c, items)) # preserved identities + @test c == cz # correct value + end + end + @testset "circular references" begin + @testset "$(wrapper.name)" for wrapper in ( + filter(w -> (w.mutable && (w.typed in (:partial, false))), wrappers) + ) + a = [1.0] + if wrapper.N == 1 + w = wrapper.f(nothing) + setx!(w, (w, a)) + else + w = wrapper.f(nothing, a) + setx!(w, w) + end + @test_noerr make_zero!(w) + if wrapper.N == 1 + x, y = getx(w) + else + x, y = getx(w), gety(w) + end + @test x === w # preserved self-referential identity + @test y === a # preserved identity + @test a[1] === 0.0 # correct value + end + end + @testset "bring your own IdSet" begin + a = [1.0] + seen = Base.IdSet() + make_zero!(a, seen) + @test a[1] === 0.0 # correct value + @test (a in seen) # object added to IdSet + end + @testset "custom leaf type" begin + a = [1.0] + v = CustomVector(a) + # bringing own IdSet to avoid calling the custom method directly; + # it should still be invoked + @test_logs (:info, "make_zero!(::CustomVector)") make_zero!(v, Base.IdSet()) + @test v.data === a # preserved identity + @test a[1] === 0.0 # correct value + end + @testset "undefined fields/unassigned elements" begin + @testset "array w inactive/active/mutable/unassigned" begin + a = [1.0] + values = ("a", 1.0, a) + arr = Vector{Any}(undef, 4) + arr[1:3] .= values + make_zero!(arr) + @views begin + @test all(typeof.(arr[1:3]) .=== typeof.(values)) # preserved types + @test arr[1:3] == ["a", 0.0, [0.0]] # correct value + @test arr[3] === a # preserved identity + @test !isassigned(arr, 4) # preserved unassigned + end + end + @testset "struct w inactive/active/mutable/undefined" begin + a = [1.0] + incompletearr = [Incomplete("a", 1.0, a)] + make_zero!(incompletearr) + @test incompletearr == [Incomplete("a", 0.0, [0.0])] # correct value, preserved undefined + @test incompletearr[1].w === a # preserved identity + end + @testset "mutable struct w inactive/const active/active/mutable/undefined" begin + a = [1.0] + incomplete = MutableIncomplete("a", #=const=#1.0, 1.0, a) + make_zero!(incomplete) + @test incomplete == MutableIncomplete("a", 0.0, 0.0, [0.0]) # correct value, preserved undefined + @test incomplete.w === a # preserved identity + end + @testset "Array{Tuple{struct w undefined}} (issue #1935)" begin + # old implementation triggered #1935 + # new implementation would work regardless due to limited use of justActive + a = [1.0] + incomplete = Incomplete("a", 1.0, a) + incompletetuparr = [(incomplete,)] + make_zero!(incompletetuparr) + @test typeof(incompletetuparr[1]) === typeof((incomplete,)) # preserved type + @test incompletetuparr == [(Incomplete("a", 0.0, [0.0]),)] # correct value + @test incompletetuparr[1][1].w === a # preserved identity + end + end + @testset "active/mixed type error" begin + @test_throws ArgumentError make_zero!((1.0,)) + @test_throws ArgumentError make_zero!((1.0, [1.0])) + @test_throws ArgumentError make_zero!((Incomplete("a", 1.0, 1.0im),)) # issue #1935 + end + @testset "containing IO" begin # issue #2091 + f = WithIO([1.0, 2.0], stdout) + fwrapped = [f] + @test_noerr make_zero!(fwrapped) + @test fwrapped[1] === f + @test fwrapped[1].v == [0.0, 0.0] + end + return nothing +end + +@testset "make_zero" test_make_zero() +@testset "make_zero!" test_make_zero!() + +end # module MakeZeroTests From 0d0022411d7e047cf0994f9b9327607baef7920c Mon Sep 17 00:00:00 2001 From: Daniel Wennberg Date: Wed, 27 Nov 2024 14:50:07 -0800 Subject: [PATCH 4/6] Improve make_zero! error message --- src/make_zero.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/make_zero.jl b/src/make_zero.jl index 562a8afe7d..5fa1c8bd47 100644 --- a/src/make_zero.jl +++ b/src/make_zero.jl @@ -577,7 +577,8 @@ end elseif activitystate == DupState EnzymeCore.make_zero!(xi, seen) else - throw(ArgumentError("$xi of type $SBT cannot be zeroed in-place")) + msg = "cannot set $xi to zero in-place, as it contains differentiable values in immutable positions" + throw(ArgumentError(msg)) end end end From cc8824076076f9fec11fb3e9c1a8cd93659fbdcc Mon Sep 17 00:00:00 2001 From: Daniel Wennberg Date: Wed, 27 Nov 2024 18:36:08 -0800 Subject: [PATCH 5/6] Simplify likely dead branch --- src/make_zero.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/make_zero.jl b/src/make_zero.jl index 5fa1c8bd47..5c7b49a749 100644 --- a/src/make_zero.jl +++ b/src/make_zero.jl @@ -231,8 +231,7 @@ end return y end if nf == 0 - # nothing to do, assume inactive - return copy_if_inactive ? Base.deepcopy_internal(prev, seen) : prev + return prev end flds = Vector{Any}(undef, nf) for i = 1:nf From 636b025dc9061a9718702b4621587641e1655d23 Mon Sep 17 00:00:00 2001 From: Daniel Wennberg Date: Wed, 27 Nov 2024 18:41:45 -0800 Subject: [PATCH 6/6] Reinstate single-arg StaticArrays methods --- ext/EnzymeStaticArraysExt.jl | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/ext/EnzymeStaticArraysExt.jl b/ext/EnzymeStaticArraysExt.jl index 7e1d131116..ef955ebd9b 100644 --- a/ext/EnzymeStaticArraysExt.jl +++ b/ext/EnzymeStaticArraysExt.jl @@ -33,12 +33,23 @@ end end @inline function Enzyme.EnzymeCore.make_zero( - ::Type{FT}, seen::IdDict, prev::FT, ::Val{copy_if_inactive} + prev::FT +) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:SArray{S,T}} + return Base.zero(prev)::FT +end +@inline function Enzyme.EnzymeCore.make_zero( + prev::FT +) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:MArray{S,T}} + return Base.zero(prev)::FT +end + +@inline function Enzyme.EnzymeCore.make_zero( + ::Type{FT}, seen::IdDict, prev::FT, ::Val{copy_if_inactive} = Val(false) ) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:SArray{S,T},copy_if_inactive} return Base.zero(prev)::FT end @inline function Enzyme.EnzymeCore.make_zero( - ::Type{FT}, seen::IdDict, prev::FT, ::Val{copy_if_inactive} + ::Type{FT}, seen::IdDict, prev::FT, ::Val{copy_if_inactive} = Val(false) ) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:MArray{S,T},copy_if_inactive} if haskey(seen, prev) return seen[prev] @@ -47,6 +58,7 @@ end seen[prev] = new return new end + @inline function Enzyme.EnzymeCore.make_zero!( prev::FT, seen ) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:MArray{S,T}} @@ -62,7 +74,8 @@ end @inline function Enzyme.EnzymeCore.make_zero!( prev::FT ) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:MArray{S,T}} - return Enzyme.EnzymeCore.make_zero!(prev, nothing) + Enzyme.EnzymeCore.make_zero!(prev, nothing) + return nothing end end