Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(iter): Added context accepting variants of Map & ForEach #114

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
57 changes: 49 additions & 8 deletions iter/iter.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
package iter

import (
"context"
"runtime"
"sync/atomic"

"github.com/sourcegraph/conc"
"github.com/sourcegraph/conc/pool"
)

// defaultMaxGoroutines returns the default maximum number of
Expand Down Expand Up @@ -57,29 +58,69 @@ func ForEachIdx[T any](input []T, f func(int, *T)) { Iterator[T]{}.ForEachIdx(in
// ForEachIdx is the same as ForEach except it also provides the
// index of the element to the callback.
func (iter Iterator[T]) ForEachIdx(input []T, f func(int, *T)) {
_ = iter.ForEachIdxCtx(context.Background(), input, func(_ context.Context, idx int, input *T) error {
f(idx, input)
return nil
})
}

// ForEachCtx is the same as ForEach except it also accepts a context
// that it uses to manages the execution of tasks.
// The context is cancelled on task failure and the first error is returned.
func ForEachCtx[T any](ctx context.Context, input []T, f func(context.Context, *T) error) error {
return Iterator[T]{}.ForEachCtx(ctx, input, f)
}

// ForEachCtx is the same as ForEach except it also accepts a context
// that it uses to manages the execution of tasks.
// The context is cancelled on task failure and the first error is returned.
func (iter Iterator[T]) ForEachCtx(ctx context.Context, input []T, f func(context.Context, *T) error) error {
return iter.ForEachIdxCtx(ctx, input, func(innerctx context.Context, _ int, input *T) error {
return f(innerctx, input)
})
}

// ForEachIdxCtx is the same as ForEachIdx except it also accepts a context
// that it uses to manages the execution of tasks.
// The context is cancelled on task failure and the first error is returned.
func ForEachIdxCtx[T any](ctx context.Context, input []T, f func(context.Context, int, *T) error) error {
return Iterator[T]{}.ForEachIdxCtx(ctx, input, f)
}

// ForEachIdxCtx is the same as ForEachIdx except it also accepts a context
// that it uses to manages the execution of tasks.
// The context is cancelled on task failure and the first error is returned.
func (iter Iterator[T]) ForEachIdxCtx(ctx context.Context, input []T, f func(context.Context, int, *T) error) error {
if iter.MaxGoroutines == 0 {
// iter is a value receiver and is hence safe to mutate
iter.MaxGoroutines = defaultMaxGoroutines()
}

numInput := len(input)
if iter.MaxGoroutines > numInput {
if iter.MaxGoroutines > numInput && numInput > 0 {
// No more concurrent tasks than the number of input items.
iter.MaxGoroutines = numInput
}

var idx atomic.Int64
// Create the task outside the loop to avoid extra closure allocations.
task := func() {
task := func(innerctx context.Context) error {
i := int(idx.Add(1) - 1)
for ; i < numInput; i = int(idx.Add(1) - 1) {
f(i, &input[i])
for ; i < numInput && innerctx.Err() == nil; i = int(idx.Add(1) - 1) {
if err := f(innerctx, i, &input[i]); err != nil {
return err
}
}
return innerctx.Err() // nil if the context was never cancelled
}

var wg conc.WaitGroup
runner := pool.New().
WithContext(ctx).
WithCancelOnError().
WithFirstError().
WithMaxGoroutines(iter.MaxGoroutines)
for i := 0; i < iter.MaxGoroutines; i++ {
wg.Go(task)
runner.Go(task)
}
wg.Wait()
return runner.Wait()
}
77 changes: 70 additions & 7 deletions iter/iter_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package iter_test

import (
"context"
"errors"
"fmt"
"strconv"
"sync/atomic"
Expand Down Expand Up @@ -72,16 +74,18 @@ func TestIterator(t *testing.T) {
})
}

func TestForEachIdx(t *testing.T) {
func TestForEachIdxCtx(t *testing.T) {
t.Parallel()

bgctx := context.Background()
t.Run("empty", func(t *testing.T) {
t.Parallel()
f := func() {
ints := []int{}
iter.ForEachIdx(ints, func(i int, val *int) {
err := iter.ForEachIdxCtx(bgctx, ints, func(ctx context.Context, i int, val *int) error {
panic("this should never be called")
})
require.NoError(t, err)
}
require.NotPanics(t, f)
})
Expand All @@ -90,33 +94,57 @@ func TestForEachIdx(t *testing.T) {
t.Parallel()
f := func() {
ints := []int{1}
iter.ForEachIdx(ints, func(i int, val *int) {
panic("super bad thing happened")
})
_ = iter.ForEachIdxCtx(bgctx, ints,
func(ctx context.Context, i int, val *int) error {
panic("super bad thing happened")
})
}
require.Panics(t, f)
})

t.Run("mutating inputs is fine", func(t *testing.T) {
t.Parallel()
ints := []int{1, 2, 3, 4, 5}
iter.ForEachIdx(ints, func(i int, val *int) {
err := iter.ForEachIdxCtx(bgctx, ints, func(ctx context.Context, i int, val *int) error {
*val += 1
return nil
})
require.Equal(t, []int{2, 3, 4, 5, 6}, ints)
require.NoError(t, err)
})

t.Run("huge inputs", func(t *testing.T) {
t.Parallel()
ints := make([]int, 10000)
iter.ForEachIdx(ints, func(i int, val *int) {
err := iter.ForEachIdxCtx(bgctx, ints, func(ctx context.Context, i int, val *int) error {
*val = i
return nil
})
expected := make([]int, 10000)
for i := 0; i < 10000; i++ {
expected[i] = i
}
require.Equal(t, expected, ints)
require.NoError(t, err)
})

err1 := errors.New("error1")
err2 := errors.New("error2")

t.Run("first error is propagated", func(t *testing.T) {
t.Parallel()
ints := []int{1, 2, 3, 4, 5}
err := iter.ForEachIdxCtx(bgctx, ints, func(ctx context.Context, i int, val *int) error {
if *val == 3 {
return err1
}
if *val == 4 {
return err2
}
return nil
})
require.ErrorIs(t, err, err1)
require.NotErrorIs(t, err, err2)
})
}

Expand Down Expand Up @@ -168,6 +196,41 @@ func TestForEach(t *testing.T) {
})
}

func TestForEachCtx(t *testing.T) {
t.Parallel()

bgctx := context.Background()
t.Run("mutating inputs is fine", func(t *testing.T) {
t.Parallel()
ints := []int{1, 2, 3, 4, 5}
err := iter.ForEachCtx(bgctx, ints, func(ctx context.Context, val *int) error {
*val += 1
return nil
})
require.Equal(t, []int{2, 3, 4, 5, 6}, ints)
require.NoError(t, err)
})

err1 := errors.New("error1")
err2 := errors.New("error2")

t.Run("first error is propagated", func(t *testing.T) {
t.Parallel()
ints := []int{1, 2, 3, 4, 5}
err := iter.ForEachCtx(bgctx, ints, func(ctx context.Context, val *int) error {
if *val == 3 {
return err1
}
if *val == 4 {
return err2
}
return nil
})
require.ErrorIs(t, err, err1)
require.NotErrorIs(t, err, err2)
})
}

func BenchmarkForEach(b *testing.B) {
for _, count := range []int{0, 1, 8, 100, 1000, 10000, 100000} {
b.Run(strconv.Itoa(count), func(b *testing.B) {
Expand Down
35 changes: 28 additions & 7 deletions iter/map.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package iter

import (
"context"
"errors"
"sync"
)
Expand All @@ -24,9 +25,8 @@ func Map[T, R any](input []T, f func(*T) R) []R {
//
// Map uses up to the configured Mapper's maximum number of goroutines.
func (m Mapper[T, R]) Map(input []T, f func(*T) R) []R {
res := make([]R, len(input))
Iterator[T](m).ForEachIdx(input, func(i int, t *T) {
res[i] = f(t)
res, _ := m.MapCtx(context.Background(), input, func(_ context.Context, t *T) (R, error) {
return f(t), nil
})
return res
}
Expand All @@ -46,18 +46,39 @@ func MapErr[T, R any](input []T, f func(*T) (R, error)) ([]R, error) {
// Map uses up to the configured Mapper's maximum number of goroutines.
func (m Mapper[T, R]) MapErr(input []T, f func(*T) (R, error)) ([]R, error) {
var (
res = make([]R, len(input))
errMux sync.Mutex
errs []error
)
Iterator[T](m).ForEachIdx(input, func(i int, t *T) {
var err error
res[i], err = f(t)
// MapErr handles its own errors by accumulating them as a multierror, ignoring the error from MapCtx which is only the first error
res, _ := m.MapCtx(context.Background(), input, func(ctx context.Context, t *T) (R, error) {
ires, err := f(t)
if err != nil {
errMux.Lock()
errs = append(errs, err)
errMux.Unlock()
}
return ires, nil
})
return res, errors.Join(errs...)
}

// MapCtx is the same as Map except it also accepts a context
// that it uses to manages the execution of tasks.
// The context is cancelled on task failure and the first error is returned.
func MapCtx[T, R any](ctx context.Context, input []T, f func(context.Context, *T) (R, error)) ([]R, error) {
return Mapper[T, R]{}.MapCtx(ctx, input, f)
}

// MapCtx is the same as Map except it also accepts a context
// that it uses to manages the execution of tasks.
// The context is cancelled on task failure and the first error is returned.
func (m Mapper[T, R]) MapCtx(ctx context.Context, input []T, f func(context.Context, *T) (R, error)) ([]R, error) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm interested in this as well 🙂 What do you think about using a builder for the Mapper? I would like a WithCancelOnError as well.

var (
res = make([]R, len(input))
)
return res, Iterator[T](m).ForEachIdxCtx(ctx, input, func(innerctx context.Context, i int, t *T) error {
var err error
res[i], err = f(innerctx, t)
return err
})
}
Loading