Skip to content

Commit

Permalink
Fix ValueTask bugs (#338)
Browse files Browse the repository at this point in the history
  • Loading branch information
jasongin authored Jul 24, 2024
1 parent 9af64a6 commit 3c36102
Show file tree
Hide file tree
Showing 5 changed files with 161 additions and 32 deletions.
143 changes: 112 additions & 31 deletions src/NodeApi.DotNetHost/JSMarshaller.cs
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,12 @@ public JSMarshaller()
?? throw new NotImplementedException("JSValue.TryUnwrap");

private static readonly MethodInfo s_getOrCreateObjectWrapper =
typeof(JSRuntimeContext).GetInstanceMethod(nameof(JSRuntimeContext.GetOrCreateObjectWrapper))
?? throw new NotImplementedException("JSRuntimeContext.GetOrCreateObjectWrapper");
typeof(JSRuntimeContext).GetInstanceMethod(
nameof(JSRuntimeContext.GetOrCreateObjectWrapper));

private static readonly MethodInfo s_asVoidPromise =
typeof(TaskExtensions).GetStaticMethod(
nameof(TaskExtensions.AsPromise), new[] { typeof(Task) })
?? throw new NotImplementedException("TaskExtensions.AsPromise");
nameof(TaskExtensions.AsPromise), new[] { typeof(Task) });

/// <summary>
/// Gets or sets a value indicating whether the marshaller automatically converts
Expand All @@ -149,6 +148,7 @@ internal static bool IsConvertedType(Type type)
type == typeof(string) ||
type == typeof(Array) ||
type == typeof(Task) ||
type == typeof(ValueTask) ||
type == typeof(CancellationToken) ||
type == typeof(DateTime) ||
type == typeof(TimeSpan) ||
Expand All @@ -164,7 +164,9 @@ internal static bool IsConvertedType(Type type)
}

if (type.IsGenericTypeDefinition &&
(type == typeof(IEnumerable<>) ||
(type == typeof(Task<>) ||
type == typeof(ValueTask<>) ||
type == typeof(IEnumerable<>) ||
type == typeof(IAsyncEnumerable<>) ||
type == typeof(ICollection<>) ||
type == typeof(IReadOnlyCollection<>) ||
Expand Down Expand Up @@ -275,8 +277,9 @@ public LambdaExpression GetFromJSValueExpression(Type toType)

try
{
if (toType == typeof(Task) ||
(toType.IsGenericType && toType.GetGenericTypeDefinition() == typeof(Task<>)))
if (toType == typeof(Task) || toType == typeof(ValueTask) ||
(toType.IsGenericType && toType.GetGenericTypeDefinition() == typeof(Task<>)) ||
(toType.IsGenericType && toType.GetGenericTypeDefinition() == typeof(ValueTask<>)))
{
return _fromJSExpressions.GetOrAdd(toType, BuildConvertFromJSPromiseExpression);
}
Expand Down Expand Up @@ -307,10 +310,12 @@ public LambdaExpression GetToJSValueExpression(Type fromType)

try
{
if (fromType == typeof(Task) ||
(fromType.IsGenericType && fromType.GetGenericTypeDefinition() == typeof(Task<>)))
if (fromType == typeof(Task) || fromType == typeof(ValueTask) ||
(fromType.IsGenericType && fromType.GetGenericTypeDefinition() == typeof(Task<>)) ||
(fromType.IsGenericType &&
fromType.GetGenericTypeDefinition() == typeof(ValueTask<>)))
{
return _fromJSExpressions.GetOrAdd(fromType, BuildConvertToJSPromiseExpression);
return _toJSExpressions.GetOrAdd(fromType, BuildConvertToJSPromiseExpression);
}
else
{
Expand Down Expand Up @@ -1825,8 +1830,17 @@ private Expression BuildResultExpression(
GetCastToJSValueMethod(typeof(JSPromise))!,
Expression.Call(s_asVoidPromise, resultVariable));
}

if (resultType.IsGenericType && resultType.GetGenericTypeDefinition() == typeof(Task<>))
else if (resultType == typeof(ValueTask))
{
return Expression.Call(
GetCastToJSValueMethod(typeof(JSPromise))!,
Expression.Call(
typeof(TaskExtensions).GetStaticMethod(
nameof(TaskExtensions.AsPromise), new[] { typeof(ValueTask) }),
resultVariable));
}
else if (resultType.IsGenericType &&
resultType.GetGenericTypeDefinition() == typeof(Task<>))
{
Type asyncResultType = resultType;
resultType = resultType.GenericTypeArguments[0];
Expand All @@ -1841,6 +1855,22 @@ private Expression BuildResultExpression(
resultVariable,
GetToJSValueExpression(resultType)));
}
else if (resultType.IsGenericType &&
resultType.GetGenericTypeDefinition() == typeof(ValueTask<>))
{
Type asyncResultType = resultType;
resultType = resultType.GenericTypeArguments[0];
MethodInfo asPromiseMethod = typeof(TaskExtensions).GetStaticMethod(
nameof(TaskExtensions.AsPromise),
new[] { typeof(ValueTask<>), typeof(JSValue.From<>) },
resultType);
return Expression.Call(
GetCastToJSValueMethod(typeof(JSPromise))!,
Expression.Call(
asPromiseMethod,
resultVariable,
GetToJSValueExpression(resultType)));
}

Type? nullableType = null;
if (resultType.IsGenericType && resultType.GetGenericTypeDefinition() == typeof(Nullable<>))
Expand Down Expand Up @@ -2521,6 +2551,10 @@ private LambdaExpression BuildConvertFromJSPromiseExpression(Type toType)
string delegateName = "to_" + FullTypeName(toType);

ParameterExpression valueParameter = Expression.Parameter(typeof(JSValue), "value");
Expression valueAsPromise = Expression.Convert(
valueParameter,
typeof(JSPromise),
typeof(JSPromise).GetExplicitConversion(typeof(JSValue), typeof(JSPromise)));
Expression asTaskExpression;

if (toType == typeof(Task))
Expand All @@ -2531,10 +2565,7 @@ private LambdaExpression BuildConvertFromJSPromiseExpression(Type toType)
asTaskExpression = Expression.Call(
typeof(TaskExtensions).GetStaticMethod(
nameof(TaskExtensions.AsTask), new[] { typeof(JSPromise) }),
Expression.Convert(
valueParameter,
typeof(JSPromise),
typeof(JSPromise).GetExplicitConversion(typeof(JSValue), typeof(JSPromise))));
valueAsPromise);
}
else if (toType.IsGenericType && toType.GetGenericTypeDefinition() == typeof(Task<>))
{
Expand All @@ -2546,11 +2577,31 @@ private LambdaExpression BuildConvertFromJSPromiseExpression(Type toType)
typeof(TaskExtensions).GetStaticMethod(
nameof(TaskExtensions.AsTask),
new[] { typeof(JSPromise), typeof(JSValue.To<>) }, resultType),
Expression.Convert(
valueParameter,
typeof(JSPromise),
typeof(JSPromise).GetExplicitConversion(typeof(JSValue), typeof(JSPromise))),
GetFromJSValueExpression(resultType));
valueAsPromise,
GetFromJSValueExpression(resultType));
}
else if (toType == typeof(ValueTask))
{
/*
* ((JSPromise)value).AsValueTask()
*/
asTaskExpression = Expression.Call(
typeof(TaskExtensions).GetStaticMethod(
nameof(TaskExtensions.AsValueTask), new[] { typeof(JSPromise) }),
valueAsPromise);
}
else if (toType.IsGenericType && toType.GetGenericTypeDefinition() == typeof(ValueTask<>))
{
/*
* ((JSPromise)value).AsValueTask<T>((value) => (T)value)
*/
Type resultType = toType.GenericTypeArguments[0];
asTaskExpression = Expression.Call(
typeof(TaskExtensions).GetStaticMethod(
nameof(TaskExtensions.AsValueTask),
new[] { typeof(JSPromise), typeof(JSValue.To<>) }, resultType),
valueAsPromise,
GetFromJSValueExpression(resultType));
}
else
{
Expand All @@ -2571,6 +2622,14 @@ private LambdaExpression BuildConvertToJSPromiseExpression(Type fromType)
Type delegateType = typeof(JSValue.From<>).MakeGenericType(fromType);
string delegateName = "from_" + FullTypeName(fromType);

static Expression ConvertPromiseToJSValueExpression(Expression promiseExpression)
{
return Expression.Convert(
promiseExpression,
typeof(JSValue),
typeof(JSPromise).GetImplicitConversion(typeof(JSPromise), typeof(JSValue)));
}

ParameterExpression valueParameter = Expression.Parameter(fromType, "value");
Expression asPromiseExpression;

Expand All @@ -2579,29 +2638,51 @@ private LambdaExpression BuildConvertToJSPromiseExpression(Type fromType)
/*
* (JSValue)value.AsPromise()
*/
asPromiseExpression = Expression.Convert(
asPromiseExpression = ConvertPromiseToJSValueExpression(
Expression.Call(
typeof(TaskExtensions).GetStaticMethod(
nameof(TaskExtensions.AsPromise), new[] { typeof(Task) }),
valueParameter),
typeof(JSValue),
typeof(JSPromise).GetImplicitConversion(typeof(JSPromise), typeof(JSValue)));
nameof(TaskExtensions.AsPromise), new[] { typeof(Task) }),
valueParameter));
}
else if (fromType.IsGenericType && fromType.GetGenericTypeDefinition() == typeof(Task<>))
{
/*
* (JSValue)value.AsPromise<T>((value) => (JSValue)value)
*/
Type resultType = fromType.GenericTypeArguments[0];
asPromiseExpression = Expression.Convert(
asPromiseExpression = ConvertPromiseToJSValueExpression(
Expression.Call(
typeof(TaskExtensions).GetStaticMethod(
nameof(TaskExtensions.AsPromise),
new[] { typeof(Task), typeof(JSValue.From<>) }),
new[] { typeof(Task<>), typeof(JSValue.From<>) }),
valueParameter,
GetToJSValueExpression(resultType)),
typeof(JSValue),
typeof(JSPromise).GetImplicitConversion(typeof(JSPromise), typeof(JSValue)));
GetToJSValueExpression(resultType)));
}
if (fromType == typeof(ValueTask))
{
/*
* (JSValue)value.AsPromise()
*/
asPromiseExpression = ConvertPromiseToJSValueExpression(
Expression.Call(
typeof(TaskExtensions).GetStaticMethod(
nameof(TaskExtensions.AsPromise), new[] { typeof(ValueTask) }),
valueParameter));
}
else if (fromType.IsGenericType &&
fromType.GetGenericTypeDefinition() == typeof(ValueTask<>))
{
/*
* (JSValue)value.AsPromise<T>((value) => (JSValue)value)
*/
Type resultType = fromType.GenericTypeArguments[0];
asPromiseExpression = ConvertPromiseToJSValueExpression(
Expression.Call(
typeof(TaskExtensions).GetStaticMethod(
nameof(TaskExtensions.AsPromise),
new[] { typeof(ValueTask<>), typeof(JSValue.From<>) }),
valueParameter,
GetToJSValueExpression(resultType)));
}
else
{
Expand Down
2 changes: 1 addition & 1 deletion src/NodeApi.Generator/SymbolExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ private static Type AsType(

// Generating the containing type will also generate the nested type,
// so it should be found in the SymbolicTypes dictionary afterward.
typeSymbol.ContainingType?.AsType(genericTypeParameters: null, buildType);
typeSymbol.ContainingType?.AsType(genericTypeParameters, buildType);

if (SymbolicTypes.TryGetValue(typeFullName, out Type? symbolicType))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,35 @@ public static async Task<T> AsTask<T>(
return fromJS(await jsTask);
}

public static async ValueTask<JSValue> AsValueTask(this JSPromise promise)
{
return await promise.AsTask();
}

public static async ValueTask<JSValue> AsValueTask(
this JSPromise promise,
CancellationToken cancellation)
{
return await promise.AsTask(cancellation);
}

public static async ValueTask<T> AsValueTask<T>(
this JSPromise promise,
JSValue.To<T> fromJS)
{
ValueTask<JSValue> jsTask = promise.AsValueTask();
return fromJS(await jsTask);
}

public static async ValueTask<T> AsValueTask<T>(
this JSPromise promise,
JSValue.To<T> fromJS,
CancellationToken cancellation)
{
ValueTask<JSValue> jsTask = promise.AsValueTask(cancellation);
return fromJS(await jsTask);
}

public static JSPromise AsPromise(this Task task)
{
if (task.Status == TaskStatus.RanToCompletion)
Expand Down
14 changes: 14 additions & 0 deletions test/TestCases/napi-dotnet/AsyncMethods.cs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,20 @@ public async Task<string> TestAsync(string greeter)
return $"Hey {greeter}!";
}
}


[JSExport("async_method_valuetask")]
public static async ValueTask ValueTaskTest()
{
await Task.Yield();
}

[JSExport("async_method_valuetask_of_string")]
public static async ValueTask<string> ValueTaskTest(string greeter)
{
await Task.Delay(50);
return $"Hey {greeter}!";
}
}

[JSExport]
Expand Down
5 changes: 5 additions & 0 deletions test/TestCases/napi-dotnet/async_methods.js
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,10 @@ common.runTest(async () => {
// A JS object that implements an interface can be returned from C#.
binding.async_interface = asyncInterfaceImpl;
assert.strictEqual(binding.async_interface, asyncInterfaceImpl);

// Invoke C# methods that return ValueTask.
await binding.async_method_valuetask();
const result5 = await binding.async_method_valuetask_of_string('buddy');
assert.strictEqual(result5, 'Hey buddy!');
});

0 comments on commit 3c36102

Please sign in to comment.