From 45f01bd94ca560940328cd996cd7036276b7db9d Mon Sep 17 00:00:00 2001 From: Daniel Wennberg Date: Wed, 27 Nov 2024 21:27:40 -0800 Subject: [PATCH] make_zero(!) bugfixes and improved tests (#1961) * Fix make_zero(!) bugs * Add make_zero(!) tests Aiming for full coverage of both new and old implementations of make_zero(!) * Fix more make_zero(!) bugs and add more tests * Improve make_zero! error message * Simplify likely dead branch * Reinstate single-arg StaticArrays methods --- ext/EnzymeStaticArraysExt.jl | 47 ++- src/make_zero.jl | 237 +++++++----- test/abi.jl | 32 -- test/make_zero.jl | 725 +++++++++++++++++++++++++++++++++++ test/runtests.jl | 1 + 5 files changed, 912 insertions(+), 130 deletions(-) create mode 100644 test/make_zero.jl diff --git a/ext/EnzymeStaticArraysExt.jl b/ext/EnzymeStaticArraysExt.jl index c2639a4c99..ef955ebd9b 100644 --- a/ext/EnzymeStaticArraysExt.jl +++ b/ext/EnzymeStaticArraysExt.jl @@ -32,11 +32,50 @@ end end end -@inline function Enzyme.EnzymeCore.make_zero(x::FT)::FT where {FT<:SArray} - return Base.zero(x) +@inline function Enzyme.EnzymeCore.make_zero( + prev::FT +) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:SArray{S,T}} + return Base.zero(prev)::FT end -@inline function Enzyme.EnzymeCore.make_zero(x::FT)::FT where {FT<:MArray} - return Base.zero(x) +@inline function Enzyme.EnzymeCore.make_zero( + prev::FT +) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:MArray{S,T}} + return Base.zero(prev)::FT +end + +@inline function Enzyme.EnzymeCore.make_zero( + ::Type{FT}, seen::IdDict, prev::FT, ::Val{copy_if_inactive} = Val(false) +) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:SArray{S,T},copy_if_inactive} + return Base.zero(prev)::FT +end +@inline function Enzyme.EnzymeCore.make_zero( + ::Type{FT}, seen::IdDict, prev::FT, ::Val{copy_if_inactive} = Val(false) +) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:MArray{S,T},copy_if_inactive} + if haskey(seen, prev) + return seen[prev] + end + new = Base.zero(prev)::FT + seen[prev] = new + return new +end + +@inline function Enzyme.EnzymeCore.make_zero!( + prev::FT, seen +) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:MArray{S,T}} + if !isnothing(seen) + if prev in seen + return nothing + end + push!(seen, prev) + end + fill!(prev, zero(T)) + return nothing +end +@inline function Enzyme.EnzymeCore.make_zero!( + prev::FT +) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:MArray{S,T}} + Enzyme.EnzymeCore.make_zero!(prev, nothing) + return nothing end end diff --git a/src/make_zero.jl b/src/make_zero.jl index f2fd055c61..5c7b49a749 100644 --- a/src/make_zero.jl +++ b/src/make_zero.jl @@ -1,4 +1,3 @@ - @inline function EnzymeCore.make_zero(x::FT)::FT where {FT<:AbstractFloat} return Base.zero(x) end @@ -104,7 +103,7 @@ end prev::Complex{RT}, ::Val{copy_if_inactive} = Val(false), )::Complex{RT} where {copy_if_inactive,RT<:AbstractFloat} - return RT(0) + return Complex{RT}(0) end @inline function EnzymeCore.make_zero( @@ -178,7 +177,9 @@ end prev::NamedTuple{A,RT}, ::Val{copy_if_inactive} = Val(false), )::NamedTuple{A,RT} where {copy_if_inactive,A,RT} - return NamedTuple{A,RT}(EnzymeCore.make_zero(RT, seen, RT(prev), Val(copy_if_inactive))) + prevtup = RT(prev) + TT = Core.Typeof(prevtup) # RT can be abstract + return NamedTuple{A,RT}(EnzymeCore.make_zero(TT, seen, prevtup, Val(copy_if_inactive))) end @inline function EnzymeCore.make_zero( @@ -193,9 +194,7 @@ end prev2 = prev.contents res = Core.Box() seen[prev] = res - res.contents = Base.Ref( - EnzymeCore.make_zero(Core.Typeof(prev2), seen, prev2, Val(copy_if_inactive)), - ) + res.contents = EnzymeCore.make_zero(Core.Typeof(prev2), seen, prev2, Val(copy_if_inactive)) return res end @@ -214,7 +213,6 @@ end @assert !Base.isabstracttype(RT) @assert Base.isconcretetype(RT) nf = fieldcount(RT) - if ismutable(prev) y = ccall(:jl_new_struct_uninit, Any, (Any,), RT)::RT seen[prev] = y @@ -232,11 +230,9 @@ end end return y end - if nf == 0 return prev end - flds = Vector{Any}(undef, nf) for i = 1:nf if isdefined(prev, i) @@ -254,48 +250,71 @@ end end function make_zero_immutable!(prev::T, seen::S)::T where {T<:AbstractFloat,S} - zero(T) + return zero(T) end function make_zero_immutable!( prev::Complex{T}, seen::S, )::Complex{T} where {T<:AbstractFloat,S} - zero(T) + return zero(Complex{T}) end function make_zero_immutable!(prev::T, seen::S)::T where {T<:Tuple,S} + if guaranteed_const_nongen(T, nothing) + return prev # unreachable from make_zero! + end ntuple(Val(length(T.parameters))) do i Base.@_inline_meta - make_zero_immutable!(prev[i], seen) + p = prev[i] + SBT = Core.Typeof(p) + if guaranteed_const_nongen(SBT, nothing) + p # covered by several tests even if not shown in coverage + elseif !ismutabletype(SBT) + make_zero_immutable!(p, seen) + else + EnzymeCore.make_zero!(p, seen) + p + end end end function make_zero_immutable!(prev::NamedTuple{a,b}, seen::S)::NamedTuple{a,b} where {a,b,S} - NamedTuple{a,b}(ntuple(Val(length(T.parameters))) do i + if guaranteed_const_nongen(NamedTuple{a,b}, nothing) + return prev # unreachable from make_zero! + end + NamedTuple{a,b}(ntuple(Val(length(b.parameters))) do i Base.@_inline_meta - make_zero_immutable!(prev[a[i]], seen) + p = prev[a[i]] + SBT = Core.Typeof(p) + if guaranteed_const_nongen(SBT, nothing) + p # covered by several tests even if not shown in coverage + elseif !ismutabletype(SBT) + make_zero_immutable!(p, seen) + else + EnzymeCore.make_zero!(p, seen) + p + end end) end function make_zero_immutable!(prev::T, seen::S)::T where {T,S} if guaranteed_const_nongen(T, nothing) - return prev + return prev # unreachable from make_zero! end - @assert !ismutable(prev) - - RT = Core.Typeof(prev) - @assert !Base.isabstracttype(RT) - @assert Base.isconcretetype(RT) - nf = fieldcount(RT) - + @assert !ismutabletype(T) + @assert !Base.isabstracttype(T) + @assert Base.isconcretetype(T) + nf = fieldcount(T) flds = Vector{Any}(undef, nf) for i = 1:nf if isdefined(prev, i) xi = getfield(prev, i) ST = Core.Typeof(xi) - flds[i] = if active_reg_inner(ST, (), nothing, Val(true)) == ActiveState #=justActive=# + flds[i] = if guaranteed_const_nongen(ST, nothing) + xi + elseif !ismutabletype(ST) make_zero_immutable!(xi, seen) else EnzymeCore.make_zero!(xi, seen) @@ -306,39 +325,63 @@ function make_zero_immutable!(prev::T, seen::S)::T where {T,S} break end end - ccall(:jl_new_structv, Any, (Any, Ptr{Any}, UInt32), RT, flds, nf)::T + return ccall(:jl_new_structv, Any, (Any, Ptr{Any}, UInt32), T, flds, nf)::T end @inline function EnzymeCore.make_zero!( prev::Base.RefValue{T}, seen::ST, )::Nothing where {T<:AbstractFloat,ST} - T[] = zero(T) - nothing + if !isnothing(seen) + if prev in seen + return nothing + end + push!(seen, prev) + end + prev[] = zero(T) + return nothing end @inline function EnzymeCore.make_zero!( prev::Base.RefValue{Complex{T}}, seen::ST, )::Nothing where {T<:AbstractFloat,ST} - T[] = zero(Complex{T}) - nothing + if !isnothing(seen) + if prev in seen + return nothing + end + push!(seen, prev) + end + prev[] = zero(Complex{T}) + return nothing end @inline function EnzymeCore.make_zero!( prev::Array{T,N}, seen::ST, )::Nothing where {T<:AbstractFloat,N,ST} + if !isnothing(seen) + if prev in seen + return nothing + end + push!(seen, prev) + end fill!(prev, zero(T)) - nothing + return nothing end @inline function EnzymeCore.make_zero!( prev::Array{Complex{T},N}, seen::ST, )::Nothing where {T<:AbstractFloat,N,ST} + if !isnothing(seen) + if prev in seen + return nothing + end + push!(seen, prev) + end fill!(prev, zero(Complex{T})) - nothing + return nothing end @static if VERSION < v"1.11-" @@ -347,16 +390,28 @@ else prev::GenericMemory{kind, T}, seen::ST, )::Nothing where {T<:AbstractFloat,kind,ST} + if !isnothing(seen) + if prev in seen + return nothing + end + push!(seen, prev) + end fill!(prev, zero(T)) - nothing + return nothing end @inline function EnzymeCore.make_zero!( prev::GenericMemory{kind, Complex{T}}, seen::ST, )::Nothing where {T<:AbstractFloat,kind,ST} + if !isnothing(seen) + if prev in seen + return nothing + end + push!(seen, prev) + end fill!(prev, zero(Complex{T})) - nothing + return nothing end end @@ -364,90 +419,88 @@ end prev::Base.RefValue{T}, )::Nothing where {T<:AbstractFloat} EnzymeCore.make_zero!(prev, nothing) - nothing + return nothing end @inline function EnzymeCore.make_zero!( prev::Base.RefValue{Complex{T}}, )::Nothing where {T<:AbstractFloat} EnzymeCore.make_zero!(prev, nothing) - nothing + return nothing end @inline function EnzymeCore.make_zero!(prev::Array{T,N})::Nothing where {T<:AbstractFloat,N} EnzymeCore.make_zero!(prev, nothing) - nothing + return nothing end @inline function EnzymeCore.make_zero!( prev::Array{Complex{T},N}, )::Nothing where {T<:AbstractFloat,N} EnzymeCore.make_zero!(prev, nothing) - nothing + return nothing end @inline function EnzymeCore.make_zero!(prev::Array{T,N}, seen::ST)::Nothing where {T,N,ST} if guaranteed_const_nongen(T, nothing) - return + return nothing end - if in(seen, prev) - return + if prev in seen + return nothing end push!(seen, prev) - for I in eachindex(prev) if isassigned(prev, I) pv = prev[I] SBT = Core.Typeof(pv) - if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=# + if guaranteed_const_nongen(SBT, nothing) + continue + elseif !ismutabletype(SBT) @inbounds prev[I] = make_zero_immutable!(pv, seen) - nothing else EnzymeCore.make_zero!(pv, seen) - nothing end end end - nothing + return nothing end @static if VERSION < v"1.11-" else @inline function EnzymeCore.make_zero!(prev::GenericMemory{kind, T})::Nothing where {T<:AbstractFloat,kind} EnzymeCore.make_zero!(prev, nothing) - nothing + return nothing end @inline function EnzymeCore.make_zero!( prev::GenericMemory{kind, Complex{T}}, )::Nothing where {T<:AbstractFloat, kind} EnzymeCore.make_zero!(prev, nothing) - nothing + return nothing end @inline function EnzymeCore.make_zero!(prev::GenericMemory{kind, T}, seen::ST)::Nothing where {T,kind,ST} if guaranteed_const_nongen(T, nothing) - return + return nothing end - if in(seen, prev) - return + if prev in seen + return nothing end push!(seen, prev) - for I in eachindex(prev) if isassigned(prev, I) pv = prev[I] SBT = Core.Typeof(pv) - if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=# + if guaranteed_const_nongen(SBT, nothing) + continue + elseif !ismutabletype(SBT) @inbounds prev[I] = make_zero_immutable!(pv, seen) - nothing else EnzymeCore.make_zero!(pv, seen) - nothing end end end - nothing + return nothing end end @@ -457,82 +510,78 @@ end seen::ST, )::Nothing where {T,ST} if guaranteed_const_nongen(T, nothing) - return + return nothing end - if in(seen, prev) - return + if prev in seen + return nothing end push!(seen, prev) - pv = prev[] SBT = Core.Typeof(pv) - if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=# + if guaranteed_const_nongen(SBT, nothing) + return nothing + elseif !ismutabletype(SBT) prev[] = make_zero_immutable!(pv, seen) - nothing else EnzymeCore.make_zero!(pv, seen) - nothing end - nothing + return nothing end @inline function EnzymeCore.make_zero!(prev::Core.Box, seen::ST)::Nothing where {ST} - pv = prev.contents - T = Core.Typeof(pv) - if guaranteed_const_nongen(T, nothing) - return - end - if in(seen, prev) - return + if prev in seen + return nothing end push!(seen, prev) + pv = prev.contents SBT = Core.Typeof(pv) - if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=# - prev.contents = EnzymeCore.make_zero_immutable!(pv, seen) - nothing + if guaranteed_const_nongen(SBT, nothing) + return nothing + elseif !ismutabletype(SBT) + prev.contents = make_zero_immutable!(pv, seen) else EnzymeCore.make_zero!(pv, seen) - nothing end - nothing + return nothing end -@inline function EnzymeCore.make_zero!( - prev::T, - seen::S = Base.IdSet{Any}(), -)::Nothing where {T,S} +@inline function EnzymeCore.make_zero!(prev::T, seen::S)::Nothing where {T,S} if guaranteed_const_nongen(T, nothing) - return + return nothing end - if in(prev, seen) - return + if prev in seen + return nothing end @assert !Base.isabstracttype(T) @assert Base.isconcretetype(T) nf = fieldcount(T) - - if nf == 0 - return + return nothing end - push!(seen, prev) - for i = 1:nf if isdefined(prev, i) xi = getfield(prev, i) SBT = Core.Typeof(xi) - if guaranteed_const_nongen(SBT, nothing) + activitystate = active_reg_inner(SBT, (), nothing) + if activitystate == AnyState # guaranteed_const continue - end - if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=# - setfield!(prev, i, make_zero_immutable!(xi, seen)) - nothing - else + elseif ismutabletype(T) && !ismutabletype(SBT) + yi = make_zero_immutable!(xi, seen) + if Base.isconst(T, i) + ccall(:jl_set_nth_field, Cvoid, (Any, Csize_t, Any), prev, i-1, yi) + else + setfield!(prev, i, yi) + end + elseif activitystate == DupState EnzymeCore.make_zero!(xi, seen) - nothing + else + msg = "cannot set $xi to zero in-place, as it contains differentiable values in immutable positions" + throw(ArgumentError(msg)) end end end - return + return nothing end + +@inline EnzymeCore.make_zero!(prev) = EnzymeCore.make_zero!(prev, Base.IdSet()) diff --git a/test/abi.jl b/test/abi.jl index 20747f2aaa..b6898ac1ba 100644 --- a/test/abi.jl +++ b/test/abi.jl @@ -489,38 +489,6 @@ mulsin(x) = sin(x[1] * x[2]) @test Enzyme.autodiff(ForwardWithPrimal, () -> Enzyme.within_autodiff())[1] end -mutable struct ConstVal - x::Float64 - const y::Float64 -end - -struct WithIO{F} - v::Vector{Float64} - callback::F - function WithIO(v, io) - callback() = println(io, "hello") - return new{typeof(callback)}(v, callback) - end -end - -@testset "Make Zero" begin - v = ConstVal(2.0, 3.0) - dv = make_zero(v) - @test dv isa ConstVal - @test dv.x ≈ 0.0 - @test dv.y ≈ 0.0 - - f = WithIO([1.0, 2.0], stdout) - df = @test_nowarn try - # catch errors to get failed test instead of "exception outside of a @test" - make_zero(f) - catch e - showerror(stderr, e) - end - @test df.v == [0.0, 0.0] - @test df.callback === f.callback -end - @testset "Type inference" begin x = ones(10) @inferred autodiff(Enzyme.Reverse, abssum, Duplicated(x,x)) diff --git a/test/make_zero.jl b/test/make_zero.jl new file mode 100644 index 0000000000..cbe2f2159f --- /dev/null +++ b/test/make_zero.jl @@ -0,0 +1,725 @@ +module MakeZeroTests + +using Enzyme +using StaticArrays +using Test + +# Universal getters/setters for built-in and custom containers/wrappers +getx(w::Base.RefValue) = w[] +getx(w::Core.Box) = w.contents +getx(w) = first(w) +gety(w) = last(w) + +setx!(w::Base.RefValue, x) = (w[] = x) +setx!(w::Core.Box, x) = (w.contents = x) +setx!(w, x) = (w[begin] = x) +sety!(w, y) = (w[end] = y) + +# non-isbits MArray doesn't support setindex!, so requires a little hack +function setx!(w::MArray{S,T}, x) where {S,T} + if isbitstype(T) + w[begin] = x + else + w.data = (x, Base.tail(w.data)...) + end + return x +end + +function sety!(w::MArray{S,T}, y) where {S,T} + if isbitstype(T) + w[end] = y + else + w.data = (Base.front(w.data)..., y) + end + return y +end + +struct Empty end + +mutable struct MutableEmpty end + +Base.:(==)(::MutableEmpty, ::MutableEmpty) = true + +struct Wrapper{T} + x::T +end + +Base.:(==)(a::Wrapper, b::Wrapper) = (a === b) || (a.x == b.x) +getx(a::Wrapper) = a.x + +mutable struct MutableWrapper{T} + x::T +end + +Base.:(==)(a::MutableWrapper, b::MutableWrapper) = (a === b) || (a.x == b.x) + +getx(a::MutableWrapper) = a.x +setx!(a::MutableWrapper, x) = (a.x = x) + +struct DualWrapper{Tx,Ty} + x::Tx + y::Ty +end + +DualWrapper{T}(x::T, y) where {T} = DualWrapper{T,typeof(y)}(x, y) + +function Base.:(==)(a::DualWrapper, b::DualWrapper) + return (a === b) || ((a.x == b.x) && (a.y == b.y)) +end + +getx(a::DualWrapper) = a.x +gety(a::DualWrapper) = a.y + +mutable struct MutableDualWrapper{Tx,Ty} + x::Tx + y::Ty +end + +MutableDualWrapper{T}(x::T, y) where {T} = MutableDualWrapper{T,typeof(y)}(x, y) + +function Base.:(==)(a::MutableDualWrapper, b::MutableDualWrapper) + return (a === b) || ((a.x == b.x) && (a.y == b.y)) +end + +getx(a::MutableDualWrapper) = a.x +gety(a::MutableDualWrapper) = a.y + +setx!(a::MutableDualWrapper, x) = (a.x = x) +sety!(a::MutableDualWrapper, y) = (a.y = y) + +struct Incomplete{T} + s::String + x::Float64 + w::T + z # not initialized + Incomplete(s, x, w) = new{typeof(w)}(s, x, w) +end + +function Base.:(==)(a::Incomplete, b::Incomplete) + (a === b) && return true + ((a.s == b.s) && (a.x == b.x) && (a.w == b.w)) || return false + if isdefined(a, :z) && isdefined(b, :z) + (a.z == b.z) || return false + elseif isdefined(a, :z) || isdefined(b, :z) + return false + end + return true +end + +mutable struct MutableIncomplete{T} + s::String + const x::Float64 + y::Float64 + z # not initialized + w::T + function MutableIncomplete(s, x, y, w) + ret = new{typeof(w)}(s, x, y) + ret.w = w + return ret + end +end + +function Base.:(==)(a::MutableIncomplete, b::MutableIncomplete) + (a === b) && return true + if (a.s != b.s) || (a.x != b.x) || (a.y != b.y) || (a.w != b.w) + return false + end + if isdefined(a, :z) && isdefined(b, :z) + (a.z == b.z) || return false + elseif isdefined(a, :z) || isdefined(b, :z) + return false + end + return true +end + +mutable struct CustomVector{T} <: AbstractVector{T} + data::Vector{T} +end + +Base.:(==)(a::CustomVector, b::CustomVector) = (a === b) || (a.data == b.data) + +function Enzyme.EnzymeCore.make_zero( + ::Type{CV}, seen::IdDict, prev::CV, ::Val{copy_if_inactive} +) where {CV<:CustomVector{<:AbstractFloat},copy_if_inactive} + @info "make_zero(::CustomVector)" + if haskey(seen, prev) + return seen[prev] + end + new = CustomVector(zero(prev.data))::CV + seen[prev] = new + return new +end + +function Enzyme.EnzymeCore.make_zero!(prev::CustomVector{<:AbstractFloat}, seen)::Nothing + @info "make_zero!(::CustomVector)" + if !isnothing(seen) + if prev in seen + return nothing + end + push!(seen, prev) + end + fill!(prev.data, false) + return nothing +end + +function Enzyme.EnzymeCore.make_zero!(prev::CustomVector{<:AbstractFloat}) + return Enzyme.EnzymeCore.make_zero!(prev, nothing) +end + +struct WithIO{F} # issue 2091 + v::Vector{Float64} + callback::F + function WithIO(v, io) + callback() = println(io, "hello") + return new{typeof(callback)}(v, callback) + end +end + +macro test_noerr(expr) + return quote + @test_nowarn try + # catch errors to get failed test instead of "exception outside of a @test" + $(esc(expr)) + catch e + showerror(stderr, e) + end + end +end + +const scalartypes = [Float32, ComplexF32, Float64, ComplexF64] + +const inactivetup = ("a", Empty(), MutableEmpty()) +const inactivearr = [inactivetup] + +const wrappers = [ + (name="Tuple{X}", f=tuple, N=1, mutable=false, typed=true), + (name="@NamedTuple{x::X}", f=(NamedTuple{(:x,)} ∘ tuple), N=1, mutable=false, typed=true), + (name="struct{X}", f=Wrapper, N=1, mutable=false, typed=true), + + (name="@NamedTuple{x}", f=(@NamedTuple{x} ∘ tuple), N=1, mutable=false, typed=false), + (name="struct{Any}", f=Wrapper{Any}, N=1, mutable=false, typed=false), + + (name="Array{X}", f=(x -> [x]), N=1, mutable=true, typed=true), + (name="Base.RefValue{X}", f=Ref, N=1, mutable=true, typed=true), + (name="mutable struct{X}", f=MutableWrapper, N=1, mutable=true, typed=true), + + (name="Array{Any}", f=(x -> Any[x]), N=1, mutable=true, typed=false), + (name="Base.RefValue{Any}", f=Ref{Any}, N=1, mutable=true, typed=false), + (name="Core.Box", f=Core.Box, N=1, mutable=true, typed=false), + (name="mutable struct{Any}", f=MutableWrapper{Any}, N=1, mutable=true, typed=false), + + (name="Tuple{X,Y}", f=tuple, N=2, mutable=false, typed=true), + (name="@NamedTuple{x::X,y::Y}", f=(NamedTuple{(:x, :y)} ∘ tuple), N=2, mutable=false, typed=true), + (name="struct{X,Y}", f=DualWrapper, N=2, mutable=false, typed=true), + + (name="@NamedTuple{x,y::Y}", f=((x, y) -> @NamedTuple{x,y::typeof(y)}((x, y))), N=2, mutable=false, typed=:partial), + (name="struct{Any,Y}", f=DualWrapper{Any}, N=2, mutable=false, typed=:partial), + + (name="@NamedTuple{x,y}", f=@NamedTuple{x,y} ∘ tuple, N=2, mutable=false, typed=false), + (name="struct{Any}", f=DualWrapper{Any,Any}, N=2, mutable=false, typed=false), + + (name="mutable struct{X,Y}", f=MutableDualWrapper, N=2, mutable=true, typed=true), + + (name="Array{promote_type(X,Y)}", f=((x, y) -> [x, y]), N=2, mutable=true, typed=:promoted), + (name="mutable struct{Any,Y}", f=MutableDualWrapper{Any}, N=2, mutable=true, typed=:partial), + + (name="Array{Any}", f=((x, y) -> Any[x, y]), N=2, mutable=true, typed=false), + (name="mutable struct{Any,Any}", f=MutableDualWrapper{Any,Any}, N=2, mutable=true, typed=false), + + # StaticArrays extension + (name="SVector{1,X}", f=SVector{1} ∘ tuple, N=1, mutable=false, typed=true), + (name="SVector{1,Any}", f=SVector{1,Any} ∘ tuple, N=1, mutable=false, typed=false), + (name="MVector{1,X}", f=MVector{1} ∘ tuple, N=1, mutable=true, typed=true), + (name="MVector{1,Any}", f=MVector{1,Any} ∘ tuple, N=1, mutable=true, typed=false), + (name="SVector{2,promote_type(X,Y)}", f=SVector{2} ∘ tuple, N=2, mutable=false, typed=:promoted), + (name="SVector{2,Any}", f=SVector{2,Any} ∘ tuple, N=2, mutable=false, typed=false), + (name="MVector{2,promote_type(X,Y)}", f=MVector{2} ∘ tuple, N=2, mutable=true, typed=:promoted), + (name="MVector{2,Any}", f=MVector{2,Any} ∘ tuple, N=2, mutable=true, typed=false), +] + +@static if VERSION < v"1.11-" +else +_memory(x::Vector) = Memory{eltype(x)}(x) +push!( + wrappers, + (name="Memory{X}", f=(x -> _memory([x])), N=1, mutable=true, typed=true), + (name="Memory{Any}", f=(x -> _memory(Any[x])), N=1, mutable=true, typed=false), + (name="Memory{promote_type(X,Y)}", f=((x, y) -> _memory([x, y])), N=2, mutable=true, typed=:promoted), + (name="Memory{Any}", f=((x, y) -> _memory(Any[x, y])), N=2, mutable=true, typed=false), +) +end + +function test_make_zero() + @testset "scalars" begin + @testset "$T" for T in scalartypes + x = oneunit(T) + x_makez = make_zero(x) + @test typeof(x_makez) === T # correct type + @test x_makez == zero(T) # correct value + @test x == oneunit(T) # no mutation of original (relevant for BigFloat) + end + end + @testset "nested types" begin + @testset "$T in $(wrapper.name)" for + T in scalartypes, wrapper in filter(w -> (w.N == 1), wrappers) + x = oneunit(T) + w = wrapper.f(x) + w_makez = make_zero(w) + @test typeof(w_makez) === typeof(w) # correct type + @test typeof(getx(w_makez)) === T # correct type + @test getx(w_makez) == zero(T) # correct value + @test getx(w) === x # no mutation of original + @test x == oneunit(T) # no mutation of original (relevant for BigFloat) + @testset "doubly included in $(dualwrapper.name)" for + dualwrapper in filter(w -> (w.N == 2), wrappers) + w_inner = wrapper.f(x) + d_outer = dualwrapper.f(w_inner, w_inner) + d_outer_makez = make_zero(d_outer) + @test typeof(d_outer_makez) === typeof(d_outer) # correct type + @test typeof(getx(d_outer_makez)) === typeof(w_inner) # correct type + @test typeof(getx(getx(d_outer_makez))) === T # correct type + @test getx(d_outer_makez) === gety(d_outer_makez) # correct topology + @test getx(getx(d_outer_makez)) == zero(T) # correct value + @test getx(d_outer) === gety(d_outer) # no mutation of original + @test getx(d_outer) === w_inner # no mutation of original + @test getx(w_inner) === x # no mutation of original + @test x == oneunit(T) # no mutation of original (relevant for BigFloat) + d_inner = dualwrapper.f(x, x) + w_outer = wrapper.f(d_inner) + w_outer_makez = make_zero(w_outer) + @test typeof(w_outer_makez) === typeof(w_outer) # correct type + @test typeof(getx(w_outer_makez)) === typeof(d_inner) # correct type + @test typeof(getx(getx(w_outer_makez))) === T # correct type + @test getx(getx(w_outer_makez)) == gety(getx(w_outer_makez)) # correct topology + @test getx(getx(w_outer_makez)) == zero(T) # correct value + @test getx(w_outer) === d_inner # no mutation of original + @test getx(d_inner) === gety(d_inner) # no mutation of original + @test getx(d_inner) === x # no mutation of original + @test x == oneunit(T) # no mutation of original (relevant for BigFloat) + if wrapper.mutable && !dualwrapper.mutable + # some code paths can only be hit with three layers of wrapping: + # mutable(immutable(mutable(scalar))) + @testset "all wrapped in $(outerwrapper.name)" for + outerwrapper in filter(w -> ((w.N == 1) && w.mutable), wrappers) + w_inner = wrapper.f(x) + d_middle = dualwrapper.f(w_inner, w_inner) + w_outer = outerwrapper.f(d_middle) + w_outer_makez = make_zero(w_outer) + @test typeof(w_outer_makez) === typeof(w_outer) # correct type + @test typeof(getx(w_outer_makez)) === typeof(d_middle) # correct type + @test typeof(getx(getx(w_outer_makez))) === typeof(w_inner) # correct type + @test typeof(getx(getx(getx(w_outer_makez)))) === T # correct type + @test getx(getx(w_outer_makez)) === gety(getx(w_outer_makez)) # correct topology + @test getx(getx(getx(w_outer_makez))) == zero(T) # correct value + @test getx(w_outer) === d_middle # no mutation of original + @test getx(d_middle) === gety(d_middle) # no mutation of original + @test getx(d_middle) === w_inner # no mutation of original + @test getx(w_inner) === x # no mutation of original + @test x == oneunit(T) # no mutation of original (relevant for BigFloat) + end + end + end + end + end + @testset "inactive" begin + @testset "in $(wrapper.name)" for wrapper in wrappers + if wrapper.N == 1 + w = wrapper.f(inactivearr) + w_makez = make_zero(w) + if wrapper.typed == true + @test w_makez === w # preserved wrapper identity if guaranteed const + end + @test typeof(w_makez) === typeof(w) # correct type + @test getx(w_makez) === inactivearr # preserved identity + @test inactivearr[1] === inactivetup # preserved value + @test getx(w) === inactivearr # no mutation of original + else # wrapper.N == 2 + @testset "multiple references" begin + w = wrapper.f(inactivearr, inactivearr) + w_makez = make_zero(w) + if wrapper.typed == true + @test w_makez === w # preserved wrapper identity if guaranteed const + end + @test typeof(w_makez) === typeof(w) # correct type + @test getx(w_makez) === gety(w_makez) # preserved topology + @test getx(w_makez) === inactivearr # preserved identity + @test inactivearr[1] === inactivetup # preserved value + @test getx(w) === gety(w) # no mutation of original + @test getx(w) === inactivearr # no mutation of original + end + @testset "alongside active" begin + a = [1.0] + w = wrapper.f(a, inactivearr) + w_makez = make_zero(w) + @test typeof(w_makez) === typeof(w) # correct type + @test typeof(getx(w_makez)) === typeof(a) # correct type + @test getx(w_makez) == [0.0] # correct value + @test gety(w_makez) === inactivearr # preserved inactive identity + @test inactivearr[1] === inactivetup # preserved inactive value + @test getx(w) === a # no mutation of original + @test a[1] === 1.0 # no mutation of original + @test gety(w) === inactivearr # no mutation of original + if wrapper.typed == :partial + # above: untyped active / typed inactive + # below: untyped inactive / typed active + w = wrapper.f(inactivearr, a) + w_makez = make_zero(w) + @test typeof(w_makez) === typeof(w) # correct type + @test getx(w_makez) === inactivearr # preserved inactive identity + @test inactivearr[1] === inactivetup # preserved inactive value + @test typeof(gety(w_makez)) === typeof(a) # correct type + @test gety(w_makez) == [0.0] # correct value + @test getx(w) === inactivearr # no mutation of original + @test gety(w) === a # no mutation of original + @test a[1] === 1.0 # no mutation of original + end + end + end + end + @testset "copy_if_inactive $value" for (value, args) in [ + ("unspecified", ()), + ("= false", (Val(false),)), + ("= true", (Val(true),)), + ] + a = [1.0] + w = Any[a, inactivearr, inactivearr] + w_makez = make_zero(w, args...) + @test typeof(w_makez) === typeof(w) # correct type + @test typeof(w_makez[1]) === typeof(a) # correct type + @test w_makez[1] == [0.0] # correct value + @test w_makez[2] === w_makez[3] # correct topology (topology should propagate even when copy_if_inactive = Val(true)) + @test w[1] === a # no mutation of original + @test a[1] === 1.0 # no mutation of original + @test w[2] === w[3] # no mutation of original + @test w[2] === inactivearr # no mutation of original + @test inactivearr[1] === inactivetup # no mutation of original + if args == (Val(true),) + @test typeof(w_makez[2]) === typeof(inactivearr) # correct type + @test w_makez[2] == inactivearr # correct value + @test w_makez[2][1] !== inactivetup # correct identity + else + @test w_makez[2] === inactivearr # correct value/type/identity + end + end + end + @testset "heterogeneous containers" begin + scalars, scalarsz = oneunit.(scalartypes), zero.(scalartypes) + wraps, wrapsz = Wrapper.(scalars), Wrapper.(scalarsz) + mwraps, mwrapsz = MutableWrapper.(scalars), MutableWrapper.(scalarsz) + items = (inactivetup..., scalars..., wraps..., mwraps...) + itemsz = (inactivetup..., scalarsz..., wrapsz..., mwrapsz...) + labels = Symbol.("i" .* string.(1:length(items))) + @testset "$name" for (name, c, cz) in [ + ("Tuple", Tuple(items), Tuple(itemsz)), + ("NamedTuple", NamedTuple(labels .=> items), NamedTuple(labels .=> itemsz)), + ("Array", collect(items), collect(itemsz)), + ] + c_makez = make_zero(c) + @test typeof(c_makez) === typeof(c) # correct type + @test all(typeof(czj) === typeof(cj) for (czj, cj) in zip(c_makez, c)) # correct type + @test c_makez == cz # correct value + @test all(czj === inj for (czj, inj) in zip(c_makez, inactivetup)) # preserved inactive identities + @test all(cj === itj for (cj, itj) in zip(c, items)) # no mutation of original + @test all(m.x == oneunit(m.x) for m in mwraps) # no mutation of original + end + end + @testset "circular references" begin + @testset "$(wrapper.name)" for wrapper in ( + filter(w -> (w.mutable && (w.typed in (:partial, false))), wrappers) + ) + a = [1.0] + if wrapper.N == 1 + w = wrapper.f(nothing) + setx!(w, (w, a)) + else + w = wrapper.f(nothing, a) + setx!(w, w) + end + w_makez = @test_noerr make_zero(w) + if wrapper.N == 1 + xz, yz = getx(w_makez) + x, y = getx(w) + else + xz, yz = getx(w_makez), gety(w_makez) + x, y = getx(w), gety(w) + end + @test typeof(w_makez) === typeof(w) # correct type + @test typeof(xz) === typeof(w) # correct type + @test typeof(yz) === typeof(a) # correct type + @test xz === w_makez # correct self-reference + @test yz == [0.0] # correct value + @test x === w # no mutation of original + @test y === a # no mutation of original + @test a[1] === 1.0 # no mutation of original + end + end + @testset "bring your own IdDict" begin + a = [1.0] + seen = IdDict() + a_makez = make_zero(typeof(a), seen, a) + @test typeof(a_makez) === typeof(a) # correct type + @test a_makez == [0.0] # correct value + @test a[1] === 1.0 # no mutation of original + @test haskey(seen, a) # original added to IdDict + @test seen[a] === a_makez # original points to zeroed value + end + @testset "custom leaf type" begin + a = [1.0] + v = CustomVector(a) + # include optional arg Val(false) to avoid calling the custom method directly; + # it should still be invoked + v_makez = @test_logs (:info, "make_zero(::CustomVector)") make_zero(v, Val(false)) + @test typeof(v_makez) === typeof(v) # correct type + @test typeof(v_makez.data) === typeof(a) # correct type + @test v_makez == CustomVector([0.0]) # correct value + @test v.data === a # no mutation of original + @test a[1] === 1.0 # no mutation of original + end + @testset "undefined fields/unassigned elements" begin + @testset "array w inactive/active/mutable/unassigned" begin + a = [1.0] + values = ("a", 1.0, a) + arr = Vector{Any}(undef, 4) + arr[1:3] .= values + arr_makez = make_zero(arr) + @views begin + @test typeof(arr_makez) === typeof(arr) # correct type + @test all(typeof.(arr_makez[1:3]) .=== typeof.(values)) # correct type + @test arr_makez[1:3] == ["a", 0.0, [0.0]] # correct value + @test !isassigned(arr_makez, 4) # propagated undefined + @test all(arr[1:3] .=== values) # no mutation of original + @test !isassigned(arr, 4) # no mutation of original + @test a[1] === 1.0 # no mutation of original + end + end + @testset "struct w inactive/active/mutable/undefined" begin + a = [1.0] + incomplete = Incomplete("a", 1.0, a) + incomplete_makez = make_zero(incomplete) + @test typeof(incomplete_makez) === typeof(incomplete) # correct type + @test typeof(incomplete_makez.w) === typeof(a) # correct type + @test incomplete_makez == Incomplete("a", 0.0, [0.0]) # correct value, propagated undefined + @test a[1] === 1.0 # no mutation of original + end + @testset "mutable struct w inactive/const active/active/mutable/undefined" begin + a = [1.0] + incomplete = MutableIncomplete("a", #=const=#1.0, 1.0, a) + incomplete_makez = make_zero(incomplete) + @test typeof(incomplete_makez) === typeof(incomplete) # correct type + @test typeof(incomplete_makez.w) === typeof(a) # correct type + @test incomplete_makez == MutableIncomplete("a", 0.0, 0.0, [0.0]) # correct value, propagated undefined + @test incomplete == MutableIncomplete("a", 1.0, 1.0, a) # no mutation of original + @test incomplete.w === a # no mutation of original + @test a[1] === 1.0 # no mutation of original + end + end + @testset "containing IO" begin # issue #2091 + f = WithIO([1.0, 2.0], stdout) + df = @test_noerr make_zero(f) + @test df.v == [0.0, 0.0] + @test df.callback === f.callback + end + return nothing +end + +function test_make_zero!() + @testset "nested types" begin + @testset "$T in $(wrapper.name)" for + T in scalartypes, wrapper in filter(w -> (w.N == 1), wrappers) + x = oneunit(T) + if wrapper.mutable + w = wrapper.f(x) + make_zero!(w) + @test typeof(getx(w)) === T # preserved type + @test getx(w) == zero(T) # correct value + @test x == oneunit(T) # no mutation of scalar (relevant for BigFloat) + end + @testset "doubly included in $(dualwrapper.name)" for dualwrapper in ( + filter(w -> ((w.N == 2) && (w.mutable || wrapper.mutable)), wrappers) + ) + w_inner = wrapper.f(x) + d_outer = dualwrapper.f(w_inner, w_inner) + make_zero!(d_outer) + @test typeof(getx(d_outer)) === typeof(w_inner) # preserved type + @test typeof(getx(getx(d_outer))) === T # preserved type + @test getx(getx(d_outer)) == zero(T) # correct value + @test getx(d_outer) === gety(d_outer) # preserved topology + @test x == oneunit(T) # no mutation of scalar (relevant for BigFloat) + if wrapper.mutable + @test getx(d_outer) === w_inner # preserved identity + end + d_inner = dualwrapper.f(x, x) + w_outer = wrapper.f(d_inner) + make_zero!(w_outer) + @test typeof(getx(w_outer)) === typeof(d_inner) # preserved type + @test typeof(getx(getx(w_outer))) === T # preserved type + @test getx(getx(w_outer)) == zero(T) # correct value + @test getx(getx(w_outer)) === gety(getx(w_outer)) # preserved topology + @test x == oneunit(T) # no mutation of scalar (relevant for BigFloat) + if dualwrapper.mutable + @test getx(w_outer) === d_inner # preserved identity + end + if wrapper.mutable && !dualwrapper.mutable + # some code paths can only be hit with three layers of wrapping: + # mutable(immutable(mutable(scalar))) + @assert !dualwrapper.mutable # sanity check + @testset "all wrapped in $(outerwrapper.name)" for + outerwrapper in filter(w -> ((w.N == 1) && w.mutable), wrappers) + w_inner = wrapper.f(x) + d_middle = dualwrapper.f(w_inner, w_inner) + w_outer = outerwrapper.f(d_middle) + make_zero!(w_outer) + @test typeof(getx(w_outer)) === typeof(d_middle) # preserved type + @test typeof(getx(getx(w_outer))) === typeof(w_inner) # preserved type + @test typeof(getx(getx(getx(w_outer)))) === T # preserved type + @test getx(getx(getx(w_outer))) == zero(T) # correct value + @test getx(getx(w_outer)) === gety(getx(w_outer)) # preserved topology + @test getx(getx(w_outer)) === w_inner # preserved identity + @test x == oneunit(T) # no mutation of scalar (relevant for BigFloat) + end + end + end + end + end + @testset "inactive" begin + @testset "in $(wrapper.name)" for + wrapper in filter(w -> (w.mutable || (w.typed == true)), wrappers) + if wrapper.N == 1 + w = wrapper.f(inactivearr) + make_zero!(w) + @test getx(w) === inactivearr # preserved identity + @test inactivearr[1] === inactivetup # preserved value + else # wrapper.N == 2 + @testset "multiple references" begin + w = wrapper.f(inactivearr, inactivearr) + make_zero!(w) + @test getx(w) === gety(w) # preserved topology + @test getx(w) === inactivearr # preserved identity + @test inactivearr[1] === inactivetup # preserved value + end + @testset "alongside active" begin + a = [1.0] + w = wrapper.f(a, inactivearr) + make_zero!(w) + @test getx(w) === a # preserved identity + @test a[1] === 0.0 # correct value + @test gety(w) === inactivearr # preserved inactive identity + @test inactivearr[1] === inactivetup # preserved inactive value + end + end + end + end + @testset "heterogeneous containers" begin + mwraps = MutableWrapper.(oneunit.(scalartypes)) + mwrapsz = MutableWrapper.(zero.(scalartypes)) + items = (inactivetup..., mwraps...) + itemsz = (inactivetup..., mwrapsz...) + labels = Symbol.("i" .* string.(1:length(items))) + @testset "$name" for (name, c, cz) in [ + ("Tuple", Tuple(items), Tuple(itemsz)), + ("NamedTuple", NamedTuple(labels .=> items), NamedTuple(labels .=> itemsz)), + ("Array", collect(items), collect(itemsz)), + ] + make_zero!(c) + @test all(cj === itj for (cj, itj) in zip(c, items)) # preserved identities + @test c == cz # correct value + end + end + @testset "circular references" begin + @testset "$(wrapper.name)" for wrapper in ( + filter(w -> (w.mutable && (w.typed in (:partial, false))), wrappers) + ) + a = [1.0] + if wrapper.N == 1 + w = wrapper.f(nothing) + setx!(w, (w, a)) + else + w = wrapper.f(nothing, a) + setx!(w, w) + end + @test_noerr make_zero!(w) + if wrapper.N == 1 + x, y = getx(w) + else + x, y = getx(w), gety(w) + end + @test x === w # preserved self-referential identity + @test y === a # preserved identity + @test a[1] === 0.0 # correct value + end + end + @testset "bring your own IdSet" begin + a = [1.0] + seen = Base.IdSet() + make_zero!(a, seen) + @test a[1] === 0.0 # correct value + @test (a in seen) # object added to IdSet + end + @testset "custom leaf type" begin + a = [1.0] + v = CustomVector(a) + # bringing own IdSet to avoid calling the custom method directly; + # it should still be invoked + @test_logs (:info, "make_zero!(::CustomVector)") make_zero!(v, Base.IdSet()) + @test v.data === a # preserved identity + @test a[1] === 0.0 # correct value + end + @testset "undefined fields/unassigned elements" begin + @testset "array w inactive/active/mutable/unassigned" begin + a = [1.0] + values = ("a", 1.0, a) + arr = Vector{Any}(undef, 4) + arr[1:3] .= values + make_zero!(arr) + @views begin + @test all(typeof.(arr[1:3]) .=== typeof.(values)) # preserved types + @test arr[1:3] == ["a", 0.0, [0.0]] # correct value + @test arr[3] === a # preserved identity + @test !isassigned(arr, 4) # preserved unassigned + end + end + @testset "struct w inactive/active/mutable/undefined" begin + a = [1.0] + incompletearr = [Incomplete("a", 1.0, a)] + make_zero!(incompletearr) + @test incompletearr == [Incomplete("a", 0.0, [0.0])] # correct value, preserved undefined + @test incompletearr[1].w === a # preserved identity + end + @testset "mutable struct w inactive/const active/active/mutable/undefined" begin + a = [1.0] + incomplete = MutableIncomplete("a", #=const=#1.0, 1.0, a) + make_zero!(incomplete) + @test incomplete == MutableIncomplete("a", 0.0, 0.0, [0.0]) # correct value, preserved undefined + @test incomplete.w === a # preserved identity + end + @testset "Array{Tuple{struct w undefined}} (issue #1935)" begin + # old implementation triggered #1935 + # new implementation would work regardless due to limited use of justActive + a = [1.0] + incomplete = Incomplete("a", 1.0, a) + incompletetuparr = [(incomplete,)] + make_zero!(incompletetuparr) + @test typeof(incompletetuparr[1]) === typeof((incomplete,)) # preserved type + @test incompletetuparr == [(Incomplete("a", 0.0, [0.0]),)] # correct value + @test incompletetuparr[1][1].w === a # preserved identity + end + end + @testset "active/mixed type error" begin + @test_throws ArgumentError make_zero!((1.0,)) + @test_throws ArgumentError make_zero!((1.0, [1.0])) + @test_throws ArgumentError make_zero!((Incomplete("a", 1.0, 1.0im),)) # issue #1935 + end + @testset "containing IO" begin # issue #2091 + f = WithIO([1.0, 2.0], stdout) + fwrapped = [f] + @test_noerr make_zero!(fwrapped) + @test fwrapped[1] === f + @test fwrapped[1].v == [0.0, 0.0] + end + return nothing +end + +@testset "make_zero" test_make_zero() +@testset "make_zero!" test_make_zero!() + +end # module MakeZeroTests diff --git a/test/runtests.jl b/test/runtests.jl index 26587c3892..e331645378 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -74,6 +74,7 @@ end include("abi.jl") include("typetree.jl") include("optimize.jl") +include("make_zero.jl") include("rules.jl") include("rrules.jl")