diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index 02e435153..015147773 100644 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -9,5 +9,6 @@ ### Updated ### Fixed +- Fix: driver failing to authenticate on token update in U2M flow. --- *Note: When making changes, please add your change under the appropriate section with a brief description.* diff --git a/src/main/java/com/databricks/jdbc/common/util/DatabricksAuthUtil.java b/src/main/java/com/databricks/jdbc/common/util/DatabricksAuthUtil.java index 465fed6d2..003428646 100644 --- a/src/main/java/com/databricks/jdbc/common/util/DatabricksAuthUtil.java +++ b/src/main/java/com/databricks/jdbc/common/util/DatabricksAuthUtil.java @@ -4,6 +4,7 @@ import com.databricks.jdbc.common.DatabricksJdbcConstants; import com.databricks.jdbc.log.JdbcLogger; import com.databricks.jdbc.log.JdbcLoggerFactory; +import com.databricks.sdk.core.CredentialsProvider; import com.databricks.sdk.core.DatabricksConfig; import com.databricks.sdk.core.DatabricksException; import com.databricks.sdk.core.http.HttpClient; @@ -33,12 +34,19 @@ public static DatabricksConfig initializeConfigWithToken( String newAccessToken, DatabricksConfig config) { String hostUrl = config.getHost(); HttpClient httpClient = config.getHttpClient(); + CredentialsProvider credentialsProvider = config.getCredentialsProvider(); DatabricksConfig newConfig = new DatabricksConfig(); newConfig .setHost(hostUrl) .setHttpClient(httpClient) .setAuthType(DatabricksJdbcConstants.ACCESS_TOKEN_AUTH_TYPE) .setToken(newAccessToken); + + // Preserve and reconfigure the credentials provider if it exists + if (credentialsProvider != null) { + newConfig.setCredentialsProvider(credentialsProvider); + } + return newConfig; } diff --git a/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksHttpTTransport.java b/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksHttpTTransport.java index 7c00a8c6e..337ee8a28 100644 --- a/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksHttpTTransport.java +++ b/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksHttpTTransport.java @@ -1,7 +1,5 @@ package com.databricks.jdbc.dbclient.impl.thrift; -import static com.databricks.jdbc.common.util.DatabricksAuthUtil.initializeConfigWithToken; - import com.databricks.jdbc.api.internal.IDatabricksConnectionContext; import com.databricks.jdbc.common.util.ValidationUtil; import com.databricks.jdbc.dbclient.IDatabricksHttpClient; @@ -171,11 +169,6 @@ private void refreshHeadersIfRequired() { refreshedHeaders != null ? new HashMap<>(refreshedHeaders) : Collections.emptyMap(); } - void resetAccessToken(String newAccessToken) { - this.databricksConfig = initializeConfigWithToken(newAccessToken, databricksConfig); - this.databricksConfig.resolve(); - } - @VisibleForTesting void setResponseBuffer(ByteArrayInputStream responseBuffer) { this.responseBuffer = responseBuffer; diff --git a/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftAccessor.java b/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftAccessor.java index eafb96804..5897ec5ac 100644 --- a/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftAccessor.java +++ b/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftAccessor.java @@ -47,7 +47,7 @@ final class DatabricksThriftAccessor { TExecuteStatementResp._Fields.OPERATION_HANDLE.getThriftFieldId(); private static final short statusFieldId = TExecuteStatementResp._Fields.STATUS.getThriftFieldId(); - private final DatabricksConfig databricksConfig; + private DatabricksConfig databricksConfig; private final boolean enableDirectResults; private final int asyncPollIntervalMillis; private final int maxRowsPerBlock; @@ -477,6 +477,10 @@ DatabricksConfig getDatabricksConfig() { return databricksConfig; } + void updateConfig(DatabricksConfig newConfig) { + this.databricksConfig = newConfig; + } + TFetchResultsResp getResultSetResp( TStatus responseStatus, TOperationHandle operationHandle, diff --git a/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftServiceClient.java b/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftServiceClient.java index 37cfd738d..d85401e02 100644 --- a/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftServiceClient.java +++ b/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftServiceClient.java @@ -2,6 +2,7 @@ import static com.databricks.jdbc.common.EnvironmentVariables.DEFAULT_STATEMENT_TIMEOUT_SECONDS; import static com.databricks.jdbc.common.EnvironmentVariables.JDBC_THRIFT_VERSION; +import static com.databricks.jdbc.common.util.DatabricksAuthUtil.initializeConfigWithToken; import static com.databricks.jdbc.common.util.DatabricksThriftUtil.*; import static com.databricks.jdbc.common.util.DatabricksTypeUtil.DECIMAL; import static com.databricks.jdbc.common.util.DatabricksTypeUtil.getDecimalTypeString; @@ -77,8 +78,11 @@ public IDatabricksConnectionContext getConnectionContext() { @Override public void resetAccessToken(String newAccessToken) { - ((DatabricksHttpTTransport) thriftAccessor.getThriftClient().getInputProtocol().getTransport()) - .resetAccessToken(newAccessToken); + // Update the config stored in the accessor so new transports use the new token + DatabricksConfig currentConfig = thriftAccessor.getDatabricksConfig(); + DatabricksConfig newConfig = initializeConfigWithToken(newAccessToken, currentConfig); + newConfig.resolve(); + thriftAccessor.updateConfig(newConfig); } @Override diff --git a/src/test/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksHttpTTransportTest.java b/src/test/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksHttpTTransportTest.java index b673e4463..0e624561f 100644 --- a/src/test/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksHttpTTransportTest.java +++ b/src/test/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksHttpTTransportTest.java @@ -5,7 +5,6 @@ import static org.mockito.Mockito.*; import com.databricks.jdbc.api.internal.IDatabricksConnectionContext; -import com.databricks.jdbc.common.DatabricksJdbcConstants; import com.databricks.jdbc.dbclient.impl.common.TracingUtil; import com.databricks.jdbc.dbclient.impl.http.DatabricksHttpClient; import com.databricks.jdbc.exception.DatabricksHttpException; @@ -122,21 +121,4 @@ public void flush_SendsDataCorrectly_tracingEnabled() assertTrue(capturedRequest.containsHeader("Content-Type")); assertTrue(capturedRequest.containsHeader(TracingUtil.TRACE_HEADER)); } - - @Test - public void resetAccessToken_UpdatesConfigCorrectly() { - DatabricksHttpTTransport transport = - new DatabricksHttpTTransport( - mockedHttpClient, testUrl, mockDatabricksConfig, mockConnectionContext); - - when(mockDatabricksConfig.getHost()).thenReturn(testUrl); - - transport.resetAccessToken(NEW_ACCESS_TOKEN); - - assertEquals(NEW_ACCESS_TOKEN, transport.databricksConfig.getToken()); - assertEquals(testUrl, transport.databricksConfig.getHost()); - assertEquals( - DatabricksJdbcConstants.ACCESS_TOKEN_AUTH_TYPE, transport.databricksConfig.getAuthType()); - assertNotNull(transport.databricksConfig.getHttpClient()); - } } diff --git a/src/test/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftServiceClientTest.java b/src/test/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftServiceClientTest.java index f0e71d9d5..27f4c04e5 100644 --- a/src/test/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftServiceClientTest.java +++ b/src/test/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftServiceClientTest.java @@ -33,7 +33,6 @@ import java.sql.SQLException; import java.util.*; import java.util.stream.Stream; -import org.apache.thrift.protocol.TProtocol; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.params.ParameterizedTest; @@ -41,7 +40,6 @@ import org.junit.jupiter.params.provider.MethodSource; import org.mockito.ArgumentCaptor; import org.mockito.Mock; -import org.mockito.Mockito; import org.mockito.junit.jupiter.MockitoExtension; @ExtendWith(MockitoExtension.class) @@ -904,15 +902,10 @@ void testGetDatabricksConfig() { void testResetAccessToken() throws DatabricksParsingException { DatabricksThriftServiceClient client = new DatabricksThriftServiceClient(thriftAccessor, connectionContext); - DatabricksHttpTTransport mockDatabricksHttpTTransport = - Mockito.mock(DatabricksHttpTTransport.class); - TCLIService.Client mockTCLIServiceClient = Mockito.mock(TCLIService.Client.class); - TProtocol mockProtocol = Mockito.mock(TProtocol.class); - when(thriftAccessor.getThriftClient()).thenReturn(mockTCLIServiceClient); - when(mockTCLIServiceClient.getInputProtocol()).thenReturn(mockProtocol); - when(mockProtocol.getTransport()).thenReturn(mockDatabricksHttpTTransport); + when(thriftAccessor.getDatabricksConfig()).thenReturn(databricksConfig); + when(databricksConfig.getHost()).thenReturn("test-host"); client.resetAccessToken(NEW_ACCESS_TOKEN); - verify(mockDatabricksHttpTTransport).resetAccessToken(NEW_ACCESS_TOKEN); + verify(thriftAccessor).updateConfig(any(DatabricksConfig.class)); } @Test