Skip to content

Commit

Permalink
make_zero(!) bugfixes and improved tests (#1961)
Browse files Browse the repository at this point in the history
* Fix make_zero(!) bugs

* Add make_zero(!) tests

Aiming for full coverage of both new and old implementations of
make_zero(!)

* Fix more make_zero(!) bugs and add more tests

* Improve make_zero! error message

* Simplify likely dead branch

* Reinstate single-arg StaticArrays methods
  • Loading branch information
danielwe authored Nov 28, 2024
1 parent 06e791e commit 45f01bd
Show file tree
Hide file tree
Showing 5 changed files with 912 additions and 130 deletions.
47 changes: 43 additions & 4 deletions ext/EnzymeStaticArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,50 @@ end
end
end

@inline function Enzyme.EnzymeCore.make_zero(x::FT)::FT where {FT<:SArray}
return Base.zero(x)
@inline function Enzyme.EnzymeCore.make_zero(
prev::FT
) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:SArray{S,T}}
return Base.zero(prev)::FT
end
@inline function Enzyme.EnzymeCore.make_zero(x::FT)::FT where {FT<:MArray}
return Base.zero(x)
@inline function Enzyme.EnzymeCore.make_zero(
prev::FT
) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:MArray{S,T}}
return Base.zero(prev)::FT
end

@inline function Enzyme.EnzymeCore.make_zero(
::Type{FT}, seen::IdDict, prev::FT, ::Val{copy_if_inactive} = Val(false)
) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:SArray{S,T},copy_if_inactive}
return Base.zero(prev)::FT
end
@inline function Enzyme.EnzymeCore.make_zero(
::Type{FT}, seen::IdDict, prev::FT, ::Val{copy_if_inactive} = Val(false)
) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:MArray{S,T},copy_if_inactive}
if haskey(seen, prev)
return seen[prev]
end
new = Base.zero(prev)::FT
seen[prev] = new
return new
end

@inline function Enzyme.EnzymeCore.make_zero!(
prev::FT, seen
) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:MArray{S,T}}
if !isnothing(seen)
if prev in seen
return nothing
end
push!(seen, prev)
end
fill!(prev, zero(T))
return nothing
end
@inline function Enzyme.EnzymeCore.make_zero!(
prev::FT
) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:MArray{S,T}}
Enzyme.EnzymeCore.make_zero!(prev, nothing)
return nothing
end

end
Loading

0 comments on commit 45f01bd

Please sign in to comment.