Skip to content

Commit

Permalink
Remove redundant frule definitions
Browse files Browse the repository at this point in the history
  • Loading branch information
zhujch1 committed Nov 7, 2023
1 parent 61326ef commit 6f0fcd5
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 30 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,16 @@ IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
SliceMap = "82cb661a-3f19-5665-9e27-df437c7e54c8"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ChainRules = "1"
ChainRulesCore = "1"
ChainRulesOverloadGeneration = "0.1"
IrrationalConstants = "0.2"
SliceMap = "0.2"
SpecialFunctions = "2"
IrrationalConstants = "0.2"
SymbolicUtils = "1"
Zygote = "0.6.55"
julia = "1.6"
35 changes: 6 additions & 29 deletions src/codegen.jl
Original file line number Diff line number Diff line change
@@ -1,36 +1,13 @@
using ChainRules
using ChainRulesCore
using SpecialFunctions
using IrrationalConstants: sqrtπ
using Symbolics: @variables
using SymbolicUtils, SymbolicUtils.Code
using SymbolicUtils: BasicSymbolic, Pow

@scalar_rule +(x::BasicSymbolic) true
@scalar_rule -(x::BasicSymbolic) -1
@scalar_rule deg2rad(x::BasicSymbolic) deg2rad(one(x))
@scalar_rule rad2deg(x::BasicSymbolic) rad2deg(one(x))
@scalar_rule asin(x::BasicSymbolic) inv(sqrt(1 - x^2))
@scalar_rule acos(x::BasicSymbolic) inv(-sqrt(1 - x^2))
@scalar_rule atan(x::BasicSymbolic) inv(-(1 + x^2))
@scalar_rule acot(x::BasicSymbolic) inv(-(1 + x^2))
@scalar_rule acsc(x::BasicSymbolic) inv(x^2 * -sqrt(1 - x^-2))
@scalar_rule asec(x::BasicSymbolic) inv(x^2 * sqrt(1 - x^-2))
@scalar_rule log(x::BasicSymbolic) inv(x)
@scalar_rule log10(x::BasicSymbolic) inv(log(10.0) * x)
@scalar_rule log1p(x::BasicSymbolic) inv(x + 1)
@scalar_rule log2(x::BasicSymbolic) inv(log(2.0) * x)
@scalar_rule sinh(x::BasicSymbolic) cosh(x)
@scalar_rule cosh(x::BasicSymbolic) sinh(x)
@scalar_rule tanh(x::BasicSymbolic) 1-Ω^2
@scalar_rule acosh(x::BasicSymbolic) inv(sqrt(x - 1) * sqrt(x + 1))
@scalar_rule acoth(x::BasicSymbolic) inv(1 - x^2)
@scalar_rule acsch(x::BasicSymbolic) inv(x^2 * -sqrt(1 + x^-2))
@scalar_rule asech(x::BasicSymbolic) inv(x * -sqrt(1 - x^2))
@scalar_rule asinh(x::BasicSymbolic) inv(sqrt(x^2 + 1))
@scalar_rule atanh(x::BasicSymbolic) inv(1 - x^2)
@scalar_rule erf(x::BasicSymbolic) exp(-x^2) * 2/sqrtπ
using SymbolicUtils: Pow

dummy = (NoTangent(), 1)
@syms t₁
@variables z
for func in (+, -, deg2rad, rad2deg,
sinh, cosh, tanh,
asin, acos, atan, asec, acsc, acot,
Expand All @@ -43,15 +20,15 @@ for func in (+, -, deg2rad, rad2deg,
t0, t1 = value(t)
TaylorScalar{T, 2}(frule((NoTangent(), t1), op, t0))
end
der = frule(dummy, func, t₁)[2]
der = frule(dummy, func, z)[2]
term, raiser = der isa Pow && der.exp == -1 ? (der.base, raiseinv) : (der, raise)
# recursion by raising
@eval @generated function (op::$F)(t::TaylorScalar{T, N}) where {T, N}
der_expr = $(QuoteNode(toexpr(term)))
f = $func
quote
$(Expr(:meta, :inline))
t₁ = TaylorScalar{T, N - 1}(t)
z = TaylorScalar{T, N - 1}(t)
df = $der_expr
$$raiser($f(value(t)[1]), df, t)
end
Expand Down

0 comments on commit 6f0fcd5

Please sign in to comment.