diff --git a/ext/EnzymeStaticArraysExt.jl b/ext/EnzymeStaticArraysExt.jl index c2639a4c99..7e1d131116 100644 --- a/ext/EnzymeStaticArraysExt.jl +++ b/ext/EnzymeStaticArraysExt.jl @@ -32,11 +32,37 @@ end end end -@inline function Enzyme.EnzymeCore.make_zero(x::FT)::FT where {FT<:SArray} - return Base.zero(x) +@inline function Enzyme.EnzymeCore.make_zero( + ::Type{FT}, seen::IdDict, prev::FT, ::Val{copy_if_inactive} +) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:SArray{S,T},copy_if_inactive} + return Base.zero(prev)::FT end -@inline function Enzyme.EnzymeCore.make_zero(x::FT)::FT where {FT<:MArray} - return Base.zero(x) +@inline function Enzyme.EnzymeCore.make_zero( + ::Type{FT}, seen::IdDict, prev::FT, ::Val{copy_if_inactive} +) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:MArray{S,T},copy_if_inactive} + if haskey(seen, prev) + return seen[prev] + end + new = Base.zero(prev)::FT + seen[prev] = new + return new +end +@inline function Enzyme.EnzymeCore.make_zero!( + prev::FT, seen +) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:MArray{S,T}} + if !isnothing(seen) + if prev in seen + return nothing + end + push!(seen, prev) + end + fill!(prev, zero(T)) + return nothing +end +@inline function Enzyme.EnzymeCore.make_zero!( + prev::FT +) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:MArray{S,T}} + return Enzyme.EnzymeCore.make_zero!(prev, nothing) end end diff --git a/src/make_zero.jl b/src/make_zero.jl index d744956672..562a8afe7d 100644 --- a/src/make_zero.jl +++ b/src/make_zero.jl @@ -1,4 +1,3 @@ - @inline function EnzymeCore.make_zero(x::FT)::FT where {FT<:AbstractFloat} return Base.zero(x) end @@ -178,7 +177,9 @@ end prev::NamedTuple{A,RT}, ::Val{copy_if_inactive} = Val(false), )::NamedTuple{A,RT} where {copy_if_inactive,A,RT} - return NamedTuple{A,RT}(EnzymeCore.make_zero(RT, seen, RT(prev), Val(copy_if_inactive))) + prevtup = RT(prev) + TT = Core.Typeof(prevtup) # RT can be abstract + return NamedTuple{A,RT}(EnzymeCore.make_zero(TT, seen, prevtup, Val(copy_if_inactive))) end @inline function EnzymeCore.make_zero( @@ -191,10 +192,9 @@ end return seen[prev] end prev2 = prev.contents - res = Core.Box( - EnzymeCore.make_zero(Core.Typeof(prev2), seen, prev2, Val(copy_if_inactive)), - ) + res = Core.Box() seen[prev] = res + res.contents = EnzymeCore.make_zero(Core.Typeof(prev2), seen, prev2, Val(copy_if_inactive)) return res end @@ -213,7 +213,6 @@ end @assert !Base.isabstracttype(RT) @assert Base.isconcretetype(RT) nf = fieldcount(RT) - if ismutable(prev) y = ccall(:jl_new_struct_uninit, Any, (Any,), RT)::RT seen[prev] = y @@ -231,13 +230,10 @@ end end return y end - if nf == 0 - # Unclear what types might end up here rather than in specialized methods or - # guaranteed_const_nongen, but as a last-ditch attempt try falling back to Base.zero - return Base.zero(prev)::RT + # nothing to do, assume inactive + return copy_if_inactive ? Base.deepcopy_internal(prev, seen) : prev end - flds = Vector{Any}(undef, nf) for i = 1:nf if isdefined(prev, i) @@ -255,22 +251,27 @@ end end function make_zero_immutable!(prev::T, seen::S)::T where {T<:AbstractFloat,S} - zero(T) + return zero(T) end function make_zero_immutable!( prev::Complex{T}, seen::S, )::Complex{T} where {T<:AbstractFloat,S} - zero(Complex{T}) + return zero(Complex{T}) end function make_zero_immutable!(prev::T, seen::S)::T where {T<:Tuple,S} + if guaranteed_const_nongen(T, nothing) + return prev # unreachable from make_zero! + end ntuple(Val(length(T.parameters))) do i Base.@_inline_meta p = prev[i] SBT = Core.Typeof(p) - if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=# + if guaranteed_const_nongen(SBT, nothing) + p # covered by several tests even if not shown in coverage + elseif !ismutabletype(SBT) make_zero_immutable!(p, seen) else EnzymeCore.make_zero!(p, seen) @@ -280,11 +281,16 @@ function make_zero_immutable!(prev::T, seen::S)::T where {T<:Tuple,S} end function make_zero_immutable!(prev::NamedTuple{a,b}, seen::S)::NamedTuple{a,b} where {a,b,S} + if guaranteed_const_nongen(NamedTuple{a,b}, nothing) + return prev # unreachable from make_zero! + end NamedTuple{a,b}(ntuple(Val(length(b.parameters))) do i Base.@_inline_meta p = prev[a[i]] SBT = Core.Typeof(p) - if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=# + if guaranteed_const_nongen(SBT, nothing) + p # covered by several tests even if not shown in coverage + elseif !ismutabletype(SBT) make_zero_immutable!(p, seen) else EnzymeCore.make_zero!(p, seen) @@ -296,21 +302,20 @@ end function make_zero_immutable!(prev::T, seen::S)::T where {T,S} if guaranteed_const_nongen(T, nothing) - return prev # Note: unreachable from make_zero! + return prev # unreachable from make_zero! end - @assert !ismutable(prev) - - RT = Core.Typeof(prev) - @assert !Base.isabstracttype(RT) - @assert Base.isconcretetype(RT) - nf = fieldcount(RT) - + @assert !ismutabletype(T) + @assert !Base.isabstracttype(T) + @assert Base.isconcretetype(T) + nf = fieldcount(T) flds = Vector{Any}(undef, nf) for i = 1:nf if isdefined(prev, i) xi = getfield(prev, i) ST = Core.Typeof(xi) - flds[i] = if active_reg_inner(ST, (), nothing, Val(true)) == ActiveState #=justActive=# + flds[i] = if guaranteed_const_nongen(ST, nothing) + xi + elseif !ismutabletype(ST) make_zero_immutable!(xi, seen) else EnzymeCore.make_zero!(xi, seen) @@ -321,39 +326,63 @@ function make_zero_immutable!(prev::T, seen::S)::T where {T,S} break end end - ccall(:jl_new_structv, Any, (Any, Ptr{Any}, UInt32), RT, flds, nf)::T + return ccall(:jl_new_structv, Any, (Any, Ptr{Any}, UInt32), T, flds, nf)::T end @inline function EnzymeCore.make_zero!( prev::Base.RefValue{T}, seen::ST, )::Nothing where {T<:AbstractFloat,ST} + if !isnothing(seen) + if prev in seen + return nothing + end + push!(seen, prev) + end prev[] = zero(T) - nothing + return nothing end @inline function EnzymeCore.make_zero!( prev::Base.RefValue{Complex{T}}, seen::ST, )::Nothing where {T<:AbstractFloat,ST} + if !isnothing(seen) + if prev in seen + return nothing + end + push!(seen, prev) + end prev[] = zero(Complex{T}) - nothing + return nothing end @inline function EnzymeCore.make_zero!( prev::Array{T,N}, seen::ST, )::Nothing where {T<:AbstractFloat,N,ST} + if !isnothing(seen) + if prev in seen + return nothing + end + push!(seen, prev) + end fill!(prev, zero(T)) - nothing + return nothing end @inline function EnzymeCore.make_zero!( prev::Array{Complex{T},N}, seen::ST, )::Nothing where {T<:AbstractFloat,N,ST} + if !isnothing(seen) + if prev in seen + return nothing + end + push!(seen, prev) + end fill!(prev, zero(Complex{T})) - nothing + return nothing end @static if VERSION < v"1.11-" @@ -362,16 +391,28 @@ else prev::GenericMemory{kind, T}, seen::ST, )::Nothing where {T<:AbstractFloat,kind,ST} + if !isnothing(seen) + if prev in seen + return nothing + end + push!(seen, prev) + end fill!(prev, zero(T)) - nothing + return nothing end @inline function EnzymeCore.make_zero!( prev::GenericMemory{kind, Complex{T}}, seen::ST, )::Nothing where {T<:AbstractFloat,kind,ST} + if !isnothing(seen) + if prev in seen + return nothing + end + push!(seen, prev) + end fill!(prev, zero(Complex{T})) - nothing + return nothing end end @@ -379,90 +420,88 @@ end prev::Base.RefValue{T}, )::Nothing where {T<:AbstractFloat} EnzymeCore.make_zero!(prev, nothing) - nothing + return nothing end @inline function EnzymeCore.make_zero!( prev::Base.RefValue{Complex{T}}, )::Nothing where {T<:AbstractFloat} EnzymeCore.make_zero!(prev, nothing) - nothing + return nothing end @inline function EnzymeCore.make_zero!(prev::Array{T,N})::Nothing where {T<:AbstractFloat,N} EnzymeCore.make_zero!(prev, nothing) - nothing + return nothing end @inline function EnzymeCore.make_zero!( prev::Array{Complex{T},N}, )::Nothing where {T<:AbstractFloat,N} EnzymeCore.make_zero!(prev, nothing) - nothing + return nothing end @inline function EnzymeCore.make_zero!(prev::Array{T,N}, seen::ST)::Nothing where {T,N,ST} if guaranteed_const_nongen(T, nothing) - return + return nothing end if prev in seen - return + return nothing end push!(seen, prev) - for I in eachindex(prev) if isassigned(prev, I) pv = prev[I] SBT = Core.Typeof(pv) - if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=# + if guaranteed_const_nongen(SBT, nothing) + continue + elseif !ismutabletype(SBT) @inbounds prev[I] = make_zero_immutable!(pv, seen) - nothing else EnzymeCore.make_zero!(pv, seen) - nothing end end end - nothing + return nothing end @static if VERSION < v"1.11-" else @inline function EnzymeCore.make_zero!(prev::GenericMemory{kind, T})::Nothing where {T<:AbstractFloat,kind} EnzymeCore.make_zero!(prev, nothing) - nothing + return nothing end @inline function EnzymeCore.make_zero!( prev::GenericMemory{kind, Complex{T}}, )::Nothing where {T<:AbstractFloat, kind} EnzymeCore.make_zero!(prev, nothing) - nothing + return nothing end @inline function EnzymeCore.make_zero!(prev::GenericMemory{kind, T}, seen::ST)::Nothing where {T,kind,ST} if guaranteed_const_nongen(T, nothing) - return + return nothing end if prev in seen - return + return nothing end push!(seen, prev) - for I in eachindex(prev) if isassigned(prev, I) pv = prev[I] SBT = Core.Typeof(pv) - if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=# + if guaranteed_const_nongen(SBT, nothing) + continue + elseif !ismutabletype(SBT) @inbounds prev[I] = make_zero_immutable!(pv, seen) - nothing else EnzymeCore.make_zero!(pv, seen) - nothing end end end - nothing + return nothing end end @@ -472,87 +511,77 @@ end seen::ST, )::Nothing where {T,ST} if guaranteed_const_nongen(T, nothing) - return + return nothing end if prev in seen - return + return nothing end push!(seen, prev) - pv = prev[] SBT = Core.Typeof(pv) - if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=# + if guaranteed_const_nongen(SBT, nothing) + return nothing + elseif !ismutabletype(SBT) prev[] = make_zero_immutable!(pv, seen) - nothing else EnzymeCore.make_zero!(pv, seen) - nothing end - nothing + return nothing end @inline function EnzymeCore.make_zero!(prev::Core.Box, seen::ST)::Nothing where {ST} - pv = prev.contents - T = Core.Typeof(pv) - if guaranteed_const_nongen(T, nothing) - return - end if prev in seen - return + return nothing end push!(seen, prev) + pv = prev.contents SBT = Core.Typeof(pv) - if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=# + if guaranteed_const_nongen(SBT, nothing) + return nothing + elseif !ismutabletype(SBT) prev.contents = make_zero_immutable!(pv, seen) - nothing else EnzymeCore.make_zero!(pv, seen) - nothing end - nothing + return nothing end -@inline function EnzymeCore.make_zero!( - prev::T, - seen::S = Base.IdSet{Any}(), -)::Nothing where {T,S} +@inline function EnzymeCore.make_zero!(prev::T, seen::S)::Nothing where {T,S} if guaranteed_const_nongen(T, nothing) - return + return nothing end if prev in seen - return + return nothing end @assert !Base.isabstracttype(T) @assert Base.isconcretetype(T) nf = fieldcount(T) - - if nf == 0 - error("cannot zero $T in-place: it is apparently differentiable but has no fields") + return nothing end - push!(seen, prev) - for i = 1:nf if isdefined(prev, i) xi = getfield(prev, i) SBT = Core.Typeof(xi) - if guaranteed_const_nongen(SBT, nothing) + activitystate = active_reg_inner(SBT, (), nothing) + if activitystate == AnyState # guaranteed_const continue - end - if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=# + elseif ismutabletype(T) && !ismutabletype(SBT) yi = make_zero_immutable!(xi, seen) if Base.isconst(T, i) ccall(:jl_set_nth_field, Cvoid, (Any, Csize_t, Any), prev, i-1, yi) else setfield!(prev, i, yi) end - nothing - else + elseif activitystate == DupState EnzymeCore.make_zero!(xi, seen) - nothing + else + throw(ArgumentError("$xi of type $SBT cannot be zeroed in-place")) end end end - return + return nothing end + +@inline EnzymeCore.make_zero!(prev) = EnzymeCore.make_zero!(prev, Base.IdSet()) diff --git a/test/make_zero.jl b/test/make_zero.jl index 0f04db2fdb..cbe2f2159f 100644 --- a/test/make_zero.jl +++ b/test/make_zero.jl @@ -1,11 +1,91 @@ +module MakeZeroTests + using Enzyme +using StaticArrays using Test +# Universal getters/setters for built-in and custom containers/wrappers +getx(w::Base.RefValue) = w[] +getx(w::Core.Box) = w.contents +getx(w) = first(w) +gety(w) = last(w) + +setx!(w::Base.RefValue, x) = (w[] = x) +setx!(w::Core.Box, x) = (w.contents = x) +setx!(w, x) = (w[begin] = x) +sety!(w, y) = (w[end] = y) + +# non-isbits MArray doesn't support setindex!, so requires a little hack +function setx!(w::MArray{S,T}, x) where {S,T} + if isbitstype(T) + w[begin] = x + else + w.data = (x, Base.tail(w.data)...) + end + return x +end + +function sety!(w::MArray{S,T}, y) where {S,T} + if isbitstype(T) + w[end] = y + else + w.data = (Base.front(w.data)..., y) + end + return y +end + +struct Empty end + +mutable struct MutableEmpty end + +Base.:(==)(::MutableEmpty, ::MutableEmpty) = true + +struct Wrapper{T} + x::T +end + +Base.:(==)(a::Wrapper, b::Wrapper) = (a === b) || (a.x == b.x) +getx(a::Wrapper) = a.x + mutable struct MutableWrapper{T} x::T end -Base.:(==)(a::MutableWrapper, b::MutableWrapper) = (a === b) || isequal(a.x, b.x) +Base.:(==)(a::MutableWrapper, b::MutableWrapper) = (a === b) || (a.x == b.x) + +getx(a::MutableWrapper) = a.x +setx!(a::MutableWrapper, x) = (a.x = x) + +struct DualWrapper{Tx,Ty} + x::Tx + y::Ty +end + +DualWrapper{T}(x::T, y) where {T} = DualWrapper{T,typeof(y)}(x, y) + +function Base.:(==)(a::DualWrapper, b::DualWrapper) + return (a === b) || ((a.x == b.x) && (a.y == b.y)) +end + +getx(a::DualWrapper) = a.x +gety(a::DualWrapper) = a.y + +mutable struct MutableDualWrapper{Tx,Ty} + x::Tx + y::Ty +end + +MutableDualWrapper{T}(x::T, y) where {T} = MutableDualWrapper{T,typeof(y)}(x, y) + +function Base.:(==)(a::MutableDualWrapper, b::MutableDualWrapper) + return (a === b) || ((a.x == b.x) && (a.y == b.y)) +end + +getx(a::MutableDualWrapper) = a.x +gety(a::MutableDualWrapper) = a.y + +setx!(a::MutableDualWrapper, x) = (a.x = x) +sety!(a::MutableDualWrapper, y) = (a.y = y) struct Incomplete{T} s::String @@ -17,9 +97,9 @@ end function Base.:(==)(a::Incomplete, b::Incomplete) (a === b) && return true - (isequal(a.s, b.s) && isequal(a.x, b.x) && isequal(a.w, b.w)) || return false + ((a.s == b.s) && (a.x == b.x) && (a.w == b.w)) || return false if isdefined(a, :z) && isdefined(b, :z) - isequal(a.z, b.z) || return false + (a.z == b.z) || return false elseif isdefined(a, :z) || isdefined(b, :z) return false end @@ -41,17 +121,51 @@ end function Base.:(==)(a::MutableIncomplete, b::MutableIncomplete) (a === b) && return true - if !isequal(a.s, b.s) || !isequal(a.x, b.x) || !isequal(a.y, b.y) || !isequal(a.w, b.w) + if (a.s != b.s) || (a.x != b.x) || (a.y != b.y) || (a.w != b.w) return false end if isdefined(a, :z) && isdefined(b, :z) - isequal(a.z, b.z) || return false + (a.z == b.z) || return false elseif isdefined(a, :z) || isdefined(b, :z) return false end return true end +mutable struct CustomVector{T} <: AbstractVector{T} + data::Vector{T} +end + +Base.:(==)(a::CustomVector, b::CustomVector) = (a === b) || (a.data == b.data) + +function Enzyme.EnzymeCore.make_zero( + ::Type{CV}, seen::IdDict, prev::CV, ::Val{copy_if_inactive} +) where {CV<:CustomVector{<:AbstractFloat},copy_if_inactive} + @info "make_zero(::CustomVector)" + if haskey(seen, prev) + return seen[prev] + end + new = CustomVector(zero(prev.data))::CV + seen[prev] = new + return new +end + +function Enzyme.EnzymeCore.make_zero!(prev::CustomVector{<:AbstractFloat}, seen)::Nothing + @info "make_zero!(::CustomVector)" + if !isnothing(seen) + if prev in seen + return nothing + end + push!(seen, prev) + end + fill!(prev.data, false) + return nothing +end + +function Enzyme.EnzymeCore.make_zero!(prev::CustomVector{<:AbstractFloat}) + return Enzyme.EnzymeCore.make_zero!(prev, nothing) +end + struct WithIO{F} # issue 2091 v::Vector{Float64} callback::F @@ -72,284 +186,540 @@ macro test_noerr(expr) end end -@testset "make_zero" begin - # floats - @test make_zero(1.0) == 0.0 - @test make_zero(1.0im) == 0.0im - - # float arrays + multiple references - rearr = [1.0] - imarr = [1.0im] - rearr0 = make_zero(rearr) - imarr0 = make_zero(imarr) - @test typeof(rearr0) === typeof(rearr) - @test typeof(imarr0) === typeof(imarr) - @test rearr == [1.0] # no mutation - @test imarr == [1.0im] # no mutation - @test rearr0 == [0.0] - @test imarr0 == [0.0im] - rearrs0 = make_zero((rearr, rearr)) - imarrs0 = make_zero((imarr, imarr)) - @test typeof(rearrs0) === typeof((rearr, rearr)) - @test typeof(imarrs0) === typeof((imarr, imarr)) - @test rearr == [1.0] # no mutation - @test imarr == [1.0im] # no mutation - @test rearrs0[1] === rearrs0[2] - @test imarrs0[1] === imarrs0[2] - @test rearrs0[1] == [0.0] - @test imarrs0[1] == [0.0im] - - # floats in structs - rewrapped = MutableWrapper(1.0) - imwrapped = MutableWrapper(1.0im) - rewrapped0 = make_zero(rewrapped) - imwrapped0 = make_zero(imwrapped) - @test typeof(rewrapped0) === typeof(rewrapped) - @test typeof(imwrapped0) === typeof(imwrapped) - @test rewrapped == MutableWrapper(1.0) # no mutation - @test imwrapped == MutableWrapper(1.0im) # no mutation - @test rewrapped0 == MutableWrapper(0.0) - @test imwrapped0 == MutableWrapper(0.0im) - - # generic array + multiple references - wrapped = MutableWrapper(1.0) - mixarr = ["a", 1.0, wrapped] - mixarr0 = make_zero(mixarr) - @test typeof(mixarr0) === typeof(mixarr) - @test view(mixarr, 1:2) == ["a", 1.0] # no mutation - @test mixarr[3] === wrapped # no mutation - @test mixarr0 == ["a", 0.0, MutableWrapper(0.0)] - mixarrs0 = make_zero((mixarr, mixarr)) - @test typeof(mixarrs0) === typeof((mixarr, mixarr)) - @test view(mixarr, 1:2) == ["a", 1.0] # no mutation - @test mixarr[3] === wrapped # no mutation - @test mixarrs0[1] === mixarrs0[2] - @test mixarrs0[1] == ["a", 0.0, MutableWrapper(0.0)] - - # non-differentiable array + copy_if_inactive - constarr = ["a"] - constarr0 = make_zero(constarr) - @test typeof(constarr0) === typeof(constarr) - @test constarr == ["a"] # no mutation - @test constarr0 === constarr - constarr0copy = make_zero(constarr, #=copy_if_inactive=#Val(true)) - @test typeof(constarr0copy) === typeof(constarr0) - @test constarr == ["a"] # no mutation - @test constarr0copy !== constarr - @test constarr0copy == constarr - - # Tuple - tup = ("a", 1.0, MutableWrapper(1.0)) - tup0 = make_zero(tup) - @test typeof(tup0) === typeof(tup) - @test tup == ("a", 1.0, MutableWrapper(1.0)) # no mutation - @test tup0 == ("a", 0.0, MutableWrapper(0.0)) - - # NamedTuple - ntup = (a="a", b=1.0, c=MutableWrapper(1.0)) - ntup0 = make_zero(ntup) - @test typeof(ntup0) === typeof(ntup) - @test ntup == (a="a", b=1.0, c=MutableWrapper(1.0)) # no mutation - @test ntup0 == (a="a", b=0.0, c=MutableWrapper(0.0)) - - # Box + multiple references - box = Core.Box(1.0) - box0 = make_zero(box) - @test typeof(box0) === typeof(box) - @test box.contents == 1.0 # no mutation - @test box0.contents == 0.0 - boxes0 = make_zero((box, box)) - @test typeof(boxes0) === typeof((box, box)) - @test box.contents == 1.0 # no mutation - @test boxes0[1] === boxes0[2] - @test boxes0[1].contents == 0.0 - - # differentiable custom type + multiple references - wrapped = MutableWrapper(1.0) - wrapped0 = make_zero(wrapped) - @test typeof(wrapped0) === typeof(wrapped) - @test wrapped == MutableWrapper(1.0) # no mutation - @test wrapped0 == MutableWrapper(0.0) - wrappeds0 = make_zero((wrapped, wrapped)) - @test typeof(wrappeds0) === typeof((wrapped, wrapped)) - @test wrapped == MutableWrapper(1.0) # no mutation - @test wrappeds0[1] === wrappeds0[2] - @test wrappeds0[1] == MutableWrapper(0.0) - - # non-differentiable custom type + copy_if_inactive - constwrapped = MutableWrapper("a") - constwrapped0 = make_zero(constwrapped) - @test typeof(constwrapped0) === typeof(constwrapped) - @test constwrapped == MutableWrapper("a") # no mutation - @test constwrapped0 === constwrapped - constwrapped0copy = make_zero(constwrapped, #=copy_if_inactive=#Val(true)) - @test typeof(constwrapped0copy) === typeof(constwrapped0) - @test constwrapped == MutableWrapper("a") # no mutation - @test constwrapped0copy !== constwrapped - @test constwrapped0copy == constwrapped - - # immutable struct with active, mutable, inactive and undefined fields - incomplete = Incomplete("a", 1.0, MutableWrapper(1.0)) - incomplete0 = make_zero(incomplete) - @test typeof(incomplete0) === typeof(incomplete) - @test incomplete == Incomplete("a", 1.0, MutableWrapper(1.0)) # no mutation - @test incomplete0 == Incomplete("a", 0.0, MutableWrapper(0.0)) - - # mutable struct with inactive, active, undefined, and mutable fields - # + multiple references - incompletemut = MutableIncomplete("a", 1.0, 1.0, MutableWrapper(1.0)) - incompletemut0 = make_zero(incompletemut) - @test typeof(incompletemut0) === typeof(incompletemut) - @test incompletemut == MutableIncomplete("a", 1.0, 1.0, MutableWrapper(1.0)) # no mutation - @test incompletemut0 == MutableIncomplete("a", 0.0, 0.0, MutableWrapper(0.0)) - incompletemuts0 = make_zero((incompletemut, incompletemut)) - @test typeof(incompletemuts0) === typeof((incompletemut, incompletemut)) - @test incompletemut == MutableIncomplete("a", 1.0, 1.0, MutableWrapper(1.0)) # no mutation - @test incompletemuts0[1] === incompletemuts0[2] - @test incompletemuts0[1] == MutableIncomplete("a", 0.0, 0.0, MutableWrapper(0.0)) - - # containing IO (issue #2091) - f = WithIO([1.0, 2.0], stdout) - df = @test_noerr make_zero(f) - @test df.v == [0.0, 0.0] - @test df.callback === f.callback +const scalartypes = [Float32, ComplexF32, Float64, ComplexF64] + +const inactivetup = ("a", Empty(), MutableEmpty()) +const inactivearr = [inactivetup] + +const wrappers = [ + (name="Tuple{X}", f=tuple, N=1, mutable=false, typed=true), + (name="@NamedTuple{x::X}", f=(NamedTuple{(:x,)} ∘ tuple), N=1, mutable=false, typed=true), + (name="struct{X}", f=Wrapper, N=1, mutable=false, typed=true), + + (name="@NamedTuple{x}", f=(@NamedTuple{x} ∘ tuple), N=1, mutable=false, typed=false), + (name="struct{Any}", f=Wrapper{Any}, N=1, mutable=false, typed=false), + + (name="Array{X}", f=(x -> [x]), N=1, mutable=true, typed=true), + (name="Base.RefValue{X}", f=Ref, N=1, mutable=true, typed=true), + (name="mutable struct{X}", f=MutableWrapper, N=1, mutable=true, typed=true), + + (name="Array{Any}", f=(x -> Any[x]), N=1, mutable=true, typed=false), + (name="Base.RefValue{Any}", f=Ref{Any}, N=1, mutable=true, typed=false), + (name="Core.Box", f=Core.Box, N=1, mutable=true, typed=false), + (name="mutable struct{Any}", f=MutableWrapper{Any}, N=1, mutable=true, typed=false), + + (name="Tuple{X,Y}", f=tuple, N=2, mutable=false, typed=true), + (name="@NamedTuple{x::X,y::Y}", f=(NamedTuple{(:x, :y)} ∘ tuple), N=2, mutable=false, typed=true), + (name="struct{X,Y}", f=DualWrapper, N=2, mutable=false, typed=true), + + (name="@NamedTuple{x,y::Y}", f=((x, y) -> @NamedTuple{x,y::typeof(y)}((x, y))), N=2, mutable=false, typed=:partial), + (name="struct{Any,Y}", f=DualWrapper{Any}, N=2, mutable=false, typed=:partial), + + (name="@NamedTuple{x,y}", f=@NamedTuple{x,y} ∘ tuple, N=2, mutable=false, typed=false), + (name="struct{Any}", f=DualWrapper{Any,Any}, N=2, mutable=false, typed=false), + + (name="mutable struct{X,Y}", f=MutableDualWrapper, N=2, mutable=true, typed=true), + + (name="Array{promote_type(X,Y)}", f=((x, y) -> [x, y]), N=2, mutable=true, typed=:promoted), + (name="mutable struct{Any,Y}", f=MutableDualWrapper{Any}, N=2, mutable=true, typed=:partial), + + (name="Array{Any}", f=((x, y) -> Any[x, y]), N=2, mutable=true, typed=false), + (name="mutable struct{Any,Any}", f=MutableDualWrapper{Any,Any}, N=2, mutable=true, typed=false), + + # StaticArrays extension + (name="SVector{1,X}", f=SVector{1} ∘ tuple, N=1, mutable=false, typed=true), + (name="SVector{1,Any}", f=SVector{1,Any} ∘ tuple, N=1, mutable=false, typed=false), + (name="MVector{1,X}", f=MVector{1} ∘ tuple, N=1, mutable=true, typed=true), + (name="MVector{1,Any}", f=MVector{1,Any} ∘ tuple, N=1, mutable=true, typed=false), + (name="SVector{2,promote_type(X,Y)}", f=SVector{2} ∘ tuple, N=2, mutable=false, typed=:promoted), + (name="SVector{2,Any}", f=SVector{2,Any} ∘ tuple, N=2, mutable=false, typed=false), + (name="MVector{2,promote_type(X,Y)}", f=MVector{2} ∘ tuple, N=2, mutable=true, typed=:promoted), + (name="MVector{2,Any}", f=MVector{2,Any} ∘ tuple, N=2, mutable=true, typed=false), +] + +@static if VERSION < v"1.11-" +else +_memory(x::Vector) = Memory{eltype(x)}(x) +push!( + wrappers, + (name="Memory{X}", f=(x -> _memory([x])), N=1, mutable=true, typed=true), + (name="Memory{Any}", f=(x -> _memory(Any[x])), N=1, mutable=true, typed=false), + (name="Memory{promote_type(X,Y)}", f=((x, y) -> _memory([x, y])), N=2, mutable=true, typed=:promoted), + (name="Memory{Any}", f=((x, y) -> _memory(Any[x, y])), N=2, mutable=true, typed=false), +) end -@testset "make_zero!" begin - # floats in mutable struct - rewrapped, imwrapped = MutableWrapper(1.0), MutableWrapper(1.0im) - make_zero!(rewrapped) - make_zero!(imwrapped) - @test rewrapped == MutableWrapper(0.0) - @test imwrapped == MutableWrapper(0.0im) - - # mixed tuple in mutable container - wrapped = MutableWrapper(1.0) - tuparr = [(1.0, wrapped)] - make_zero!(tuparr) - @test tuparr[1] === (0.0, wrapped) - @test wrapped == MutableWrapper(0.0) - - # mixed namedtuple in mutable container - wrapped = MutableWrapper(1.0) - ntuparr = [(a=1.0, b=wrapped)] - make_zero!(ntuparr) - @test ntuparr[1] === (a=0.0, b=wrapped) - @test wrapped == MutableWrapper(0.0) - - # immutable struct with active, mutable, inactive and undefined fields in mutable container - wrapped = MutableWrapper(1.0) - incompletearr = [Incomplete("a", 1.0, wrapped)] - make_zero!(incompletearr) - @test incompletearr[1] == Incomplete("a", 0.0, wrapped) - @test wrapped == MutableWrapper(0.0) - - # floats in Ref - reref, imref = Ref(1.0), Ref(1.0im) - make_zero!(reref) - make_zero!(imref) - @test reref[] == 0.0 - @test imref[] == 0.0im - - # float arrays - rearr, imarr = [1.0], [1.0im] - make_zero!(rearr) - make_zero!(imarr) - @test rearr[1] == 0.0 - @test imarr[1] == 0.0im - - # non-differentiable array - constarr = ["a"] - make_zero!(constarr) - @test constarr[1] == "a" - - # array with active, mutable, inactive and unassigned elements + multiple references - wrapped = MutableWrapper(1.0) - genericarr = Vector(undef, 4) - genericarr[1:3] .= ("a", 1.0, wrapped) - genericarrs = [genericarr, genericarr] - make_zero!(genericarrs) - @test genericarrs[1] === genericarrs[2] - @test genericarrs[1] === genericarr - @test view(genericarr, 1:2) == ["a", 0.0] - @test genericarr[3] === wrapped - @test wrapped == MutableWrapper(0.0) - @test !isassigned(genericarr, 4) - - # Ref with multiple references - genericref = Ref((1.0,)) - genericrefs = [genericref, genericref] - make_zero!(genericrefs) - @test genericrefs[1] === genericrefs[2] - @test genericrefs[1] === genericref - @test genericref[] == (0.0,) - - # Ref with mutable value - wrapped = MutableWrapper(1.0) - mutref = Ref(wrapped) - make_zero!(mutref) - @test mutref[] === wrapped - @test wrapped == MutableWrapper(0.0) - - # Ref with non-differentiable value - constref = Ref("a") - make_zero!(constref) - @test constref[] == "a" - - # Box with multiple references - box = Core.Box(1.0) - boxes = [box, box] - make_zero!(boxes) - @test boxes[1] === boxes[2] - @test boxes[1] === box - @test box.contents == 0.0 - - # Box with mutable value - wrapped = MutableWrapper(1.0) - mutbox = Core.Box(wrapped) - make_zero!(mutbox) - @test mutbox.contents === wrapped - @test wrapped == MutableWrapper(0.0) - - # Box with non-differentiable value - constbox = Core.Box("a") - make_zero!(constbox) - @test constbox.contents == "a" - - # mutable struct with inactive, active, const active, undefined, and mutable fields - # + multiple references - wrapped = MutableWrapper(1.0) - incompletemut = MutableIncomplete("a", #=const=#1.0, 1.0, wrapped) - incompletemuts = [incompletemut, incompletemut] - make_zero!(incompletemuts) - @test incompletemuts[1] === incompletemuts[2] - @test incompletemuts[1] === incompletemut - @test incompletemut == MutableIncomplete("a", #=const=#0.0, 0.0, MutableWrapper(0.0)) - @test incompletemut.w === wrapped - - # wrapped differentiable array - arr = [1.0] - arrwrapped = MutableWrapper(arr) - make_zero!(arrwrapped) - @test arrwrapped.x === arr - @test arr == [0.0] - - # early error on active/mixed type - @test_throws ErrorException make_zero!(1.0) - @test_throws ErrorException make_zero!((1.0, MutableWrapper(1.0))) - - # immutable struct with both active and undefined fields in immutable container - # (currently fails due to #1935) - wrapped = MutableWrapper(1.0) - incompletetuparr = [(Incomplete("a", 1.0, wrapped),)] - make_zero!(incompletetuparr) - @test incompletetuparr[1][1] == Incomplete("a", 0.0, MutableWrapper(0.0)) - @test incompletetuparr[1][1].w === wrapped - - # containing IO (issue #2091) - f = WithIO([1.0, 2.0], stdout) - fwrapped = [f] - @test_noerr make_zero!(fwrapped) - @test fwrapped[1] === f - @test fwrapped[1].v == [0.0, 0.0] +function test_make_zero() + @testset "scalars" begin + @testset "$T" for T in scalartypes + x = oneunit(T) + x_makez = make_zero(x) + @test typeof(x_makez) === T # correct type + @test x_makez == zero(T) # correct value + @test x == oneunit(T) # no mutation of original (relevant for BigFloat) + end + end + @testset "nested types" begin + @testset "$T in $(wrapper.name)" for + T in scalartypes, wrapper in filter(w -> (w.N == 1), wrappers) + x = oneunit(T) + w = wrapper.f(x) + w_makez = make_zero(w) + @test typeof(w_makez) === typeof(w) # correct type + @test typeof(getx(w_makez)) === T # correct type + @test getx(w_makez) == zero(T) # correct value + @test getx(w) === x # no mutation of original + @test x == oneunit(T) # no mutation of original (relevant for BigFloat) + @testset "doubly included in $(dualwrapper.name)" for + dualwrapper in filter(w -> (w.N == 2), wrappers) + w_inner = wrapper.f(x) + d_outer = dualwrapper.f(w_inner, w_inner) + d_outer_makez = make_zero(d_outer) + @test typeof(d_outer_makez) === typeof(d_outer) # correct type + @test typeof(getx(d_outer_makez)) === typeof(w_inner) # correct type + @test typeof(getx(getx(d_outer_makez))) === T # correct type + @test getx(d_outer_makez) === gety(d_outer_makez) # correct topology + @test getx(getx(d_outer_makez)) == zero(T) # correct value + @test getx(d_outer) === gety(d_outer) # no mutation of original + @test getx(d_outer) === w_inner # no mutation of original + @test getx(w_inner) === x # no mutation of original + @test x == oneunit(T) # no mutation of original (relevant for BigFloat) + d_inner = dualwrapper.f(x, x) + w_outer = wrapper.f(d_inner) + w_outer_makez = make_zero(w_outer) + @test typeof(w_outer_makez) === typeof(w_outer) # correct type + @test typeof(getx(w_outer_makez)) === typeof(d_inner) # correct type + @test typeof(getx(getx(w_outer_makez))) === T # correct type + @test getx(getx(w_outer_makez)) == gety(getx(w_outer_makez)) # correct topology + @test getx(getx(w_outer_makez)) == zero(T) # correct value + @test getx(w_outer) === d_inner # no mutation of original + @test getx(d_inner) === gety(d_inner) # no mutation of original + @test getx(d_inner) === x # no mutation of original + @test x == oneunit(T) # no mutation of original (relevant for BigFloat) + if wrapper.mutable && !dualwrapper.mutable + # some code paths can only be hit with three layers of wrapping: + # mutable(immutable(mutable(scalar))) + @testset "all wrapped in $(outerwrapper.name)" for + outerwrapper in filter(w -> ((w.N == 1) && w.mutable), wrappers) + w_inner = wrapper.f(x) + d_middle = dualwrapper.f(w_inner, w_inner) + w_outer = outerwrapper.f(d_middle) + w_outer_makez = make_zero(w_outer) + @test typeof(w_outer_makez) === typeof(w_outer) # correct type + @test typeof(getx(w_outer_makez)) === typeof(d_middle) # correct type + @test typeof(getx(getx(w_outer_makez))) === typeof(w_inner) # correct type + @test typeof(getx(getx(getx(w_outer_makez)))) === T # correct type + @test getx(getx(w_outer_makez)) === gety(getx(w_outer_makez)) # correct topology + @test getx(getx(getx(w_outer_makez))) == zero(T) # correct value + @test getx(w_outer) === d_middle # no mutation of original + @test getx(d_middle) === gety(d_middle) # no mutation of original + @test getx(d_middle) === w_inner # no mutation of original + @test getx(w_inner) === x # no mutation of original + @test x == oneunit(T) # no mutation of original (relevant for BigFloat) + end + end + end + end + end + @testset "inactive" begin + @testset "in $(wrapper.name)" for wrapper in wrappers + if wrapper.N == 1 + w = wrapper.f(inactivearr) + w_makez = make_zero(w) + if wrapper.typed == true + @test w_makez === w # preserved wrapper identity if guaranteed const + end + @test typeof(w_makez) === typeof(w) # correct type + @test getx(w_makez) === inactivearr # preserved identity + @test inactivearr[1] === inactivetup # preserved value + @test getx(w) === inactivearr # no mutation of original + else # wrapper.N == 2 + @testset "multiple references" begin + w = wrapper.f(inactivearr, inactivearr) + w_makez = make_zero(w) + if wrapper.typed == true + @test w_makez === w # preserved wrapper identity if guaranteed const + end + @test typeof(w_makez) === typeof(w) # correct type + @test getx(w_makez) === gety(w_makez) # preserved topology + @test getx(w_makez) === inactivearr # preserved identity + @test inactivearr[1] === inactivetup # preserved value + @test getx(w) === gety(w) # no mutation of original + @test getx(w) === inactivearr # no mutation of original + end + @testset "alongside active" begin + a = [1.0] + w = wrapper.f(a, inactivearr) + w_makez = make_zero(w) + @test typeof(w_makez) === typeof(w) # correct type + @test typeof(getx(w_makez)) === typeof(a) # correct type + @test getx(w_makez) == [0.0] # correct value + @test gety(w_makez) === inactivearr # preserved inactive identity + @test inactivearr[1] === inactivetup # preserved inactive value + @test getx(w) === a # no mutation of original + @test a[1] === 1.0 # no mutation of original + @test gety(w) === inactivearr # no mutation of original + if wrapper.typed == :partial + # above: untyped active / typed inactive + # below: untyped inactive / typed active + w = wrapper.f(inactivearr, a) + w_makez = make_zero(w) + @test typeof(w_makez) === typeof(w) # correct type + @test getx(w_makez) === inactivearr # preserved inactive identity + @test inactivearr[1] === inactivetup # preserved inactive value + @test typeof(gety(w_makez)) === typeof(a) # correct type + @test gety(w_makez) == [0.0] # correct value + @test getx(w) === inactivearr # no mutation of original + @test gety(w) === a # no mutation of original + @test a[1] === 1.0 # no mutation of original + end + end + end + end + @testset "copy_if_inactive $value" for (value, args) in [ + ("unspecified", ()), + ("= false", (Val(false),)), + ("= true", (Val(true),)), + ] + a = [1.0] + w = Any[a, inactivearr, inactivearr] + w_makez = make_zero(w, args...) + @test typeof(w_makez) === typeof(w) # correct type + @test typeof(w_makez[1]) === typeof(a) # correct type + @test w_makez[1] == [0.0] # correct value + @test w_makez[2] === w_makez[3] # correct topology (topology should propagate even when copy_if_inactive = Val(true)) + @test w[1] === a # no mutation of original + @test a[1] === 1.0 # no mutation of original + @test w[2] === w[3] # no mutation of original + @test w[2] === inactivearr # no mutation of original + @test inactivearr[1] === inactivetup # no mutation of original + if args == (Val(true),) + @test typeof(w_makez[2]) === typeof(inactivearr) # correct type + @test w_makez[2] == inactivearr # correct value + @test w_makez[2][1] !== inactivetup # correct identity + else + @test w_makez[2] === inactivearr # correct value/type/identity + end + end + end + @testset "heterogeneous containers" begin + scalars, scalarsz = oneunit.(scalartypes), zero.(scalartypes) + wraps, wrapsz = Wrapper.(scalars), Wrapper.(scalarsz) + mwraps, mwrapsz = MutableWrapper.(scalars), MutableWrapper.(scalarsz) + items = (inactivetup..., scalars..., wraps..., mwraps...) + itemsz = (inactivetup..., scalarsz..., wrapsz..., mwrapsz...) + labels = Symbol.("i" .* string.(1:length(items))) + @testset "$name" for (name, c, cz) in [ + ("Tuple", Tuple(items), Tuple(itemsz)), + ("NamedTuple", NamedTuple(labels .=> items), NamedTuple(labels .=> itemsz)), + ("Array", collect(items), collect(itemsz)), + ] + c_makez = make_zero(c) + @test typeof(c_makez) === typeof(c) # correct type + @test all(typeof(czj) === typeof(cj) for (czj, cj) in zip(c_makez, c)) # correct type + @test c_makez == cz # correct value + @test all(czj === inj for (czj, inj) in zip(c_makez, inactivetup)) # preserved inactive identities + @test all(cj === itj for (cj, itj) in zip(c, items)) # no mutation of original + @test all(m.x == oneunit(m.x) for m in mwraps) # no mutation of original + end + end + @testset "circular references" begin + @testset "$(wrapper.name)" for wrapper in ( + filter(w -> (w.mutable && (w.typed in (:partial, false))), wrappers) + ) + a = [1.0] + if wrapper.N == 1 + w = wrapper.f(nothing) + setx!(w, (w, a)) + else + w = wrapper.f(nothing, a) + setx!(w, w) + end + w_makez = @test_noerr make_zero(w) + if wrapper.N == 1 + xz, yz = getx(w_makez) + x, y = getx(w) + else + xz, yz = getx(w_makez), gety(w_makez) + x, y = getx(w), gety(w) + end + @test typeof(w_makez) === typeof(w) # correct type + @test typeof(xz) === typeof(w) # correct type + @test typeof(yz) === typeof(a) # correct type + @test xz === w_makez # correct self-reference + @test yz == [0.0] # correct value + @test x === w # no mutation of original + @test y === a # no mutation of original + @test a[1] === 1.0 # no mutation of original + end + end + @testset "bring your own IdDict" begin + a = [1.0] + seen = IdDict() + a_makez = make_zero(typeof(a), seen, a) + @test typeof(a_makez) === typeof(a) # correct type + @test a_makez == [0.0] # correct value + @test a[1] === 1.0 # no mutation of original + @test haskey(seen, a) # original added to IdDict + @test seen[a] === a_makez # original points to zeroed value + end + @testset "custom leaf type" begin + a = [1.0] + v = CustomVector(a) + # include optional arg Val(false) to avoid calling the custom method directly; + # it should still be invoked + v_makez = @test_logs (:info, "make_zero(::CustomVector)") make_zero(v, Val(false)) + @test typeof(v_makez) === typeof(v) # correct type + @test typeof(v_makez.data) === typeof(a) # correct type + @test v_makez == CustomVector([0.0]) # correct value + @test v.data === a # no mutation of original + @test a[1] === 1.0 # no mutation of original + end + @testset "undefined fields/unassigned elements" begin + @testset "array w inactive/active/mutable/unassigned" begin + a = [1.0] + values = ("a", 1.0, a) + arr = Vector{Any}(undef, 4) + arr[1:3] .= values + arr_makez = make_zero(arr) + @views begin + @test typeof(arr_makez) === typeof(arr) # correct type + @test all(typeof.(arr_makez[1:3]) .=== typeof.(values)) # correct type + @test arr_makez[1:3] == ["a", 0.0, [0.0]] # correct value + @test !isassigned(arr_makez, 4) # propagated undefined + @test all(arr[1:3] .=== values) # no mutation of original + @test !isassigned(arr, 4) # no mutation of original + @test a[1] === 1.0 # no mutation of original + end + end + @testset "struct w inactive/active/mutable/undefined" begin + a = [1.0] + incomplete = Incomplete("a", 1.0, a) + incomplete_makez = make_zero(incomplete) + @test typeof(incomplete_makez) === typeof(incomplete) # correct type + @test typeof(incomplete_makez.w) === typeof(a) # correct type + @test incomplete_makez == Incomplete("a", 0.0, [0.0]) # correct value, propagated undefined + @test a[1] === 1.0 # no mutation of original + end + @testset "mutable struct w inactive/const active/active/mutable/undefined" begin + a = [1.0] + incomplete = MutableIncomplete("a", #=const=#1.0, 1.0, a) + incomplete_makez = make_zero(incomplete) + @test typeof(incomplete_makez) === typeof(incomplete) # correct type + @test typeof(incomplete_makez.w) === typeof(a) # correct type + @test incomplete_makez == MutableIncomplete("a", 0.0, 0.0, [0.0]) # correct value, propagated undefined + @test incomplete == MutableIncomplete("a", 1.0, 1.0, a) # no mutation of original + @test incomplete.w === a # no mutation of original + @test a[1] === 1.0 # no mutation of original + end + end + @testset "containing IO" begin # issue #2091 + f = WithIO([1.0, 2.0], stdout) + df = @test_noerr make_zero(f) + @test df.v == [0.0, 0.0] + @test df.callback === f.callback + end + return nothing end + +function test_make_zero!() + @testset "nested types" begin + @testset "$T in $(wrapper.name)" for + T in scalartypes, wrapper in filter(w -> (w.N == 1), wrappers) + x = oneunit(T) + if wrapper.mutable + w = wrapper.f(x) + make_zero!(w) + @test typeof(getx(w)) === T # preserved type + @test getx(w) == zero(T) # correct value + @test x == oneunit(T) # no mutation of scalar (relevant for BigFloat) + end + @testset "doubly included in $(dualwrapper.name)" for dualwrapper in ( + filter(w -> ((w.N == 2) && (w.mutable || wrapper.mutable)), wrappers) + ) + w_inner = wrapper.f(x) + d_outer = dualwrapper.f(w_inner, w_inner) + make_zero!(d_outer) + @test typeof(getx(d_outer)) === typeof(w_inner) # preserved type + @test typeof(getx(getx(d_outer))) === T # preserved type + @test getx(getx(d_outer)) == zero(T) # correct value + @test getx(d_outer) === gety(d_outer) # preserved topology + @test x == oneunit(T) # no mutation of scalar (relevant for BigFloat) + if wrapper.mutable + @test getx(d_outer) === w_inner # preserved identity + end + d_inner = dualwrapper.f(x, x) + w_outer = wrapper.f(d_inner) + make_zero!(w_outer) + @test typeof(getx(w_outer)) === typeof(d_inner) # preserved type + @test typeof(getx(getx(w_outer))) === T # preserved type + @test getx(getx(w_outer)) == zero(T) # correct value + @test getx(getx(w_outer)) === gety(getx(w_outer)) # preserved topology + @test x == oneunit(T) # no mutation of scalar (relevant for BigFloat) + if dualwrapper.mutable + @test getx(w_outer) === d_inner # preserved identity + end + if wrapper.mutable && !dualwrapper.mutable + # some code paths can only be hit with three layers of wrapping: + # mutable(immutable(mutable(scalar))) + @assert !dualwrapper.mutable # sanity check + @testset "all wrapped in $(outerwrapper.name)" for + outerwrapper in filter(w -> ((w.N == 1) && w.mutable), wrappers) + w_inner = wrapper.f(x) + d_middle = dualwrapper.f(w_inner, w_inner) + w_outer = outerwrapper.f(d_middle) + make_zero!(w_outer) + @test typeof(getx(w_outer)) === typeof(d_middle) # preserved type + @test typeof(getx(getx(w_outer))) === typeof(w_inner) # preserved type + @test typeof(getx(getx(getx(w_outer)))) === T # preserved type + @test getx(getx(getx(w_outer))) == zero(T) # correct value + @test getx(getx(w_outer)) === gety(getx(w_outer)) # preserved topology + @test getx(getx(w_outer)) === w_inner # preserved identity + @test x == oneunit(T) # no mutation of scalar (relevant for BigFloat) + end + end + end + end + end + @testset "inactive" begin + @testset "in $(wrapper.name)" for + wrapper in filter(w -> (w.mutable || (w.typed == true)), wrappers) + if wrapper.N == 1 + w = wrapper.f(inactivearr) + make_zero!(w) + @test getx(w) === inactivearr # preserved identity + @test inactivearr[1] === inactivetup # preserved value + else # wrapper.N == 2 + @testset "multiple references" begin + w = wrapper.f(inactivearr, inactivearr) + make_zero!(w) + @test getx(w) === gety(w) # preserved topology + @test getx(w) === inactivearr # preserved identity + @test inactivearr[1] === inactivetup # preserved value + end + @testset "alongside active" begin + a = [1.0] + w = wrapper.f(a, inactivearr) + make_zero!(w) + @test getx(w) === a # preserved identity + @test a[1] === 0.0 # correct value + @test gety(w) === inactivearr # preserved inactive identity + @test inactivearr[1] === inactivetup # preserved inactive value + end + end + end + end + @testset "heterogeneous containers" begin + mwraps = MutableWrapper.(oneunit.(scalartypes)) + mwrapsz = MutableWrapper.(zero.(scalartypes)) + items = (inactivetup..., mwraps...) + itemsz = (inactivetup..., mwrapsz...) + labels = Symbol.("i" .* string.(1:length(items))) + @testset "$name" for (name, c, cz) in [ + ("Tuple", Tuple(items), Tuple(itemsz)), + ("NamedTuple", NamedTuple(labels .=> items), NamedTuple(labels .=> itemsz)), + ("Array", collect(items), collect(itemsz)), + ] + make_zero!(c) + @test all(cj === itj for (cj, itj) in zip(c, items)) # preserved identities + @test c == cz # correct value + end + end + @testset "circular references" begin + @testset "$(wrapper.name)" for wrapper in ( + filter(w -> (w.mutable && (w.typed in (:partial, false))), wrappers) + ) + a = [1.0] + if wrapper.N == 1 + w = wrapper.f(nothing) + setx!(w, (w, a)) + else + w = wrapper.f(nothing, a) + setx!(w, w) + end + @test_noerr make_zero!(w) + if wrapper.N == 1 + x, y = getx(w) + else + x, y = getx(w), gety(w) + end + @test x === w # preserved self-referential identity + @test y === a # preserved identity + @test a[1] === 0.0 # correct value + end + end + @testset "bring your own IdSet" begin + a = [1.0] + seen = Base.IdSet() + make_zero!(a, seen) + @test a[1] === 0.0 # correct value + @test (a in seen) # object added to IdSet + end + @testset "custom leaf type" begin + a = [1.0] + v = CustomVector(a) + # bringing own IdSet to avoid calling the custom method directly; + # it should still be invoked + @test_logs (:info, "make_zero!(::CustomVector)") make_zero!(v, Base.IdSet()) + @test v.data === a # preserved identity + @test a[1] === 0.0 # correct value + end + @testset "undefined fields/unassigned elements" begin + @testset "array w inactive/active/mutable/unassigned" begin + a = [1.0] + values = ("a", 1.0, a) + arr = Vector{Any}(undef, 4) + arr[1:3] .= values + make_zero!(arr) + @views begin + @test all(typeof.(arr[1:3]) .=== typeof.(values)) # preserved types + @test arr[1:3] == ["a", 0.0, [0.0]] # correct value + @test arr[3] === a # preserved identity + @test !isassigned(arr, 4) # preserved unassigned + end + end + @testset "struct w inactive/active/mutable/undefined" begin + a = [1.0] + incompletearr = [Incomplete("a", 1.0, a)] + make_zero!(incompletearr) + @test incompletearr == [Incomplete("a", 0.0, [0.0])] # correct value, preserved undefined + @test incompletearr[1].w === a # preserved identity + end + @testset "mutable struct w inactive/const active/active/mutable/undefined" begin + a = [1.0] + incomplete = MutableIncomplete("a", #=const=#1.0, 1.0, a) + make_zero!(incomplete) + @test incomplete == MutableIncomplete("a", 0.0, 0.0, [0.0]) # correct value, preserved undefined + @test incomplete.w === a # preserved identity + end + @testset "Array{Tuple{struct w undefined}} (issue #1935)" begin + # old implementation triggered #1935 + # new implementation would work regardless due to limited use of justActive + a = [1.0] + incomplete = Incomplete("a", 1.0, a) + incompletetuparr = [(incomplete,)] + make_zero!(incompletetuparr) + @test typeof(incompletetuparr[1]) === typeof((incomplete,)) # preserved type + @test incompletetuparr == [(Incomplete("a", 0.0, [0.0]),)] # correct value + @test incompletetuparr[1][1].w === a # preserved identity + end + end + @testset "active/mixed type error" begin + @test_throws ArgumentError make_zero!((1.0,)) + @test_throws ArgumentError make_zero!((1.0, [1.0])) + @test_throws ArgumentError make_zero!((Incomplete("a", 1.0, 1.0im),)) # issue #1935 + end + @testset "containing IO" begin # issue #2091 + f = WithIO([1.0, 2.0], stdout) + fwrapped = [f] + @test_noerr make_zero!(fwrapped) + @test fwrapped[1] === f + @test fwrapped[1].v == [0.0, 0.0] + end + return nothing +end + +@testset "make_zero" test_make_zero() +@testset "make_zero!" test_make_zero!() + +end # module MakeZeroTests