diff --git a/src/chainrules.jl b/src/chainrules.jl index 68ebcce..d2647a5 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -6,7 +6,7 @@ function contract(a::TaylorScalar{T, N}, b::TaylorScalar{S, N}) where {T, S, N} mapreduce(*, +, value(a), value(b)) end -function rrule(::Type{TaylorScalar{T, N}}, v::NTuple{N, T}) where {N, T <: Number} +function rrule(::Type{TaylorScalar{T, N}}, v::NTuple{N, T}) where {N, T} taylor_scalar_pullback(t̄) = NoTangent(), value(t̄) return TaylorScalar(v), taylor_scalar_pullback end @@ -22,7 +22,7 @@ function rrule(::typeof(value), t::TaylorScalar{T, N}) where {N, T} end function rrule(::typeof(extract_derivative), t::TaylorScalar{T, N}, - i::Integer) where {N, T <: Number} + i::Integer) where {N, T} function extract_derivative_pullback(d̄) NoTangent(), TaylorScalar{T, N}(ntuple(j -> j === i ? d̄ : zero(T), Val(N))), NoTangent() @@ -31,7 +31,7 @@ function rrule(::typeof(extract_derivative), t::TaylorScalar{T, N}, end function rrule(::typeof(*), A::AbstractMatrix{S}, - t::AbstractVector{TaylorScalar{T, N}}) where {N, S <: Number, T} + t::AbstractVector{TaylorScalar{T, N}}) where {N, S, T} project_A = ProjectTo(A) function gemv_pullback(x̄) x̂ = reinterpret(reshape, T, x̄) @@ -41,17 +41,17 @@ function rrule(::typeof(*), A::AbstractMatrix{S}, return A * t, gemv_pullback end -@adjoint function +(t::Vector{TaylorScalar{T, N}}, v::Vector{T}) where {N, T <: Number} +@adjoint function +(t::Vector{TaylorScalar{T, N}}, v::Vector{T}) where {N, T} project_v = ProjectTo(v) t + v, x̄ -> (x̄, project_v(x̄)) end -@adjoint function +(v::Vector{T}, t::Vector{TaylorScalar{T, N}}) where {N, T <: Number} +@adjoint function +(v::Vector{T}, t::Vector{TaylorScalar{T, N}}) where {N, T} project_v = ProjectTo(v) v + t, x̄ -> (project_v(x̄), x̄) end -(project::ProjectTo{T})(dx::TaylorScalar{T, N}) where {N, T <: Number} = primal(dx) +(project::ProjectTo{T})(dx::TaylorScalar{T, N}) where {N, T} = primal(dx) # Not-a-number patches diff --git a/src/derivative.jl b/src/derivative.jl index c5f2b0b..1cf1d6f 100644 --- a/src/derivative.jl +++ b/src/derivative.jl @@ -40,13 +40,13 @@ end # Added to help Zygote infer types make_taylor(t0::T, t1::S, ::Val{N}) where {T, S, N} = TaylorScalar{T, N}(t0, T(t1)) -@inline function derivative(f, x::T, ::Val{N}) where {T <: Number, N} +@inline function derivative(f, x::T, ::Val{N}) where {T <: TN, N} t = TaylorScalar{T, N}(x, one(x)) return extract_derivative(f(t), N) end @inline function derivative(f, x::AbstractVector{T}, l::AbstractVector{S}, - vN::Val{N}) where {T <: Number, S <: Number, N} + vN::Val{N}) where {T <: TN, S <: TN, N} t = map((t0, t1) -> make_taylor(t0, t1, vN), x, l) # equivalent to map(TaylorScalar{T, N}, x, l) return extract_derivative(f(t), N) @@ -54,12 +54,12 @@ end # shorthand notations for matrices -@inline function derivative(f, x::AbstractMatrix{T}, vN::Val{N}) where {T <: Number, N} +@inline function derivative(f, x::AbstractMatrix{T}, vN::Val{N}) where {T <: TN, N} size(x)[1] != 1 && @warn "x is not a row vector." mapcols(u -> derivative(f, u[1], vN), x) end @inline function derivative(f, x::AbstractMatrix{T}, l::AbstractVector{S}, - vN::Val{N}) where {T <: Number, S <: Number, N} + vN::Val{N}) where {T <: TN, S <: TN, N} mapcols(u -> derivative(f, u, l, vN), x) end diff --git a/src/scalar.jl b/src/scalar.jl index d427580..80ec1f0 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -17,7 +17,7 @@ struct TaylorScalar{T, N} value::NTuple{N, T} end -TaylorOrNumber = Union{TaylorScalar, Number} +TN = Union{TaylorScalar, Number} @inline TaylorScalar(xs::Vararg{T, N}) where {T, N} = TaylorScalar(xs)