diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj index 926ebf3580..c9de00df6d 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj @@ -928,6 +928,9 @@ Interop\Windows\Sni\SniNativeWrapper.cs + + Interop\Windows\Sni\SniSslProtocols.cs + Interop\Windows\Sni\TransparentNetworkResolutionMode.cs @@ -973,12 +976,13 @@ Microsoft\Data\SqlClient\TdsParserStateObjectFactory.Windows.cs + + Microsoft\Data\SqlClient\TdsParserStateObjectNative.Windows.cs + Microsoft\Data\SqlTypes\SqlFileStream.Windows.cs - - ILLink.Substitutions.xml Resources\ILLink.Substitutions.Windows.xml diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs deleted file mode 100644 index a553b43dde..0000000000 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs +++ /dev/null @@ -1,554 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; -using System.Collections.Generic; -using System.Diagnostics; -using System.Net; -using System.Runtime.InteropServices; -using System.Security.Authentication; -using System.Threading.Tasks; -using Interop.Windows.Sni; -using Microsoft.Data.Common; -using Microsoft.Data.ProviderBase; - -namespace Microsoft.Data.SqlClient -{ - internal class TdsParserStateObjectNative : TdsParserStateObject - { - // protocol versions from native sni - [Flags] - private enum NativeProtocols - { - SP_PROT_SSL2_SERVER = 0x00000004, - SP_PROT_SSL2_CLIENT = 0x00000008, - SP_PROT_SSL3_SERVER = 0x00000010, - SP_PROT_SSL3_CLIENT = 0x00000020, - SP_PROT_TLS1_0_SERVER = 0x00000040, - SP_PROT_TLS1_0_CLIENT = 0x00000080, - SP_PROT_TLS1_1_SERVER = 0x00000100, - SP_PROT_TLS1_1_CLIENT = 0x00000200, - SP_PROT_TLS1_2_SERVER = 0x00000400, - SP_PROT_TLS1_2_CLIENT = 0x00000800, - SP_PROT_TLS1_3_SERVER = 0x00001000, - SP_PROT_TLS1_3_CLIENT = 0x00002000, - SP_PROT_NONE = 0x0 - } - - private SNIHandle _sessionHandle = null; // the SNI handle we're to work on - - private SNIPacket _sniPacket = null; // Will have to re-vamp this for MARS - internal SNIPacket _sniAsyncAttnPacket = null; // Packet to use to send Attn - private readonly WritePacketCache _writePacketCache = new WritePacketCache(); // Store write packets that are ready to be re-used - - private GCHandle _gcHandle; // keeps this object alive until we're closed. - - private readonly Dictionary _pendingWritePackets = new Dictionary(); // Stores write packets that have been sent to SNI, but have not yet finished writing (i.e. we are waiting for SNI's callback) - - internal TdsParserStateObjectNative(TdsParser parser, TdsParserStateObject physicalConnection, bool async) - : base(parser, physicalConnection, async) - { - } - - internal TdsParserStateObjectNative(TdsParser parser) - : base(parser) - { - } - - #region Properties - - internal SNIHandle Handle => _sessionHandle; - - internal override uint Status => _sessionHandle != null ? _sessionHandle.Status : TdsEnums.SNI_UNINITIALIZED; - - internal override SessionHandle SessionHandle => SessionHandle.FromNativeHandle(_sessionHandle); - - protected override PacketHandle EmptyReadPacket => PacketHandle.FromNativePointer(default); - - internal override Guid? SessionId => default; - - #endregion - - protected override void CreateSessionHandle(TdsParserStateObject physicalConnection, bool async) - { - Debug.Assert(physicalConnection is TdsParserStateObjectNative, "Expected a stateObject of type " + this.GetType()); - TdsParserStateObjectNative nativeSNIObject = physicalConnection as TdsParserStateObjectNative; - ConsumerInfo myInfo = CreateConsumerInfo(async); - - SQLDNSInfo cachedDNSInfo; - bool ret = SQLFallbackDNSCache.Instance.GetDNSInfo(_parser.FQDNforDNSCache, out cachedDNSInfo); - - _sessionHandle = new SNIHandle(myInfo, nativeSNIObject.Handle, _parser.Connection.ConnectionOptions.IPAddressPreference, cachedDNSInfo); - } - - // Retrieve the IP and port number from native SNI for TCP protocol. The IP information is stored temporarily in the - // pendingSQLDNSObject but not in the DNS Cache at this point. We only add items to the DNS Cache after we receive the - // IsSupported flag as true in the feature ext ack from server. - internal override void AssignPendingDNSInfo(string userProtocol, string DNSCacheKey, ref SQLDNSInfo pendingDNSInfo) - { - uint result; - ushort portFromSNI = 0; - string IPStringFromSNI = string.Empty; - IPAddress IPFromSNI; - _parser.isTcpProtocol = false; - Provider providerNumber = Provider.INVALID_PROV; - - if (string.IsNullOrEmpty(userProtocol)) - { - - result = SniNativeWrapper.SniGetProviderNumber(Handle, ref providerNumber); - Debug.Assert(result == TdsEnums.SNI_SUCCESS, "Unexpected failure state upon calling SniGetProviderNumber"); - _parser.isTcpProtocol = (providerNumber == Provider.TCP_PROV); - } - else if (userProtocol == TdsEnums.TCP) - { - _parser.isTcpProtocol = true; - } - - // serverInfo.UserProtocol could be empty - if (_parser.isTcpProtocol) - { - result = SniNativeWrapper.SniGetConnectionPort(Handle, ref portFromSNI); - Debug.Assert(result == TdsEnums.SNI_SUCCESS, "Unexpected failure state upon calling SniGetConnectionPort"); - - result = SniNativeWrapper.SniGetConnectionIpString(Handle, ref IPStringFromSNI); - Debug.Assert(result == TdsEnums.SNI_SUCCESS, "Unexpected failure state upon calling SniGetConnectionIPString"); - - pendingDNSInfo = new SQLDNSInfo(DNSCacheKey, null, null, portFromSNI.ToString()); - - if (IPAddress.TryParse(IPStringFromSNI, out IPFromSNI)) - { - if (System.Net.Sockets.AddressFamily.InterNetwork == IPFromSNI.AddressFamily) - { - pendingDNSInfo.AddrIPv4 = IPStringFromSNI; - } - else if (System.Net.Sockets.AddressFamily.InterNetworkV6 == IPFromSNI.AddressFamily) - { - pendingDNSInfo.AddrIPv6 = IPStringFromSNI; - } - } - } - else - { - pendingDNSInfo = null; - } - } - - private ConsumerInfo CreateConsumerInfo(bool async) - { - ConsumerInfo myInfo = new ConsumerInfo(); - - Debug.Assert(_outBuff.Length == _inBuff.Length, "Unexpected unequal buffers."); - - myInfo.defaultBufferSize = _outBuff.Length; // Obtain packet size from outBuff size. - - if (async) - { - myInfo.readDelegate = SNILoadHandle.SingletonInstance.ReadAsyncCallbackDispatcher; - myInfo.writeDelegate = SNILoadHandle.SingletonInstance.WriteAsyncCallbackDispatcher; - _gcHandle = GCHandle.Alloc(this, GCHandleType.Normal); - myInfo.key = (IntPtr)_gcHandle; - } - return myInfo; - } - - internal override void CreatePhysicalSNIHandle( - string serverName, - TimeoutTimer timeout, - out byte[] instanceName, - out ManagedSni.ResolvedServerSpn resolvedSpn, - bool flushCache, - bool async, - bool fParallel, - TransparentNetworkResolutionState transparentNetworkResolutionState, - int totalTimeout, - SqlConnectionIPAddressPreference iPAddressPreference, - string cachedFQDN, - ref SQLDNSInfo pendingDNSInfo, - string serverSPN, - bool isIntegratedSecurity, - bool tlsFirst, - string hostNameInCertificate, - string serverCertificateFilename) - { - if (isIntegratedSecurity) - { - // now allocate proper length of buffer - if (!string.IsNullOrEmpty(serverSPN)) - { - // Native SNI requires the Unicode encoding and any other encoding like UTF8 breaks the code. - SqlClientEventSource.Log.TryTraceEvent("<{0}.{1}|SEC> Server SPN `{2}` from the connection string is used.", nameof(TdsParserStateObjectNative), nameof(CreatePhysicalSNIHandle), serverSPN); - } - else - { - // This will signal to the interop layer that we need to retrieve the SPN - serverSPN = string.Empty; - } - } - - ConsumerInfo myInfo = CreateConsumerInfo(async); - SQLDNSInfo cachedDNSInfo; - bool ret = SQLFallbackDNSCache.Instance.GetDNSInfo(cachedFQDN, out cachedDNSInfo); - - _sessionHandle = new SNIHandle(myInfo, serverName, ref serverSPN, timeout.MillisecondsRemainingInt, out instanceName, - flushCache, !async, fParallel, iPAddressPreference, cachedDNSInfo, hostNameInCertificate); - resolvedSpn = new(serverSPN.TrimEnd()); - } - - protected override uint SniPacketGetData(PacketHandle packet, byte[] _inBuff, ref uint dataSize) - { - Debug.Assert(packet.Type == PacketHandle.NativePointerType, "unexpected packet type when requiring NativePointer"); - return SniNativeWrapper.SniPacketGetData(packet.NativePointer, _inBuff, ref dataSize); - } - - protected override bool CheckPacket(PacketHandle packet, TaskCompletionSource source) - { - Debug.Assert(packet.Type == PacketHandle.NativePointerType, "unexpected packet type when requiring NativePointer"); - IntPtr ptr = packet.NativePointer; - return IntPtr.Zero == ptr || IntPtr.Zero != ptr && source != null; - } - - public void ReadAsyncCallback(IntPtr key, IntPtr packet, uint error) => ReadAsyncCallback(key, packet, error); - - public void WriteAsyncCallback(IntPtr key, IntPtr packet, uint sniError) => WriteAsyncCallback(key, packet, sniError); - - protected override void RemovePacketFromPendingList(PacketHandle ptr) - { - Debug.Assert(ptr.Type == PacketHandle.NativePointerType, "unexpected packet type when requiring NativePointer"); - IntPtr pointer = ptr.NativePointer; - - lock (_writePacketLockObject) - { - if (_pendingWritePackets.TryGetValue(pointer, out SNIPacket recoveredPacket)) - { - _pendingWritePackets.Remove(pointer); - _writePacketCache.Add(recoveredPacket); - } -#if DEBUG - else - { - Debug.Fail("Removing a packet from the pending list that was never added to it"); - } -#endif - } - } - - internal override void Dispose() - { - SafeHandle packetHandle = _sniPacket; - SafeHandle sessionHandle = _sessionHandle; - SafeHandle asyncAttnPacket = _sniAsyncAttnPacket; - - _sniPacket = null; - _sessionHandle = null; - _sniAsyncAttnPacket = null; - - DisposeCounters(); - - if (sessionHandle != null || packetHandle != null) - { - packetHandle?.Dispose(); - asyncAttnPacket?.Dispose(); - - if (sessionHandle != null) - { - sessionHandle.Dispose(); - DecrementPendingCallbacks(true); // Will dispose of GC handle. - } - } - - DisposePacketCache(); - } - - protected override void FreeGcHandle(int remaining, bool release) - { - if ((0 == remaining || release) && _gcHandle.IsAllocated) - { - _gcHandle.Free(); - } - } - - internal override bool IsFailedHandle() => _sessionHandle.Status != TdsEnums.SNI_SUCCESS; - - internal override bool IsPacketEmpty(PacketHandle readPacket) - { - Debug.Assert(readPacket.Type == PacketHandle.NativePointerType || readPacket.Type == 0, "unexpected packet type when requiring NativePointer"); - return IntPtr.Zero == readPacket.NativePointer; - } - - internal override void ReleasePacket(PacketHandle syncReadPacket) - { - Debug.Assert(syncReadPacket.Type == PacketHandle.NativePointerType, "unexpected packet type when requiring NativePointer"); - SniNativeWrapper.SniPacketRelease(syncReadPacket.NativePointer); - } - - internal override uint CheckConnection() - { - SNIHandle handle = Handle; - return handle == null ? TdsEnums.SNI_SUCCESS : SniNativeWrapper.SniCheckConnection(handle); - } - - internal override PacketHandle ReadAsync(SessionHandle handle, out uint error) - { - Debug.Assert(handle.Type == SessionHandle.NativeHandleType, "unexpected handle type when requiring NativePointer"); - IntPtr readPacketPtr = IntPtr.Zero; - error = SniNativeWrapper.SniReadAsync(handle.NativeHandle, ref readPacketPtr); - return PacketHandle.FromNativePointer(readPacketPtr); - } - - internal override PacketHandle ReadSyncOverAsync(int timeoutRemaining, out uint error) - { - SNIHandle handle = Handle ?? throw ADP.ClosedConnectionError(); - IntPtr readPacketPtr = IntPtr.Zero; - error = SniNativeWrapper.SniReadSyncOverAsync(handle, ref readPacketPtr, GetTimeoutRemaining()); - return PacketHandle.FromNativePointer(readPacketPtr); - } - - internal override PacketHandle CreateAndSetAttentionPacket() - { - SNIPacket attnPacket = new SNIPacket(Handle); - _sniAsyncAttnPacket = attnPacket; - SniNativeWrapper.SniPacketSetData(attnPacket, SQL.AttentionHeader, TdsEnums.HEADER_LEN); - return PacketHandle.FromNativePacket(attnPacket); - } - - internal override uint WritePacket(PacketHandle packet, bool sync) - { - Debug.Assert(packet.Type == PacketHandle.NativePacketType, "unexpected packet type when requiring NativePacket"); - return SniNativeWrapper.SniWritePacket(Handle, packet.NativePacket, sync); - } - - internal override PacketHandle AddPacketToPendingList(PacketHandle packetToAdd) - { - Debug.Assert(packetToAdd.Type == PacketHandle.NativePacketType, "unexpected packet type when requiring NativePacket"); - SNIPacket packet = packetToAdd.NativePacket; - Debug.Assert(packet == _sniPacket, "Adding a packet other than the current packet to the pending list"); - _sniPacket = null; - IntPtr pointer = packet.DangerousGetHandle(); - - lock (_writePacketLockObject) - { - _pendingWritePackets.Add(pointer, packet); - } - - return PacketHandle.FromNativePointer(pointer); - } - - internal override bool IsValidPacket(PacketHandle packetPointer) - { - Debug.Assert(packetPointer.Type == PacketHandle.NativePointerType || packetPointer.Type == PacketHandle.NativePacketType, "unexpected packet type when requiring NativePointer"); - - return (packetPointer.Type == PacketHandle.NativePointerType && packetPointer.NativePointer != IntPtr.Zero) - || (packetPointer.Type == PacketHandle.NativePacketType && packetPointer.NativePacket != null); - } - - internal override PacketHandle GetResetWritePacket(int dataSize) - { - if (_sniPacket != null) - { - SniNativeWrapper.SniPacketReset(Handle, IoType.WRITE, _sniPacket, ConsumerNumber.SNI_Consumer_SNI); - } - else - { - lock (_writePacketLockObject) - { - _sniPacket = _writePacketCache.Take(Handle); - } - } - return PacketHandle.FromNativePacket(_sniPacket); - } - - internal override void ClearAllWritePackets() - { - if (_sniPacket != null) - { - _sniPacket.Dispose(); - _sniPacket = null; - } - lock (_writePacketLockObject) - { - Debug.Assert(_pendingWritePackets.Count == 0 && _asyncWriteCount == 0, "Should not clear all write packets if there are packets pending"); - _writePacketCache.Clear(); - } - } - - internal override void SetPacketData(PacketHandle packet, byte[] buffer, int bytesUsed) - { - Debug.Assert(packet.Type == PacketHandle.NativePacketType, "unexpected packet type when requiring NativePacket"); - SniNativeWrapper.SniPacketSetData(packet.NativePacket, buffer, bytesUsed); - } - - internal override uint SniGetConnectionId(ref Guid clientConnectionId) - => SniNativeWrapper.SniGetConnectionId(Handle, ref clientConnectionId); - - internal override uint DisableSsl() - => SniNativeWrapper.SniRemoveProvider(Handle, Provider.SSL_PROV); - - internal override uint EnableMars(ref uint info) - => SniNativeWrapper.SniAddProvider(Handle, Provider.SMUX_PROV, ref info); - - internal override uint PostReadAsyncForMars(TdsParserStateObject physicalStateObject) - { - // HACK HACK HACK - for Async only - // Have to post read to initialize MARS - will get callback on this when connection goes - // down or is closed. - - PacketHandle temp = default; - uint error = TdsEnums.SNI_SUCCESS; - - IncrementPendingCallbacks(); - SessionHandle handle = SessionHandle; - // we do not need to consider partial packets when making this read because we - // expect this read to pend. a partial packet should not exist at setup of the - // parser - Debug.Assert(physicalStateObject.PartialPacket == null); - temp = ReadAsync(handle, out error); - - Debug.Assert(temp.Type == PacketHandle.NativePointerType, "unexpected packet type when requiring NativePointer"); - - if (temp.NativePointer != IntPtr.Zero) - { - // Be sure to release packet, otherwise it will be leaked by native. - ReleasePacket(temp); - } - - Debug.Assert(IntPtr.Zero == temp.NativePointer, "unexpected syncReadPacket without corresponding SNIPacketRelease"); - return error; - } - - internal override uint EnableSsl(ref uint info, bool tlsFirst, string serverCertificateFilename) - { - AuthProviderInfo authInfo = new AuthProviderInfo(); - authInfo.flags = info; - authInfo.tlsFirst = tlsFirst; - authInfo.serverCertFileName = serverCertificateFilename; - - // Add SSL (Encryption) SNI provider. - return SniNativeWrapper.SniAddProvider(Handle, Provider.SSL_PROV, ref authInfo); - } - - internal override uint SetConnectionBufferSize(ref uint unsignedPacketSize) - => SniNativeWrapper.SniSetInfo(Handle, QueryType.SNI_QUERY_CONN_BUFSIZE, ref unsignedPacketSize); - - internal override uint WaitForSSLHandShakeToComplete(out SslProtocols protocolVersion) - { - uint returnValue = SniNativeWrapper.SniWaitForSslHandshakeToComplete(Handle, GetTimeoutRemaining(), out uint nativeProtocolVersion); - var nativeProtocol = (NativeProtocols)nativeProtocolVersion; - -#pragma warning disable CA5398 // Avoid hardcoded SslProtocols values - if (nativeProtocol.HasFlag(NativeProtocols.SP_PROT_TLS1_2_CLIENT) || nativeProtocol.HasFlag(NativeProtocols.SP_PROT_TLS1_2_SERVER)) - { - protocolVersion = SslProtocols.Tls12; - } - else if (nativeProtocol.HasFlag(NativeProtocols.SP_PROT_TLS1_3_CLIENT) || nativeProtocol.HasFlag(NativeProtocols.SP_PROT_TLS1_3_SERVER)) - { - /* The SslProtocols.Tls13 is supported by netcoreapp3.1 and later */ - protocolVersion = SslProtocols.Tls13; - } - else if (nativeProtocol.HasFlag(NativeProtocols.SP_PROT_TLS1_1_CLIENT) || nativeProtocol.HasFlag(NativeProtocols.SP_PROT_TLS1_1_SERVER)) - { - protocolVersion = SslProtocols.Tls11; - } - else if (nativeProtocol.HasFlag(NativeProtocols.SP_PROT_TLS1_0_CLIENT) || nativeProtocol.HasFlag(NativeProtocols.SP_PROT_TLS1_0_SERVER)) - { - protocolVersion = SslProtocols.Tls; - } - else if (nativeProtocol.HasFlag(NativeProtocols.SP_PROT_SSL3_CLIENT) || nativeProtocol.HasFlag(NativeProtocols.SP_PROT_SSL3_SERVER)) - { - // SSL 2.0 and 3.0 are only referenced to log a warning, not explicitly used for connections -#pragma warning disable CS0618, CA5397 - protocolVersion = SslProtocols.Ssl3; - } - else if (nativeProtocol.HasFlag(NativeProtocols.SP_PROT_SSL2_CLIENT) || nativeProtocol.HasFlag(NativeProtocols.SP_PROT_SSL2_SERVER)) - { - protocolVersion = SslProtocols.Ssl2; -#pragma warning restore CS0618, CA5397 - } - else //if (nativeProtocol.HasFlag(NativeProtocols.SP_PROT_NONE)) - { - protocolVersion = SslProtocols.None; - } -#pragma warning restore CA5398 // Avoid hardcoded SslProtocols values - return returnValue; - } - - internal override SniErrorDetails GetErrorDetails() - { - SniNativeWrapper.SniGetLastError(out SniError sniError); - - return new SniErrorDetails(sniError.errorMessage, sniError.nativeError, sniError.sniError, - (int)sniError.provider, sniError.lineNumber, sniError.function); - } - - internal override void DisposePacketCache() - { - lock (_writePacketLockObject) - { - _writePacketCache.Dispose(); - // Do not set _writePacketCache to null, just in case a WriteAsyncCallback completes after this point - } - } - - internal override SspiContextProvider CreateSspiContextProvider() => new NativeSspiContextProvider(); - - internal sealed class WritePacketCache : IDisposable - { - private bool _disposed; - private Stack _packets; - - public WritePacketCache() - { - _disposed = false; - _packets = new Stack(); - } - - public SNIPacket Take(SNIHandle sniHandle) - { - SNIPacket packet; - if (_packets.Count > 0) - { - // Success - reset the packet - packet = _packets.Pop(); - SniNativeWrapper.SniPacketReset(sniHandle, IoType.WRITE, packet, ConsumerNumber.SNI_Consumer_SNI); - } - else - { - // Failed to take a packet - create a new one - packet = new SNIPacket(sniHandle); - } - return packet; - } - - public void Add(SNIPacket packet) - { - if (!_disposed) - { - _packets.Push(packet); - } - else - { - // If we're disposed, then get rid of any packets added to us - packet.Dispose(); - } - } - - public void Clear() - { - while (_packets.Count > 0) - { - _packets.Pop().Dispose(); - } - } - - public void Dispose() - { - if (!_disposed) - { - _disposed = true; - Clear(); - } - } - } - } -} diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj index 41a3fe3f68..de1e9b7c1f 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj @@ -219,6 +219,9 @@ Interop\Windows\Sni\SqlDependencyProcessDispatcherStorage.netfx.cs + + Interop\Windows\Sni\SniSslProtocols.cs + Interop\Windows\Sni\TransparentNetworkResolutionMode.cs @@ -924,6 +927,9 @@ Microsoft\Data\SqlClient\TdsParserStateObjectFactory.Windows.cs + + Microsoft\Data\SqlClient\TdsParserStateObjectNative.Windows.cs + Microsoft\Data\SqlClient\TdsParserStaticMethods.cs @@ -996,7 +1002,6 @@ - diff --git a/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/ISniNativeMethods.cs b/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/ISniNativeMethods.cs index 109b7c27ae..c05c76b188 100644 --- a/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/ISniNativeMethods.cs +++ b/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/ISniNativeMethods.cs @@ -92,7 +92,7 @@ unsafe uint SniSecGenClientContextWrapper( uint SniTerminate(); - uint SniWaitForSslHandshakeToComplete(SNIHandle pConn, int dwMilliseconds, out uint pProtocolVersion); + uint SniWaitForSslHandshakeToComplete(SNIHandle pConn, int dwMilliseconds, out SniSslProtocols pProtocolVersion); uint SniWriteAsyncWrapper(SNIHandle pConn, SNIPacket pPacket); diff --git a/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeMethods.netcore.cs b/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeMethods.netcore.cs index a47af81e9f..b780f8f63c 100644 --- a/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeMethods.netcore.cs +++ b/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeMethods.netcore.cs @@ -148,7 +148,7 @@ public uint SniSetInfo(SNIHandle pConn, QueryType queryType, ref uint pbQueryInf public uint SniTerminate() => SNITerminate(); - public uint SniWaitForSslHandshakeToComplete(SNIHandle pConn, int dwMilliseconds, out uint pProtocolVersion) => + public uint SniWaitForSslHandshakeToComplete(SNIHandle pConn, int dwMilliseconds, out SniSslProtocols pProtocolVersion) => SNIWaitForSSLHandshakeToCompleteWrapper(pConn, dwMilliseconds, out pProtocolVersion); public uint SniWriteAsyncWrapper(SNIHandle pConn, SNIPacket pPacket) => @@ -299,7 +299,7 @@ private static extern int SNIServerEnumReadWrapper( private static extern uint SNIWaitForSSLHandshakeToCompleteWrapper( [In] SNIHandle pConn, int dwMilliseconds, - out uint pProtocolVersion); + out SniSslProtocols pProtocolVersion); [DllImport(DllName, CallingConvention = CallingConvention.Cdecl)] private static extern uint SNIWriteAsyncWrapper(SNIHandle pConn, [In] SNIPacket pPacket); diff --git a/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeMethodsArm64.netfx.cs b/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeMethodsArm64.netfx.cs index 1c60f92443..7839a047cd 100644 --- a/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeMethodsArm64.netfx.cs +++ b/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeMethodsArm64.netfx.cs @@ -148,7 +148,7 @@ public uint SniSetInfo(SNIHandle pConn, QueryType queryType, ref uint pbQueryInf public uint SniTerminate() => SNITerminate(); - public uint SniWaitForSslHandshakeToComplete(SNIHandle pConn, int dwMilliseconds, out uint pProtocolVersion) => + public uint SniWaitForSslHandshakeToComplete(SNIHandle pConn, int dwMilliseconds, out SniSslProtocols pProtocolVersion) => SNIWaitForSSLHandshakeToCompleteWrapper(pConn, dwMilliseconds, out pProtocolVersion); public uint SniWriteAsyncWrapper(SNIHandle pConn, SNIPacket pPacket) => @@ -299,7 +299,7 @@ private static extern int SNIServerEnumReadWrapper( private static extern uint SNIWaitForSSLHandshakeToCompleteWrapper( [In] SNIHandle pConn, int dwMilliseconds, - out uint pProtocolVersion); + out SniSslProtocols pProtocolVersion); [DllImport(DllName, CallingConvention = CallingConvention.Cdecl)] private static extern uint SNIWriteAsyncWrapper(SNIHandle pConn, [In] SNIPacket pPacket); diff --git a/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeMethodsNotSupported.netfx.cs b/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeMethodsNotSupported.netfx.cs index ba26fb0dc9..7139ee0b08 100644 --- a/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeMethodsNotSupported.netfx.cs +++ b/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeMethodsNotSupported.netfx.cs @@ -142,7 +142,7 @@ public uint SniSetInfo(SNIHandle pConn, QueryType queryType, ref uint pbQueryInf public uint SniTerminate() => throw ADP.SNIPlatformNotSupported(_architecture); - public uint SniWaitForSslHandshakeToComplete(SNIHandle pConn, int dwMilliseconds, out uint pProtocolVersion) => + public uint SniWaitForSslHandshakeToComplete(SNIHandle pConn, int dwMilliseconds, out SniSslProtocols pProtocolVersion) => throw ADP.SNIPlatformNotSupported(_architecture); public uint SniWriteAsyncWrapper(SNIHandle pConn, SNIPacket pPacket) => diff --git a/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeMethodsX64.netfx.cs b/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeMethodsX64.netfx.cs index 1c6519327f..e899250786 100644 --- a/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeMethodsX64.netfx.cs +++ b/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeMethodsX64.netfx.cs @@ -148,7 +148,7 @@ public uint SniSetInfo(SNIHandle pConn, QueryType queryType, ref uint pbQueryInf public uint SniTerminate() => SNITerminate(); - public uint SniWaitForSslHandshakeToComplete(SNIHandle pConn, int dwMilliseconds, out uint pProtocolVersion) => + public uint SniWaitForSslHandshakeToComplete(SNIHandle pConn, int dwMilliseconds, out SniSslProtocols pProtocolVersion) => SNIWaitForSSLHandshakeToCompleteWrapper(pConn, dwMilliseconds, out pProtocolVersion); public uint SniWriteAsyncWrapper(SNIHandle pConn, SNIPacket pPacket) => @@ -299,7 +299,7 @@ private static extern int SNIServerEnumReadWrapper( private static extern uint SNIWaitForSSLHandshakeToCompleteWrapper( [In] SNIHandle pConn, int dwMilliseconds, - out uint pProtocolVersion); + out SniSslProtocols pProtocolVersion); [DllImport(DllName, CallingConvention = CallingConvention.Cdecl)] private static extern uint SNIWriteAsyncWrapper(SNIHandle pConn, [In] SNIPacket pPacket); diff --git a/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeMethodsX86.netfx.cs b/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeMethodsX86.netfx.cs index fd2ab4644b..4c022821dd 100644 --- a/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeMethodsX86.netfx.cs +++ b/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeMethodsX86.netfx.cs @@ -148,7 +148,7 @@ public uint SniSetInfo(SNIHandle pConn, QueryType queryType, ref uint pbQueryInf public uint SniTerminate() => SNITerminate(); - public uint SniWaitForSslHandshakeToComplete(SNIHandle pConn, int dwMilliseconds, out uint pProtocolVersion) => + public uint SniWaitForSslHandshakeToComplete(SNIHandle pConn, int dwMilliseconds, out SniSslProtocols pProtocolVersion) => SNIWaitForSSLHandshakeToCompleteWrapper(pConn, dwMilliseconds, out pProtocolVersion); public uint SniWriteAsyncWrapper(SNIHandle pConn, SNIPacket pPacket) => @@ -299,7 +299,7 @@ private static extern int SNIServerEnumReadWrapper( private static extern uint SNIWaitForSSLHandshakeToCompleteWrapper( [In] SNIHandle pConn, int dwMilliseconds, - out uint pProtocolVersion); + out SniSslProtocols pProtocolVersion); [DllImport(DllName, CallingConvention = CallingConvention.Cdecl)] private static extern uint SNIWriteAsyncWrapper(SNIHandle pConn, [In] SNIPacket pPacket); diff --git a/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeWrapper.cs b/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeWrapper.cs index 6dc01dc31e..60f5f98dcd 100644 --- a/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeWrapper.cs +++ b/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeWrapper.cs @@ -376,19 +376,67 @@ internal static uint SniSetInfo(SNIHandle pConn, QueryType qType, ref uint pbQIn internal static uint SniTerminate() => s_nativeMethods.SniTerminate(); - + internal static uint SniWaitForSslHandshakeToComplete( SNIHandle pConn, int dwMilliseconds, - out uint pProtocolVersion) => - s_nativeMethods.SniWaitForSslHandshakeToComplete(pConn, dwMilliseconds, out pProtocolVersion); + out System.Security.Authentication.SslProtocols pProtocolVersion) + { + uint returnValue = s_nativeMethods.SniWaitForSslHandshakeToComplete(pConn, dwMilliseconds, out SniSslProtocols nativeProtocolVersion); + +#pragma warning disable CA5398 // Avoid hardcoded SslProtocols values + if ((nativeProtocolVersion & SniSslProtocols.SP_PROT_TLS1_2) != 0) + { + pProtocolVersion = System.Security.Authentication.SslProtocols.Tls12; + } + else if ((nativeProtocolVersion & SniSslProtocols.SP_PROT_TLS1_3) != 0) + { +#if NET + pProtocolVersion = System.Security.Authentication.SslProtocols.Tls13; +#else + // Only .NET Core supports SslProtocols.Tls13 + pProtocolVersion = (System.Security.Authentication.SslProtocols)0x3000; +#endif + } + else if ((nativeProtocolVersion & SniSslProtocols.SP_PROT_TLS1_1) != 0) + { +#if NET8_0_OR_GREATER +#pragma warning disable SYSLIB0039 // Type or member is obsolete: TLS 1.0 & 1.1 are deprecated +#endif + pProtocolVersion = System.Security.Authentication.SslProtocols.Tls11; + } + else if ((nativeProtocolVersion & SniSslProtocols.SP_PROT_TLS1_0) != 0) + { + pProtocolVersion = System.Security.Authentication.SslProtocols.Tls; +#if NET8_0_OR_GREATER +#pragma warning restore SYSLIB0039 // Type or member is obsolete: SSL and TLS 1.0 & 1.1 is deprecated +#endif + } + else if ((nativeProtocolVersion & SniSslProtocols.SP_PROT_SSL3) != 0) + { + // SSL 2.0 and 3.0 are only referenced to log a warning, not explicitly used for connections +#pragma warning disable CS0618, CA5397 + pProtocolVersion = System.Security.Authentication.SslProtocols.Ssl3; + } + else if ((nativeProtocolVersion & SniSslProtocols.SP_PROT_SSL2) != 0) + { + pProtocolVersion = System.Security.Authentication.SslProtocols.Ssl2; +#pragma warning restore CS0618, CA5397 + } + else + { + pProtocolVersion = System.Security.Authentication.SslProtocols.None; + } +#pragma warning restore CA5398 // Avoid hardcoded SslProtocols values + return returnValue; + } internal static uint SniWritePacket(SNIHandle pConn, SNIPacket packet, bool sync) => sync ? s_nativeMethods.SniWriteSyncOverAsync(pConn, packet) : s_nativeMethods.SniWriteAsyncWrapper(pConn, packet); - #endregion +#endregion #region Private Methods diff --git a/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniSslProtocols.cs b/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniSslProtocols.cs new file mode 100644 index 0000000000..cd9bf47ed3 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniSslProtocols.cs @@ -0,0 +1,32 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace Interop.Windows.Sni +{ + internal enum SniSslProtocols : uint + { + // Protocol versions from native SNI + SP_PROT_SSL2_SERVER = 0x00000004, + SP_PROT_SSL2_CLIENT = 0x00000008, + SP_PROT_SSL3_SERVER = 0x00000010, + SP_PROT_SSL3_CLIENT = 0x00000020, + SP_PROT_TLS1_0_SERVER = 0x00000040, + SP_PROT_TLS1_0_CLIENT = 0x00000080, + SP_PROT_TLS1_1_SERVER = 0x00000100, + SP_PROT_TLS1_1_CLIENT = 0x00000200, + SP_PROT_TLS1_2_SERVER = 0x00000400, + SP_PROT_TLS1_2_CLIENT = 0x00000800, + SP_PROT_TLS1_3_SERVER = 0x00001000, + SP_PROT_TLS1_3_CLIENT = 0x00002000, + SP_PROT_NONE = 0x0, + + // Combinations for easier use when mapping to SslProtocols + SP_PROT_SSL2 = SP_PROT_SSL2_SERVER | SP_PROT_SSL2_CLIENT, + SP_PROT_SSL3 = SP_PROT_SSL3_SERVER | SP_PROT_SSL3_CLIENT, + SP_PROT_TLS1_0 = SP_PROT_TLS1_0_SERVER | SP_PROT_TLS1_0_CLIENT, + SP_PROT_TLS1_1 = SP_PROT_TLS1_1_SERVER | SP_PROT_TLS1_1_CLIENT, + SP_PROT_TLS1_2 = SP_PROT_TLS1_2_SERVER | SP_PROT_TLS1_2_CLIENT, + SP_PROT_TLS1_3 = SP_PROT_TLS1_3_SERVER | SP_PROT_TLS1_3_CLIENT, + } +} diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserSafeHandles.Windows.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserSafeHandles.Windows.cs index ea4c46753e..749fe5a008 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserSafeHandles.Windows.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserSafeHandles.Windows.cs @@ -277,63 +277,4 @@ override protected bool ReleaseHandle() return true; } } - - internal sealed class WritePacketCache : IDisposable - { - private bool _disposed; - private Stack _packets; - - public WritePacketCache() - { - _disposed = false; - _packets = new Stack(); - } - - public SNIPacket Take(SNIHandle sniHandle) - { - SNIPacket packet; - if (_packets.Count > 0) - { - // Success - reset the packet - packet = _packets.Pop(); - SniNativeWrapper.SniPacketReset(sniHandle, IoType.WRITE, packet, ConsumerNumber.SNI_Consumer_SNI); - } - else - { - // Failed to take a packet - create a new one - packet = new SNIPacket(sniHandle); - } - return packet; - } - - public void Add(SNIPacket packet) - { - if (!_disposed) - { - _packets.Push(packet); - } - else - { - // If we're disposed, then get rid of any packets added to us - packet.Dispose(); - } - } - - public void Clear() - { - while (_packets.Count > 0) - { - _packets.Pop().Dispose(); - } - } - - public void Dispose() - { - if (!_disposed) - { - _disposed = true; - Clear(); - } - } - } } diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs index cac03827ab..64052b16d8 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs @@ -527,10 +527,10 @@ internal abstract void CreatePhysicalSNIHandle( string cachedFQDN, ref SQLDNSInfo pendingDNSInfo, string serverSPN, - bool isIntegratedSecurity = false, - bool tlsFirst = false, - string hostNameInCertificate = "", - string serverCertificateFilename = ""); + bool isIntegratedSecurity, + bool tlsFirst, + string hostNameInCertificate, + string serverCertificateFilename); internal abstract uint EnableSsl(ref uint info, bool tlsFirst, string serverCertificateFilename); diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.Windows.cs similarity index 87% rename from src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs rename to src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.Windows.cs index cb85e2b9b1..4efd94e112 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.Windows.cs @@ -148,15 +148,17 @@ internal override void CreatePhysicalSNIHandle( string cachedFQDN, ref SQLDNSInfo pendingDNSInfo, string serverSPN, - bool isIntegratedSecurity = false, - bool tlsFirst = false, - string hostNameInCertificate = "", - string serverCertificateFilename = "") + bool isIntegratedSecurity, + bool tlsFirst, + string hostNameInCertificate, + string serverCertificateFilename) { if (isIntegratedSecurity) { + // now allocate proper length of buffer if (!string.IsNullOrEmpty(serverSPN)) { + // Native SNI requires the Unicode encoding and any other encoding like UTF8 breaks the code. SqlClientEventSource.Log.TryTraceEvent(" Server SPN `{0}` from the connection string is used.", serverSPN); } else @@ -171,10 +173,13 @@ internal override void CreatePhysicalSNIHandle( // serverName : serverInfo.ExtendedServerName // may not use this serverName as key - _ = SQLFallbackDNSCache.Instance.GetDNSInfo(cachedFQDN, out SQLDNSInfo cachedDNSInfo); + SQLFallbackDNSCache.Instance.GetDNSInfo(cachedFQDN, out SQLDNSInfo cachedDNSInfo); - _sessionHandle = new SNIHandle(myInfo, serverName, ref serverSPN, timeout.MillisecondsRemainingInt, - out instanceName, flushCache, !async, fParallel, transparentNetworkResolutionState, totalTimeout, + _sessionHandle = new SNIHandle(myInfo, serverName, ref serverSPN, timeout.MillisecondsRemainingInt, out instanceName, + flushCache, !async, fParallel, +#if NETFRAMEWORK + transparentNetworkResolutionState, totalTimeout, +#endif iPAddressPreference, cachedDNSInfo, hostNameInCertificate); resolvedSpn = new(serverSPN.TrimEnd()); } @@ -218,6 +223,7 @@ internal override void Dispose() SafeHandle packetHandle = _sniPacket; SafeHandle sessionHandle = _sessionHandle; SafeHandle asyncAttnPacket = _sniAsyncAttnPacket; + _sniPacket = null; _sessionHandle = null; _sniAsyncAttnPacket = null; @@ -231,14 +237,9 @@ internal override void Dispose() // here for the callbacks!!! This only applies to async. Should be fixed by async fixes for // AD unload/exit. - if (packetHandle != null) - { - packetHandle.Dispose(); - } - if (asyncAttnPacket != null) - { - asyncAttnPacket.Dispose(); - } + packetHandle?.Dispose(); + asyncAttnPacket?.Dispose(); + if (sessionHandle != null) { sessionHandle.Dispose(); @@ -280,6 +281,9 @@ internal override uint CheckConnection() internal override PacketHandle ReadAsync(SessionHandle handle, out uint error) { +#if NET + Debug.Assert(handle.Type == SessionHandle.NativeHandleType, "unexpected handle type when requiring NativePointer"); +#endif IntPtr readPacketPtr = IntPtr.Zero; error = SniNativeWrapper.SniReadAsync(handle.NativeHandle, ref readPacketPtr); return PacketHandle.FromNativePointer(readPacketPtr); @@ -410,10 +414,6 @@ internal override uint EnableSsl(ref uint info, bool tlsFirst, string serverCert AuthProviderInfo authInfo = new AuthProviderInfo(); authInfo.flags = info; authInfo.tlsFirst = tlsFirst; - authInfo.certId = null; - authInfo.certHash = false; - authInfo.clientCertificateCallbackContext = IntPtr.Zero; - authInfo.clientCertificateCallback = null; authInfo.serverCertFileName = string.IsNullOrEmpty(serverCertificateFilename) ? null : serverCertificateFilename; // Add SSL (Encryption) SNI provider. @@ -423,19 +423,15 @@ internal override uint EnableSsl(ref uint info, bool tlsFirst, string serverCert internal override uint SetConnectionBufferSize(ref uint unsignedPacketSize) => SniNativeWrapper.SniSetInfo(Handle, QueryType.SNI_QUERY_CONN_BUFSIZE, ref unsignedPacketSize); - internal override uint WaitForSSLHandShakeToComplete(out SslProtocols protocolVersion) - { - uint returnValue = SniNativeWrapper.SniWaitForSslHandshakeToComplete(Handle, GetTimeoutRemaining(), out uint nativeProtocolVersion); - - protocolVersion = (SslProtocols)nativeProtocolVersion; - return returnValue; - } + internal override uint WaitForSSLHandShakeToComplete(out SslProtocols protocolVersion) => + SniNativeWrapper.SniWaitForSslHandshakeToComplete(Handle, GetTimeoutRemaining(), out protocolVersion); internal override SniErrorDetails GetErrorDetails() { SniNativeWrapper.SniGetLastError(out SniError sniError); - return new SniErrorDetails(sniError.errorMessage, sniError.nativeError, sniError.sniError, (int)sniError.provider, sniError.lineNumber, sniError.function); + return new SniErrorDetails(sniError.errorMessage, sniError.nativeError, sniError.sniError, + (int)sniError.provider, sniError.lineNumber, sniError.function); } internal override void DisposePacketCache() @@ -448,5 +444,64 @@ internal override void DisposePacketCache() } internal override SspiContextProvider CreateSspiContextProvider() => new NativeSspiContextProvider(); + + private sealed class WritePacketCache : IDisposable + { + private bool _disposed; + private Stack _packets; + + public WritePacketCache() + { + _disposed = false; + _packets = new Stack(); + } + + public SNIPacket Take(SNIHandle sniHandle) + { + SNIPacket packet; + if (_packets.Count > 0) + { + // Success - reset the packet + packet = _packets.Pop(); + SniNativeWrapper.SniPacketReset(sniHandle, IoType.WRITE, packet, ConsumerNumber.SNI_Consumer_SNI); + } + else + { + // Failed to take a packet - create a new one + packet = new SNIPacket(sniHandle); + } + return packet; + } + + public void Add(SNIPacket packet) + { + if (!_disposed) + { + _packets.Push(packet); + } + else + { + // If we're disposed, then get rid of any packets added to us + packet.Dispose(); + } + } + + public void Clear() + { + while (_packets.Count > 0) + { + _packets.Pop().Dispose(); + } + } + + public void Dispose() + { + if (!_disposed) + { + _disposed = true; + Clear(); + } + } + } } }