Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NEXT_CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@
### Updated

### Fixed
- Fix: driver failing to authenticate on token update in U2M flow.
---
*Note: When making changes, please add your change under the appropriate section with a brief description.*
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import com.databricks.jdbc.common.DatabricksJdbcConstants;
import com.databricks.jdbc.log.JdbcLogger;
import com.databricks.jdbc.log.JdbcLoggerFactory;
import com.databricks.sdk.core.CredentialsProvider;
import com.databricks.sdk.core.DatabricksConfig;
import com.databricks.sdk.core.DatabricksException;
import com.databricks.sdk.core.http.HttpClient;
Expand Down Expand Up @@ -33,12 +34,19 @@ public static DatabricksConfig initializeConfigWithToken(
String newAccessToken, DatabricksConfig config) {
String hostUrl = config.getHost();
HttpClient httpClient = config.getHttpClient();
CredentialsProvider credentialsProvider = config.getCredentialsProvider();
DatabricksConfig newConfig = new DatabricksConfig();
newConfig
.setHost(hostUrl)
.setHttpClient(httpClient)
.setAuthType(DatabricksJdbcConstants.ACCESS_TOKEN_AUTH_TYPE)
.setToken(newAccessToken);

// Preserve and reconfigure the credentials provider if it exists
if (credentialsProvider != null) {
newConfig.setCredentialsProvider(credentialsProvider);
}

return newConfig;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
package com.databricks.jdbc.dbclient.impl.thrift;

import static com.databricks.jdbc.common.util.DatabricksAuthUtil.initializeConfigWithToken;

import com.databricks.jdbc.api.internal.IDatabricksConnectionContext;
import com.databricks.jdbc.common.util.ValidationUtil;
import com.databricks.jdbc.dbclient.IDatabricksHttpClient;
Expand Down Expand Up @@ -171,11 +169,6 @@ private void refreshHeadersIfRequired() {
refreshedHeaders != null ? new HashMap<>(refreshedHeaders) : Collections.emptyMap();
}

void resetAccessToken(String newAccessToken) {
this.databricksConfig = initializeConfigWithToken(newAccessToken, databricksConfig);
this.databricksConfig.resolve();
}

@VisibleForTesting
void setResponseBuffer(ByteArrayInputStream responseBuffer) {
this.responseBuffer = responseBuffer;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ final class DatabricksThriftAccessor {
TExecuteStatementResp._Fields.OPERATION_HANDLE.getThriftFieldId();
private static final short statusFieldId =
TExecuteStatementResp._Fields.STATUS.getThriftFieldId();
private final DatabricksConfig databricksConfig;
private DatabricksConfig databricksConfig;
private final boolean enableDirectResults;
private final int asyncPollIntervalMillis;
private final int maxRowsPerBlock;
Expand Down Expand Up @@ -477,6 +477,10 @@ DatabricksConfig getDatabricksConfig() {
return databricksConfig;
}

void updateConfig(DatabricksConfig newConfig) {
this.databricksConfig = newConfig;
}

TFetchResultsResp getResultSetResp(
TStatus responseStatus,
TOperationHandle operationHandle,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import static com.databricks.jdbc.common.EnvironmentVariables.DEFAULT_STATEMENT_TIMEOUT_SECONDS;
import static com.databricks.jdbc.common.EnvironmentVariables.JDBC_THRIFT_VERSION;
import static com.databricks.jdbc.common.util.DatabricksAuthUtil.initializeConfigWithToken;
import static com.databricks.jdbc.common.util.DatabricksThriftUtil.*;
import static com.databricks.jdbc.common.util.DatabricksTypeUtil.DECIMAL;
import static com.databricks.jdbc.common.util.DatabricksTypeUtil.getDecimalTypeString;
Expand Down Expand Up @@ -77,8 +78,11 @@ public IDatabricksConnectionContext getConnectionContext() {

@Override
public void resetAccessToken(String newAccessToken) {
((DatabricksHttpTTransport) thriftAccessor.getThriftClient().getInputProtocol().getTransport())
.resetAccessToken(newAccessToken);
// Update the config stored in the accessor so new transports use the new token
DatabricksConfig currentConfig = thriftAccessor.getDatabricksConfig();
DatabricksConfig newConfig = initializeConfigWithToken(newAccessToken, currentConfig);
newConfig.resolve();
thriftAccessor.updateConfig(newConfig);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import static org.mockito.Mockito.*;

import com.databricks.jdbc.api.internal.IDatabricksConnectionContext;
import com.databricks.jdbc.common.DatabricksJdbcConstants;
import com.databricks.jdbc.dbclient.impl.common.TracingUtil;
import com.databricks.jdbc.dbclient.impl.http.DatabricksHttpClient;
import com.databricks.jdbc.exception.DatabricksHttpException;
Expand Down Expand Up @@ -122,21 +121,4 @@ public void flush_SendsDataCorrectly_tracingEnabled()
assertTrue(capturedRequest.containsHeader("Content-Type"));
assertTrue(capturedRequest.containsHeader(TracingUtil.TRACE_HEADER));
}

@Test
public void resetAccessToken_UpdatesConfigCorrectly() {
DatabricksHttpTTransport transport =
new DatabricksHttpTTransport(
mockedHttpClient, testUrl, mockDatabricksConfig, mockConnectionContext);

when(mockDatabricksConfig.getHost()).thenReturn(testUrl);

transport.resetAccessToken(NEW_ACCESS_TOKEN);

assertEquals(NEW_ACCESS_TOKEN, transport.databricksConfig.getToken());
assertEquals(testUrl, transport.databricksConfig.getHost());
assertEquals(
DatabricksJdbcConstants.ACCESS_TOKEN_AUTH_TYPE, transport.databricksConfig.getAuthType());
assertNotNull(transport.databricksConfig.getHttpClient());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,13 @@
import java.sql.SQLException;
import java.util.*;
import java.util.stream.Stream;
import org.apache.thrift.protocol.TProtocol;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.junit.jupiter.MockitoExtension;

@ExtendWith(MockitoExtension.class)
Expand Down Expand Up @@ -904,15 +902,10 @@ void testGetDatabricksConfig() {
void testResetAccessToken() throws DatabricksParsingException {
DatabricksThriftServiceClient client =
new DatabricksThriftServiceClient(thriftAccessor, connectionContext);
DatabricksHttpTTransport mockDatabricksHttpTTransport =
Mockito.mock(DatabricksHttpTTransport.class);
TCLIService.Client mockTCLIServiceClient = Mockito.mock(TCLIService.Client.class);
TProtocol mockProtocol = Mockito.mock(TProtocol.class);
when(thriftAccessor.getThriftClient()).thenReturn(mockTCLIServiceClient);
when(mockTCLIServiceClient.getInputProtocol()).thenReturn(mockProtocol);
when(mockProtocol.getTransport()).thenReturn(mockDatabricksHttpTTransport);
when(thriftAccessor.getDatabricksConfig()).thenReturn(databricksConfig);
when(databricksConfig.getHost()).thenReturn("test-host");
client.resetAccessToken(NEW_ACCESS_TOKEN);
verify(mockDatabricksHttpTTransport).resetAccessToken(NEW_ACCESS_TOKEN);
verify(thriftAccessor).updateConfig(any(DatabricksConfig.class));
}

@Test
Expand Down
Loading