Skip to content

Commit

Permalink
feat: adding try catch for return RPC
Browse files Browse the repository at this point in the history
return rpc will now catch if async method throws, and then throw ReturnRpcException on the caller. This will avoid tasks hanging forever if exception is thrown
  • Loading branch information
James-Frowen committed Dec 1, 2024
1 parent 0e67017 commit 6b7e683
Show file tree
Hide file tree
Showing 11 changed files with 219 additions and 22 deletions.
2 changes: 1 addition & 1 deletion Assets/Mirage/Runtime/MethodInvocationException.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using System;
using System;
using System.Runtime.Serialization;

namespace Mirage
Expand Down
6 changes: 4 additions & 2 deletions Assets/Mirage/Runtime/RemoteCalls/ClientRpcSender.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,12 @@ public static void SendTarget(NetworkBehaviour behaviour, int relativeIndex, Net

public static UniTask<T> SendTargetWithReturn<T>(NetworkBehaviour behaviour, int relativeIndex, NetworkWriter writer, INetworkPlayer player)
{
var index = behaviour.Identity.RemoteCallCollection.GetIndexOffset(behaviour) + relativeIndex;
var collection = behaviour.Identity.RemoteCallCollection;
var index = collection.GetIndexOffset(behaviour) + relativeIndex;
Validate(behaviour, index);

(var task, var id) = behaviour.ServerObjectManager._rpcHandler.CreateReplyTask<T>();
var callInfo = collection.GetAbsolute(index);
(var task, var id) = behaviour.ServerObjectManager._rpcHandler.CreateReplyTask<T>(callInfo);
var message = new RpcWithReplyMessage
{
NetId = behaviour.NetId,
Expand Down
38 changes: 30 additions & 8 deletions Assets/Mirage/Runtime/RemoteCalls/RemoteCallHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,38 @@ public void RegisterRequest<T>(int index, string name, bool cmdRequireAuthority,
async UniTaskVoid Wrapper(NetworkBehaviour obj, NetworkReader reader, INetworkPlayer senderPlayer, int replyId)
{
/// invoke the serverRpc and send a reply message
var result = await func(obj, reader, senderPlayer, replyId);
bool success;
T result = default;
try
{
result = await func(obj, reader, senderPlayer, replyId);
success = true;
}
catch (Exception e)
{
success = false;
logger.LogError($"Return RPC threw an Exception: {e}");
}


using (var writer = NetworkWriterPool.GetWriter())
var serverRpcReply = new RpcReply
{
writer.Write(result);
var serverRpcReply = new RpcReply
ReplyId = replyId,
Success = success,
};
if (success)
{
// if success, write payload and send
// else just send it without payload (since there is no result)
using (var writer = NetworkWriterPool.GetWriter())
{
ReplyId = replyId,
Payload = writer.ToArraySegment()
};

writer.Write(result);
serverRpcReply.Payload = writer.ToArraySegment();
senderPlayer.Send(serverRpcReply);
}
}
else
{
senderPlayer.Send(serverRpcReply);
}
}
Expand Down Expand Up @@ -148,6 +169,7 @@ public class RemoteCall
public RemoteCall(NetworkBehaviour behaviour, RpcInvokeType invokeType, RpcDelegate function, bool requireAuthority, string name)
{
Behaviour = behaviour;
DeclaringType = behaviour.GetType();
InvokeType = invokeType;
Function = function;
RequireAuthority = requireAuthority;
Expand Down
9 changes: 9 additions & 0 deletions Assets/Mirage/Runtime/RemoteCalls/ReturnRpcException.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
using System;

namespace Mirage.RemoteCalls
{
public class ReturnRpcException : Exception
{
public ReturnRpcException(string message) : base(message) { }
}
}
11 changes: 11 additions & 0 deletions Assets/Mirage/Runtime/RemoteCalls/ReturnRpcException.cs.meta

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

38 changes: 31 additions & 7 deletions Assets/Mirage/Runtime/RemoteCalls/RpcHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ internal class RpcHandler
{
private static readonly ILogger logger = LogFactory.GetLogger<RpcHandler>();

private readonly Dictionary<int, Action<NetworkReader>> _callbacks = new Dictionary<int, Action<NetworkReader>>();
private delegate void ReplyCallbackSuccess(NetworkReader reader);
private delegate void ReplyCallbackFail();

private readonly Dictionary<int, (ReplyCallbackSuccess success, ReplyCallbackFail fail)> _callbacks = new Dictionary<int, (ReplyCallbackSuccess success, ReplyCallbackFail fail)>();
private int _nextReplyId;
/// <summary>
/// Object locator required for deserializing the reply
Expand Down Expand Up @@ -103,29 +106,50 @@ private void ThrowInvalidRpc(RemoteCall remoteCall)
/// </summary>
/// <typeparam name="T"></typeparam>
/// <returns>the task that will be completed when the result is in, and the id to use in the request</returns>
public (UniTask<T> task, int replyId) CreateReplyTask<T>()
public (UniTask<T> task, int replyId) CreateReplyTask<T>(RemoteCall info)
{
var newReplyId = _nextReplyId++;
var completionSource = AutoResetUniTaskCompletionSource<T>.Create();
void Callback(NetworkReader reader)
void CallbackSuccess(NetworkReader reader)
{
var result = reader.Read<T>();
completionSource.TrySetResult(result);
}

_callbacks.Add(newReplyId, Callback);
void CallbackFail()
{
var netId = 0u;
var name = "";
if (info.Behaviour != null)
{
netId = info.Behaviour.NetId;
name = info.Behaviour.name;
}
var message = $"Exception thrown from return RPC. {info.Name} on netId={netId} {name}";
completionSource.TrySetException(new ReturnRpcException(message));
}

_callbacks.Add(newReplyId, (CallbackSuccess, CallbackFail));
return (completionSource.Task, newReplyId);
}

private void OnReply(INetworkPlayer player, RpcReply reply)
{
// find the callback that was waiting for this and invoke it.
if (_callbacks.TryGetValue(reply.ReplyId, out var action))
if (_callbacks.TryGetValue(reply.ReplyId, out var callbacks))
{
_callbacks.Remove(_nextReplyId);
using (var reader = NetworkReaderPool.GetReader(reply.Payload, _objectLocator))

if (reply.Success)
{
using (var reader = NetworkReaderPool.GetReader(reply.Payload, _objectLocator))
{
callbacks.success.Invoke(reader);
}
}
else
{
action.Invoke(reader);
callbacks.fail.Invoke();
}
}
else
Expand Down
2 changes: 2 additions & 0 deletions Assets/Mirage/Runtime/RemoteCalls/RpcMessages.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ public struct RpcWithReplyMessage
public struct RpcReply
{
public int ReplyId;
/// <summary>If result is returned, or exception was thrown</summary>
public bool Success;
public ArraySegment<byte> Payload;
}
}
6 changes: 4 additions & 2 deletions Assets/Mirage/Runtime/RemoteCalls/ServerRpcSender.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ public static void Send(NetworkBehaviour behaviour, int relativeIndex, NetworkWr

public static UniTask<T> SendWithReturn<T>(NetworkBehaviour behaviour, int relativeIndex, NetworkWriter writer, bool requireAuthority)
{
var index = behaviour.Identity.RemoteCallCollection.GetIndexOffset(behaviour) + relativeIndex;
var collection = behaviour.Identity.RemoteCallCollection;
var index = collection.GetIndexOffset(behaviour) + relativeIndex;
Validate(behaviour, index, requireAuthority);
var message = new RpcWithReplyMessage
{
Expand All @@ -35,7 +36,8 @@ public static UniTask<T> SendWithReturn<T>(NetworkBehaviour behaviour, int relat
Payload = writer.ToArraySegment()
};

(var task, var id) = behaviour.ClientObjectManager._rpcHandler.CreateReplyTask<T>();
var callInfo = collection.GetAbsolute(index);
(var task, var id) = behaviour.ClientObjectManager._rpcHandler.CreateReplyTask<T>(callInfo);

message.ReplyId = id;

Expand Down
58 changes: 57 additions & 1 deletion Assets/Tests/Runtime/RpcTests/Async/ReturnRpcClientServerTest.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System.Collections;
using System.Text.RegularExpressions;
using Cysharp.Threading.Tasks;
using Mirage.Tests.Runtime.Host;
using Mirage.RemoteCalls;
using NUnit.Framework;
using UnityEngine;
using UnityEngine.TestTools;
Expand Down Expand Up @@ -36,6 +37,7 @@ public IEnumerator ClientRpcOwnerReturn() => UniTask.ToCoroutine(async () =>
Assert.That(result, Is.EqualTo(random));
});
}

public class ReturnRpcClientServerTest_float : ClientServerSetup<ReturnRpcComponent_float>
{
[UnityTest]
Expand Down Expand Up @@ -65,6 +67,60 @@ public IEnumerator ClientRpcOwnerReturn() => UniTask.ToCoroutine(async () =>
Assert.That(result, Is.EqualTo(random));
});
}

public class ReturnRpcClientServerTest_throw : ClientServerSetup<ReturnRpcComponent_throw>
{
[UnityTest]
public IEnumerator ServerRpcReturn() => UniTask.ToCoroutine(async () =>
{
LogAssert.Expect(LogType.Error, new Regex(".*Return RPC threw an Exception:.*", RegexOptions.Multiline));
try
{
_ = await clientComponent.GetResultServer();
Assert.Fail();
}
catch (ReturnRpcException e)
{
var fullName = "Mirage.Tests.Runtime.RpcTests.Async.ReturnRpcComponent_throw.GetResultServer";
var message = $"Exception thrown from return RPC. {fullName} on netId={clientComponent.NetId} {clientComponent.name}";
Assert.That(e, Has.Message.EqualTo(message));
}
});

[UnityTest]
public IEnumerator ClientRpcTargetReturn() => UniTask.ToCoroutine(async () =>
{
LogAssert.Expect(LogType.Error, new Regex(".*Return RPC threw an Exception:.*", RegexOptions.Multiline));
try
{
_ = await serverComponent.GetResultTarget(serverPlayer);
Assert.Fail();
}
catch (ReturnRpcException e)
{
var fullName = "Mirage.Tests.Runtime.RpcTests.Async.ReturnRpcComponent_throw.GetResultTarget";
var message = $"Exception thrown from return RPC. {fullName} on netId={serverComponent.NetId} {serverComponent.name}";
Assert.That(e, Has.Message.EqualTo(message));
}
});

[UnityTest]
public IEnumerator ClientRpcOwnerReturn() => UniTask.ToCoroutine(async () =>
{
LogAssert.Expect(LogType.Error, new Regex(".*Return RPC threw an Exception:.*", RegexOptions.Multiline));
try
{
_ = await serverComponent.GetResultOwner();
Assert.Fail();
}
catch (ReturnRpcException e)
{
var fullName = "Mirage.Tests.Runtime.RpcTests.Async.ReturnRpcComponent_throw.GetResultOwner";
var message = $"Exception thrown from return RPC. {fullName} on netId={serverComponent.NetId} {serverComponent.name}";
Assert.That(e, Has.Message.EqualTo(message));
}
});
}
public class ReturnRpcClientServerTest_struct : ClientServerSetup<ReturnRpcComponent_struct>
{
[UnityTest]
Expand Down
22 changes: 22 additions & 0 deletions Assets/Tests/Runtime/RpcTests/Async/ReturnRpcComponents.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using System;
using Cysharp.Threading.Tasks;
using UnityEngine;

Expand Down Expand Up @@ -47,6 +48,27 @@ public UniTask<float> GetResultOwner()
return UniTask.FromResult(rpcResult);
}
}
public class ReturnRpcComponent_throw : NetworkBehaviour
{
public static ArgumentException TestException => new System.ArgumentException("some bad thing happened");
[ServerRpc]
public UniTask<float> GetResultServer()
{
throw TestException;
}

[ClientRpc(target = RpcTarget.Player)]
public UniTask<float> GetResultTarget(INetworkPlayer target)
{
throw TestException;
}

[ClientRpc(target = RpcTarget.Owner)]
public UniTask<float> GetResultOwner()
{
throw TestException;
}
}
public class ReturnRpcComponent_struct : NetworkBehaviour
{
public Vector3 rpcResult;
Expand Down
49 changes: 48 additions & 1 deletion Assets/Tests/Runtime/RpcTests/Async/ReturnRpcHostTest.cs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
using System;
using System.Collections;
using Cysharp.Threading.Tasks;
using Mirage.Tests.Runtime.Host;
using NUnit.Framework;
using UnityEngine;
using UnityEngine.TestTools;
using Random = UnityEngine.Random;

namespace Mirage.Tests.Runtime.RpcTests.Async
{
Expand Down Expand Up @@ -65,6 +66,52 @@ public IEnumerator ClientRpcOwnerReturn() => UniTask.ToCoroutine(async () =>
Assert.That(result, Is.EqualTo(random));
});
}

public class ReturnRpcHostTest_Throw : HostSetup<ReturnRpcComponent_throw>
{
[UnityTest]
public IEnumerator ServerRpcReturn() => UniTask.ToCoroutine(async () =>
{
try
{
_ = await hostComponent.GetResultServer();
Assert.Fail();
}
catch (ArgumentException e) // host invokes direction, so we can catch exception
{
Assert.That(e, Has.Message.EqualTo(ReturnRpcComponent_throw.TestException.Message));
}
});

[UnityTest]
public IEnumerator ClientRpcTargetReturn() => UniTask.ToCoroutine(async () =>
{
try
{
_ = await hostComponent.GetResultTarget(hostServerPlayer);
Assert.Fail();
}
catch (ArgumentException e) // host invokes direction, so we can catch exception
{
Assert.That(e, Has.Message.EqualTo(ReturnRpcComponent_throw.TestException.Message));
}
});

[UnityTest]
public IEnumerator ClientRpcOwnerReturn() => UniTask.ToCoroutine(async () =>
{
try
{
_ = await hostComponent.GetResultOwner();
Assert.Fail();
}
catch (ArgumentException e) // host invokes direction, so we can catch exception
{
Assert.That(e, Has.Message.EqualTo(ReturnRpcComponent_throw.TestException.Message));
}
});
}

public class ReturnRpcHostTest_struct : HostSetup<ReturnRpcComponent_struct>
{
[UnityTest]
Expand Down

0 comments on commit 6b7e683

Please sign in to comment.