Skip to content

Commit

Permalink
Format
Browse files Browse the repository at this point in the history
  • Loading branch information
zhujch1 authored and tansongchen committed Nov 9, 2023
1 parent 9a037b3 commit 770c190
Show file tree
Hide file tree
Showing 7 changed files with 21 additions and 19 deletions.
2 changes: 1 addition & 1 deletion benchmark/mlp.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
function create_benchmark_mlp(mlp_conf::Tuple{Int, Int}, x::Vector{T},
l::Vector{T}) where {T <: Number}
l::Vector{T}) where {T <: Number}
input, hidden = mlp_conf
W₁, W₂, b₁, b₂ = rand(hidden, input), rand(1, hidden), rand(hidden), rand(1)
σ = exp
Expand Down
4 changes: 2 additions & 2 deletions ext/TaylorDiffSFExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using ChainRules, ChainRulesCore

dummy = (NoTangent(), 1)
@variables z
for func in (erf, )
for func in (erf,)
F = typeof(func)
# base case
@eval function (op::$F)(t::TaylorScalar{T, 2}) where {T}
Expand All @@ -30,4 +30,4 @@ for func in (erf, )
end
end

end
end
16 changes: 9 additions & 7 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ function rrule(::typeof(value), t::TaylorScalar{T, N}) where {N, T}
end

function rrule(::typeof(extract_derivative), t::TaylorScalar{T, N},
i::Integer) where {N, T}
i::Integer) where {N, T}
function extract_derivative_pullback(d̄)
NoTangent(), TaylorScalar{T, N}(ntuple(j -> j === i ?: zero(T), Val(N))),
NoTangent()
Expand All @@ -31,7 +31,7 @@ function rrule(::typeof(extract_derivative), t::TaylorScalar{T, N},
end

function rrule(::typeof(*), A::AbstractMatrix{S},
t::AbstractVector{TaylorScalar{T, N}}) where {N, S, T}
t::AbstractVector{TaylorScalar{T, N}}) where {N, S, T}
project_A = ProjectTo(A)
function gemv_pullback(x̄)
= reinterpret(reshape, T, x̄)
Expand All @@ -42,12 +42,14 @@ function rrule(::typeof(*), A::AbstractMatrix{S},
end

function rrule(::typeof(*), A::AbstractMatrix{S},
B::AbstractMatrix{TaylorScalar{T, N}}) where {N, S, T}
B::AbstractMatrix{TaylorScalar{T, N}}) where {N, S, T}
project_A = ProjectTo(A)
project_B = ProjectTo(B)
function gemm_pullback(x̄)
= unthunk(x̄)
NoTangent(), @thunk(project_A(X̄ * transpose(B))), @thunk(project_B(transpose(A) * X̄))
NoTangent(),
@thunk(project_A(X̄ * transpose(B))),
@thunk(project_B(transpose(A) * X̄))
end
return A * B, gemm_pullback
end
Expand Down Expand Up @@ -85,8 +87,8 @@ struct TaylorOneElement{T, N, I, A} <: AbstractArray{T, N}
ind::I
axes::A
function TaylorOneElement(val::T, ind::I,
axes::A) where {T <: TaylorScalar, I <: NTuple{N, Int},
A <: NTuple{N, AbstractUnitRange}} where {N}
axes::A) where {T <: TaylorScalar, I <: NTuple{N, Int},
A <: NTuple{N, AbstractUnitRange}} where {N}
new{T, N, I, A}(val, ind, axes)
end
end
Expand Down Expand Up @@ -125,7 +127,7 @@ function rrule(::typeof(*), x::TaylorScalar, y::TaylorScalar)
end

function rrule(::typeof(*), x::TaylorScalar, y::TaylorScalar, z::TaylorScalar,
more::TaylorScalar...)
more::TaylorScalar...)
Ω2, back2 = rrule(*, x, y)
Ω3, back3 = rrule(*, Ω2, z)
Ω4, back4 = rrule(*, Ω3, more...)
Expand Down
4 changes: 2 additions & 2 deletions src/derivative.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ make_taylor(t0::T, t1::S, ::Val{N}) where {T, S, N} = TaylorScalar{T, N}(t0, T(t
end

@inline function derivative(f, x::AbstractVector{T}, l::AbstractVector{S},
vN::Val{N}) where {T <: TN, S <: TN, N}
vN::Val{N}) where {T <: TN, S <: TN, N}
t = map((t0, t1) -> make_taylor(t0, t1, vN), x, l)
# equivalent to map(TaylorScalar{T, N}, x, l)
return extract_derivative(f(t), N)
Expand All @@ -61,7 +61,7 @@ end
end

@inline function derivative(f, x::AbstractMatrix{T}, l::AbstractVector{S},
vN::Val{N}) where {T <: TN, S <: TN, N}
vN::Val{N}) where {T <: TN, S <: TN, N}
t = make_taylor.(x, l, vN)
return extract_derivative.(f(t), N)
end
4 changes: 2 additions & 2 deletions src/primitive.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ end
end

@generated function raise(f::T, df::TaylorScalar{T, M},
t::TaylorScalar{T, N}) where {T, M, N} # M + 1 == N
t::TaylorScalar{T, N}) where {T, M, N} # M + 1 == N
return quote
$(Expr(:meta, :inline))
vdf, vt = value(df), value(t)
Expand All @@ -162,7 +162,7 @@ end
raise(::T, df::S, t::TaylorScalar{T, N}) where {S <: Number, T, N} = df * t

@generated function raiseinv(f::T, df::TaylorScalar{T, M},
t::TaylorScalar{T, N}) where {T, M, N} # M + 1 == N
t::TaylorScalar{T, N}) where {T, M, N} # M + 1 == N
ex = quote
vdf, vt = value(df), value(t)
v1 = vt[2] / vdf[1]
Expand Down
4 changes: 2 additions & 2 deletions src/scalar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ end
@inline value(t::TaylorScalar) = t.value
@inline extract_derivative(t::TaylorScalar, i::Integer) = t.value[i]
@inline function extract_derivative(v::AbstractArray{T},
i::Integer) where {T <: TaylorScalar}
i::Integer) where {T <: TaylorScalar}
map(t -> extract_derivative(t, i), v)
end
@inline extract_derivative(r, i::Integer) = false
Expand All @@ -73,7 +73,7 @@ adjoint(t::TaylorScalar) = t
conj(t::TaylorScalar) = t

function promote_rule(::Type{TaylorScalar{T, N}},
::Type{S}) where {T, S, N}
::Type{S}) where {T, S, N}
TaylorScalar{promote_type(T, S), N}
end

Expand Down
6 changes: 3 additions & 3 deletions test/derivative.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
g(x) = x^3
@test derivative(g, 1.0, 1) 3

h(x) = x.^3
h(x) = x .^ 3
@test derivative(h, [2.0 3.0], 1) [12.0 27.0]
end

@testset "Directional derivative" begin
g(x) = x[1] * x[1] + x[2] * x[2]
@test derivative(g, [1.0, 2.0], [1.0, 0.0], 1) 2.0

h(x) = sum(x, dims=1)
@test derivative(h, [1.0 2.0; 2.0 3.0], [1.0, 1.0], 1) [2. 2.]
h(x) = sum(x, dims = 1)
@test derivative(h, [1.0 2.0; 2.0 3.0], [1.0, 1.0], 1) [2.0 2.0]
end

0 comments on commit 770c190

Please sign in to comment.