diff --git a/Assets/Mirage/Runtime/MethodInvocationException.cs b/Assets/Mirage/Runtime/MethodInvocationException.cs index facaebb16a..9edafaad6e 100644 --- a/Assets/Mirage/Runtime/MethodInvocationException.cs +++ b/Assets/Mirage/Runtime/MethodInvocationException.cs @@ -1,4 +1,4 @@ -using System; +using System; using System.Runtime.Serialization; namespace Mirage diff --git a/Assets/Mirage/Runtime/RemoteCalls/ClientRpcSender.cs b/Assets/Mirage/Runtime/RemoteCalls/ClientRpcSender.cs index 5d159742b0..e406909b5a 100644 --- a/Assets/Mirage/Runtime/RemoteCalls/ClientRpcSender.cs +++ b/Assets/Mirage/Runtime/RemoteCalls/ClientRpcSender.cs @@ -39,10 +39,12 @@ public static void SendTarget(NetworkBehaviour behaviour, int relativeIndex, Net public static UniTask SendTargetWithReturn(NetworkBehaviour behaviour, int relativeIndex, NetworkWriter writer, INetworkPlayer player) { - var index = behaviour.Identity.RemoteCallCollection.GetIndexOffset(behaviour) + relativeIndex; + var collection = behaviour.Identity.RemoteCallCollection; + var index = collection.GetIndexOffset(behaviour) + relativeIndex; Validate(behaviour, index); - (var task, var id) = behaviour.ServerObjectManager._rpcHandler.CreateReplyTask(); + var callInfo = collection.GetAbsolute(index); + (var task, var id) = behaviour.ServerObjectManager._rpcHandler.CreateReplyTask(callInfo); var message = new RpcWithReplyMessage { NetId = behaviour.NetId, diff --git a/Assets/Mirage/Runtime/RemoteCalls/RemoteCallHelper.cs b/Assets/Mirage/Runtime/RemoteCalls/RemoteCallHelper.cs index cae26509a0..603700b390 100644 --- a/Assets/Mirage/Runtime/RemoteCalls/RemoteCallHelper.cs +++ b/Assets/Mirage/Runtime/RemoteCalls/RemoteCallHelper.cs @@ -61,17 +61,38 @@ public void RegisterRequest(int index, string name, bool cmdRequireAuthority, async UniTaskVoid Wrapper(NetworkBehaviour obj, NetworkReader reader, INetworkPlayer senderPlayer, int replyId) { /// invoke the serverRpc and send a reply message - var result = await func(obj, reader, senderPlayer, replyId); + bool success; + T result = default; + try + { + result = await func(obj, reader, senderPlayer, replyId); + success = true; + } + catch (Exception e) + { + success = false; + logger.LogError($"Return RPC threw an Exception: {e}"); + } + - using (var writer = NetworkWriterPool.GetWriter()) + var serverRpcReply = new RpcReply { - writer.Write(result); - var serverRpcReply = new RpcReply + ReplyId = replyId, + Success = success, + }; + if (success) + { + // if success, write payload and send + // else just send it without payload (since there is no result) + using (var writer = NetworkWriterPool.GetWriter()) { - ReplyId = replyId, - Payload = writer.ToArraySegment() - }; - + writer.Write(result); + serverRpcReply.Payload = writer.ToArraySegment(); + senderPlayer.Send(serverRpcReply); + } + } + else + { senderPlayer.Send(serverRpcReply); } } @@ -148,6 +169,7 @@ public class RemoteCall public RemoteCall(NetworkBehaviour behaviour, RpcInvokeType invokeType, RpcDelegate function, bool requireAuthority, string name) { Behaviour = behaviour; + DeclaringType = behaviour.GetType(); InvokeType = invokeType; Function = function; RequireAuthority = requireAuthority; diff --git a/Assets/Mirage/Runtime/RemoteCalls/ReturnRpcException.cs b/Assets/Mirage/Runtime/RemoteCalls/ReturnRpcException.cs new file mode 100644 index 0000000000..6c12bc84d6 --- /dev/null +++ b/Assets/Mirage/Runtime/RemoteCalls/ReturnRpcException.cs @@ -0,0 +1,9 @@ +using System; + +namespace Mirage.RemoteCalls +{ + public class ReturnRpcException : Exception + { + public ReturnRpcException(string message) : base(message) { } + } +} diff --git a/Assets/Mirage/Runtime/RemoteCalls/ReturnRpcException.cs.meta b/Assets/Mirage/Runtime/RemoteCalls/ReturnRpcException.cs.meta new file mode 100644 index 0000000000..8294c78510 --- /dev/null +++ b/Assets/Mirage/Runtime/RemoteCalls/ReturnRpcException.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 6d7c9adc2757bbd42b1ec02b8742b90e +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/Assets/Mirage/Runtime/RemoteCalls/RpcHandler.cs b/Assets/Mirage/Runtime/RemoteCalls/RpcHandler.cs index ba792c0520..1487d236ca 100644 --- a/Assets/Mirage/Runtime/RemoteCalls/RpcHandler.cs +++ b/Assets/Mirage/Runtime/RemoteCalls/RpcHandler.cs @@ -12,7 +12,10 @@ internal class RpcHandler { private static readonly ILogger logger = LogFactory.GetLogger(); - private readonly Dictionary> _callbacks = new Dictionary>(); + private delegate void ReplyCallbackSuccess(NetworkReader reader); + private delegate void ReplyCallbackFail(); + + private readonly Dictionary _callbacks = new Dictionary(); private int _nextReplyId; /// /// Object locator required for deserializing the reply @@ -103,29 +106,50 @@ private void ThrowInvalidRpc(RemoteCall remoteCall) /// /// /// the task that will be completed when the result is in, and the id to use in the request - public (UniTask task, int replyId) CreateReplyTask() + public (UniTask task, int replyId) CreateReplyTask(RemoteCall info) { var newReplyId = _nextReplyId++; var completionSource = AutoResetUniTaskCompletionSource.Create(); - void Callback(NetworkReader reader) + void CallbackSuccess(NetworkReader reader) { var result = reader.Read(); completionSource.TrySetResult(result); } - _callbacks.Add(newReplyId, Callback); + void CallbackFail() + { + var netId = 0u; + var name = ""; + if (info.Behaviour != null) + { + netId = info.Behaviour.NetId; + name = info.Behaviour.name; + } + var message = $"Exception thrown from return RPC. {info.Name} on netId={netId} {name}"; + completionSource.TrySetException(new ReturnRpcException(message)); + } + + _callbacks.Add(newReplyId, (CallbackSuccess, CallbackFail)); return (completionSource.Task, newReplyId); } private void OnReply(INetworkPlayer player, RpcReply reply) { // find the callback that was waiting for this and invoke it. - if (_callbacks.TryGetValue(reply.ReplyId, out var action)) + if (_callbacks.TryGetValue(reply.ReplyId, out var callbacks)) { _callbacks.Remove(_nextReplyId); - using (var reader = NetworkReaderPool.GetReader(reply.Payload, _objectLocator)) + + if (reply.Success) + { + using (var reader = NetworkReaderPool.GetReader(reply.Payload, _objectLocator)) + { + callbacks.success.Invoke(reader); + } + } + else { - action.Invoke(reader); + callbacks.fail.Invoke(); } } else diff --git a/Assets/Mirage/Runtime/RemoteCalls/RpcMessages.cs b/Assets/Mirage/Runtime/RemoteCalls/RpcMessages.cs index 6cc9432811..189176a45f 100644 --- a/Assets/Mirage/Runtime/RemoteCalls/RpcMessages.cs +++ b/Assets/Mirage/Runtime/RemoteCalls/RpcMessages.cs @@ -28,6 +28,8 @@ public struct RpcWithReplyMessage public struct RpcReply { public int ReplyId; + /// If result is returned, or exception was thrown + public bool Success; public ArraySegment Payload; } } diff --git a/Assets/Mirage/Runtime/RemoteCalls/ServerRpcSender.cs b/Assets/Mirage/Runtime/RemoteCalls/ServerRpcSender.cs index 1e0015322b..813271b81f 100644 --- a/Assets/Mirage/Runtime/RemoteCalls/ServerRpcSender.cs +++ b/Assets/Mirage/Runtime/RemoteCalls/ServerRpcSender.cs @@ -26,7 +26,8 @@ public static void Send(NetworkBehaviour behaviour, int relativeIndex, NetworkWr public static UniTask SendWithReturn(NetworkBehaviour behaviour, int relativeIndex, NetworkWriter writer, bool requireAuthority) { - var index = behaviour.Identity.RemoteCallCollection.GetIndexOffset(behaviour) + relativeIndex; + var collection = behaviour.Identity.RemoteCallCollection; + var index = collection.GetIndexOffset(behaviour) + relativeIndex; Validate(behaviour, index, requireAuthority); var message = new RpcWithReplyMessage { @@ -35,7 +36,8 @@ public static UniTask SendWithReturn(NetworkBehaviour behaviour, int relat Payload = writer.ToArraySegment() }; - (var task, var id) = behaviour.ClientObjectManager._rpcHandler.CreateReplyTask(); + var callInfo = collection.GetAbsolute(index); + (var task, var id) = behaviour.ClientObjectManager._rpcHandler.CreateReplyTask(callInfo); message.ReplyId = id; diff --git a/Assets/Tests/Runtime/RpcTests/Async/ReturnRpcClientServerTest.cs b/Assets/Tests/Runtime/RpcTests/Async/ReturnRpcClientServerTest.cs index 28f5d627a9..583c5f4d3f 100644 --- a/Assets/Tests/Runtime/RpcTests/Async/ReturnRpcClientServerTest.cs +++ b/Assets/Tests/Runtime/RpcTests/Async/ReturnRpcClientServerTest.cs @@ -1,6 +1,7 @@ using System.Collections; +using System.Text.RegularExpressions; using Cysharp.Threading.Tasks; -using Mirage.Tests.Runtime.Host; +using Mirage.RemoteCalls; using NUnit.Framework; using UnityEngine; using UnityEngine.TestTools; @@ -36,6 +37,7 @@ public IEnumerator ClientRpcOwnerReturn() => UniTask.ToCoroutine(async () => Assert.That(result, Is.EqualTo(random)); }); } + public class ReturnRpcClientServerTest_float : ClientServerSetup { [UnityTest] @@ -65,6 +67,60 @@ public IEnumerator ClientRpcOwnerReturn() => UniTask.ToCoroutine(async () => Assert.That(result, Is.EqualTo(random)); }); } + + public class ReturnRpcClientServerTest_throw : ClientServerSetup + { + [UnityTest] + public IEnumerator ServerRpcReturn() => UniTask.ToCoroutine(async () => + { + LogAssert.Expect(LogType.Error, new Regex(".*Return RPC threw an Exception:.*", RegexOptions.Multiline)); + try + { + _ = await clientComponent.GetResultServer(); + Assert.Fail(); + } + catch (ReturnRpcException e) + { + var fullName = "Mirage.Tests.Runtime.RpcTests.Async.ReturnRpcComponent_throw.GetResultServer"; + var message = $"Exception thrown from return RPC. {fullName} on netId={clientComponent.NetId} {clientComponent.name}"; + Assert.That(e, Has.Message.EqualTo(message)); + } + }); + + [UnityTest] + public IEnumerator ClientRpcTargetReturn() => UniTask.ToCoroutine(async () => + { + LogAssert.Expect(LogType.Error, new Regex(".*Return RPC threw an Exception:.*", RegexOptions.Multiline)); + try + { + _ = await serverComponent.GetResultTarget(serverPlayer); + Assert.Fail(); + } + catch (ReturnRpcException e) + { + var fullName = "Mirage.Tests.Runtime.RpcTests.Async.ReturnRpcComponent_throw.GetResultTarget"; + var message = $"Exception thrown from return RPC. {fullName} on netId={serverComponent.NetId} {serverComponent.name}"; + Assert.That(e, Has.Message.EqualTo(message)); + } + }); + + [UnityTest] + public IEnumerator ClientRpcOwnerReturn() => UniTask.ToCoroutine(async () => + { + LogAssert.Expect(LogType.Error, new Regex(".*Return RPC threw an Exception:.*", RegexOptions.Multiline)); + try + { + _ = await serverComponent.GetResultOwner(); + Assert.Fail(); + } + catch (ReturnRpcException e) + { + var fullName = "Mirage.Tests.Runtime.RpcTests.Async.ReturnRpcComponent_throw.GetResultOwner"; + var message = $"Exception thrown from return RPC. {fullName} on netId={serverComponent.NetId} {serverComponent.name}"; + Assert.That(e, Has.Message.EqualTo(message)); + } + }); + } public class ReturnRpcClientServerTest_struct : ClientServerSetup { [UnityTest] diff --git a/Assets/Tests/Runtime/RpcTests/Async/ReturnRpcComponents.cs b/Assets/Tests/Runtime/RpcTests/Async/ReturnRpcComponents.cs index 2457257b20..7115553546 100644 --- a/Assets/Tests/Runtime/RpcTests/Async/ReturnRpcComponents.cs +++ b/Assets/Tests/Runtime/RpcTests/Async/ReturnRpcComponents.cs @@ -1,3 +1,4 @@ +using System; using Cysharp.Threading.Tasks; using UnityEngine; @@ -47,6 +48,27 @@ public UniTask GetResultOwner() return UniTask.FromResult(rpcResult); } } + public class ReturnRpcComponent_throw : NetworkBehaviour + { + public static ArgumentException TestException => new System.ArgumentException("some bad thing happened"); + [ServerRpc] + public UniTask GetResultServer() + { + throw TestException; + } + + [ClientRpc(target = RpcTarget.Player)] + public UniTask GetResultTarget(INetworkPlayer target) + { + throw TestException; + } + + [ClientRpc(target = RpcTarget.Owner)] + public UniTask GetResultOwner() + { + throw TestException; + } + } public class ReturnRpcComponent_struct : NetworkBehaviour { public Vector3 rpcResult; diff --git a/Assets/Tests/Runtime/RpcTests/Async/ReturnRpcHostTest.cs b/Assets/Tests/Runtime/RpcTests/Async/ReturnRpcHostTest.cs index 0304da9791..b010833e15 100644 --- a/Assets/Tests/Runtime/RpcTests/Async/ReturnRpcHostTest.cs +++ b/Assets/Tests/Runtime/RpcTests/Async/ReturnRpcHostTest.cs @@ -1,9 +1,10 @@ +using System; using System.Collections; using Cysharp.Threading.Tasks; using Mirage.Tests.Runtime.Host; using NUnit.Framework; -using UnityEngine; using UnityEngine.TestTools; +using Random = UnityEngine.Random; namespace Mirage.Tests.Runtime.RpcTests.Async { @@ -65,6 +66,52 @@ public IEnumerator ClientRpcOwnerReturn() => UniTask.ToCoroutine(async () => Assert.That(result, Is.EqualTo(random)); }); } + + public class ReturnRpcHostTest_Throw : HostSetup + { + [UnityTest] + public IEnumerator ServerRpcReturn() => UniTask.ToCoroutine(async () => + { + try + { + _ = await hostComponent.GetResultServer(); + Assert.Fail(); + } + catch (ArgumentException e) // host invokes direction, so we can catch exception + { + Assert.That(e, Has.Message.EqualTo(ReturnRpcComponent_throw.TestException.Message)); + } + }); + + [UnityTest] + public IEnumerator ClientRpcTargetReturn() => UniTask.ToCoroutine(async () => + { + try + { + _ = await hostComponent.GetResultTarget(hostServerPlayer); + Assert.Fail(); + } + catch (ArgumentException e) // host invokes direction, so we can catch exception + { + Assert.That(e, Has.Message.EqualTo(ReturnRpcComponent_throw.TestException.Message)); + } + }); + + [UnityTest] + public IEnumerator ClientRpcOwnerReturn() => UniTask.ToCoroutine(async () => + { + try + { + _ = await hostComponent.GetResultOwner(); + Assert.Fail(); + } + catch (ArgumentException e) // host invokes direction, so we can catch exception + { + Assert.That(e, Has.Message.EqualTo(ReturnRpcComponent_throw.TestException.Message)); + } + }); + } + public class ReturnRpcHostTest_struct : HostSetup { [UnityTest]