diff --git a/common/json/badjson/json.go b/common/json/badjson/json.go index 04dba1eb..35f33a8b 100644 --- a/common/json/badjson/json.go +++ b/common/json/badjson/json.go @@ -2,13 +2,14 @@ package badjson import ( "bytes" + "context" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/json" ) -func Decode(content []byte) (any, error) { - decoder := json.NewDecoder(bytes.NewReader(content)) +func Decode(ctx context.Context, content []byte) (any, error) { + decoder := json.NewDecoderContext(ctx, bytes.NewReader(content)) return decodeJSON(decoder) } diff --git a/common/json/badjson/merge.go b/common/json/badjson/merge.go index ee7193e7..ac1f12fe 100644 --- a/common/json/badjson/merge.go +++ b/common/json/badjson/merge.go @@ -1,6 +1,7 @@ package badjson import ( + "context" "os" "reflect" @@ -9,75 +10,75 @@ import ( "github.com/sagernet/sing/common/json" ) -func Omitempty[T any](value T) (T, error) { +func Omitempty[T any](ctx context.Context, value T) (T, error) { objectContent, err := json.Marshal(value) if err != nil { return common.DefaultValue[T](), E.Cause(err, "marshal object") } - rawNewObject, err := Decode(objectContent) + rawNewObject, err := Decode(ctx, objectContent) if err != nil { return common.DefaultValue[T](), err } - newObjectContent, err := json.Marshal(rawNewObject) + newObjectContent, err := json.MarshalContext(ctx, rawNewObject) if err != nil { return common.DefaultValue[T](), E.Cause(err, "marshal new object") } var newObject T - err = json.Unmarshal(newObjectContent, &newObject) + err = json.UnmarshalContext(ctx, newObjectContent, &newObject) if err != nil { return common.DefaultValue[T](), E.Cause(err, "unmarshal new object") } return newObject, nil } -func Merge[T any](source T, destination T, disableAppend bool) (T, error) { - rawSource, err := json.Marshal(source) +func Merge[T any](ctx context.Context, source T, destination T, disableAppend bool) (T, error) { + rawSource, err := json.MarshalContext(ctx, source) if err != nil { return common.DefaultValue[T](), E.Cause(err, "marshal source") } - rawDestination, err := json.Marshal(destination) + rawDestination, err := json.MarshalContext(ctx, destination) if err != nil { return common.DefaultValue[T](), E.Cause(err, "marshal destination") } - return MergeFrom[T](rawSource, rawDestination, disableAppend) + return MergeFrom[T](ctx, rawSource, rawDestination, disableAppend) } -func MergeFromSource[T any](rawSource json.RawMessage, destination T, disableAppend bool) (T, error) { +func MergeFromSource[T any](ctx context.Context, rawSource json.RawMessage, destination T, disableAppend bool) (T, error) { if rawSource == nil { return destination, nil } - rawDestination, err := json.Marshal(destination) + rawDestination, err := json.MarshalContext(ctx, destination) if err != nil { return common.DefaultValue[T](), E.Cause(err, "marshal destination") } - return MergeFrom[T](rawSource, rawDestination, disableAppend) + return MergeFrom[T](ctx, rawSource, rawDestination, disableAppend) } -func MergeFromDestination[T any](source T, rawDestination json.RawMessage, disableAppend bool) (T, error) { +func MergeFromDestination[T any](ctx context.Context, source T, rawDestination json.RawMessage, disableAppend bool) (T, error) { if rawDestination == nil { return source, nil } - rawSource, err := json.Marshal(source) + rawSource, err := json.MarshalContext(ctx, source) if err != nil { return common.DefaultValue[T](), E.Cause(err, "marshal source") } - return MergeFrom[T](rawSource, rawDestination, disableAppend) + return MergeFrom[T](ctx, rawSource, rawDestination, disableAppend) } -func MergeFrom[T any](rawSource json.RawMessage, rawDestination json.RawMessage, disableAppend bool) (T, error) { - rawMerged, err := MergeJSON(rawSource, rawDestination, disableAppend) +func MergeFrom[T any](ctx context.Context, rawSource json.RawMessage, rawDestination json.RawMessage, disableAppend bool) (T, error) { + rawMerged, err := MergeJSON(ctx, rawSource, rawDestination, disableAppend) if err != nil { return common.DefaultValue[T](), E.Cause(err, "merge options") } var merged T - err = json.Unmarshal(rawMerged, &merged) + err = json.UnmarshalContext(ctx, rawMerged, &merged) if err != nil { return common.DefaultValue[T](), E.Cause(err, "unmarshal merged options") } return merged, nil } -func MergeJSON(rawSource json.RawMessage, rawDestination json.RawMessage, disableAppend bool) (json.RawMessage, error) { +func MergeJSON(ctx context.Context, rawSource json.RawMessage, rawDestination json.RawMessage, disableAppend bool) (json.RawMessage, error) { if rawSource == nil && rawDestination == nil { return nil, os.ErrInvalid } else if rawSource == nil { @@ -85,16 +86,16 @@ func MergeJSON(rawSource json.RawMessage, rawDestination json.RawMessage, disabl } else if rawDestination == nil { return rawSource, nil } - source, err := Decode(rawSource) + source, err := Decode(ctx, rawSource) if err != nil { return nil, E.Cause(err, "decode source") } - destination, err := Decode(rawDestination) + destination, err := Decode(ctx, rawDestination) if err != nil { return nil, E.Cause(err, "decode destination") } if source == nil { - return json.Marshal(destination) + return json.MarshalContext(ctx, destination) } else if destination == nil { return json.Marshal(source) } @@ -102,7 +103,7 @@ func MergeJSON(rawSource json.RawMessage, rawDestination json.RawMessage, disabl if err != nil { return nil, err } - return json.Marshal(merged) + return json.MarshalContext(ctx, merged) } func mergeJSON(anySource any, anyDestination any, disableAppend bool) (any, error) { diff --git a/common/json/badjson/merge_objects.go b/common/json/badjson/merge_objects.go index 37a5dafa..fa6c2d42 100644 --- a/common/json/badjson/merge_objects.go +++ b/common/json/badjson/merge_objects.go @@ -1,32 +1,42 @@ package badjson import ( + "context" + E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/json" ) func MarshallObjects(objects ...any) ([]byte, error) { + return MarshallObjectsContext(context.Background(), objects...) +} + +func MarshallObjectsContext(ctx context.Context, objects ...any) ([]byte, error) { if len(objects) == 1 { return json.Marshal(objects[0]) } var content JSONObject for _, object := range objects { - objectMap, err := newJSONObject(object) + objectMap, err := newJSONObject(ctx, object) if err != nil { return nil, err } content.PutAll(objectMap) } - return content.MarshalJSON() + return content.MarshalJSONContext(ctx) } func UnmarshallExcluded(inputContent []byte, parentObject any, object any) error { - parentContent, err := newJSONObject(parentObject) + return UnmarshallExcludedContext(context.Background(), inputContent, parentObject, object) +} + +func UnmarshallExcludedContext(ctx context.Context, inputContent []byte, parentObject any, object any) error { + parentContent, err := newJSONObject(ctx, parentObject) if err != nil { return err } var content JSONObject - err = content.UnmarshalJSON(inputContent) + err = content.UnmarshalJSONContext(ctx, inputContent) if err != nil { return err } @@ -39,20 +49,20 @@ func UnmarshallExcluded(inputContent []byte, parentObject any, object any) error } return E.New("unexpected key: ", content.Keys()[0]) } - inputContent, err = content.MarshalJSON() + inputContent, err = content.MarshalJSONContext(ctx) if err != nil { return err } - return json.UnmarshalDisallowUnknownFields(inputContent, object) + return json.UnmarshalContextDisallowUnknownFields(ctx, inputContent, object) } -func newJSONObject(object any) (*JSONObject, error) { - inputContent, err := json.Marshal(object) +func newJSONObject(ctx context.Context, object any) (*JSONObject, error) { + inputContent, err := json.MarshalContext(ctx, object) if err != nil { return nil, err } var content JSONObject - err = content.UnmarshalJSON(inputContent) + err = content.UnmarshalJSONContext(ctx, inputContent) if err != nil { return nil, err } diff --git a/common/json/badjson/object.go b/common/json/badjson/object.go index 61d5862d..3f5dab41 100644 --- a/common/json/badjson/object.go +++ b/common/json/badjson/object.go @@ -2,6 +2,7 @@ package badjson import ( "bytes" + "context" "strings" "github.com/sagernet/sing/common" @@ -28,6 +29,10 @@ func (m *JSONObject) IsEmpty() bool { } func (m *JSONObject) MarshalJSON() ([]byte, error) { + return m.MarshalJSONContext(context.Background()) +} + +func (m *JSONObject) MarshalJSONContext(ctx context.Context) ([]byte, error) { buffer := new(bytes.Buffer) buffer.WriteString("{") items := common.Filter(m.Entries(), func(it collections.MapEntry[string, any]) bool { @@ -38,13 +43,13 @@ func (m *JSONObject) MarshalJSON() ([]byte, error) { }) iLen := len(items) for i, entry := range items { - keyContent, err := json.Marshal(entry.Key) + keyContent, err := json.MarshalContext(ctx, entry.Key) if err != nil { return nil, err } buffer.WriteString(strings.TrimSpace(string(keyContent))) buffer.WriteString(": ") - valueContent, err := json.Marshal(entry.Value) + valueContent, err := json.MarshalContext(ctx, entry.Value) if err != nil { return nil, err } @@ -58,7 +63,11 @@ func (m *JSONObject) MarshalJSON() ([]byte, error) { } func (m *JSONObject) UnmarshalJSON(content []byte) error { - decoder := json.NewDecoder(bytes.NewReader(content)) + return m.UnmarshalJSONContext(context.Background(), content) +} + +func (m *JSONObject) UnmarshalJSONContext(ctx context.Context, content []byte) error { + decoder := json.NewDecoderContext(ctx, bytes.NewReader(content)) m.Clear() objectStart, err := decoder.Token() if err != nil { diff --git a/common/json/badjson/typed.go b/common/json/badjson/typed.go index 66f41a6e..aef85c9b 100644 --- a/common/json/badjson/typed.go +++ b/common/json/badjson/typed.go @@ -2,6 +2,7 @@ package badjson import ( "bytes" + "context" "strings" E "github.com/sagernet/sing/common/exceptions" @@ -14,18 +15,22 @@ type TypedMap[K comparable, V any] struct { } func (m TypedMap[K, V]) MarshalJSON() ([]byte, error) { + return m.MarshalJSONContext(context.Background()) +} + +func (m TypedMap[K, V]) MarshalJSONContext(ctx context.Context) ([]byte, error) { buffer := new(bytes.Buffer) buffer.WriteString("{") items := m.Entries() iLen := len(items) for i, entry := range items { - keyContent, err := json.Marshal(entry.Key) + keyContent, err := json.MarshalContext(ctx, entry.Key) if err != nil { return nil, err } buffer.WriteString(strings.TrimSpace(string(keyContent))) buffer.WriteString(": ") - valueContent, err := json.Marshal(entry.Value) + valueContent, err := json.MarshalContext(ctx, entry.Value) if err != nil { return nil, err } @@ -39,7 +44,11 @@ func (m TypedMap[K, V]) MarshalJSON() ([]byte, error) { } func (m *TypedMap[K, V]) UnmarshalJSON(content []byte) error { - decoder := json.NewDecoder(bytes.NewReader(content)) + return m.UnmarshalJSONContext(context.Background(), content) +} + +func (m *TypedMap[K, V]) UnmarshalJSONContext(ctx context.Context, content []byte) error { + decoder := json.NewDecoderContext(ctx, bytes.NewReader(content)) m.Clear() objectStart, err := decoder.Token() if err != nil { @@ -47,7 +56,7 @@ func (m *TypedMap[K, V]) UnmarshalJSON(content []byte) error { } else if objectStart != json.Delim('{') { return E.New("expected json object start, but starts with ", objectStart) } - err = m.decodeJSON(decoder) + err = m.decodeJSON(ctx, decoder) if err != nil { return E.Cause(err, "decode json object content") } @@ -60,18 +69,18 @@ func (m *TypedMap[K, V]) UnmarshalJSON(content []byte) error { return nil } -func (m *TypedMap[K, V]) decodeJSON(decoder *json.Decoder) error { +func (m *TypedMap[K, V]) decodeJSON(ctx context.Context, decoder *json.Decoder) error { for decoder.More() { keyToken, err := decoder.Token() if err != nil { return err } - keyContent, err := json.Marshal(keyToken) + keyContent, err := json.MarshalContext(ctx, keyToken) if err != nil { return err } var entryKey K - err = json.Unmarshal(keyContent, &entryKey) + err = json.UnmarshalContext(ctx, keyContent, &entryKey) if err != nil { return err } diff --git a/common/json/context_ext.go b/common/json/context_ext.go new file mode 100644 index 00000000..aec149a2 --- /dev/null +++ b/common/json/context_ext.go @@ -0,0 +1,23 @@ +package json + +import ( + "context" + + "github.com/sagernet/sing/common/json/internal/contextjson" +) + +var ( + MarshalContext = json.MarshalContext + UnmarshalContext = json.UnmarshalContext + NewEncoderContext = json.NewEncoderContext + NewDecoderContext = json.NewDecoderContext + UnmarshalContextDisallowUnknownFields = json.UnmarshalContextDisallowUnknownFields +) + +type ContextMarshaler interface { + MarshalJSONContext(ctx context.Context) ([]byte, error) +} + +type ContextUnmarshaler interface { + UnmarshalJSONContext(ctx context.Context, content []byte) error +} diff --git a/common/json/internal/contextjson/context.go b/common/json/internal/contextjson/context.go new file mode 100644 index 00000000..ded69d7d --- /dev/null +++ b/common/json/internal/contextjson/context.go @@ -0,0 +1,11 @@ +package json + +import "context" + +type ContextMarshaler interface { + MarshalJSONContext(ctx context.Context) ([]byte, error) +} + +type ContextUnmarshaler interface { + UnmarshalJSONContext(ctx context.Context, content []byte) error +} diff --git a/common/json/internal/contextjson/context_test.go b/common/json/internal/contextjson/context_test.go new file mode 100644 index 00000000..cffecbb0 --- /dev/null +++ b/common/json/internal/contextjson/context_test.go @@ -0,0 +1,43 @@ +package json_test + +import ( + "context" + "testing" + + "github.com/sagernet/sing/common/json/internal/contextjson" + + "github.com/stretchr/testify/require" +) + +type myStruct struct { + value string +} + +func (m *myStruct) MarshalJSONContext(ctx context.Context) ([]byte, error) { + return json.Marshal(ctx.Value("key").(string)) +} + +func (m *myStruct) UnmarshalJSONContext(ctx context.Context, content []byte) error { + m.value = ctx.Value("key").(string) + return nil +} + +//nolint:staticcheck +func TestMarshalContext(t *testing.T) { + t.Parallel() + ctx := context.WithValue(context.Background(), "key", "value") + var s myStruct + b, err := json.MarshalContext(ctx, &s) + require.NoError(t, err) + require.Equal(t, []byte(`"value"`), b) +} + +//nolint:staticcheck +func TestUnmarshalContext(t *testing.T) { + t.Parallel() + ctx := context.WithValue(context.Background(), "key", "value") + var s myStruct + err := json.UnmarshalContext(ctx, []byte(`{}`), &s) + require.NoError(t, err) + require.Equal(t, "value", s.value) +} diff --git a/common/json/internal/contextjson/decode.go b/common/json/internal/contextjson/decode.go index 8457171e..20c7ac68 100644 --- a/common/json/internal/contextjson/decode.go +++ b/common/json/internal/contextjson/decode.go @@ -8,6 +8,7 @@ package json import ( + "context" "encoding" "encoding/base64" "fmt" @@ -95,10 +96,15 @@ import ( // Instead, they are replaced by the Unicode replacement // character U+FFFD. func Unmarshal(data []byte, v any) error { + return UnmarshalContext(context.Background(), data, v) +} + +func UnmarshalContext(ctx context.Context, data []byte, v any) error { // Check for well-formedness. // Avoids filling out half a data structure // before discovering a JSON syntax error. var d decodeState + d.ctx = ctx err := checkValid(data, &d.scan) if err != nil { return err @@ -209,6 +215,7 @@ type errorContext struct { // decodeState represents the state while decoding a JSON value. type decodeState struct { + ctx context.Context data []byte off int // next read offset in data opcode int // last read result @@ -428,7 +435,7 @@ func (d *decodeState) valueQuoted() any { // If it encounters an Unmarshaler, indirect stops and returns that. // If decodingNull is true, indirect stops at the first settable pointer so it // can be set to nil. -func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, encoding.TextUnmarshaler, reflect.Value) { +func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, ContextUnmarshaler, encoding.TextUnmarshaler, reflect.Value) { // Issue #24153 indicates that it is generally not a guaranteed property // that you may round-trip a reflect.Value by calling Value.Addr().Elem() // and expect the value to still be settable for values derived from @@ -482,11 +489,14 @@ func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, encoding.TextUnm } if v.Type().NumMethod() > 0 && v.CanInterface() { if u, ok := v.Interface().(Unmarshaler); ok { - return u, nil, reflect.Value{} + return u, nil, nil, reflect.Value{} + } + if cu, ok := v.Interface().(ContextUnmarshaler); ok { + return nil, cu, nil, reflect.Value{} } if !decodingNull { if u, ok := v.Interface().(encoding.TextUnmarshaler); ok { - return nil, u, reflect.Value{} + return nil, nil, u, reflect.Value{} } } } @@ -498,14 +508,14 @@ func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, encoding.TextUnm v = v.Elem() } } - return nil, nil, v + return nil, nil, nil, v } // array consumes an array from d.data[d.off-1:], decoding into v. // The first byte of the array ('[') has been read already. func (d *decodeState) array(v reflect.Value) error { // Check for unmarshaler. - u, ut, pv := indirect(v, false) + u, cu, ut, pv := indirect(v, false) if u != nil { start := d.readIndex() d.skip() @@ -515,6 +525,15 @@ func (d *decodeState) array(v reflect.Value) error { } return nil } + if cu != nil { + start := d.readIndex() + d.skip() + err := cu.UnmarshalJSONContext(d.ctx, d.data[start:d.off]) + if err != nil { + d.saveError(err) + } + return nil + } if ut != nil { d.saveError(&UnmarshalTypeError{Value: "array", Type: v.Type(), Offset: int64(d.off)}) d.skip() @@ -612,7 +631,7 @@ var ( // The first byte ('{') of the object has been read already. func (d *decodeState) object(v reflect.Value) error { // Check for unmarshaler. - u, ut, pv := indirect(v, false) + u, cu, ut, pv := indirect(v, false) if u != nil { start := d.readIndex() d.skip() @@ -622,6 +641,15 @@ func (d *decodeState) object(v reflect.Value) error { } return nil } + if cu != nil { + start := d.readIndex() + d.skip() + err := cu.UnmarshalJSONContext(d.ctx, d.data[start:d.off]) + if err != nil { + d.saveError(err) + } + return nil + } if ut != nil { d.saveError(&UnmarshalTypeError{Value: "object", Type: v.Type(), Offset: int64(d.off)}) d.skip() @@ -870,7 +898,7 @@ func (d *decodeState) literalStore(item []byte, v reflect.Value, fromQuoted bool return nil } isNull := item[0] == 'n' // null - u, ut, pv := indirect(v, isNull) + u, cu, ut, pv := indirect(v, isNull) if u != nil { err := u.UnmarshalJSON(item) if err != nil { @@ -878,6 +906,13 @@ func (d *decodeState) literalStore(item []byte, v reflect.Value, fromQuoted bool } return nil } + if cu != nil { + err := cu.UnmarshalJSONContext(d.ctx, item) + if err != nil { + d.saveError(err) + } + return nil + } if ut != nil { if item[0] != '"' { if fromQuoted { diff --git a/common/json/internal/contextjson/encode.go b/common/json/internal/contextjson/encode.go index 296177a5..27f901be 100644 --- a/common/json/internal/contextjson/encode.go +++ b/common/json/internal/contextjson/encode.go @@ -12,6 +12,7 @@ package json import ( "bytes" + "context" "encoding" "encoding/base64" "fmt" @@ -156,7 +157,11 @@ import ( // handle them. Passing cyclic structures to Marshal will result in // an error. func Marshal(v any) ([]byte, error) { - e := newEncodeState() + return MarshalContext(context.Background(), v) +} + +func MarshalContext(ctx context.Context, v any) ([]byte, error) { + e := newEncodeState(ctx) defer encodeStatePool.Put(e) err := e.marshal(v, encOpts{escapeHTML: true}) @@ -251,6 +256,7 @@ var hex = "0123456789abcdef" type encodeState struct { bytes.Buffer // accumulated output + ctx context.Context // Keep track of what pointers we've seen in the current recursive call // path, to avoid cycles that could lead to a stack overflow. Only do // the relatively expensive map operations if ptrLevel is larger than @@ -264,7 +270,7 @@ const startDetectingCyclesAfter = 1000 var encodeStatePool sync.Pool -func newEncodeState() *encodeState { +func newEncodeState(ctx context.Context) *encodeState { if v := encodeStatePool.Get(); v != nil { e := v.(*encodeState) e.Reset() @@ -274,7 +280,7 @@ func newEncodeState() *encodeState { e.ptrLevel = 0 return e } - return &encodeState{ptrSeen: make(map[any]struct{})} + return &encodeState{ctx: ctx, ptrSeen: make(map[any]struct{})} } // jsonError is an error wrapper type for internal use only. @@ -371,8 +377,9 @@ func typeEncoder(t reflect.Type) encoderFunc { } var ( - marshalerType = reflect.TypeOf((*Marshaler)(nil)).Elem() - textMarshalerType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem() + marshalerType = reflect.TypeOf((*Marshaler)(nil)).Elem() + contextMarshalerType = reflect.TypeOf((*ContextMarshaler)(nil)).Elem() + textMarshalerType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem() ) // newTypeEncoder constructs an encoderFunc for a type. @@ -385,9 +392,15 @@ func newTypeEncoder(t reflect.Type, allowAddr bool) encoderFunc { if t.Kind() != reflect.Pointer && allowAddr && reflect.PointerTo(t).Implements(marshalerType) { return newCondAddrEncoder(addrMarshalerEncoder, newTypeEncoder(t, false)) } + if t.Kind() != reflect.Pointer && allowAddr && reflect.PointerTo(t).Implements(contextMarshalerType) { + return newCondAddrEncoder(addrContextMarshalerEncoder, newTypeEncoder(t, false)) + } if t.Implements(marshalerType) { return marshalerEncoder } + if t.Implements(contextMarshalerType) { + return contextMarshalerEncoder + } if t.Kind() != reflect.Pointer && allowAddr && reflect.PointerTo(t).Implements(textMarshalerType) { return newCondAddrEncoder(addrTextMarshalerEncoder, newTypeEncoder(t, false)) } @@ -470,6 +483,47 @@ func addrMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) { } } +func contextMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) { + if v.Kind() == reflect.Pointer && v.IsNil() { + e.WriteString("null") + return + } + m, ok := v.Interface().(ContextMarshaler) + if !ok { + e.WriteString("null") + return + } + b, err := m.MarshalJSONContext(e.ctx) + if err == nil { + e.Grow(len(b)) + out := availableBuffer(&e.Buffer) + out, err = appendCompact(out, b, opts.escapeHTML) + e.Buffer.Write(out) + } + if err != nil { + e.error(&MarshalerError{v.Type(), err, "MarshalJSON"}) + } +} + +func addrContextMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) { + va := v.Addr() + if va.IsNil() { + e.WriteString("null") + return + } + m := va.Interface().(ContextMarshaler) + b, err := m.MarshalJSONContext(e.ctx) + if err == nil { + e.Grow(len(b)) + out := availableBuffer(&e.Buffer) + out, err = appendCompact(out, b, opts.escapeHTML) + e.Buffer.Write(out) + } + if err != nil { + e.error(&MarshalerError{v.Type(), err, "MarshalJSON"}) + } +} + func textMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) { if v.Kind() == reflect.Pointer && v.IsNil() { e.WriteString("null") @@ -827,7 +881,7 @@ func newSliceEncoder(t reflect.Type) encoderFunc { // Byte slices get special treatment; arrays don't. if t.Elem().Kind() == reflect.Uint8 { p := reflect.PointerTo(t.Elem()) - if !p.Implements(marshalerType) && !p.Implements(textMarshalerType) { + if !p.Implements(marshalerType) && !p.Implements(contextMarshalerType) && !p.Implements(textMarshalerType) { return encodeByteSlice } } diff --git a/common/json/internal/contextjson/stream.go b/common/json/internal/contextjson/stream.go index a670ab14..2849dbf9 100644 --- a/common/json/internal/contextjson/stream.go +++ b/common/json/internal/contextjson/stream.go @@ -6,6 +6,7 @@ package json import ( "bytes" + "context" "errors" "io" ) @@ -29,7 +30,11 @@ type Decoder struct { // The decoder introduces its own buffering and may // read data from r beyond the JSON values requested. func NewDecoder(r io.Reader) *Decoder { - return &Decoder{r: r} + return NewDecoderContext(context.Background(), r) +} + +func NewDecoderContext(ctx context.Context, r io.Reader) *Decoder { + return &Decoder{r: r, d: decodeState{ctx: ctx}} } // UseNumber causes the Decoder to unmarshal a number into an interface{} as a @@ -183,6 +188,7 @@ func nonSpace(b []byte) bool { // An Encoder writes JSON values to an output stream. type Encoder struct { + ctx context.Context w io.Writer err error escapeHTML bool @@ -194,7 +200,11 @@ type Encoder struct { // NewEncoder returns a new encoder that writes to w. func NewEncoder(w io.Writer) *Encoder { - return &Encoder{w: w, escapeHTML: true} + return NewEncoderContext(context.Background(), w) +} + +func NewEncoderContext(ctx context.Context, w io.Writer) *Encoder { + return &Encoder{ctx: ctx, w: w, escapeHTML: true} } // Encode writes the JSON encoding of v to the stream, @@ -207,7 +217,7 @@ func (enc *Encoder) Encode(v any) error { return enc.err } - e := newEncodeState() + e := newEncodeState(enc.ctx) defer encodeStatePool.Put(e) err := e.marshal(v, encOpts{escapeHTML: enc.escapeHTML}) diff --git a/common/json/internal/contextjson/unmarshal.go b/common/json/internal/contextjson/unmarshal.go index 29405395..04c13cbe 100644 --- a/common/json/internal/contextjson/unmarshal.go +++ b/common/json/internal/contextjson/unmarshal.go @@ -1,5 +1,7 @@ package json +import "context" + func UnmarshalDisallowUnknownFields(data []byte, v any) error { var d decodeState d.disallowUnknownFields = true @@ -10,3 +12,15 @@ func UnmarshalDisallowUnknownFields(data []byte, v any) error { d.init(data) return d.unmarshal(v) } + +func UnmarshalContextDisallowUnknownFields(ctx context.Context, data []byte, v any) error { + var d decodeState + d.ctx = ctx + d.disallowUnknownFields = true + err := checkValid(data, &d.scan) + if err != nil { + return err + } + d.init(data) + return d.unmarshal(v) +} diff --git a/common/json/unmarshal.go b/common/json/unmarshal.go index 7505ebc3..94a2d764 100644 --- a/common/json/unmarshal.go +++ b/common/json/unmarshal.go @@ -2,6 +2,7 @@ package json import ( "bytes" + "context" "errors" "strings" @@ -10,7 +11,11 @@ import ( ) func UnmarshalExtended[T any](content []byte) (T, error) { - decoder := NewDecoder(NewCommentFilter(bytes.NewReader(content))) + return UnmarshalExtendedContext[T](context.Background(), content) +} + +func UnmarshalExtendedContext[T any](ctx context.Context, content []byte) (T, error) { + decoder := NewDecoderContext(ctx, NewCommentFilter(bytes.NewReader(content))) var value T err := decoder.Decode(&value) if err == nil {