Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RE-5863] Resolve potential race condition in log fields #185

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 33 additions & 7 deletions utils/log/loggers/fields.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package loggers

import (
"context"
"sync"
)

type logsContext string
Expand All @@ -12,6 +13,10 @@ var (

//LogFields contains all fields that have to be added to logs
type LogFields map[string]interface{}
type ProtectedLogFields struct {
Content LogFields
mtx sync.RWMutex
achichen marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As @ankurs suggested, sync.Map is indeed a better option.

According to https://pkg.go.dev/sync#Map, it optimizes for our use case:

when the entry for a given key is only ever written once but read many times, as in caches that only grow

}

// Add or modify log fields
func (o LogFields) Add(key string, value interface{}) {
Expand All @@ -28,14 +33,17 @@ func (o LogFields) Del(key string) {
//AddToLogContext adds log fields to context.
// Any info added here will be added to all logs using this context
func AddToLogContext(ctx context.Context, key string, value interface{}) context.Context {
data := FromContext(ctx)
data := fromContext(ctx)
//Initialize if key doesn't exist
if data == nil {
ctx = context.WithValue(ctx, contextKey, make(LogFields))
data = FromContext(ctx)
ctx = context.WithValue(ctx, contextKey, &ProtectedLogFields{Content: make(LogFields)})
data = fromContext(ctx)
}
m := ctx.Value(contextKey)
if data, ok := m.(LogFields); ok {
data.Add(key, value)
if data, ok := m.(*ProtectedLogFields); ok {
data.mtx.Lock()
defer data.mtx.Unlock()
data.Content.Add(key, value)
}
achichen marked this conversation as resolved.
Show resolved Hide resolved
return ctx
}
Expand All @@ -46,8 +54,26 @@ func FromContext(ctx context.Context) LogFields {
return nil
}
if h := ctx.Value(contextKey); h != nil {
if logData, ok := h.(LogFields); ok {
return logData
if plf, ok := h.(*ProtectedLogFields); ok {
plf.mtx.RLock()
defer plf.mtx.RUnlock()
content := make(LogFields)
for k, v := range plf.Content {
content[k] = v
}
return content
}
}
achichen marked this conversation as resolved.
Show resolved Hide resolved
return nil
}

func fromContext(ctx context.Context) *ProtectedLogFields {
if ctx == nil {
return nil
}
if h := ctx.Value(contextKey); h != nil {
if plf, ok := h.(*ProtectedLogFields); ok {
return plf
}
}
return nil
Expand Down
27 changes: 27 additions & 0 deletions utils/log/loggers/test/benchmark_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
//go test -v -bench=. -run=none .
package loggers_test

import (
"context"
"fmt"
"testing"

s "github.com/carousell/Orion/utils/log/loggers"
)

func BenchmarkFromContext(b *testing.B) {
ctx := context.Background()
for i := 0; i < 10000; i++ {
s.AddToLogContext(ctx, fmt.Sprintf("key%d", i), "good value")
}
for i := 0; i < b.N; i++ {
s.FromContext(ctx)
}
}

func BenchmarkFromAddToLogContext(b *testing.B) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

function name should be BenchmarkAddToLogContext

ctx := context.Background()
for i := 0; i < b.N; i++ {
s.AddToLogContext(ctx, fmt.Sprintf("key%d", i), "good value")
}
}
19 changes: 19 additions & 0 deletions utils/log/loggers/test/example_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package loggers_test

import (
"context"
"fmt"

s "github.com/carousell/Orion/utils/log/loggers"
)

func ExampleFromContext() {
ctx := context.Background()
ctx = s.AddToLogContext(ctx, "indespensable", "amazing data")
ctx = s.AddToLogContext(ctx, "preciousData", "valuable key")
lf := s.FromContext(ctx)
fmt.Println(lf)

// Output:
// map[indespensable:amazing data preciousData:valuable key]
}
103 changes: 103 additions & 0 deletions utils/log/loggers/test/parallelism_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
//go test -race
package loggers_test

import (
"context"
"fmt"
"math/rand"
"sync"
"testing"
"time"

s "github.com/carousell/Orion/utils/log/loggers"
"github.com/stretchr/testify/assert"
)

const readWorkerCount = 50
const writeWorkerCount = 50

func readWorker(idx int, ctx context.Context) {
s.FromContext(ctx)
// simulate reading task
time.Sleep(time.Millisecond * 250)
}

func writeWorker(idx int, ctx context.Context) context.Context {
key := fmt.Sprintf("key%d", idx)
val := fmt.Sprintf("val%d", rand.Intn(10000))
ctx = s.AddToLogContext(ctx, key, val)
time.Sleep(time.Millisecond * 250)
return ctx
}

func TestParallelRead(t *testing.T) {
// LogContext init, non-paralel
ctx := context.Background()
ctx = s.AddToLogContext(ctx, "k1", "v1")
ctx = s.AddToLogContext(ctx, "k2", "v2")

var wg sync.WaitGroup
for i := 1; i <= readWorkerCount; i++ {
wg.Add(1)
go func(j int) {
defer wg.Done()
readWorker(j, ctx)
}(i)
}
wg.Wait()
}

func TestParallelWrite(t *testing.T) {
ctx := context.Background()
ctx = s.AddToLogContext(ctx, "test-key", "test-value")

var wg sync.WaitGroup
for i := 1; i <= writeWorkerCount; i++ {
wg.Add(1)
go func(j int) {
defer wg.Done()
writeWorker(j, ctx)
}(i)
}
wg.Wait()

lf := s.FromContext(ctx)

assert.Contains(t, lf, "test-key")
for i := 1; i <= writeWorkerCount; i++ {
key := fmt.Sprintf("key%d", i)
assert.Contains(t, lf, key)
}
}

func TestParallelReadAndWrite(t *testing.T) {
ctx := context.Background()
ctx = s.AddToLogContext(ctx, "test-key", "test-value")

var wgRead sync.WaitGroup
for i := 1; i <= readWorkerCount; i++ {
wgRead.Add(1)
go func(j int) {
defer wgRead.Done()
readWorker(j, ctx)
}(i)
}
var wgWrite sync.WaitGroup
for i := 1; i <= writeWorkerCount; i++ {
wgWrite.Add(1)
go func(j int) {
defer wgWrite.Done()
writeWorker(j, ctx)
}(i)
}
wgRead.Wait()
wgWrite.Wait()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems we could use single wait group here, how about using the same wg?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay! changed.


lf := s.FromContext(ctx)

assert.Contains(t, lf, "test-key")
for i := 1; i <= writeWorkerCount; i++ {
key := fmt.Sprintf("key%d", i)
assert.Contains(t, lf, key)
}
}