Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
zhujch1 authored and tansongchen committed Nov 9, 2023
1 parent eff4ac7 commit 9a037b3
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 27 deletions.
1 change: 0 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ TaylorDiffSFExt = ["SpecialFunctions"]
ChainRules = "1"
ChainRulesCore = "1"
ChainRulesOverloadGeneration = "0.1"
IrrationalConstants = "0.2"
SpecialFunctions = "2"
SymbolicUtils = "1"
Symbolics = "5"
Expand Down
17 changes: 9 additions & 8 deletions src/derivative.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,25 @@
export derivative

"""
derivative(f, x::T, order::Int64)
derivative(f, x, order::Int64)
derivative(f, x, l, order::Int64)
Wrapper functions for converting order from a number to a type. Actual APIs are detailed below:
derivative(f, x::T, ::Val{N})
Computes `order`-th derivative of `f` w.r.t. scalar `x`.
derivative(f, x::AbstractVector{T}, l::AbstractVector{T}, order::Int64)
derivative(f, x::AbstractVector{T}, l::AbstractVector{T}, ::Val{N})
Computes `order`-th directional derivative of `f` w.r.t. vector `x` in direction `l`.
derivative(f, x::AbstractMatrix{T}, order::Int64)
derivative(f, x::AbstractMatrix{T}, ::Val{N})
derivative(f, x::AbstractMatrix{T}, l::AbstractVector{T}, order::Int64)
derivative(f, x::AbstractMatrix{T}, l::AbstractVector{T}, ::Val{N})
Shorthand notations for multiple calculations.
For a M-by-N matrix, calculate the directional derivative for each column.
For a 1-by-N matrix (row vector), calculate the derivative for each scalar.
Batch mode derivative / directional derivative calculations, where each column of `x` represents a scalar or a vector. `f` is expected to accept matrices as input.
- For a M-by-N matrix, calculate the directional derivative for each column.
- For a 1-by-N matrix (row vector), calculate the derivative for each scalar.
"""
function derivative end

Expand Down Expand Up @@ -55,7 +56,7 @@ end

@inline function derivative(f, x::AbstractMatrix{T}, vN::Val{N}) where {T <: TN, N}
size(x)[1] != 1 && @warn "x is not a row vector."
t = make_taylor.(x, one(N), vN)
t = make_taylor.(x, one(T), vN)
return extract_derivative.(f(t), N)
end

Expand Down
16 changes: 16 additions & 0 deletions test/derivative.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@

@testset "Derivative" begin
g(x) = x^3
@test derivative(g, 1.0, 1) 3

h(x) = x.^3
@test derivative(h, [2.0 3.0], 1) [12.0 27.0]
end

@testset "Directional derivative" begin
g(x) = x[1] * x[1] + x[2] * x[2]
@test derivative(g, [1.0, 2.0], [1.0, 0.0], 1) 2.0

h(x) = sum(x, dims=1)
@test derivative(h, [1.0 2.0; 2.0 3.0], [1.0, 1.0], 1) [2. 2.]
end
3 changes: 1 addition & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
using TaylorDiff
using Test

include("scalar.jl")
include("vector.jl")
include("primitive.jl")
include("derivative.jl")
include("zygote.jl")
# include("lux.jl")
6 changes: 0 additions & 6 deletions test/scalar.jl

This file was deleted.

6 changes: 0 additions & 6 deletions test/vector.jl

This file was deleted.

8 changes: 4 additions & 4 deletions test/zygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@ using Zygote, LinearAlgebra
for f in (exp, log, sqrt, sin, asin, sinh, asinh)
@test gradient(x -> derivative(f, x, 2), some_number)[1]
derivative(f, some_number, 3)
derivative_result = vec(derivative(f, some_numbers, 3))
@test Zygote.jacobian(x -> derivative(f, x, 2), some_numbers)[1]
derivative_result = vec(derivative.(f, some_numbers, 3))
@test Zygote.jacobian(x -> derivative.(f, x, 2), some_numbers)[1]
diagm(derivative_result)
end

some_matrix = [0.7 0.1; 0.4 0.2]
f = x -> sum(tanh.(x), dims = 1)
dfdx1(m, x) = derivative(u -> sum(m(u)), x, [1.0, 0.0], 1)
dfdx2(m, x) = derivative(u -> sum(m(u)), x, [0.0, 1.0], 1)
dfdx1(m, x) = derivative(m, x, [1.0, 0.0], 1)
dfdx2(m, x) = derivative(m, x, [0.0, 1.0], 1)
res(m, x) = dfdx1(m, x) .+ 2 * dfdx2(m, x)
grads = Zygote.gradient(some_matrix) do x
sum(res(f, x))
Expand Down

0 comments on commit 9a037b3

Please sign in to comment.