From 0bfcec1d8d92cf6484c5696ba7c765b18c75a14f Mon Sep 17 00:00:00 2001 From: pavel Date: Fri, 26 Apr 2024 18:26:55 +0300 Subject: [PATCH] refine ForEachIdxErr --- iter/iter.go | 14 ++++++++------ iter/iter_test.go | 30 +++++++++++++++--------------- 2 files changed, 23 insertions(+), 21 deletions(-) diff --git a/iter/iter.go b/iter/iter.go index 6851eac..5f3358e 100644 --- a/iter/iter.go +++ b/iter/iter.go @@ -1,12 +1,12 @@ package iter import ( + "errors" "runtime" "sync" "sync/atomic" "github.com/sourcegraph/conc" - "github.com/sourcegraph/conc/internal/multierror" ) // defaultMaxGoroutines returns the default maximum number of @@ -127,7 +127,7 @@ func (iter Iterator[T]) ForEachIdxErr(input []T, f func(int, *T) error) error { iter.MaxGoroutines = numInput } - var errs error + var errs []error var errsMu sync.Mutex var idx atomic.Int64 var failed atomic.Bool @@ -137,9 +137,11 @@ func (iter Iterator[T]) ForEachIdxErr(input []T, f func(int, *T) error) error { i := int(idx.Add(1) - 1) for ; i < numInput && !failed.Load(); i = int(idx.Add(1) - 1) { if err := f(i, &input[i]); err != nil { - errsMu.Lock() - errs = multierror.Join(errs, err) - errsMu.Unlock() + if alreadyFailedFast := failed.Swap(iter.FailFast); !alreadyFailedFast { + errsMu.Lock() + errs = append(errs, err) + errsMu.Unlock() + } failed.Store(iter.FailFast) } @@ -152,5 +154,5 @@ func (iter Iterator[T]) ForEachIdxErr(input []T, f func(int, *T) error) error { } wg.Wait() - return errs + return errors.Join(errs...) } diff --git a/iter/iter_test.go b/iter/iter_test.go index de360b9..539803b 100644 --- a/iter/iter_test.go +++ b/iter/iter_test.go @@ -174,13 +174,13 @@ func TestForIterator_EachIdxErr(t *testing.T) { t.Parallel() t.Run("failFast=false", func(t *testing.T) { - it := Iterator[int]{MaxGoroutines: 999} + it := iter.Iterator[int]{MaxGoroutines: 999} forEach := noIndex(it.ForEachIdxErr) testForEachErr(t, false, forEach) }) t.Run("failFast=true", func(t *testing.T) { - it := Iterator[int]{MaxGoroutines: 999} + it := iter.Iterator[int]{MaxGoroutines: 999} forEach := noIndex(it.ForEachIdxErr) testForEachErr(t, true, forEach) }) @@ -190,7 +190,7 @@ func TestForIterator_EachIdxErr(t *testing.T) { input := []int{1, 2, 3, 4, 5} errTest := errors.New("test error") - iterator := Iterator[int]{MaxGoroutines: 1, FailFast: true} + iterator := iter.Iterator[int]{MaxGoroutines: 1, FailFast: true} var mu sync.Mutex var results []int @@ -211,7 +211,7 @@ func TestForIterator_EachIdxErr(t *testing.T) { t.Run("safe for reuse", func(t *testing.T) { t.Parallel() - iterator := Iterator[int]{MaxGoroutines: 999} + iterator := iter.Iterator[int]{MaxGoroutines: 999} // iter.Concurrency > numInput case that updates iter.Concurrency _ = iterator.ForEachIdxErr([]int{1, 2, 3}, func(i int, t *int) error { @@ -224,12 +224,12 @@ func TestForIterator_EachIdxErr(t *testing.T) { t.Run("allows more than defaultMaxGoroutines() concurrent tasks", func(t *testing.T) { t.Parallel() - wantConcurrency := 2 * defaultMaxGoroutines() + wantConcurrency := 2 * iter.DefaultMaxGoroutines() maxConcurrencyHit := make(chan struct{}) tasks := make([]int, wantConcurrency) - iterator := Iterator[int]{MaxGoroutines: wantConcurrency} + iterator := iter.Iterator[int]{MaxGoroutines: wantConcurrency} var concurrentTasks atomic.Int64 _ = iterator.ForEachIdxErr(tasks, func(_ int, t *int) error { @@ -257,19 +257,19 @@ func TestForIterator_EachErr(t *testing.T) { t.Parallel() t.Run("failFast=false", func(t *testing.T) { - it := Iterator[int]{MaxGoroutines: 999} + it := iter.Iterator[int]{MaxGoroutines: 999} testForEachErr(t, false, it.ForEachErr) }) t.Run("failFast=true", func(t *testing.T) { - it := Iterator[int]{MaxGoroutines: 999} + it := iter.Iterator[int]{MaxGoroutines: 999} testForEachErr(t, true, it.ForEachErr) }) t.Run("safe for reuse", func(t *testing.T) { t.Parallel() - iterator := Iterator[int]{MaxGoroutines: 999} + iterator := iter.Iterator[int]{MaxGoroutines: 999} // iter.Concurrency > numInput case that updates iter.Concurrency _ = iterator.ForEachErr([]int{1, 2, 3}, func(t *int) error { @@ -284,7 +284,7 @@ func TestForIterator_EachErr(t *testing.T) { input := []int{1, 2, 3, 4, 5} errTest := errors.New("test error") - iterator := Iterator[int]{MaxGoroutines: 1, FailFast: true} + iterator := iter.Iterator[int]{MaxGoroutines: 1, FailFast: true} var mu sync.Mutex var results []int @@ -305,12 +305,12 @@ func TestForIterator_EachErr(t *testing.T) { t.Run("allows more than defaultMaxGoroutines() concurrent tasks", func(t *testing.T) { t.Parallel() - wantConcurrency := 2 * defaultMaxGoroutines() + wantConcurrency := 2 * iter.DefaultMaxGoroutines() maxConcurrencyHit := make(chan struct{}) tasks := make([]int, wantConcurrency) - iterator := Iterator[int]{MaxGoroutines: wantConcurrency} + iterator := iter.Iterator[int]{MaxGoroutines: wantConcurrency} var concurrentTasks atomic.Int64 _ = iterator.ForEachErr(tasks, func(t *int) error { @@ -338,7 +338,7 @@ func TestForEachIdxErr(t *testing.T) { t.Parallel() t.Run("standart", func(t *testing.T) { - forEach := noIndex(ForEachIdxErr[int]) + forEach := noIndex(iter.ForEachIdxErr[int]) testForEachErr(t, false, forEach) }) @@ -347,7 +347,7 @@ func TestForEachIdxErr(t *testing.T) { got := []int{} gotMu := sync.Mutex{} - err := ForEachIdxErr(ints, func(i int, _ *int) error { + err := iter.ForEachIdxErr(ints, func(i int, _ *int) error { gotMu.Lock() defer gotMu.Unlock() got = append(got, i) @@ -362,7 +362,7 @@ func TestForEachIdxErr(t *testing.T) { func TestForEachErr(t *testing.T) { t.Parallel() - testForEachErr(t, false, ForEachErr[int]) + testForEachErr(t, false, iter.ForEachErr[int]) } // noIndex converts a ForEachIdxErr function (or method) into a ForEachErr function (or method).