From 6b7e68326f29485edf3858695d4615a0b7083802 Mon Sep 17 00:00:00 2001 From: James Frowen Date: Sun, 1 Dec 2024 15:03:23 +0000 Subject: [PATCH] feat: adding try catch for return RPC return rpc will now catch if async method throws, and then throw ReturnRpcException on the caller. This will avoid tasks hanging forever if exception is thrown --- .../Runtime/MethodInvocationException.cs | 2 +- .../Runtime/RemoteCalls/ClientRpcSender.cs | 6 +- .../Runtime/RemoteCalls/RemoteCallHelper.cs | 38 +++++++++--- .../Runtime/RemoteCalls/ReturnRpcException.cs | 9 +++ .../RemoteCalls/ReturnRpcException.cs.meta | 11 ++++ .../Mirage/Runtime/RemoteCalls/RpcHandler.cs | 38 +++++++++--- .../Mirage/Runtime/RemoteCalls/RpcMessages.cs | 2 + .../Runtime/RemoteCalls/ServerRpcSender.cs | 6 +- .../Async/ReturnRpcClientServerTest.cs | 58 ++++++++++++++++++- .../RpcTests/Async/ReturnRpcComponents.cs | 22 +++++++ .../RpcTests/Async/ReturnRpcHostTest.cs | 49 +++++++++++++++- 11 files changed, 219 insertions(+), 22 deletions(-) create mode 100644 Assets/Mirage/Runtime/RemoteCalls/ReturnRpcException.cs create mode 100644 Assets/Mirage/Runtime/RemoteCalls/ReturnRpcException.cs.meta diff --git a/Assets/Mirage/Runtime/MethodInvocationException.cs b/Assets/Mirage/Runtime/MethodInvocationException.cs index facaebb16a1..9edafaad6eb 100644 --- a/Assets/Mirage/Runtime/MethodInvocationException.cs +++ b/Assets/Mirage/Runtime/MethodInvocationException.cs @@ -1,4 +1,4 @@ -using System; +using System; using System.Runtime.Serialization; namespace Mirage diff --git a/Assets/Mirage/Runtime/RemoteCalls/ClientRpcSender.cs b/Assets/Mirage/Runtime/RemoteCalls/ClientRpcSender.cs index 5d159742b03..e406909b5ab 100644 --- a/Assets/Mirage/Runtime/RemoteCalls/ClientRpcSender.cs +++ b/Assets/Mirage/Runtime/RemoteCalls/ClientRpcSender.cs @@ -39,10 +39,12 @@ public static void SendTarget(NetworkBehaviour behaviour, int relativeIndex, Net public static UniTask SendTargetWithReturn(NetworkBehaviour behaviour, int relativeIndex, NetworkWriter writer, INetworkPlayer player) { - var index = behaviour.Identity.RemoteCallCollection.GetIndexOffset(behaviour) + relativeIndex; + var collection = behaviour.Identity.RemoteCallCollection; + var index = collection.GetIndexOffset(behaviour) + relativeIndex; Validate(behaviour, index); - (var task, var id) = behaviour.ServerObjectManager._rpcHandler.CreateReplyTask(); + var callInfo = collection.GetAbsolute(index); + (var task, var id) = behaviour.ServerObjectManager._rpcHandler.CreateReplyTask(callInfo); var message = new RpcWithReplyMessage { NetId = behaviour.NetId, diff --git a/Assets/Mirage/Runtime/RemoteCalls/RemoteCallHelper.cs b/Assets/Mirage/Runtime/RemoteCalls/RemoteCallHelper.cs index cae26509a02..603700b3901 100644 --- a/Assets/Mirage/Runtime/RemoteCalls/RemoteCallHelper.cs +++ b/Assets/Mirage/Runtime/RemoteCalls/RemoteCallHelper.cs @@ -61,17 +61,38 @@ public void RegisterRequest(int index, string name, bool cmdRequireAuthority, async UniTaskVoid Wrapper(NetworkBehaviour obj, NetworkReader reader, INetworkPlayer senderPlayer, int replyId) { /// invoke the serverRpc and send a reply message - var result = await func(obj, reader, senderPlayer, replyId); + bool success; + T result = default; + try + { + result = await func(obj, reader, senderPlayer, replyId); + success = true; + } + catch (Exception e) + { + success = false; + logger.LogError($"Return RPC threw an Exception: {e}"); + } + - using (var writer = NetworkWriterPool.GetWriter()) + var serverRpcReply = new RpcReply { - writer.Write(result); - var serverRpcReply = new RpcReply + ReplyId = replyId, + Success = success, + }; + if (success) + { + // if success, write payload and send + // else just send it without payload (since there is no result) + using (var writer = NetworkWriterPool.GetWriter()) { - ReplyId = replyId, - Payload = writer.ToArraySegment() - }; - + writer.Write(result); + serverRpcReply.Payload = writer.ToArraySegment(); + senderPlayer.Send(serverRpcReply); + } + } + else + { senderPlayer.Send(serverRpcReply); } } @@ -148,6 +169,7 @@ public class RemoteCall public RemoteCall(NetworkBehaviour behaviour, RpcInvokeType invokeType, RpcDelegate function, bool requireAuthority, string name) { Behaviour = behaviour; + DeclaringType = behaviour.GetType(); InvokeType = invokeType; Function = function; RequireAuthority = requireAuthority; diff --git a/Assets/Mirage/Runtime/RemoteCalls/ReturnRpcException.cs b/Assets/Mirage/Runtime/RemoteCalls/ReturnRpcException.cs new file mode 100644 index 00000000000..6c12bc84d61 --- /dev/null +++ b/Assets/Mirage/Runtime/RemoteCalls/ReturnRpcException.cs @@ -0,0 +1,9 @@ +using System; + +namespace Mirage.RemoteCalls +{ + public class ReturnRpcException : Exception + { + public ReturnRpcException(string message) : base(message) { } + } +} diff --git a/Assets/Mirage/Runtime/RemoteCalls/ReturnRpcException.cs.meta b/Assets/Mirage/Runtime/RemoteCalls/ReturnRpcException.cs.meta new file mode 100644 index 00000000000..8294c785100 --- /dev/null +++ b/Assets/Mirage/Runtime/RemoteCalls/ReturnRpcException.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 6d7c9adc2757bbd42b1ec02b8742b90e +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/Assets/Mirage/Runtime/RemoteCalls/RpcHandler.cs b/Assets/Mirage/Runtime/RemoteCalls/RpcHandler.cs index ba792c0520b..1487d236ca1 100644 --- a/Assets/Mirage/Runtime/RemoteCalls/RpcHandler.cs +++ b/Assets/Mirage/Runtime/RemoteCalls/RpcHandler.cs @@ -12,7 +12,10 @@ internal class RpcHandler { private static readonly ILogger logger = LogFactory.GetLogger(); - private readonly Dictionary> _callbacks = new Dictionary>(); + private delegate void ReplyCallbackSuccess(NetworkReader reader); + private delegate void ReplyCallbackFail(); + + private readonly Dictionary _callbacks = new Dictionary(); private int _nextReplyId; /// /// Object locator required for deserializing the reply @@ -103,29 +106,50 @@ private void ThrowInvalidRpc(RemoteCall remoteCall) /// /// /// the task that will be completed when the result is in, and the id to use in the request - public (UniTask task, int replyId) CreateReplyTask() + public (UniTask task, int replyId) CreateReplyTask(RemoteCall info) { var newReplyId = _nextReplyId++; var completionSource = AutoResetUniTaskCompletionSource.Create(); - void Callback(NetworkReader reader) + void CallbackSuccess(NetworkReader reader) { var result = reader.Read(); completionSource.TrySetResult(result); } - _callbacks.Add(newReplyId, Callback); + void CallbackFail() + { + var netId = 0u; + var name = ""; + if (info.Behaviour != null) + { + netId = info.Behaviour.NetId; + name = info.Behaviour.name; + } + var message = $"Exception thrown from return RPC. {info.Name} on netId={netId} {name}"; + completionSource.TrySetException(new ReturnRpcException(message)); + } + + _callbacks.Add(newReplyId, (CallbackSuccess, CallbackFail)); return (completionSource.Task, newReplyId); } private void OnReply(INetworkPlayer player, RpcReply reply) { // find the callback that was waiting for this and invoke it. - if (_callbacks.TryGetValue(reply.ReplyId, out var action)) + if (_callbacks.TryGetValue(reply.ReplyId, out var callbacks)) { _callbacks.Remove(_nextReplyId); - using (var reader = NetworkReaderPool.GetReader(reply.Payload, _objectLocator)) + + if (reply.Success) + { + using (var reader = NetworkReaderPool.GetReader(reply.Payload, _objectLocator)) + { + callbacks.success.Invoke(reader); + } + } + else { - action.Invoke(reader); + callbacks.fail.Invoke(); } } else diff --git a/Assets/Mirage/Runtime/RemoteCalls/RpcMessages.cs b/Assets/Mirage/Runtime/RemoteCalls/RpcMessages.cs index 6cc9432811a..189176a45f0 100644 --- a/Assets/Mirage/Runtime/RemoteCalls/RpcMessages.cs +++ b/Assets/Mirage/Runtime/RemoteCalls/RpcMessages.cs @@ -28,6 +28,8 @@ public struct RpcWithReplyMessage public struct RpcReply { public int ReplyId; + /// If result is returned, or exception was thrown + public bool Success; public ArraySegment Payload; } } diff --git a/Assets/Mirage/Runtime/RemoteCalls/ServerRpcSender.cs b/Assets/Mirage/Runtime/RemoteCalls/ServerRpcSender.cs index 1e0015322b4..813271b81f9 100644 --- a/Assets/Mirage/Runtime/RemoteCalls/ServerRpcSender.cs +++ b/Assets/Mirage/Runtime/RemoteCalls/ServerRpcSender.cs @@ -26,7 +26,8 @@ public static void Send(NetworkBehaviour behaviour, int relativeIndex, NetworkWr public static UniTask SendWithReturn(NetworkBehaviour behaviour, int relativeIndex, NetworkWriter writer, bool requireAuthority) { - var index = behaviour.Identity.RemoteCallCollection.GetIndexOffset(behaviour) + relativeIndex; + var collection = behaviour.Identity.RemoteCallCollection; + var index = collection.GetIndexOffset(behaviour) + relativeIndex; Validate(behaviour, index, requireAuthority); var message = new RpcWithReplyMessage { @@ -35,7 +36,8 @@ public static UniTask SendWithReturn(NetworkBehaviour behaviour, int relat Payload = writer.ToArraySegment() }; - (var task, var id) = behaviour.ClientObjectManager._rpcHandler.CreateReplyTask(); + var callInfo = collection.GetAbsolute(index); + (var task, var id) = behaviour.ClientObjectManager._rpcHandler.CreateReplyTask(callInfo); message.ReplyId = id; diff --git a/Assets/Tests/Runtime/RpcTests/Async/ReturnRpcClientServerTest.cs b/Assets/Tests/Runtime/RpcTests/Async/ReturnRpcClientServerTest.cs index 28f5d627a90..583c5f4d3f6 100644 --- a/Assets/Tests/Runtime/RpcTests/Async/ReturnRpcClientServerTest.cs +++ b/Assets/Tests/Runtime/RpcTests/Async/ReturnRpcClientServerTest.cs @@ -1,6 +1,7 @@ using System.Collections; +using System.Text.RegularExpressions; using Cysharp.Threading.Tasks; -using Mirage.Tests.Runtime.Host; +using Mirage.RemoteCalls; using NUnit.Framework; using UnityEngine; using UnityEngine.TestTools; @@ -36,6 +37,7 @@ public IEnumerator ClientRpcOwnerReturn() => UniTask.ToCoroutine(async () => Assert.That(result, Is.EqualTo(random)); }); } + public class ReturnRpcClientServerTest_float : ClientServerSetup { [UnityTest] @@ -65,6 +67,60 @@ public IEnumerator ClientRpcOwnerReturn() => UniTask.ToCoroutine(async () => Assert.That(result, Is.EqualTo(random)); }); } + + public class ReturnRpcClientServerTest_throw : ClientServerSetup + { + [UnityTest] + public IEnumerator ServerRpcReturn() => UniTask.ToCoroutine(async () => + { + LogAssert.Expect(LogType.Error, new Regex(".*Return RPC threw an Exception:.*", RegexOptions.Multiline)); + try + { + _ = await clientComponent.GetResultServer(); + Assert.Fail(); + } + catch (ReturnRpcException e) + { + var fullName = "Mirage.Tests.Runtime.RpcTests.Async.ReturnRpcComponent_throw.GetResultServer"; + var message = $"Exception thrown from return RPC. {fullName} on netId={clientComponent.NetId} {clientComponent.name}"; + Assert.That(e, Has.Message.EqualTo(message)); + } + }); + + [UnityTest] + public IEnumerator ClientRpcTargetReturn() => UniTask.ToCoroutine(async () => + { + LogAssert.Expect(LogType.Error, new Regex(".*Return RPC threw an Exception:.*", RegexOptions.Multiline)); + try + { + _ = await serverComponent.GetResultTarget(serverPlayer); + Assert.Fail(); + } + catch (ReturnRpcException e) + { + var fullName = "Mirage.Tests.Runtime.RpcTests.Async.ReturnRpcComponent_throw.GetResultTarget"; + var message = $"Exception thrown from return RPC. {fullName} on netId={serverComponent.NetId} {serverComponent.name}"; + Assert.That(e, Has.Message.EqualTo(message)); + } + }); + + [UnityTest] + public IEnumerator ClientRpcOwnerReturn() => UniTask.ToCoroutine(async () => + { + LogAssert.Expect(LogType.Error, new Regex(".*Return RPC threw an Exception:.*", RegexOptions.Multiline)); + try + { + _ = await serverComponent.GetResultOwner(); + Assert.Fail(); + } + catch (ReturnRpcException e) + { + var fullName = "Mirage.Tests.Runtime.RpcTests.Async.ReturnRpcComponent_throw.GetResultOwner"; + var message = $"Exception thrown from return RPC. {fullName} on netId={serverComponent.NetId} {serverComponent.name}"; + Assert.That(e, Has.Message.EqualTo(message)); + } + }); + } public class ReturnRpcClientServerTest_struct : ClientServerSetup { [UnityTest] diff --git a/Assets/Tests/Runtime/RpcTests/Async/ReturnRpcComponents.cs b/Assets/Tests/Runtime/RpcTests/Async/ReturnRpcComponents.cs index 2457257b206..71155535466 100644 --- a/Assets/Tests/Runtime/RpcTests/Async/ReturnRpcComponents.cs +++ b/Assets/Tests/Runtime/RpcTests/Async/ReturnRpcComponents.cs @@ -1,3 +1,4 @@ +using System; using Cysharp.Threading.Tasks; using UnityEngine; @@ -47,6 +48,27 @@ public UniTask GetResultOwner() return UniTask.FromResult(rpcResult); } } + public class ReturnRpcComponent_throw : NetworkBehaviour + { + public static ArgumentException TestException => new System.ArgumentException("some bad thing happened"); + [ServerRpc] + public UniTask GetResultServer() + { + throw TestException; + } + + [ClientRpc(target = RpcTarget.Player)] + public UniTask GetResultTarget(INetworkPlayer target) + { + throw TestException; + } + + [ClientRpc(target = RpcTarget.Owner)] + public UniTask GetResultOwner() + { + throw TestException; + } + } public class ReturnRpcComponent_struct : NetworkBehaviour { public Vector3 rpcResult; diff --git a/Assets/Tests/Runtime/RpcTests/Async/ReturnRpcHostTest.cs b/Assets/Tests/Runtime/RpcTests/Async/ReturnRpcHostTest.cs index 0304da97910..b010833e158 100644 --- a/Assets/Tests/Runtime/RpcTests/Async/ReturnRpcHostTest.cs +++ b/Assets/Tests/Runtime/RpcTests/Async/ReturnRpcHostTest.cs @@ -1,9 +1,10 @@ +using System; using System.Collections; using Cysharp.Threading.Tasks; using Mirage.Tests.Runtime.Host; using NUnit.Framework; -using UnityEngine; using UnityEngine.TestTools; +using Random = UnityEngine.Random; namespace Mirage.Tests.Runtime.RpcTests.Async { @@ -65,6 +66,52 @@ public IEnumerator ClientRpcOwnerReturn() => UniTask.ToCoroutine(async () => Assert.That(result, Is.EqualTo(random)); }); } + + public class ReturnRpcHostTest_Throw : HostSetup + { + [UnityTest] + public IEnumerator ServerRpcReturn() => UniTask.ToCoroutine(async () => + { + try + { + _ = await hostComponent.GetResultServer(); + Assert.Fail(); + } + catch (ArgumentException e) // host invokes direction, so we can catch exception + { + Assert.That(e, Has.Message.EqualTo(ReturnRpcComponent_throw.TestException.Message)); + } + }); + + [UnityTest] + public IEnumerator ClientRpcTargetReturn() => UniTask.ToCoroutine(async () => + { + try + { + _ = await hostComponent.GetResultTarget(hostServerPlayer); + Assert.Fail(); + } + catch (ArgumentException e) // host invokes direction, so we can catch exception + { + Assert.That(e, Has.Message.EqualTo(ReturnRpcComponent_throw.TestException.Message)); + } + }); + + [UnityTest] + public IEnumerator ClientRpcOwnerReturn() => UniTask.ToCoroutine(async () => + { + try + { + _ = await hostComponent.GetResultOwner(); + Assert.Fail(); + } + catch (ArgumentException e) // host invokes direction, so we can catch exception + { + Assert.That(e, Has.Message.EqualTo(ReturnRpcComponent_throw.TestException.Message)); + } + }); + } + public class ReturnRpcHostTest_struct : HostSetup { [UnityTest]