From c8888e12d7f7f96539eca1f131d99ef0c9f58b81 Mon Sep 17 00:00:00 2001 From: Bar Levy <12776887+barchkile@users.noreply.github.com> Date: Thu, 13 Apr 2023 03:00:34 +0300 Subject: [PATCH] Fix condition of method return type in RequestBuilder to only allow Task<> and IObservable<> (#1364) * Fix condition of generic return type in method info * add test * Update RestMethodInfo.cs --------- Co-authored-by: barle Co-authored-by: Glenn <5834289+glennawatson@users.noreply.github.com> --- Refit.Tests/RequestBuilder.cs | 11 +++++++++++ Refit/RestMethodInfo.cs | 7 ++++--- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/Refit.Tests/RequestBuilder.cs b/Refit.Tests/RequestBuilder.cs index cd84c657d..a61680d2f 100644 --- a/Refit.Tests/RequestBuilder.cs +++ b/Refit.Tests/RequestBuilder.cs @@ -250,6 +250,9 @@ public interface IRestMethodInfoTests [Get("/api/{id}")] Task IEnumerableThrowingError([Query(CollectionFormat.Multi)] IEnumerable values); + + [Get("/foo")] + List InvalidGenericReturnType(); } public enum TestEnum { A, B, C } @@ -1146,6 +1149,14 @@ public void ParameterMappingWithHeaderQueryParamAndQueryArrayParam() Assert.Single(fixture.HeaderParameterMap); Assert.Single(fixture.PropertyParameterMap); } + + [Fact] + public void GenericReturnTypeIsNotTaskOrObservableShouldThrow() + { + var input = typeof(IRestMethodInfoTests); + Assert.Throws(() => new RestMethodInfo(input, + input.GetMethods().First(x => x.Name == nameof(IRestMethodInfoTests.InvalidGenericReturnType)))); + } } [Headers("User-Agent: RefitTestClient", "Api-Version: 1")] diff --git a/Refit/RestMethodInfo.cs b/Refit/RestMethodInfo.cs index 760c790bd..04ec347a8 100644 --- a/Refit/RestMethodInfo.cs +++ b/Refit/RestMethodInfo.cs @@ -463,8 +463,9 @@ static Dictionary BuildHeaderParameterMap(List param void DetermineReturnTypeInfo(MethodInfo methodInfo) { var returnType = methodInfo.ReturnType; - if (returnType.IsGenericType && (methodInfo.ReturnType.GetGenericTypeDefinition() != typeof(Task<>) - || methodInfo.ReturnType.GetGenericTypeDefinition() != typeof(IObservable<>))) + if (returnType.IsGenericType && (methodInfo.ReturnType.GetGenericTypeDefinition() == typeof(Task<>) + || methodInfo.ReturnType.GetGenericTypeDefinition() == typeof(ValueTask<>) + || methodInfo.ReturnType.GetGenericTypeDefinition() == typeof(IObservable<>))) { ReturnType = returnType; ReturnResultType = returnType.GetGenericArguments()[0]; @@ -488,7 +489,7 @@ void DetermineReturnTypeInfo(MethodInfo methodInfo) DeserializedResultType = typeof(void); } else - throw new ArgumentException($"Method \"{methodInfo.Name}\" is invalid. All REST Methods must return either Task or IObservable"); + throw new ArgumentException($"Method \"{methodInfo.Name}\" is invalid. All REST Methods must return either Task or ValueTask or IObservable"); } void DetermineIfResponseMustBeDisposed()