diff --git a/encode_test.go b/encode_test.go index 1165f6a6..0c26d84a 100644 --- a/encode_test.go +++ b/encode_test.go @@ -1999,30 +1999,33 @@ type marshalContextKey struct{} type marshalContextStructType struct{} func (t *marshalContextStructType) MarshalJSON(ctx context.Context) ([]byte, error) { + if ctx == nil { + return []byte(`"no context"`), nil + } v := ctx.Value(marshalContextKey{}) + if v == nil { + return []byte(`"no value in context"`), nil + } s, ok := v.(string) if !ok { - return nil, fmt.Errorf("failed to propagate parent context.Context") + return []byte(`"unexpected value in context"`), nil } - if s != "hello" { - return nil, fmt.Errorf("failed to propagate parent context.Context") - } - return []byte(`"success"`), nil + return []byte(`"` + s + `"`), nil } func TestEncodeContextOption(t *testing.T) { t.Run("MarshalContext", func(t *testing.T) { - ctx := context.WithValue(context.Background(), marshalContextKey{}, "hello") + ctx := context.WithValue(context.Background(), marshalContextKey{}, "success") b, err := json.MarshalContext(ctx, &marshalContextStructType{}) if err != nil { t.Fatal(err) } if string(b) != `"success"` { - t.Fatal("failed to encode with MarshalerContext") + t.Fatal("failed to encode with MarshalContext") } }) t.Run("EncodeContext", func(t *testing.T) { - ctx := context.WithValue(context.Background(), marshalContextKey{}, "hello") + ctx := context.WithValue(context.Background(), marshalContextKey{}, "success") buf := bytes.NewBuffer([]byte{}) if err := json.NewEncoder(buf).EncodeContext(ctx, &marshalContextStructType{}); err != nil { t.Fatal(err) @@ -2031,6 +2034,25 @@ func TestEncodeContextOption(t *testing.T) { t.Fatal("failed to encode with EncodeContext") } }) + t.Run("Marshal after MarshalContext", func(t *testing.T) { + // Regression test for https://github.com/goccy/go-json/issues/499 + ctx := context.WithValue(context.Background(), marshalContextKey{}, "success") + b, err := json.MarshalContext(ctx, &marshalContextStructType{}) + if err != nil { + t.Fatal(err) + } + if string(b) != `"success"` { + t.Fatal("failed to encode with MarshalContext") + } + + b, err = json.Marshal(&marshalContextStructType{}) + if err != nil { + t.Fatal(err) + } + if string(b) != `"no context"` { + t.Fatal("failed to encode with Marshal") + } + }) } func TestInterfaceWithPointer(t *testing.T) { diff --git a/internal/decoder/context.go b/internal/decoder/context.go index cb2ffdaf..ff2b3696 100644 --- a/internal/decoder/context.go +++ b/internal/decoder/context.go @@ -27,6 +27,7 @@ func TakeRuntimeContext() *RuntimeContext { } func ReleaseRuntimeContext(ctx *RuntimeContext) { + ctx.Option.Context = nil runtimeContextPool.Put(ctx) } diff --git a/internal/encoder/context.go b/internal/encoder/context.go index 3833d0c8..c6795ea6 100644 --- a/internal/encoder/context.go +++ b/internal/encoder/context.go @@ -1,7 +1,6 @@ package encoder import ( - "context" "sync" "unsafe" @@ -69,7 +68,6 @@ var ( ) type RuntimeContext struct { - Context context.Context Buf []byte MarshalBuf []byte Ptrs []uintptr @@ -101,5 +99,6 @@ func TakeRuntimeContext() *RuntimeContext { } func ReleaseRuntimeContext(ctx *RuntimeContext) { + ctx.Option.Context = nil runtimeContextPool.Put(ctx) }