diff --git a/decoder.go b/decoder.go index b87a9e3..e100068 100644 --- a/decoder.go +++ b/decoder.go @@ -733,7 +733,11 @@ func DecodeSliceOfUint64sContent[T ~uint64](dec *Decoder, ns *[]T, maxItems uint size := dec.retrieveSize() if size == 0 { // Empty slice, remove anything extra - *ns = (*ns)[:0] + if *ns == nil { + *ns = make([]T, 0) // Don't leave nil, init to empty + } else { + *ns = (*ns)[:0] + } return } // Compute the number of items based on the item size of the type @@ -880,7 +884,11 @@ func DecodeSliceOfStaticBytesContent[T commonBytesLengths](dec *Decoder, blobs * size := dec.retrieveSize() if size == 0 { // Empty slice, remove anything extra - *blobs = (*blobs)[:0] + if *blobs == nil { + *blobs = make([]T, 0) // Don't leave nil, init to empty + } else { + *blobs = (*blobs)[:0] + } return } // Compute the number of items based on the item size of the type @@ -967,7 +975,11 @@ func DecodeSliceOfDynamicBytesContent(dec *Decoder, blobs *[][]byte, maxItems ui size := dec.retrieveSize() if size == 0 { // Empty slice, remove anything extra - *blobs = (*blobs)[:0] + if *blobs == nil { + *blobs = make([][]byte, 0) // Don't leave nil, init to empty + } else { + *blobs = (*blobs)[:0] + } return } if size < 4 { @@ -1048,7 +1060,11 @@ func DecodeSliceOfStaticObjectsContent[T newableStaticObject[U], U any](dec *Dec size := dec.retrieveSize() if size == 0 { // Empty slice, remove anything extra - *objects = (*objects)[:0] + if *objects == nil { + *objects = make([]T, 0) // Don't leave nil, init to empty + } else { + *objects = (*objects)[:0] + } return } // Compute the number of items based on the item size of the type @@ -1122,7 +1138,11 @@ func DecodeSliceOfDynamicObjectsContent[T newableDynamicObject[U], U any](dec *D size := dec.retrieveSize() if size == 0 { // Empty slice, remove anything extra - *objects = (*objects)[:0] + if *objects == nil { + *objects = make([]T, 0) // Don't leave nil, init to empty + } else { + *objects = (*objects)[:0] + } return } if size < 4 { diff --git a/tests/corner_cases_test.go b/tests/corner_cases_test.go index 9564bf4..9f5c95f 100644 --- a/tests/corner_cases_test.go +++ b/tests/corner_cases_test.go @@ -102,3 +102,66 @@ func TestInvalidBoolean(t *testing.T) { t.Errorf("decode error mismatch: have %v, want %v", err, ssz.ErrInvalidBoolean) } } + +// Tests that decoding empty slices will init them instead of leaving as nil. +func TestEmptySliceInit(t *testing.T) { + obj := new(testEmptySlicesType) + buf := new(bytes.Buffer) + + if err := ssz.EncodeToStream(buf, obj); err != nil { + panic(err) + } + if err := ssz.DecodeFromBytes(buf.Bytes(), obj); err != nil { + panic(err) + } + if obj.A == nil { + t.Errorf("failed to init empty uint64 slice") + } + if obj.B == nil { + t.Errorf("failed to init empty statc bytes slice") + } + if obj.C == nil { + t.Errorf("failed to init empty dynamic bytes slice") + } + if obj.D == nil { + t.Errorf("failed to init empty static objects slice") + } + if obj.E == nil { + t.Errorf("failed to init empty dynamic objects slice") + } +} + +type testEmptySlicesType struct { + A []uint64 // Slice of uint64 + B [][32]byte // Slice of static bytes + C [][]byte // Slice of dynamic bytes + D []*types.Withdrawal // Slice of static objects + E []*types.ExecutionPayload // Slice of dynamic objects +} + +func (t *testEmptySlicesType) SizeSSZ(sizer *ssz.Sizer, fixed bool) (size uint32) { + size = 5 * 4 + if fixed { + return size + } + size += ssz.SizeSliceOfUint64s(sizer, t.A) + size += ssz.SizeSliceOfStaticBytes(sizer, t.B) + size += ssz.SizeSliceOfDynamicBytes(sizer, t.C) + size += ssz.SizeSliceOfStaticObjects(sizer, t.D) + size += ssz.SizeSliceOfDynamicObjects(sizer, t.E) + + return size +} +func (t *testEmptySlicesType) DefineSSZ(codec *ssz.Codec) { + ssz.DefineSliceOfUint64sOffset(codec, &t.A, 16) + ssz.DefineSliceOfStaticBytesOffset(codec, &t.B, 16) + ssz.DefineSliceOfDynamicBytesOffset(codec, &t.C, 16, 16) + ssz.DefineSliceOfStaticObjectsOffset(codec, &t.D, 16) + ssz.DefineSliceOfDynamicObjectsOffset(codec, &t.E, 16) + + ssz.DefineSliceOfUint64sContent(codec, &t.A, 16) + ssz.DefineSliceOfStaticBytesContent(codec, &t.B, 16) + ssz.DefineSliceOfDynamicBytesContent(codec, &t.C, 16, 16) + ssz.DefineSliceOfStaticObjectsContent(codec, &t.D, 16) + ssz.DefineSliceOfDynamicObjectsContent(codec, &t.E, 16) +}