diff --git a/src/MemoryPack.Core/MemoryPackSerializer.Deserialize.cs b/src/MemoryPack.Core/MemoryPackSerializer.Deserialize.cs index 718ebb8..6325523 100644 --- a/src/MemoryPack.Core/MemoryPackSerializer.Deserialize.cs +++ b/src/MemoryPack.Core/MemoryPackSerializer.Deserialize.cs @@ -75,6 +75,51 @@ public static int Deserialize< #endif T>(in ReadOnlySequence buffer, ref T? value, MemoryPackSerializerOptions? options = default) { + if (!RuntimeHelpers.IsReferenceOrContainsReferences()) + { + int sizeOfT = Unsafe.SizeOf(); + if (buffer.Length < sizeOfT) + { + MemoryPackSerializationException.ThrowInvalidRange(Unsafe.SizeOf(), (int)buffer.Length); + } + + ReadOnlySequence sliced = buffer.Slice(0, sizeOfT); + + if (sliced.IsSingleSegment) + { + value = Unsafe.ReadUnaligned(ref MemoryMarshal.GetReference(sliced.FirstSpan)); + return sizeOfT; + } + else + { + // We can't read directly from ReadOnlySequence to T, so we copy to a temp array. + // if less than 512 bytes, use stackalloc, otherwise use MemoryPool + byte[]? tempArray = null; + + Span tempSpan = sizeOfT <= 512 ? stackalloc byte[sizeOfT] : default; + + try + { + if (sizeOfT > 512) + { + tempArray = ArrayPool.Shared.Rent(sizeOfT); + tempSpan = tempArray; + } + + sliced.CopyTo(tempSpan); + value = Unsafe.ReadUnaligned(ref MemoryMarshal.GetReference(tempSpan)); + return sizeOfT; + } + finally + { + if (tempArray is not null) + { + ArrayPool.Shared.Return(tempArray); + } + } + } + } + var state = threadStaticReaderOptionalState; if (state == null) { diff --git a/tests/MemoryPack.Tests/DeserializeTest.cs b/tests/MemoryPack.Tests/DeserializeTest.cs index c937fed..d54255e 100644 --- a/tests/MemoryPack.Tests/DeserializeTest.cs +++ b/tests/MemoryPack.Tests/DeserializeTest.cs @@ -1,14 +1,16 @@ using System; +using System.Buffers; using System.Collections.Generic; using System.Dynamic; using System.IO; using System.Linq; +using System.Runtime.InteropServices; using System.Text; using System.Threading.Tasks; namespace MemoryPack.Tests; -public class DeserializeTest +public partial class DeserializeTest { [Fact] public async Task StreamTest() @@ -30,6 +32,106 @@ public async Task StreamTest() result.Should().Equal(expected); } + [Fact] + public void GenericValueStructTest() + { + GenericStruct value = new() { Id = 75, Value = 23 }; + + RunMultiSegmentTest(value); + } + + [Fact] + public void LargeGenericValueStructTest() + { + GenericStruct value = new() { Id = 75, Value = new PrePaddedInt() { Value = 23 } }; + + RunMultiSegmentTest(value); + } + + [Fact] + public void GenericReferenceStructTest() + { + GenericStruct value = new GenericStruct() { Id = 75, Value = "Hello World!" }; + + RunMultiSegmentTest(value); + } + + [Fact] + public void LargeGenericReferenceStructTest() + { + GenericStruct value = new() { Id = 75, Value = new PrePaddedString() { Value = "Hello World!" } }; + + RunMultiSegmentTest(value); + } + + private void RunMultiSegmentTest(T value) + { + byte[] bytes = MemoryPackSerializer.Serialize(value); + + byte[] firstHalf = new byte[bytes.Length / 2]; + Array.Copy(bytes, 0, firstHalf, 0, firstHalf.Length); + + int secondHalfLength = bytes.Length / 2; + if (bytes.Length % 2 != 0) + { + secondHalfLength++; + } + + byte[] secondHalf = new byte[secondHalfLength]; + + Array.Copy(bytes, firstHalf.Length, secondHalf, 0, secondHalfLength); + + ReadOnlySequence sequence = ReadOnlySequenceBuilder.Create(firstHalf, secondHalf); + + T? result = MemoryPackSerializer.Deserialize(sequence); + result.Should().Be(value); + } + + [MemoryPackable] + public partial struct GenericStruct + { + public int Id; + public T Value; + + public override string ToString() + { + return $"{Id}, {Value}"; + } + } + + [StructLayout(LayoutKind.Explicit, Size = 516)] + struct PrePaddedInt + { + [FieldOffset(512)] + public int Value; + } + + [MemoryPackable] + private partial class PrePaddedString : IEquatable + { + private PrePaddedInt _padding; + public string Value { get; set; } = ""; + + public bool Equals(PrePaddedString? other) + { + if (other is null) + return false; + + return Value.Equals(other.Value); + } + + public override bool Equals(object? obj) + { + if (obj is PrePaddedString other) + return Equals(other); + return false; + } + + public override int GetHashCode() + { + return Value.GetHashCode(); + } + } class RandomStream : Stream {