diff --git a/Project.toml b/Project.toml index 9a56b55..abdd078 100644 --- a/Project.toml +++ b/Project.toml @@ -12,9 +12,11 @@ Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [weakdeps] +NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" [extensions] +TaylorDiffNNlibExt = ["NNlib"] TaylorDiffSFExt = ["SpecialFunctions"] [compat] @@ -22,6 +24,7 @@ ChainRules = "1" ChainRulesCore = "1" ChainRulesOverloadGeneration = "0.1" SpecialFunctions = "2" +NNlib = "0.9" SymbolicUtils = "1" Symbolics = "5" Zygote = "0.6.55" diff --git a/ext/TaylorDiffNNlibExt.jl b/ext/TaylorDiffNNlibExt.jl new file mode 100644 index 0000000..02e728e --- /dev/null +++ b/ext/TaylorDiffNNlibExt.jl @@ -0,0 +1,18 @@ +module TaylorDiffNNlibExt + +using TaylorDiff +import NNlib: oftf +import NNlib: sigmoid_fast, tanh_fast, rrelu, leakyrelu + +@inline sigmoid_fast(t::TaylorScalar) = one(t) / (one(t) + exp(-t)) + +@inline tanh_fast(t::TaylorScalar) = tanh(t) + +@inline function rrelu(t::TaylorScalar{T, N}, + l = oftf(t, 1 / 8), + u = oftf(t, 1 / 3)) where {T, N} + a = (u - l) * rand(float(T)) + l + return leakyrelu(t, a) +end + +end diff --git a/ext/TaylorDiffSFExt.jl b/ext/TaylorDiffSFExt.jl index c2059da..41a1d16 100644 --- a/ext/TaylorDiffSFExt.jl +++ b/ext/TaylorDiffSFExt.jl @@ -8,7 +8,8 @@ using ChainRules, ChainRulesCore dummy = (NoTangent(), 1) @variables z -for func in (erf,) +# logerfc, logerfcx, erfinv, gamma, digamma, trigamma +for func in (erf, erfc, erfcinv, erfcx, erfi) F = typeof(func) # base case @eval function (op::$F)(t::TaylorScalar{T, 2}) where {T} diff --git a/src/codegen.jl b/src/codegen.jl index 6186944..6a5c003 100644 --- a/src/codegen.jl +++ b/src/codegen.jl @@ -11,7 +11,8 @@ for func in (+, -, deg2rad, rad2deg, asin, acos, atan, asec, acsc, acot, log, log10, log1p, log2, asinh, acosh, atanh, asech, acsch, - acoth) + acoth, + abs, sign) F = typeof(func) # base case @eval function (op::$F)(t::TaylorScalar{T, 2}) where {T} diff --git a/src/derivative.jl b/src/derivative.jl index 6010a8e..538313a 100644 --- a/src/derivative.jl +++ b/src/derivative.jl @@ -38,7 +38,7 @@ end # Core APIs # Added to help Zygote infer types -make_taylor(t0::T, t1::S, ::Val{N}) where {T, S, N} = TaylorScalar{T, N}(t0, T(t1)) +make_taylor(t0::T, t1::S, ::Val{N}) where {T, S, N} = TaylorScalar{T, N}(t0, convert(T, t1)) @inline function derivative(f, x::T, ::Val{N}) where {T <: TN, N} t = TaylorScalar{T, N}(x, one(x)) diff --git a/src/scalar.jl b/src/scalar.jl index b0b7afd..4bba134 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -88,3 +88,7 @@ for op in (:+, :-, :*, :/) @eval @inline $op(a::Number, b::TaylorScalar) = $op(promote(a, b)...) end transpose(t::TaylorScalar) = t + +function Base.AbstractFloat(x::TaylorScalar{T, N}) where {T, N} + TaylorScalar{Float64, N}(convert(NTuple{N, Float64}, x.value)) +end