@@ -28,8 +28,17 @@ module System.Random.Array
28
28
, byteArrayToShortByteString
29
29
, getSizeOfMutableByteArray
30
30
, shortByteStringToByteString
31
+ -- ** MutableArray
32
+ , Array (.. )
33
+ , MutableArray (.. )
34
+ , newMutableArray
35
+ , freezeMutableArray
36
+ , writeArray
37
+ , shuffleListM
38
+ , shuffleListST
31
39
) where
32
40
41
+ import Control.Monad.Trans (lift , MonadTrans )
33
42
import Control.Monad (when )
34
43
import Control.Monad.ST
35
44
import Data.Array.Byte (ByteArray (.. ), MutableByteArray (.. ))
@@ -54,6 +63,10 @@ import Data.ByteString (ByteString)
54
63
wordSizeInBits :: Int
55
64
wordSizeInBits = finiteBitSize (0 :: Word )
56
65
66
+ ----------------
67
+ -- Byte Array --
68
+ ----------------
69
+
57
70
-- Architecture independent helpers:
58
71
59
72
sizeOfByteArray :: ByteArray -> Int
@@ -204,3 +217,146 @@ pinnedByteArrayToForeignPtr ba# =
204
217
ForeignPtr (byteArrayContents# ba# ) (PlainPtr (unsafeCoerce# ba# ))
205
218
{-# INLINE pinnedByteArrayToForeignPtr #-}
206
219
#endif
220
+
221
+ -----------------
222
+ -- Boxed Array --
223
+ -----------------
224
+
225
+ data Array a = Array (Array # a )
226
+
227
+ data MutableArray s a = MutableArray (MutableArray # s a )
228
+
229
+ newMutableArray :: Int -> a -> ST s (MutableArray s a )
230
+ newMutableArray (I # n# ) a =
231
+ ST $ \ s# ->
232
+ case newArray# n# a s# of
233
+ (# s'# , ma# # ) -> (# s'# , MutableArray ma# # )
234
+ {-# INLINE newMutableArray #-}
235
+
236
+ freezeMutableArray :: MutableArray s a -> ST s (Array a )
237
+ freezeMutableArray (MutableArray ma# ) =
238
+ ST $ \ s# ->
239
+ case unsafeFreezeArray# ma# s# of
240
+ (# s'# , a# # ) -> (# s'# , Array a# # )
241
+ {-# INLINE freezeMutableArray #-}
242
+
243
+ sizeOfMutableArray :: MutableArray s a -> Int
244
+ sizeOfMutableArray (MutableArray ma# ) = I # (sizeofMutableArray# ma# )
245
+ {-# INLINE sizeOfMutableArray #-}
246
+
247
+ readArray :: MutableArray s a -> Int -> ST s a
248
+ readArray (MutableArray ma# ) (I # i# ) = ST (readArray# ma# i# )
249
+ {-# INLINE readArray #-}
250
+
251
+ writeArray :: MutableArray s a -> Int -> a -> ST s ()
252
+ writeArray (MutableArray ma# ) (I # i# ) a = st_ (writeArray# ma# i# a)
253
+ {-# INLINE writeArray #-}
254
+
255
+ swapArray :: MutableArray s a -> Int -> Int -> ST s ()
256
+ swapArray ma i j = do
257
+ x <- readArray ma i
258
+ y <- readArray ma j
259
+ writeArray ma j x
260
+ writeArray ma i y
261
+ {-# INLINE swapArray #-}
262
+
263
+ -- | Write contents of the list into the mutable array. Make sure that array is big
264
+ -- enough or segfault will happen.
265
+ fillMutableArrayFromList :: MutableArray s a -> [a ] -> ST s ()
266
+ fillMutableArrayFromList ma = go 0
267
+ where
268
+ go _ [] = pure ()
269
+ go i (x: xs) = writeArray ma i x >> go (i + 1 ) xs
270
+ {-# INLINE fillMutableArrayFromList #-}
271
+
272
+ readListFromMutableArray :: MutableArray s a -> ST s [a ]
273
+ readListFromMutableArray ma = go (len - 1 ) []
274
+ where
275
+ len = sizeOfMutableArray ma
276
+ go i ! acc
277
+ | i >= 0 = do
278
+ x <- readArray ma i
279
+ go (i - 1 ) (x : acc)
280
+ | otherwise = pure acc
281
+ {-# INLINE readListFromMutableArray #-}
282
+
283
+
284
+ -- | Generate a list of indices that will be used for swapping elements in uniform shuffling:
285
+ --
286
+ -- @
287
+ -- [ (0, n - 1)
288
+ -- , (0, n - 2)
289
+ -- , (0, n - 3)
290
+ -- , ...
291
+ -- , (0, 3)
292
+ -- , (0, 2)
293
+ -- , (0, 1)
294
+ -- ]
295
+ -- @
296
+ genSwapIndices
297
+ :: Monad m
298
+ => (Word -> m Word )
299
+ -- ^ Action that generates a Word in the supplied range.
300
+ -> Word
301
+ -- ^ Number of index swaps to generate.
302
+ -> m [Int ]
303
+ genSwapIndices genWordR n = go 1 []
304
+ where
305
+ go i ! acc
306
+ | i >= n = pure acc
307
+ | otherwise = do
308
+ x <- genWordR i
309
+ let ! xi = fromIntegral x
310
+ go (i + 1 ) (xi : acc)
311
+ {-# INLINE genSwapIndices #-}
312
+
313
+
314
+ -- | Implementation of mutable version of Fisher-Yates shuffle. Unfortunately, we cannot generally
315
+ -- interleave pseudo-random number generation and mutation of `ST` monad, therefore we have to
316
+ -- pre-generate all of the index swaps with `genSwapIndices` and store them in a list before we can
317
+ -- perform the actual swaps.
318
+ shuffleListM :: Monad m => (Word -> m Word ) -> [a ] -> m [a ]
319
+ shuffleListM genWordR ls
320
+ | len <= 1 = pure ls
321
+ | otherwise = do
322
+ swapIxs <- genSwapIndices genWordR (fromIntegral len)
323
+ pure $ runST $ do
324
+ ma <- newMutableArray len $ error " Impossible: shuffleListM"
325
+ fillMutableArrayFromList ma ls
326
+
327
+ -- Shuffle elements of the mutable array according to the uniformly generated index swap list
328
+ let goSwap _ [] = pure ()
329
+ goSwap i (j: js) = swapArray ma i j >> goSwap (i - 1 ) js
330
+ goSwap (len - 1 ) swapIxs
331
+
332
+ readListFromMutableArray ma
333
+ where
334
+ len = length ls
335
+ {-# INLINE shuffleListM #-}
336
+
337
+ -- | This is a ~x2-x3 more efficient version of `shuffleListM`. It is more efficient because it does
338
+ -- not need to pregenerate a list of indices and instead generates them on demand. Because of this the
339
+ -- result that will be produced will differ for the same generator, since the order in which index
340
+ -- swaps are generated is reversed.
341
+ --
342
+ -- Unfortunately, most stateful generator monads can't handle `MonadTrans`, so this version is only
343
+ -- used for implementing the pure shuffle.
344
+ shuffleListST :: (Monad (t (ST s )), MonadTrans t ) => (Word -> t (ST s ) Word ) -> [a ] -> t (ST s ) [a ]
345
+ shuffleListST genWordR ls
346
+ | len <= 1 = pure ls
347
+ | otherwise = do
348
+ ma <- lift $ newMutableArray len $ error " Impossible: shuffleListST"
349
+ lift $ fillMutableArrayFromList ma ls
350
+
351
+ -- Shuffle elements of the mutable array according to the uniformly generated index swap
352
+ let goSwap i =
353
+ when (i > 0 ) $ do
354
+ j <- genWordR $ (fromIntegral :: Int -> Word ) i
355
+ lift $ swapArray ma i ((fromIntegral :: Word -> Int ) j)
356
+ goSwap (i - 1 )
357
+ goSwap (len - 1 )
358
+
359
+ lift $ readListFromMutableArray ma
360
+ where
361
+ len = length ls
362
+ {-# INLINE shuffleListST #-}
0 commit comments