Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,16 @@ public ref byte GetResultStorageOrNull()
ref byte data = ref RuntimeHelpers.GetRawData(this);
return ref Unsafe.Add(ref data, (DataOffset - PointerSize) + index * PointerSize);
}

protected void EncodeFieldOffsetInFlags(ref byte field, ContinuationFlags firstBit, ContinuationFlags numBits)
{
int offset = (int)Unsafe.ByteOffset(ref RuntimeHelpers.GetRawData(this), ref field);
offset -= DataOffset;
Debug.Assert(offset >= 0 && offset % PointerSize == 0);
uint index = 1 + (uint)offset / PointerSize;
Debug.Assert(index < (1 << (int)numBits));
Flags |= (ContinuationFlags)((uint)index << (int)firstBit);
}
}

[StructLayout(LayoutKind.Explicit)]
Expand Down Expand Up @@ -202,7 +212,7 @@ private ref struct RuntimeAsyncStackState
// to one of these notifiers.
public ICriticalNotifyCompletion? CriticalNotifier;
public INotifyCompletion? Notifier;
public ValueTaskSourceNotifier? ValueTaskSourceNotifier;
public ValueTaskContinuation? ValueTaskContinuation;
public Task? TaskNotifier;

// When we suspend in the leaf, the contexts are captured into these fields.
Expand Down Expand Up @@ -245,6 +255,7 @@ public void Pop(Thread thread)
private unsafe struct RuntimeAsyncAwaitState
{
public Continuation? SentinelContinuation;
public ValueTaskContinuation? CachedValueTaskContinuation;

// We cache the thread here to avoid unnecessary repeated TLS lookups.
public Thread? CurrentThread;
Expand Down Expand Up @@ -330,6 +341,202 @@ private static unsafe Continuation AllocContinuationClass(Continuation prevConti
}
#endif

private sealed unsafe class ValueTaskContinuation : Continuation
{
// Currently all continuations are expected to capture and restore
// ExecutionContext, even though we do not actually need it here.
public ExecutionContext? ExecutionContext;
private object? _source;
private short _token;
private delegate* managed<object, Action<object?>, object?, short, ValueTaskSourceOnCompletedFlags, void> _onCompleted;
private delegate* managed<object, short, ref byte, void> _getResult;

public ValueTaskContinuation()
{
ResumeInfo = (ResumeInfo*)Unsafe.AsPointer(ref ValueTaskContinuationResume.ResumeInfo);

EncodeFieldOffsetInFlags(
ref Unsafe.As<ExecutionContext?, byte>(ref ExecutionContext),
ContinuationFlags.ExecutionContextIndexFirstBit,
ContinuationFlags.ExecutionContextIndexNumBits);
}

public void OnCompleted(Action<object?> continuation, object? state, ValueTaskSourceOnCompletedFlags flags)
{
Debug.Assert(_source != null);
_onCompleted(_source, continuation, state, _token, flags);
}

public void GetResult(ref byte returnValue)
{
Debug.Assert(_source != null);

// Avoid retaining source. The call below may throw.
object source = _source;
_source = null;

_getResult(source, _token, ref returnValue);
}

public void Initialize(object source, short token)
{
_source = source;
_token = token;
_onCompleted = &OnCompleted;
_getResult = &GetResult;
}

public void Initialize<T>(object source, short token)
{
_source = source;
_token = token;
_onCompleted = &OnCompleted<T>;
_getResult = &GetResult<T>;
}

private static void OnCompleted(object source, Action<object?> continuation, object? state, short token, ValueTaskSourceOnCompletedFlags flags)
{
if (source is Task t)
{
Debug.Assert(state is ITaskCompletionAction);
if (!t.TryAddCompletionAction(Unsafe.As<object, ITaskCompletionAction>(ref state)))
{
ThreadPool.UnsafeQueueUserWorkItemInternal(state, preferLocal: true);
}
}
else
{
Debug.Assert(source is IValueTaskSource);
IValueTaskSource typedSource = Unsafe.As<object, IValueTaskSource>(ref source);
typedSource.OnCompleted(continuation, state, token, flags);
}
}

private static void GetResult(object source, short token, ref byte result)
{
if (source is Task t)
{
TaskAwaiter.ValidateEnd(t);
}
else
{
Debug.Assert(source is IValueTaskSource);
IValueTaskSource typedSource = Unsafe.As<object, IValueTaskSource>(ref source);
typedSource.GetResult(token);
}
}

private static void OnCompleted<T>(object source, Action<object?> continuation, object? state, short token, ValueTaskSourceOnCompletedFlags flags)
{
if (source is Task t)
{
Debug.Assert(state is ITaskCompletionAction);
if (!t.TryAddCompletionAction(Unsafe.As<object, ITaskCompletionAction>(ref state)))
{
ThreadPool.UnsafeQueueUserWorkItemInternal(state, preferLocal: true);
}
}
else
{
Debug.Assert(source is IValueTaskSource<T>);
IValueTaskSource<T> typedSource = Unsafe.As<object, IValueTaskSource<T>>(ref source);
typedSource.OnCompleted(continuation, state, token, flags);
}
}

private static void GetResult<T>(object source, short token, ref byte result)
{
if (source is Task<T> t)
{
TaskAwaiter.ValidateEnd(t);
Unsafe.As<byte, T>(ref result) = t.ResultOnSuccess;
}
else
{
Debug.Assert(source is IValueTaskSource<T>);
IValueTaskSource<T> typedSource = Unsafe.As<object, IValueTaskSource<T>>(ref source);
Unsafe.As<byte, T>(ref result) = typedSource.GetResult(token);
}
}

private static class ValueTaskContinuationResume
{
[FixedAddressValueType]
public static ResumeInfo ResumeInfo = new ResumeInfo
{
DiagnosticIP = null,
Resume = &ResumeValueTaskContinuation,
};

public static Continuation? ResumeValueTaskContinuation(Continuation cont, ref byte result)
{
var vtsCont = (ValueTaskContinuation)cont;
vtsCont.Next = null;
vtsCont.ExecutionContext = null;
t_runtimeAsyncAwaitState.CachedValueTaskContinuation = vtsCont;

vtsCont.GetResult(ref result);
Comment thread
jakobbotsch marked this conversation as resolved.
return null;
}
}
}

[BypassReadyToRun]
[MethodImpl(MethodImplOptions.NoInlining | MethodImplOptions.Async)]
private static unsafe void TransparentAwaitValueTask(ValueTask valueTask)
{
ref RuntimeAsyncAwaitState state = ref t_runtimeAsyncAwaitState;
Continuation? sentinelContinuation = state.SentinelContinuation ??= new Continuation();

ValueTaskContinuation? vtsCont = state.CachedValueTaskContinuation;
if (vtsCont != null)
{
state.CachedValueTaskContinuation = null;
}
else
{
vtsCont = new ValueTaskContinuation();
}

Debug.Assert(valueTask._obj != null);
vtsCont.Initialize(valueTask._obj, valueTask._token);
vtsCont.ExecutionContext = ExecutionContext.CaptureForSuspension(state.CurrentThread!);

sentinelContinuation.Next = vtsCont;
state.StackState->ValueTaskContinuation = vtsCont;

Comment thread
jakobbotsch marked this conversation as resolved.
state.CaptureContexts();
AsyncSuspend(vtsCont);
}

[BypassReadyToRun]
[MethodImpl(MethodImplOptions.NoInlining | MethodImplOptions.Async)]
private static unsafe void TransparentAwaitValueTaskOfT<T>(ValueTask<T?> valueTask)
{
ref RuntimeAsyncAwaitState state = ref t_runtimeAsyncAwaitState;
Continuation? sentinelContinuation = state.SentinelContinuation ??= new Continuation();

ValueTaskContinuation? vtsCont = state.CachedValueTaskContinuation;
if (vtsCont != null)
{
state.CachedValueTaskContinuation = null;
}
else
{
vtsCont = new ValueTaskContinuation();
}

Debug.Assert(valueTask._obj != null);
vtsCont.Initialize<T>(valueTask._obj, valueTask._token);
vtsCont.ExecutionContext = ExecutionContext.CaptureForSuspension(state.CurrentThread!);

sentinelContinuation.Next = vtsCont;
state.StackState->ValueTaskContinuation = vtsCont;

state.CaptureContexts();
AsyncSuspend(vtsCont);
}

/// <summary>
/// Used by internal thunks that implement awaiting on Task or a ValueTask.
/// A ValueTask may wrap:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All the comments about ValueTask may be dropped now.

Expand All @@ -339,22 +546,15 @@ private static unsafe Continuation AllocContinuationClass(Continuation prevConti
/// Therefore, when we are awaiting a ValueTask completion we are really
/// awaiting a completion of an underlying Task or ValueTaskSource.
/// </summary>
/// <param name="o"> Task or a ValueTaskNotifier whose completion we are awaiting.</param>
/// <param name="t"> Task whose completion we are awaiting.</param>
[BypassReadyToRun]
[MethodImpl(MethodImplOptions.NoInlining | MethodImplOptions.Async)]
private static unsafe void TransparentAwait(object o)
private static unsafe void TransparentAwait(Task t)
{
ref RuntimeAsyncAwaitState state = ref t_runtimeAsyncAwaitState;
Continuation? sentinelContinuation = state.SentinelContinuation ??= new Continuation();

if (o is Task t)
{
state.StackState->TaskNotifier = t;
}
else
{
state.StackState->ValueTaskSourceNotifier = (ValueTaskSourceNotifier)o;
}
state.StackState->TaskNotifier = t;

state.CaptureContexts();
AsyncSuspend(sentinelContinuation);
Expand Down Expand Up @@ -456,8 +656,9 @@ internal unsafe bool HandleSuspended(ref RuntimeAsyncAwaitState state)
ThreadPool.UnsafeQueueUserWorkItemInternal(this, preferLocal: true);
}
}
else if (stackState->ValueTaskSourceNotifier is { } valueTaskSourceNotifier)
else if (stackState->ValueTaskContinuation is { } valueTaskSourceCont)
{
Debug.Assert(headContinuation == valueTaskSourceCont);
// The awaiter must inform the ValueTaskSource on whether the continuation
// wants to run on a context, although the source may decide to ignore the suggestion.
// Since the behavior of the source takes precedence, we clear the context flags of
Expand Down Expand Up @@ -491,7 +692,7 @@ internal unsafe bool HandleSuspended(ref RuntimeAsyncAwaitState state)

// Clear continuation flags, so that continuation runs transparently
nextUserContinuation.Flags &= ~continueFlags;
valueTaskSourceNotifier.OnCompleted(s_runContinuationAction, this, configFlags);
valueTaskSourceCont.OnCompleted(s_runContinuationAction, this, configFlags);
}
else
{
Expand Down
2 changes: 1 addition & 1 deletion src/coreclr/debug/daccess/dacdbiimpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7165,7 +7165,7 @@ HRESULT STDMETHODCALLTYPE DacDbiInterfaceImpl::IsValidObject(CORDB_ADDRESS obj,

if (mt == cls->GetMethodTable())
isValid = TRUE;
else if (!mt->IsCanonicalMethodTable() || mt->IsContinuation())
else if (!mt->IsCanonicalMethodTable() || (mt->IsContinuation() && !mt->IsContinuationWithMetadata()))
isValid = cls->GetMethodTable()->GetClass() == cls;
}
EX_CATCH
Expand Down
1 change: 1 addition & 0 deletions src/coreclr/inc/dacvars.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ DEFINE_DACVAR(UNKNOWN_POINTER_TYPE, dac__g_pWeakReferenceClass, ::g_pWeakReferen
DEFINE_DACVAR(UNKNOWN_POINTER_TYPE, dac__g_pWeakReferenceOfTClass, ::g_pWeakReferenceOfTClass)

DEFINE_DACVAR_VOLATILE(UNKNOWN_POINTER_TYPE, dac__g_pContinuationClassIfSubTypeCreated, ::g_pContinuationClassIfSubTypeCreated)
DEFINE_DACVAR_VOLATILE(UNKNOWN_POINTER_TYPE, dac__g_singletonContinuationEEClass, ::g_singletonContinuationEEClass)

#ifdef FEATURE_COMINTEROP
DEFINE_DACVAR(UNKNOWN_POINTER_TYPE, dac__g_pBaseCOMObject, ::g_pBaseCOMObject)
Expand Down
23 changes: 15 additions & 8 deletions src/coreclr/tools/Common/TypeSystem/IL/Stubs/AsyncThunks.cs
Original file line number Diff line number Diff line change
Expand Up @@ -292,21 +292,26 @@ public static MethodIL EmitAsyncMethodThunk(MethodDesc asyncMethod, MethodDesc t
TypeDesc valueTaskType = taskReturningMethodReturnType;
MethodDesc isCompletedMethod;
MethodDesc completionResultMethod;
MethodDesc asTaskOrNotifierMethod;
MethodDesc transparentAwaitValueTaskMethod;

if (!taskReturningMethodReturnType.HasInstantiation)
{
// ValueTask (non-generic)
isCompletedMethod = valueTaskType.GetKnownMethod("get_IsCompleted"u8, null);
completionResultMethod = valueTaskType.GetKnownMethod("ThrowIfCompletedUnsuccessfully"u8, null);
asTaskOrNotifierMethod = valueTaskType.GetKnownMethod("AsTaskOrNotifier"u8, null);
transparentAwaitValueTaskMethod =
context.SystemModule.GetKnownType("System.Runtime.CompilerServices"u8, "AsyncHelpers"u8)
.GetKnownMethod("TransparentAwaitValueTask"u8, null);
}
else
{
// ValueTask<T> (generic)
isCompletedMethod = valueTaskType.GetKnownMethod("get_IsCompleted"u8, null);
completionResultMethod = valueTaskType.GetKnownMethod("get_Result"u8, null);
asTaskOrNotifierMethod = valueTaskType.GetKnownMethod("AsTaskOrNotifier"u8, null);
transparentAwaitValueTaskMethod =
context.SystemModule.GetKnownType("System.Runtime.CompilerServices"u8, "AsyncHelpers"u8)
.GetKnownMethod("TransparentAwaitValueTaskOfT"u8, null)
.MakeInstantiatedMethod(valueTaskType.Instantiation[0]);
}

ILLocalVariable valueTaskLocal = emitter.NewLocal(valueTaskType);
Expand All @@ -315,15 +320,17 @@ public static MethodIL EmitAsyncMethodThunk(MethodDesc asyncMethod, MethodDesc t
// Store value task returned by call to actual user func
codestream.EmitStLoc(valueTaskLocal);
codestream.EmitLdLoca(valueTaskLocal);

// Was it already completed?
codestream.Emit(ILOpcode.call, emitter.NewToken(isCompletedMethod));
codestream.Emit(ILOpcode.brtrue, valueTaskCompletedLabel);

codestream.EmitLdLoca(valueTaskLocal);
codestream.Emit(ILOpcode.call, emitter.NewToken(asTaskOrNotifierMethod));
codestream.Emit(ILOpcode.call, emitter.NewToken(
context.SystemModule.GetKnownType("System.Runtime.CompilerServices"u8, "AsyncHelpers"u8)
.GetKnownMethod("TransparentAwait"u8, null)));
// No, tail await to TransparentAwaitValueTask
codestream.EmitLdLoc(valueTaskLocal);
codestream.Emit(ILOpcode.call, emitter.NewToken(context.GetCoreLibEntryPoint("System.Runtime.CompilerServices"u8, "AsyncHelpers"u8, "TailAwait"u8, null)));
codestream.Emit(ILOpcode.call, emitter.NewToken(transparentAwaitValueTaskMethod));

// Yes, just get the result
codestream.EmitLabel(valueTaskCompletedLabel);
codestream.EmitLdLoca(valueTaskLocal);
codestream.Emit(ILOpcode.call, emitter.NewToken(completionResultMethod));
Expand Down
2 changes: 0 additions & 2 deletions src/coreclr/vm/asynccontinuations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@ void AsyncContinuationsManager::NotifyUnloadingClasses()
#endif // PROFILING_SUPPORTED
}

static EEClass* volatile g_singletonContinuationEEClass;

EEClass* AsyncContinuationsManager::GetOrCreateSingletonSubContinuationEEClass()
{
if (g_singletonContinuationEEClass != NULL)
Expand Down
Loading
Loading