Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ext #66

Merged
merged 3 commits into from
Nov 9, 2023
Merged

Ext #66

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,20 @@ version = "0.2.1"
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesOverloadGeneration = "f51149dc-2911-5acf-81fc-2076a2a81d4f"
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[weakdeps]
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"

[extensions]
TaylorDiffSFExt = ["SpecialFunctions"]

[compat]
ChainRules = "1"
ChainRulesCore = "1"
ChainRulesOverloadGeneration = "0.1"
IrrationalConstants = "0.2"
SpecialFunctions = "2"
SymbolicUtils = "1"
Symbolics = "5"
Expand Down
2 changes: 1 addition & 1 deletion benchmark/mlp.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
function create_benchmark_mlp(mlp_conf::Tuple{Int, Int}, x::Vector{T},
l::Vector{T}) where {T <: Number}
l::Vector{T}) where {T <: Number}
input, hidden = mlp_conf
W₁, W₂, b₁, b₂ = rand(hidden, input), rand(1, hidden), rand(hidden), rand(1)
σ = exp
Expand Down
33 changes: 33 additions & 0 deletions ext/TaylorDiffSFExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
module TaylorDiffSFExt
using TaylorDiff, SpecialFunctions
using Symbolics: @variables
using SymbolicUtils, SymbolicUtils.Code
using SymbolicUtils: Pow
using TaylorDiff: value, raise
using ChainRules, ChainRulesCore

dummy = (NoTangent(), 1)
@variables z
for func in (erf,)
F = typeof(func)
# base case
@eval function (op::$F)(t::TaylorScalar{T, 2}) where {T}
t0, t1 = value(t)
TaylorScalar{T, 2}(frule((NoTangent(), t1), op, t0))

Check warning on line 16 in ext/TaylorDiffSFExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/TaylorDiffSFExt.jl#L14-L16

Added lines #L14 - L16 were not covered by tests
end
der = frule(dummy, func, z)[2]
term, raiser = der isa Pow && der.exp == -1 ? (der.base, raiseinv) : (der, raise)
# recursion by raising
@eval @generated function (op::$F)(t::TaylorScalar{T, N}) where {T, N}
der_expr = $(QuoteNode(toexpr(term)))
f = $func
quote
$(Expr(:meta, :inline))
z = TaylorScalar{T, N - 1}(t)
df = $der_expr
$$raiser($f(value(t)[1]), df, t)

Check warning on line 28 in ext/TaylorDiffSFExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/TaylorDiffSFExt.jl#L21-L28

Added lines #L21 - L28 were not covered by tests
end
end
end

end
16 changes: 9 additions & 7 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
end

function rrule(::typeof(extract_derivative), t::TaylorScalar{T, N},
i::Integer) where {N, T}
i::Integer) where {N, T}
function extract_derivative_pullback(d̄)
NoTangent(), TaylorScalar{T, N}(ntuple(j -> j === i ? d̄ : zero(T), Val(N))),
NoTangent()
Expand All @@ -31,7 +31,7 @@
end

function rrule(::typeof(*), A::AbstractMatrix{S},
t::AbstractVector{TaylorScalar{T, N}}) where {N, S, T}
t::AbstractVector{TaylorScalar{T, N}}) where {N, S, T}
project_A = ProjectTo(A)
function gemv_pullback(x̄)
x̂ = reinterpret(reshape, T, x̄)
Expand All @@ -41,15 +41,17 @@
return A * t, gemv_pullback
end

function rrule(::typeof(*), A::AbstractMatrix{S},

Check warning on line 44 in src/chainrules.jl

View check run for this annotation

Codecov / codecov/patch

src/chainrules.jl#L44

Added line #L44 was not covered by tests
B::AbstractMatrix{TaylorScalar{T, N}}) where {N, S, T}
B::AbstractMatrix{TaylorScalar{T, N}}) where {N, S, T}
project_A = ProjectTo(A)
project_B = ProjectTo(B)
function gemm_pullback(x̄)
X̄ = unthunk(x̄)
NoTangent(), @thunk(project_A(X̄ * transpose(B))), @thunk(project_B(transpose(A) * X̄))
NoTangent(),
@thunk(project_A(X̄ * transpose(B))),
@thunk(project_B(transpose(A) * X̄))

Check warning on line 52 in src/chainrules.jl

View check run for this annotation

Codecov / codecov/patch

src/chainrules.jl#L46-L52

Added lines #L46 - L52 were not covered by tests
end
return A * B, gemm_pullback

Check warning on line 54 in src/chainrules.jl

View check run for this annotation

Codecov / codecov/patch

src/chainrules.jl#L54

Added line #L54 was not covered by tests
end

@adjoint function +(t::Vector{TaylorScalar{T, N}}, v::Vector{T}) where {N, T}
Expand Down Expand Up @@ -85,8 +87,8 @@
ind::I
axes::A
function TaylorOneElement(val::T, ind::I,
axes::A) where {T <: TaylorScalar, I <: NTuple{N, Int},
A <: NTuple{N, AbstractUnitRange}} where {N}
axes::A) where {T <: TaylorScalar, I <: NTuple{N, Int},
A <: NTuple{N, AbstractUnitRange}} where {N}
new{T, N, I, A}(val, ind, axes)
end
end
Expand Down Expand Up @@ -125,7 +127,7 @@
end

function rrule(::typeof(*), x::TaylorScalar, y::TaylorScalar, z::TaylorScalar,
more::TaylorScalar...)
more::TaylorScalar...)
Ω2, back2 = rrule(*, x, y)
Ω3, back3 = rrule(*, Ω2, z)
Ω4, back4 = rrule(*, Ω3, more...)
Expand Down
4 changes: 1 addition & 3 deletions src/codegen.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
using ChainRules
using ChainRulesCore
using SpecialFunctions
using IrrationalConstants: sqrtπ
using Symbolics: @variables
using SymbolicUtils, SymbolicUtils.Code
using SymbolicUtils: Pow
Expand All @@ -13,7 +11,7 @@ for func in (+, -, deg2rad, rad2deg,
asin, acos, atan, asec, acsc, acot,
log, log10, log1p, log2,
asinh, acosh, atanh, asech, acsch,
acoth, erf)
acoth)
F = typeof(func)
# base case
@eval function (op::$F)(t::TaylorScalar{T, 2}) where {T}
Expand Down
21 changes: 11 additions & 10 deletions src/derivative.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,25 @@
export derivative

"""
derivative(f, x::T, order::Int64)
derivative(f, x, order::Int64)
derivative(f, x, l, order::Int64)

Wrapper functions for converting order from a number to a type. Actual APIs are detailed below:

derivative(f, x::T, ::Val{N})

Computes `order`-th derivative of `f` w.r.t. scalar `x`.

derivative(f, x::AbstractVector{T}, l::AbstractVector{T}, order::Int64)
derivative(f, x::AbstractVector{T}, l::AbstractVector{T}, ::Val{N})

Computes `order`-th directional derivative of `f` w.r.t. vector `x` in direction `l`.

derivative(f, x::AbstractMatrix{T}, order::Int64)
derivative(f, x::AbstractMatrix{T}, ::Val{N})
derivative(f, x::AbstractMatrix{T}, l::AbstractVector{T}, order::Int64)
derivative(f, x::AbstractMatrix{T}, l::AbstractVector{T}, ::Val{N})

Shorthand notations for multiple calculations.
For a M-by-N matrix, calculate the directional derivative for each column.
For a 1-by-N matrix (row vector), calculate the derivative for each scalar.
Batch mode derivative / directional derivative calculations, where each column of `x` represents a scalar or a vector. `f` is expected to accept matrices as input.
- For a M-by-N matrix, calculate the directional derivative for each column.
- For a 1-by-N matrix (row vector), calculate the derivative for each scalar.
"""
function derivative end

Expand All @@ -45,7 +46,7 @@ make_taylor(t0::T, t1::S, ::Val{N}) where {T, S, N} = TaylorScalar{T, N}(t0, T(t
end

@inline function derivative(f, x::AbstractVector{T}, l::AbstractVector{S},
vN::Val{N}) where {T <: TN, S <: TN, N}
vN::Val{N}) where {T <: TN, S <: TN, N}
t = map((t0, t1) -> make_taylor(t0, t1, vN), x, l)
# equivalent to map(TaylorScalar{T, N}, x, l)
return extract_derivative(f(t), N)
Expand All @@ -55,12 +56,12 @@ end

@inline function derivative(f, x::AbstractMatrix{T}, vN::Val{N}) where {T <: TN, N}
size(x)[1] != 1 && @warn "x is not a row vector."
t = make_taylor.(x, one(N), vN)
t = make_taylor.(x, one(T), vN)
return extract_derivative.(f(t), N)
end

@inline function derivative(f, x::AbstractMatrix{T}, l::AbstractVector{S},
vN::Val{N}) where {T <: TN, S <: TN, N}
vN::Val{N}) where {T <: TN, S <: TN, N}
t = make_taylor.(x, l, vN)
return extract_derivative.(f(t), N)
end
4 changes: 2 additions & 2 deletions src/primitive.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ end
end

@generated function raise(f::T, df::TaylorScalar{T, M},
t::TaylorScalar{T, N}) where {T, M, N} # M + 1 == N
t::TaylorScalar{T, N}) where {T, M, N} # M + 1 == N
return quote
$(Expr(:meta, :inline))
vdf, vt = value(df), value(t)
Expand All @@ -162,7 +162,7 @@ end
raise(::T, df::S, t::TaylorScalar{T, N}) where {S <: Number, T, N} = df * t

@generated function raiseinv(f::T, df::TaylorScalar{T, M},
t::TaylorScalar{T, N}) where {T, M, N} # M + 1 == N
t::TaylorScalar{T, N}) where {T, M, N} # M + 1 == N
ex = quote
vdf, vt = value(df), value(t)
v1 = vt[2] / vdf[1]
Expand Down
4 changes: 2 additions & 2 deletions src/scalar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ end
@inline value(t::TaylorScalar) = t.value
@inline extract_derivative(t::TaylorScalar, i::Integer) = t.value[i]
@inline function extract_derivative(v::AbstractArray{T},
i::Integer) where {T <: TaylorScalar}
i::Integer) where {T <: TaylorScalar}
map(t -> extract_derivative(t, i), v)
end
@inline extract_derivative(r, i::Integer) = false
Expand All @@ -73,7 +73,7 @@ adjoint(t::TaylorScalar) = t
conj(t::TaylorScalar) = t

function promote_rule(::Type{TaylorScalar{T, N}},
::Type{S}) where {T, S, N}
::Type{S}) where {T, S, N}
TaylorScalar{promote_type(T, S), N}
end

Expand Down
16 changes: 16 additions & 0 deletions test/derivative.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@

@testset "Derivative" begin
g(x) = x^3
@test derivative(g, 1.0, 1) ≈ 3

h(x) = x .^ 3
@test derivative(h, [2.0 3.0], 1) ≈ [12.0 27.0]
end

@testset "Directional derivative" begin
g(x) = x[1] * x[1] + x[2] * x[2]
@test derivative(g, [1.0, 2.0], [1.0, 0.0], 1) ≈ 2.0

h(x) = sum(x, dims = 1)
@test derivative(h, [1.0 2.0; 2.0 3.0], [1.0, 1.0], 1) ≈ [2.0 2.0]
end
3 changes: 1 addition & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
using TaylorDiff
using Test

include("scalar.jl")
include("vector.jl")
include("primitive.jl")
include("derivative.jl")
include("zygote.jl")
# include("lux.jl")
6 changes: 0 additions & 6 deletions test/scalar.jl

This file was deleted.

6 changes: 0 additions & 6 deletions test/vector.jl

This file was deleted.

8 changes: 4 additions & 4 deletions test/zygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@ using Zygote, LinearAlgebra
for f in (exp, log, sqrt, sin, asin, sinh, asinh)
@test gradient(x -> derivative(f, x, 2), some_number)[1] ≈
derivative(f, some_number, 3)
derivative_result = vec(derivative(f, some_numbers, 3))
@test Zygote.jacobian(x -> derivative(f, x, 2), some_numbers)[1] ≈
derivative_result = vec(derivative.(f, some_numbers, 3))
@test Zygote.jacobian(x -> derivative.(f, x, 2), some_numbers)[1] ≈
diagm(derivative_result)
end

some_matrix = [0.7 0.1; 0.4 0.2]
f = x -> sum(tanh.(x), dims = 1)
dfdx1(m, x) = derivative(u -> sum(m(u)), x, [1.0, 0.0], 1)
dfdx2(m, x) = derivative(u -> sum(m(u)), x, [0.0, 1.0], 1)
dfdx1(m, x) = derivative(m, x, [1.0, 0.0], 1)
dfdx2(m, x) = derivative(m, x, [0.0, 1.0], 1)
res(m, x) = dfdx1(m, x) .+ 2 * dfdx2(m, x)
grads = Zygote.gradient(some_matrix) do x
sum(res(f, x))
Expand Down
Loading