Skip to content

Latest commit

 

History

History
84 lines (58 loc) · 3 KB

README.md

File metadata and controls

84 lines (58 loc) · 3 KB

Get the type of AD right in Haskell.

Forward accumulation mode automatic differentiation only, for now. (There is no barrier in principle to extending it to cover reverse mode.)

This cannot be used to actually take derivatives, because DA.lift is not implemented for function objects. It seems like this would require a reflective mechanism, which is unavailable.

There are two main mechanisms involved. One is a "tangent bundle" type class, which is the T functor from differential geometry and has the property T (a → b) = a → T b and T Double = Dual Double. The other is a "differential algebra" type class, which lifts an object to replace each Double value by a Dual Double inside and out, so D (a → b) = D a → D b. The former is the "right" API for differentiation, namely the pushforward operator. The latter is the natural implementation mechanism. Goo is defined to translate between the two as needed, at the interface between the API of differentiation and its implementation.

The heart of the implementation then becomes the apparently unimplementable DA.lift on functions.

But the types work.

$ ghci -Wall
GHCi, version 7.6.3

Prelude> :l Diff.hs
[1 of 5] Compiling Numeric.Dual     ( Numeric/Dual.hs, interpreted )
[2 of 5] Compiling TangentBundle    ( TangentBundle.hs, interpreted )
[3 of 5] Compiling DifferentialAlgebra ( DifferentialAlgebra.hs, interpreted )
[4 of 5] Compiling ConvertTBandDA   ( ConvertTBandDA.hs, interpreted )
[5 of 5] Compiling Diff             ( Diff.hs, interpreted )
Ok, modules loaded: Diff, ConvertTBandDA, DifferentialAlgebra, TangentBundle, Numeric.Dual.

> :m + Numeric.Dual 

> :t diff (sin::Double -> Double)

diff (sin::Double -> Double) :: Double -> Double

> :t iterate diff ((*)::Double -> Double -> Double)

iterate diff ((*)::Double -> Double -> Double)
  :: [Double -> Double -> Double]

> :t pushforward ((*)::Double -> Double -> Double)

pushforward ((*)::Double -> Double -> Double)
  :: Dual Double -> Double -> Dual Double

> :t pushforward (pushforward ((*)::Double -> Double -> Double))

pushforward (pushforward ((*)::Double -> Double -> Double))
  :: Dual (Dual Double) -> Double -> Dual (Dual Double)

> :t diff (pushforward (pushforward ((*)::Double -> Double -> Double)))

diff (pushforward (pushforward ((*)::Double -> Double -> Double)))
  :: Dual (Dual Double) -> Double -> Dual (Dual Double)

> :t pushforward (flip (pushforward (pushforward ((*)::Double -> Double -> Double))))

pushforward (flip (pushforward (pushforward ((*)::Double -> Double -> Double))))
  :: Dual Double -> Dual (Dual Double) -> Dual (Dual (Dual Double))

> :t pushforward (uncurry ((*)::Double -> Double -> Double))

pushforward (uncurry ((*)::Double -> Double -> Double))
  :: (Dual Double, Dual Double) -> Dual Double

> :t pushforward (sum::[Double] -> Double)

pushforward (sum::[Double] -> Double) :: [Dual Double] -> Dual Double

> :t pushforward sum

pushforward sum
  :: (Num a, ConvertTBandDA a a' ta da ba) => [ta] -> ta