Skip to content

Commit fb9dcb1

Browse files
authored
Merge pull request #171 from haskell/lehins/use-array-for-shuffle
Implement a faster and unbiased version of list shuffling
2 parents 6b30bd9 + a79f427 commit fb9dcb1

File tree

7 files changed

+223
-37
lines changed

7 files changed

+223
-37
lines changed

CHANGELOG.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
* Add `Seed`, `SeedGen`, `seedSize`, `mkSeed` and `unSeed`:
55
[#162](https://github.com/haskell/random/pull/162)
66
* Add `SplitGen` and `splitGen`: [#160](https://github.com/haskell/random/pull/160)
7-
* Add `shuffleList` and `shuffleListM`: [#140](https://github.com/haskell/random/pull/140)
7+
* Add `unifromShuffleList` and `unifromShuffleListM`: [#140](https://github.com/haskell/random/pull/140)
8+
* Add `uniformWordR`: [#140](https://github.com/haskell/random/pull/140)
89
* Add `mkStdGen64`: [#155](https://github.com/haskell/random/pull/155)
910
* Add `uniformListRM`, `uniformList`, `uniformListR`, `uniforms` and `uniformRs`:
1011
[#154](https://github.com/haskell/random/pull/154)

bench/Main.hs

+17-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ module Main (main) where
77
import Control.Monad
88
import Control.Monad.State.Strict
99
import Data.Int
10+
import Data.List (sortOn)
1011
import Data.Proxy
1112
import Data.Typeable
1213
import Data.Word
@@ -263,9 +264,15 @@ main = do
263264
, env getStdGen $ \gen ->
264265
bench "uniformByteArray 100MB" $ nf (\n -> uniformByteArray False n gen) sz100MiB
265266
, env getStdGen $ \gen ->
266-
bench "genByteString 100MB" $ nf (\k -> genByteString k gen) sz100MiB
267+
bench "genByteString 100MB" $ nf (`genByteString` gen) sz100MiB
267268
]
268269
]
270+
, env (pure [0 :: Integer .. 200000]) $ \xs ->
271+
bgroup "shuffle"
272+
[ env getStdGen $ bench "uniformShuffleList" . nf (uniformShuffleList xs)
273+
, env getStdGen $ bench "uniformShuffleListM" . nf (`runStateGen` uniformShuffleListM xs)
274+
, env getStdGen $ bench "naiveShuffleListM" . nf (`runStateGen` naiveShuffleListM xs)
275+
]
269276
]
270277

271278
pureUniformRFullBench ::
@@ -351,3 +358,12 @@ fillMutablePrimArrayM f ma g = do
351358
go 0
352359
unsafeFreezePrimArray ma
353360
#endif
361+
362+
363+
naiveShuffleListM :: StatefulGen g m => [a] -> g -> m [a]
364+
naiveShuffleListM xs gen = do
365+
is <- uniformListM n gen
366+
pure $ map snd $ sortOn fst $ zip (is :: [Int]) xs
367+
where
368+
!n = length xs
369+
{-# INLINE naiveShuffleListM #-}

src/System/Random.hs

+9-8
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ module System.Random
4545
, uniformRs
4646
, uniformList
4747
, uniformListR
48-
, shuffleList
48+
, uniformShuffleList
4949
-- ** Bytes
5050
, uniformByteArray
5151
, uniformByteString
@@ -94,6 +94,7 @@ import Data.IORef
9494
import Data.Word
9595
import Foreign.C.Types
9696
import GHC.Exts
97+
import System.Random.Array (shuffleListST)
9798
import System.Random.GFinite (Finite)
9899
import System.Random.Internal
99100
import System.Random.Seed
@@ -294,18 +295,18 @@ uniformListR :: (UniformRange a, RandomGen g) => Int -> (a, a) -> g -> ([a], g)
294295
uniformListR n r g = runStateGen g (uniformListRM n r)
295296
{-# INLINE uniformListR #-}
296297

297-
-- | Shuffle elements of a list in a random order.
298+
-- | Shuffle elements of a list in a uniformly random order.
298299
--
299300
-- ====__Examples__
300301
--
301-
-- >>> let gen = mkStdGen 2023
302-
-- >>> shuffleList ['a'..'z'] gen
303-
-- ("renlhfqmgptwksdiyavbxojzcu",StdGen {unStdGen = SMGen 9882508430712573120 1920468677557965761})
302+
-- >>> uniformShuffleList "ELVIS" $ mkStdGen 252
303+
-- ("LIVES",StdGen {unStdGen = SMGen 17676540583805057877 5302934877338729551})
304304
--
305305
-- @since 1.3.0
306-
shuffleList :: RandomGen g => [a] -> g -> ([a], g)
307-
shuffleList xs g = runStateGen g (shuffleListM xs)
308-
{-# INLINE shuffleList #-}
306+
uniformShuffleList :: RandomGen g => [a] -> g -> ([a], g)
307+
uniformShuffleList xs g =
308+
runStateGenST g $ \gen -> shuffleListST (`uniformWordR` gen) xs
309+
{-# INLINE uniformShuffleList #-}
309310

310311
-- | Generates a 'ByteString' of the specified size using a pure pseudo-random
311312
-- number generator. See 'uniformByteStringM' for the monadic version.

src/System/Random/Array.hs

+156
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,17 @@ module System.Random.Array
2828
, byteArrayToShortByteString
2929
, getSizeOfMutableByteArray
3030
, shortByteStringToByteString
31+
-- ** MutableArray
32+
, Array (..)
33+
, MutableArray (..)
34+
, newMutableArray
35+
, freezeMutableArray
36+
, writeArray
37+
, shuffleListM
38+
, shuffleListST
3139
) where
3240

41+
import Control.Monad.Trans (lift, MonadTrans)
3342
import Control.Monad (when)
3443
import Control.Monad.ST
3544
import Data.Array.Byte (ByteArray(..), MutableByteArray(..))
@@ -54,6 +63,10 @@ import Data.ByteString (ByteString)
5463
wordSizeInBits :: Int
5564
wordSizeInBits = finiteBitSize (0 :: Word)
5665

66+
----------------
67+
-- Byte Array --
68+
----------------
69+
5770
-- Architecture independent helpers:
5871

5972
sizeOfByteArray :: ByteArray -> Int
@@ -204,3 +217,146 @@ pinnedByteArrayToForeignPtr ba# =
204217
ForeignPtr (byteArrayContents# ba#) (PlainPtr (unsafeCoerce# ba#))
205218
{-# INLINE pinnedByteArrayToForeignPtr #-}
206219
#endif
220+
221+
-----------------
222+
-- Boxed Array --
223+
-----------------
224+
225+
data Array a = Array (Array# a)
226+
227+
data MutableArray s a = MutableArray (MutableArray# s a)
228+
229+
newMutableArray :: Int -> a -> ST s (MutableArray s a)
230+
newMutableArray (I# n#) a =
231+
ST $ \s# ->
232+
case newArray# n# a s# of
233+
(# s'#, ma# #) -> (# s'#, MutableArray ma# #)
234+
{-# INLINE newMutableArray #-}
235+
236+
freezeMutableArray :: MutableArray s a -> ST s (Array a)
237+
freezeMutableArray (MutableArray ma#) =
238+
ST $ \s# ->
239+
case unsafeFreezeArray# ma# s# of
240+
(# s'#, a# #) -> (# s'#, Array a# #)
241+
{-# INLINE freezeMutableArray #-}
242+
243+
sizeOfMutableArray :: MutableArray s a -> Int
244+
sizeOfMutableArray (MutableArray ma#) = I# (sizeofMutableArray# ma#)
245+
{-# INLINE sizeOfMutableArray #-}
246+
247+
readArray :: MutableArray s a -> Int -> ST s a
248+
readArray (MutableArray ma#) (I# i#) = ST (readArray# ma# i#)
249+
{-# INLINE readArray #-}
250+
251+
writeArray :: MutableArray s a -> Int -> a -> ST s ()
252+
writeArray (MutableArray ma#) (I# i#) a = st_ (writeArray# ma# i# a)
253+
{-# INLINE writeArray #-}
254+
255+
swapArray :: MutableArray s a -> Int -> Int -> ST s ()
256+
swapArray ma i j = do
257+
x <- readArray ma i
258+
y <- readArray ma j
259+
writeArray ma j x
260+
writeArray ma i y
261+
{-# INLINE swapArray #-}
262+
263+
-- | Write contents of the list into the mutable array. Make sure that array is big
264+
-- enough or segfault will happen.
265+
fillMutableArrayFromList :: MutableArray s a -> [a] -> ST s ()
266+
fillMutableArrayFromList ma = go 0
267+
where
268+
go _ [] = pure ()
269+
go i (x:xs) = writeArray ma i x >> go (i + 1) xs
270+
{-# INLINE fillMutableArrayFromList #-}
271+
272+
readListFromMutableArray :: MutableArray s a -> ST s [a]
273+
readListFromMutableArray ma = go (len - 1) []
274+
where
275+
len = sizeOfMutableArray ma
276+
go i !acc
277+
| i >= 0 = do
278+
x <- readArray ma i
279+
go (i - 1) (x : acc)
280+
| otherwise = pure acc
281+
{-# INLINE readListFromMutableArray #-}
282+
283+
284+
-- | Generate a list of indices that will be used for swapping elements in uniform shuffling:
285+
--
286+
-- @
287+
-- [ (0, n - 1)
288+
-- , (0, n - 2)
289+
-- , (0, n - 3)
290+
-- , ...
291+
-- , (0, 3)
292+
-- , (0, 2)
293+
-- , (0, 1)
294+
-- ]
295+
-- @
296+
genSwapIndices
297+
:: Monad m
298+
=> (Word -> m Word)
299+
-- ^ Action that generates a Word in the supplied range.
300+
-> Word
301+
-- ^ Number of index swaps to generate.
302+
-> m [Int]
303+
genSwapIndices genWordR n = go 1 []
304+
where
305+
go i !acc
306+
| i >= n = pure acc
307+
| otherwise = do
308+
x <- genWordR i
309+
let !xi = fromIntegral x
310+
go (i + 1) (xi : acc)
311+
{-# INLINE genSwapIndices #-}
312+
313+
314+
-- | Implementation of mutable version of Fisher-Yates shuffle. Unfortunately, we cannot generally
315+
-- interleave pseudo-random number generation and mutation of `ST` monad, therefore we have to
316+
-- pre-generate all of the index swaps with `genSwapIndices` and store them in a list before we can
317+
-- perform the actual swaps.
318+
shuffleListM :: Monad m => (Word -> m Word) -> [a] -> m [a]
319+
shuffleListM genWordR ls
320+
| len <= 1 = pure ls
321+
| otherwise = do
322+
swapIxs <- genSwapIndices genWordR (fromIntegral len)
323+
pure $ runST $ do
324+
ma <- newMutableArray len $ error "Impossible: shuffleListM"
325+
fillMutableArrayFromList ma ls
326+
327+
-- Shuffle elements of the mutable array according to the uniformly generated index swap list
328+
let goSwap _ [] = pure ()
329+
goSwap i (j:js) = swapArray ma i j >> goSwap (i - 1) js
330+
goSwap (len - 1) swapIxs
331+
332+
readListFromMutableArray ma
333+
where
334+
len = length ls
335+
{-# INLINE shuffleListM #-}
336+
337+
-- | This is a ~x2-x3 more efficient version of `shuffleListM`. It is more efficient because it does
338+
-- not need to pregenerate a list of indices and instead generates them on demand. Because of this the
339+
-- result that will be produced will differ for the same generator, since the order in which index
340+
-- swaps are generated is reversed.
341+
--
342+
-- Unfortunately, most stateful generator monads can't handle `MonadTrans`, so this version is only
343+
-- used for implementing the pure shuffle.
344+
shuffleListST :: (Monad (t (ST s)), MonadTrans t) => (Word -> t (ST s) Word) -> [a] -> t (ST s) [a]
345+
shuffleListST genWordR ls
346+
| len <= 1 = pure ls
347+
| otherwise = do
348+
ma <- lift $ newMutableArray len $ error "Impossible: shuffleListST"
349+
lift $ fillMutableArrayFromList ma ls
350+
351+
-- Shuffle elements of the mutable array according to the uniformly generated index swap
352+
let goSwap i =
353+
when (i > 0) $ do
354+
j <- genWordR $ (fromIntegral :: Int -> Word) i
355+
lift $ swapArray ma i ((fromIntegral :: Word -> Int) j)
356+
goSwap (i - 1)
357+
goSwap (len - 1)
358+
359+
lift $ readListFromMutableArray ma
360+
where
361+
len = length ls
362+
{-# INLINE shuffleListST #-}

src/System/Random/Internal.hs

+20-24
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ module System.Random.Internal
5757
, Uniform(..)
5858
, uniformViaFiniteM
5959
, UniformRange(..)
60+
, uniformWordR
6061
, uniformDouble01M
6162
, uniformDoublePositive01M
6263
, uniformFloat01M
@@ -65,7 +66,6 @@ module System.Random.Internal
6566
, uniformEnumRM
6667
, uniformListM
6768
, uniformListRM
68-
, shuffleListM
6969
, isInRangeOrd
7070
, isInRangeEnum
7171

@@ -108,7 +108,6 @@ import Data.ByteString (ByteString)
108108
import Data.ByteString.Short.Internal (ShortByteString(SBS))
109109
import Data.IORef (IORef, newIORef)
110110
import Data.Int
111-
import Data.List (sortOn)
112111
import Data.Word
113112
import Foreign.C.Types
114113
import Foreign.Storable (Storable)
@@ -221,7 +220,6 @@ class RandomGen g where
221220
-- /Note/ - This function will be removed from the type class in the next major release as
222221
-- it is no longer needed because of `unsafeUniformFillMutableByteArray`.
223222
--
224-
--
225223
-- @since 1.2.0
226224
genShortByteString :: Int -> g -> (ShortByteString, g)
227225
genShortByteString n g =
@@ -273,10 +271,10 @@ class RandomGen g where
273271
{-# DEPRECATED split "In favor of `splitGen`" #-}
274272

275273
-- | Pseudo-random generators that can be split into two separate and independent
276-
-- psuedo-random generators can have an instance for this type class.
274+
-- psuedo-random generators should provide an instance for this type class.
277275
--
278276
-- Historically this functionality was included in the `RandomGen` type class in the
279-
-- `split` function, however, few pseudo-random generators posses this property of
277+
-- `split` function, however, few pseudo-random generators possess this property of
280278
-- splittability. This lead the old `split` function being usually implemented in terms of
281279
-- `error`.
282280
--
@@ -784,25 +782,6 @@ uniformListRM :: (StatefulGen g m, UniformRange a) => Int -> (a, a) -> g -> m [a
784782
uniformListRM n range gen = replicateM n (uniformRM range gen)
785783
{-# INLINE uniformListRM #-}
786784

787-
-- | Shuffle elements of a list in a random order.
788-
--
789-
-- ====__Examples__
790-
--
791-
-- >>> import System.Random.Stateful
792-
-- >>> let pureGen = mkStdGen 2023
793-
-- >>> g <- newIOGenM pureGen
794-
-- >>> shuffleListM ['a'..'z'] g :: IO String
795-
-- "renlhfqmgptwksdiyavbxojzcu"
796-
--
797-
-- @since 1.3.0
798-
shuffleListM :: StatefulGen g m => [a] -> g -> m [a]
799-
shuffleListM xs gen = do
800-
is <- uniformListM n gen
801-
pure $ map snd $ sortOn fst $ zip (is :: [Int]) xs
802-
where
803-
!n = length xs
804-
{-# INLINE shuffleListM #-}
805-
806785
-- | The standard pseudo-random number generator.
807786
newtype StdGen = StdGen { unStdGen :: SM.SMGen }
808787
deriving (Show, RandomGen, SplitGen, NFData)
@@ -1128,6 +1107,23 @@ instance UniformRange Word where
11281107
{-# INLINE uniformRM #-}
11291108
isInRange = isInRangeOrd
11301109

1110+
-- | Architecture specific `Word` generation in the specified lower range
1111+
--
1112+
-- @since 1.3.0
1113+
uniformWordR ::
1114+
StatefulGen g m
1115+
=> Word
1116+
-- ^ Maximum value to generate
1117+
-> g
1118+
-- ^ Stateful generator
1119+
-> m Word
1120+
uniformWordR r
1121+
| wordSizeInBits == 64 =
1122+
fmap (fromIntegral :: Word64 -> Word) . uniformWord64R ((fromIntegral :: Word -> Word64) r)
1123+
| otherwise =
1124+
fmap (fromIntegral :: Word32 -> Word) . uniformWord32R ((fromIntegral :: Word -> Word32) r)
1125+
{-# INLINE uniformWordR #-}
1126+
11311127
instance Uniform Word8 where
11321128
uniformM = uniformWord8
11331129
{-# INLINE uniformM #-}

0 commit comments

Comments
 (0)