Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: adding batching to unreliable messages #1172

Merged
merged 2 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 12 additions & 50 deletions Assets/Mirage/Runtime/SocketLayer/Connection/AckSystem.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,10 @@ public class AckSystem : IDisposable
/// <summary>PacketType, ack sequence, mask</summary>
public const int ACK_HEADER_SIZE = sizeof(byte) + sizeof(ushort) + sizeof(ulong);

public const int RELIABLE_MESSAGE_LENGTH_SIZE = sizeof(ushort);
public const int FRAGMENT_INDEX_SIZE = sizeof(byte);

/// <summary>Smallest size a header for reliable packet, <see cref="RELIABLE_HEADER_SIZE"/> + 2 bytes per message</summary>
public const int MIN_RELIABLE_HEADER_SIZE = RELIABLE_HEADER_SIZE + RELIABLE_MESSAGE_LENGTH_SIZE;
public const int MIN_RELIABLE_HEADER_SIZE = RELIABLE_HEADER_SIZE + Batch.MESSAGE_LENGTH_SIZE;

/// <summary>Smallest size a header for reliable packet, <see cref="RELIABLE_HEADER_SIZE"/> + 1 byte for fragment index</summary>
public const int MIN_RELIABLE_FRAGMENT_HEADER_SIZE = RELIABLE_HEADER_SIZE + FRAGMENT_INDEX_SIZE;
Expand Down Expand Up @@ -66,7 +65,7 @@ public class AckSystem : IDisposable
private float _lastSentTime;
private ushort _lastSentAck;
private int _emptyAckCount = 0;
private ReliablePacket _nextBatch;
private readonly Batch _batch;

/// <summary>
///
Expand All @@ -83,6 +82,7 @@ public AckSystem(IRawConnection connection, Config config, int maxPacketSize, IT
_bufferPool = bufferPool;
_reliablePool = new Pool<ReliablePacket>(ReliablePacket.CreateNew, 0, config.MaxReliablePacketsInSendBufferPerConnection);
_metrics = metrics;
_batch = new ReliableBatch(maxPacketSize, CreateReliableBuffer, SendReliablePacket);

_ackTimeout = config.TimeBeforeEmptyAck;
_emptyAckLimit = config.EmptyAckLimit;
Expand Down Expand Up @@ -112,6 +112,9 @@ public void Dispose()
{
var removeSafety = new HashSet<ByteBuffer>();

if (_batch is IDisposable disposable)
disposable.Dispose();

_sentAckablePackets.ClearAndRelease((packet) =>
{
Debug.Assert(packet.IsValid());
Expand Down Expand Up @@ -142,7 +145,7 @@ public void Dispose()


/// <summary>
/// Gets next Reliable packet in order, packet consists for multiple messsages
/// Gets next Reliable packet in order, packet consists for multiple messages
/// <para>[length, message, length, message, ...]</para>
/// </summary>
/// <param name="packet"></param>
Expand Down Expand Up @@ -191,12 +194,7 @@ public ReliableReceived GetNextFragment()

public void Update()
{
if (_nextBatch != null)
{
SendReliablePacket(_nextBatch);
_nextBatch = null;
}

_batch.Flush();

// todo send ack if not recently been sent
// ack only packet sent if no other sent within last frame
Expand Down Expand Up @@ -308,8 +306,6 @@ public void SendNotify(byte[] inPacket, int inOffset, int inLength, INotifyCallB
}
}



public void SendReliable(byte[] message, int offset, int length)
{
if (_sentAckablePackets.IsFull)
Expand All @@ -325,37 +321,16 @@ public void SendReliable(byte[] message, int offset, int length)
// if there is existing batch, send it first
// we need to do this so that fragmented message arrive in order
// if we dont, a message sent after maybe be added to batch and then have earlier order than fragmented message
if (_nextBatch != null)
{
SendReliablePacket(_nextBatch);
_nextBatch = null;
}

_batch.Flush();
SendFragmented(message, offset, length);
return;
}


if (_nextBatch == null)
{
_nextBatch = CreateReliableBuffer(PacketType.Reliable);
}

var msgLength = length + RELIABLE_MESSAGE_LENGTH_SIZE;
var batchLength = _nextBatch.Length;
if (batchLength + msgLength > _maxPacketSize)
{
// if full, send and create new
SendReliablePacket(_nextBatch);

_nextBatch = CreateReliableBuffer(PacketType.Reliable);
}

AddToBatch(_nextBatch, message, offset, length);
_batch.AddMessage(message, offset, length);
}

/// <summary>
/// Splits messsage into multiple packets
/// Splits message into multiple packets
/// <para>Note: this might just send 1 packet if length is equal to size.
/// This might happen because fragmented header is 1 less that batched header</para>
/// </summary>
Expand Down Expand Up @@ -408,18 +383,6 @@ private ReliablePacket CreateReliableBuffer(PacketType packetType)
return packet;
}

private static void AddToBatch(ReliablePacket packet, byte[] message, int offset, int length)
{
var array = packet.Buffer.array;
var packetOffset = packet.Length;

ByteUtils.WriteUShort(array, ref packetOffset, (ushort)length);
Buffer.BlockCopy(message, offset, array, packetOffset, length);
packetOffset += length;

packet.Length = packetOffset;
}

private void SendReliablePacket(ReliablePacket reliable)
{
ThrowIfBufferLimitReached();
Expand Down Expand Up @@ -447,7 +410,6 @@ private void ThrowIfBufferLimitReached()
}
}


/// <summary>
/// Receives incoming Notify packet
/// <para>Ignores duplicate or late packets</para>
Expand Down Expand Up @@ -743,7 +705,7 @@ public bool IsNotValid()
}
}

private class ReliablePacket
public class ReliablePacket
{
public ushort LastSequence;
public int Length;
Expand Down
130 changes: 130 additions & 0 deletions Assets/Mirage/Runtime/SocketLayer/Connection/Batch.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
using System;

namespace Mirage.SocketLayer
{
public abstract class Batch
{
public const int MESSAGE_LENGTH_SIZE = 2;

private readonly int _maxPacketSize;

public Batch(int maxPacketSize)
{
_maxPacketSize = maxPacketSize;
}

protected abstract bool Created { get; }
protected abstract byte[] GetBatch();
protected abstract ref int GetBatchLength();

protected abstract void CreateNewBatch();
protected abstract void SendAndReset();

public void AddMessage(byte[] message, int offset, int length)
{
if (Created)
{
var msgLength = length + MESSAGE_LENGTH_SIZE;
var batchLength = GetBatchLength();
if (batchLength + msgLength > _maxPacketSize)
{
// if full, send and create new
SendAndReset();
}
}

if (!Created)
CreateNewBatch();

AddToBatch(message, offset, length);
}

private void AddToBatch(byte[] message, int offset, int length)
{
var batch = GetBatch();
ref var batchLength = ref GetBatchLength();
ByteUtils.WriteUShort(batch, ref batchLength, checked((ushort)length));
Buffer.BlockCopy(message, offset, batch, batchLength, length);
batchLength += length;
}

public void Flush()
{
if (Created)
SendAndReset();
}
}

public class ArrayBatch : Batch
{
private readonly Action<byte[], int> _send;
private readonly PacketType _packetType;

private readonly byte[] _batch;
private int _batchLength;

public ArrayBatch(int maxPacketSize, Action<byte[], int> send, PacketType reliable)
: base(maxPacketSize)
{
_batch = new byte[maxPacketSize];
_send = send;
_packetType = reliable;
}

protected override bool Created => _batchLength > 0;

protected override byte[] GetBatch() => _batch;
protected override ref int GetBatchLength() => ref _batchLength;

protected override void CreateNewBatch()
{
_batch[0] = (byte)_packetType;
_batchLength = 1;
}

protected override void SendAndReset()
{
_send.Invoke(_batch, _batchLength);
_batchLength = 0;
}
}

public class ReliableBatch : Batch, IDisposable
{
private AckSystem.ReliablePacket _nextBatch;
private readonly Func<PacketType, AckSystem.ReliablePacket> _createReliableBuffer;
private readonly Action<AckSystem.ReliablePacket> _sendReliablePacket;

public ReliableBatch(int maxPacketSize, Func<PacketType, AckSystem.ReliablePacket> createReliableBuffer, Action<AckSystem.ReliablePacket> sendReliablePacket)
: base(maxPacketSize)
{
_createReliableBuffer = createReliableBuffer;
_sendReliablePacket = sendReliablePacket;
}

protected override bool Created => _nextBatch != null;

protected override byte[] GetBatch() => _nextBatch.Buffer.array;
protected override ref int GetBatchLength() => ref _nextBatch.Length;

protected override void CreateNewBatch()
{
_nextBatch = _createReliableBuffer.Invoke(PacketType.Reliable);
}

protected override void SendAndReset()
{
_sendReliablePacket.Invoke(_nextBatch);
_nextBatch = null;
}

void IDisposable.Dispose()
{
if (_nextBatch != null)
{
_nextBatch.Buffer.Release();
_nextBatch = null;
}
}
}
}
11 changes: 11 additions & 0 deletions Assets/Mirage/Runtime/SocketLayer/Connection/Batch.cs.meta

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Assets/Mirage/Runtime/SocketLayer/Connection/Connection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -205,15 +205,15 @@ private void UpdateConnected()
internal abstract void ReceiveNotifyAck(Packet packet);
internal abstract void ReceiveReliableFragment(Packet packet);

protected void HandleReliableBatched(byte[] array, int offset, int packetLength)
protected void HandleReliableBatched(byte[] array, int offset, int packetLength, PacketType packetType)
{
while (offset < packetLength)
{
var length = ByteUtils.ReadUShort(array, ref offset);
var message = new ArraySegment<byte>(array, offset, length);
offset += length;

_metrics?.OnReceiveMessageReliable(length);
_metrics?.OnReceiveMessage(packetType, length);
_dataHandler.ReceiveMessage(this, message);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,28 @@ namespace Mirage.SocketLayer
/// </summary>
internal sealed class NoReliableConnection : Connection
{
private const int HEADER_SIZE = 1 + MESSAGE_LENGTH_SIZE;
private const int MESSAGE_LENGTH_SIZE = 2;
private const int HEADER_SIZE = 1 + Batch.MESSAGE_LENGTH_SIZE;

private byte[] _nextBatch;
private int _batchLength;
private readonly Batch _nextBatchReliable;

internal NoReliableConnection(Peer peer, IEndPoint endPoint, IDataHandler dataHandler, Config config, int maxPacketSize, Time time, ILogger logger, Metrics metrics)
: base(peer, endPoint, dataHandler, config, maxPacketSize, time, logger, metrics)
{
_nextBatch = new byte[maxPacketSize];
CreateNewBatch();
_nextBatchReliable = new ArrayBatch(maxPacketSize, SendBatchInternal, PacketType.Reliable);

if (maxPacketSize > ushort.MaxValue)
{
throw new ArgumentException($"Max package size can not bigger than {ushort.MaxValue}. NoReliableConnection uses 2 bytes for message length, maxPacketSize over that value will mean that message will be incorrectly batched.");
}
}

private void SendBatchInternal(byte[] batch, int length)
{
_peer.Send(this, batch, length);
}

// just sue SendReliable for unreliable/notify
// note: we dont need to pass in that it is reliable, receiving doesn't really care what channel it is
public override void SendUnreliable(byte[] packet, int offset, int length) => SendReliable(packet, offset, length);
public override void SendNotify(byte[] packet, int offset, int length, INotifyCallBack callBacks)
{
Expand All @@ -53,40 +56,13 @@ public override void SendReliable(byte[] message, int offset, int length)
throw new ArgumentException($"Message is bigger than MTU, size:{length} but max message size is {_maxPacketSize - HEADER_SIZE}");
}


var msgLength = length + MESSAGE_LENGTH_SIZE;
if (_batchLength + msgLength > _maxPacketSize)
{
// if full, send and create new
SendBatch();
}

AddToBatch(message, offset, length);
_nextBatchReliable.AddMessage(message, offset, length);
_metrics?.OnSendMessageReliable(length);
}

private void SendBatch()
{
_peer.Send(this, _nextBatch, _batchLength);
CreateNewBatch();
}

private void CreateNewBatch()
{
_nextBatch[0] = (byte)PacketType.Reliable;
_batchLength = 1;
}

private void AddToBatch(byte[] message, int offset, int length)
{
ByteUtils.WriteUShort(_nextBatch, ref _batchLength, checked((ushort)length));
Buffer.BlockCopy(message, offset, _nextBatch, _batchLength, length);
_batchLength += length;
}

internal override void ReceiveReliablePacket(Packet packet)
{
HandleReliableBatched(packet.Buffer.array, 1, packet.Length);
HandleReliableBatched(packet.Buffer.array, 1, packet.Length, PacketType.Reliable);
}

internal override void ReceiveUnreliablePacket(Packet packet) => throw new NotSupportedException();
Expand All @@ -96,10 +72,7 @@ internal override void ReceiveReliablePacket(Packet packet)

public override void FlushBatch()
{
if (_batchLength > 1)
{
SendBatch();
}
_nextBatchReliable.Flush();
}

internal override bool IsValidSize(Packet packet)
Expand Down
Loading
Loading