diff --git a/src/libraries/System.Private.CoreLib/src/Microsoft/Win32/SafeHandles/SafeFileHandle.OverlappedValueTaskSource.Windows.cs b/src/libraries/System.Private.CoreLib/src/Microsoft/Win32/SafeHandles/SafeFileHandle.OverlappedValueTaskSource.Windows.cs index f8a075cc61c76f..7e2bd4e72af183 100644 --- a/src/libraries/System.Private.CoreLib/src/Microsoft/Win32/SafeHandles/SafeFileHandle.OverlappedValueTaskSource.Windows.cs +++ b/src/libraries/System.Private.CoreLib/src/Microsoft/Win32/SafeHandles/SafeFileHandle.OverlappedValueTaskSource.Windows.cs @@ -14,17 +14,45 @@ namespace Microsoft.Win32.SafeHandles public sealed partial class SafeFileHandle : SafeHandleZeroOrMinusOneIsInvalid { private OverlappedValueTaskSource? _reusableOverlappedValueTaskSource; // reusable OverlappedValueTaskSource that is currently NOT being used + private ManualResetEvent? _reusableSyncWaitEvent; // reusable event for sync-over-async I/O // Rent the reusable OverlappedValueTaskSource, or create a new one to use if we couldn't get one (which // should only happen on first use or if the SafeFileHandle is being used concurrently). internal OverlappedValueTaskSource GetOverlappedValueTaskSource() => Interlocked.Exchange(ref _reusableOverlappedValueTaskSource, null) ?? new OverlappedValueTaskSource(this); + // Rent the reusable ManualResetEvent for sync-over-async I/O, or create a new one. + // The returned event is guaranteed to be in non-signaled state. + internal ManualResetEvent RentSyncWaitEvent() + { + ManualResetEvent? mre = Interlocked.Exchange(ref _reusableSyncWaitEvent, null); + if (mre is not null) + { + mre.Reset(); + return mre; + } + + return new ManualResetEvent(false); + } + + internal void ReturnSyncWaitEvent(ManualResetEvent waitEvent) + { + if (Interlocked.CompareExchange(ref _reusableSyncWaitEvent, waitEvent, null) is not null) + { + waitEvent.Dispose(); + } + else if (IsClosed) + { + Interlocked.Exchange(ref _reusableSyncWaitEvent, null)?.Dispose(); + } + } + protected override bool ReleaseHandle() { bool result = Interop.Kernel32.CloseHandle(handle); Interlocked.Exchange(ref _reusableOverlappedValueTaskSource, null)?.Dispose(); + Interlocked.Exchange(ref _reusableSyncWaitEvent, null)?.Dispose(); return result; } diff --git a/src/libraries/System.Private.CoreLib/src/System/IO/RandomAccess.Windows.cs b/src/libraries/System.Private.CoreLib/src/System/IO/RandomAccess.Windows.cs index ad3614376ab396..7fb1ebea54d71e 100644 --- a/src/libraries/System.Private.CoreLib/src/System/IO/RandomAccess.Windows.cs +++ b/src/libraries/System.Private.CoreLib/src/System/IO/RandomAccess.Windows.cs @@ -16,8 +16,6 @@ namespace System.IO { public static partial class RandomAccess { - private static readonly IOCompletionCallback s_callback = AllocateCallback(); - internal static unsafe void SetFileLength(SafeFileHandle handle, long length) { var eofInfo = new Interop.Kernel32.FILE_END_OF_FILE_INFO @@ -67,68 +65,74 @@ _ when IsEndOfFile(errorCode, handle, fileOffset) => 0, } } - private static unsafe int ReadSyncUsingAsyncHandle(SafeFileHandle handle, Span buffer, long fileOffset) + private static unsafe int ReadSyncUsingAsyncHandle(SafeFileHandle fileHandle, Span buffer, long fileOffset) { - handle.EnsureThreadPoolBindingInitialized(); - - CallbackResetEvent resetEvent = new CallbackResetEvent(handle.ThreadPoolBinding!); - NativeOverlapped* overlapped = null; + ManualResetEvent waitEvent = fileHandle.RentSyncWaitEvent(); + SafeWaitHandle waitHandle = waitEvent.SafeWaitHandle; + bool releaseWaitHandle = false, releaseFileHandle = false; try { - overlapped = GetNativeOverlappedForAsyncHandle(handle, fileOffset, resetEvent); + fileHandle.DangerousAddRef(ref releaseFileHandle); // keep it alive for the whole overlapped operation + waitHandle.DangerousAddRef(ref releaseWaitHandle); + NativeOverlapped overlapped = GetNativeOverlappedForAsyncHandle(fileHandle, fileOffset, waitHandle.DangerousGetHandle()); fixed (byte* pinned = &MemoryMarshal.GetReference(buffer)) { - Interop.Kernel32.ReadFile(handle, pinned, buffer.Length, IntPtr.Zero, overlapped); - - int errorCode = FileStreamHelpers.GetLastWin32ErrorAndDisposeHandleIfInvalid(handle); - if (errorCode == Interop.Errors.ERROR_IO_PENDING) + int errorCode = Interop.Kernel32.ReadFile(fileHandle, pinned, buffer.Length, IntPtr.Zero, &overlapped) == 0 + ? FileStreamHelpers.GetLastWin32ErrorAndDisposeHandleIfInvalid(fileHandle) + : Interop.Errors.ERROR_SUCCESS; + if (errorCode is Interop.Errors.ERROR_IO_PENDING) { - resetEvent.WaitOne(); - errorCode = Interop.Errors.ERROR_SUCCESS; + try + { + waitEvent.WaitOne(); + errorCode = Interop.Errors.ERROR_SUCCESS; + } + catch + { + // WaitOne can throw arbitrary exceptions (e.g., via SynchronizationContext). + // Cancel the pending IO and wait for completion before freeing the overlapped. + Interop.Kernel32.CancelIoEx(fileHandle, &overlapped); + int canceledBytes = 0; + Interop.Kernel32.GetOverlappedResult(fileHandle, &overlapped, ref canceledBytes, bWait: true); + throw; + } } - if (errorCode == Interop.Errors.ERROR_SUCCESS) + if (errorCode is Interop.Errors.ERROR_SUCCESS) { int result = 0; - if (Interop.Kernel32.GetOverlappedResult(handle, overlapped, ref result, bWait: false)) + if (Interop.Kernel32.GetOverlappedResult(fileHandle, &overlapped, ref result, bWait: false)) { Debug.Assert(result >= 0 && result <= buffer.Length, $"GetOverlappedResult returned {result} for {buffer.Length} bytes request"); return result; } - errorCode = FileStreamHelpers.GetLastWin32ErrorAndDisposeHandleIfInvalid(handle); - } - else - { - // The initial errorCode was neither ERROR_IO_PENDING nor ERROR_SUCCESS, so the operation - // failed with an error and the callback won't be invoked. We thus need to decrement the - // ref count on the resetEvent that was initialized to a value under the expectation that - // the callback would be invoked and decrement it. - resetEvent.ReleaseRefCount(overlapped); + errorCode = FileStreamHelpers.GetLastWin32ErrorAndDisposeHandleIfInvalid(fileHandle); } - if (IsEndOfFile(errorCode, handle, fileOffset)) + if (IsEndOfFile(errorCode, fileHandle, fileOffset)) { - // EOF on a pipe. Callback will not be called. - // We clear the overlapped status bit for this special case (failure - // to do so looks like we are freeing a pending overlapped later). - overlapped->InternalLow = IntPtr.Zero; return 0; } - throw Win32Marshal.GetExceptionForWin32Error(errorCode, handle.Path); + throw Win32Marshal.GetExceptionForWin32Error(errorCode, fileHandle.Path); } } finally { - if (overlapped != null) + if (releaseWaitHandle) { - resetEvent.ReleaseRefCount(overlapped); + waitHandle.DangerousRelease(); } - resetEvent.Dispose(); + if (releaseFileHandle) + { + fileHandle.DangerousRelease(); + } + + fileHandle.ReturnSyncWaitEvent(waitEvent); } } @@ -159,51 +163,56 @@ internal static unsafe void WriteAtOffset(SafeFileHandle handle, ReadOnlySpan buffer, long fileOffset) + private static unsafe void WriteSyncUsingAsyncHandle(SafeFileHandle fileHandle, ReadOnlySpan buffer, long fileOffset) { if (buffer.IsEmpty) { return; } - handle.EnsureThreadPoolBindingInitialized(); - - CallbackResetEvent resetEvent = new CallbackResetEvent(handle.ThreadPoolBinding!); - NativeOverlapped* overlapped = null; + ManualResetEvent waitEvent = fileHandle.RentSyncWaitEvent(); + SafeWaitHandle waitHandle = waitEvent.SafeWaitHandle; + bool releaseWaitHandle = false, releaseFileHandle = false; try { - overlapped = GetNativeOverlappedForAsyncHandle(handle, fileOffset, resetEvent); + fileHandle.DangerousAddRef(ref releaseFileHandle); // keep it alive for the whole overlapped operation + waitHandle.DangerousAddRef(ref releaseWaitHandle); + NativeOverlapped overlapped = GetNativeOverlappedForAsyncHandle(fileHandle, fileOffset, waitHandle.DangerousGetHandle()); fixed (byte* pinned = &MemoryMarshal.GetReference(buffer)) { - Interop.Kernel32.WriteFile(handle, pinned, buffer.Length, IntPtr.Zero, overlapped); - - int errorCode = FileStreamHelpers.GetLastWin32ErrorAndDisposeHandleIfInvalid(handle); - if (errorCode == Interop.Errors.ERROR_IO_PENDING) + int errorCode = Interop.Kernel32.WriteFile(fileHandle, pinned, buffer.Length, IntPtr.Zero, &overlapped) == 0 + ? FileStreamHelpers.GetLastWin32ErrorAndDisposeHandleIfInvalid(fileHandle) + : Interop.Errors.ERROR_SUCCESS; + if (errorCode is Interop.Errors.ERROR_IO_PENDING) { - resetEvent.WaitOne(); - errorCode = Interop.Errors.ERROR_SUCCESS; + try + { + waitEvent.WaitOne(); + errorCode = Interop.Errors.ERROR_SUCCESS; + } + catch + { + // WaitOne can throw arbitrary exceptions (e.g., via SynchronizationContext). + // Cancel the pending IO and wait for completion before freeing the overlapped. + Interop.Kernel32.CancelIoEx(fileHandle, &overlapped); + int canceledBytes = 0; + Interop.Kernel32.GetOverlappedResult(fileHandle, &overlapped, ref canceledBytes, bWait: true); + throw; + } } - if (errorCode == Interop.Errors.ERROR_SUCCESS) + if (errorCode is Interop.Errors.ERROR_SUCCESS) { int result = 0; - if (Interop.Kernel32.GetOverlappedResult(handle, overlapped, ref result, bWait: false)) + if (Interop.Kernel32.GetOverlappedResult(fileHandle, &overlapped, ref result, bWait: false)) { Debug.Assert(result == buffer.Length, $"GetOverlappedResult returned {result} for {buffer.Length} bytes request"); return; } - errorCode = FileStreamHelpers.GetLastWin32ErrorAndDisposeHandleIfInvalid(handle); - } - else - { - // The initial errorCode was neither ERROR_IO_PENDING nor ERROR_SUCCESS, so the operation - // failed with an error and the callback won't be invoked. We thus need to decrement the - // ref count on the resetEvent that was initialized to a value under the expectation that - // the callback would be invoked and decrement it. - resetEvent.ReleaseRefCount(overlapped); + errorCode = FileStreamHelpers.GetLastWin32ErrorAndDisposeHandleIfInvalid(fileHandle); } throw errorCode switch @@ -213,18 +222,23 @@ private static unsafe void WriteSyncUsingAsyncHandle(SafeFileHandle handle, Read // to a handle opened asynchronously. Interop.Errors.ERROR_INVALID_PARAMETER => new IOException(SR.IO_FileTooLong), - _ => Win32Marshal.GetExceptionForWin32Error(errorCode, handle.Path), + _ => Win32Marshal.GetExceptionForWin32Error(errorCode, fileHandle.Path), }; } } finally { - if (overlapped != null) + if (releaseWaitHandle) + { + waitHandle.DangerousRelease(); + } + + if (releaseFileHandle) { - resetEvent.ReleaseRefCount(overlapped); + fileHandle.DangerousRelease(); } - resetEvent.Dispose(); + fileHandle.ReturnSyncWaitEvent(waitEvent); } } @@ -724,16 +738,15 @@ private static async ValueTask WriteGatherAtOffsetMultipleSyscallsAsync(SafeFile } } - private static unsafe NativeOverlapped* GetNativeOverlappedForAsyncHandle(SafeFileHandle handle, long fileOffset, CallbackResetEvent resetEvent) + private static NativeOverlapped GetNativeOverlappedForAsyncHandle(SafeFileHandle handle, long fileOffset, nint waitHandle) { - // After SafeFileHandle is bound to ThreadPool, we need to use ThreadPoolBinding - // to allocate a native overlapped and provide a valid callback. - NativeOverlapped* result = handle.ThreadPoolBinding!.UnsafeAllocateNativeOverlapped(s_callback, resetEvent, null); + Debug.Assert(handle.IsAsync); + NativeOverlapped result = default; if (handle.CanSeek) { - result->OffsetLow = unchecked((int)fileOffset); - result->OffsetHigh = (int)(fileOffset >> 32); + result.OffsetLow = unchecked((int)fileOffset); + result.OffsetHigh = (int)(fileOffset >> 32); } // From https://learn.microsoft.com/windows/win32/api/ioapiset/nf-ioapiset-getoverlappedresult: @@ -743,7 +756,11 @@ private static async ValueTask WriteGatherAtOffsetMultipleSyscallsAsync(SafeFile // are performed on the same file, named pipe, or communications device. // In this situation, there is no way to know which operation caused the object's state to be signaled." // Since we want RandomAccess APIs to be thread-safe, we provide a dedicated wait handle. - result->EventHandle = resetEvent.SafeWaitHandle.DangerousGetHandle(); + // From https://learn.microsoft.com/windows/win32/api/ioapiset/nf-ioapiset-getqueuedcompletionstatus: + // "If the file handle associated with the completion packet was previously associated with an I/O completion port + // [...] setting the low-order bit of hEvent in the OVERLAPPED structure prevents the I/O completion + // from being queued to a completion port." + result.EventHandle = waitHandle | 1; return result; } @@ -761,17 +778,6 @@ private static NativeOverlapped GetNativeOverlappedForSyncHandle(SafeFileHandle return result; } - private static unsafe IOCompletionCallback AllocateCallback() - { - return new IOCompletionCallback(Callback); - - static void Callback(uint errorCode, uint numBytes, NativeOverlapped* pOverlapped) - { - CallbackResetEvent state = (CallbackResetEvent)ThreadPoolBoundHandle.GetNativeOverlappedState(pOverlapped)!; - state.ReleaseRefCount(pOverlapped); - } - } - internal static bool IsEndOfFile(int errorCode, SafeFileHandle handle, long fileOffset) { switch (errorCode) @@ -798,32 +804,6 @@ internal static bool IsEndOfFile(int errorCode, SafeFileHandle handle, long file private static bool IsEndOfFileForNoBuffering(SafeFileHandle fileHandle, long fileOffset) => fileHandle.IsNoBuffering && fileHandle.CanSeek && fileOffset >= fileHandle.GetFileLength(); - // We need to store the reference count (see the comment in ReleaseRefCount) and an EventHandle to signal the completion. - // We could keep these two things separate, but since ManualResetEvent is sealed and we want to avoid any extra allocations, this type has been created. - // It's basically ManualResetEvent with reference count. - private sealed class CallbackResetEvent : EventWaitHandle - { - private readonly ThreadPoolBoundHandle _threadPoolBoundHandle; - private int _freeWhenZero = 2; // one for the callback and another for the method that calls GetOverlappedResult - - internal CallbackResetEvent(ThreadPoolBoundHandle threadPoolBoundHandle) : base(initialState: false, EventResetMode.ManualReset) - { - _threadPoolBoundHandle = threadPoolBoundHandle; - } - - internal unsafe void ReleaseRefCount(NativeOverlapped* pOverlapped) - { - // Each SafeFileHandle opened for async IO is bound to ThreadPool. - // It requires us to provide a callback even if we want to use EventHandle and use GetOverlappedResult to obtain the result. - // There can be a race condition between the call to GetOverlappedResult and the callback invocation, - // so we need to track the number of references, and when it drops to zero, then free the native overlapped. - if (Interlocked.Decrement(ref _freeWhenZero) == 0) - { - _threadPoolBoundHandle.FreeNativeOverlapped(pOverlapped); - } - } - } - // Abstracts away the type signature incompatibility between Memory and ReadOnlyMemory. private interface IMemoryHandler { diff --git a/src/libraries/System.Runtime/tests/System.IO.FileSystem.Tests/RandomAccess/Mixed.Windows.cs b/src/libraries/System.Runtime/tests/System.IO.FileSystem.Tests/RandomAccess/Mixed.Windows.cs index 4c6a8ddf4b22d4..c685a2a0b8c37b 100644 --- a/src/libraries/System.Runtime/tests/System.IO.FileSystem.Tests/RandomAccess/Mixed.Windows.cs +++ b/src/libraries/System.Runtime/tests/System.IO.FileSystem.Tests/RandomAccess/Mixed.Windows.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.Runtime.InteropServices; +using System.Threading; using System.Threading.Tasks; using Microsoft.Win32.SafeHandles; using Xunit; @@ -143,5 +144,80 @@ static async Task Validate(SafeFileHandle handle, FileOptions options, bool[] sy } } } + + [Fact] + public void SyncIOOnAsyncHandle_DoesNotCorruptMemory_WhenSynchronizationContextThrows() + { + // This test verifies that when WaitOne() throws via SynchronizationContext, + // the pending IO is properly canceled before freeing the NativeOverlapped, + // preventing use-after-free / heap corruption. + byte[] expectedData = new byte[1024]; + Random.Shared.NextBytes(expectedData); + + SynchronizationContext previous = SynchronizationContext.Current; + try + { + ThrowingSynchronizationContext throwingContext = new(); + SynchronizationContext.SetSynchronizationContext(throwingContext); + + SafeFileHandle.CreateAnonymousPipe( + out SafeFileHandle readHandle, + out SafeFileHandle writeHandle, + asyncRead: true, + asyncWrite: false); + + using (readHandle) + using (writeHandle) + { + byte[] pendingReadBuffer = new byte[1]; + + // The ThrowingSynchronizationContext.Wait throws, which should be caught + // and the IO should be canceled gracefully. + Assert.Throws(() => RandomAccess.Read(readHandle, pendingReadBuffer, 0)); + + // Restore the previous context and verify the read handle is still usable. + SynchronizationContext.SetSynchronizationContext(previous); + + RandomAccess.Write(writeHandle, expectedData, 0); + + byte[] readBuffer = new byte[expectedData.Length]; + int totalRead = 0; + while (totalRead < readBuffer.Length) + { + int bytesRead = RandomAccess.Read(readHandle, readBuffer.AsSpan(totalRead), totalRead); + if (bytesRead == 0) + { + break; + } + + totalRead += bytesRead; + } + + Assert.Equal(expectedData.Length, totalRead); + Assert.Equal(expectedData, readBuffer); + } + } + finally + { + SynchronizationContext.SetSynchronizationContext(previous); + } + } + + /// + /// A SynchronizationContext that throws from Wait to simulate the scenario + /// where WaitOne() can throw arbitrary exceptions via user code. + /// + private sealed class ThrowingSynchronizationContext : SynchronizationContext + { + public ThrowingSynchronizationContext() + { + SetWaitNotificationRequired(); + } + + public override int Wait(IntPtr[] waitHandles, bool waitAll, int millisecondsTimeout) + { + throw new InvalidOperationException("SynchronizationContext.Wait threw an exception"); + } + } } }