Skip to content

Commit

Permalink
#12 Emit leave instruction if rewriting an async method
Browse files Browse the repository at this point in the history
  • Loading branch information
Miista committed May 2, 2024
1 parent 44e57ef commit cdf8430
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 14 deletions.
16 changes: 16 additions & 0 deletions src/Pose/Extensions/TypeExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System;
using System.Linq;
using System.Reflection;
using System.Runtime.CompilerServices;

namespace Pose.Extensions
{
Expand Down Expand Up @@ -44,5 +45,20 @@ private static Type GetInterfaceType<TInterface>(this Type type)

return type.GetInterfaces().FirstOrDefault(interfaceType => interfaceType == typeof(TInterface));
}

public static bool IsAsync(this Type thisType)
{
if (thisType == null) throw new ArgumentNullException(nameof(thisType));

return
// State machines are generated by the compiler...
thisType.HasAttribute<CompilerGeneratedAttribute>()

// as nested private classes...
&& thisType.IsNestedPrivate

// which implements IAsyncStateMachine.
&& thisType.ImplementsInterface<IAsyncStateMachine>();
}
}
}
15 changes: 1 addition & 14 deletions src/Pose/Helpers/StubHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -59,27 +59,14 @@ public static MethodInfo DeVirtualizeMethod(Type thisType, MethodInfo virtualMet
var bindingFlags = BindingFlags.Instance | (virtualMethod.IsPublic ? BindingFlags.Public : BindingFlags.NonPublic);
var types = virtualMethod.GetParameters().Select(p => p.ParameterType).ToArray();

if (IsAsync(thisType))
if (thisType.IsAsync())
{
return thisType.GetExplicitlyImplementedMethod<IAsyncStateMachine>(nameof(IAsyncStateMachine.MoveNext));
}

return thisType.GetMethod(virtualMethod.Name, bindingFlags, null, types, null);
}

private static bool IsAsync(Type thisType)
{
return
// State machines are generated by the compiler...
thisType.HasAttribute<CompilerGeneratedAttribute>()

// as nested private classes...
&& thisType.IsNestedPrivate

// which implements IAsyncStateMachine.
&& thisType.ImplementsInterface<IAsyncStateMachine>();
}

public static Module GetOwningModule() => typeof(StubHelper).Module;

public static bool IsIntrinsic(MethodBase method)
Expand Down
11 changes: 11 additions & 0 deletions src/Pose/IL/MethodRewriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ internal class MethodRewriter
private readonly MethodBase _method;
private readonly Type _owningType;
private readonly bool _isInterfaceDispatch;
private readonly bool _isAsync;

private int _exceptionBlockLevel;
private TypeInfo _constrainedType;
Expand All @@ -33,6 +34,8 @@ private MethodRewriter(MethodBase method, Type owningType, bool isInterfaceDispa
_method = method ?? throw new ArgumentNullException(nameof(method));
_owningType = owningType ?? throw new ArgumentNullException(nameof(owningType));
_isInterfaceDispatch = isInterfaceDispatch;

_isAsync = method.Name == nameof(IAsyncStateMachine.MoveNext) && (method.DeclaringType?.IsAsync() ?? false);
}

public static MethodRewriter CreateRewriter(MethodBase method, bool isInterfaceDispatch)
Expand Down Expand Up @@ -308,6 +311,14 @@ private void EmitILForInlineBrTarget(ILGenerator ilGenerator, Instruction instru
else if (opCode == OpCodes.Blt_Un_S) opCode = OpCodes.Blt_Un;
else if (opCode == OpCodes.Leave_S) opCode = OpCodes.Leave;

// 'Leave' instructions must be emitted if we are rewriting an async method.
// Otherwise the rewritten method will always start from the beginning every time.
if (opCode == OpCodes.Leave && _isAsync)
{
ilGenerator.Emit(opCode, targetLabel);
return;
}

// Check if 'Leave' opcode is being used in an exception block,
// only emit it if that's not the case
if (opCode == OpCodes.Leave && _exceptionBlockLevel > 0) return;
Expand Down

0 comments on commit cdf8430

Please sign in to comment.