Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add missing ValueTask support with InProcessNoEmit toolchain #2110

Closed
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Added AwaitHelper to properly wait for ValueTasks.
  • Loading branch information
timcassell committed Feb 17, 2023
commit fa1499a26e180e6070958230e75895251ecdd68e
14 changes: 5 additions & 9 deletions src/BenchmarkDotNet/Code/DeclarationsProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ private string GetMethodName(MethodInfo method)
(method.ReturnType.GetGenericTypeDefinition() == typeof(Task<>) ||
method.ReturnType.GetGenericTypeDefinition() == typeof(ValueTask<>))))
{
return $"() => {method.Name}().GetAwaiter().GetResult()";
return $"() => awaitHelper.GetResult({method.Name}())";
}

return method.Name;
Expand Down Expand Up @@ -149,12 +149,10 @@ internal class TaskDeclarationsProvider : VoidDeclarationsProvider
{
public TaskDeclarationsProvider(Descriptor descriptor) : base(descriptor) { }

// we use GetAwaiter().GetResult() because it's fastest way to obtain the result in blocking way,
// and will eventually throw actual exception, not aggregated one
public override string WorkloadMethodDelegate(string passArguments)
=> $"({passArguments}) => {{ {Descriptor.WorkloadMethod.Name}({passArguments}).GetAwaiter().GetResult(); }}";
=> $"({passArguments}) => {{ awaitHelper.GetResult({Descriptor.WorkloadMethod.Name}({passArguments})); }}";

public override string GetWorkloadMethodCall(string passArguments) => $"{Descriptor.WorkloadMethod.Name}({passArguments}).GetAwaiter().GetResult()";
public override string GetWorkloadMethodCall(string passArguments) => $"awaitHelper.GetResult({Descriptor.WorkloadMethod.Name}({passArguments}))";

protected override Type WorkloadMethodReturnType => typeof(void);
}
Expand All @@ -168,11 +166,9 @@ internal class GenericTaskDeclarationsProvider : NonVoidDeclarationsProvider

protected override Type WorkloadMethodReturnType => Descriptor.WorkloadMethod.ReturnType.GetTypeInfo().GetGenericArguments().Single();

// we use GetAwaiter().GetResult() because it's fastest way to obtain the result in blocking way,
// and will eventually throw actual exception, not aggregated one
public override string WorkloadMethodDelegate(string passArguments)
=> $"({passArguments}) => {{ return {Descriptor.WorkloadMethod.Name}({passArguments}).GetAwaiter().GetResult(); }}";
=> $"({passArguments}) => {{ return awaitHelper.GetResult({Descriptor.WorkloadMethod.Name}({passArguments})); }}";

public override string GetWorkloadMethodCall(string passArguments) => $"{Descriptor.WorkloadMethod.Name}({passArguments}).GetAwaiter().GetResult()";
public override string GetWorkloadMethodCall(string passArguments) => $"awaitHelper.GetResult({Descriptor.WorkloadMethod.Name}({passArguments}))";
}
}
112 changes: 112 additions & 0 deletions src/BenchmarkDotNet/Helpers/AwaitHelper.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
using System;
using System.Linq;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Threading.Tasks;

namespace BenchmarkDotNet.Helpers
{
public class AwaitHelper
{
private readonly object awaiterLock = new object();
private readonly Action awaiterCallback;
private bool awaiterCompleted;

public AwaitHelper()
{
awaiterCallback = AwaiterCallback;
}

private void AwaiterCallback()
{
lock (awaiterLock)
{
awaiterCompleted = true;
System.Threading.Monitor.Pulse(awaiterLock);
}
}

// we use GetAwaiter().GetResult() because it's fastest way to obtain the result in blocking way,
// and will eventually throw actual exception, not aggregated one
public void GetResult(Task task)
{
task.GetAwaiter().GetResult();
}

public T GetResult<T>(Task<T> task)
{
return task.GetAwaiter().GetResult();
}

// It is illegal to call GetResult from an uncomplete ValueTask, so we must hook up a callback.
public void GetResult(ValueTask task)
{
// Don't continue on the captured context, as that may result in a deadlock if the user runs this in-process.
var awaiter = task.ConfigureAwait(false).GetAwaiter();
if (!awaiter.IsCompleted)
{
lock (awaiterLock)
{
awaiterCompleted = false;
awaiter.UnsafeOnCompleted(awaiterCallback);
// Check if the callback executed synchronously before blocking.
if (!awaiterCompleted)
{
System.Threading.Monitor.Wait(awaiterLock);
}
}
}
awaiter.GetResult();
}

public T GetResult<T>(ValueTask<T> task)
{
// Don't continue on the captured context, as that may result in a deadlock if the user runs this in-process.
var awaiter = task.ConfigureAwait(false).GetAwaiter();
if (!awaiter.IsCompleted)
{
lock (awaiterLock)
{
awaiterCompleted = false;
awaiter.UnsafeOnCompleted(awaiterCallback);
// Check if the callback executed synchronously before blocking.
if (!awaiterCompleted)
{
System.Threading.Monitor.Wait(awaiterLock);
}
}
}
return awaiter.GetResult();
}

internal static MethodInfo GetGetResultMethod(Type taskType)
{
if (!taskType.IsGenericType)
{
return typeof(AwaitHelper).GetMethod(nameof(AwaitHelper.GetResult), BindingFlags.Public | BindingFlags.Instance, null, new Type[1] { taskType }, null);
}

Type compareType = taskType.GetGenericTypeDefinition() == typeof(ValueTask<>) ? typeof(ValueTask<>)
: typeof(Task).IsAssignableFrom(taskType.GetGenericTypeDefinition()) ? typeof(Task<>)
: null;
if (compareType == null)
{
return null;
}
var resultType = taskType
.GetMethod(nameof(Task.GetAwaiter), BindingFlags.Public | BindingFlags.Instance)
.ReturnType
.GetMethod(nameof(TaskAwaiter.GetResult), BindingFlags.Public | BindingFlags.Instance)
.ReturnType;
return typeof(AwaitHelper).GetMethods(BindingFlags.Public | BindingFlags.Instance)
.First(m =>
{
if (m.Name != nameof(AwaitHelper.GetResult)) return false;
Type paramType = m.GetParameters().First().ParameterType;
// We have to compare the types indirectly, == check doesn't work.
return paramType.Assembly == compareType.Assembly && paramType.Namespace == compareType.Namespace && paramType.Name == compareType.Name;
})
.MakeGenericMethod(new[] { resultType });
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,37 @@ public static void EmitVoidReturn(this ILGenerator ilBuilder, MethodBuilder meth
ilBuilder.Emit(OpCodes.Ret);
}

public static void EmitSetFieldToNewInstance(
this ILGenerator ilBuilder,
FieldBuilder field,
Type instanceType)
{
if (field.IsStatic)
throw new ArgumentException("The field should be instance field", nameof(field));

if (instanceType != null)
{
/*
IL_0006: ldarg.0
IL_0007: newobj instance void BenchmarkDotNet.Helpers.AwaitHelper::.ctor()
IL_000c: stfld class BenchmarkDotNet.Helpers.AwaitHelper BenchmarkDotNet.Autogenerated.Runnable_0::awaitHelper
*/
var ctor = instanceType.GetConstructor(Array.Empty<Type>());
if (ctor == null)
throw new InvalidOperationException($"Bug: instanceType {instanceType.Name} does not have a 0-parameter accessible constructor.");

ilBuilder.Emit(OpCodes.Ldarg_0);
ilBuilder.Emit(OpCodes.Newobj, ctor);
ilBuilder.Emit(OpCodes.Stfld, field);
}
else
{
ilBuilder.Emit(OpCodes.Ldarg_0);
ilBuilder.Emit(OpCodes.Ldnull);
ilBuilder.Emit(OpCodes.Stfld, field);
}
}

public static void EmitSetDelegateToThisField(
this ILGenerator ilBuilder,
FieldBuilder delegateField,
Expand Down
4 changes: 4 additions & 0 deletions src/BenchmarkDotNet/Templates/BenchmarkType.txt
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@

public Runnable_$ID$()
{
awaitHelper = new BenchmarkDotNet.Helpers.AwaitHelper();

globalSetupAction = $GlobalSetupMethodName$;
globalCleanupAction = $GlobalCleanupMethodName$;
iterationSetupAction = $IterationSetupMethodName$;
Expand All @@ -66,6 +68,8 @@
$InitializeArgumentFields$
}

private readonly BenchmarkDotNet.Helpers.AwaitHelper awaitHelper;

private System.Action globalSetupAction;
private System.Action globalCleanupAction;
private System.Action iterationSetupAction;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
using BenchmarkDotNet.Engines;
using JetBrains.Annotations;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Threading.Tasks;
Expand All @@ -17,28 +19,24 @@ public ConsumableTypeInfo(Type methodReturnType)

OriginMethodReturnType = methodReturnType;

// Please note this code does not support await over extension methods.
var getAwaiterMethod = methodReturnType.GetMethod(nameof(Task<int>.GetAwaiter), BindingFlagsPublicInstance);
if (getAwaiterMethod == null)
// Only support (Value)Task for parity with other toolchains (and so we can use AwaitHelper).
IsAwaitable = methodReturnType == typeof(Task) || methodReturnType == typeof(ValueTask)
|| (methodReturnType.GetTypeInfo().IsGenericType
&& (methodReturnType.GetTypeInfo().GetGenericTypeDefinition() == typeof(Task<>)
|| methodReturnType.GetTypeInfo().GetGenericTypeDefinition() == typeof(ValueTask<>)));

if (!IsAwaitable)
{
WorkloadMethodReturnType = methodReturnType;
}
else
{
var getResultMethod = getAwaiterMethod
WorkloadMethodReturnType = methodReturnType
.GetMethod(nameof(Task.GetAwaiter), BindingFlagsPublicInstance)
.ReturnType
.GetMethod(nameof(TaskAwaiter.GetResult), BindingFlagsPublicInstance);

if (getResultMethod == null)
{
WorkloadMethodReturnType = methodReturnType;
}
else
{
WorkloadMethodReturnType = getResultMethod.ReturnType;
GetAwaiterMethod = getAwaiterMethod;
GetResultMethod = getResultMethod;
}
.GetMethod(nameof(TaskAwaiter.GetResult), BindingFlagsPublicInstance)
.ReturnType;
GetResultMethod = Helpers.AwaitHelper.GetGetResultMethod(methodReturnType);
}

if (WorkloadMethodReturnType == null)
Expand Down Expand Up @@ -75,14 +73,13 @@ public ConsumableTypeInfo(Type methodReturnType)
public Type WorkloadMethodReturnType { get; }
public Type OverheadMethodReturnType { get; }

public MethodInfo? GetAwaiterMethod { get; }
public MethodInfo? GetResultMethod { get; }

public bool IsVoid { get; }
public bool IsByRef { get; }
public bool IsConsumable { get; }
public FieldInfo? WorkloadConsumableField { get; }

public bool IsAwaitable => GetAwaiterMethod != null && GetResultMethod != null;
public bool IsAwaitable { get; }
}
}
Loading