From cae2566e12dbc5a2f782972cdc45f6c49133357f Mon Sep 17 00:00:00 2001 From: Wraith2 Date: Wed, 23 Oct 2024 01:38:52 +0100 Subject: [PATCH 01/17] reimplementation of experimental branch on main --- .../src/Microsoft.Data.SqlClient.csproj | 6 + .../Microsoft/Data/SqlClient/SqlDataReader.cs | 11 +- .../src/Microsoft/Data/SqlClient/TdsParser.cs | 6 +- .../SqlClient/TdsParserStateObject.netcore.cs | 115 ++++++- .../netfx/src/Microsoft.Data.SqlClient.csproj | 6 + .../Microsoft/Data/SqlClient/SqlDataReader.cs | 2 +- .../SqlClient/TdsParserStateObject.netfx.cs | 115 ++++++- .../src/Microsoft/Data/SqlClient/Packet.cs | 122 +++++++ .../TdsParserStateObject.Multiplexer.cs | 298 ++++++++++++++++++ .../Data/SqlClient/TdsParserStateObject.cs | 297 ++++++++++++----- 10 files changed, 879 insertions(+), 99 deletions(-) create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Packet.cs create mode 100644 src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.Multiplexer.cs diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj index 08e07b1989..2b03ede595 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj @@ -185,6 +185,9 @@ Microsoft\Data\SqlClient\OnChangedEventHandler.cs + + Microsoft\Data\SqlClient\Packet.cs + Microsoft\Data\SqlClient\ParameterPeekAheadValue.cs @@ -578,6 +581,9 @@ Microsoft\Data\SqlClient\TdsParserStateObject.cs + + Microsoft\Data\SqlClient\TdsParserStateObject.Multiplexer.cs + Microsoft\Data\SqlClient\TdsParserStaticMethods.cs diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs index cc07cd03c3..85c4516f91 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs @@ -3566,7 +3566,7 @@ private TdsOperationStatus TryNextResult(out bool more) /// // user must call Read() to position on the first row - override public bool Read() + public override bool Read() { if (_currentTask != null) { @@ -4198,7 +4198,7 @@ private TdsOperationStatus TryResetBlobState() #if DEBUG else { - Debug.Assert((_sharedState._columnDataBytesRemaining == 0 || _sharedState._columnDataBytesRemaining == -1) && _stateObj._longlen == 0, "Haven't read header yet, but column is partially read?"); + //Debug.Assert((_sharedState._columnDataBytesRemaining == 0 || _sharedState._columnDataBytesRemaining == -1) && _stateObj._longlen == 0, "Haven't read header yet, but column is partially read?"); } #endif @@ -4349,9 +4349,10 @@ internal TdsOperationStatus TrySetMetaData(_SqlMetaDataSet metaData, bool moreIn _metaDataConsumed = true; if (_parser != null) - { // There is a valid case where parser is null - // Peek, and if row token present, set _hasRows true since there is a - // row in the result + { + // There is a valid case where parser is null + // Peek, and if row token present, set _hasRows true since there is a + // row in the result byte b; TdsOperationStatus result = _stateObj.TryPeekByte(out b); if (result != TdsOperationStatus.Done) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs index d99ca11429..55e4a71788 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -2412,6 +2412,7 @@ internal TdsOperationStatus TryRun(RunBehavior runBehavior, SqlCommand cmdHandle } } + // temporarily cache next byte byte peekedToken; result = stateObj.TryPeekByte(out peekedToken); if (result != TdsOperationStatus.Done) @@ -4162,6 +4163,8 @@ internal TdsOperationStatus TryProcessReturnValue(int length, TdsParserStateObje { return result; } + + // Length of parameter name byte len; result = stateObj.TryReadByte(out len); if (result != TdsOperationStatus.Done) @@ -4560,7 +4563,6 @@ internal TdsOperationStatus TryProcessCollation(TdsParserStateObject stateObj, o collation = null; return result; } - if (SqlCollation.Equals(_cachedCollation, info, sortId)) { collation = _cachedCollation; @@ -5284,7 +5286,7 @@ private TdsOperationStatus TryCommonProcessMetaData(TdsParserStateObject stateOb { // If the column is encrypted, we should have a valid cipherTable if (cipherTable != null) - { + { result = TryProcessTceCryptoMetadata(stateObj, col, cipherTable, columnEncryptionSetting, isReturnValue: false); if (result != TdsOperationStatus.Done) { diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.netcore.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.netcore.cs index 9eab4ae44f..f9ca2975fc 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.netcore.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.netcore.cs @@ -393,7 +393,7 @@ private void ReadSniError(TdsParserStateObject stateObj, uint error) AssertValidState(); } - public void ProcessSniPacket(PacketHandle packet, uint error) + public void ProcessSniPacket(PacketHandle packet, uint error, bool usePartialPacket = false) { if (error != 0) { @@ -410,8 +410,23 @@ public void ProcessSniPacket(PacketHandle packet, uint error) else { uint dataSize = 0; + bool usedPartialPacket = false; + uint getDataError = 0; - uint getDataError = SNIPacketGetData(packet, _inBuff, ref dataSize); + if (usePartialPacket && _snapshot == null && _partialPacket != null && _partialPacket.IsComplete) + { + //Debug.Assert(_snapshot == null, "_snapshot must be null when processing partial packet instead of network read"); + //Debug.Assert(_partialPacket != null, "_partialPacket must not be null when usePartialPacket is true"); + //Debug.Assert(_partialPacket.IsComplete, "_partialPacket.IsComplete must be true to use it in place of a real read"); + SetBuffer(_partialPacket.Buffer, 0, _partialPacket.CurrentLength); + ClearPartialPacket(); + getDataError = TdsEnums.SNI_SUCCESS; + usedPartialPacket = true; + } + else + { + getDataError = SNIPacketGetData(packet, _inBuff, ref dataSize); + } if (getDataError == TdsEnums.SNI_SUCCESS) { @@ -421,18 +436,100 @@ public void ProcessSniPacket(PacketHandle packet, uint error) throw SQL.InvalidInternalPacketSize(StringsHelper.GetString(Strings.SqlMisc_InvalidArraySizeMessage)); } - _lastSuccessfulIOTimer._value = DateTime.UtcNow.Ticks; - _inBytesRead = (int)dataSize; - _inBytesUsed = 0; + if (!usedPartialPacket) + { + _lastSuccessfulIOTimer._value = DateTime.UtcNow.Ticks; + + SetBuffer(_inBuff, 0, (int)dataSize); + } + + bool recurse; + bool appended = false; + do + { + MultiplexPackets( + _inBuff, _inBytesUsed, _inBytesRead, + _partialPacket, + out int newDataOffset, + out int newDataLength, + out Packet remainderPacket, + out bool consumeInputDirectly, + out bool consumePartialPacket, + out bool remainderPacketProduced, + out recurse + ); + bool bufferIsPartialCompleted = false; + + // if a partial packet was reconstructed it must be handled first + if (consumePartialPacket) + { + if (_snapshot != null) + { + _snapshot.AppendPacketData(_partialPacket.Buffer, _partialPacket.CurrentLength); + appended = true; + } + else + { + SetBuffer(_partialPacket.Buffer, 0, _partialPacket.CurrentLength); + bufferIsPartialCompleted = true; + } + ClearPartialPacket(); + } + + // if the remaining data can be processed directly it must be second + if (consumeInputDirectly) + { + // if some data was taken from the new packet adjust the counters + if (dataSize != newDataLength || 0 != newDataOffset) + { + SetBuffer(_inBuff, newDataOffset, newDataLength); + } + + if (_snapshot != null) + { + _snapshot.AppendPacketData(_inBuff, _inBytesRead); + appended = true; + } + else + { + SetBuffer(_inBuff, 0, _inBytesRead); + } + } + else + { + // whatever is in the input buffer should not be directly consumed + // and is contained in the partial or remainder packets so make sure + // we don't process it + if (!bufferIsPartialCompleted) + { + SetBuffer(_inBuff, 0, 0); + } + } + + // if there is a remainder it must be last + if (remainderPacketProduced) + { + SetPartialPacket(remainderPacket); + if (!bufferIsPartialCompleted) + { + // we are keeping the partial packet buffer so replace it with a new one + // unless we have already set the buffer to the partial packet buffer + SetBuffer(new byte[_inBuff.Length], 0, 0); + } + } + + } while (recurse && _snapshot != null); if (_snapshot != null) { - _snapshot.AppendPacketData(_inBuff, _inBytesRead); - if (_snapshotReplay) + if (_snapshotStatus != SnapshotStatus.NotActive && appended) { _snapshot.MoveNext(); #if DEBUG - _snapshot.AssertCurrent(); + // multiple packets can be appended by demuxing but we should only move + // forward by a single packet so we can no longer assert that we are on + // the last packet at this time + //_snapshot.AssertCurrent(); #endif } } @@ -1633,7 +1730,7 @@ internal void AssertStateIsClean() if ((parser != null) && (parser.State != TdsParserState.Closed) && (parser.State != TdsParserState.Broken)) { // Async reads - Debug.Assert(_snapshot == null && !_snapshotReplay, "StateObj has leftover snapshot state"); + Debug.Assert(_snapshot == null && _snapshotStatus == SnapshotStatus.NotActive); Debug.Assert(!_asyncReadWithoutSnapshot, "StateObj has AsyncReadWithoutSnapshot still enabled"); Debug.Assert(_executionContext == null, "StateObj has a stored execution context from an async read"); // Async writes diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj index c42fc5b18e..7486817fbd 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj @@ -319,6 +319,9 @@ Microsoft\Data\SqlClient\OnChangedEventHandler.cs + + Microsoft\Data\SqlClient\Packet.cs + Microsoft\Data\SqlClient\ParameterPeekAheadValue.cs @@ -697,6 +700,9 @@ Microsoft\Data\SqlClient\TdsParserStateObject.cs + + Microsoft\Data\SqlClient\TdsParserStateObject.Multiplexer.cs + Microsoft\Data\SqlClient\TdsParserStaticMethods.cs diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDataReader.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDataReader.cs index e8f4938964..5a8254aa64 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDataReader.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDataReader.cs @@ -4745,7 +4745,7 @@ private TdsOperationStatus TryResetBlobState() #if DEBUG else { - Debug.Assert((_sharedState._columnDataBytesRemaining == 0 || _sharedState._columnDataBytesRemaining == -1) && _stateObj._longlen == 0, "Haven't read header yet, but column is partially read?"); + //Debug.Assert((_sharedState._columnDataBytesRemaining == 0 || _sharedState._columnDataBytesRemaining == -1) && _stateObj._longlen == 0, "Haven't read header yet, but column is partially read?"); } #endif diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.netfx.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.netfx.cs index b58c40a0a5..42c2908856 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.netfx.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.netfx.cs @@ -524,7 +524,7 @@ private void ReadSniError(TdsParserStateObject stateObj, uint error) AssertValidState(); } - public void ProcessSniPacket(PacketHandle packet, uint error) + public void ProcessSniPacket(PacketHandle packet, uint error, bool usePartialPacket = false) { if (error != 0) { @@ -541,8 +541,23 @@ public void ProcessSniPacket(PacketHandle packet, uint error) else { uint dataSize = 0; + bool usedPartialPacket = false; + uint getDataError = 0; - uint getDataError = SNINativeMethodWrapper.SNIPacketGetData(packet, _inBuff, ref dataSize); + if (usePartialPacket && _snapshot == null && _partialPacket != null && _partialPacket.IsComplete) + { + //Debug.Assert(_snapshot == null, "_snapshot must be null when processing partial packet instead of network read"); + //Debug.Assert(_partialPacket != null, "_partialPacket must not be null when usePartialPacket is true"); + //Debug.Assert(_partialPacket.IsComplete, "_partialPacket.IsComplete must be true to use it in place of a real read"); + SetBuffer(_partialPacket.Buffer, 0, _partialPacket.CurrentLength); + ClearPartialPacket(); + getDataError = TdsEnums.SNI_SUCCESS; + usedPartialPacket = true; + } + else + { + getDataError = SNINativeMethodWrapper.SNIPacketGetData(packet, _inBuff, ref dataSize); + } if (getDataError == TdsEnums.SNI_SUCCESS) { @@ -552,18 +567,100 @@ public void ProcessSniPacket(PacketHandle packet, uint error) throw SQL.InvalidInternalPacketSize(StringsHelper.GetString(Strings.SqlMisc_InvalidArraySizeMessage)); } - _lastSuccessfulIOTimer._value = DateTime.UtcNow.Ticks; - _inBytesRead = (int)dataSize; - _inBytesUsed = 0; + if (!usedPartialPacket) + { + _lastSuccessfulIOTimer._value = DateTime.UtcNow.Ticks; + + SetBuffer(_inBuff, 0, (int)dataSize); + } + + bool recurse; + bool appended = false; + do + { + MultiplexPackets( + _inBuff, _inBytesUsed, _inBytesRead, + _partialPacket, + out int newDataOffset, + out int newDataLength, + out Packet remainderPacket, + out bool consumeInputDirectly, + out bool consumePartialPacket, + out bool remainderPacketProduced, + out recurse + ); + bool bufferIsPartialCompleted = false; + + // if a partial packet was reconstructed it must be handled first + if (consumePartialPacket) + { + if (_snapshot != null) + { + _snapshot.AppendPacketData(_partialPacket.Buffer, _partialPacket.CurrentLength); + appended = true; + } + else + { + SetBuffer(_partialPacket.Buffer, 0, _partialPacket.CurrentLength); + bufferIsPartialCompleted = true; + } + ClearPartialPacket(); + } + + // if the remaining data can be processed directly it must be second + if (consumeInputDirectly) + { + // if some data was taken from the new packet adjust the counters + if (dataSize != newDataLength || 0 != newDataOffset) + { + SetBuffer(_inBuff, newDataOffset, newDataLength); + } + + if (_snapshot != null) + { + _snapshot.AppendPacketData(_inBuff, _inBytesRead); + appended = true; + } + else + { + SetBuffer(_inBuff, 0, _inBytesRead); + } + } + else + { + // whatever is in the input buffer should not be directly consumed + // and is contained in the partial or remainder packets so make sure + // we don't process it + if (!bufferIsPartialCompleted) + { + SetBuffer(_inBuff, 0, 0); + } + } + + // if there is a remainder it must be last + if (remainderPacketProduced) + { + SetPartialPacket(remainderPacket); + if (!bufferIsPartialCompleted) + { + // we are keeping the partial packet buffer so replace it with a new one + // unless we have already set the buffer to the partial packet buffer + SetBuffer(new byte[_inBuff.Length], 0, 0); + } + } + + } while (recurse && _snapshot != null); if (_snapshot != null) { - _snapshot.AppendPacketData(_inBuff, _inBytesRead); - if (_snapshotReplay) + if (_snapshotStatus != SnapshotStatus.NotActive && appended) { _snapshot.MoveNext(); #if DEBUG - _snapshot.AssertCurrent(); + // multiple packets can be appended by demuxing but we should only move + // forward by a single packet so we can no longer assert that we are on + // the last packet at this time + //_snapshot.AssertCurrent(); #endif } } @@ -1773,7 +1870,7 @@ internal void AssertStateIsClean() if ((parser != null) && (parser.State != TdsParserState.Closed) && (parser.State != TdsParserState.Broken)) { // Async reads - Debug.Assert(_snapshot == null && !_snapshotReplay, "StateObj has leftover snapshot state"); + Debug.Assert(_snapshot == null && _snapshotStatus == SnapshotStatus.NotActive, "StateObj has leftover snapshot state"); Debug.Assert(!_asyncReadWithoutSnapshot, "StateObj has AsyncReadWithoutSnapshot still enabled"); Debug.Assert(_executionContext == null, "StateObj has a stored execution context from an async read"); // Async writes diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Packet.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Packet.cs new file mode 100644 index 0000000000..5b0463008f --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Packet.cs @@ -0,0 +1,122 @@ +using System; + +namespace Microsoft.Data.SqlClient +{ + internal sealed class Packet + { + public const int UnknownDataLength = -1; + + private bool _disposed; + private int _dataLength; + private int _totalLength; + private byte[] _buffer; + + public Packet() + { + _disposed = false; + _dataLength = UnknownDataLength; + } + + public int DataLength + { + get + { + CheckDisposed(); + return _dataLength; + } + set + { + CheckDisposed(); + //if (value > 7992) + //{ + // Debugger.Break(); + //} + _dataLength = value; + } + } + public byte[] Buffer + { + get + { + CheckDisposed(); + return _buffer; + } + set + { + CheckDisposed(); + _buffer = value; + } + } + public int CurrentLength + { + get + { + CheckDisposed(); + return _totalLength; + } + set + { + CheckDisposed(); + _totalLength = value; + } + } + + public int RequiredLength + { + get + { + CheckDisposed(); + if (!HasDataLength) + { + throw new InvalidOperationException($"cannot get {nameof(RequiredLength)} when {nameof(HasDataLength)} is false"); + } + return TdsEnums.HEADER_LEN + _dataLength; + } + } + + public bool HasHeader => _totalLength >= TdsEnums.HEADER_LEN; + + public bool HasDataLength => _dataLength >= 0; + + public bool IsComplete => _dataLength != UnknownDataLength && (TdsEnums.HEADER_LEN + _dataLength) == _totalLength; + + public ReadOnlySpan GetHeaderSpan() => _buffer.AsSpan(0, TdsEnums.HEADER_LEN); + + public void Dispose() + { + _disposed = true; + } + + public void CheckDisposed() + { + if (_disposed) + { + ThrowDisposed(); + } + } + + public static void ThrowDisposed() + { + throw new ObjectDisposedException(nameof(Packet)); + } + + internal static byte GetStatusFromHeader(ReadOnlySpan header) => header[1]; + + internal static int GetDataLengthFromHeader(ReadOnlySpan header) + { + return (header[TdsEnums.HEADER_LEN_FIELD_OFFSET] << 8 | header[TdsEnums.HEADER_LEN_FIELD_OFFSET + 1]) - TdsEnums.HEADER_LEN; + } + internal static int GetSpidFromHeader(ReadOnlySpan header) + { + return (header[TdsEnums.SPID_OFFSET] << 8 | header[TdsEnums.SPID_OFFSET + 1]); + } + internal static int GetIDFromHeader(ReadOnlySpan header) + { + return header[TdsEnums.HEADER_LEN_FIELD_OFFSET + 4]; + } + + internal static int GetDataLengthFromHeader(Packet packet) => GetDataLengthFromHeader(packet.GetHeaderSpan()); + + internal static bool GetIsEOMFromHeader(ReadOnlySpan header) => GetStatusFromHeader(header) == 1; + } +} diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.Multiplexer.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.Multiplexer.cs new file mode 100644 index 0000000000..fb610b607e --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.Multiplexer.cs @@ -0,0 +1,298 @@ +using System; +using System.Diagnostics; + +namespace Microsoft.Data.SqlClient +{ + partial class TdsParserStateObject + { + private Packet __partialPacket; + private Packet _partialPacket => __partialPacket; + + private void SetPartialPacket(Packet packet/*, [CallerMemberName] string caller = null*/) + { + if (__partialPacket != null && packet != null) + { + throw new InvalidOperationException("partial packet cannot be non-null when setting to non=null"); + } + __partialPacket = packet; + } + + private void ClearPartialPacket(/*[CallerMemberName] string caller = null*/) + { + Packet partialPacket = __partialPacket; + __partialPacket = null; + if (partialPacket != null) + { + ReadOnlySpan header = partialPacket.GetHeaderSpan(); + int packetId = Packet.GetIDFromHeader(header); + bool isEOM = Packet.GetIsEOMFromHeader(header); + partialPacket.Dispose(); + } + } + + public static void MultiplexPackets( + byte[] dataBuffer, int dataOffset, int dataLength, + Packet partialPacket, + out int newDataOffset, + out int newDataLength, + out Packet remainderPacket, + out bool consumeInputDirectly, + out bool consumePartialPacket, + out bool consumeRemainderPacket, + out bool recurse + ) + { + ReadOnlySpan data = dataBuffer.AsSpan(dataOffset, dataLength); + remainderPacket = null; + consumeInputDirectly = false; + consumePartialPacket = false; + consumeRemainderPacket = false; + recurse = false; + + newDataLength = dataLength; + newDataOffset = dataOffset; + + int bytesConsumed = 0; + + if (partialPacket != null) + { + if (!partialPacket.HasDataLength) + { + // we need to get enough bytes to read the packet header + int headeBytesNeeded = Math.Max(0, TdsEnums.HEADER_LEN - partialPacket.CurrentLength); + if (headeBytesNeeded > 0) + { + int headerBytesAvailable = Math.Min(data.Length, headeBytesNeeded); + Span headerTarget = partialPacket.Buffer.AsSpan(partialPacket.CurrentLength, headerBytesAvailable); + ReadOnlySpan headerSource = data.Slice(0, headerBytesAvailable); + headerSource.CopyTo(headerTarget); + partialPacket.CurrentLength = partialPacket.CurrentLength + headerBytesAvailable; + data = data.Slice(headerBytesAvailable); + bytesConsumed += headerBytesAvailable; + } + if (partialPacket.HasHeader) + { + partialPacket.DataLength = Packet.GetDataLengthFromHeader(partialPacket); + //if (partialPacket.DataLength > dataBuffer.Length) + //{ + // Debugger.Break(); + //} + } + } + + if (partialPacket.HasDataLength) + { + if (partialPacket.CurrentLength < partialPacket.RequiredLength) + { + // the packet length is known so take as much data as possible from the incoming + // data to try and complete the packet + int payloadBytesNeeded = partialPacket.DataLength - (partialPacket.CurrentLength - TdsEnums.HEADER_LEN); + int payloadBytesAvailable = Math.Min(data.Length, payloadBytesNeeded); + Span payloadTarget = partialPacket.Buffer.AsSpan(partialPacket.CurrentLength, payloadBytesAvailable); + ReadOnlySpan payloadSource = data.Slice(0, payloadBytesAvailable); + payloadSource.CopyTo(payloadTarget); + partialPacket.CurrentLength = partialPacket.CurrentLength + payloadBytesAvailable; + bytesConsumed += payloadBytesAvailable; + data = data.Slice(payloadBytesAvailable); + } + else if (partialPacket.CurrentLength > partialPacket.RequiredLength) + { + // the packet contains an entire packet and more data after that so we need + // to extract the following data into a new packet with a new buffer and return + // it as the remainer packet + + int remainderLength = partialPacket.CurrentLength - partialPacket.RequiredLength; + remainderPacket = new Packet + { + Buffer = new byte[dataBuffer.Length], + CurrentLength = remainderLength, + }; + Buffer.BlockCopy( + partialPacket.Buffer, partialPacket.RequiredLength, // from + remainderPacket.Buffer, 0, // to + remainderPacket.CurrentLength // for + ); + partialPacket.CurrentLength = partialPacket.CurrentLength - remainderPacket.CurrentLength; + consumeRemainderPacket = true; + + if (remainderPacket.HasHeader) + { + remainderPacket.DataLength = Packet.GetDataLengthFromHeader(remainderPacket); + if (remainderPacket.HasDataLength && remainderPacket.CurrentLength >= remainderPacket.RequiredLength) + { + recurse = true; + } + } + } + + if (partialPacket.CurrentLength == partialPacket.RequiredLength) + { + // partial packet has been completed + consumePartialPacket = true; + } + } + + if (bytesConsumed > 0) + { + if (data.Length > 0) + { + //if (data[0] == 120) + //{ + // var d = Vizualize(dataBuffer, dataOffset, dataLength); + // Debugger.Break(); + //} + + // some data has been taken from the buffer, put into the partial + // packet buffer and we have data left so move the data we have + // left to the start of the buffer so we can pass the buffer back + // as zero based to the caller avoiding offset calculations everywhere + Buffer.BlockCopy( + dataBuffer, dataOffset + bytesConsumed, // from + dataBuffer, dataOffset, // to + dataLength - bytesConsumed // for + ); + + //// for debugging purposes fill the removed data area with an easily + //// recognisable pattern so we can see if it is misused + //Span removed = dataBuffer.AsSpan(dataOffset + (dataLength - bytesConsumed), (dataOffset + bytesConsumed)); + //removed.Fill(0xFF); + + // then recreate the data span so that we're looking at the data + // that has been moved + data = dataBuffer.AsSpan(dataOffset, dataLength - bytesConsumed); + //if (data[0] == 120) + //{ + // Debugger.Break(); + //} + } + + newDataLength = dataLength - bytesConsumed; + } + } + + if (data.Length > 0 && !consumeRemainderPacket) + { + if (data.Length >= TdsEnums.HEADER_LEN) + { + // we have enough bytes to read the packet header and see how + // much data we are expecting it to contain + int packetDataLength = Packet.GetDataLengthFromHeader(data); + //if (packetDataLength > dataBuffer.Length) + //{ + // Debugger.Break(); + //} + if (data.Length == TdsEnums.HEADER_LEN + packetDataLength) + { + if (!consumePartialPacket) + { + // we can tell the caller that they should directly consume the data in + // the input buffer, this is the happy path + consumeInputDirectly = true; + } + else + { + // we took some data from the input to reconstruct the partial packet + // so we can't tell the caller to directly consume the packet in the + // input buffer, we need to construct a new remainder packet and then + // tell them to consume it + remainderPacket = new Packet + { + Buffer = dataBuffer, + DataLength = packetDataLength, + CurrentLength = data.Length + }; + consumeRemainderPacket = true; + //Debug.Assert(remainderPacket.HasHeader); // precondition of entering this block + //Debug.Assert(remainderPacket.HasDataLength); // must have been set at construction + if (remainderPacket.CurrentLength >= remainderPacket.RequiredLength) + { + // the remainder packet contains more data than the packet so we need + // to tell the caller to recurse into this function again once they have + // consumed the first packet + recurse = true; + } + } + } + else if (data.Length < TdsEnums.HEADER_LEN + packetDataLength) + { + // another partial packet so produce one and tell the caller that they need + // consume it. + remainderPacket = new Packet + { + Buffer = dataBuffer, + DataLength = packetDataLength, + CurrentLength = data.Length + }; + consumeRemainderPacket = true; + Debug.Assert(remainderPacket.HasHeader); // precondition of entering this block + if (remainderPacket.HasDataLength && remainderPacket.CurrentLength >= remainderPacket.RequiredLength) + { + // the remainder packet contains more data than the packet so we need + // to tell the caller to recurse into this function again once they have + // consumed the first packet + recurse = true; + } + } + else // implied: current length > required length + { + // more data than required so need to split it out but we can't do that + // here so we need to tell the caller to take the remainer packet and then + // call this function again + remainderPacket = new Packet + { + Buffer = dataBuffer, + DataLength = packetDataLength, + CurrentLength = data.Length + }; + consumeRemainderPacket = true; + recurse = true; + } + } + else + { + // we don't have enough information to read the header + if (!consumePartialPacket) + { + // we can tell the caller that they should directly consume the data in + // the input buffer, this is the happy path + consumeInputDirectly = true; + } + else + { + // we took some data from the input to reconstruct the partial packet + // so we can't tell the caller to directly consume the packet in the + // input buffer, we need to construct a new remainder packet and then + // tell them to consume it + remainderPacket = new Packet + { + Buffer = dataBuffer, + CurrentLength = data.Length + }; + consumeRemainderPacket = true; + if (remainderPacket.HasHeader) + { + remainderPacket.DataLength = Packet.GetDataLengthFromHeader(remainderPacket.GetHeaderSpan()); + } + if (remainderPacket.HasDataLength && remainderPacket.CurrentLength >= remainderPacket.RequiredLength) + { + // the remainder packet contains more data than the packet so we need + // to tell the caller to recurse into this function again once they have + // consumed the first packet + recurse = true; + } + } + } + } + + if (remainderPacket != null && remainderPacket.HasHeader) + { + remainderPacket.Buffer[7] = 0xF; + } + + if (consumePartialPacket && consumeInputDirectly) + { + throw new InvalidOperationException($"AppendData cannot return both {nameof(consumePartialPacket)} and {nameof(consumeInputDirectly)}"); + } + } + } +} diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs index eb8d4bfe68..004d78731f 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs @@ -64,6 +64,13 @@ public TimeoutState(int value) public int IdentityValue => _value; } + private enum SnapshotStatus + { + NotActive, + ReplayStarting, + ReplayRunning + } + private const int AttentionTimeoutSeconds = 5; // Ticks to consider a connection "good" after a successful I/O (10,000 ticks = 1 ms) @@ -215,7 +222,7 @@ public TimeoutState(int value) internal TaskCompletionSource _networkPacketTaskSource; private Timer _networkPacketTimeout; internal bool _syncOverAsync = true; - private bool _snapshotReplay; + private SnapshotStatus _snapshotStatus; private StateSnapshot _snapshot; private StateSnapshot _cachedSnapshot; internal ExecutionContext _executionContext; @@ -939,13 +946,11 @@ internal TdsOperationStatus TryProcessHeader() if (_partialHeaderBytesRead == _inputHeaderLen) { // All read + ReadOnlySpan header = _partialHeaderBuffer.AsSpan(0, TdsEnums.HEADER_LEN); _partialHeaderBytesRead = 0; - _inBytesPacket = ((int)_partialHeaderBuffer[TdsEnums.HEADER_LEN_FIELD_OFFSET] << 8 | - (int)_partialHeaderBuffer[TdsEnums.HEADER_LEN_FIELD_OFFSET + 1]) - _inputHeaderLen; - - _messageStatus = _partialHeaderBuffer[1]; - _spid = _partialHeaderBuffer[TdsEnums.SPID_OFFSET] << 8 | - _partialHeaderBuffer[TdsEnums.SPID_OFFSET + 1]; + _messageStatus = Packet.GetStatusFromHeader(header); + _inBytesPacket = Packet.GetDataLengthFromHeader(header); + _spid = Packet.GetSpidFromHeader(header); SqlClientEventSource.Log.TryAdvancedTraceEvent("TdsParserStateObject.TryProcessHeader | ADV | State Object Id {0}, Client Connection Id {1}, Server process Id (SPID) {2}", _objectID, _parser?.Connection?.ClientConnectionId, _spid); } @@ -981,11 +986,10 @@ internal TdsOperationStatus TryProcessHeader() else { // normal header processing... - _messageStatus = _inBuff[_inBytesUsed + 1]; - _inBytesPacket = (_inBuff[_inBytesUsed + TdsEnums.HEADER_LEN_FIELD_OFFSET] << 8 | - _inBuff[_inBytesUsed + TdsEnums.HEADER_LEN_FIELD_OFFSET + 1]) - _inputHeaderLen; - _spid = _inBuff[_inBytesUsed + TdsEnums.SPID_OFFSET] << 8 | - _inBuff[_inBytesUsed + TdsEnums.SPID_OFFSET + 1]; + ReadOnlySpan header = _inBuff.AsSpan(_inBytesUsed, TdsEnums.HEADER_LEN); + _messageStatus = Packet.GetStatusFromHeader(header); + _inBytesPacket = Packet.GetDataLengthFromHeader(header); + _spid = Packet.GetSpidFromHeader(header); #if NET6_0_OR_GREATER SqlClientEventSource.Log.TryAdvancedTraceEvent("TdsParserStateObject.TryProcessHeader | ADV | State Object Id {0}, Client Connection Id {1}, Server process Id (SPID) {2}", _objectID, _parser?.Connection?.ClientConnectionId, _spid); #endif @@ -1127,9 +1131,7 @@ internal bool SetPacketSize(int size) // Allocate or re-allocate _inBuff. if (_inBuff == null) { - _inBuff = new byte[size]; - _inBytesRead = 0; - _inBytesUsed = 0; + SetBuffer(new byte[size], 0, 0); } else if (size != _inBuff.Length) { @@ -1139,28 +1141,24 @@ internal bool SetPacketSize(int size) // if we still have data left in the buffer we must keep that array reference and then copy into new one byte[] temp = _inBuff; - _inBuff = new byte[size]; - // copy remainder of unused data int remainingData = _inBytesRead - _inBytesUsed; - if ((temp.Length < _inBytesUsed + remainingData) || (_inBuff.Length < remainingData)) + if ((temp.Length < _inBytesUsed + remainingData) || (size < remainingData)) { - string errormessage = StringsHelper.GetString(Strings.SQL_InvalidInternalPacketSize) + ' ' + temp.Length + ", " + _inBytesUsed + ", " + remainingData + ", " + _inBuff.Length; + string errormessage = StringsHelper.GetString(Strings.SQL_InvalidInternalPacketSize) + ' ' + temp.Length + ", " + _inBytesUsed + ", " + remainingData + ", " + size; throw SQL.InvalidInternalPacketSize(errormessage); } - Buffer.BlockCopy(temp, _inBytesUsed, _inBuff, 0, remainingData); - _inBytesRead = _inBytesRead - _inBytesUsed; - _inBytesUsed = 0; + byte[] inBuff = new byte[size]; + Buffer.BlockCopy(temp, _inBytesUsed, inBuff, 0, remainingData); + SetBuffer(inBuff, 0, remainingData); AssertValidState(); } else { // buffer is empty - just create the new one that is double the size of the old one - _inBuff = new byte[size]; - _inBytesRead = 0; - _inBytesUsed = 0; + SetBuffer(new byte[size], 0, 0); } } @@ -1385,7 +1383,7 @@ internal TdsOperationStatus TryReadInt32(out int value) TdsOperationStatus result = TryReadByteArray(buffer, 4); if (result != TdsOperationStatus.Done) { - value = default; + value = 0; return result; } } @@ -1857,12 +1855,17 @@ internal TdsOperationStatus TryReadPlpBytes(ref byte[] buff, int offset, int len // If total length is known up front, allocate the whole buffer in one shot instead of realloc'ing and copying over each time if (buff == null && _longlen != TdsEnums.SQL_PLP_UNKNOWNLEN) { - if (_snapshot != null) + if (_snapshot != null && _snapshotStatus != SnapshotStatus.NotActive) { // if there is a snapshot and it contains a stored plp buffer take it // and try to use it if it is the right length buff = _snapshot._plpBuffer; _snapshot._plpBuffer = null; + if (_snapshot.ContinueEnabled) + { + offset = _snapshot.GetPacketDataOffset(); + totalBytesRead = offset; + } } if ((ulong)(buff?.Length ?? 0) != _longlen) @@ -1906,7 +1909,7 @@ internal TdsOperationStatus TryReadPlpBytes(ref byte[] buff, int offset, int len buff = newbuf; } - TdsOperationStatus result = TryReadByteArray(buff.AsSpan(offset), bytesToRead, out bytesRead); + bool result = TryReadByteArray(buff.AsSpan(offset), bytesToRead, out bytesRead); Debug.Assert(bytesRead <= bytesLeft, "Read more bytes than we needed"); Debug.Assert((ulong)bytesRead <= _longlenleft, "Read more bytes than is available"); @@ -1914,7 +1917,7 @@ internal TdsOperationStatus TryReadPlpBytes(ref byte[] buff, int offset, int len offset += bytesRead; totalBytesRead += bytesRead; _longlenleft -= (ulong)bytesRead; - if (result != TdsOperationStatus.Done) + if (!result) { if (_snapshot != null) { @@ -1928,8 +1931,7 @@ internal TdsOperationStatus TryReadPlpBytes(ref byte[] buff, int offset, int len if (_longlenleft == 0) { // Read the next chunk or cleanup state if hit the end - result = TryReadPlpLength(false, out _); - if (result != TdsOperationStatus.Done) + if (!TryReadPlpLength(false, out _)) { if (_snapshot != null) { @@ -1937,7 +1939,7 @@ internal TdsOperationStatus TryReadPlpBytes(ref byte[] buff, int offset, int len // so it can be re-used when another packet arrives and we read again _snapshot._plpBuffer = buff; } - return result; + return false; } } @@ -2000,7 +2002,7 @@ internal TdsOperationStatus TryReadNetworkPacket() TdsOperationStatus result = TdsOperationStatus.InvalidData; if (_snapshot != null) { - if (_snapshotReplay) + if (_snapshotStatus != SnapshotStatus.NotActive) { #if DEBUG // in debug builds stack traces contain line numbers so if we want to be @@ -2025,17 +2027,40 @@ internal TdsOperationStatus TryReadNetworkPacket() { _lastStack = stackTrace; } + + if (_bTmpRead == 0 && _partialHeaderBytesRead == 0 && _longlenleft == 0 && _snapshot.ContinueEnabled) + { + // no temp between packets + // mark this point as continue-able + _snapshot.CaptureAsContinue(this); + } } #endif } // previous buffer is in snapshot _inBuff = new byte[_inBuff.Length]; + result = TdsOperationStatus.NeedMoreData; + } + + if (result == TdsOperationStatus.InvalidData && _partialPacket != null && !_partialPacket.IsComplete) + { + result = TdsOperationStatus.NeedMoreData; } if (_syncOverAsync) { ReadSniSyncOverAsync(); + + while (_inBytesRead == 0) + { + // a partial packet must have taken the packet data so we + // need to read more data to complete the packet but we + // can't return NeedMoreData in sync mode so we have to + // spin fetching more data here until we have something + // that the caller can read + ReadSniSyncOverAsync(); + } return TdsOperationStatus.Done; } @@ -2059,7 +2084,10 @@ internal TdsOperationStatus TryReadNetworkPacket() internal void PrepareReplaySnapshot() { _networkPacketTaskSource = null; - _snapshot.MoveToStart(); + //if (!_snapshot.MoveToContinue()) + { + _snapshot.MoveToStart(); + } } internal void ReadSniSyncOverAsync() @@ -2070,6 +2098,11 @@ internal void ReadSniSyncOverAsync() } PacketHandle readPacket = default; + bool readFromNetwork = true; + if (_partialPacket != null && _partialPacket.IsComplete) + { + readFromNetwork = false; + } uint error; @@ -2082,7 +2115,14 @@ internal void ReadSniSyncOverAsync() Interlocked.Increment(ref _readingCount); shouldDecrement = true; - readPacket = ReadSyncOverAsync(GetTimeoutRemaining(), out error); + if (readFromNetwork) + { + readPacket = ReadSyncOverAsync(GetTimeoutRemaining(), out error); + } + else + { + error = TdsEnums.SNI_SUCCESS; + } Interlocked.Decrement(ref _readingCount); shouldDecrement = false; @@ -2093,11 +2133,15 @@ internal void ReadSniSyncOverAsync() } if (TdsEnums.SNI_SUCCESS == error) - { // Success - process results! + { + // Success - process results! - Debug.Assert(!IsPacketEmpty(readPacket), "ReadNetworkPacket cannot be null in synchronous operation!"); + if (readFromNetwork) + { + Debug.Assert(!IsPacketEmpty(readPacket), "ReadNetworkPacket cannot be null in synchronous operation!"); + } - ProcessSniPacket(readPacket, 0); + ProcessSniPacket(readPacket, TdsEnums.SNI_SUCCESS, usePartialPacket: !readFromNetwork); #if DEBUG if (s_forcePendingReadsToWaitForUser) { @@ -2109,9 +2153,12 @@ internal void ReadSniSyncOverAsync() #endif } else - { // Failure! - - Debug.Assert(!IsValidPacket(readPacket), "unexpected readPacket without corresponding SNIPacketRelease"); + { + // Failure! + if (readFromNetwork) + { + Debug.Assert(!IsValidPacket(readPacket), "unexpected readPacket without corresponding SNIPacketRelease"); + } ReadSniError(this, error); } @@ -2123,9 +2170,12 @@ internal void ReadSniSyncOverAsync() Interlocked.Decrement(ref _readingCount); } - if (!IsPacketEmpty(readPacket)) + if (readFromNetwork) { - ReleasePacket(readPacket); + if (!IsPacketEmpty(readPacket)) + { + ReleasePacket(readPacket); + } } AssertValidState(); @@ -2577,7 +2627,7 @@ internal void SetSnapshot() } _snapshot = snapshot; _snapshot.CaptureAsStart(this); - _snapshotReplay = false; + _snapshotStatus = SnapshotStatus.NotActive; } internal void ResetSnapshot() @@ -2589,7 +2639,7 @@ internal void ResetSnapshot() snapshot.Clear(); Interlocked.CompareExchange(ref _cachedSnapshot, snapshot, null); } - _snapshotReplay = false; + _snapshotStatus = SnapshotStatus.NotActive; } sealed partial class StateSnapshot @@ -2601,6 +2651,37 @@ private sealed partial class PacketData public PacketData NextPacket; public PacketData PrevPacket; + public int DataOffset; + public int DataLength; + + public int TotalSize; + + internal int GetPacketTotalSize() + { + if (TotalSize == 0) + { + int previous = 0; + if (PrevPacket != null) + { + previous = PrevPacket.TotalSize; + } + return previous; + } + else + { + return TotalSize; + } + } + internal int GetPacketDataOffset() + { + int previous = 0; + if (PrevPacket != null) + { + previous = PrevPacket.TotalSize; + } + return TotalSize - (TotalSize - previous); + } + internal void Clear() { Buffer = null; @@ -2611,6 +2692,8 @@ internal void Clear() PrevPacket.NextPacket = null; PrevPacket = null; } + DataLength = 0; + DataOffset = 0; SetDebugStackInternal(null); SetDebugPacketIdInternal(0); } @@ -2673,10 +2756,8 @@ public PacketData[] Items partial void SetDebugPacketIdInternal(int value) => PacketId = value; - public override string ToString() { - //return $"{PacketId}: [{Buffer.Length}] ( {GetPacketDataOffset():D4}, {GetPacketTotalSize():D4} ) {(NextPacket != null ? @"->" : string.Empty)}"; string byteString = null; if (Buffer != null && Buffer.Length >= 12) { @@ -2697,28 +2778,17 @@ public override string ToString() } #endif - private sealed class PLPData - { - public readonly ulong SnapshotLongLen; - public readonly ulong SnapshotLongLenLeft; - - public PLPData(ulong snapshotLongLen, ulong snapshotLongLenLeft) - { - SnapshotLongLen = snapshotLongLen; - SnapshotLongLenLeft = snapshotLongLenLeft; - } - } - private sealed class StateObjectData { private int _inBytesUsed; private int _inBytesPacket; - private PLPData _plpData; private byte _messageStatus; internal NullBitmap _nullBitmapInfo; private _SqlMetaDataSet _cleanupMetaData; internal _SqlMetaDataSetCollection _cleanupAltMetaDataSetArray; private SnapshottedStateFlags _state; + public ulong _longLen; + public ulong _longLenLeft; internal void Capture(TdsParserStateObject stateObj, bool trackStack = true) { @@ -2726,10 +2796,8 @@ internal void Capture(TdsParserStateObject stateObj, bool trackStack = true) _inBytesPacket = stateObj._inBytesPacket; _messageStatus = stateObj._messageStatus; _nullBitmapInfo = stateObj._nullBitmapInfo; // _nullBitmapInfo must be cloned before it is updated - if (stateObj._longlen != 0 || stateObj._longlenleft != 0) - { - _plpData = new PLPData(stateObj._longlen, stateObj._longlenleft); - } + _longLen = stateObj._longlen; + _longLenLeft = stateObj._longlenleft; _cleanupMetaData = stateObj._cleanupMetaData; _cleanupAltMetaDataSetArray = stateObj._cleanupAltMetaDataSetArray; // _cleanupAltMetaDataSetArray must be cloned before it is updated _state = stateObj._snapshottedState; @@ -2749,7 +2817,8 @@ internal void Clear(TdsParserStateObject stateObj, bool trackStack = true) _inBytesPacket = 0; _messageStatus = 0; _nullBitmapInfo = default; - _plpData = null; + _longLen = 0; + _longLenLeft = 0; _cleanupMetaData = null; _cleanupAltMetaDataSetArray = null; _state = SnapshottedStateFlags.None; @@ -2782,26 +2851,30 @@ internal void Restore(TdsParserStateObject stateObj) //else _stateObj._hasOpenResult is already == _snapshotHasOpenResult stateObj._snapshottedState = _state; + // reset plp state + stateObj._longlen = _longLen; + stateObj._longlenleft = _longLenLeft; + // Reset partially read state (these only need to be maintained if doing async without snapshot) stateObj._bTmpRead = 0; stateObj._partialHeaderBytesRead = 0; - - // reset plp state - stateObj._longlen = _plpData?.SnapshotLongLen ?? 0; - stateObj._longlenleft = _plpData?.SnapshotLongLenLeft ?? 0; } } private TdsParserStateObject _stateObj; private StateObjectData _replayStateData; + private StateObjectData _continueStateData; internal byte[] _plpBuffer; private PacketData _lastPacket; private PacketData _firstPacket; private PacketData _current; + private PacketData _continuePacket; private PacketData _sparePacket; + private bool? _continueSupported; + #if DEBUG private int _packetCounter; private int _rollingPend = 0; @@ -2843,6 +2916,20 @@ internal void CheckStack(string trace) } } #endif + + public bool ContinueEnabled + { + get + { + if (_continueSupported == null) + { + _continueSupported = AppContext.TryGetSwitch("Switch.Microsoft.Data.SqlClient.UseExperimentalAsyncContinue", out bool value) ? value : false; + //_continueSupported = false; + } + return _continueSupported.Value; + } + } + internal void CloneNullBitmapInfo() { if (_stateObj._nullBitmapInfo.ReferenceEquals(_replayStateData?._nullBitmapInfo ?? default)) @@ -2863,6 +2950,7 @@ internal void AppendPacketData(byte[] buffer, int read) { Debug.Assert(buffer != null, "packet data cannot be null"); Debug.Assert(read >= TdsEnums.HEADER_LEN, "minimum packet length is TdsEnums.HEADER_LEN"); + Debug.Assert(TdsEnums.HEADER_LEN + Packet.GetDataLengthFromHeader(buffer) == read, "partially read packets cannot be appended to the snapshot"); #if DEBUG for (PacketData current = _firstPacket; current != null; current = current.NextPacket) { @@ -2899,10 +2987,12 @@ internal void AppendPacketData(byte[] buffer, int read) internal bool MoveNext() { bool retval = false; + SnapshotStatus moveToMode = SnapshotStatus.ReplayRunning; bool moved = false; if (_current == null) { _current = _firstPacket; + moveToMode = SnapshotStatus.ReplayStarting; moved = true; } else if (_current.NextPacket != null) @@ -2913,10 +3003,8 @@ internal bool MoveNext() if (moved) { - _stateObj._inBuff = _current.Buffer; - _stateObj._inBytesUsed = 0; - _stateObj._inBytesRead = _current.Read; - _stateObj._snapshotReplay = true; + _stateObj.SetBuffer(_current.Buffer, 0, _current.Read); + _stateObj._snapshotStatus = moveToMode; retval = true; } @@ -2932,6 +3020,22 @@ internal void MoveToStart() _stateObj.AssertValidState(); } + internal bool MoveToContinue() + { + if (ContinueEnabled) + { + if (_continuePacket != null && _continuePacket != _current) + { + _continueStateData.Restore(_stateObj); + _stateObj.SetBuffer(_current.Buffer, 0, _current.Read); + _stateObj._snapshotStatus = SnapshotStatus.ReplayRunning; + _stateObj.AssertValidState(); + return true; + } + } + return false; + } + internal void CaptureAsStart(TdsParserStateObject stateObj) { _firstPacket = null; @@ -2953,6 +3057,50 @@ internal void CaptureAsStart(TdsParserStateObject stateObj) AppendPacketData(stateObj._inBuff, stateObj._inBytesRead); } + internal void CaptureAsContinue(TdsParserStateObject stateObj) + { + if (ContinueEnabled) + { + Debug.Assert(_stateObj == stateObj); + if (_current is not null) + { + _continueStateData ??= new StateObjectData(); + _continueStateData.Capture(stateObj, trackStack: false); + _continuePacket = _current; + } + } + } + + internal void SetPacketPayloadSize(int size) + { + if (_current == null) + { + throw new InvalidOperationException(); + } + int total = 0; + if (_current.PrevPacket != null) + { + total = _current.PrevPacket.TotalSize; + } + _current.TotalSize = total + size; + } + internal int GetPacketDataTotalSize() + { + if (_current == null) + { + throw new InvalidOperationException(); + } + return _current.GetPacketTotalSize(); + } + internal int GetPacketDataOffset() + { + if (_current == null) + { + throw new InvalidOperationException(); + } + return _current.GetPacketDataOffset(); + } + internal void Clear() { ClearState(); @@ -2964,6 +3112,7 @@ private void ClearPackets() PacketData packet = _firstPacket; _firstPacket = null; _lastPacket = null; + _continuePacket = null; _current = null; packet.Clear(); _sparePacket = packet; @@ -2972,10 +3121,12 @@ private void ClearPackets() private void ClearState() { _replayStateData.Clear(_stateObj); + _continueStateData?.Clear(_stateObj, trackStack: false); #if DEBUG _rollingPend = 0; _rollingPendCount = 0; _stateObj._lastStack = null; + _packetCounter = 0; #endif _stateObj = null; } From c1af5b5aea7b7b80d43e5dbcc2c1c0f0d8d51891 Mon Sep 17 00:00:00 2001 From: Wraith2 Date: Tue, 29 Oct 2024 02:33:10 +0000 Subject: [PATCH 02/17] integrated multiplexer tests and align multiplexer with dev version --- .../SqlClient/TdsParserStateObject.netcore.cs | 3 +- .../SqlClient/TdsParserStateObject.netfx.cs | 2 +- .../TdsParserStateObject.Multiplexer.cs | 332 ++++++++---- .../Data/SqlClient/TdsParserStateObject.cs | 41 +- .../Microsoft.Data.SqlClient.Tests.csproj | 4 + .../tests/FunctionalTests/MultiplexerTests.cs | 510 ++++++++++++++++++ .../TdsParserStateObject.TestHarness.cs | 275 ++++++++++ 7 files changed, 1029 insertions(+), 138 deletions(-) create mode 100644 src/Microsoft.Data.SqlClient/tests/FunctionalTests/MultiplexerTests.cs create mode 100644 src/Microsoft.Data.SqlClient/tests/FunctionalTests/TdsParserStateObject.TestHarness.cs diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.netcore.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.netcore.cs index f9ca2975fc..690f9d884d 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.netcore.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.netcore.cs @@ -393,8 +393,9 @@ private void ReadSniError(TdsParserStateObject stateObj, uint error) AssertValidState(); } - public void ProcessSniPacket(PacketHandle packet, uint error, bool usePartialPacket = false) + private uint GetSniPacket(PacketHandle packet, ref uint dataSize) { + return SNIPacketGetData(packet, _inBuff, ref dataSize); if (error != 0) { if ((_parser.State == TdsParserState.Closed) || (_parser.State == TdsParserState.Broken)) diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.netfx.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.netfx.cs index 42c2908856..d7b6db7fc0 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.netfx.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.netfx.cs @@ -524,7 +524,7 @@ private void ReadSniError(TdsParserStateObject stateObj, uint error) AssertValidState(); } - public void ProcessSniPacket(PacketHandle packet, uint error, bool usePartialPacket = false) + private uint GetSniPacket(PacketHandle packet, ref uint dataSize) { if (error != 0) { diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.Multiplexer.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.Multiplexer.cs index fb610b607e..8c3f587ca6 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.Multiplexer.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.Multiplexer.cs @@ -3,10 +3,170 @@ namespace Microsoft.Data.SqlClient { +#if NETFRAMEWORK + using PacketHandle = IntPtr; +#endif partial class TdsParserStateObject { private Packet __partialPacket; - private Packet _partialPacket => __partialPacket; + internal Packet _partialPacket => __partialPacket; + + public void ProcessSniPacket(PacketHandle packet, uint error, bool usePartialPacket = false) + { + if (error != 0) + { + if ((_parser.State == TdsParserState.Closed) || (_parser.State == TdsParserState.Broken)) + { + // Do nothing with callback if closed or broken and error not 0 - callback can occur + // after connection has been closed. PROBLEM IN NETLIB - DESIGN FLAW. + return; + } + + AddError(_parser.ProcessSNIError(this)); + AssertValidState(); + } + else + { + uint dataSize = 0; + bool usedPartialPacket = false; + uint getDataError = 0; + + if (usePartialPacket && _snapshot == null && _partialPacket != null && _partialPacket.IsComplete) + { + //Debug.Assert(_snapshot == null, "_snapshot must be null when processing partial packet instead of network read"); + //Debug.Assert(_partialPacket != null, "_partialPacket must not be null when usePartialPacket is true"); + //Debug.Assert(_partialPacket.IsComplete, "_partialPacket.IsComplete must be true to use it in place of a real read"); + SetBuffer(_partialPacket.Buffer, 0, _partialPacket.CurrentLength); + ClearPartialPacket(); + getDataError = TdsEnums.SNI_SUCCESS; + usedPartialPacket = true; + } + else + { + getDataError = GetSniPacket(packet, ref dataSize); + } + + if (getDataError == TdsEnums.SNI_SUCCESS) + { + if (_inBuff.Length < dataSize) + { + Debug.Assert(true, "Unexpected dataSize on Read"); + throw SQL.InvalidInternalPacketSize(StringsHelper.GetString(Strings.SqlMisc_InvalidArraySizeMessage)); + } + + if (!usedPartialPacket) + { + _lastSuccessfulIOTimer._value = DateTime.UtcNow.Ticks; + + SetBuffer(_inBuff, 0, (int)dataSize); + } + + bool recurse; + bool appended = false; + do + { + MultiplexPackets( + _inBuff, _inBytesUsed, _inBytesRead, + _partialPacket, + out int newDataOffset, + out int newDataLength, + out Packet remainderPacket, + out bool consumeInputDirectly, + out bool consumePartialPacket, + out bool remainderPacketProduced, + out recurse + ); + bool bufferIsPartialCompleted = false; + + // if a partial packet was reconstructed it must be handled first + if (consumePartialPacket) + { + if (_snapshot != null) + { + _snapshot.AppendPacketData(_partialPacket.Buffer, _partialPacket.CurrentLength); + SetBuffer(new byte[_inBuff.Length], 0, 0); + appended = true; + } + else + { + SetBuffer(_partialPacket.Buffer, 0, _partialPacket.CurrentLength); + + } + bufferIsPartialCompleted = true; + ClearPartialPacket(); + } + + // if the remaining data can be processed directly it must be second + if (consumeInputDirectly) + { + // if some data was taken from the new packet adjust the counters + if (dataSize != newDataLength || 0 != newDataOffset) + { + SetBuffer(_inBuff, newDataOffset, newDataLength); + } + + if (_snapshot != null) + { + _snapshot.AppendPacketData(_inBuff, _inBytesRead); + SetBuffer(new byte[_inBuff.Length], 0, 0); + appended = true; + } + else + { + SetBuffer(_inBuff, 0, _inBytesRead); + } + bufferIsPartialCompleted = true; + } + else + { + // whatever is in the input buffer should not be directly consumed + // and is contained in the partial or remainder packets so make sure + // we don't process it + if (!bufferIsPartialCompleted) + { + SetBuffer(_inBuff, 0, 0); + } + } + + // if there is a remainder it must be last + if (remainderPacketProduced) + { + SetPartialPacket(remainderPacket); + if (!bufferIsPartialCompleted) + { + // we are keeping the partial packet buffer so replace it with a new one + // unless we have already set the buffer to the partial packet buffer + SetBuffer(new byte[_inBuff.Length], 0, 0); + } + } + + } while (recurse && _snapshot != null); + + if (_snapshot != null) + { + if (_snapshotStatus != SnapshotStatus.NotActive && appended) + { + _snapshot.MoveNext(); +#if DEBUG + // multiple packets can be appended by demuxing but we should only move + // forward by a single packet so we can no longer assert that we are on + // the last packet at this time + //_snapshot.AssertCurrent(); +#endif + } + } + + SniReadStatisticsAndTracing(); + SqlClientEventSource.Log.TryAdvancedTraceBinEvent("TdsParser.ReadNetworkPacketAsyncCallback | INFO | ADV | State Object Id {0}, Packet read. In Buffer {1}, In Bytes Read: {2}", ObjectID, _inBuff, (ushort)_inBytesRead); + + AssertValidState(); + } + else + { + throw SQL.ParsingError(ParsingErrorState.ProcessSniPacketFailed); + } + } + } private void SetPartialPacket(Packet packet/*, [CallerMemberName] string caller = null*/) { @@ -23,14 +183,11 @@ private void ClearPartialPacket(/*[CallerMemberName] string caller = null*/) __partialPacket = null; if (partialPacket != null) { - ReadOnlySpan header = partialPacket.GetHeaderSpan(); - int packetId = Packet.GetIDFromHeader(header); - bool isEOM = Packet.GetIsEOMFromHeader(header); partialPacket.Dispose(); } } - public static void MultiplexPackets( + private static void MultiplexPackets( byte[] dataBuffer, int dataOffset, int dataLength, Packet partialPacket, out int newDataOffset, @@ -63,20 +220,18 @@ out bool recurse if (headeBytesNeeded > 0) { int headerBytesAvailable = Math.Min(data.Length, headeBytesNeeded); + Span headerTarget = partialPacket.Buffer.AsSpan(partialPacket.CurrentLength, headerBytesAvailable); ReadOnlySpan headerSource = data.Slice(0, headerBytesAvailable); headerSource.CopyTo(headerTarget); + partialPacket.CurrentLength = partialPacket.CurrentLength + headerBytesAvailable; - data = data.Slice(headerBytesAvailable); bytesConsumed += headerBytesAvailable; + data = data.Slice(headerBytesAvailable); } if (partialPacket.HasHeader) { partialPacket.DataLength = Packet.GetDataLengthFromHeader(partialPacket); - //if (partialPacket.DataLength > dataBuffer.Length) - //{ - // Debugger.Break(); - //} } } @@ -86,43 +241,32 @@ out bool recurse { // the packet length is known so take as much data as possible from the incoming // data to try and complete the packet + int payloadBytesNeeded = partialPacket.DataLength - (partialPacket.CurrentLength - TdsEnums.HEADER_LEN); int payloadBytesAvailable = Math.Min(data.Length, payloadBytesNeeded); - Span payloadTarget = partialPacket.Buffer.AsSpan(partialPacket.CurrentLength, payloadBytesAvailable); + ReadOnlySpan payloadSource = data.Slice(0, payloadBytesAvailable); + Span payloadTarget = partialPacket.Buffer.AsSpan(partialPacket.CurrentLength, payloadBytesAvailable); payloadSource.CopyTo(payloadTarget); + partialPacket.CurrentLength = partialPacket.CurrentLength + payloadBytesAvailable; bytesConsumed += payloadBytesAvailable; data = data.Slice(payloadBytesAvailable); } else if (partialPacket.CurrentLength > partialPacket.RequiredLength) { - // the packet contains an entire packet and more data after that so we need - // to extract the following data into a new packet with a new buffer and return - // it as the remainer packet + // the partial packet contains a complete packet of data and then and also contains + // data from a following packet - int remainderLength = partialPacket.CurrentLength - partialPacket.RequiredLength; - remainderPacket = new Packet - { - Buffer = new byte[dataBuffer.Length], - CurrentLength = remainderLength, - }; - Buffer.BlockCopy( - partialPacket.Buffer, partialPacket.RequiredLength, // from - remainderPacket.Buffer, 0, // to - remainderPacket.CurrentLength // for - ); - partialPacket.CurrentLength = partialPacket.CurrentLength - remainderPacket.CurrentLength; - consumeRemainderPacket = true; + // the TDS spec requires that all packets be of the defined packet size apart from + // the last packet of a response. This means that is is not possible to have more than + // 2 packet fragments in a single buffer like this: + // - first packet caused the partial + // - second packet is the one we have just unpacked + // - third packet is the extra data we have found - if (remainderPacket.HasHeader) - { - remainderPacket.DataLength = Packet.GetDataLengthFromHeader(remainderPacket); - if (remainderPacket.HasDataLength && remainderPacket.CurrentLength >= remainderPacket.RequiredLength) - { - recurse = true; - } - } + // we must throw an exception because we have encountered an invalid tds stream + throw SQL.ParsingError(ParsingErrorState.CorruptedTdsStream); } if (partialPacket.CurrentLength == partialPacket.RequiredLength) @@ -136,12 +280,6 @@ out bool recurse { if (data.Length > 0) { - //if (data[0] == 120) - //{ - // var d = Vizualize(dataBuffer, dataOffset, dataLength); - // Debugger.Break(); - //} - // some data has been taken from the buffer, put into the partial // packet buffer and we have data left so move the data we have // left to the start of the buffer so we can pass the buffer back @@ -151,19 +289,16 @@ out bool recurse dataBuffer, dataOffset, // to dataLength - bytesConsumed // for ); - - //// for debugging purposes fill the removed data area with an easily - //// recognisable pattern so we can see if it is misused - //Span removed = dataBuffer.AsSpan(dataOffset + (dataLength - bytesConsumed), (dataOffset + bytesConsumed)); - //removed.Fill(0xFF); +#if DEBUG + // for debugging purposes fill the removed data area with an easily + // recognisable pattern so we can see if it is misused + Span removed = dataBuffer.AsSpan(dataOffset + (dataLength - bytesConsumed), (dataOffset + bytesConsumed)); + removed.Fill(0xFF); +#endif // then recreate the data span so that we're looking at the data // that has been moved data = dataBuffer.AsSpan(dataOffset, dataLength - bytesConsumed); - //if (data[0] == 120) - //{ - // Debugger.Break(); - //} } newDataLength = dataLength - bytesConsumed; @@ -177,10 +312,7 @@ out bool recurse // we have enough bytes to read the packet header and see how // much data we are expecting it to contain int packetDataLength = Packet.GetDataLengthFromHeader(data); - //if (packetDataLength > dataBuffer.Length) - //{ - // Debugger.Break(); - //} + if (data.Length == TdsEnums.HEADER_LEN + packetDataLength) { if (!consumePartialPacket) @@ -224,74 +356,64 @@ out bool recurse CurrentLength = data.Length }; consumeRemainderPacket = true; - Debug.Assert(remainderPacket.HasHeader); // precondition of entering this block - if (remainderPacket.HasDataLength && remainderPacket.CurrentLength >= remainderPacket.RequiredLength) - { - // the remainder packet contains more data than the packet so we need - // to tell the caller to recurse into this function again once they have - // consumed the first packet - recurse = true; - } } else // implied: current length > required length { - // more data than required so need to split it out but we can't do that - // here so we need to tell the caller to take the remainer packet and then - // call this function again + //// more data than required so need to split it out but we can't do that + //// here so we need to tell the caller to take the remainer packet and then + //// call this function again + + int remainderLength = data.Length - (TdsEnums.HEADER_LEN + packetDataLength); remainderPacket = new Packet { - Buffer = dataBuffer, - DataLength = packetDataLength, - CurrentLength = data.Length + Buffer = new byte[dataBuffer.Length], + CurrentLength = remainderLength, }; - consumeRemainderPacket = true; - recurse = true; - } - } - else - { - // we don't have enough information to read the header - if (!consumePartialPacket) - { - // we can tell the caller that they should directly consume the data in - // the input buffer, this is the happy path + + ReadOnlySpan remainderSource = data.Slice(TdsEnums.HEADER_LEN + packetDataLength); + Span remainderTarget = remainderPacket.Buffer.AsSpan(0, remainderLength); + remainderSource.CopyTo(remainderTarget); + + newDataLength = TdsEnums.HEADER_LEN + packetDataLength; consumeInputDirectly = true; - } - else - { - // we took some data from the input to reconstruct the partial packet - // so we can't tell the caller to directly consume the packet in the - // input buffer, we need to construct a new remainder packet and then - // tell them to consume it - remainderPacket = new Packet - { - Buffer = dataBuffer, - CurrentLength = data.Length - }; consumeRemainderPacket = true; + if (remainderPacket.HasHeader) { - remainderPacket.DataLength = Packet.GetDataLengthFromHeader(remainderPacket.GetHeaderSpan()); - } - if (remainderPacket.HasDataLength && remainderPacket.CurrentLength >= remainderPacket.RequiredLength) - { - // the remainder packet contains more data than the packet so we need - // to tell the caller to recurse into this function again once they have - // consumed the first packet - recurse = true; + remainderPacket.DataLength = Packet.GetDataLengthFromHeader(remainderPacket); + if (remainderPacket.HasDataLength && remainderPacket.CurrentLength >= remainderPacket.RequiredLength) + { + recurse = true; + } } } } + else + { + // we took some data from the input to reconstruct the partial packet + // so we can't tell the caller to directly consume the packet in the + // input buffer, we need to construct a new remainder packet and then + // tell them to consume it + remainderPacket = new Packet + { + Buffer = dataBuffer, + CurrentLength = data.Length + }; + consumeRemainderPacket = true; + } } - - if (remainderPacket != null && remainderPacket.HasHeader) - { - remainderPacket.Buffer[7] = 0xF; - } +#if DEBUG + //// the Window field is unused by the spec so it can be used as a marker + //// to identify reconstructed packets while debugging + //if (remainderPacket != null && remainderPacket.HasHeader) + //{ + // remainderPacket.Buffer[7] = 0xF; + //} +#endif if (consumePartialPacket && consumeInputDirectly) { - throw new InvalidOperationException($"AppendData cannot return both {nameof(consumePartialPacket)} and {nameof(consumeInputDirectly)}"); + throw new InvalidOperationException($"MultiplexPackets cannot return both {nameof(consumePartialPacket)} and {nameof(consumeInputDirectly)}"); } } } diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs index 004d78731f..49d0c45f96 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs @@ -2084,7 +2084,7 @@ internal TdsOperationStatus TryReadNetworkPacket() internal void PrepareReplaySnapshot() { _networkPacketTaskSource = null; - //if (!_snapshot.MoveToContinue()) + if (!_snapshot.MoveToContinue()) { _snapshot.MoveToStart(); } @@ -2651,27 +2651,8 @@ private sealed partial class PacketData public PacketData NextPacket; public PacketData PrevPacket; - public int DataOffset; - public int DataLength; - public int TotalSize; - internal int GetPacketTotalSize() - { - if (TotalSize == 0) - { - int previous = 0; - if (PrevPacket != null) - { - previous = PrevPacket.TotalSize; - } - return previous; - } - else - { - return TotalSize; - } - } internal int GetPacketDataOffset() { int previous = 0; @@ -2692,8 +2673,6 @@ internal void Clear() PrevPacket.NextPacket = null; PrevPacket = null; } - DataLength = 0; - DataOffset = 0; SetDebugStackInternal(null); SetDebugPacketIdInternal(0); } @@ -2752,6 +2731,14 @@ public PacketData[] Items public int PacketId; public string Stack; + public int PacketID => Packet.GetIDFromHeader(Buffer.AsSpan(0, TdsEnums.HEADER_LEN)); + + public int SPID => Packet.GetSpidFromHeader(Buffer.AsSpan(0, TdsEnums.HEADER_LEN)); + + public bool IsEOM => Packet.GetIsEOMFromHeader(Buffer.AsSpan(0, TdsEnums.HEADER_LEN)); + + public int DataLength => Packet.GetDataLengthFromHeader(Buffer.AsSpan(0, TdsEnums.HEADER_LEN)); + partial void SetDebugStackInternal(string value) => Stack = value; partial void SetDebugPacketIdInternal(int value) => PacketId = value; @@ -2924,7 +2911,6 @@ public bool ContinueEnabled if (_continueSupported == null) { _continueSupported = AppContext.TryGetSwitch("Switch.Microsoft.Data.SqlClient.UseExperimentalAsyncContinue", out bool value) ? value : false; - //_continueSupported = false; } return _continueSupported.Value; } @@ -3084,14 +3070,7 @@ internal void SetPacketPayloadSize(int size) } _current.TotalSize = total + size; } - internal int GetPacketDataTotalSize() - { - if (_current == null) - { - throw new InvalidOperationException(); - } - return _current.GetPacketTotalSize(); - } + internal int GetPacketDataOffset() { if (_current == null) diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/Microsoft.Data.SqlClient.Tests.csproj b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/Microsoft.Data.SqlClient.Tests.csproj index faa85ab5ce..5c8697cef8 100644 --- a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/Microsoft.Data.SqlClient.Tests.csproj +++ b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/Microsoft.Data.SqlClient.Tests.csproj @@ -31,6 +31,7 @@ + @@ -67,10 +68,12 @@ + + @@ -97,6 +100,7 @@ + diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/MultiplexerTests.cs b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/MultiplexerTests.cs new file mode 100644 index 0000000000..f42d95a38d --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/MultiplexerTests.cs @@ -0,0 +1,510 @@ +using System; +using System.Buffers.Binary; +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.Data.SqlClient.Tests +{ + public class MultiplexerTests + { + public static IEnumerable IsAsync() { yield return new object[] { false }; yield return new object[] { true }; } + + [Theory, MemberData(nameof(IsAsync))] + public static void PassThroughSinglePacket(bool isAsync) + { + int dataSize = 20; + var a = CreatePacket(dataSize, 0xF); + List input = new List { a }; + List expected = new List { a }; + + Assert.Equal(SumPacketLengths(expected), SumPacketLengths(input)); + + var output = MultiplexPacketList(isAsync, dataSize, input); + + ComparePacketLists(dataSize, expected, output); + } + + [Theory, MemberData(nameof(IsAsync))] + public static void PassThroughMultiplePacket(bool isAsync) + { + int dataSize = 40; + List input = CreatePackets(dataSize, 5, 6, 7, 8); + List expected = input; + + Assert.Equal(SumPacketLengths(expected), SumPacketLengths(input)); + + var output = MultiplexPacketList(isAsync, dataSize, input); + + ComparePacketLists(dataSize, expected, output); + } + + [Theory, MemberData(nameof(IsAsync))] + public static void PassThroughMultiplePacketWithShortEnd(bool isAsync) + { + int dataSize = 40; + List input = CreatePackets((dataSize, 20), 5, 6, 7, 8); + List expected = input; + + Assert.Equal(SumPacketLengths(expected), SumPacketLengths(input)); + + var output = MultiplexPacketList(isAsync, dataSize, input); + + ComparePacketLists(dataSize, expected, output); + } + + [Theory, MemberData(nameof(IsAsync))] + public static void ReconstructSinglePacket(bool isAsync) + { + int dataSize = 4; + var a = CreatePacket(dataSize, 0xF); + List input = SplitPacket(a, 1); + List expected = new List { a }; + + Assert.Equal(SumPacketLengths(expected), SumPacketLengths(input)); + + var output = MultiplexPacketList(isAsync, dataSize, input); + + ComparePacketLists(dataSize, expected, output); + } + + [Theory, MemberData(nameof(IsAsync))] + public static void Reconstruct2Packets_Part_PartFull(bool isAsync) + { + int dataSize = 4; + var expected = CreatePackets(dataSize, 0xAA, 0xBB); + + var input = SplitPackets(dataSize, expected, + 6, // partial first packet + (6 + 6), // end of packet 0, start of packet 1 + 6 // end of packet 1 + ); + + Assert.Equal(SumPacketLengths(expected), SumPacketLengths(input)); + + var output = MultiplexPacketList(isAsync, dataSize, input); + + ComparePacketLists(dataSize, expected, output); + } + + [Theory, MemberData(nameof(IsAsync))] + public static void Reconstruct2Packets_Full_FullPart_Part(bool isAsync) + { + int dataSize = 30; + var expected = new List + { + CreatePacket(30, 5), + CreatePacket(10, 6), + CreatePacket(30, 7) + }; + + var input = SplitPackets(38, expected, + (8 + 30), // full + (8 + 10) + (8 + 12), // full, part next + 18 // part end + ); + + Assert.Equal(SumPacketLengths(expected), SumPacketLengths(input)); + + var output = MultiplexPacketList(isAsync, dataSize, input); + + ComparePacketLists(dataSize, expected, output); + } + + [Theory, MemberData(nameof(IsAsync))] + public static void ReconstructMultiplePacketSequence(bool isAsync) + { + int dataSize = 40; + List expected = CreatePackets(dataSize, 5, 6, 7, 8); + List input = SplitPackets(dataSize, expected, + (8 + 40), + (8 + 23), + (17) + (8 + 23), + (17) + (8 + 23), + (17) + ); + + Assert.Equal(SumPacketLengths(expected), SumPacketLengths(input)); + + var output = MultiplexPacketList(isAsync, dataSize, input); + + ComparePacketLists(dataSize, expected, output); + } + + [Theory, MemberData(nameof(IsAsync))] + public static void ReconstructMultiplePacketSequenceWithShortEnd(bool isAsync) + { + int dataSize = 40; + List expected = CreatePackets((dataSize, 20), 5, 6, 7, 8); + List input = SplitPackets(dataSize, expected, + (8 + 40), + (8 + 23), + (17) + (8 + 23), + (17) + (8 + 20) + ); + + Assert.Equal(SumPacketLengths(expected), SumPacketLengths(input)); + + var output = MultiplexPacketList(isAsync, dataSize, input); + + ComparePacketLists(dataSize, expected, output); + } + + [Theory, MemberData(nameof(IsAsync))] + public static void FailReconstruct2Packets_FullFullPart_Part(bool isAsync) + { + // illegal, cannot have multiple packets end in a single packet because all packets except an end packet must + // be be of max length, thus only max length packets can exist before a short packet. + int maxDataSize = 46; + + var expected = new List + { + CreatePacket(10, 5), + CreatePacket(10, 6), + CreatePacket(30, 7) + }; + + var input = SplitPackets(maxDataSize, expected, + (8 + 10) + (8 + 10) + (8 + 2), // full, full, part + 36 // part + ); + + Assert.Throws( + () => MultiplexPacketList(isAsync, maxDataSize, input) + ); + + } + + + private static List MultiplexPacketList(bool isAsync, int dataSize, List input) + { + var stateObject = new TdsParserStateObject(input, TdsEnums.HEADER_LEN + dataSize, isAsync); + var output = new List(); + + for (int index = 0; index < input.Count; index++) + { + stateObject.Current = input[index]; + + stateObject.ProcessSniPacket(default, 0, usePartialPacket: false); + + if (stateObject._inBytesRead > 0) + { + if ( + stateObject._inBytesRead < TdsEnums.HEADER_LEN + || + stateObject._inBytesRead != (TdsEnums.HEADER_LEN + Packet.GetDataLengthFromHeader(stateObject._inBuff.AsSpan(0, TdsEnums.HEADER_LEN))) + ) + { + Assert.Fail("incomplete packet exposed after call to ProcessSniPacket"); + } + if (!isAsync) + { + output.Add(PacketData.Copy(stateObject._inBuff, stateObject._inBytesUsed, stateObject._inBytesRead)); + } + } + } + + + if (!isAsync) + { + if (stateObject._partialPacket != null) + { + stateObject.Current = default; + + stateObject.ProcessSniPacket(default, 0, usePartialPacket: true); + + if (stateObject._inBytesRead > 0) + { + if ( + stateObject._inBytesRead < TdsEnums.HEADER_LEN + || + stateObject._inBytesRead != (TdsEnums.HEADER_LEN + Packet.GetDataLengthFromHeader(stateObject._inBuff.AsSpan(0, TdsEnums.HEADER_LEN))) + ) + { + Assert.Fail("incomplete packet exposed after call to ProcessSniPacket with usePartialPacket"); + } + if (!isAsync) + { + output.Add(PacketData.Copy(stateObject._inBuff, stateObject._inBytesUsed, stateObject._inBytesRead)); + } + } + } + + } + else + { + output = stateObject._snapshot.List; + } + + return output; + } + + private static void ComparePacketLists(int dataSize, List expected, List output) + { + Assert.NotNull(expected); + Assert.NotNull(output); + Assert.Equal(expected.Count, output.Count); + + for (int index = 0; index < expected.Count; index++) + { + var a = expected[index]; + var b = output[index]; + + var compare = a.AsSpan().SequenceCompareTo(b.AsSpan()); + + if (compare != 0) + { + Assert.Fail($"expected data does not match output data at packet index {index}"); + } + } + } + + public static PacketData CreatePacket(int dataSize, byte dataValue, int startOffset = 0, int endPadding = 0) + { + byte[] buffer = new byte[startOffset + TdsEnums.HEADER_LEN + dataSize + endPadding]; + Span packet = buffer.AsSpan(startOffset, TdsEnums.HEADER_LEN + dataSize); + WritePacket(packet, dataSize, dataValue, 1); + return new PacketData(buffer, startOffset, buffer.Length - endPadding); + } + + public static List CreatePackets(DataSize sizes, params byte[] dataValues) + { + int count = dataValues.Length; + List list = new List(count); + + for (byte index = 0; index < count; index++) + { + int dataSize = sizes.GetSize(index == dataValues.Length - 1); + int packetSize = TdsEnums.HEADER_LEN + dataSize; + byte[] array = new byte[packetSize]; + WritePacket(array, dataSize, dataValues[index], index); + list.Add(new PacketData(array, 0, packetSize)); + } + + return list; + } + + private static void WritePacket(Span buffer, int dataSize, byte dataValue, byte id) + { + Span header = buffer.Slice(0, TdsEnums.HEADER_LEN); + header[0] = 4; // Type, 4 - Raw Data + header[1] = 0; // Status, 0 - normal message + BinaryPrimitives.TryWriteInt16BigEndian(header.Slice(TdsEnums.HEADER_LEN_FIELD_OFFSET, 2), (short)(TdsEnums.HEADER_LEN + dataSize)); // total length + BinaryPrimitives.TryWriteInt16BigEndian(header.Slice(TdsEnums.SPID_OFFSET, 2), short.MaxValue); // SPID + header[TdsEnums.HEADER_LEN_FIELD_OFFSET + 4] = id; // PacketID + header[TdsEnums.HEADER_LEN_FIELD_OFFSET + 5] = 0; // Window + + Span data = buffer.Slice(TdsEnums.HEADER_LEN, dataSize); + data.Fill(dataValue); + } + + public static List SplitPacket(PacketData packet, int length) + { + List list = new List(2); + while (packet.Length > length) + { + list.Add(new PacketData(packet.Array, packet.Start, length)); + packet = new PacketData(packet.Array, packet.Start + length, packet.Length - length); + } + if (packet.Length > 0) + { + list.Add(packet); + } + return list; + } + + public static List SplitPackets(int dataSize, List packets, params int[] lengths) + { + List list = new List(lengths.Length); + int packetSize = TdsEnums.HEADER_LEN + dataSize; + byte[][] arrays = new byte[lengths.Length][]; + for (int index = 0; index < lengths.Length; index++) + { + if (lengths[index] > packetSize) + { + throw new ArgumentOutOfRangeException($"segment size of an individual part cannot exceed the packet buffer size of the state object, max packet size: {packetSize}, supplied length: {lengths[index]}, at index: {index}"); + } + arrays[index] = new byte[lengths[index]]; + } + + int targetOffset = 0; + int targetIndex = 0; + + int sourceOffset = 0; + int sourceIndex = 0; + + + do + { + Span targetSpan = Span.Empty; + if (targetOffset < arrays[targetIndex].Length) + { + targetSpan = arrays[targetIndex].AsSpan(targetOffset); + } + else + { + targetIndex += 1; + targetOffset = 0; + continue; + } + + Span sourceSpan = Span.Empty; + if (sourceOffset < packets[sourceIndex].Length) + { + sourceSpan = packets[sourceIndex].AsSpan(sourceOffset); + } + else + { + sourceIndex += 1; + sourceOffset = 0; + continue; + } + + int copy = Math.Min(targetSpan.Length, sourceSpan.Length); + if (copy > 0) + { + targetOffset += copy; + sourceOffset += copy; + sourceSpan.Slice(0, copy).CopyTo(targetSpan.Slice(0, copy)); + } + + + } while (sourceIndex < packets.Count && targetIndex < arrays.Length); + + foreach (var array in arrays) + { + list.Add(new PacketData(array, 0, array.Length)); + } + + return list; + } + + + public static int PacketSizeFromDataSize(int dataSize) => TdsEnums.HEADER_LEN + dataSize; + + public static int DataSizeFromPacketSize(int packetSize) => packetSize - TdsEnums.HEADER_LEN; + + public static int SumPacketLengths(List list) + { + int total = 0; + for (int index = 0; index < list.Count; index++) + { + total += list[index].Length; + } + return total; + } + } + + + [DebuggerDisplay("{ToDebugString(),nq}")] + public readonly struct PacketData + { + public readonly byte[] Array; + public readonly int Start; + public readonly int Length; + + public PacketData(byte[] array, int start, int length) + { + Array = array; + Start = start; + Length = length; + } + + public Span AsSpan() + { + return Array == null ? Span.Empty : Array.AsSpan(Start, Length); + } + + public Span AsSpan(int start) + { + Span span = AsSpan(); + return span.Slice(start); + } + + public static PacketData Copy(byte[] array, int start, int length) + { + byte[] newArray = null; + if (array != null) + { + newArray = new byte[array.Length]; + Buffer.BlockCopy(array, start, newArray, start, length); + } + return new PacketData(newArray, start, length); + } + + [ExcludeFromCodeCoverage] + public string ToDebugString() + { + StringBuilder buffer = new StringBuilder(128); + buffer.Append(Length); + + if (Array != null && Array.Length > 0) + { + if (Array.Length != Length) + { + buffer.AppendFormat(" (arr: {0})", Array.Length); + } + buffer.Append(": {"); + buffer.AppendFormat("{0:D2}", Array[0]); + + int max = Math.Min(32, Array.Length); + for (int index = 1; index < max; index++) + { + buffer.Append(','); + buffer.AppendFormat("{0:D2}", Array[index]); + } + if (Length > max) + { + buffer.Append(" ..."); + } + buffer.Append('}'); + } + return buffer.ToString(); + } + + } + + [DebuggerStepThrough] + public struct DataSize + { + public DataSize(int commonSize) + { + CommonSize = commonSize; + LastSize = commonSize; + } + public DataSize(int commonSize, int lastSize) + { + CommonSize = commonSize; + LastSize = lastSize; + } + + public int LastSize { get; set; } + public int CommonSize { get; set; } + + public int GetSize(bool isLast) + { + if (isLast) + { + return LastSize; + } + else + { + return CommonSize; + } + } + + public static implicit operator DataSize(int commonSize) + { + return new DataSize( commonSize, commonSize ); + } + + public static implicit operator DataSize((int commonSize, int lastSize) values) + { + return new DataSize( values.commonSize, values.lastSize ); + } + } +} diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/TdsParserStateObject.TestHarness.cs b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/TdsParserStateObject.TestHarness.cs new file mode 100644 index 0000000000..77d9319b00 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/TdsParserStateObject.TestHarness.cs @@ -0,0 +1,275 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using Microsoft.Data.SqlClient.Tests; + +namespace Microsoft.Data.SqlClient +{ +#if NETFRAMEWORK + using PacketHandle = IntPtr; +#elif NETCOREAPP + internal struct PacketHandle + { + } +#endif + internal partial class TdsParserStateObject + { + internal int ObjectID = 1; + + internal class SQL + { + internal static Exception InvalidInternalPacketSize(string v) => throw new Exception(v ?? nameof(InvalidInternalPacketSize)); + + internal static Exception ParsingError(ParsingErrorState state) => throw new Exception(state.ToString()); + } + + internal static class SqlClientEventSource + { + internal static class Log + { + internal static void TryAdvancedTraceBinEvent(string message, params object[] values) + { + } + } + } + + private enum SnapshotStatus + { + NotActive, + ReplayStarting, + ReplayRunning + } + + internal enum TdsParserState + { + Closed, + OpenNotLoggedIn, + OpenLoggedIn, + Broken, + } + + private uint GetSniPacket(PacketHandle packet, ref uint dataSize) + { + return SNIPacketGetData(packet, _inBuff, ref dataSize); + } + + private class StringsHelper + { + internal static string GetString(string sqlMisc_InvalidArraySizeMessage) => Strings.SqlMisc_InvalidArraySizeMessage; + } + + internal class Strings + { + internal static string SqlMisc_InvalidArraySizeMessage = nameof(SqlMisc_InvalidArraySizeMessage); + + } + + public class Parser + { + internal object ProcessSNIError(TdsParserStateObject tdsParserStateObject) => "ProcessSNIError"; + public TdsParserState State = TdsParserState.OpenLoggedIn; + } + + sealed internal class LastIOTimer + { + internal long _value; + } + + internal sealed class Snapshot + { + public List List; + + public Snapshot() => List = new List(); + [DebuggerStepThrough] + internal void AppendPacketData(byte[] buffer, int read) => List.Add(new PacketData(buffer, 0, read)); + [DebuggerStepThrough] + internal void MoveNext() + { + + } + } + + public List Input; + public PacketData Current; + public bool IsAsync { get => _snapshot != null; } + + public int _packetSize; + + internal Snapshot _snapshot; + public int _inBytesRead; + public int _inBytesUsed; + public byte[] _inBuff; + [DebuggerStepThrough] + public TdsParserStateObject(List input, int packetSize, bool isAsync) + { + _packetSize = packetSize; + _inBuff = new byte[_packetSize]; + Input = input; + if (isAsync) + { + _snapshot = new Snapshot(); + } + } + [DebuggerStepThrough] + private uint SNIPacketGetData(PacketHandle packet, byte[] inBuff, ref uint dataSize) + { + Span target = inBuff.AsSpan(0, _packetSize); + Span source = Current.Array.AsSpan(Current.Start, Current.Length); + source.CopyTo(target); + dataSize = (uint)Current.Length; + return TdsEnums.SNI_SUCCESS; + } + + [DebuggerStepThrough] + void SetBuffer(byte[] buffer, int inBytesUsed, int inBytesRead) + { + _inBuff = buffer; + _inBytesUsed = inBytesUsed; + _inBytesRead = inBytesRead; + } + + + + // stubs + private LastIOTimer _lastSuccessfulIOTimer = new LastIOTimer(); + private Parser _parser = new Parser(); + private SnapshotStatus _snapshotStatus = SnapshotStatus.NotActive; + + + [DebuggerStepThrough] + private void SniReadStatisticsAndTracing() { } + [DebuggerStepThrough] + private void AssertValidState() { } + [DebuggerStepThrough] + private void AddError(object value) => throw new Exception(value as string ?? "AddError"); + } + + internal static class TdsEnums + { + public const uint SNI_SUCCESS = 0; // The operation completed successfully. + // header constants + public const int HEADER_LEN = 8; + public const int HEADER_LEN_FIELD_OFFSET = 2; + public const int SPID_OFFSET = 4; + } + + internal enum ParsingErrorState + { + CorruptedTdsStream = 18, + ProcessSniPacketFailed = 19, + } + + internal sealed class Packet + { + public const int UnknownDataLength = -1; + + private bool _disposed; + private int _dataLength; + private int _totalLength; + private byte[] _buffer; + + public Packet() + { + _disposed = false; + _dataLength = UnknownDataLength; + } + + public int DataLength + { + get + { + CheckDisposed(); + return _dataLength; + } + set + { + CheckDisposed(); + _dataLength = value; + } + } + public byte[] Buffer + { + get + { + CheckDisposed(); + return _buffer; + } + set + { + CheckDisposed(); + _buffer = value; + } + } + public int CurrentLength + { + get + { + CheckDisposed(); + return _totalLength; + } + set + { + CheckDisposed(); + _totalLength = value; + } + } + + public int RequiredLength + { + get + { + CheckDisposed(); + if (!HasDataLength) + { + throw new InvalidOperationException($"cannot get {nameof(RequiredLength)} when {nameof(HasDataLength)} is false"); + } + return TdsEnums.HEADER_LEN + _dataLength; + } + } + + public bool HasHeader => _totalLength >= TdsEnums.HEADER_LEN; + + public bool HasDataLength => _dataLength >= 0; + + public bool IsComplete => _dataLength != UnknownDataLength && (TdsEnums.HEADER_LEN + _dataLength) == _totalLength; + + public ReadOnlySpan GetHeaderSpan() => _buffer.AsSpan(0, TdsEnums.HEADER_LEN); + + public void Dispose() + { + _disposed = true; + } + + public void CheckDisposed() + { + if (_disposed) + { + ThrowDisposed(); + } + } + + public static void ThrowDisposed() + { + throw new ObjectDisposedException(nameof(Packet)); + } + + internal static byte GetStatusFromHeader(ReadOnlySpan header) => header[1]; + + internal static int GetDataLengthFromHeader(ReadOnlySpan header) + { + return (header[TdsEnums.HEADER_LEN_FIELD_OFFSET] << 8 | header[TdsEnums.HEADER_LEN_FIELD_OFFSET + 1]) - TdsEnums.HEADER_LEN; + } + internal static int GetSpidFromHeader(ReadOnlySpan header) + { + return (header[TdsEnums.SPID_OFFSET] << 8 | header[TdsEnums.SPID_OFFSET + 1]); + } + internal static int GetIDFromHeader(ReadOnlySpan header) + { + return header[TdsEnums.HEADER_LEN_FIELD_OFFSET + 4]; + } + + internal static int GetDataLengthFromHeader(Packet packet) => GetDataLengthFromHeader(packet.GetHeaderSpan()); + + internal static bool GetIsEOMFromHeader(ReadOnlySpan header) => GetStatusFromHeader(header) == 1; + } +} From 58759e3e09175ef9afd14d9ea4e646573f7075f4 Mon Sep 17 00:00:00 2001 From: Wraith2 Date: Wed, 24 Jul 2024 21:41:01 +0100 Subject: [PATCH 03/17] address feedback, minor fixes and tuning --- .../Microsoft/Data/SqlClient/SqlDataReader.cs | 2 +- .../src/Microsoft/Data/SqlClient/TdsParser.cs | 2 - .../SqlClient/TdsParserStateObject.netcore.cs | 2 +- .../Microsoft/Data/SqlClient/SqlDataReader.cs | 2 +- .../src/Microsoft/Data/SqlClient/TdsParser.cs | 5 +- .../src/Microsoft/Data/SqlClient/Packet.cs | 10 +- .../TdsParserStateObject.Multiplexer.cs | 33 ++--- .../Data/SqlClient/TdsParserStateObject.cs | 130 +++--------------- .../tests/FunctionalTests/MultiplexerTests.cs | 9 -- .../TdsParserStateObject.TestHarness.cs | 9 +- 10 files changed, 47 insertions(+), 157 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs index 85c4516f91..c3923ce497 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs @@ -4198,7 +4198,7 @@ private TdsOperationStatus TryResetBlobState() #if DEBUG else { - //Debug.Assert((_sharedState._columnDataBytesRemaining == 0 || _sharedState._columnDataBytesRemaining == -1) && _stateObj._longlen == 0, "Haven't read header yet, but column is partially read?"); + Debug.Assert((_sharedState._columnDataBytesRemaining == 0 || _sharedState._columnDataBytesRemaining == -1) && _stateObj._longlen == 0, "Haven't read header yet, but column is partially read?"); } #endif diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs index 55e4a71788..4945c0f58a 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -2412,7 +2412,6 @@ internal TdsOperationStatus TryRun(RunBehavior runBehavior, SqlCommand cmdHandle } } - // temporarily cache next byte byte peekedToken; result = stateObj.TryPeekByte(out peekedToken); if (result != TdsOperationStatus.Done) @@ -4164,7 +4163,6 @@ internal TdsOperationStatus TryProcessReturnValue(int length, TdsParserStateObje return result; } - // Length of parameter name byte len; result = stateObj.TryReadByte(out len); if (result != TdsOperationStatus.Done) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.netcore.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.netcore.cs index 690f9d884d..07f10619fc 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.netcore.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.netcore.cs @@ -1731,7 +1731,7 @@ internal void AssertStateIsClean() if ((parser != null) && (parser.State != TdsParserState.Closed) && (parser.State != TdsParserState.Broken)) { // Async reads - Debug.Assert(_snapshot == null && _snapshotStatus == SnapshotStatus.NotActive); + Debug.Assert(_snapshot == null && _snapshotStatus == SnapshotStatus.NotActive, "StateObj has leftover snapshot state"); Debug.Assert(!_asyncReadWithoutSnapshot, "StateObj has AsyncReadWithoutSnapshot still enabled"); Debug.Assert(_executionContext == null, "StateObj has a stored execution context from an async read"); // Async writes diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDataReader.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDataReader.cs index 5a8254aa64..e8f4938964 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDataReader.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDataReader.cs @@ -4745,7 +4745,7 @@ private TdsOperationStatus TryResetBlobState() #if DEBUG else { - //Debug.Assert((_sharedState._columnDataBytesRemaining == 0 || _sharedState._columnDataBytesRemaining == -1) && _stateObj._longlen == 0, "Haven't read header yet, but column is partially read?"); + Debug.Assert((_sharedState._columnDataBytesRemaining == 0 || _sharedState._columnDataBytesRemaining == -1) && _stateObj._longlen == 0, "Haven't read header yet, but column is partially read?"); } #endif diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs index c5500a4da3..a6853f40c9 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -2761,7 +2761,6 @@ internal TdsOperationStatus TryRun(RunBehavior runBehavior, SqlCommand cmdHandle result = stateObj.TryPeekByte(out peekedToken); if (result != TdsOperationStatus.Done) { - // temporarily cache next byte return result; } @@ -4614,12 +4613,14 @@ internal TdsOperationStatus TryProcessReturnValue(int length, return result; } } - byte len; // Length of parameter name + + byte len; result = stateObj.TryReadByte(out len); if (result != TdsOperationStatus.Done) { return result; } + rec.parameter = null; if (len > 0) { diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Packet.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Packet.cs index 5b0463008f..8d02163b10 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Packet.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Packet.cs @@ -1,4 +1,8 @@ -using System; +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; namespace Microsoft.Data.SqlClient { @@ -27,10 +31,6 @@ public int DataLength set { CheckDisposed(); - //if (value > 7992) - //{ - // Debugger.Break(); - //} _dataLength = value; } } diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.Multiplexer.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.Multiplexer.cs index 8c3f587ca6..55d6323247 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.Multiplexer.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.Multiplexer.cs @@ -1,4 +1,8 @@ -using System; +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; using System.Diagnostics; namespace Microsoft.Data.SqlClient @@ -33,9 +37,6 @@ public void ProcessSniPacket(PacketHandle packet, uint error, bool usePartialPac if (usePartialPacket && _snapshot == null && _partialPacket != null && _partialPacket.IsComplete) { - //Debug.Assert(_snapshot == null, "_snapshot must be null when processing partial packet instead of network read"); - //Debug.Assert(_partialPacket != null, "_partialPacket must not be null when usePartialPacket is true"); - //Debug.Assert(_partialPacket.IsComplete, "_partialPacket.IsComplete must be true to use it in place of a real read"); SetBuffer(_partialPacket.Buffer, 0, _partialPacket.CurrentLength); ClearPartialPacket(); getDataError = TdsEnums.SNI_SUCCESS; @@ -147,12 +148,6 @@ out recurse if (_snapshotStatus != SnapshotStatus.NotActive && appended) { _snapshot.MoveNext(); -#if DEBUG - // multiple packets can be appended by demuxing but we should only move - // forward by a single packet so we can no longer assert that we are on - // the last packet at this time - //_snapshot.AssertCurrent(); -#endif } } @@ -216,10 +211,10 @@ out bool recurse if (!partialPacket.HasDataLength) { // we need to get enough bytes to read the packet header - int headeBytesNeeded = Math.Max(0, TdsEnums.HEADER_LEN - partialPacket.CurrentLength); - if (headeBytesNeeded > 0) + int headerBytesNeeded = Math.Max(0, TdsEnums.HEADER_LEN - partialPacket.CurrentLength); + if (headerBytesNeeded > 0) { - int headerBytesAvailable = Math.Min(data.Length, headeBytesNeeded); + int headerBytesAvailable = Math.Min(data.Length, headerBytesNeeded); Span headerTarget = partialPacket.Buffer.AsSpan(partialPacket.CurrentLength, headerBytesAvailable); ReadOnlySpan headerSource = data.Slice(0, headerBytesAvailable); @@ -255,7 +250,7 @@ out bool recurse } else if (partialPacket.CurrentLength > partialPacket.RequiredLength) { - // the partial packet contains a complete packet of data and then and also contains + // the partial packet contains a complete packet of data and also contains // data from a following packet // the TDS spec requires that all packets be of the defined packet size apart from @@ -334,8 +329,8 @@ out bool recurse CurrentLength = data.Length }; consumeRemainderPacket = true; - //Debug.Assert(remainderPacket.HasHeader); // precondition of entering this block - //Debug.Assert(remainderPacket.HasDataLength); // must have been set at construction + Debug.Assert(remainderPacket.HasHeader); // precondition of entering this block + Debug.Assert(remainderPacket.HasDataLength); // must have been set at construction if (remainderPacket.CurrentLength >= remainderPacket.RequiredLength) { // the remainder packet contains more data than the packet so we need @@ -359,9 +354,9 @@ out bool recurse } else // implied: current length > required length { - //// more data than required so need to split it out but we can't do that - //// here so we need to tell the caller to take the remainer packet and then - //// call this function again + // more data than required so need to split it out but we can't do that + // here so we need to tell the caller to take the remainder packet and then + // call this function again int remainderLength = data.Length - (TdsEnums.HEADER_LEN + packetDataLength); remainderPacket = new Packet diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs index 49d0c45f96..834fa00f20 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs @@ -1830,6 +1830,7 @@ internal int ReadPlpBytesChunk(byte[] buff, int offset, int len) // Every time you call this method increment the offset and decrease len by the value of totalBytesRead internal TdsOperationStatus TryReadPlpBytes(ref byte[] buff, int offset, int len, out int totalBytesRead) { + totalBytesRead = 0; int bytesRead; int bytesLeft; byte[] newbuf; @@ -1861,11 +1862,6 @@ internal TdsOperationStatus TryReadPlpBytes(ref byte[] buff, int offset, int len // and try to use it if it is the right length buff = _snapshot._plpBuffer; _snapshot._plpBuffer = null; - if (_snapshot.ContinueEnabled) - { - offset = _snapshot.GetPacketDataOffset(); - totalBytesRead = offset; - } } if ((ulong)(buff?.Length ?? 0) != _longlen) @@ -1895,9 +1891,6 @@ internal TdsOperationStatus TryReadPlpBytes(ref byte[] buff, int offset, int len { buff = new byte[_longlenleft]; } - - totalBytesRead = 0; - while (bytesLeft > 0) { int bytesToRead = (int)Math.Min(_longlenleft, (ulong)bytesLeft); @@ -1909,7 +1902,7 @@ internal TdsOperationStatus TryReadPlpBytes(ref byte[] buff, int offset, int len buff = newbuf; } - bool result = TryReadByteArray(buff.AsSpan(offset), bytesToRead, out bytesRead); + TdsOperationStatus result = TryReadByteArray(buff.AsSpan(offset), bytesToRead, out bytesRead); Debug.Assert(bytesRead <= bytesLeft, "Read more bytes than we needed"); Debug.Assert((ulong)bytesRead <= _longlenleft, "Read more bytes than is available"); @@ -1917,7 +1910,7 @@ internal TdsOperationStatus TryReadPlpBytes(ref byte[] buff, int offset, int len offset += bytesRead; totalBytesRead += bytesRead; _longlenleft -= (ulong)bytesRead; - if (!result) + if (result != TdsOperationStatus.Done) { if (_snapshot != null) { @@ -1931,7 +1924,8 @@ internal TdsOperationStatus TryReadPlpBytes(ref byte[] buff, int offset, int len if (_longlenleft == 0) { // Read the next chunk or cleanup state if hit the end - if (!TryReadPlpLength(false, out _)) + result = TryReadPlpLength(false, out _); + if (result != TdsOperationStatus.Done) { if (_snapshot != null) { @@ -1939,7 +1933,7 @@ internal TdsOperationStatus TryReadPlpBytes(ref byte[] buff, int offset, int len // so it can be re-used when another packet arrives and we read again _snapshot._plpBuffer = buff; } - return false; + return result; } } @@ -2005,10 +1999,14 @@ internal TdsOperationStatus TryReadNetworkPacket() if (_snapshotStatus != SnapshotStatus.NotActive) { #if DEBUG - // in debug builds stack traces contain line numbers so if we want to be - // able to compare the stack traces they must all be created in the same - // location in the code - string stackTrace = Environment.StackTrace; + string stackTrace = null; + if (s_checkNetworkPacketRetryStacks) + { + // in debug builds stack traces contain line numbers so if we want to be + // able to compare the stack traces they must all be created in the same + // location in the code + stackTrace = Environment.StackTrace; + } #endif if (_snapshot.MoveNext()) { @@ -2020,22 +2018,15 @@ internal TdsOperationStatus TryReadNetworkPacket() #endif return TdsOperationStatus.Done; } -#if DEBUG else { +#if DEBUG if (s_checkNetworkPacketRetryStacks) { _lastStack = stackTrace; } - - if (_bTmpRead == 0 && _partialHeaderBytesRead == 0 && _longlenleft == 0 && _snapshot.ContinueEnabled) - { - // no temp between packets - // mark this point as continue-able - _snapshot.CaptureAsContinue(this); - } - } #endif + } } // previous buffer is in snapshot @@ -2084,10 +2075,7 @@ internal TdsOperationStatus TryReadNetworkPacket() internal void PrepareReplaySnapshot() { _networkPacketTaskSource = null; - if (!_snapshot.MoveToContinue()) - { - _snapshot.MoveToStart(); - } + _snapshot.MoveToStart(); } internal void ReadSniSyncOverAsync() @@ -2651,18 +2639,6 @@ private sealed partial class PacketData public PacketData NextPacket; public PacketData PrevPacket; - public int TotalSize; - - internal int GetPacketDataOffset() - { - int previous = 0; - if (PrevPacket != null) - { - previous = PrevPacket.TotalSize; - } - return TotalSize - (TotalSize - previous); - } - internal void Clear() { Buffer = null; @@ -2850,18 +2826,14 @@ internal void Restore(TdsParserStateObject stateObj) private TdsParserStateObject _stateObj; private StateObjectData _replayStateData; - private StateObjectData _continueStateData; internal byte[] _plpBuffer; private PacketData _lastPacket; private PacketData _firstPacket; private PacketData _current; - private PacketData _continuePacket; private PacketData _sparePacket; - private bool? _continueSupported; - #if DEBUG private int _packetCounter; private int _rollingPend = 0; @@ -2904,18 +2876,6 @@ internal void CheckStack(string trace) } #endif - public bool ContinueEnabled - { - get - { - if (_continueSupported == null) - { - _continueSupported = AppContext.TryGetSwitch("Switch.Microsoft.Data.SqlClient.UseExperimentalAsyncContinue", out bool value) ? value : false; - } - return _continueSupported.Value; - } - } - internal void CloneNullBitmapInfo() { if (_stateObj._nullBitmapInfo.ReferenceEquals(_replayStateData?._nullBitmapInfo ?? default)) @@ -3006,22 +2966,6 @@ internal void MoveToStart() _stateObj.AssertValidState(); } - internal bool MoveToContinue() - { - if (ContinueEnabled) - { - if (_continuePacket != null && _continuePacket != _current) - { - _continueStateData.Restore(_stateObj); - _stateObj.SetBuffer(_current.Buffer, 0, _current.Read); - _stateObj._snapshotStatus = SnapshotStatus.ReplayRunning; - _stateObj.AssertValidState(); - return true; - } - } - return false; - } - internal void CaptureAsStart(TdsParserStateObject stateObj) { _firstPacket = null; @@ -3031,7 +2975,6 @@ internal void CaptureAsStart(TdsParserStateObject stateObj) _stateObj = stateObj; _replayStateData ??= new StateObjectData(); _replayStateData.Capture(stateObj); - #if DEBUG _rollingPend = 0; _rollingPendCount = 0; @@ -3043,43 +2986,6 @@ internal void CaptureAsStart(TdsParserStateObject stateObj) AppendPacketData(stateObj._inBuff, stateObj._inBytesRead); } - internal void CaptureAsContinue(TdsParserStateObject stateObj) - { - if (ContinueEnabled) - { - Debug.Assert(_stateObj == stateObj); - if (_current is not null) - { - _continueStateData ??= new StateObjectData(); - _continueStateData.Capture(stateObj, trackStack: false); - _continuePacket = _current; - } - } - } - - internal void SetPacketPayloadSize(int size) - { - if (_current == null) - { - throw new InvalidOperationException(); - } - int total = 0; - if (_current.PrevPacket != null) - { - total = _current.PrevPacket.TotalSize; - } - _current.TotalSize = total + size; - } - - internal int GetPacketDataOffset() - { - if (_current == null) - { - throw new InvalidOperationException(); - } - return _current.GetPacketDataOffset(); - } - internal void Clear() { ClearState(); @@ -3091,7 +2997,6 @@ private void ClearPackets() PacketData packet = _firstPacket; _firstPacket = null; _lastPacket = null; - _continuePacket = null; _current = null; packet.Clear(); _sparePacket = packet; @@ -3100,7 +3005,6 @@ private void ClearPackets() private void ClearState() { _replayStateData.Clear(_stateObj); - _continueStateData?.Clear(_stateObj, trackStack: false); #if DEBUG _rollingPend = 0; _rollingPendCount = 0; diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/MultiplexerTests.cs b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/MultiplexerTests.cs index f42d95a38d..ede0334acf 100644 --- a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/MultiplexerTests.cs +++ b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/MultiplexerTests.cs @@ -3,9 +3,7 @@ using System.Collections.Generic; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; -using System.Linq; using System.Text; -using System.Threading.Tasks; using Xunit; namespace Microsoft.Data.SqlClient.Tests @@ -176,7 +174,6 @@ public static void FailReconstruct2Packets_FullFullPart_Part(bool isAsync) Assert.Throws( () => MultiplexPacketList(isAsync, maxDataSize, input) ); - } @@ -208,7 +205,6 @@ private static List MultiplexPacketList(bool isAsync, int dataSize, } } - if (!isAsync) { if (stateObject._partialPacket != null) @@ -337,7 +333,6 @@ public static List SplitPackets(int dataSize, List packe int sourceOffset = 0; int sourceIndex = 0; - do { Span targetSpan = Span.Empty; @@ -371,8 +366,6 @@ public static List SplitPackets(int dataSize, List packe sourceOffset += copy; sourceSpan.Slice(0, copy).CopyTo(targetSpan.Slice(0, copy)); } - - } while (sourceIndex < packets.Count && targetIndex < arrays.Length); foreach (var array in arrays) @@ -383,7 +376,6 @@ public static List SplitPackets(int dataSize, List packe return list; } - public static int PacketSizeFromDataSize(int dataSize) => TdsEnums.HEADER_LEN + dataSize; public static int DataSizeFromPacketSize(int packetSize) => packetSize - TdsEnums.HEADER_LEN; @@ -465,7 +457,6 @@ public string ToDebugString() } return buffer.ToString(); } - } [DebuggerStepThrough] diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/TdsParserStateObject.TestHarness.cs b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/TdsParserStateObject.TestHarness.cs index 77d9319b00..271502c4da 100644 --- a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/TdsParserStateObject.TestHarness.cs +++ b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/TdsParserStateObject.TestHarness.cs @@ -1,4 +1,8 @@ -using System; +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; using System.Collections.Generic; using System.Diagnostics; using Microsoft.Data.SqlClient.Tests; @@ -128,14 +132,11 @@ void SetBuffer(byte[] buffer, int inBytesUsed, int inBytesRead) _inBytesRead = inBytesRead; } - - // stubs private LastIOTimer _lastSuccessfulIOTimer = new LastIOTimer(); private Parser _parser = new Parser(); private SnapshotStatus _snapshotStatus = SnapshotStatus.NotActive; - [DebuggerStepThrough] private void SniReadStatisticsAndTracing() { } [DebuggerStepThrough] From ff5a4da2190c165cab03628334dfa891dac3b0f6 Mon Sep 17 00:00:00 2001 From: Wraith2 Date: Thu, 25 Jul 2024 19:16:02 +0100 Subject: [PATCH 04/17] add Packet comments --- .../src/Microsoft/Data/SqlClient/Packet.cs | 44 +++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Packet.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Packet.cs index 8d02163b10..d1c1f079eb 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Packet.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Packet.cs @@ -6,6 +6,12 @@ namespace Microsoft.Data.SqlClient { + /// + /// Contains a buffer for a partial or full packet and methods to get information about the status of + /// the packet that the buffer represents.
+ /// This class is used to contain partial packet data and helps ensure that the packet data is completely + /// received before a full packet is made available to the rest of the library + ///
internal sealed class Packet { public const int UnknownDataLength = -1; @@ -21,6 +27,11 @@ public Packet() _dataLength = UnknownDataLength; } + /// + /// If the packet data has enough bytes available to determine the length amount of data that should be present + /// in the packet then this property will be set to the count of data bytes in the packet.
+ /// Otherwise this will be -1 + ///
public int DataLength { get @@ -34,6 +45,10 @@ public int DataLength _dataLength = value; } } + + /// + /// A byte array containing bytes of data or + /// public byte[] Buffer { get @@ -47,6 +62,10 @@ public byte[] Buffer _buffer = value; } } + + /// + /// The total count of bytes currently in the array including the tds header bytes + /// public int CurrentLength { get @@ -61,6 +80,13 @@ public int CurrentLength } } + /// + /// If the packet data has enough bytes available to determine the length amount of data that should be present + /// in the packet then this property will return the count of data bytes that are expected to be in the packet.
+ /// If there are not enough bytes to determine the data byte count then this property will throw an exception.
+ ///
+ /// Call to check if there will be a value before using this property. + ///
public int RequiredLength { get @@ -74,12 +100,30 @@ public int RequiredLength } } + /// + /// returns a boolean value indicating if there are enough total bytes availble in the to read the tds header + /// public bool HasHeader => _totalLength >= TdsEnums.HEADER_LEN; + /// + /// returns a boolean value indicating if the value has been set. + /// public bool HasDataLength => _dataLength >= 0; + /// + /// returns a boolean value indicating whether the contains enough + /// data for a valid tds header, has a set and that the + /// is equal to the + tds header length.
+ ///
public bool IsComplete => _dataLength != UnknownDataLength && (TdsEnums.HEADER_LEN + _dataLength) == _totalLength; + /// + /// returns a containing the first 8 bytes of the array which will + /// contain the TDS header bytes. This can be passed to static functions on to extract information from the + /// tds packet header.
+ /// Call before using this function. + ///
+ /// public ReadOnlySpan GetHeaderSpan() => _buffer.AsSpan(0, TdsEnums.HEADER_LEN); public void Dispose() From 00927cf793f1dbaab0edd0a056568ce8288cecee Mon Sep 17 00:00:00 2001 From: Wraith2 Date: Fri, 9 Aug 2024 23:33:16 +0100 Subject: [PATCH 05/17] Fix async cancellation and add test coverage for the scenario. Add debugging and tighten the exception detection in async cancel tests. --- .../TdsParserStateObject.Multiplexer.cs | 5 ++- .../tests/FunctionalTests/MultiplexerTests.cs | 27 ++++++++++++++ .../AsyncCancelledConnectionsTest.cs | 35 ++++++++++++------- 3 files changed, 53 insertions(+), 14 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.Multiplexer.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.Multiplexer.cs index 55d6323247..32dab7f338 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.Multiplexer.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.Multiplexer.cs @@ -44,6 +44,10 @@ public void ProcessSniPacket(PacketHandle packet, uint error, bool usePartialPac } else { + if (_inBytesRead != 0) + { + SetBuffer(new byte[_inBuff.Length], 0, 0); + } getDataError = GetSniPacket(packet, ref dataSize); } @@ -109,7 +113,6 @@ out recurse if (_snapshot != null) { _snapshot.AppendPacketData(_inBuff, _inBytesRead); - SetBuffer(new byte[_inBuff.Length], 0, 0); appended = true; } else diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/MultiplexerTests.cs b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/MultiplexerTests.cs index ede0334acf..8a7471d072 100644 --- a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/MultiplexerTests.cs +++ b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/MultiplexerTests.cs @@ -176,6 +176,33 @@ public static void FailReconstruct2Packets_FullFullPart_Part(bool isAsync) ); } + [Fact] + public static void BetweenAsyncAttentionPacket() + { + int dataSize = 120; + var normalPacket = CreatePacket(120, 5); + var attentionPacket = CreatePacket(13, 6); + var input = new List { normalPacket, attentionPacket }; + + var stateObject = new TdsParserStateObject(input, TdsEnums.HEADER_LEN + dataSize, true); + + for (int index = 0; index < input.Count; index++) + { + stateObject.Current = input[index]; + stateObject.ProcessSniPacket(default, 0, usePartialPacket: false); + } + + // attention packet should be in the current buffer because the snapshot is not active + Assert.NotNull(stateObject._inBuff); + Assert.Equal(21, stateObject._inBytesRead); + Assert.Equal(0, stateObject._inBytesUsed); + + // attention packet should be in the snapshot as well + Assert.NotNull(stateObject._snapshot); + Assert.NotNull(stateObject._snapshot.List); + Assert.Equal(2, stateObject._snapshot.List.Count); + } + private static List MultiplexPacketList(bool isAsync, int dataSize, List input) { diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/AsyncTest/AsyncCancelledConnectionsTest.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/AsyncTest/AsyncCancelledConnectionsTest.cs index e71d6d62f6..bbc1a13239 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/AsyncTest/AsyncCancelledConnectionsTest.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/AsyncTest/AsyncCancelledConnectionsTest.cs @@ -1,4 +1,5 @@ using System; +using System.Collections.Concurrent; using System.Collections.Generic; using System.Diagnostics; using System.Text; @@ -36,9 +37,13 @@ public void CancelAsyncConnections() private void RunCancelAsyncConnections(SqlConnectionStringBuilder connectionStringBuilder) { SqlConnection.ClearAllPools(); - _watch = Stopwatch.StartNew(); - _random = new Random(4); // chosen via fair dice role. + ParallelLoopResult results = new ParallelLoopResult(); + ConcurrentDictionary tracker = new ConcurrentDictionary(); + + _random = new Random(4); // chosen via fair dice roll. + _watch = Stopwatch.StartNew(); + try { // Setup a timer so that we can see what is going on while our tasks run @@ -47,7 +52,7 @@ private void RunCancelAsyncConnections(SqlConnectionStringBuilder connectionStri results = Parallel.For( fromInclusive: 0, toExclusive: NumberOfTasks, - (int i) => DoManyAsync(connectionStringBuilder).GetAwaiter().GetResult()); + (int i) => DoManyAsync(i, tracker, connectionStringBuilder).GetAwaiter().GetResult()); } } catch (Exception ex) @@ -82,15 +87,15 @@ private void DisplaySummary() { count = _exceptionDetails.Count; } - _output.WriteLine($"{_watch.Elapsed} {_continue} Started:{_start} Done:{_done} InFlight:{_inFlight} RowsRead:{_rowsRead} ResultRead:{_resultRead} PoisonedEnded:{_poisonedEnded} nonPoisonedExceptions:{_nonPoisonedExceptions} PoisonedCleanupExceptions:{_poisonCleanUpExceptions} Count:{count} Found:{_found}"); } // This is the the main body that our Tasks run - private async Task DoManyAsync(SqlConnectionStringBuilder connectionStringBuilder) + private async Task DoManyAsync(int index, ConcurrentDictionary tracker, SqlConnectionStringBuilder connectionStringBuilder) { Interlocked.Increment(ref _start); Interlocked.Increment(ref _inFlight); + tracker[index] = true; using (SqlConnection marsConnection = new SqlConnection(connectionStringBuilder.ToString())) { @@ -100,15 +105,15 @@ private async Task DoManyAsync(SqlConnectionStringBuilder connectionStringBuilde } // First poison - await DoOneAsync(marsConnection, connectionStringBuilder.ToString(), poison: true); + await DoOneAsync(marsConnection, connectionStringBuilder.ToString(), poison: true, index); for (int i = 0; i < NumberOfNonPoisoned && _continue; i++) { // now run some without poisoning - await DoOneAsync(marsConnection, connectionStringBuilder.ToString()); + await DoOneAsync(marsConnection, connectionStringBuilder.ToString(),false,index); } } - + tracker.TryRemove(index, out var _); Interlocked.Decrement(ref _inFlight); Interlocked.Increment(ref _done); } @@ -117,7 +122,7 @@ private async Task DoManyAsync(SqlConnectionStringBuilder connectionStringBuilde // if we are poisoning we will // 1 - Interject some sleeps in the sql statement so that it will run long enough that we can cancel it // 2 - Setup a time bomb task that will cancel the command a random amount of time later - private async Task DoOneAsync(SqlConnection marsConnection, string connectionString, bool poison = false) + private async Task DoOneAsync(SqlConnection marsConnection, string connectionString, bool poison, int parent) { try { @@ -135,12 +140,12 @@ private async Task DoOneAsync(SqlConnection marsConnection, string connectionStr { if (marsConnection != null && marsConnection.State == System.Data.ConnectionState.Open) { - await RunCommand(marsConnection, builder.ToString(), poison); + await RunCommand(marsConnection, builder.ToString(), poison, parent); } else { await connection.OpenAsync(); - await RunCommand(connection, builder.ToString(), poison); + await RunCommand(connection, builder.ToString(), poison, parent); } } } @@ -176,7 +181,7 @@ private async Task DoOneAsync(SqlConnection marsConnection, string connectionStr } } - private async Task RunCommand(SqlConnection connection, string commandText, bool poison) + private async Task RunCommand(SqlConnection connection, string commandText, bool poison, int parent) { int rowsRead = 0; int resultRead = 0; @@ -211,7 +216,7 @@ private async Task RunCommand(SqlConnection connection, string commandText, bool } while (await reader.NextResultAsync() && _continue); } - catch when (poison) + catch (SqlException sqlException) when (poison && sqlException.Message.Contains("Operation cancelled by user.")) { // This looks a little strange, we failed to read above so this should fail too // But consider the case where this code is elsewhere (in the Dispose method of a class holding this logic) @@ -228,6 +233,10 @@ private async Task RunCommand(SqlConnection connection, string commandText, bool throw; } + catch (Exception ex) + { + Assert.Fail("unexpected exception: " + ex.GetType().Name + " " +ex.Message); + } } } finally From 9b5bd63f5c16d6c35c530d60b9b2cdf7af0fff3c Mon Sep 17 00:00:00 2001 From: Wraith2 Date: Mon, 12 Aug 2024 18:43:08 +0100 Subject: [PATCH 06/17] reduce CancelAsyncConnections sensitivity to match main --- .../ManualTests/SQL/AsyncTest/AsyncCancelledConnectionsTest.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/AsyncTest/AsyncCancelledConnectionsTest.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/AsyncTest/AsyncCancelledConnectionsTest.cs index bbc1a13239..0ae12be917 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/AsyncTest/AsyncCancelledConnectionsTest.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/AsyncTest/AsyncCancelledConnectionsTest.cs @@ -216,7 +216,7 @@ private async Task RunCommand(SqlConnection connection, string commandText, bool } while (await reader.NextResultAsync() && _continue); } - catch (SqlException sqlException) when (poison && sqlException.Message.Contains("Operation cancelled by user.")) + catch (SqlException) when (poison) { // This looks a little strange, we failed to read above so this should fail too // But consider the case where this code is elsewhere (in the Dispose method of a class holding this logic) From fbfcb0dce03bfce91c6ae1007d7cc225283cc733 Mon Sep 17 00:00:00 2001 From: Wraith2 Date: Wed, 2 Oct 2024 02:32:25 +0100 Subject: [PATCH 07/17] make multiplexer not require snapshot to consume partial packets --- .../Data/SqlClient/TdsParserStateObject.Multiplexer.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.Multiplexer.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.Multiplexer.cs index 32dab7f338..b9baae024a 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.Multiplexer.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.Multiplexer.cs @@ -35,7 +35,7 @@ public void ProcessSniPacket(PacketHandle packet, uint error, bool usePartialPac bool usedPartialPacket = false; uint getDataError = 0; - if (usePartialPacket && _snapshot == null && _partialPacket != null && _partialPacket.IsComplete) + if (usePartialPacket && _partialPacket != null && _partialPacket.IsComplete) { SetBuffer(_partialPacket.Buffer, 0, _partialPacket.CurrentLength); ClearPartialPacket(); From bd34b4492fb7524ffe18e851ca8f77f3f27261bc Mon Sep 17 00:00:00 2001 From: Wraith2 Date: Tue, 15 Oct 2024 22:57:32 +0100 Subject: [PATCH 08/17] refine AppendPacketData checks and fix behaviour that was causing it to assert --- .../TdsParserStateObject.Multiplexer.cs | 11 ++++++++- .../Data/SqlClient/TdsParserStateObject.cs | 10 +++++++- .../tests/FunctionalTests/MultiplexerTests.cs | 24 +++++++++++++++++++ 3 files changed, 43 insertions(+), 2 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.Multiplexer.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.Multiplexer.cs index b9baae024a..d128e8c0f2 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.Multiplexer.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.Multiplexer.cs @@ -66,10 +66,15 @@ public void ProcessSniPacket(PacketHandle packet, uint error, bool usePartialPac SetBuffer(_inBuff, 0, (int)dataSize); } - bool recurse; + bool recurse = false; bool appended = false; do { + if (recurse && appended) + { + SetBuffer(new byte[_inBuff.Length], 0, 0); + appended = false; + } MultiplexPackets( _inBuff, _inBytesUsed, _inBytesRead, _partialPacket, @@ -113,6 +118,10 @@ out recurse if (_snapshot != null) { _snapshot.AppendPacketData(_inBuff, _inBytesRead); + // if we SetBuffer here to clear the packet buffer we will break the attention handling which relies + // on the attention containing packet remaining in the active buffer even if we're appending to the + // snapshot so we will have to use the appended variable to prevent the same buffer being added again + //// SetBuffer(new byte[_inBuff.Length], 0, 0); appended = true; } else diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs index 834fa00f20..a8db09bd3e 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs @@ -2900,7 +2900,15 @@ internal void AppendPacketData(byte[] buffer, int read) #if DEBUG for (PacketData current = _firstPacket; current != null; current = current.NextPacket) { - Debug.Assert(!ReferenceEquals(current.Buffer, buffer)); + if (ReferenceEquals(current.Buffer, buffer)) + { + // multiple packets are permitted to be in the same buffer because of partial packets + // but their contents cannot overlap + if ((current.Read + current.DataLength) > read) + { + Debug.Fail("duplicate or overlapping packet appended to snapshot"); + } + } } #endif PacketData packetData = _sparePacket; diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/MultiplexerTests.cs b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/MultiplexerTests.cs index 8a7471d072..4fb6988dee 100644 --- a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/MultiplexerTests.cs +++ b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/MultiplexerTests.cs @@ -176,6 +176,30 @@ public static void FailReconstruct2Packets_FullFullPart_Part(bool isAsync) ); } + [Fact] + public static void TrailingPartialPacketInSnapshotNotDuplicated() + { + int dataSize = 120; + + var expected = new List + { + CreatePacket(120, 5), + CreatePacket(90, 6), + CreatePacket(13, 7), + }; + + var input = SplitPackets(120, expected, + (8 + 120), + (8 + 90) + (8 + 13) + ); + + Assert.Equal(SumPacketLengths(expected), SumPacketLengths(input)); + + var output = MultiplexPacketList(true, dataSize, input); + + ComparePacketLists(dataSize, expected, output); + } + [Fact] public static void BetweenAsyncAttentionPacket() { From c28aba3450841fe74e5943c79c0bf9b3c9f06c85 Mon Sep 17 00:00:00 2001 From: Wraith2 Date: Fri, 18 Oct 2024 14:40:48 +0100 Subject: [PATCH 09/17] add more debugging --- .../src/Microsoft/Data/SqlClient/TdsParser.cs | 11 +- .../Data/SqlClient/TdsParserStateObject.cs | 256 ++++++++++++++---- 2 files changed, 218 insertions(+), 49 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs index 4945c0f58a..1e476f3944 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -2077,11 +2077,20 @@ internal TdsOperationStatus TryRun(RunBehavior runBehavior, SqlCommand cmdHandle if (!IsValidTdsToken(token)) { - Debug.Fail($"unexpected token; token = {token,-2:X2}"); +#if DEBUG + string message = stateObj.DumpBuffer(); + Debug.Fail(message); +#endif + //Debug.Fail($"unexpected token; token = {token,-2:X2}"); _state = TdsParserState.Broken; _connHandler.BreakConnection(); SqlClientEventSource.Log.TryTraceEvent(" Potential multi-threaded misuse of connection, unexpected TDS token found {0}", ObjectID); +#if DEBUG + throw new InvalidOperationException(message); +#else throw SQL.ParsingError(); +#endif + } int tokenLength; diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs index a8db09bd3e..2f9ae4551c 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs @@ -7,6 +7,7 @@ using System.Collections.Generic; using System.Diagnostics; using System.Security; +using System.Security.Cryptography; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -2564,43 +2565,45 @@ internal bool IsConnectionAlive(bool throwOnException) return isAlive; } - /* + // leave this in. comes handy if you have to do Console.WriteLine style debugging ;) - private void DumpBuffer() { - Console.WriteLine("dumping buffer"); - Console.WriteLine("_inBytesRead = {0}", _inBytesRead); - Console.WriteLine("_inBytesUsed = {0}", _inBytesUsed); + internal string DumpBuffer() { + StringBuilder buffer = new StringBuilder(128); + buffer.AppendLine("dumping buffer"); + buffer.AppendFormat("_inBytesRead = {0}", _inBytesRead).AppendLine(); + buffer.AppendFormat("_inBytesUsed = {0}", _inBytesUsed).AppendLine(); int cc = 0; // character counter int i; - Console.WriteLine("used buffer:"); + buffer.AppendLine("used buffer:"); for (i=0; i< _inBytesUsed; i++) { if (cc==16) { - Console.WriteLine(); + buffer.AppendLine(); cc = 0; } - Console.Write("{0,-2:X2} ", _inBuff[i]); + buffer.AppendFormat("{0,-2:X2} ", _inBuff[i]); cc++; } if (cc>0) { - Console.WriteLine(); + buffer.AppendLine(); } cc = 0; - Console.WriteLine("unused buffer:"); + buffer.AppendLine("unused buffer:"); for (i=_inBytesUsed; i<_inBytesRead; i++) { if (cc==16) { - Console.WriteLine(); + buffer.AppendLine(); cc = 0; } - Console.Write("{0,-2:X2} ", _inBuff[i]); + buffer.AppendFormat("{0,-2:X2} ", _inBuff[i]); cc++; } if (cc>0) { - Console.WriteLine(); + buffer.AppendLine(); } + return buffer.ToString(); } - */ + internal void SetSnapshot() { @@ -2649,15 +2652,21 @@ internal void Clear() PrevPacket.NextPacket = null; PrevPacket = null; } - SetDebugStackInternal(null); - SetDebugPacketIdInternal(0); + SetDebugStackImpl(null); + SetDebugPacketId(0); + SetDebugDataHash(); } - internal void SetDebugStack(string value) => SetDebugStackInternal(value); - internal void SetDebugPacketId(int value) => SetDebugPacketIdInternal(value); + internal void SetDebugStack(string value) => SetDebugStackImpl(value); + internal void SetDebugPacketId(int value) => SetDebugPacketIdImpl(value); + internal void SetDebugDataHash() => SetDebugDataHashImpl(); - partial void SetDebugStackInternal(string value); - partial void SetDebugPacketIdInternal(int value); + internal void CheckDebugDataHash() => CheckDebugDataHashImpl(); + + partial void SetDebugStackImpl(string value); + partial void SetDebugPacketIdImpl(int value); + partial void SetDebugDataHashImpl(); + partial void CheckDebugDataHashImpl(); } #if DEBUG @@ -2679,33 +2688,137 @@ public PacketDataDebugView(PacketData data) _data = data; } - [DebuggerBrowsable(DebuggerBrowsableState.RootHidden)] - public PacketData[] Items + //[DebuggerBrowsable(DebuggerBrowsableState.RootHidden)] + //public PacketData[] Items + //{ + // get + // { + // PacketData[] items = Array.Empty(); + // if (_data != null) + // { + // int count = 0; + // for (PacketData current = _data; current != null; current = current?.NextPacket) + // { + // count++; + // } + // items = new PacketData[count]; + // int index = 0; + // for (PacketData current = _data; current != null; current = current?.NextPacket, index++) + // { + // items[index] = current; + // } + // } + // return items; + // } + //} + + public string Type { + + get + { + if (_data != null && _data.Buffer!=null) + { + switch (_data.Buffer[0]) + { + case 1: return nameof(TdsEnums.MT_SQL); + case 2: return nameof(TdsEnums.MT_LOGIN); + case 3: return nameof(TdsEnums.MT_RPC); + case 4: return nameof(TdsEnums.MT_TOKENS); + case 5: return nameof(TdsEnums.MT_BINARY); + case 6: return nameof(TdsEnums.MT_ATTN); + case 7: return nameof(TdsEnums.MT_BULK); + case 8: return nameof(TdsEnums.MT_FEDAUTH); + case 9: return nameof(TdsEnums.MT_CLOSE); + case 10: return nameof(TdsEnums.MT_ERROR); + case 11: return nameof(TdsEnums.MT_ACK); + case 12: return nameof(TdsEnums.MT_ECHO); + case 13: return nameof(TdsEnums.MT_LOGOUT); + case 14: return nameof(TdsEnums.MT_TRANS); + case 15: return nameof(TdsEnums.MT_OLEDB); + case 16: return nameof(TdsEnums.MT_LOGIN7); + case 17: return nameof(TdsEnums.MT_SSPI); + case 18: return nameof(TdsEnums.MT_PRELOGIN); + default: return _data.Buffer[0].ToString("X2"); + } + } + return ""; + } + } + + public string Status { get { - PacketData[] items = Array.Empty(); - if (_data != null) + if (_data != null && _data.Buffer != null && _data.Buffer.Length > 1) { - int count = 0; - for (PacketData current = _data; current != null; current = current?.NextPacket) + int status = Packet.GetStatusFromHeader(_data.Buffer); + StringBuilder buffer = new StringBuilder(10); + + if ((status & TdsEnums.ST_EOM) == TdsEnums.ST_EOM) { - count++; + if (buffer.Length > 0) + { + buffer.Append(','); + } + buffer.Append(nameof(TdsEnums.ST_EOM)); } - items = new PacketData[count]; - int index = 0; - for (PacketData current = _data; current != null; current = current?.NextPacket, index++) + if ((status & TdsEnums.ST_AACK) == TdsEnums.ST_AACK) { - items[index] = current; + if (buffer.Length > 0) + { + buffer.Append(','); + } + buffer.Append(nameof(TdsEnums.ST_AACK)); + } + if ((status & TdsEnums.ST_BATCH) == TdsEnums.ST_BATCH) + { + if (buffer.Length > 0) + { + buffer.Append(','); + } + buffer.Append(nameof(TdsEnums.ST_BATCH)); + } + if ((status & TdsEnums.ST_RESET_CONNECTION) == TdsEnums.ST_RESET_CONNECTION) + { + if (buffer.Length > 0) + { + buffer.Append(','); + } + buffer.Append(nameof(TdsEnums.ST_RESET_CONNECTION)); } + if ((status & TdsEnums.ST_RESET_CONNECTION_PRESERVE_TRANSACTION) == TdsEnums.ST_RESET_CONNECTION_PRESERVE_TRANSACTION) + { + if (buffer.Length > 0) + { + buffer.Append(','); + } + buffer.Append(nameof(TdsEnums.ST_RESET_CONNECTION_PRESERVE_TRANSACTION)); + } + + return buffer.ToString(); } - return items; + + return ""; } } + + public int Length => _data.DataLength; + + public int Spid => _data.SPID; + + public int PacketID => _data.PacketID; + + public ReadOnlySpan HeaderBytes => _data.GetHeaderSpan(); + + public ReadOnlySpan Data => _data.Buffer.AsSpan(TdsEnums.HEADER_LEN); + + public PacketData NextPacket => _data.NextPacket; + public PacketData PrevPacket => _data.PrevPacket; } - public int PacketId; + public int DebugPacketId; public string Stack; + public byte[] Hash; public int PacketID => Packet.GetIDFromHeader(Buffer.AsSpan(0, TdsEnums.HEADER_LEN)); @@ -2715,29 +2828,74 @@ public PacketData[] Items public int DataLength => Packet.GetDataLengthFromHeader(Buffer.AsSpan(0, TdsEnums.HEADER_LEN)); - partial void SetDebugStackInternal(string value) => Stack = value; + public ReadOnlySpan GetHeaderSpan() => Buffer.AsSpan(0, TdsEnums.HEADER_LEN); + + partial void SetDebugStackImpl(string value) => Stack = value; - partial void SetDebugPacketIdInternal(int value) => PacketId = value; + partial void SetDebugPacketIdImpl(int value) => DebugPacketId = value; - public override string ToString() + partial void SetDebugDataHashImpl() { - string byteString = null; - if (Buffer != null && Buffer.Length >= 12) + if (Buffer != null) { - ReadOnlySpan bytes = Buffer.AsSpan(0, 12); - StringBuilder buffer = new StringBuilder(12 * 3 + 10); - buffer.Append('{'); - for (int index = 0; index < bytes.Length; index++) + using (MD5 hasher = MD5.Create()) { - buffer.AppendFormat("{0:X2}", bytes[index]); - buffer.Append(", "); + Hash = hasher.ComputeHash(Buffer, 0, Read); } - buffer.Append("..."); - buffer.Append('}'); - byteString = buffer.ToString(); } - return $"{PacketId}: [{Read}] {byteString} {(NextPacket != null ? @"->" : string.Empty)}"; + else + { + Hash = null; + } + } + + partial void CheckDebugDataHashImpl() + { + if (Hash == null) + { + if (Buffer != null && Read > 0) + { + throw new InvalidOperationException("Packet modification detected. Hash is null but packet contains non-null buffer"); + } + } + else + { + byte[] checkHash = null; + using (MD5 hasher = MD5.Create()) + { + checkHash = hasher.ComputeHash(Buffer, 0, Read); + } + + for (int index = 0; index < Hash.Length; index++) + { + if (Hash[index] != checkHash[index]) + { + throw new InvalidOperationException("Packet modification detected. Hash from packet creation does not match hash from packet check"); + } + } + } + } + + //public override string ToString() + //{ + // string byteString = null; + // if (Buffer != null && Buffer.Length >= 12) + // { + // ReadOnlySpan bytes = Buffer.AsSpan(0, 12); + // StringBuilder buffer = new StringBuilder(12 * 3 + 10); + // buffer.Append('{'); + // for (int index = 0; index < bytes.Length; index++) + // { + // buffer.AppendFormat("{0:X2}", bytes[index]); + // buffer.Append(", "); + // } + // buffer.Append("..."); + // buffer.Append('}'); + // byteString = buffer.ToString(); + // } + // return $"{InternalPacketId}: [{Read}] {byteString} {(NextPacket != null ? @"->" : string.Empty)}"; + //} } #endif @@ -2925,6 +3083,7 @@ internal void AppendPacketData(byte[] buffer, int read) #if DEBUG packetData.SetDebugStack(_stateObj._lastStack); packetData.SetDebugPacketId(Interlocked.Increment(ref _packetCounter)); + packetData.SetDebugDataHash(); #endif if (_firstPacket is null) { @@ -2958,6 +3117,7 @@ internal bool MoveNext() if (moved) { _stateObj.SetBuffer(_current.Buffer, 0, _current.Read); + _current.CheckDebugDataHash(); _stateObj._snapshotStatus = moveToMode; retval = true; } From f71d6079bf3cee3ebbf01be9a54feebea454d912 Mon Sep 17 00:00:00 2001 From: Wraith2 Date: Wed, 23 Oct 2024 01:41:47 +0100 Subject: [PATCH 10/17] add debug fail to sanity check in multiplexer to make it clear that it has been hit --- .../Data/SqlClient/TdsParserStateObject.Multiplexer.cs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.Multiplexer.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.Multiplexer.cs index d128e8c0f2..d48993ca6d 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.Multiplexer.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.Multiplexer.cs @@ -420,6 +420,8 @@ out bool recurse if (consumePartialPacket && consumeInputDirectly) { + string message = $"MultiplexPackets cannot return both {nameof(consumePartialPacket)} and {nameof(consumeInputDirectly)}"; + System.Diagnostics.Debug.Fail(message); // fail is easier to debug because the exception can be swallowed by higher layers. throw new InvalidOperationException($"MultiplexPackets cannot return both {nameof(consumePartialPacket)} and {nameof(consumeInputDirectly)}"); } } From 04886d76d39f21c0934abceeb39e198dd3aa7fc7 Mon Sep 17 00:00:00 2001 From: Wraith2 Date: Thu, 24 Oct 2024 19:34:48 +0100 Subject: [PATCH 11/17] update multiplexer to deal with multiple sequential packets less than buffer size --- .../TdsParserStateObject.Multiplexer.cs | 132 +++++++++++++----- .../tests/FunctionalTests/MultiplexerTests.cs | 23 ++- 2 files changed, 109 insertions(+), 46 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.Multiplexer.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.Multiplexer.cs index d48993ca6d..5fa9f97e9e 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.Multiplexer.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.Multiplexer.cs @@ -202,15 +202,17 @@ private static void MultiplexPackets( out Packet remainderPacket, out bool consumeInputDirectly, out bool consumePartialPacket, - out bool consumeRemainderPacket, + out bool createdRemainderPacket, out bool recurse ) { + Debug.Assert(dataBuffer != null); + ReadOnlySpan data = dataBuffer.AsSpan(dataOffset, dataLength); remainderPacket = null; consumeInputDirectly = false; consumePartialPacket = false; - consumeRemainderPacket = false; + createdRemainderPacket = false; recurse = false; newDataLength = dataLength; @@ -266,14 +268,43 @@ out bool recurse // data from a following packet // the TDS spec requires that all packets be of the defined packet size apart from - // the last packet of a response. This means that is is not possible to have more than + // the last packet of a response. This means that it should not possible to have more than // 2 packet fragments in a single buffer like this: // - first packet caused the partial // - second packet is the one we have just unpacked // - third packet is the extra data we have found + // however, due to the timing of cancellation it is possible that a response token stream + // has ended before an attention message response is sent leaving us with a short final + // packet and an additional short cancel packet following it + + // this should only happen when the caller is trying to consume the partial packet + // and does not have new input data + Debug.Assert(newDataLength == 0); + + int remainderLength = partialPacket.CurrentLength - partialPacket.RequiredLength; + + partialPacket.CurrentLength -= remainderLength; + + remainderPacket = new Packet + { + Buffer = new byte[dataBuffer.Length], + CurrentLength = remainderLength, + }; + + ReadOnlySpan remainderSource = partialPacket.Buffer.AsSpan(TdsEnums.HEADER_LEN + partialPacket.DataLength, remainderLength); + Span remainderTarget = remainderPacket.Buffer.AsSpan(0, remainderLength); + remainderSource.CopyTo(remainderTarget); - // we must throw an exception because we have encountered an invalid tds stream - throw SQL.ParsingError(ParsingErrorState.CorruptedTdsStream); + createdRemainderPacket = true; + + if (remainderPacket.HasHeader) + { + remainderPacket.DataLength = Packet.GetDataLengthFromHeader(remainderPacket); + if (remainderPacket.HasDataLength && remainderPacket.CurrentLength >= remainderPacket.RequiredLength) + { + recurse = true; + } + } } if (partialPacket.CurrentLength == partialPacket.RequiredLength) @@ -290,7 +321,8 @@ out bool recurse // some data has been taken from the buffer, put into the partial // packet buffer and we have data left so move the data we have // left to the start of the buffer so we can pass the buffer back - // as zero based to the caller avoiding offset calculations everywhere + // as zero based to the caller avoiding offset calculations in the + // rest of this method Buffer.BlockCopy( dataBuffer, dataOffset + bytesConsumed, // from dataBuffer, dataOffset, // to @@ -299,7 +331,7 @@ out bool recurse #if DEBUG // for debugging purposes fill the removed data area with an easily // recognisable pattern so we can see if it is misused - Span removed = dataBuffer.AsSpan(dataOffset + (dataLength - bytesConsumed), (dataOffset + bytesConsumed)); + Span removed = dataBuffer.AsSpan(dataOffset + (dataLength - bytesConsumed), bytesConsumed); removed.Fill(0xFF); #endif @@ -312,7 +344,12 @@ out bool recurse } } - if (data.Length > 0 && !consumeRemainderPacket) + // partial packet handling should not make decisions about consuming the input buffer + Debug.Assert(!consumeInputDirectly); + // partial packet handling may only create a remainder packet when it is trying to consume the partial packet and has no incoming data + Debug.Assert(!createdRemainderPacket || data.Length == 0); + + if (data.Length > 0) { if (data.Length >= TdsEnums.HEADER_LEN) { @@ -340,7 +377,7 @@ out bool recurse DataLength = packetDataLength, CurrentLength = data.Length }; - consumeRemainderPacket = true; + createdRemainderPacket = true; Debug.Assert(remainderPacket.HasHeader); // precondition of entering this block Debug.Assert(remainderPacket.HasDataLength); // must have been set at construction if (remainderPacket.CurrentLength >= remainderPacket.RequiredLength) @@ -354,59 +391,88 @@ out bool recurse } else if (data.Length < TdsEnums.HEADER_LEN + packetDataLength) { - // another partial packet so produce one and tell the caller that they need - // consume it. + // an incomplete packet so create a remainder packet to pass back remainderPacket = new Packet { Buffer = dataBuffer, DataLength = packetDataLength, CurrentLength = data.Length }; - consumeRemainderPacket = true; + createdRemainderPacket = true; } else // implied: current length > required length { // more data than required so need to split it out but we can't do that // here so we need to tell the caller to take the remainder packet and then // call this function again + if (consumePartialPacket) + { + // we are already telling the caller to consume the partial packet so we + // can't tell them it to also consume the data in the buffer directly + // so create a remainder packet and pass it back. + remainderPacket = new Packet + { + Buffer = new byte[dataBuffer.Length], + CurrentLength = data.Length + }; - int remainderLength = data.Length - (TdsEnums.HEADER_LEN + packetDataLength); - remainderPacket = new Packet + ReadOnlySpan remainderSource = data; + Span remainderTarget = remainderPacket.Buffer.AsSpan(0, remainderPacket.CurrentLength); + remainderSource.CopyTo(remainderTarget); + + createdRemainderPacket = true; + } + else { - Buffer = new byte[dataBuffer.Length], - CurrentLength = remainderLength, - }; + int remainderLength = data.Length - (TdsEnums.HEADER_LEN + packetDataLength); + remainderPacket = new Packet + { + Buffer = new byte[dataBuffer.Length], + CurrentLength = remainderLength, + }; - ReadOnlySpan remainderSource = data.Slice(TdsEnums.HEADER_LEN + packetDataLength); - Span remainderTarget = remainderPacket.Buffer.AsSpan(0, remainderLength); - remainderSource.CopyTo(remainderTarget); + ReadOnlySpan remainderSource = data.Slice(TdsEnums.HEADER_LEN + packetDataLength); + Span remainderTarget = remainderPacket.Buffer.AsSpan(0, remainderLength); + remainderSource.CopyTo(remainderTarget); - newDataLength = TdsEnums.HEADER_LEN + packetDataLength; - consumeInputDirectly = true; - consumeRemainderPacket = true; +#if DEBUG + // for debugging purposes fill the removed data area with an easily + // recognisable pattern so we can see if it is misused + Span removed = dataBuffer.AsSpan(TdsEnums.HEADER_LEN + packetDataLength, remainderLength); + removed.Fill(0xFF); +#endif - if (remainderPacket.HasHeader) - { - remainderPacket.DataLength = Packet.GetDataLengthFromHeader(remainderPacket); - if (remainderPacket.HasDataLength && remainderPacket.CurrentLength >= remainderPacket.RequiredLength) + newDataLength = TdsEnums.HEADER_LEN + packetDataLength; + + consumeInputDirectly = true; + createdRemainderPacket = true; + + if (remainderPacket.HasHeader) { - recurse = true; + remainderPacket.DataLength = Packet.GetDataLengthFromHeader(remainderPacket); + if (remainderPacket.HasDataLength && remainderPacket.CurrentLength >= remainderPacket.RequiredLength) + { + recurse = true; + } } + } } } else { - // we took some data from the input to reconstruct the partial packet - // so we can't tell the caller to directly consume the packet in the - // input buffer, we need to construct a new remainder packet and then - // tell them to consume it + // either: + // 1) we took some data from the input to reconstruct the partial packet + // 2) there was less than a single packet header of data recieved + // in both cases we can't tell the caller to directly consume the packet + // in the input buffer, we need to construct a new remainder packet with + // the incomplete data and let the caller deal with it remainderPacket = new Packet { Buffer = dataBuffer, CurrentLength = data.Length }; - consumeRemainderPacket = true; + createdRemainderPacket = true; } } #if DEBUG diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/MultiplexerTests.cs b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/MultiplexerTests.cs index 4fb6988dee..4840a8c889 100644 --- a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/MultiplexerTests.cs +++ b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/MultiplexerTests.cs @@ -153,27 +153,24 @@ public static void ReconstructMultiplePacketSequenceWithShortEnd(bool isAsync) } [Theory, MemberData(nameof(IsAsync))] - public static void FailReconstruct2Packets_FullFullPart_Part(bool isAsync) + public static void Reconstruct3Packets_PartPartPart(bool isAsync) { - // illegal, cannot have multiple packets end in a single packet because all packets except an end packet must - // be be of max length, thus only max length packets can exist before a short packet. - int maxDataSize = 46; + int dataSize = 62; var expected = new List { - CreatePacket(10, 5), + CreatePacket(26, 5), CreatePacket(10, 6), - CreatePacket(30, 7) + CreatePacket(10, 7) }; - var input = SplitPackets(maxDataSize, expected, - (8 + 10) + (8 + 10) + (8 + 2), // full, full, part - 36 // part + var input = SplitPackets(70, expected, + (8 + 26) + (8 + 10) + (8 + 10) // = 70: full, full, part ); - Assert.Throws( - () => MultiplexPacketList(isAsync, maxDataSize, input) - ); + var output = MultiplexPacketList(isAsync, dataSize, input); + + ComparePacketLists(dataSize, expected, output); } [Fact] @@ -258,7 +255,7 @@ private static List MultiplexPacketList(bool isAsync, int dataSize, if (!isAsync) { - if (stateObject._partialPacket != null) + while (stateObject._partialPacket != null) { stateObject.Current = default; From 7c7e36242ad41d5af2f89e2c54be7a21cc0d1e52 Mon Sep 17 00:00:00 2001 From: Wraith2 Date: Tue, 29 Oct 2024 02:11:10 +0000 Subject: [PATCH 12/17] change multiplexer to deal with trailing partial packets correctly when multiple can be present in a single partial buffer. --- .../src/Microsoft/Data/SqlClient/Packet.cs | 25 +- .../TdsParserStateObject.Multiplexer.cs | 131 +++++----- .../Data/SqlClient/TdsParserStateObject.cs | 14 +- .../Microsoft.Data.SqlClient.Tests.csproj | 1 + .../tests/FunctionalTests/MultiplexerTests.cs | 241 +++++++++++++++--- .../TdsParserStateObject.TestHarness.cs | 114 --------- 6 files changed, 287 insertions(+), 239 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Packet.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Packet.cs index d1c1f079eb..802eb2c936 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Packet.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Packet.cs @@ -12,7 +12,7 @@ namespace Microsoft.Data.SqlClient /// This class is used to contain partial packet data and helps ensure that the packet data is completely /// received before a full packet is made available to the rest of the library /// - internal sealed class Packet + internal sealed partial class Packet { public const int UnknownDataLength = -1; @@ -29,8 +29,8 @@ public Packet() /// /// If the packet data has enough bytes available to determine the length amount of data that should be present - /// in the packet then this property will be set to the count of data bytes in the packet.
- /// Otherwise this will be -1 + /// in the packet then this property will be set to the count of data bytes in the packet,
+ /// otherwise this will be -1 ///
public int DataLength { @@ -113,9 +113,9 @@ public int RequiredLength /// /// returns a boolean value indicating whether the contains enough /// data for a valid tds header, has a set and that the - /// is equal to the + tds header length.
+ /// is greater than or equal to the + tds header length.
///
- public bool IsComplete => _dataLength != UnknownDataLength && (TdsEnums.HEADER_LEN + _dataLength) == _totalLength; + public bool ContainsCompletePacket => _dataLength != UnknownDataLength && (TdsEnums.HEADER_LEN + _dataLength) <= _totalLength; /// /// returns a containing the first 8 bytes of the array which will @@ -139,6 +139,10 @@ public void CheckDisposed() } } + internal void SetCreatedBy(int creator) => SetCreatedByImpl(creator); + + partial void SetCreatedByImpl(int creator); + public static void ThrowDisposed() { throw new ObjectDisposedException(nameof(Packet)); @@ -163,4 +167,15 @@ internal static int GetIDFromHeader(ReadOnlySpan header) internal static bool GetIsEOMFromHeader(ReadOnlySpan header) => GetStatusFromHeader(header) == 1; } + +#if DEBUG + internal sealed partial class Packet + { + private int _createdBy; + + public int CreatedBy => _createdBy; + + partial void SetCreatedByImpl(int creator) => _createdBy = creator; + } +#endif } diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.Multiplexer.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.Multiplexer.cs index 5fa9f97e9e..9e9e1cddd0 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.Multiplexer.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.Multiplexer.cs @@ -12,10 +12,10 @@ namespace Microsoft.Data.SqlClient #endif partial class TdsParserStateObject { - private Packet __partialPacket; - internal Packet _partialPacket => __partialPacket; + private Packet _partialPacket; + internal Packet PartialPacket => _partialPacket; - public void ProcessSniPacket(PacketHandle packet, uint error, bool usePartialPacket = false) + public void ProcessSniPacket(PacketHandle packet, uint error) { if (error != 0) { @@ -34,10 +34,11 @@ public void ProcessSniPacket(PacketHandle packet, uint error, bool usePartialPac uint dataSize = 0; bool usedPartialPacket = false; uint getDataError = 0; - - if (usePartialPacket && _partialPacket != null && _partialPacket.IsComplete) + + if (PartialPacketContainsCompletePacket()) { - SetBuffer(_partialPacket.Buffer, 0, _partialPacket.CurrentLength); + Packet partialPacket = _partialPacket; + SetBuffer(partialPacket.Buffer, 0, partialPacket.CurrentLength); ClearPartialPacket(); getDataError = TdsEnums.SNI_SUCCESS; usedPartialPacket = true; @@ -77,7 +78,7 @@ public void ProcessSniPacket(PacketHandle packet, uint error, bool usePartialPac } MultiplexPackets( _inBuff, _inBytesUsed, _inBytesRead, - _partialPacket, + PartialPacket, out int newDataOffset, out int newDataLength, out Packet remainderPacket, @@ -93,13 +94,13 @@ out recurse { if (_snapshot != null) { - _snapshot.AppendPacketData(_partialPacket.Buffer, _partialPacket.CurrentLength); + _snapshot.AppendPacketData(PartialPacket.Buffer, PartialPacket.CurrentLength); SetBuffer(new byte[_inBuff.Length], 0, 0); appended = true; } else { - SetBuffer(_partialPacket.Buffer, 0, _partialPacket.CurrentLength); + SetBuffer(PartialPacket.Buffer, 0, PartialPacket.CurrentLength); } bufferIsPartialCompleted = true; @@ -175,25 +176,33 @@ out recurse } } - private void SetPartialPacket(Packet packet/*, [CallerMemberName] string caller = null*/) + private void SetPartialPacket(Packet packet) { - if (__partialPacket != null && packet != null) + if (_partialPacket != null && packet != null) { throw new InvalidOperationException("partial packet cannot be non-null when setting to non=null"); } - __partialPacket = packet; + _partialPacket = packet; } - private void ClearPartialPacket(/*[CallerMemberName] string caller = null*/) + private void ClearPartialPacket() { - Packet partialPacket = __partialPacket; - __partialPacket = null; + Packet partialPacket = _partialPacket; + _partialPacket = null; if (partialPacket != null) { partialPacket.Dispose(); } } + // this check is used in two places that must be identical so it is + // extracted into a method, do not inline this method + internal bool PartialPacketContainsCompletePacket() + { + Packet partialPacket = _partialPacket; + return partialPacket != null && partialPacket.ContainsCompletePacket; + } + private static void MultiplexPackets( byte[] dataBuffer, int dataOffset, int dataLength, Packet partialPacket, @@ -268,7 +277,7 @@ out bool recurse // data from a following packet // the TDS spec requires that all packets be of the defined packet size apart from - // the last packet of a response. This means that it should not possible to have more than + // the last packet of a response. This means that is should not possible to have more than // 2 packet fragments in a single buffer like this: // - first packet caused the partial // - second packet is the one we have just unpacked @@ -279,7 +288,6 @@ out bool recurse // this should only happen when the caller is trying to consume the partial packet // and does not have new input data - Debug.Assert(newDataLength == 0); int remainderLength = partialPacket.CurrentLength - partialPacket.RequiredLength; @@ -290,6 +298,7 @@ out bool recurse Buffer = new byte[dataBuffer.Length], CurrentLength = remainderLength, }; + remainderPacket.SetCreatedBy(1); ReadOnlySpan remainderSource = partialPacket.Buffer.AsSpan(TdsEnums.HEADER_LEN + partialPacket.DataLength, remainderLength); Span remainderTarget = remainderPacket.Buffer.AsSpan(0, remainderLength); @@ -297,14 +306,7 @@ out bool recurse createdRemainderPacket = true; - if (remainderPacket.HasHeader) - { - remainderPacket.DataLength = Packet.GetDataLengthFromHeader(remainderPacket); - if (remainderPacket.HasDataLength && remainderPacket.CurrentLength >= remainderPacket.RequiredLength) - { - recurse = true; - } - } + recurse = SetupRemainderPacket(remainderPacket); } if (partialPacket.CurrentLength == partialPacket.RequiredLength) @@ -318,11 +320,10 @@ out bool recurse { if (data.Length > 0) { - // some data has been taken from the buffer, put into the partial - // packet buffer and we have data left so move the data we have - // left to the start of the buffer so we can pass the buffer back - // as zero based to the caller avoiding offset calculations in the - // rest of this method + // some data has been taken from the buffer and put into the partial + // packet buffer. We have data left so move the data we have to the + // start of the buffer so we can pass the buffer back as zero based + // to the caller avoiding offset calculations in the rest of this method Buffer.BlockCopy( dataBuffer, dataOffset + bytesConsumed, // from dataBuffer, dataOffset, // to @@ -374,19 +375,11 @@ out bool recurse remainderPacket = new Packet { Buffer = dataBuffer, - DataLength = packetDataLength, CurrentLength = data.Length }; + remainderPacket.SetCreatedBy(2); createdRemainderPacket = true; - Debug.Assert(remainderPacket.HasHeader); // precondition of entering this block - Debug.Assert(remainderPacket.HasDataLength); // must have been set at construction - if (remainderPacket.CurrentLength >= remainderPacket.RequiredLength) - { - // the remainder packet contains more data than the packet so we need - // to tell the caller to recurse into this function again once they have - // consumed the first packet - recurse = true; - } + recurse = SetupRemainderPacket(remainderPacket); } } else if (data.Length < TdsEnums.HEADER_LEN + packetDataLength) @@ -398,11 +391,13 @@ out bool recurse DataLength = packetDataLength, CurrentLength = data.Length }; + remainderPacket.SetCreatedBy(3); createdRemainderPacket = true; + recurse = SetupRemainderPacket(remainderPacket); } else // implied: current length > required length { - // more data than required so need to split it out but we can't do that + // more data than required so need to split it out, but we can't do that // here so we need to tell the caller to take the remainder packet and then // call this function again if (consumePartialPacket) @@ -415,47 +410,39 @@ out bool recurse Buffer = new byte[dataBuffer.Length], CurrentLength = data.Length }; - + remainderPacket.SetCreatedBy(4); ReadOnlySpan remainderSource = data; Span remainderTarget = remainderPacket.Buffer.AsSpan(0, remainderPacket.CurrentLength); remainderSource.CopyTo(remainderTarget); createdRemainderPacket = true; + + recurse = SetupRemainderPacket(remainderPacket); } else { + newDataLength = TdsEnums.HEADER_LEN + packetDataLength; int remainderLength = data.Length - (TdsEnums.HEADER_LEN + packetDataLength); remainderPacket = new Packet { Buffer = new byte[dataBuffer.Length], - CurrentLength = remainderLength, + CurrentLength = remainderLength }; + remainderPacket.SetCreatedBy(5); ReadOnlySpan remainderSource = data.Slice(TdsEnums.HEADER_LEN + packetDataLength); Span remainderTarget = remainderPacket.Buffer.AsSpan(0, remainderLength); remainderSource.CopyTo(remainderTarget); - #if DEBUG // for debugging purposes fill the removed data area with an easily // recognisable pattern so we can see if it is misused Span removed = dataBuffer.AsSpan(TdsEnums.HEADER_LEN + packetDataLength, remainderLength); removed.Fill(0xFF); #endif - - newDataLength = TdsEnums.HEADER_LEN + packetDataLength; - - consumeInputDirectly = true; createdRemainderPacket = true; + recurse = SetupRemainderPacket(remainderPacket); - if (remainderPacket.HasHeader) - { - remainderPacket.DataLength = Packet.GetDataLengthFromHeader(remainderPacket); - if (remainderPacket.HasDataLength && remainderPacket.CurrentLength >= remainderPacket.RequiredLength) - { - recurse = true; - } - } - + consumeInputDirectly = true; } } } @@ -463,7 +450,7 @@ out bool recurse { // either: // 1) we took some data from the input to reconstruct the partial packet - // 2) there was less than a single packet header of data recieved + // 2) there was less than a single packet header of data received // in both cases we can't tell the caller to directly consume the packet // in the input buffer, we need to construct a new remainder packet with // the incomplete data and let the caller deal with it @@ -472,24 +459,32 @@ out bool recurse Buffer = dataBuffer, CurrentLength = data.Length }; + remainderPacket.SetCreatedBy(6); createdRemainderPacket = true; + recurse = SetupRemainderPacket(remainderPacket); } } -#if DEBUG - //// the Window field is unused by the spec so it can be used as a marker - //// to identify reconstructed packets while debugging - //if (remainderPacket != null && remainderPacket.HasHeader) - //{ - // remainderPacket.Buffer[7] = 0xF; - //} -#endif if (consumePartialPacket && consumeInputDirectly) { - string message = $"MultiplexPackets cannot return both {nameof(consumePartialPacket)} and {nameof(consumeInputDirectly)}"; - System.Diagnostics.Debug.Fail(message); // fail is easier to debug because the exception can be swallowed by higher layers. throw new InvalidOperationException($"MultiplexPackets cannot return both {nameof(consumePartialPacket)} and {nameof(consumeInputDirectly)}"); } } + + private static bool SetupRemainderPacket(Packet packet) + { + Debug.Assert(packet != null); + bool containsFullPacket = false; + if (packet.HasHeader) + { + packet.DataLength = Packet.GetDataLengthFromHeader(packet); + if (packet.HasDataLength && packet.CurrentLength >= packet.RequiredLength) + { + containsFullPacket = true; + } + } + + return containsFullPacket; + } } } diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs index 2f9ae4551c..18a12dd884 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs @@ -2035,7 +2035,7 @@ internal TdsOperationStatus TryReadNetworkPacket() result = TdsOperationStatus.NeedMoreData; } - if (result == TdsOperationStatus.InvalidData && _partialPacket != null && !_partialPacket.IsComplete) + if (result == TdsOperationStatus.InvalidData && PartialPacket != null && !PartialPacket.ContainsCompletePacket) { result = TdsOperationStatus.NeedMoreData; } @@ -2043,11 +2043,10 @@ internal TdsOperationStatus TryReadNetworkPacket() if (_syncOverAsync) { ReadSniSyncOverAsync(); - while (_inBytesRead == 0) { // a partial packet must have taken the packet data so we - // need to read more data to complete the packet but we + // need to read more data to complete the packet, but we // can't return NeedMoreData in sync mode so we have to // spin fetching more data here until we have something // that the caller can read @@ -2087,12 +2086,7 @@ internal void ReadSniSyncOverAsync() } PacketHandle readPacket = default; - bool readFromNetwork = true; - if (_partialPacket != null && _partialPacket.IsComplete) - { - readFromNetwork = false; - } - + bool readFromNetwork = PartialPacketContainsCompletePacket(); uint error; RuntimeHelpers.PrepareConstrainedRegions(); @@ -2130,7 +2124,7 @@ internal void ReadSniSyncOverAsync() Debug.Assert(!IsPacketEmpty(readPacket), "ReadNetworkPacket cannot be null in synchronous operation!"); } - ProcessSniPacket(readPacket, TdsEnums.SNI_SUCCESS, usePartialPacket: !readFromNetwork); + ProcessSniPacket(readPacket, TdsEnums.SNI_SUCCESS); #if DEBUG if (s_forcePendingReadsToWaitForUser) { diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/Microsoft.Data.SqlClient.Tests.csproj b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/Microsoft.Data.SqlClient.Tests.csproj index 5c8697cef8..0e79129699 100644 --- a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/Microsoft.Data.SqlClient.Tests.csproj +++ b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/Microsoft.Data.SqlClient.Tests.csproj @@ -74,6 +74,7 @@ + diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/MultiplexerTests.cs b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/MultiplexerTests.cs index 4840a8c889..01f9b3e218 100644 --- a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/MultiplexerTests.cs +++ b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/MultiplexerTests.cs @@ -3,14 +3,25 @@ using System.Collections.Generic; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; +using System.IO; +using System.Linq; using System.Text; +using System.Text.RegularExpressions; using Xunit; namespace Microsoft.Data.SqlClient.Tests { public class MultiplexerTests { - public static IEnumerable IsAsync() { yield return new object[] { false }; yield return new object[] { true }; } + [ExcludeFromCodeCoverage] + public static IEnumerable IsAsync() + { + yield return new object[] { false }; + yield return new object[] { true }; + } + + [ExcludeFromCodeCoverage] + public static IEnumerable OnlyAsync() { yield return new object[] { true }; } [Theory, MemberData(nameof(IsAsync))] public static void PassThroughSinglePacket(bool isAsync) @@ -60,7 +71,7 @@ public static void ReconstructSinglePacket(bool isAsync) { int dataSize = 4; var a = CreatePacket(dataSize, 0xF); - List input = SplitPacket(a, 1); + List input = SplitPacket(a, 6); List expected = new List { a }; Assert.Equal(SumPacketLengths(expected), SumPacketLengths(input)); @@ -93,12 +104,7 @@ public static void Reconstruct2Packets_Part_PartFull(bool isAsync) public static void Reconstruct2Packets_Full_FullPart_Part(bool isAsync) { int dataSize = 30; - var expected = new List - { - CreatePacket(30, 5), - CreatePacket(10, 6), - CreatePacket(30, 7) - }; + var expected = new List { CreatePacket(30, 5), CreatePacket(10, 6), CreatePacket(30, 7) }; var input = SplitPackets(38, expected, (8 + 30), // full @@ -157,12 +163,7 @@ public static void Reconstruct3Packets_PartPartPart(bool isAsync) { int dataSize = 62; - var expected = new List - { - CreatePacket(26, 5), - CreatePacket(10, 6), - CreatePacket(10, 7) - }; + var expected = new List { CreatePacket(26, 5), CreatePacket(10, 6), CreatePacket(10, 7) }; var input = SplitPackets(70, expected, (8 + 26) + (8 + 10) + (8 + 10) // = 70: full, full, part @@ -178,12 +179,7 @@ public static void TrailingPartialPacketInSnapshotNotDuplicated() { int dataSize = 120; - var expected = new List - { - CreatePacket(120, 5), - CreatePacket(90, 6), - CreatePacket(13, 7), - }; + var expected = new List { CreatePacket(120, 5), CreatePacket(90, 6), CreatePacket(13, 7), }; var input = SplitPackets(120, expected, (8 + 120), @@ -205,26 +201,43 @@ public static void BetweenAsyncAttentionPacket() var attentionPacket = CreatePacket(13, 6); var input = new List { normalPacket, attentionPacket }; - var stateObject = new TdsParserStateObject(input, TdsEnums.HEADER_LEN + dataSize, true); + var stateObject = new TdsParserStateObject(input, TdsEnums.HEADER_LEN + dataSize, isAsync: true); for (int index = 0; index < input.Count; index++) { stateObject.Current = input[index]; - stateObject.ProcessSniPacket(default, 0, usePartialPacket: false); + stateObject.ProcessSniPacket(default, 0); } - // attention packet should be in the current buffer because the snapshot is not active Assert.NotNull(stateObject._inBuff); Assert.Equal(21, stateObject._inBytesRead); Assert.Equal(0, stateObject._inBytesUsed); - - // attention packet should be in the snapshot as well Assert.NotNull(stateObject._snapshot); Assert.NotNull(stateObject._snapshot.List); - Assert.Equal(2, stateObject._snapshot.List.Count); + Assert.Equal(2, stateObject._snapshot.List.Count); + } + [Fact] + public static void MultipleFullPacketsInRemainderAreSplitCorrectly() + { + int dataSize = 800 - TdsEnums.HEADER_LEN; + List expected = new List + { + CreatePacket(dataSize, 5), CreatePacket(80, 6), CreatePacket(21, 7) + }; + + + List input = SplitPacket(CombinePackets(expected), 700); + + var stateObject = new TdsParserStateObject(input, dataSize, isAsync: false); + + var output = MultiplexPacketList(false, dataSize, input); + + ComparePacketLists(dataSize, expected, output); + } + [ExcludeFromCodeCoverage] private static List MultiplexPacketList(bool isAsync, int dataSize, List input) { var stateObject = new TdsParserStateObject(input, TdsEnums.HEADER_LEN + dataSize, isAsync); @@ -234,47 +247,54 @@ private static List MultiplexPacketList(bool isAsync, int dataSize, { stateObject.Current = input[index]; - stateObject.ProcessSniPacket(default, 0, usePartialPacket: false); + stateObject.ProcessSniPacket(default, 0); if (stateObject._inBytesRead > 0) { if ( stateObject._inBytesRead < TdsEnums.HEADER_LEN || - stateObject._inBytesRead != (TdsEnums.HEADER_LEN + Packet.GetDataLengthFromHeader(stateObject._inBuff.AsSpan(0, TdsEnums.HEADER_LEN))) + stateObject._inBytesRead != (TdsEnums.HEADER_LEN + + Packet.GetDataLengthFromHeader( + stateObject._inBuff.AsSpan(0, TdsEnums.HEADER_LEN))) ) { Assert.Fail("incomplete packet exposed after call to ProcessSniPacket"); } + if (!isAsync) { - output.Add(PacketData.Copy(stateObject._inBuff, stateObject._inBytesUsed, stateObject._inBytesRead)); + output.Add(PacketData.Copy(stateObject._inBuff, stateObject._inBytesUsed, + stateObject._inBytesRead)); } } } + if (!isAsync) { - while (stateObject._partialPacket != null) + while (stateObject.PartialPacket != null) { stateObject.Current = default; - stateObject.ProcessSniPacket(default, 0, usePartialPacket: true); + stateObject.ProcessSniPacket(default, 0); if (stateObject._inBytesRead > 0) { if ( stateObject._inBytesRead < TdsEnums.HEADER_LEN || - stateObject._inBytesRead != (TdsEnums.HEADER_LEN + Packet.GetDataLengthFromHeader(stateObject._inBuff.AsSpan(0, TdsEnums.HEADER_LEN))) + stateObject._inBytesRead != (TdsEnums.HEADER_LEN + + Packet.GetDataLengthFromHeader( + stateObject._inBuff.AsSpan(0, TdsEnums.HEADER_LEN))) ) { - Assert.Fail("incomplete packet exposed after call to ProcessSniPacket with usePartialPacket"); - } - if (!isAsync) - { - output.Add(PacketData.Copy(stateObject._inBuff, stateObject._inBytesUsed, stateObject._inBytesRead)); + Assert.Fail( + "incomplete packet exposed after call to ProcessSniPacket with usePartialPacket"); } + + output.Add(PacketData.Copy(stateObject._inBuff, stateObject._inBytesUsed, + stateObject._inBytesRead)); } } @@ -287,6 +307,7 @@ private static List MultiplexPacketList(bool isAsync, int dataSize, return output; } + [ExcludeFromCodeCoverage] private static void ComparePacketLists(int dataSize, List expected, List output) { Assert.NotNull(expected); @@ -307,6 +328,7 @@ private static void ComparePacketLists(int dataSize, List expected, } } + [ExcludeFromCodeCoverage] public static PacketData CreatePacket(int dataSize, byte dataValue, int startOffset = 0, int endPadding = 0) { byte[] buffer = new byte[startOffset + TdsEnums.HEADER_LEN + dataSize + endPadding]; @@ -315,6 +337,7 @@ public static PacketData CreatePacket(int dataSize, byte dataValue, int startOff return new PacketData(buffer, startOffset, buffer.Length - endPadding); } + [ExcludeFromCodeCoverage] public static List CreatePackets(DataSize sizes, params byte[] dataValues) { int count = dataValues.Length; @@ -332,12 +355,14 @@ public static List CreatePackets(DataSize sizes, params byte[] dataV return list; } + [ExcludeFromCodeCoverage] private static void WritePacket(Span buffer, int dataSize, byte dataValue, byte id) { Span header = buffer.Slice(0, TdsEnums.HEADER_LEN); header[0] = 4; // Type, 4 - Raw Data header[1] = 0; // Status, 0 - normal message - BinaryPrimitives.TryWriteInt16BigEndian(header.Slice(TdsEnums.HEADER_LEN_FIELD_OFFSET, 2), (short)(TdsEnums.HEADER_LEN + dataSize)); // total length + BinaryPrimitives.TryWriteInt16BigEndian(header.Slice(TdsEnums.HEADER_LEN_FIELD_OFFSET, 2), + (short)(TdsEnums.HEADER_LEN + dataSize)); // total length BinaryPrimitives.TryWriteInt16BigEndian(header.Slice(TdsEnums.SPID_OFFSET, 2), short.MaxValue); // SPID header[TdsEnums.HEADER_LEN_FIELD_OFFSET + 4] = id; // PacketID header[TdsEnums.HEADER_LEN_FIELD_OFFSET + 5] = 0; // Window @@ -346,6 +371,7 @@ private static void WritePacket(Span buffer, int dataSize, byte dataValue, data.Fill(dataValue); } + [ExcludeFromCodeCoverage] public static List SplitPacket(PacketData packet, int length) { List list = new List(2); @@ -354,13 +380,16 @@ public static List SplitPacket(PacketData packet, int length) list.Add(new PacketData(packet.Array, packet.Start, length)); packet = new PacketData(packet.Array, packet.Start + length, packet.Length - length); } + if (packet.Length > 0) { list.Add(packet); } + return list; } + [ExcludeFromCodeCoverage] public static List SplitPackets(int dataSize, List packets, params int[] lengths) { List list = new List(lengths.Length); @@ -370,8 +399,10 @@ public static List SplitPackets(int dataSize, List packe { if (lengths[index] > packetSize) { - throw new ArgumentOutOfRangeException($"segment size of an individual part cannot exceed the packet buffer size of the state object, max packet size: {packetSize}, supplied length: {lengths[index]}, at index: {index}"); + throw new ArgumentOutOfRangeException( + $"segment size of an individual part cannot exceed the packet buffer size of the state object, max packet size: {packetSize}, supplied length: {lengths[index]}, at index: {index}"); } + arrays[index] = new byte[lengths[index]]; } @@ -381,6 +412,7 @@ public static List SplitPackets(int dataSize, List packe int sourceOffset = 0; int sourceIndex = 0; + do { Span targetSpan = Span.Empty; @@ -424,10 +456,29 @@ public static List SplitPackets(int dataSize, List packe return list; } + [ExcludeFromCodeCoverage] + public static PacketData CombinePackets(List packets) + { + int totalLength = SumPacketLengths(packets); + byte[] buffer = new byte[totalLength]; + int offset = 0; + for (int index = 0; index < packets.Count; index++) + { + PacketData packet = packets[index]; + Array.Copy(packet.Array, packet.Start, buffer, offset, packet.Length); + offset += packet.Length; + } + + return new PacketData(buffer, 0, totalLength); + } + + [ExcludeFromCodeCoverage] public static int PacketSizeFromDataSize(int dataSize) => TdsEnums.HEADER_LEN + dataSize; + [ExcludeFromCodeCoverage] public static int DataSizeFromPacketSize(int packetSize) => packetSize - TdsEnums.HEADER_LEN; + [ExcludeFromCodeCoverage] public static int SumPacketLengths(List list) { int total = 0; @@ -437,9 +488,107 @@ public static int SumPacketLengths(List list) } return total; } - } + [ExcludeFromCodeCoverage] + public static List LoadPacketBinFiles(string directoryName) + { + // expects a set of files contained in a directory with the name + // formatted as packet_{number}_{dataSize}.bin each packet will be + // loaded into a byte[] + + string[] files = Directory.GetFiles(directoryName, "packet*.bin", SearchOption.TopDirectoryOnly); + SortedDictionary packets = new SortedDictionary(); + foreach (string file in files) + { + Match match = Regex.Match(file, @"packet_(?\d+)_(?\d+)\.bin"); + int number = int.Parse(match.Groups["number"].Value); + int size = int.Parse(match.Groups["size"].Value); + packets.Add( + number, + new PacketData( + System.IO.File.ReadAllBytes(file), + 0, + size + ) + ); + } + + return packets.Values.ToList(); + } + + [ExcludeFromCodeCoverage] + public static List NaiveReconstructPacketStream(List input) + { + int dataSize = input[0].Array.Length; + List output = new List(input.Count); + + byte[] currentBuffer = new byte[dataSize]; + int currentBufferOffset = 0; + + foreach (PacketData inputPacket in input) + { + int inputPacketOffset = 0; + while (inputPacketOffset < inputPacket.Length) + { + if (currentBufferOffset < dataSize) + { + int requiredCount = dataSize - currentBufferOffset; + int availableCount = inputPacket.Length - inputPacketOffset; + int copyCount = Math.Min(requiredCount, availableCount); + ReadOnlySpan copyFrom = inputPacket.Array.AsSpan(inputPacketOffset, copyCount); + Span copyTo = currentBuffer.AsSpan(currentBufferOffset, copyCount); + copyFrom.CopyTo(copyTo); + currentBufferOffset += copyCount; + inputPacketOffset += copyCount; + } + + if (currentBufferOffset == dataSize) + { + output.Add(new PacketData(currentBuffer, 0, dataSize)); + currentBufferOffset = 0; + currentBuffer = new byte[dataSize]; + } + } + } + + if (currentBufferOffset > 0) + { + output.Add(new PacketData(currentBuffer, 0, currentBufferOffset)); + } + + for (int index = 0; index < output.Count; index++) + { + PacketData packet = output[index]; + int expectedLength = 8 + Packet.GetDataLengthFromHeader(packet.Array); + if (expectedLength != packet.Length) + { + if (index != output.Count - 1) + { + throw new InvalidOperationException( + "non-terminal packet has a length mismatch between the packet header and amount of data available"); + } + else + { + byte[] remainder = new byte[dataSize]; + int remainderSize = packet.Length - expectedLength; + Span copyFrom = packet.Array.AsSpan(expectedLength, remainderSize); + Span copyTo = remainder.AsSpan(0, remainderSize); + copyFrom.CopyTo(copyTo); + copyFrom.Fill(0); + + PacketData replacementPacket = new PacketData(packet.Array, 0, expectedLength); + PacketData additionalPacket = new PacketData(remainder, 0, remainderSize); + output[index] = replacementPacket; + output.Add(additionalPacket); + } + } + } + + return output; + } + } + [ExcludeFromCodeCoverage] [DebuggerDisplay("{ToDebugString(),nq}")] public readonly struct PacketData { @@ -473,6 +622,7 @@ public static PacketData Copy(byte[] array, int start, int length) newArray = new byte[array.Length]; Buffer.BlockCopy(array, start, newArray, start, length); } + return new PacketData(newArray, start, length); } @@ -488,6 +638,7 @@ public string ToDebugString() { buffer.AppendFormat(" (arr: {0})", Array.Length); } + buffer.Append(": {"); buffer.AppendFormat("{0:D2}", Array[0]); @@ -497,16 +648,21 @@ public string ToDebugString() buffer.Append(','); buffer.AppendFormat("{0:D2}", Array[index]); } + if (Length > max) { buffer.Append(" ..."); } + buffer.Append('}'); } + return buffer.ToString(); } + } + [ExcludeFromCodeCoverage] [DebuggerStepThrough] public struct DataSize { @@ -515,6 +671,7 @@ public DataSize(int commonSize) CommonSize = commonSize; LastSize = commonSize; } + public DataSize(int commonSize, int lastSize) { CommonSize = commonSize; @@ -538,12 +695,12 @@ public int GetSize(bool isLast) public static implicit operator DataSize(int commonSize) { - return new DataSize( commonSize, commonSize ); + return new DataSize(commonSize, commonSize); } public static implicit operator DataSize((int commonSize, int lastSize) values) { - return new DataSize( values.commonSize, values.lastSize ); + return new DataSize(values.commonSize, values.lastSize); } } } diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/TdsParserStateObject.TestHarness.cs b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/TdsParserStateObject.TestHarness.cs index 271502c4da..ac5c093b35 100644 --- a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/TdsParserStateObject.TestHarness.cs +++ b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/TdsParserStateObject.TestHarness.cs @@ -159,118 +159,4 @@ internal enum ParsingErrorState CorruptedTdsStream = 18, ProcessSniPacketFailed = 19, } - - internal sealed class Packet - { - public const int UnknownDataLength = -1; - - private bool _disposed; - private int _dataLength; - private int _totalLength; - private byte[] _buffer; - - public Packet() - { - _disposed = false; - _dataLength = UnknownDataLength; - } - - public int DataLength - { - get - { - CheckDisposed(); - return _dataLength; - } - set - { - CheckDisposed(); - _dataLength = value; - } - } - public byte[] Buffer - { - get - { - CheckDisposed(); - return _buffer; - } - set - { - CheckDisposed(); - _buffer = value; - } - } - public int CurrentLength - { - get - { - CheckDisposed(); - return _totalLength; - } - set - { - CheckDisposed(); - _totalLength = value; - } - } - - public int RequiredLength - { - get - { - CheckDisposed(); - if (!HasDataLength) - { - throw new InvalidOperationException($"cannot get {nameof(RequiredLength)} when {nameof(HasDataLength)} is false"); - } - return TdsEnums.HEADER_LEN + _dataLength; - } - } - - public bool HasHeader => _totalLength >= TdsEnums.HEADER_LEN; - - public bool HasDataLength => _dataLength >= 0; - - public bool IsComplete => _dataLength != UnknownDataLength && (TdsEnums.HEADER_LEN + _dataLength) == _totalLength; - - public ReadOnlySpan GetHeaderSpan() => _buffer.AsSpan(0, TdsEnums.HEADER_LEN); - - public void Dispose() - { - _disposed = true; - } - - public void CheckDisposed() - { - if (_disposed) - { - ThrowDisposed(); - } - } - - public static void ThrowDisposed() - { - throw new ObjectDisposedException(nameof(Packet)); - } - - internal static byte GetStatusFromHeader(ReadOnlySpan header) => header[1]; - - internal static int GetDataLengthFromHeader(ReadOnlySpan header) - { - return (header[TdsEnums.HEADER_LEN_FIELD_OFFSET] << 8 | header[TdsEnums.HEADER_LEN_FIELD_OFFSET + 1]) - TdsEnums.HEADER_LEN; - } - internal static int GetSpidFromHeader(ReadOnlySpan header) - { - return (header[TdsEnums.SPID_OFFSET] << 8 | header[TdsEnums.SPID_OFFSET + 1]); - } - internal static int GetIDFromHeader(ReadOnlySpan header) - { - return header[TdsEnums.HEADER_LEN_FIELD_OFFSET + 4]; - } - - internal static int GetDataLengthFromHeader(Packet packet) => GetDataLengthFromHeader(packet.GetHeaderSpan()); - - internal static bool GetIsEOMFromHeader(ReadOnlySpan header) => GetStatusFromHeader(header) == 1; - } } From 2e924313b539f9c0df94560bd8dd4c6f1e9f956d Mon Sep 17 00:00:00 2001 From: Wraith2 Date: Tue, 29 Oct 2024 02:36:51 +0000 Subject: [PATCH 13/17] fix rebase conflicts --- .../SqlClient/TdsParserStateObject.netcore.cs | 149 ----------------- .../SqlClient/TdsParserStateObject.netfx.cs | 150 +----------------- 2 files changed, 1 insertion(+), 298 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.netcore.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.netcore.cs index 07f10619fc..b7886fe6d9 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.netcore.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.netcore.cs @@ -396,155 +396,6 @@ private void ReadSniError(TdsParserStateObject stateObj, uint error) private uint GetSniPacket(PacketHandle packet, ref uint dataSize) { return SNIPacketGetData(packet, _inBuff, ref dataSize); - if (error != 0) - { - if ((_parser.State == TdsParserState.Closed) || (_parser.State == TdsParserState.Broken)) - { - // Do nothing with callback if closed or broken and error not 0 - callback can occur - // after connection has been closed. PROBLEM IN NETLIB - DESIGN FLAW. - return; - } - - AddError(_parser.ProcessSNIError(this)); - AssertValidState(); - } - else - { - uint dataSize = 0; - bool usedPartialPacket = false; - uint getDataError = 0; - - if (usePartialPacket && _snapshot == null && _partialPacket != null && _partialPacket.IsComplete) - { - //Debug.Assert(_snapshot == null, "_snapshot must be null when processing partial packet instead of network read"); - //Debug.Assert(_partialPacket != null, "_partialPacket must not be null when usePartialPacket is true"); - //Debug.Assert(_partialPacket.IsComplete, "_partialPacket.IsComplete must be true to use it in place of a real read"); - SetBuffer(_partialPacket.Buffer, 0, _partialPacket.CurrentLength); - ClearPartialPacket(); - getDataError = TdsEnums.SNI_SUCCESS; - usedPartialPacket = true; - } - else - { - getDataError = SNIPacketGetData(packet, _inBuff, ref dataSize); - } - - if (getDataError == TdsEnums.SNI_SUCCESS) - { - if (_inBuff.Length < dataSize) - { - Debug.Assert(true, "Unexpected dataSize on Read"); - throw SQL.InvalidInternalPacketSize(StringsHelper.GetString(Strings.SqlMisc_InvalidArraySizeMessage)); - } - - if (!usedPartialPacket) - { - _lastSuccessfulIOTimer._value = DateTime.UtcNow.Ticks; - - SetBuffer(_inBuff, 0, (int)dataSize); - } - - bool recurse; - bool appended = false; - do - { - MultiplexPackets( - _inBuff, _inBytesUsed, _inBytesRead, - _partialPacket, - out int newDataOffset, - out int newDataLength, - out Packet remainderPacket, - out bool consumeInputDirectly, - out bool consumePartialPacket, - out bool remainderPacketProduced, - out recurse - ); - bool bufferIsPartialCompleted = false; - - // if a partial packet was reconstructed it must be handled first - if (consumePartialPacket) - { - if (_snapshot != null) - { - _snapshot.AppendPacketData(_partialPacket.Buffer, _partialPacket.CurrentLength); - appended = true; - } - else - { - SetBuffer(_partialPacket.Buffer, 0, _partialPacket.CurrentLength); - bufferIsPartialCompleted = true; - } - ClearPartialPacket(); - } - - // if the remaining data can be processed directly it must be second - if (consumeInputDirectly) - { - // if some data was taken from the new packet adjust the counters - if (dataSize != newDataLength || 0 != newDataOffset) - { - SetBuffer(_inBuff, newDataOffset, newDataLength); - } - - if (_snapshot != null) - { - _snapshot.AppendPacketData(_inBuff, _inBytesRead); - appended = true; - } - else - { - SetBuffer(_inBuff, 0, _inBytesRead); - } - } - else - { - // whatever is in the input buffer should not be directly consumed - // and is contained in the partial or remainder packets so make sure - // we don't process it - if (!bufferIsPartialCompleted) - { - SetBuffer(_inBuff, 0, 0); - } - } - - // if there is a remainder it must be last - if (remainderPacketProduced) - { - SetPartialPacket(remainderPacket); - if (!bufferIsPartialCompleted) - { - // we are keeping the partial packet buffer so replace it with a new one - // unless we have already set the buffer to the partial packet buffer - SetBuffer(new byte[_inBuff.Length], 0, 0); - } - } - - } while (recurse && _snapshot != null); - - if (_snapshot != null) - { - if (_snapshotStatus != SnapshotStatus.NotActive && appended) - { - _snapshot.MoveNext(); -#if DEBUG - // multiple packets can be appended by demuxing but we should only move - // forward by a single packet so we can no longer assert that we are on - // the last packet at this time - //_snapshot.AssertCurrent(); -#endif - } - } - - SniReadStatisticsAndTracing(); - SqlClientEventSource.Log.TryAdvancedTraceBinEvent("TdsParser.ReadNetworkPacketAsyncCallback | INFO | ADV | State Object Id {0}, Packet read. In Buffer: {1}, In Bytes Read: {2}", ObjectID, _inBuff, _inBytesRead); - - AssertValidState(); - } - else - { - throw SQL.ParsingError(ParsingErrorState.ProcessSniPacketFailed); - } - } } private void ChangeNetworkPacketTimeout(int dueTime, int period) diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.netfx.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.netfx.cs index d7b6db7fc0..1316f700d1 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.netfx.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.netfx.cs @@ -526,155 +526,7 @@ private void ReadSniError(TdsParserStateObject stateObj, uint error) private uint GetSniPacket(PacketHandle packet, ref uint dataSize) { - if (error != 0) - { - if ((_parser.State == TdsParserState.Closed) || (_parser.State == TdsParserState.Broken)) - { - // Do nothing with callback if closed or broken and error not 0 - callback can occur - // after connection has been closed. PROBLEM IN NETLIB - DESIGN FLAW. - return; - } - - AddError(_parser.ProcessSNIError(this)); - AssertValidState(); - } - else - { - uint dataSize = 0; - bool usedPartialPacket = false; - uint getDataError = 0; - - if (usePartialPacket && _snapshot == null && _partialPacket != null && _partialPacket.IsComplete) - { - //Debug.Assert(_snapshot == null, "_snapshot must be null when processing partial packet instead of network read"); - //Debug.Assert(_partialPacket != null, "_partialPacket must not be null when usePartialPacket is true"); - //Debug.Assert(_partialPacket.IsComplete, "_partialPacket.IsComplete must be true to use it in place of a real read"); - SetBuffer(_partialPacket.Buffer, 0, _partialPacket.CurrentLength); - ClearPartialPacket(); - getDataError = TdsEnums.SNI_SUCCESS; - usedPartialPacket = true; - } - else - { - getDataError = SNINativeMethodWrapper.SNIPacketGetData(packet, _inBuff, ref dataSize); - } - - if (getDataError == TdsEnums.SNI_SUCCESS) - { - if (_inBuff.Length < dataSize) - { - Debug.Assert(true, "Unexpected dataSize on Read"); - throw SQL.InvalidInternalPacketSize(StringsHelper.GetString(Strings.SqlMisc_InvalidArraySizeMessage)); - } - - if (!usedPartialPacket) - { - _lastSuccessfulIOTimer._value = DateTime.UtcNow.Ticks; - - SetBuffer(_inBuff, 0, (int)dataSize); - } - - bool recurse; - bool appended = false; - do - { - MultiplexPackets( - _inBuff, _inBytesUsed, _inBytesRead, - _partialPacket, - out int newDataOffset, - out int newDataLength, - out Packet remainderPacket, - out bool consumeInputDirectly, - out bool consumePartialPacket, - out bool remainderPacketProduced, - out recurse - ); - bool bufferIsPartialCompleted = false; - - // if a partial packet was reconstructed it must be handled first - if (consumePartialPacket) - { - if (_snapshot != null) - { - _snapshot.AppendPacketData(_partialPacket.Buffer, _partialPacket.CurrentLength); - appended = true; - } - else - { - SetBuffer(_partialPacket.Buffer, 0, _partialPacket.CurrentLength); - bufferIsPartialCompleted = true; - } - ClearPartialPacket(); - } - - // if the remaining data can be processed directly it must be second - if (consumeInputDirectly) - { - // if some data was taken from the new packet adjust the counters - if (dataSize != newDataLength || 0 != newDataOffset) - { - SetBuffer(_inBuff, newDataOffset, newDataLength); - } - - if (_snapshot != null) - { - _snapshot.AppendPacketData(_inBuff, _inBytesRead); - appended = true; - } - else - { - SetBuffer(_inBuff, 0, _inBytesRead); - } - } - else - { - // whatever is in the input buffer should not be directly consumed - // and is contained in the partial or remainder packets so make sure - // we don't process it - if (!bufferIsPartialCompleted) - { - SetBuffer(_inBuff, 0, 0); - } - } - - // if there is a remainder it must be last - if (remainderPacketProduced) - { - SetPartialPacket(remainderPacket); - if (!bufferIsPartialCompleted) - { - // we are keeping the partial packet buffer so replace it with a new one - // unless we have already set the buffer to the partial packet buffer - SetBuffer(new byte[_inBuff.Length], 0, 0); - } - } - - } while (recurse && _snapshot != null); - - if (_snapshot != null) - { - if (_snapshotStatus != SnapshotStatus.NotActive && appended) - { - _snapshot.MoveNext(); -#if DEBUG - // multiple packets can be appended by demuxing but we should only move - // forward by a single packet so we can no longer assert that we are on - // the last packet at this time - //_snapshot.AssertCurrent(); -#endif - } - } - - SniReadStatisticsAndTracing(); - SqlClientEventSource.Log.TryAdvancedTraceBinEvent("TdsParser.ReadNetworkPacketAsyncCallback | INFO | ADV | State Object Id {0}, Packet read. In Buffer: {1}, In Bytes Read: {2}", ObjectID, _inBuff, _inBytesRead); - - AssertValidState(); - } - else - { - throw SQL.ParsingError(ParsingErrorState.ProcessSniPacketFailed); - } - } + return SNINativeMethodWrapper.SNIPacketGetData(packet, _inBuff, ref dataSize); } private void ChangeNetworkPacketTimeout(int dueTime, int period) From 69dc7d415d60ab1533ad6e87df565cc60bac2596 Mon Sep 17 00:00:00 2001 From: Wraith2 Date: Tue, 29 Oct 2024 09:05:32 +0000 Subject: [PATCH 14/17] review feedback and misc fix --- .../netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs | 2 +- .../src/Microsoft/Data/SqlClient/Packet.cs | 4 +++- .../src/Microsoft/Data/SqlClient/TdsParserStateObject.cs | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs index c3923ce497..0937e8a36f 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs @@ -3533,7 +3533,7 @@ private TdsOperationStatus TryNextResult(out bool more) if (result != TdsOperationStatus.Done) { more = false; - return TdsOperationStatus.Done; + return result; } // In the case of not closing the reader, null out the metadata AFTER diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Packet.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Packet.cs index 802eb2c936..7c97edec57 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Packet.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Packet.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System; +using System.Diagnostics; namespace Microsoft.Data.SqlClient { @@ -101,7 +102,7 @@ public int RequiredLength } /// - /// returns a boolean value indicating if there are enough total bytes availble in the to read the tds header + /// returns a boolean value indicating if there are enough total bytes available in the to read the tds header /// public bool HasHeader => _totalLength >= TdsEnums.HEADER_LEN; @@ -139,6 +140,7 @@ public void CheckDisposed() } } + [Conditional("DEBUG")] internal void SetCreatedBy(int creator) => SetCreatedByImpl(creator); partial void SetCreatedByImpl(int creator); diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs index 18a12dd884..90ddc02569 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs @@ -2086,7 +2086,7 @@ internal void ReadSniSyncOverAsync() } PacketHandle readPacket = default; - bool readFromNetwork = PartialPacketContainsCompletePacket(); + bool readFromNetwork = !PartialPacketContainsCompletePacket(); uint error; RuntimeHelpers.PrepareConstrainedRegions(); From 1f57d73c9e34c3840d6c0964d5b4f54575686450 Mon Sep 17 00:00:00 2001 From: Wraith2 Date: Wed, 30 Oct 2024 22:57:58 +0000 Subject: [PATCH 15/17] protect missed locations where network reads can happen --- .../Data/SqlClient/TdsParser.Windows.cs | 4 +++ .../SqlClient/TdsParserStateObject.netcore.cs | 16 +++++++--- .../SqlClient/TdsParserStateObject.netfx.cs | 16 +++++++--- .../Data/SqlClient/TdsParserStateObject.cs | 29 +++++++++++++------ 4 files changed, 48 insertions(+), 17 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.Windows.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.Windows.cs index 7f15666951..ddbb38b9f8 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.Windows.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.Windows.cs @@ -24,6 +24,10 @@ internal void PostReadAsyncForMars() _pMarsPhysicalConObj.IncrementPendingCallbacks(); SessionHandle handle = _pMarsPhysicalConObj.SessionHandle; + // we do not need to consider partial packets when making this read because we + // expect this read to pend. a partial packet should not exist at setup of the + // parser + Debug.Assert(_physicalStateObj.PartialPacket==null); temp = _pMarsPhysicalConObj.ReadAsync(handle, out error); Debug.Assert(temp.Type == PacketHandle.NativePointerType, "unexpected packet type when requiring NativePointer"); diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.netcore.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.netcore.cs index b7886fe6d9..aafcc4b182 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.netcore.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.netcore.cs @@ -322,14 +322,22 @@ private void ReadSniError(TdsParserStateObject stateObj, uint error) stateObj.SendAttention(mustTakeWriteLock: true); PacketHandle syncReadPacket = default; + bool readFromNetwork = true; RuntimeHelpers.PrepareConstrainedRegions(); bool shouldDecrement = false; try { Interlocked.Increment(ref _readingCount); shouldDecrement = true; - - syncReadPacket = ReadSyncOverAsync(stateObj.GetTimeoutRemaining(), out error); + readFromNetwork = !PartialPacketContainsCompletePacket(); + if (readFromNetwork) + { + syncReadPacket = ReadSyncOverAsync(stateObj.GetTimeoutRemaining(), out error); + } + else + { + error = TdsEnums.SNI_SUCCESS; + } Interlocked.Decrement(ref _readingCount); shouldDecrement = false; @@ -342,7 +350,7 @@ private void ReadSniError(TdsParserStateObject stateObj, uint error) } else { - Debug.Assert(!IsValidPacket(syncReadPacket), "unexpected syncReadPacket without corresponding SNIPacketRelease"); + Debug.Assert(!readFromNetwork || !IsValidPacket(syncReadPacket), "unexpected syncReadPacket without corresponding SNIPacketRelease"); fail = true; // Subsequent read failed, time to give up. } } @@ -353,7 +361,7 @@ private void ReadSniError(TdsParserStateObject stateObj, uint error) Interlocked.Decrement(ref _readingCount); } - if (!IsPacketEmpty(syncReadPacket)) + if (readFromNetwork && !IsPacketEmpty(syncReadPacket)) { ReleasePacket(syncReadPacket); } diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.netfx.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.netfx.cs index 1316f700d1..5f7e87ca61 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.netfx.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.netfx.cs @@ -452,14 +452,22 @@ private void ReadSniError(TdsParserStateObject stateObj, uint error) stateObj.SendAttention(mustTakeWriteLock: true); PacketHandle syncReadPacket = default; + bool readFromNetwork = true; RuntimeHelpers.PrepareConstrainedRegions(); bool shouldDecrement = false; try { Interlocked.Increment(ref _readingCount); shouldDecrement = true; - - syncReadPacket = ReadSyncOverAsync(stateObj.GetTimeoutRemaining(), out error); + readFromNetwork = !PartialPacketContainsCompletePacket(); + if (readFromNetwork) + { + syncReadPacket = ReadSyncOverAsync(stateObj.GetTimeoutRemaining(), out error); + } + else + { + error = TdsEnums.SNI_SUCCESS; + } Interlocked.Decrement(ref _readingCount); shouldDecrement = false; @@ -472,7 +480,7 @@ private void ReadSniError(TdsParserStateObject stateObj, uint error) } else { - Debug.Assert(!IsValidPacket(syncReadPacket), "unexpected syncReadPacket without corresponding SNIPacketRelease"); + Debug.Assert(!readFromNetwork || !IsValidPacket(syncReadPacket), "unexpected syncReadPacket without corresponding SNIPacketRelease"); fail = true; // Subsequent read failed, time to give up. } } @@ -483,7 +491,7 @@ private void ReadSniError(TdsParserStateObject stateObj, uint error) Interlocked.Decrement(ref _readingCount); } - if (!IsPacketEmpty(syncReadPacket)) + if (readFromNetwork && !IsPacketEmpty(syncReadPacket)) { // Be sure to release packet, otherwise it will be leaked by native. ReleasePacket(syncReadPacket); diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs index 90ddc02569..cac4abefd5 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs @@ -2382,6 +2382,7 @@ internal void ReadSni(TaskCompletionSource completion) PacketHandle readPacket = default; uint error = 0; + bool readFromNetwork = true; RuntimeHelpers.PrepareConstrainedRegions(); try @@ -2427,17 +2428,27 @@ internal void ReadSni(TaskCompletionSource completion) Interlocked.Increment(ref _readingCount); handle = SessionHandle; - if (!handle.IsNull) + + readFromNetwork = !PartialPacketContainsCompletePacket(); + if (readFromNetwork) { - IncrementPendingCallbacks(); + if (!handle.IsNull) + { + IncrementPendingCallbacks(); - readPacket = ReadAsync(handle, out error); + readPacket = ReadAsync(handle, out error); - if (!(TdsEnums.SNI_SUCCESS == error || TdsEnums.SNI_SUCCESS_IO_PENDING == error)) - { - DecrementPendingCallbacks(false); // Failure - we won't receive callback! + if (!(TdsEnums.SNI_SUCCESS == error || TdsEnums.SNI_SUCCESS_IO_PENDING == error)) + { + DecrementPendingCallbacks(false); // Failure - we won't receive callback! + } } } + else + { + readPacket = default; + error = TdsEnums.SNI_SUCCESS; + } Interlocked.Decrement(ref _readingCount); } @@ -2449,12 +2460,12 @@ internal void ReadSni(TaskCompletionSource completion) if (TdsEnums.SNI_SUCCESS == error) { // Success - process results! - Debug.Assert(IsValidPacket(readPacket), "ReadNetworkPacket should not have been null on this async operation!"); + Debug.Assert(!readFromNetwork || IsValidPacket(readPacket) , "ReadNetworkPacket should not have been null on this async operation!"); // Evaluate this condition for MANAGED_SNI. This may not be needed because the network call is happening Async and only the callback can receive a success. ReadAsyncCallback(IntPtr.Zero, readPacket, 0); // Only release packet for Managed SNI as for Native SNI packet is released in finally block. - if (TdsParserStateObjectFactory.UseManagedSNI && !IsPacketEmpty(readPacket)) + if (TdsParserStateObjectFactory.UseManagedSNI && readFromNetwork && !IsPacketEmpty(readPacket)) { ReleasePacket(readPacket); } @@ -2492,7 +2503,7 @@ internal void ReadSni(TaskCompletionSource completion) { if (!TdsParserStateObjectFactory.UseManagedSNI) { - if (!IsPacketEmpty(readPacket)) + if (readFromNetwork && !IsPacketEmpty(readPacket)) { // Be sure to release packet, otherwise it will be leaked by native. ReleasePacket(readPacket); From 00c242602c26c0432c990960a839768e677ca948 Mon Sep 17 00:00:00 2001 From: Wraith2 Date: Thu, 31 Oct 2024 14:43:21 +0000 Subject: [PATCH 16/17] fix CheckPacket assertion --- .../Microsoft/Data/SqlClient/TdsParserStateObject.netcore.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.netcore.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.netcore.cs index aafcc4b182..694b731bb7 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.netcore.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.netcore.cs @@ -492,7 +492,7 @@ public void ReadAsyncCallback(IntPtr key, PacketHandle packet, uint error) bool processFinallyBlock = true; try { - Debug.Assert(CheckPacket(packet, source) && source != null, "AsyncResult null on callback"); + Debug.Assert((packet.Type == 0 && PartialPacketContainsCompletePacket()) || (CheckPacket(packet, source) && source != null), "AsyncResult null on callback"); if (_parser.MARSOn) { @@ -504,7 +504,7 @@ public void ReadAsyncCallback(IntPtr key, PacketHandle packet, uint error) // The timer thread may be unreliable under high contention scenarios. It cannot be // assumed that the timeout has happened on the timer thread callback. Check the timeout - // synchrnously and then call OnTimeoutSync to force an atomic change of state. + // synchronously and then call OnTimeoutSync to force an atomic change of state. if (TimeoutHasExpired) { OnTimeoutSync(asyncClose: true); From f75993aa89e5cb9a72cc788b2765e42d07b7be12 Mon Sep 17 00:00:00 2001 From: Michel Zehnder Date: Mon, 4 Nov 2024 11:02:10 +0100 Subject: [PATCH 17/17] I've had issues with readingCount and traced it back to this Other decrements also seem to be using a try/finally pattern --- .../Data/SqlClient/TdsParserStateObject.cs | 35 +++++++++++-------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs index cac4abefd5..fb01ee081c 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs @@ -2427,30 +2427,35 @@ internal void ReadSni(TaskCompletionSource completion) { Interlocked.Increment(ref _readingCount); - handle = SessionHandle; - - readFromNetwork = !PartialPacketContainsCompletePacket(); - if (readFromNetwork) + try { - if (!handle.IsNull) + handle = SessionHandle; + + readFromNetwork = !PartialPacketContainsCompletePacket(); + if (readFromNetwork) { - IncrementPendingCallbacks(); + if (!handle.IsNull) + { + IncrementPendingCallbacks(); - readPacket = ReadAsync(handle, out error); + readPacket = ReadAsync(handle, out error); - if (!(TdsEnums.SNI_SUCCESS == error || TdsEnums.SNI_SUCCESS_IO_PENDING == error)) - { - DecrementPendingCallbacks(false); // Failure - we won't receive callback! + if (!(TdsEnums.SNI_SUCCESS == error || TdsEnums.SNI_SUCCESS_IO_PENDING == error)) + { + DecrementPendingCallbacks(false); // Failure - we won't receive callback! + } } } + else + { + readPacket = default; + error = TdsEnums.SNI_SUCCESS; + } } - else + finally { - readPacket = default; - error = TdsEnums.SNI_SUCCESS; + Interlocked.Decrement(ref _readingCount); } - - Interlocked.Decrement(ref _readingCount); } if (handle.IsNull)