diff --git a/README.md b/README.md index 73c33223..7eb109d9 100644 --- a/README.md +++ b/README.md @@ -238,7 +238,9 @@ Supported search helpers: - [LastOr](#LastOr) - [Nth](#nth) - [Sample](#sample) +- [SampleBy](#sampleby) - [Samples](#samples) +- [SamplesBy](#samplesby) Conditional helpers: @@ -2494,6 +2496,21 @@ lo.Sample([]string{}) // "" ``` +### SampleBy + +Returns a random item from collection, using a given random integer generator. + +```go +import "math/rand" + +r := rand.New(rand.NewSource(42)) +lo.SampleBy([]string{"a", "b", "c"}, r.Intn) +// a random string from []string{"a", "b", "c"}, using a seeded random generator + +lo.SampleBy([]string{}, r.Intn) +// "" +``` + ### Samples Returns N random unique items from collection. @@ -2503,6 +2520,16 @@ lo.Samples([]string{"a", "b", "c"}, 3) // []string{"a", "b", "c"} in random order ``` +### SamplesBy + +Returns N random unique items from collection, using a given random integer generator. + +```go +r := rand.New(rand.NewSource(42)) +lo.SamplesBy([]string{"a", "b", "c"}, 3, r.Intn) +// []string{"a", "b", "c"} in random order, using a seeded random generator +``` + ### Ternary A 1 line if/else statement. diff --git a/find.go b/find.go index 59c23460..8e80156b 100644 --- a/find.go +++ b/find.go @@ -473,18 +473,33 @@ func Nth[T any, N constraints.Integer](collection []T, nth N) (T, error) { return collection[l+n], nil } +// randomIntGenerator is a function that should return a random integer in the range [0, n) +// where n is the parameter passed to the randomIntGenerator. +type randomIntGenerator func(n int) int + // Sample returns a random item from collection. func Sample[T any](collection []T) T { + result := SampleBy(collection, rand.IntN) + return result +} + +// SampleBy returns a random item from collection, using randomIntGenerator as the random index generator. +func SampleBy[T any](collection []T, randomIntGenerator randomIntGenerator) T { size := len(collection) if size == 0 { return Empty[T]() } - - return collection[rand.IntN(size)] + return collection[randomIntGenerator(size)] } // Samples returns N random unique items from collection. func Samples[T any, Slice ~[]T](collection Slice, count int) Slice { + results := SamplesBy(collection, count, rand.IntN) + return results +} + +// SamplesBy returns N random unique items from collection, using randomIntGenerator as the random index generator. +func SamplesBy[T any, Slice ~[]T](collection Slice, count int, randomIntGenerator randomIntGenerator) Slice { size := len(collection) copy := append(Slice{}, collection...) @@ -494,7 +509,7 @@ func Samples[T any, Slice ~[]T](collection Slice, count int) Slice { for i := 0; i < size && i < count; i++ { copyLength := size - i - index := rand.IntN(size - i) + index := randomIntGenerator(size - i) results = append(results, copy[index]) // Removes element. diff --git a/find_test.go b/find_test.go index b1533997..f52bb6be 100644 --- a/find_test.go +++ b/find_test.go @@ -551,6 +551,19 @@ func TestSample(t *testing.T) { is.Equal(result2, "") } +func TestSampleBy(t *testing.T) { + t.Parallel() + is := assert.New(t) + + r := rand.New(rand.NewSource(42)) + + result1 := SampleBy([]string{"a", "b", "c"}, r.Intn) + result2 := SampleBy([]string{}, rand.Intn) + + is.True(Contains([]string{"a", "b", "c"}, result1)) + is.Equal(result2, "") +} + func TestSamples(t *testing.T) { t.Parallel() is := assert.New(t) @@ -570,3 +583,23 @@ func TestSamples(t *testing.T) { nonempty := Samples(allStrings, 2) is.IsType(nonempty, allStrings, "type preserved") } + +func TestSamplesBy(t *testing.T) { + t.Parallel() + is := assert.New(t) + + r := rand.New(rand.NewSource(42)) + + result1 := SamplesBy([]string{"a", "b", "c"}, 3, r.Intn) + result2 := SamplesBy([]string{}, 3, r.Intn) + + sort.Strings(result1) + + is.Equal(result1, []string{"a", "b", "c"}) + is.Equal(result2, []string{}) + + type myStrings []string + allStrings := myStrings{"", "foo", "bar"} + nonempty := SamplesBy(allStrings, 2, r.Intn) + is.IsType(nonempty, allStrings, "type preserved") +}