Skip to content

Commit

Permalink
"NNlib Extension"
Browse files Browse the repository at this point in the history
  • Loading branch information
zhujch1 committed Nov 28, 2023
1 parent 770c190 commit 1581f43
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 2 deletions.
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,19 @@ Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[weakdeps]
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"

[extensions]
TaylorDiffNNlibExt = ["NNlib"]
TaylorDiffSFExt = ["SpecialFunctions"]

[compat]
ChainRules = "1"
ChainRulesCore = "1"
ChainRulesOverloadGeneration = "0.1"
SpecialFunctions = "2"
NNlib = "0.9"
SymbolicUtils = "1"
Symbolics = "5"
Zygote = "0.6.55"
Expand Down
18 changes: 18 additions & 0 deletions ext/TaylorDiffNNlibExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
module TaylorDiffNNlibExt

using TaylorDiff
import NNlib: oftf
import NNlib: sigmoid_fast, tanh_fast, rrelu, leakyrelu

println("revise!")

@inline sigmoid_fast(t::TaylorScalar) = one(t) / (one(t) + exp(-t))

@inline tanh_fast(t::TaylorScalar) = tanh(t)

@inline function rrelu(t::TaylorScalar{T, N}, l=oftf(t, 1/8), u=oftf(t, 1/3)) where {T, N}
a = (u - l) * rand(float(T)) + l
return leakyrelu(t, a)
end

end
3 changes: 2 additions & 1 deletion src/codegen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ for func in (+, -, deg2rad, rad2deg,
asin, acos, atan, asec, acsc, acot,
log, log10, log1p, log2,
asinh, acosh, atanh, asech, acsch,
acoth)
acoth,
abs, sign)
F = typeof(func)
# base case
@eval function (op::$F)(t::TaylorScalar{T, 2}) where {T}
Expand Down
2 changes: 1 addition & 1 deletion src/derivative.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ end
# Core APIs

# Added to help Zygote infer types
make_taylor(t0::T, t1::S, ::Val{N}) where {T, S, N} = TaylorScalar{T, N}(t0, T(t1))
make_taylor(t0::T, t1::S, ::Val{N}) where {T, S, N} = TaylorScalar{T, N}(t0, convert(T, t1))

@inline function derivative(f, x::T, ::Val{N}) where {T <: TN, N}
t = TaylorScalar{T, N}(x, one(x))
Expand Down
2 changes: 2 additions & 0 deletions src/scalar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,5 @@ for op in (:+, :-, :*, :/)
@eval @inline $op(a::Number, b::TaylorScalar) = $op(promote(a, b)...)
end
transpose(t::TaylorScalar) = t

Base.AbstractFloat(x::TaylorScalar{T, N}) where {T, N} = TaylorScalar{Float64, N}(convert(NTuple{N, Float64}, x.value))

0 comments on commit 1581f43

Please sign in to comment.