Skip to content

Commit

Permalink
Create a test suite for fusion, using inspection-testing
Browse files Browse the repository at this point in the history
see the comments in `fusion-test/Canary.hs` for an outline.

Currently, “only” 71 tests are defined; there is more tedious work to be
done if this approach is found to be good.

This words towards fixing haskell#229.
  • Loading branch information
nomeata committed Nov 17, 2018
1 parent cc06420 commit 7a0f68a
Show file tree
Hide file tree
Showing 3 changed files with 389 additions and 0 deletions.
120 changes: 120 additions & 0 deletions fusion-test/Canary.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
{- |
This module provide the function 'fuseHere', which can be inserted into a
pipeline of vector-processing functions. It also contains copies of all the
fusion-related rewrite rules from "Data.Vector.Generic", with 'fuseHere'
inserted. This way, if fusion happens at this point, the 'fuseHere' function
disappears.
Having to maintain a complete copy of all the rewrite rules is a big downsid of
this approach, and a better way would be appreciated.
-}
module Canary (fuseHere) where

import qualified Data.Vector.Generic as V
import qualified Data.Vector.Generic.New as New
import Data.Vector.Fusion.Stream.Monadic ( Stream )
import qualified Data.Vector.Fusion.Bundle as Bundle
import Data.Vector.Fusion.Bundle ( Bundle, MBundle, lift, inplace )
import qualified Data.Vector.Fusion.Bundle.Monadic as MBundle

-- | Put this function into vector pipelines where you want them to fuse
fuseHere :: a -> a
fuseHere = id
{-# NOINLINE fuseHere #-}

-- | We need to copy all fusion rules here, with fuseHere inserted in the right
-- spot.

{-# RULES

"(!)/fuseHere/unstream [Vector]" forall i s.
fuseHere (V.new (New.unstream s)) V.! i = s Bundle.!! i

"(!?)/fuseHere/unstream [Vector]" forall i s.
fuseHere (V.new (New.unstream s)) V.!? i = s Bundle.!? i

"head/fuseHere/unstream [Vector]" forall s.
V.head (fuseHere (V.new (New.unstream s))) = Bundle.head s

"last/fuseHere/unstream [Vector]" forall s.
V.last (fuseHere (V.new (New.unstream s))) = Bundle.last s

"unsafeIndex/fuseHere/unstream [Vector]" forall i s.
V.unsafeIndex (fuseHere (V.new (New.unstream s))) i = s Bundle.!! i

"unsafeHead/fuseHere/unstream [Vector]" forall s.
V.unsafeHead (fuseHere (V.new (New.unstream s))) = Bundle.head s

"unsafeLast/fuseHere/unstream [Vector]" forall s.
V.unsafeLast (fuseHere (V.new (New.unstream s))) = Bundle.last s #-}

{-# RULES

"indexM/fuseHere/unstream [Vector]" forall s i.
V.indexM (fuseHere (V.new (New.unstream s))) i = lift s MBundle.!! i

"headM/fuseHere/unstream [Vector]" forall s.
V.headM (fuseHere (V.new (New.unstream s))) = MBundle.head (lift s)

"lastM/fuseHere/unstream [Vector]" forall s.
V.lastM (fuseHere (V.new (New.unstream s))) = MBundle.last (lift s)

"unsafeIndexM/fuseHere/unstream [Vector]" forall s i.
V.unsafeIndexM (fuseHere (V.new (New.unstream s))) i = lift s MBundle.!! i

"unsafeHeadM/fuseHere/unstream [Vector]" forall s.
V.unsafeHeadM (fuseHere (V.new (New.unstream s))) = MBundle.head (lift s)

"unsafeLastM/fuseHere/unstream [Vector]" forall s.
V.unsafeLastM (fuseHere (V.new (New.unstream s))) = MBundle.last (lift s) #-}

{-# RULES

"slice/fuseHere/new [Vector]" forall i n p.
V.slice i n (fuseHere (V.new p)) = V.new (New.slice i n p)

"init/fuseHere/new [Vector]" forall p.
V.init (fuseHere (V.new p)) = V.new (New.init p)

"tail/fuseHere/new [Vector]" forall p.
V.tail (fuseHere (V.new p)) = V.new (New.tail p)

"take/fuseHere/new [Vector]" forall n p.
V.take n (fuseHere (V.new p)) = V.new (New.take n p)

"drop/fuseHere/new [Vector]" forall n p.
V.drop n (fuseHere (V.new p)) = V.new (New.drop n p)

"unsafeSlice/fuseHere/new [Vector]" forall i n p.
V.unsafeSlice i n (fuseHere (V.new p)) = V.new (New.unsafeSlice i n p)

"unsafeInit/fuseHere/new [Vector]" forall p.
V.unsafeInit (fuseHere (V.new p)) = V.new (New.unsafeInit p)

"unsafeTail/fuseHere/new [Vector]" forall p.
V.unsafeTail (fuseHere (V.new p)) = V.new (New.unsafeTail p) #-}


{-# RULES

"stream/fuseHere/unstream [Vector]" forall s.
V.stream (fuseHere (V.new (New.unstream s))) = s

"New.unstream/fuseHere/stream [Vector]" forall v.
New.unstream (fuseHere (V.stream v)) = V.clone v

"clone/fuseHere/new [Vector]" forall p.
V.clone (fuseHere (V.new p)) = p

"inplace [Vector]"
forall (f :: forall m. Monad m => Stream m a -> Stream m a) g m.
New.unstream (inplace f g (V.stream (V.new m))) = New.transform f g m

"uninplace [Vector]"
forall (f :: forall m. Monad m => Stream m a -> Stream m a) g m.
V.stream (V.new (New.transform f g m)) = inplace f g (V.stream (V.new m))
#-}


258 changes: 258 additions & 0 deletions fusion-test/Main.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# OPTIONS_GHC -fplugin=Test.Inspection.Plugin #-}

{-
Fusion tests for vector
=======================
This file tests whether fusion happens the way we want it.
It does so by creating a typical pipeline with a function. For example, to test
`map`, we write
test_map f = V.toList . V.map f . V.fromList
Actually, we insert the `fuseHere` function, from the Canary module, in the
spots where we expect fusion to happen:
test_map f = V.toList . fuseHere . V.map f . fuseHere . V.fromList
The `fuseHere` function has rules included that make it disappear if fusion
happens at this spot. See Canary.hs for more details.
So if `fuseHere` remains in the code, fusion did not happen as expected. We
check this using the `inspection-testing` library.
Even if `fuseHere` disappeared, which means that for example the
`unstream/stream` rule fired, we want to check that all of the constructors
of the `Step` data type (`Skip`, `Done`, `Yield`) have been eliminiated. We
test this again using `inspection-testing`.
-}

module Main where

import qualified Data.Vector as V
import Canary
import Test.Inspection
import Data.Vector.Fusion.Stream.Monadic (Step(..))
import Control.Monad
import qualified Language.Haskell.TH as TH

main :: IO ()
main = return ()

-- Testing pipelines for producers, transformers, consumers
p f = V.toList . fuseHere . f
t f = V.toList . fuseHere . f . fuseHere . V.fromList
c f = f . fuseHere . V.fromList
{-# INLINE t #-}
{-# INLINE c #-}
{-# INLINE p #-}

-- To help the type checker, avoid ambiguous Monad ctraints
inIO :: IO a -> IO a
inIO = id

-- Length information

test_length = c V.length
test_null = c V.null

-- Indexing

test_bang = c (V.! 42)
test_safe_bang = c (V.!? 42)
test_head = c V.head
test_last = c V.last
test_unsafeIndex = c (`V.unsafeIndex` 42)
test_unsafeHead = c V.unsafeHead
test_unsafeLast = c V.unsafeLast

-- Monadic Indexing

test_indexM = inIO . c (`V.indexM` 42)
test_headM = inIO . c V.headM
test_lastM = inIO . c V.lastM
test_unsafeIndexM = inIO . c (`V.unsafeIndexM` 42)
test_unsafeHeadM = inIO . c V.unsafeHeadM
test_unsafeLastM = inIO . c V.unsafeLastM

-- Extracting subvectors (slicing)
test_slice = t (V.slice 23 42)
test_init = t V.init
test_tail = t V.tail
test_take = t (V.take 42)
test_drop = t (V.drop 42)
-- splitAt: hard to test
test_unsafeSlice = t (V.unsafeSlice 23 42)
test_unsafeInit = t V.unsafeInit
test_unsafeTail = t V.unsafeTail
test_unsafeTake = t (V.unsafeTake 42)
-- Does not actually fuse
-- test_unsafeDrop = t (V.unsafeDrop 42)

-- Initialisation

-- Does not fuse, as the ctant expression floats out
--test_empty = p (\(_::()) -> V.empty)
test_singleton = p V.singleton
test_replicate = p (V.replicate 5)
test_generate = p (V.generate 5)
test_iterateN = p (V.iterateN 5 (+1))

-- Monadic Initialisation
--
-- These don't fuse (no rules for unstreamM)

-- Unfolding

test_unfoldr x = p (V.unfoldr x)
test_unfoldrN x = p (V.unfoldrN 42 x)
-- ctructN and ctructrN cannot fuse

-- Enumeration

test_enumFromN (x::Integer) = p (V.enumFromN x)
test_enumFromStepN (x::Integer) y = p (V.enumFromStepN x y)
test_enumFromTo (x::Integer) = p (V.enumFromTo x)
test_enumFromThenTo (x::Integer) y = p (V.enumFromThenTo x y)

-- Concatenation

test_cons x = t (V.cons x)
test_snoc x = t (`V.snoc` x)
test_append_r x = t (x V.++)
test_append_l x = t (V.++ x)
test_concat = p V.concat

-- Bulk updates

-- bulk updates fuse as a consumers, but not as a producer
test_upd x = c (V.// x)
test_update_l x = c (`V.update` x)
test_update_r x = c (x `V.update`)
test_update__1 y z = c (\x -> V.update_ x y z)
test_update__2 y z = c (\x -> V.update_ y x z)
test_update__3 y z = c (\x -> V.update_ y z x)
test_unsafeUpd x = c (`V.unsafeUpd` x)
test_unsafeUpdate_l x = c (`V.unsafeUpdate` x)
test_unsafeUpdate_r x = c (x `V.unsafeUpdate`)
test_unsafeUpdate__1 y z = c (\x -> V.unsafeUpdate_ x y z)
test_unsafeUpdate__2 y z = c (\x -> V.unsafeUpdate_ y x z)
test_unsafeUpdate__3 y z = c (\x -> V.unsafeUpdate_ y z x)

-- Accumulations

test_accum f y = c (\x -> V.accum f x y)
test_accumulate_l f y = c (\x -> V.accumulate f x y)
test_accumulate_r f y = c (\x -> V.accumulate f y x)
test_accumulate__1 f y z = c (\x -> V.accumulate_ f x y z)
test_accumulate__2 f y z = c (\x -> V.accumulate_ f y x z)
test_accumulate__3 f y z = c (\x -> V.accumulate_ f y z x)
test_unsafeAccum f y = c (\x -> V.unsafeAccum f x y)
test_unsafeAccumulate_l f y = c (\x -> V.unsafeAccumulate f x y)
test_unsafeAccumulate_r f y = c (\x -> V.unsafeAccumulate f y x)
test_unsafeAccumulate__1 f y z = c (\x -> V.unsafeAccumulate_ f x y z)
test_unsafeAccumulate__2 f y z = c (\x -> V.unsafeAccumulate_ f y x z)
test_unsafeAccumulate__3 f y z = c (\x -> V.unsafeAccumulate_ f y z x)

-- Permutations

-- reverse is only a good producer, not a good consumer
test_reverse = p V.reverse
-- backpermute is only a good consumer in the second argument,
-- but not the first and not a good producer
test_backpermute y = c (V.backpermute y)
test_unsafeBackpermute y = c (V.unsafeBackpermute y)

-- Elementwise operations

test_indexed = t V.indexed
test_map f = t (V.map f)
test_imap f = t (V.imap f)
test_concatMap f = t (V.concatMap f)
-- No deep fusion?
-- test_concatMap_deep f = t (V.concatMap (\ x -> fuseHere (f x)))

-- ... much more to come ...

fmap (concat . reverse)$ forM
[ 'test_head
, 'test_null
, 'test_bang
, 'test_safe_bang
, 'test_head
, 'test_last
, 'test_unsafeIndex
, 'test_unsafeHead
, 'test_unsafeLast
, 'test_indexM
, 'test_headM
, 'test_lastM
, 'test_unsafeIndexM
, 'test_unsafeHeadM
, 'test_unsafeLastM
, 'test_slice
, 'test_init
, 'test_tail
, 'test_take
, 'test_drop
, 'test_unsafeSlice
, 'test_unsafeInit
, 'test_unsafeTail
, 'test_unsafeTake
-- , 'test_unsafeDrop
-- , 'test_empty
, 'test_singleton
, 'test_replicate
, 'test_generate
, 'test_iterateN
, 'test_unfoldr
, 'test_unfoldrN
, 'test_enumFromN
, 'test_enumFromStepN
, 'test_enumFromTo
, 'test_enumFromThenTo
, 'test_cons
, 'test_cons
, 'test_snoc
, 'test_append_r
, 'test_append_l
, 'test_concat
, 'test_upd
, 'test_update_l
, 'test_update_r
, 'test_update__1
, 'test_update__2
, 'test_update__3
, 'test_unsafeUpd
, 'test_unsafeUpdate_l
, 'test_unsafeUpdate_r
, 'test_unsafeUpdate__1
, 'test_unsafeUpdate__2
, 'test_unsafeUpdate__3
, 'test_accum
, 'test_accumulate_l
, 'test_accumulate_r
, 'test_accumulate__1
, 'test_accumulate__2
, 'test_accumulate__3
, 'test_unsafeAccum
, 'test_unsafeAccumulate_l
, 'test_unsafeAccumulate_r
, 'test_unsafeAccumulate__1
, 'test_unsafeAccumulate__2
, 'test_unsafeAccumulate__3
, 'test_reverse
, 'test_backpermute
, 'test_unsafeBackpermute
, 'test_indexed
, 'test_map
, 'test_imap
, 'test_concatMap
] $ \thn -> inspect
(mkObligation thn (NoUseOf ['fuseHere, 'Yield, 'Skip, 'Done]))
{ testName = Just (TH.nameBase thn) }

11 changes: 11 additions & 0 deletions vector.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -251,3 +251,14 @@ test-suite vector-tests-O2
if impl(ghc >= 8.0) && impl(ghc < 8.1)
Ghc-Options: -Wno-redundant-constraints

test-suite fusion-tests
Default-Language: Haskell2010
type: exitcode-stdio-1.0
Main-Is: Main.hs

other-modules: Canary
hs-source-dirs: fusion-test
Build-Depends: base >= 4.5 && < 5, vector, template-haskell,
inspection-testing >= 0.4.1

Ghc-Options: -dsuppress-all

0 comments on commit 7a0f68a

Please sign in to comment.