Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 5 additions & 9 deletions src/DotNetty.Handlers/Tls/TlsHandler.Reader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -180,15 +180,11 @@ protected override void Decode(IChannelHandlerContext context, IByteBuffer input
// of the SSLException reported here.
WrapAndFlush(context);
}
// TODO revisit
//catch (IOException)
//{
// if (s_logger.DebugEnabled)
// {
// s_logger.Debug("SSLException during trying to call SSLEngine.wrap(...)" +
// " because of an previous SSLException, ignoring...", ex);
// }
//}
catch (Exception)
{
// Swallow any exception from WrapAndFlush so it does not mask the original cause.
// See https://github.com/maksimkim/SpanNetty/issues/60
}
finally
{
HandleFailure(cause);
Expand Down
30 changes: 30 additions & 0 deletions src/DotNetty.Handlers/Tls/TlsHandler.Writer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,12 @@ private void Wrap(IChannelHandlerContext context)
buf = null;

var promise = _pendingUnencryptedWrites.Remove();
if (promise is null)
{
// Queue was drained externally (e.g., by re-entrant HandleFailure → RemoveAndFailAll).
// See https://github.com/maksimkim/SpanNetty/issues/60
break;
}
Task task = _lastContextWriteTask;
if (task is object)
{
Expand All @@ -172,6 +178,12 @@ private void Wrap(IChannelHandlerContext context)
#if NETCOREAPP || NETSTANDARD_2_0_GREATER
private void FinishWrap(in ReadOnlySpan<byte> buffer, IPromise promise)
{
if (_outboundClosed)
{
_ = promise.TryComplete();
return;
}

IByteBuffer output;
var capturedContext = CapturedContext;
if (buffer.IsEmpty)
Expand All @@ -192,6 +204,12 @@ private void FinishWrap(in ReadOnlySpan<byte> buffer, IPromise promise)

private void FinishWrap(byte[] buffer, int offset, int count, IPromise promise)
{
if (_outboundClosed)
{
_ = promise.TryComplete();
return;
}

IByteBuffer output;
var capturedContext = CapturedContext;
if (0u >= (uint)count)
Expand All @@ -210,6 +228,12 @@ private void FinishWrap(byte[] buffer, int offset, int count, IPromise promise)
#if NETCOREAPP || NETSTANDARD_2_0_GREATER
private Task FinishWrapNonAppDataAsync(in ReadOnlyMemory<byte> buffer, IPromise promise)
{
if (_outboundClosed)
{
_ = promise.TryComplete();
return TaskUtil.Completed;
}

var capturedContext = CapturedContext;
Task future;
if (MemoryMarshal.TryGetArray(buffer, out var seg))
Expand All @@ -227,6 +251,12 @@ private Task FinishWrapNonAppDataAsync(in ReadOnlyMemory<byte> buffer, IPromise

private Task FinishWrapNonAppDataAsync(byte[] buffer, int offset, int count, IPromise promise)
{
if (_outboundClosed)
{
_ = promise.TryComplete();
return TaskUtil.Completed;
}

var capturedContext = CapturedContext;
var future = capturedContext.WriteAndFlushAsync(Unpooled.WrappedBuffer(buffer, offset, count), promise);
this.ReadIfNeeded(capturedContext);
Expand Down
10 changes: 9 additions & 1 deletion src/DotNetty.Handlers/Tls/TlsHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,15 @@ public override void Close(IChannelHandlerContext context, IPromise promise)
{
//CloseOutboundAndChannel(context, promise, false);
_ = _closeFuture.TryComplete();
_sslStream.Dispose();
try
{
_sslStream.Dispose();
}
catch (Exception)
{
// Swallow dispose exceptions to prevent them from propagating during channel close.
// See https://github.com/maksimkim/SpanNetty/issues/60
}
base.Close(context, promise);
}

Expand Down
241 changes: 241 additions & 0 deletions test/DotNetty.Handlers.Tests/TlsHandlerTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ namespace DotNetty.Handlers.Tests
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Net.Security;
using System.Runtime.InteropServices;
Expand Down Expand Up @@ -378,5 +379,245 @@ public override void ChannelActive(IChannelHandlerContext context)
}
}
}

/// <summary>
/// Regression test for https://github.com/maksimkim/SpanNetty/issues/60
/// Verifies that when the pending write queue is drained re-entrantly during
/// Wrap (between the Current check and Remove call), the null return from
/// Remove() is handled gracefully instead of throwing NullReferenceException.
///
/// Uses a custom Stream wrapper around MediationStream to simulate the
/// re-entrant HandleFailure → RemoveAndFailAll scenario: after SslStream
/// encrypts data and writes ciphertext to MediationStream (which sets
/// _lastContextWriteTask via FinishWrap), the wrapper drains the queue and
/// clears _lastContextWriteTask. When Wrap continues, Remove() returns null
/// and the unfixed code hits promise.TryComplete() on a null promise → NRE.
/// </summary>
[Fact]
public async Task WrapRemoveNull_ShouldNotThrowNullReferenceException()
{
var executor = new DefaultEventExecutor();
try
{
var writeTasks = new List<Task>();
var writeStrategy = new AsIsWriteStrategy();

X509Certificate2 tlsCertificate = TestResourceHelper.GetTestCertificate();
string targetHost = tlsCertificate.GetNameInfo(X509NameType.DnsName, false);

// Reflection fields for the re-entrant drain simulation
var queueField = typeof(TlsHandler).GetField("_pendingUnencryptedWrites",
System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance);
var lastTaskField = typeof(TlsHandler).GetField("_lastContextWriteTask",
System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance);

// Create a TlsHandler with a custom SslStream that wraps MediationStream
// in a QueueDrainingStreamWrapper. The wrapper intercepts SslStream's
// ciphertext writes and, when enabled, drains the pending write queue
// to simulate re-entrant HandleFailure.
QueueDrainingStreamWrapper streamWrapper = null;
TlsHandler tlsHandler = new TlsHandler(
stream =>
{
streamWrapper = new QueueDrainingStreamWrapper(stream);
return new SslStream(streamWrapper, true, (sender, certificate, chain, errors) => true);
},
new ClientTlsSettings(SslProtocols.Tls12, false, new List<X509Certificate>(), targetHost));

// Wire up the reflection targets so the wrapper can drain the queue
streamWrapper.SetTarget(tlsHandler, queueField, lastTaskField);

var ch = new EmbeddedChannel(tlsHandler);

// -- Complete the TLS handshake --
IByteBuffer readResultBuffer = Unpooled.Buffer(4 * 1024);
Func<ArraySegment<byte>, Task<int>> readDataFunc = async output =>
{
if (writeTasks.Count > 0)
{
await Task.WhenAll(writeTasks).WithTimeout(TestTimeout);
writeTasks.Clear();
}
if (readResultBuffer.ReadableBytes < output.Count)
{
if (ch.IsActive)
{
#pragma warning disable CS1998
await ReadOutboundAsync(async () => ch.ReadOutbound<IByteBuffer>(), output.Count - readResultBuffer.ReadableBytes, readResultBuffer, TestTimeout, readResultBuffer.ReadableBytes != 0 ? 0 : 1);
#pragma warning restore CS1998
}
}
int read = Math.Min(output.Count, readResultBuffer.ReadableBytes);
readResultBuffer.ReadBytes(output.Array, output.Offset, read);
return read;
};
var mediationStream = new MediationStream(readDataFunc, input =>
{
Task task = executor.SubmitAsync(() => writeStrategy.WriteToChannelAsync(ch, input)).Unwrap();
writeTasks.Add(task);
return task;
}, () => { ch.CloseAsync(); });

var driverStream = new SslStream(mediationStream, true, (_1, _2, _3, _4) => true);
await Task.Run(() => driverStream.AuthenticateAsServerAsync(tlsCertificate, false, SslProtocols.Tls12, false))
.WithTimeout(TimeSpan.FromSeconds(10));
writeTasks.Clear();

// -- Handshake complete. Enable the re-entrant drain simulation. --
streamWrapper.ShouldDrain = true;

// Write + Flush triggers: TlsHandler.Write (adds to queue) →
// TlsHandler.Flush → WrapAndFlush → Wrap → buf.ReadBytes(_sslStream, ...) →
// SslStream encrypts → wrapper.Write → MediationStream.Write (FinishWrap sets
// _lastContextWriteTask) → wrapper drains queue & clears _lastContextWriteTask →
// back in Wrap: Remove() returns null, _lastContextWriteTask is null →
// Without fix: promise.TryComplete() where promise is null → NRE
// With fix: if (promise is null) { break; } → exits gracefully
try
{
ch.WriteOutbound(Unpooled.WrappedBuffer(new byte[] { 1, 2, 3 }));
}
catch (Exception ex)
{
Assert.False(
ContainsNullReferenceException(ex),
$"NRE from Wrap.Remove() should not occur: {ex}");
}

try
{
ch.CheckException();
}
catch (Exception ex)
{
Assert.False(
ContainsNullReferenceException(ex),
$"NRE stored in channel: {ex}");
}

Assert.True(streamWrapper.WasDrained,
"The queue should have been drained during the write");

driverStream.Dispose();
}
finally
{
await executor.ShutdownGracefullyAsync(TimeSpan.Zero, TimeSpan.Zero);
}
}

static bool ContainsNullReferenceException(Exception ex)
{
if (ex is NullReferenceException) return true;
if (ex is AggregateException agg)
{
foreach (var inner in agg.Flatten().InnerExceptions)
{
if (inner is NullReferenceException) return true;
}
}
return ex.InnerException is object && ContainsNullReferenceException(ex.InnerException);
}

/// <summary>
/// Wraps MediationStream to simulate re-entrant queue drain during SslStream write.
/// After forwarding the encrypted write to MediationStream (which calls FinishWrap
/// and sets _lastContextWriteTask), it drains the pending write queue and clears
/// _lastContextWriteTask — reproducing the effect of HandleFailure being called
/// re-entrantly during an outbound write.
/// </summary>
sealed class QueueDrainingStreamWrapper : Stream
{
readonly Stream _inner;
object _handler;
System.Reflection.FieldInfo _queueField;
System.Reflection.FieldInfo _lastTaskField;
bool _drained;

public bool ShouldDrain { get; set; }
public bool WasDrained => _drained;

public QueueDrainingStreamWrapper(Stream inner) { _inner = inner; }

public void SetTarget(object handler, System.Reflection.FieldInfo queueField, System.Reflection.FieldInfo lastTaskField)
{
_handler = handler;
_queueField = queueField;
_lastTaskField = lastTaskField;
}

public override void Write(byte[] buffer, int offset, int count)
{
_inner.Write(buffer, offset, count);
DrainIfNeeded();
}

private void DrainIfNeeded()
{
if (ShouldDrain && !_drained)
{
_drained = true;
// Clear _lastContextWriteTask so Remove()'s null hits the else branch
// (promise.TryComplete()) instead of the ContinueWith path in LinkOutcome
_lastTaskField.SetValue(_handler, null);
// Drain the queue to make Remove() return null
var queue = (BatchingPendingWriteQueue)_queueField.GetValue(_handler);
queue.RemoveAndFailAll(new IOException("simulated connection failure"));
}
}

// Required Stream overrides (forward to inner)
public override void Flush() => _inner.Flush();
public override int Read(byte[] buffer, int offset, int count) => _inner.Read(buffer, offset, count);
public override Task<int> ReadAsync(byte[] buffer, int offset, int count, System.Threading.CancellationToken cancellationToken) => _inner.ReadAsync(buffer, offset, count, cancellationToken);
public override long Seek(long offset, SeekOrigin origin) => _inner.Seek(offset, origin);
public override void SetLength(long value) => _inner.SetLength(value);
public override bool CanRead => _inner.CanRead;
public override bool CanSeek => _inner.CanSeek;
public override bool CanWrite => _inner.CanWrite;
public override long Length => _inner.Length;
public override long Position { get => _inner.Position; set => _inner.Position = value; }

#if NETCOREAPP || NETSTANDARD_2_0_GREATER
public override System.Threading.Tasks.ValueTask<int> ReadAsync(System.Memory<byte> buffer, System.Threading.CancellationToken cancellationToken = default)
=> _inner.ReadAsync(buffer, cancellationToken);

public override void Write(System.ReadOnlySpan<byte> buffer)
{
_inner.Write(buffer);
DrainIfNeeded();
}

public override System.Threading.Tasks.ValueTask WriteAsync(System.ReadOnlyMemory<byte> buffer, System.Threading.CancellationToken cancellationToken = default)
{
var result = _inner.WriteAsync(buffer, cancellationToken);
DrainIfNeeded();
return result;
}
#endif

public override Task WriteAsync(byte[] buffer, int offset, int count, System.Threading.CancellationToken cancellationToken)
{
var task = _inner.WriteAsync(buffer, offset, count, cancellationToken);
DrainIfNeeded();
return task;
}

#if !NETCOREAPP1_1
public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state) => _inner.BeginRead(buffer, offset, count, callback, state);
public override int EndRead(IAsyncResult asyncResult) => _inner.EndRead(asyncResult);
public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state)
{
// On .NET Framework, SslStream.Write uses BeginWrite/EndWrite internally
var result = _inner.BeginWrite(buffer, offset, count, callback, state);
DrainIfNeeded();
return result;
}
public override void EndWrite(IAsyncResult asyncResult) => _inner.EndWrite(asyncResult);
#endif

protected override void Dispose(bool disposing) { if (disposing) _inner.Dispose(); base.Dispose(disposing); }
}

}
}