Skip to content

Commit

Permalink
Portability fixes (#1469)
Browse files Browse the repository at this point in the history
- Fix portability of `choose_multiple_array`
- Fix portability of `rand::distributions::Slice`
  • Loading branch information
dhardy authored Jul 23, 2024
1 parent f3aab23 commit 605476c
Show file tree
Hide file tree
Showing 8 changed files with 67 additions and 18 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@ You may also find the [Upgrade Guide](https://rust-random.github.io/book/update.

## [Unreleased]
- Add `rand::distributions::WeightedIndex::{weight, weights, total_weight}` (#1420)
- Add `IndexedRandom::choose_multiple_array`, `index::sample_array` (#1453)
- Add `IndexedRandom::choose_multiple_array`, `index::sample_array` (#1453, #1469)
- Bump the MSRV to 1.61.0
- Rename `Rng::gen` to `Rng::random` to avoid conflict with the new `gen` keyword in Rust 2024 (#1435)
- Move all benchmarks to new `benches` crate (#1439)
- Annotate panicking methods with `#[track_caller]` (#1442, #1447)
- Enable feature `small_rng` by default (#1455)
- Allow `UniformFloat::new` samples and `UniformFloat::sample_single` to yield `high` (#1462)
- Fix portability of `rand::distributions::Slice` (#1469)

## [0.9.0-alpha.1] - 2024-03-18
- Add the `Slice::num_choices` method to the Slice distribution (#1402)
Expand Down
9 changes: 1 addition & 8 deletions rand_distr/src/dirichlet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -333,20 +333,13 @@ where
#[cfg(test)]
mod test {
use super::*;
use alloc::vec::Vec;

#[test]
fn test_dirichlet() {
let d = Dirichlet::new([1.0, 2.0, 3.0]).unwrap();
let mut rng = crate::test::rng(221);
let samples = d.sample(&mut rng);
let _: Vec<f64> = samples
.into_iter()
.map(|x| {
assert!(x > 0.0);
x
})
.collect();
assert!(samples.into_iter().all(|x: f64| x > 0.0));
}

#[test]
Expand Down
48 changes: 46 additions & 2 deletions src/distributions/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,39 @@
use core::num::NonZeroUsize;

use crate::distributions::{Distribution, Uniform};
use crate::Rng;
#[cfg(feature = "alloc")]
use alloc::string::String;

#[cfg(not(any(target_pointer_width = "32", target_pointer_width = "64")))]
compile_error!("unsupported pointer width");

#[derive(Debug, Clone, Copy)]
enum UniformUsize {
U32(Uniform<u32>),
#[cfg(target_pointer_width = "64")]
U64(Uniform<u64>),
}

impl UniformUsize {
pub fn new(ubound: usize) -> Result<Self, super::uniform::Error> {
#[cfg(target_pointer_width = "64")]
if ubound > (u32::MAX as usize) {
return Uniform::new(0, ubound as u64).map(UniformUsize::U64);
}

Uniform::new(0, ubound as u32).map(UniformUsize::U32)
}

pub fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
match self {
UniformUsize::U32(uu) => uu.sample(rng) as usize,
#[cfg(target_pointer_width = "64")]
UniformUsize::U64(uu) => uu.sample(rng) as usize,
}
}
}

/// A distribution to sample items uniformly from a slice.
///
/// [`Slice::new`] constructs a distribution referencing a slice and uniformly
Expand Down Expand Up @@ -68,7 +98,7 @@ use alloc::string::String;
#[derive(Debug, Clone, Copy)]
pub struct Slice<'a, T> {
slice: &'a [T],
range: Uniform<usize>,
range: UniformUsize,
num_choices: NonZeroUsize,
}

Expand All @@ -80,7 +110,7 @@ impl<'a, T> Slice<'a, T> {

Ok(Self {
slice,
range: Uniform::new(0, num_choices.get()).unwrap(),
range: UniformUsize::new(num_choices.get()).unwrap(),
num_choices,
})
}
Expand Down Expand Up @@ -161,3 +191,17 @@ impl<'a> super::DistString for Slice<'a, char> {
}
}
}

#[cfg(test)]
mod test {
use super::*;
use core::iter;

#[test]
fn value_stability() {
let rng = crate::test::rng(651);
let slice = Slice::new(b"escaped emus explore extensively").unwrap();
let expected = b"eaxee";
assert!(iter::zip(slice.sample_iter(rng), expected).all(|(a, b)| a == b));
}
}
2 changes: 2 additions & 0 deletions src/distributions/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,9 @@ pub(crate) trait FloatSIMDScalarUtils: FloatSIMDUtils {

/// Implement functions on f32/f64 to give them APIs similar to SIMD types
pub(crate) trait FloatAsSIMD: Sized {
#[cfg(test)]
const LEN: usize = 1;

#[inline(always)]
fn splat(scalar: Self) -> Self {
scalar
Expand Down
5 changes: 4 additions & 1 deletion src/rngs/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,10 @@ pub mod mock; // Public so we don't export `StepRng` directly, making it a bit

#[cfg(feature = "small_rng")]
mod small;
#[cfg(all(feature = "small_rng", not(target_pointer_width = "64")))]
#[cfg(all(
feature = "small_rng",
any(target_pointer_width = "32", target_pointer_width = "16")
))]
mod xoshiro128plusplus;
#[cfg(all(feature = "small_rng", target_pointer_width = "64"))]
mod xoshiro256plusplus;
Expand Down
4 changes: 2 additions & 2 deletions src/rngs/small.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@

use rand_core::{RngCore, SeedableRng};

#[cfg(any(target_pointer_width = "32", target_pointer_width = "16"))]
type Rng = super::xoshiro128plusplus::Xoshiro128PlusPlus;
#[cfg(target_pointer_width = "64")]
type Rng = super::xoshiro256plusplus::Xoshiro256PlusPlus;
#[cfg(not(target_pointer_width = "64"))]
type Rng = super::xoshiro128plusplus::Xoshiro128PlusPlus;

/// A small-state, fast, non-crypto, non-portable PRNG
///
Expand Down
3 changes: 2 additions & 1 deletion src/seq/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
// except according to those terms.

//! Low-level API for sampling indices
use super::gen_index;
#[cfg(feature = "alloc")]
use alloc::vec::{self, Vec};
use core::slice;
Expand Down Expand Up @@ -288,7 +289,7 @@ where
// Floyd's algorithm
let mut indices = [0; N];
for (i, j) in (len - N..len).enumerate() {
let t = rng.gen_range(0..=j);
let t = gen_index(rng, j + 1);
if let Some(pos) = indices[0..i].iter().position(|&x| x == t) {
indices[pos] = j;
}
Expand Down
11 changes: 8 additions & 3 deletions src/seq/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -495,19 +495,24 @@ mod test {
assert_eq!(chars.choose(&mut r), Some(&'l'));
assert_eq!(nums.choose_mut(&mut r), Some(&mut 3));

assert_eq!(
&chars.choose_multiple_array(&mut r),
&Some(['f', 'i', 'd', 'b', 'c', 'm', 'j', 'k'])
);

#[cfg(feature = "alloc")]
assert_eq!(
&chars
.choose_multiple(&mut r, 8)
.cloned()
.collect::<Vec<char>>(),
&['f', 'i', 'd', 'b', 'c', 'm', 'j', 'k']
&['h', 'm', 'd', 'b', 'c', 'e', 'n', 'f']
);

#[cfg(feature = "alloc")]
assert_eq!(chars.choose_weighted(&mut r, |_| 1), Ok(&'l'));
assert_eq!(chars.choose_weighted(&mut r, |_| 1), Ok(&'i'));
#[cfg(feature = "alloc")]
assert_eq!(nums.choose_weighted_mut(&mut r, |_| 1), Ok(&mut 8));
assert_eq!(nums.choose_weighted_mut(&mut r, |_| 1), Ok(&mut 2));

let mut r = crate::test::rng(414);
nums.shuffle(&mut r);
Expand Down

0 comments on commit 605476c

Please sign in to comment.