diff --git a/decode_test.go b/decode_test.go index 8f32cedb..df914eda 100644 --- a/decode_test.go +++ b/decode_test.go @@ -197,13 +197,6 @@ func Test_Decoder(t *testing.T) { assertEq(t, "interface{}", v.F, nil) assertEq(t, "nilfunc", true, v.G == nil) }) - t.Run("struct.pointer must be nil", func(t *testing.T) { - var v struct { - A *int - } - json.Unmarshal([]byte(`{"a": "alpha"}`), &v) - assertEq(t, "struct.A", v.A, (*int)(nil)) - }) }) t.Run("interface", func(t *testing.T) { t.Run("number", func(t *testing.T) { diff --git a/internal/decoder/ptr.go b/internal/decoder/ptr.go index ae229946..de12e105 100644 --- a/internal/decoder/ptr.go +++ b/internal/decoder/ptr.go @@ -85,7 +85,6 @@ func (d *ptrDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, p unsafe.P } c, err := d.dec.Decode(ctx, cursor, depth, newptr) if err != nil { - *(*unsafe.Pointer)(p) = nil return 0, err } cursor = c diff --git a/internal/encoder/code.go b/internal/encoder/code.go index 5b08faef..fec45a4b 100644 --- a/internal/encoder/code.go +++ b/internal/encoder/code.go @@ -518,6 +518,7 @@ func (c *StructCode) ToAnonymousOpcode(ctx *compileContext) Opcodes { prevField = firstField codes = codes.Add(fieldCodes...) } + ctx.structTypeToCodes[uintptr(unsafe.Pointer(c.typ))] = codes return codes } diff --git a/internal/encoder/compiler.go b/internal/encoder/compiler.go index 3ae39ba8..6b0c7f7a 100644 --- a/internal/encoder/compiler.go +++ b/internal/encoder/compiler.go @@ -86,11 +86,13 @@ func getFilteredCodeSetIfNeeded(ctx *RuntimeContext, codeSet *OpcodeSet) (*Opcod type Compiler struct { structTypeToCode map[uintptr]*StructCode + anonymousStructTypeToCode map[uintptr]*StructCode } func newCompiler() *Compiler { return &Compiler{ structTypeToCode: map[uintptr]*StructCode{}, + anonymousStructTypeToCode: map[uintptr]*StructCode{}, } } @@ -169,11 +171,11 @@ func (c *Compiler) typeToCode(typ *runtime.Type) (Code, error) { return c.sliceCode(typ) case reflect.Map: if isPtr { - return c.ptrCode(runtime.PtrTo(typ)) + return c.ptrCode(runtime.PtrTo(typ), false) } return c.mapCode(typ) case reflect.Struct: - return c.structCode(typ, isPtr) + return c.structCode(typ, isPtr, false) case reflect.Int: return c.intCode(typ, isPtr) case reflect.Int8: @@ -208,11 +210,11 @@ func (c *Compiler) typeToCode(typ *runtime.Type) (Code, error) { if isPtr && typ.Implements(marshalTextType) { typ = orgType } - return c.typeToCodeWithPtr(typ, isPtr) + return c.typeToCodeWithPtr(typ, isPtr, false) } } -func (c *Compiler) typeToCodeWithPtr(typ *runtime.Type, isPtr bool) (Code, error) { +func (c *Compiler) typeToCodeWithPtr(typ *runtime.Type, isPtr, isAnonymous bool) (Code, error) { switch { case c.implementsMarshalJSON(typ): return c.marshalJSONCode(typ) @@ -221,7 +223,7 @@ func (c *Compiler) typeToCodeWithPtr(typ *runtime.Type, isPtr bool) (Code, error } switch typ.Kind() { case reflect.Ptr: - return c.ptrCode(typ) + return c.ptrCode(typ, isAnonymous) case reflect.Slice: elem := typ.Elem() if elem.Kind() == reflect.Uint8 { @@ -236,7 +238,7 @@ func (c *Compiler) typeToCodeWithPtr(typ *runtime.Type, isPtr bool) (Code, error case reflect.Map: return c.mapCode(typ) case reflect.Struct: - return c.structCode(typ, isPtr) + return c.structCode(typ, isPtr, isAnonymous) case reflect.Interface: return c.interfaceCode(typ, false) case reflect.Int: @@ -424,8 +426,8 @@ func (c *Compiler) marshalTextCode(typ *runtime.Type) (*MarshalTextCode, error) }, nil } -func (c *Compiler) ptrCode(typ *runtime.Type) (*PtrCode, error) { - code, err := c.typeToCodeWithPtr(typ.Elem(), true) +func (c *Compiler) ptrCode(typ *runtime.Type, isAnonymous bool) (*PtrCode, error) { + code, err := c.typeToCodeWithPtr(typ.Elem(), true, isAnonymous) if err != nil { return nil, err } @@ -485,12 +487,12 @@ func (c *Compiler) listElemCode(typ *runtime.Type) (Code, error) { case !typ.Implements(marshalTextType) && runtime.PtrTo(typ).Implements(marshalTextType): return c.marshalTextCode(typ) case typ.Kind() == reflect.Map: - return c.ptrCode(runtime.PtrTo(typ)) + return c.ptrCode(runtime.PtrTo(typ), false) default: // isPtr was originally used to indicate whether the type of top level is pointer. // However, since the slice/array element is a specification that can get the pointer address, explicitly set isPtr to true. // See here for related issues: https://github.com/goccy/go-json/issues/370 - code, err := c.typeToCodeWithPtr(typ, true) + code, err := c.typeToCodeWithPtr(typ, true, false) if err != nil { return nil, err } @@ -511,7 +513,7 @@ func (c *Compiler) mapKeyCode(typ *runtime.Type) (Code, error) { } switch typ.Kind() { case reflect.Ptr: - return c.ptrCode(typ) + return c.ptrCode(typ, false) case reflect.String: return c.stringCode(typ, false) case reflect.Int: @@ -543,9 +545,9 @@ func (c *Compiler) mapKeyCode(typ *runtime.Type) (Code, error) { func (c *Compiler) mapValueCode(typ *runtime.Type) (Code, error) { switch typ.Kind() { case reflect.Map: - return c.ptrCode(runtime.PtrTo(typ)) + return c.ptrCode(runtime.PtrTo(typ), false) default: - code, err := c.typeToCodeWithPtr(typ, false) + code, err := c.typeToCodeWithPtr(typ, false, false) if err != nil { return nil, err } @@ -559,16 +561,20 @@ func (c *Compiler) mapValueCode(typ *runtime.Type) (Code, error) { } } -func (c *Compiler) structCode(typ *runtime.Type, isPtr bool) (*StructCode, error) { +func (c *Compiler) structCode(typ *runtime.Type, isPtr, isAnonymous bool) (*StructCode, error) { typeptr := uintptr(unsafe.Pointer(typ)) - if code, exists := c.structTypeToCode[typeptr]; exists { + structTypeToCode := c.structTypeToCode + if isAnonymous { + structTypeToCode = c.anonymousStructTypeToCode + } + if code, exists := structTypeToCode[typeptr]; exists { derefCode := *code derefCode.isRecursive = true return &derefCode, nil } indirect := runtime.IfaceIndir(typ) code := &StructCode{typ: typ, isPtr: isPtr, isIndirect: indirect} - c.structTypeToCode[typeptr] = code + structTypeToCode[typeptr] = code fieldNum := typ.NumField() tags := c.typeToStructTags(typ) @@ -613,7 +619,7 @@ func (c *Compiler) structCode(typ *runtime.Type, isPtr bool) (*StructCode, error if !code.disableIndirectConversion && !indirect && isPtr { code.enableIndirect() } - delete(c.structTypeToCode, typeptr) + delete(structTypeToCode, typeptr) return code, nil } @@ -680,7 +686,7 @@ func (c *Compiler) structFieldCode(structCode *StructCode, tag *runtime.StructTa fieldCode.isAddrForMarshaler = true fieldCode.isNilCheck = false default: - code, err := c.typeToCodeWithPtr(fieldType, isPtr) + code, err := c.typeToCodeWithPtr(fieldType, isPtr, fieldCode.isAnonymous) if err != nil { return nil, err }