From 9a037b36184bbf73f4c7c2b4ee6930df36b1cc8c Mon Sep 17 00:00:00 2001 From: zhujch Date: Wed, 8 Nov 2023 16:29:35 -0500 Subject: [PATCH] Update tests --- Project.toml | 1 - src/derivative.jl | 17 +++++++++-------- test/derivative.jl | 16 ++++++++++++++++ test/runtests.jl | 3 +-- test/scalar.jl | 6 ------ test/vector.jl | 6 ------ test/zygote.jl | 8 ++++---- 7 files changed, 30 insertions(+), 27 deletions(-) create mode 100644 test/derivative.jl delete mode 100644 test/scalar.jl delete mode 100644 test/vector.jl diff --git a/Project.toml b/Project.toml index 011789f..9a56b55 100644 --- a/Project.toml +++ b/Project.toml @@ -21,7 +21,6 @@ TaylorDiffSFExt = ["SpecialFunctions"] ChainRules = "1" ChainRulesCore = "1" ChainRulesOverloadGeneration = "0.1" -IrrationalConstants = "0.2" SpecialFunctions = "2" SymbolicUtils = "1" Symbolics = "5" diff --git a/src/derivative.jl b/src/derivative.jl index 62a60c7..e3559d3 100644 --- a/src/derivative.jl +++ b/src/derivative.jl @@ -2,24 +2,25 @@ export derivative """ - derivative(f, x::T, order::Int64) + derivative(f, x, order::Int64) + derivative(f, x, l, order::Int64) + +Wrapper functions for converting order from a number to a type. Actual APIs are detailed below: + derivative(f, x::T, ::Val{N}) Computes `order`-th derivative of `f` w.r.t. scalar `x`. - derivative(f, x::AbstractVector{T}, l::AbstractVector{T}, order::Int64) derivative(f, x::AbstractVector{T}, l::AbstractVector{T}, ::Val{N}) Computes `order`-th directional derivative of `f` w.r.t. vector `x` in direction `l`. - derivative(f, x::AbstractMatrix{T}, order::Int64) derivative(f, x::AbstractMatrix{T}, ::Val{N}) - derivative(f, x::AbstractMatrix{T}, l::AbstractVector{T}, order::Int64) derivative(f, x::AbstractMatrix{T}, l::AbstractVector{T}, ::Val{N}) -Shorthand notations for multiple calculations. -For a M-by-N matrix, calculate the directional derivative for each column. -For a 1-by-N matrix (row vector), calculate the derivative for each scalar. +Batch mode derivative / directional derivative calculations, where each column of `x` represents a scalar or a vector. `f` is expected to accept matrices as input. +- For a M-by-N matrix, calculate the directional derivative for each column. +- For a 1-by-N matrix (row vector), calculate the derivative for each scalar. """ function derivative end @@ -55,7 +56,7 @@ end @inline function derivative(f, x::AbstractMatrix{T}, vN::Val{N}) where {T <: TN, N} size(x)[1] != 1 && @warn "x is not a row vector." - t = make_taylor.(x, one(N), vN) + t = make_taylor.(x, one(T), vN) return extract_derivative.(f(t), N) end diff --git a/test/derivative.jl b/test/derivative.jl new file mode 100644 index 0000000..e33126a --- /dev/null +++ b/test/derivative.jl @@ -0,0 +1,16 @@ + +@testset "Derivative" begin + g(x) = x^3 + @test derivative(g, 1.0, 1) ≈ 3 + + h(x) = x.^3 + @test derivative(h, [2.0 3.0], 1) ≈ [12.0 27.0] +end + +@testset "Directional derivative" begin + g(x) = x[1] * x[1] + x[2] * x[2] + @test derivative(g, [1.0, 2.0], [1.0, 0.0], 1) ≈ 2.0 + + h(x) = sum(x, dims=1) + @test derivative(h, [1.0 2.0; 2.0 3.0], [1.0, 1.0], 1) ≈ [2. 2.] +end diff --git a/test/runtests.jl b/test/runtests.jl index 1ef7378..a28c383 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,8 +1,7 @@ using TaylorDiff using Test -include("scalar.jl") -include("vector.jl") include("primitive.jl") +include("derivative.jl") include("zygote.jl") # include("lux.jl") diff --git a/test/scalar.jl b/test/scalar.jl deleted file mode 100644 index 444fc68..0000000 --- a/test/scalar.jl +++ /dev/null @@ -1,6 +0,0 @@ - -@testset "Scalar" begin - g(x) = x^3 - @test derivative(g, 1.0, 1) ≈ 3 - @test derivative(g, [2.0 3.0], 1) ≈ [12.0 27.0] -end diff --git a/test/vector.jl b/test/vector.jl deleted file mode 100644 index 7b62531..0000000 --- a/test/vector.jl +++ /dev/null @@ -1,6 +0,0 @@ - -@testset "Vector" begin - g(x) = x[1] * x[1] + x[2] * x[2] - @test derivative(g, [1.0, 2.0], [1.0, 0.0], 1) ≈ 2.0 - @test derivative(g, [1.0 2.0; 2.0 3.0], [1.0, 1.0], 1) ≈ [6.0 10.0] -end diff --git a/test/zygote.jl b/test/zygote.jl index a1d2c2b..1ef667f 100644 --- a/test/zygote.jl +++ b/test/zygote.jl @@ -6,15 +6,15 @@ using Zygote, LinearAlgebra for f in (exp, log, sqrt, sin, asin, sinh, asinh) @test gradient(x -> derivative(f, x, 2), some_number)[1] ≈ derivative(f, some_number, 3) - derivative_result = vec(derivative(f, some_numbers, 3)) - @test Zygote.jacobian(x -> derivative(f, x, 2), some_numbers)[1] ≈ + derivative_result = vec(derivative.(f, some_numbers, 3)) + @test Zygote.jacobian(x -> derivative.(f, x, 2), some_numbers)[1] ≈ diagm(derivative_result) end some_matrix = [0.7 0.1; 0.4 0.2] f = x -> sum(tanh.(x), dims = 1) - dfdx1(m, x) = derivative(u -> sum(m(u)), x, [1.0, 0.0], 1) - dfdx2(m, x) = derivative(u -> sum(m(u)), x, [0.0, 1.0], 1) + dfdx1(m, x) = derivative(m, x, [1.0, 0.0], 1) + dfdx2(m, x) = derivative(m, x, [0.0, 1.0], 1) res(m, x) = dfdx1(m, x) .+ 2 * dfdx2(m, x) grads = Zygote.gradient(some_matrix) do x sum(res(f, x))