diff --git a/ext/EnzymeStaticArraysExt.jl b/ext/EnzymeStaticArraysExt.jl index 7e1d131116..ef955ebd9b 100644 --- a/ext/EnzymeStaticArraysExt.jl +++ b/ext/EnzymeStaticArraysExt.jl @@ -33,12 +33,23 @@ end end @inline function Enzyme.EnzymeCore.make_zero( - ::Type{FT}, seen::IdDict, prev::FT, ::Val{copy_if_inactive} + prev::FT +) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:SArray{S,T}} + return Base.zero(prev)::FT +end +@inline function Enzyme.EnzymeCore.make_zero( + prev::FT +) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:MArray{S,T}} + return Base.zero(prev)::FT +end + +@inline function Enzyme.EnzymeCore.make_zero( + ::Type{FT}, seen::IdDict, prev::FT, ::Val{copy_if_inactive} = Val(false) ) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:SArray{S,T},copy_if_inactive} return Base.zero(prev)::FT end @inline function Enzyme.EnzymeCore.make_zero( - ::Type{FT}, seen::IdDict, prev::FT, ::Val{copy_if_inactive} + ::Type{FT}, seen::IdDict, prev::FT, ::Val{copy_if_inactive} = Val(false) ) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:MArray{S,T},copy_if_inactive} if haskey(seen, prev) return seen[prev] @@ -47,6 +58,7 @@ end seen[prev] = new return new end + @inline function Enzyme.EnzymeCore.make_zero!( prev::FT, seen ) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:MArray{S,T}} @@ -62,7 +74,8 @@ end @inline function Enzyme.EnzymeCore.make_zero!( prev::FT ) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:MArray{S,T}} - return Enzyme.EnzymeCore.make_zero!(prev, nothing) + Enzyme.EnzymeCore.make_zero!(prev, nothing) + return nothing end end