From 770c1908652f877aa3b1e0851b62b384404968ed Mon Sep 17 00:00:00 2001 From: zhujch Date: Wed, 8 Nov 2023 16:34:54 -0500 Subject: [PATCH] Format --- benchmark/mlp.jl | 2 +- ext/TaylorDiffSFExt.jl | 4 ++-- src/chainrules.jl | 16 +++++++++------- src/derivative.jl | 4 ++-- src/primitive.jl | 4 ++-- src/scalar.jl | 4 ++-- test/derivative.jl | 6 +++--- 7 files changed, 21 insertions(+), 19 deletions(-) diff --git a/benchmark/mlp.jl b/benchmark/mlp.jl index dacd34a..aba9952 100644 --- a/benchmark/mlp.jl +++ b/benchmark/mlp.jl @@ -1,5 +1,5 @@ function create_benchmark_mlp(mlp_conf::Tuple{Int, Int}, x::Vector{T}, - l::Vector{T}) where {T <: Number} + l::Vector{T}) where {T <: Number} input, hidden = mlp_conf W₁, W₂, b₁, b₂ = rand(hidden, input), rand(1, hidden), rand(hidden), rand(1) σ = exp diff --git a/ext/TaylorDiffSFExt.jl b/ext/TaylorDiffSFExt.jl index 1705cfb..c2059da 100644 --- a/ext/TaylorDiffSFExt.jl +++ b/ext/TaylorDiffSFExt.jl @@ -8,7 +8,7 @@ using ChainRules, ChainRulesCore dummy = (NoTangent(), 1) @variables z -for func in (erf, ) +for func in (erf,) F = typeof(func) # base case @eval function (op::$F)(t::TaylorScalar{T, 2}) where {T} @@ -30,4 +30,4 @@ for func in (erf, ) end end -end \ No newline at end of file +end diff --git a/src/chainrules.jl b/src/chainrules.jl index 4019fce..92b4ed3 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -22,7 +22,7 @@ function rrule(::typeof(value), t::TaylorScalar{T, N}) where {N, T} end function rrule(::typeof(extract_derivative), t::TaylorScalar{T, N}, - i::Integer) where {N, T} + i::Integer) where {N, T} function extract_derivative_pullback(d̄) NoTangent(), TaylorScalar{T, N}(ntuple(j -> j === i ? d̄ : zero(T), Val(N))), NoTangent() @@ -31,7 +31,7 @@ function rrule(::typeof(extract_derivative), t::TaylorScalar{T, N}, end function rrule(::typeof(*), A::AbstractMatrix{S}, - t::AbstractVector{TaylorScalar{T, N}}) where {N, S, T} + t::AbstractVector{TaylorScalar{T, N}}) where {N, S, T} project_A = ProjectTo(A) function gemv_pullback(x̄) x̂ = reinterpret(reshape, T, x̄) @@ -42,12 +42,14 @@ function rrule(::typeof(*), A::AbstractMatrix{S}, end function rrule(::typeof(*), A::AbstractMatrix{S}, - B::AbstractMatrix{TaylorScalar{T, N}}) where {N, S, T} + B::AbstractMatrix{TaylorScalar{T, N}}) where {N, S, T} project_A = ProjectTo(A) project_B = ProjectTo(B) function gemm_pullback(x̄) X̄ = unthunk(x̄) - NoTangent(), @thunk(project_A(X̄ * transpose(B))), @thunk(project_B(transpose(A) * X̄)) + NoTangent(), + @thunk(project_A(X̄ * transpose(B))), + @thunk(project_B(transpose(A) * X̄)) end return A * B, gemm_pullback end @@ -85,8 +87,8 @@ struct TaylorOneElement{T, N, I, A} <: AbstractArray{T, N} ind::I axes::A function TaylorOneElement(val::T, ind::I, - axes::A) where {T <: TaylorScalar, I <: NTuple{N, Int}, - A <: NTuple{N, AbstractUnitRange}} where {N} + axes::A) where {T <: TaylorScalar, I <: NTuple{N, Int}, + A <: NTuple{N, AbstractUnitRange}} where {N} new{T, N, I, A}(val, ind, axes) end end @@ -125,7 +127,7 @@ function rrule(::typeof(*), x::TaylorScalar, y::TaylorScalar) end function rrule(::typeof(*), x::TaylorScalar, y::TaylorScalar, z::TaylorScalar, - more::TaylorScalar...) + more::TaylorScalar...) Ω2, back2 = rrule(*, x, y) Ω3, back3 = rrule(*, Ω2, z) Ω4, back4 = rrule(*, Ω3, more...) diff --git a/src/derivative.jl b/src/derivative.jl index e3559d3..6010a8e 100644 --- a/src/derivative.jl +++ b/src/derivative.jl @@ -46,7 +46,7 @@ make_taylor(t0::T, t1::S, ::Val{N}) where {T, S, N} = TaylorScalar{T, N}(t0, T(t end @inline function derivative(f, x::AbstractVector{T}, l::AbstractVector{S}, - vN::Val{N}) where {T <: TN, S <: TN, N} + vN::Val{N}) where {T <: TN, S <: TN, N} t = map((t0, t1) -> make_taylor(t0, t1, vN), x, l) # equivalent to map(TaylorScalar{T, N}, x, l) return extract_derivative(f(t), N) @@ -61,7 +61,7 @@ end end @inline function derivative(f, x::AbstractMatrix{T}, l::AbstractVector{S}, - vN::Val{N}) where {T <: TN, S <: TN, N} + vN::Val{N}) where {T <: TN, S <: TN, N} t = make_taylor.(x, l, vN) return extract_derivative.(f(t), N) end diff --git a/src/primitive.jl b/src/primitive.jl index 3fd55b9..32c2dcb 100644 --- a/src/primitive.jl +++ b/src/primitive.jl @@ -148,7 +148,7 @@ end end @generated function raise(f::T, df::TaylorScalar{T, M}, - t::TaylorScalar{T, N}) where {T, M, N} # M + 1 == N + t::TaylorScalar{T, N}) where {T, M, N} # M + 1 == N return quote $(Expr(:meta, :inline)) vdf, vt = value(df), value(t) @@ -162,7 +162,7 @@ end raise(::T, df::S, t::TaylorScalar{T, N}) where {S <: Number, T, N} = df * t @generated function raiseinv(f::T, df::TaylorScalar{T, M}, - t::TaylorScalar{T, N}) where {T, M, N} # M + 1 == N + t::TaylorScalar{T, N}) where {T, M, N} # M + 1 == N ex = quote vdf, vt = value(df), value(t) v1 = vt[2] / vdf[1] diff --git a/src/scalar.jl b/src/scalar.jl index 80ec1f0..b0b7afd 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -58,7 +58,7 @@ end @inline value(t::TaylorScalar) = t.value @inline extract_derivative(t::TaylorScalar, i::Integer) = t.value[i] @inline function extract_derivative(v::AbstractArray{T}, - i::Integer) where {T <: TaylorScalar} + i::Integer) where {T <: TaylorScalar} map(t -> extract_derivative(t, i), v) end @inline extract_derivative(r, i::Integer) = false @@ -73,7 +73,7 @@ adjoint(t::TaylorScalar) = t conj(t::TaylorScalar) = t function promote_rule(::Type{TaylorScalar{T, N}}, - ::Type{S}) where {T, S, N} + ::Type{S}) where {T, S, N} TaylorScalar{promote_type(T, S), N} end diff --git a/test/derivative.jl b/test/derivative.jl index e33126a..020b125 100644 --- a/test/derivative.jl +++ b/test/derivative.jl @@ -3,7 +3,7 @@ g(x) = x^3 @test derivative(g, 1.0, 1) ≈ 3 - h(x) = x.^3 + h(x) = x .^ 3 @test derivative(h, [2.0 3.0], 1) ≈ [12.0 27.0] end @@ -11,6 +11,6 @@ end g(x) = x[1] * x[1] + x[2] * x[2] @test derivative(g, [1.0, 2.0], [1.0, 0.0], 1) ≈ 2.0 - h(x) = sum(x, dims=1) - @test derivative(h, [1.0 2.0; 2.0 3.0], [1.0, 1.0], 1) ≈ [2. 2.] + h(x) = sum(x, dims = 1) + @test derivative(h, [1.0 2.0; 2.0 3.0], [1.0, 1.0], 1) ≈ [2.0 2.0] end