Skip to content

Commit

Permalink
feat: adding writer for dictionary
Browse files Browse the repository at this point in the history
  • Loading branch information
James-Frowen committed Nov 5, 2023
1 parent 9f9a4fe commit f89d596
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 12 deletions.
32 changes: 32 additions & 0 deletions Assets/Mirage/Runtime/Serialization/CollectionExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,20 @@ public static void WriteArraySegment<T>(this NetworkWriter writer, ArraySegment<
}
}

[WeaverSerializeCollection]
public static void WriteGodotDictionary<TKey, TValue>(this NetworkWriter writer, Dictionary<TKey, TValue> dictionary)
{
WriteCountPlusOne(writer, dictionary?.Count);

if (dictionary is null)
return;

foreach (var kvp in dictionary)
{
writer.Write(kvp.Key);
writer.Write(kvp.Value);
}
}

/// <returns>array or null</returns>
public static byte[] ReadBytesAndSize(this NetworkReader reader)
Expand Down Expand Up @@ -159,6 +173,24 @@ public static ArraySegment<T> ReadArraySegment<T>(this NetworkReader reader)
return array != null ? new ArraySegment<T>(array) : default;
}

[WeaverSerializeCollection]
public static Dictionary<TKey, TValue> ReadGodotDictionary<TKey, TValue>(this NetworkReader reader)
{
var hasValue = ReadCountPlusOne(reader, out var length);
if (!hasValue)
return null;

ValidateSize(reader, length);

var result = new Dictionary<TKey, TValue>();
for (var i = 0; i < length; i++)
{
var key = reader.Read<TKey>();
var value = reader.Read<TValue>();
result[key] = value;
}
return result;
}

/// <summary>Writes null as 0, and all over values as +1</summary>
/// <param name="count">The real count or null if collection is is null</param>
Expand Down
11 changes: 7 additions & 4 deletions Assets/Mirage/Weaver/Serialization/Readers.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System;
using System.Collections.Generic;
using System.Linq.Expressions;
using Mirage.CodeGen;
using Mirage.Serialization;
Expand All @@ -15,7 +16,7 @@ public class Readers : SerializeFunctionBase
public Readers(ModuleDefinition module, IWeaverLogger logger) : base(module, logger) { }

protected override string FunctionTypeLog => "read function";
protected override Expression<Action> ArrayExpression => () => CollectionExtensions.ReadArray<byte>(default);
protected override Expression<Action> ArrayExpression => () => Mirage.Serialization.CollectionExtensions.ReadArray<byte>(default);

protected override MethodReference GetGenericFunction()
{
Expand Down Expand Up @@ -84,17 +85,19 @@ private ReadMethod GenerateReaderFunction(TypeReference variable)
return new ReadMethod(definition, readParameter, worker);
}

protected override MethodReference GenerateCollectionFunction(TypeReference typeReference, TypeReference elementType, MethodReference collectionMethod)
protected override MethodReference GenerateCollectionFunction(TypeReference typeReference, List<TypeReference> elementTypes, MethodReference collectionMethod)
{
// generate readers for the element
_ = GetFunction_Throws(elementType);
foreach (var elementType in elementTypes)
_ = GetFunction_Throws(elementType);

var readMethod = GenerateReaderFunction(typeReference);

var collectionReader = collectionMethod.GetElementMethod();

var methodRef = new GenericInstanceMethod(collectionReader);
methodRef.GenericArguments.Add(elementType);
foreach (var elementType in elementTypes)
methodRef.GenericArguments.Add(elementType);

// generates
// return reader.ReadList<T>()
Expand Down
11 changes: 7 additions & 4 deletions Assets/Mirage/Weaver/Serialization/Writers.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System;
using System.Collections.Generic;
using System.Linq.Expressions;
using Mirage.CodeGen;
using Mirage.Serialization;
Expand All @@ -14,7 +15,7 @@ public class Writers : SerializeFunctionBase
public Writers(ModuleDefinition module, IWeaverLogger logger) : base(module, logger) { }

protected override string FunctionTypeLog => "write function";
protected override Expression<Action> ArrayExpression => () => CollectionExtensions.WriteArray<byte>(default, default);
protected override Expression<Action> ArrayExpression => () => Mirage.Serialization.CollectionExtensions.WriteArray<byte>(default, default);

protected override MethodReference GetGenericFunction()
{
Expand Down Expand Up @@ -151,17 +152,19 @@ private void WriteAllFields(TypeReference type, WriteMethod writerFunc)
}
}

protected override MethodReference GenerateCollectionFunction(TypeReference typeReference, TypeReference elementType, MethodReference collectionMethod)
protected override MethodReference GenerateCollectionFunction(TypeReference typeReference, List<TypeReference> elementTypes, MethodReference collectionMethod)
{
// make sure element has a writer
// collection writers use the generic writer, so this will make sure one exists
_ = GetFunction_Throws(elementType);
foreach (var elementType in elementTypes)
_ = GetFunction_Throws(elementType);

var writerMethod = GenerateWriterFunc(typeReference);
var collectionWriter = collectionMethod.GetElementMethod();

var methodRef = new GenericInstanceMethod(collectionWriter);
methodRef.GenericArguments.Add(elementType);
foreach (var elementType in elementTypes)
methodRef.GenericArguments.Add(elementType);

// generates
// reader.WriteArray<T>(array);
Expand Down
10 changes: 6 additions & 4 deletions Assets/Mirage/Weaver/SerializeFunctionBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ private MethodReference GenerateFunction(TypeReference typeReference)
}
var elementType = typeReference.GetElementType();
var arrayMethod = module.ImportReference(ArrayExpression);
return GenerateCollectionFunction(typeReference, elementType, arrayMethod);
return GenerateCollectionFunction(typeReference, new List<TypeReference> { elementType }, arrayMethod);
}

var typeDefinition = typeReference.Resolve();
Expand All @@ -173,9 +173,11 @@ private MethodReference GenerateFunction(TypeReference typeReference)
if (collectionMethods.TryGetValue(typeDefinition, out var collectionMethod))
{
var genericInstance = (GenericInstanceType)typeReference;
var elementType = genericInstance.GenericArguments[0];
var elementTypes = new List<TypeReference>();
foreach (var type in genericInstance.GenericArguments)
elementTypes.Add(type);

return GenerateCollectionFunction(typeReference, elementType, collectionMethod);
return GenerateCollectionFunction(typeReference, elementTypes, collectionMethod);
}

// check for invalid types
Expand Down Expand Up @@ -261,7 +263,7 @@ private GenericInstanceMethod CreateGenericFunction(TypeReference argument)
protected abstract MethodReference GetNetworkBehaviourFunction(TypeReference typeReference);

protected abstract MethodReference GenerateEnumFunction(TypeReference typeReference);
protected abstract MethodReference GenerateCollectionFunction(TypeReference typeReference, TypeReference elementType, MethodReference collectionMethod);
protected abstract MethodReference GenerateCollectionFunction(TypeReference typeReference, List<TypeReference> elementTypes, MethodReference collectionMethod);

protected abstract Expression<Action> ArrayExpression { get; }

Expand Down
40 changes: 40 additions & 0 deletions Assets/Tests/Runtime/Serialization/NetworkWriterTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1133,6 +1133,46 @@ public void NullableUlong(ulong? value)
Assert.That(unpacked, Is.EqualTo(value));
}

[Test]
public void DictionaryNull()
{
writer.Write<Dictionary<string, Vector3>>(null);
reader.Reset(writer.ToArraySegment());
var unpacked = reader.Read<Dictionary<string, Vector3>>();

Assert.That(unpacked, Is.Null);
}
[Test]
public void DictionaryEmpty()
{
var dict = new Dictionary<string, Vector3>();
writer.Write(dict);
reader.Reset(writer.ToArraySegment());
var unpacked = reader.Read<Dictionary<string, Vector3>>();

Assert.That(unpacked, Is.Not.Null);
Assert.That(unpacked.Count, Is.Zero);
}
[Test]
public void Dictionary()
{
var dict = new Dictionary<string, Vector3>
{
{ "one", Vector3.one },
{ "two", Vector3.one * 2 },
{ "left", Vector3.left }
};

writer.Write(dict);
reader.Reset(writer.ToArraySegment());
var unpacked = reader.Read<Dictionary<string, Vector3>>();

Assert.That(unpacked.Count, Is.EqualTo(3));
Assert.That(unpacked["one"], Is.EqualTo(Vector3.one));
Assert.That(unpacked["two"], Is.EqualTo(Vector3.one * 2));
Assert.That(unpacked["left"], Is.EqualTo(Vector3.left));
}

[Test]
public void SByteLength()
{
Expand Down

0 comments on commit f89d596

Please sign in to comment.