From 1581f432c2b97cd4408f9976ed935dea41065c15 Mon Sep 17 00:00:00 2001 From: zhujch Date: Tue, 21 Nov 2023 18:10:58 -0500 Subject: [PATCH 1/4] "NNlib Extension" --- Project.toml | 3 +++ ext/TaylorDiffNNlibExt.jl | 18 ++++++++++++++++++ src/codegen.jl | 3 ++- src/derivative.jl | 2 +- src/scalar.jl | 2 ++ 5 files changed, 26 insertions(+), 2 deletions(-) create mode 100644 ext/TaylorDiffNNlibExt.jl diff --git a/Project.toml b/Project.toml index 9a56b55..abdd078 100644 --- a/Project.toml +++ b/Project.toml @@ -12,9 +12,11 @@ Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [weakdeps] +NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" [extensions] +TaylorDiffNNlibExt = ["NNlib"] TaylorDiffSFExt = ["SpecialFunctions"] [compat] @@ -22,6 +24,7 @@ ChainRules = "1" ChainRulesCore = "1" ChainRulesOverloadGeneration = "0.1" SpecialFunctions = "2" +NNlib = "0.9" SymbolicUtils = "1" Symbolics = "5" Zygote = "0.6.55" diff --git a/ext/TaylorDiffNNlibExt.jl b/ext/TaylorDiffNNlibExt.jl new file mode 100644 index 0000000..7d7cccf --- /dev/null +++ b/ext/TaylorDiffNNlibExt.jl @@ -0,0 +1,18 @@ +module TaylorDiffNNlibExt + +using TaylorDiff +import NNlib: oftf +import NNlib: sigmoid_fast, tanh_fast, rrelu, leakyrelu + +println("revise!") + +@inline sigmoid_fast(t::TaylorScalar) = one(t) / (one(t) + exp(-t)) + +@inline tanh_fast(t::TaylorScalar) = tanh(t) + +@inline function rrelu(t::TaylorScalar{T, N}, l=oftf(t, 1/8), u=oftf(t, 1/3)) where {T, N} + a = (u - l) * rand(float(T)) + l + return leakyrelu(t, a) +end + +end \ No newline at end of file diff --git a/src/codegen.jl b/src/codegen.jl index 6186944..1c87080 100644 --- a/src/codegen.jl +++ b/src/codegen.jl @@ -11,7 +11,8 @@ for func in (+, -, deg2rad, rad2deg, asin, acos, atan, asec, acsc, acot, log, log10, log1p, log2, asinh, acosh, atanh, asech, acsch, - acoth) + acoth, + abs, sign) F = typeof(func) # base case @eval function (op::$F)(t::TaylorScalar{T, 2}) where {T} diff --git a/src/derivative.jl b/src/derivative.jl index 6010a8e..538313a 100644 --- a/src/derivative.jl +++ b/src/derivative.jl @@ -38,7 +38,7 @@ end # Core APIs # Added to help Zygote infer types -make_taylor(t0::T, t1::S, ::Val{N}) where {T, S, N} = TaylorScalar{T, N}(t0, T(t1)) +make_taylor(t0::T, t1::S, ::Val{N}) where {T, S, N} = TaylorScalar{T, N}(t0, convert(T, t1)) @inline function derivative(f, x::T, ::Val{N}) where {T <: TN, N} t = TaylorScalar{T, N}(x, one(x)) diff --git a/src/scalar.jl b/src/scalar.jl index b0b7afd..5a0b840 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -88,3 +88,5 @@ for op in (:+, :-, :*, :/) @eval @inline $op(a::Number, b::TaylorScalar) = $op(promote(a, b)...) end transpose(t::TaylorScalar) = t + +Base.AbstractFloat(x::TaylorScalar{T, N}) where {T, N} = TaylorScalar{Float64, N}(convert(NTuple{N, Float64}, x.value)) \ No newline at end of file From e5ac0a165f1ec48ac6e1d50405df692167a06d7b Mon Sep 17 00:00:00 2001 From: zhujch Date: Wed, 22 Nov 2023 14:57:06 -0500 Subject: [PATCH 2/4] NNlibExt --- ext/TaylorDiffNNlibExt.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/ext/TaylorDiffNNlibExt.jl b/ext/TaylorDiffNNlibExt.jl index 7d7cccf..64a32d7 100644 --- a/ext/TaylorDiffNNlibExt.jl +++ b/ext/TaylorDiffNNlibExt.jl @@ -4,7 +4,6 @@ using TaylorDiff import NNlib: oftf import NNlib: sigmoid_fast, tanh_fast, rrelu, leakyrelu -println("revise!") @inline sigmoid_fast(t::TaylorScalar) = one(t) / (one(t) + exp(-t)) From eefb212acd16b161f711ab4b0955496c0dfdf07e Mon Sep 17 00:00:00 2001 From: zhujch Date: Tue, 28 Nov 2023 17:28:28 -0500 Subject: [PATCH 3/4] format --- ext/TaylorDiffNNlibExt.jl | 7 ++++--- src/codegen.jl | 2 +- src/scalar.jl | 4 +++- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/ext/TaylorDiffNNlibExt.jl b/ext/TaylorDiffNNlibExt.jl index 64a32d7..02e728e 100644 --- a/ext/TaylorDiffNNlibExt.jl +++ b/ext/TaylorDiffNNlibExt.jl @@ -4,14 +4,15 @@ using TaylorDiff import NNlib: oftf import NNlib: sigmoid_fast, tanh_fast, rrelu, leakyrelu - @inline sigmoid_fast(t::TaylorScalar) = one(t) / (one(t) + exp(-t)) @inline tanh_fast(t::TaylorScalar) = tanh(t) -@inline function rrelu(t::TaylorScalar{T, N}, l=oftf(t, 1/8), u=oftf(t, 1/3)) where {T, N} +@inline function rrelu(t::TaylorScalar{T, N}, + l = oftf(t, 1 / 8), + u = oftf(t, 1 / 3)) where {T, N} a = (u - l) * rand(float(T)) + l return leakyrelu(t, a) end -end \ No newline at end of file +end diff --git a/src/codegen.jl b/src/codegen.jl index 1c87080..6a5c003 100644 --- a/src/codegen.jl +++ b/src/codegen.jl @@ -11,7 +11,7 @@ for func in (+, -, deg2rad, rad2deg, asin, acos, atan, asec, acsc, acot, log, log10, log1p, log2, asinh, acosh, atanh, asech, acsch, - acoth, + acoth, abs, sign) F = typeof(func) # base case diff --git a/src/scalar.jl b/src/scalar.jl index 5a0b840..4bba134 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -89,4 +89,6 @@ for op in (:+, :-, :*, :/) end transpose(t::TaylorScalar) = t -Base.AbstractFloat(x::TaylorScalar{T, N}) where {T, N} = TaylorScalar{Float64, N}(convert(NTuple{N, Float64}, x.value)) \ No newline at end of file +function Base.AbstractFloat(x::TaylorScalar{T, N}) where {T, N} + TaylorScalar{Float64, N}(convert(NTuple{N, Float64}, x.value)) +end From 8c77e7247022b6eb5ca653a1310f3e53aef4f903 Mon Sep 17 00:00:00 2001 From: zhujch Date: Tue, 28 Nov 2023 17:29:05 -0500 Subject: [PATCH 4/4] support more special functions --- ext/TaylorDiffSFExt.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ext/TaylorDiffSFExt.jl b/ext/TaylorDiffSFExt.jl index c2059da..41a1d16 100644 --- a/ext/TaylorDiffSFExt.jl +++ b/ext/TaylorDiffSFExt.jl @@ -8,7 +8,8 @@ using ChainRules, ChainRulesCore dummy = (NoTangent(), 1) @variables z -for func in (erf,) +# logerfc, logerfcx, erfinv, gamma, digamma, trigamma +for func in (erf, erfc, erfcinv, erfcx, erfi) F = typeof(func) # base case @eval function (op::$F)(t::TaylorScalar{T, 2}) where {T}