Skip to content

Commit

Permalink
Fix more make_zero(!) bugs and add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
danielwe committed Nov 27, 2024
1 parent 6ff40b6 commit 327a2f9
Show file tree
Hide file tree
Showing 3 changed files with 798 additions and 373 deletions.
34 changes: 30 additions & 4 deletions ext/EnzymeStaticArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,37 @@ end
end
end

@inline function Enzyme.EnzymeCore.make_zero(x::FT)::FT where {FT<:SArray}
return Base.zero(x)
@inline function Enzyme.EnzymeCore.make_zero(
::Type{FT}, seen::IdDict, prev::FT, ::Val{copy_if_inactive}
) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:SArray{S,T},copy_if_inactive}
return Base.zero(prev)::FT
end
@inline function Enzyme.EnzymeCore.make_zero(x::FT)::FT where {FT<:MArray}
return Base.zero(x)
@inline function Enzyme.EnzymeCore.make_zero(
::Type{FT}, seen::IdDict, prev::FT, ::Val{copy_if_inactive}
) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:MArray{S,T},copy_if_inactive}
if haskey(seen, prev)
return seen[prev]
end
new = Base.zero(prev)::FT
seen[prev] = new
return new
end
@inline function Enzyme.EnzymeCore.make_zero!(
prev::FT, seen
) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:MArray{S,T}}
if !isnothing(seen)
if prev in seen
return nothing
end
push!(seen, prev)
end
fill!(prev, zero(T))
return nothing
end
@inline function Enzyme.EnzymeCore.make_zero!(
prev::FT
) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:MArray{S,T}}
return Enzyme.EnzymeCore.make_zero!(prev, nothing)
end

end
Loading

0 comments on commit 327a2f9

Please sign in to comment.