diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksConfig.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksConfig.java index fcb79c87b..9e567f0e2 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksConfig.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksConfig.java @@ -35,6 +35,9 @@ public class DatabricksConfig { @ConfigAttribute(env = "DATABRICKS_SCOPES", auth = "oauth") private List scopes; + @ConfigAttribute(env = "DATABRICKS_TOKEN_AUDIENCE", auth = "oauth") + private String audience; + @ConfigAttribute(env = "DATABRICKS_REDIRECT_URL", auth = "oauth") private String redirectUrl; @@ -309,6 +312,15 @@ public DatabricksConfig setClientSecret(String clientSecret) { return this; } + public String getAudience() { + return audience; + } + + public DatabricksConfig setAudience(String audience) { + this.audience = audience; + return this; + } + public String getOAuthRedirectUrl() { return redirectUrl; } @@ -381,13 +393,17 @@ public DatabricksConfig setAzureUseMsi(boolean azureUseMsi) { return this; } - /** @deprecated Use {@link #getAzureUseMsi()} instead. */ + /** + * @deprecated Use {@link #getAzureUseMsi()} instead. + */ @Deprecated() public boolean getAzureUseMSI() { return azureUseMsi; } - /** @deprecated Use {@link #setAzureUseMsi(boolean)} instead. */ + /** + * @deprecated Use {@link #setAzureUseMsi(boolean)} instead. + */ @Deprecated public DatabricksConfig setAzureUseMSI(boolean azureUseMsi) { this.azureUseMsi = azureUseMsi; diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DefaultCredentialsProvider.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DefaultCredentialsProvider.java index b8f4d7867..19ebad345 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DefaultCredentialsProvider.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DefaultCredentialsProvider.java @@ -1,57 +1,39 @@ package com.databricks.sdk.core; -import com.databricks.sdk.core.oauth.*; +import com.databricks.sdk.core.oauth.AzureGithubOidcCredentialsProvider; +import com.databricks.sdk.core.oauth.AzureServicePrincipalCredentialsProvider; +import com.databricks.sdk.core.oauth.DatabricksOAuthTokenSource; +import com.databricks.sdk.core.oauth.ExternalBrowserCredentialsProvider; +import com.databricks.sdk.core.oauth.GithubIDTokenSource; +import com.databricks.sdk.core.oauth.IDTokenSource; +import com.databricks.sdk.core.oauth.OAuthM2MServicePrincipalCredentialsProvider; +import com.databricks.sdk.core.oauth.OpenIDConnectEndpoints; +import com.databricks.sdk.core.oauth.TokenSourceCredentialsProvider; import java.util.ArrayList; -import java.util.Arrays; +import java.util.HashMap; import java.util.List; +import java.util.Map; import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class DefaultCredentialsProvider implements CredentialsProvider { private static final Logger LOG = LoggerFactory.getLogger(DefaultCredentialsProvider.class); - private static final List> providerClasses = - Arrays.asList( - PatCredentialsProvider.class, - BasicCredentialsProvider.class, - OAuthM2MServicePrincipalCredentialsProvider.class, - GithubOidcCredentialsProvider.class, - AzureGithubOidcCredentialsProvider.class, - AzureServicePrincipalCredentialsProvider.class, - AzureCliCredentialsProvider.class, - ExternalBrowserCredentialsProvider.class, - DatabricksCliCredentialsProvider.class, - NotebookNativeCredentialsProvider.class, - GoogleCredentialsCredentialsProvider.class, - GoogleIdCredentialsProvider.class); - - private final List providers; + private List providers = new ArrayList<>(); private String authType = "default"; + public DefaultCredentialsProvider() {} + + @Override public String authType() { return authType; } - public DefaultCredentialsProvider() { - providers = new ArrayList<>(); - for (Class clazz : providerClasses) { - try { - providers.add((CredentialsProvider) clazz.newInstance()); - } catch (NoClassDefFoundError | InstantiationException | IllegalAccessException e) { - LOG.warn( - "Failed to instantiate credentials provider: " - + clazz.getName() - + ", skipping. Cause: " - + e.getClass().getCanonicalName() - + ": " - + e.getMessage()); - } - } - } - @Override public synchronized HeaderFactory configure(DatabricksConfig config) { + addDefaultCredentialsProviders(config); + for (CredentialsProvider provider : providers) { if (config.getAuthType() != null && !config.getAuthType().isEmpty() @@ -80,4 +62,57 @@ public synchronized HeaderFactory configure(DatabricksConfig config) { + authFlowUrl + " to configure credentials for your preferred authentication method"); } + + private void addDefaultCredentialsProviders(DatabricksConfig config) { + providers.add(new PatCredentialsProvider()); + providers.add(new BasicCredentialsProvider()); + providers.add(new OAuthM2MServicePrincipalCredentialsProvider()); + + addOIDCTokenCredentialsProviders(config); + + providers.add(new AzureGithubOidcCredentialsProvider()); + providers.add(new AzureServicePrincipalCredentialsProvider()); + providers.add(new AzureCliCredentialsProvider()); + providers.add(new ExternalBrowserCredentialsProvider()); + providers.add(new DatabricksCliCredentialsProvider()); + providers.add(new NotebookNativeCredentialsProvider()); + providers.add(new GoogleCredentialsCredentialsProvider()); + providers.add(new GoogleIdCredentialsProvider()); + } + + private void addOIDCTokenCredentialsProviders(DatabricksConfig config) { + OpenIDConnectEndpoints endpoints = null; + try { + endpoints = config.getOidcEndpoints(); + } catch (Exception e) { + LOG.error("Error getting OIDC endpoints", e); + } + + Map namedIdTokenSources = new HashMap<>(); + namedIdTokenSources.put( + "github-oidc", + new GithubIDTokenSource( + config.getActionsIdTokenRequestUrl(), + config.getActionsIdTokenRequestToken(), + config.getHttpClient())); + // Add new providers to the map as needed + + for (Map.Entry entry : namedIdTokenSources.entrySet()) { + String name = entry.getKey(); + IDTokenSource idTokenSource = entry.getValue(); + + DatabricksOAuthTokenSource oauthTokenSource = + new DatabricksOAuthTokenSource.Builder( + config.getClientId(), + config.getHost(), + endpoints, + idTokenSource, + config.getHttpClient()) + .audience(config.getAudience()) + .accountId(config.isAccountClient() ? config.getAccountId() : null) + .build(); + + providers.add(new TokenSourceCredentialsProvider(oauthTokenSource, name)); + } + } } diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/GithubIDTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/GithubIDTokenSource.java new file mode 100644 index 000000000..362719bda --- /dev/null +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/GithubIDTokenSource.java @@ -0,0 +1,89 @@ +package com.databricks.sdk.core.oauth; + +import com.databricks.sdk.core.DatabricksException; +import com.databricks.sdk.core.http.HttpClient; +import com.databricks.sdk.core.http.Request; +import com.databricks.sdk.core.http.Response; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.google.common.base.Strings; +import java.io.IOException; + +/** GithubIDTokenSource retrieves JWT Tokens from GitHub Actions. */ +public class GithubIDTokenSource implements IDTokenSource { + private final String actionsIDTokenRequestURL; + private final String actionsIDTokenRequestToken; + private final HttpClient httpClient; + private final ObjectMapper mapper = new ObjectMapper(); + + /** + * Constructs a new GithubIDTokenSource. + * + * @param actionsIDTokenRequestURL The URL to request the ID token from GitHub Actions. + * @param actionsIDTokenRequestToken The token used to authenticate the request. + * @param httpClient The HTTP client to use for making requests. + */ + public GithubIDTokenSource( + String actionsIDTokenRequestURL, String actionsIDTokenRequestToken, HttpClient httpClient) { + this.actionsIDTokenRequestURL = actionsIDTokenRequestURL; + this.actionsIDTokenRequestToken = actionsIDTokenRequestToken; + this.httpClient = httpClient; + } + + @Override + public IDToken getIDToken(String audience) { + if (Strings.isNullOrEmpty(actionsIDTokenRequestURL)) { + throw new DatabricksException("Missing ActionsIDTokenRequestURL"); + } + if (Strings.isNullOrEmpty(actionsIDTokenRequestToken)) { + throw new DatabricksException("Missing ActionsIDTokenRequestToken"); + } + if (httpClient == null) { + throw new DatabricksException("HttpClient cannot be null"); + } + + String requestUrl = actionsIDTokenRequestURL; + if (!Strings.isNullOrEmpty(audience)) { + requestUrl = String.format("%s&audience=%s", requestUrl, audience); + } + + Request req = + new Request("GET", requestUrl) + .withHeader("Authorization", "Bearer " + actionsIDTokenRequestToken); + + Response resp; + try { + resp = httpClient.execute(req); + } catch (IOException e) { + throw new DatabricksException( + "Failed to request ID token from " + requestUrl + ": " + e.getMessage(), e); + } + + if (resp.getStatusCode() != 200) { + throw new DatabricksException( + "Failed to request ID token: status code " + + resp.getStatusCode() + + ", response body: " + + resp.getBody().toString()); + } + + ObjectNode jsonResp; + try { + jsonResp = mapper.readValue(resp.getBody(), ObjectNode.class); + } catch (IOException e) { + throw new DatabricksException( + "Failed to request ID token: corrupted token: " + e.getMessage()); + } + + if (!jsonResp.has("value")) { + throw new DatabricksException("ID token response missing 'value' field"); + } + + String tokenValue = jsonResp.get("value").textValue(); + if (Strings.isNullOrEmpty(tokenValue)) { + throw new DatabricksException("Received empty ID token from GitHub Actions"); + } + + return new IDToken(tokenValue); + } +} diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/TokenSourceCredentialsProvider.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/TokenSourceCredentialsProvider.java new file mode 100644 index 000000000..233d839db --- /dev/null +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/TokenSourceCredentialsProvider.java @@ -0,0 +1,44 @@ +package com.databricks.sdk.core.oauth; + +import com.databricks.sdk.core.CredentialsProvider; +import com.databricks.sdk.core.DatabricksConfig; +import com.databricks.sdk.core.HeaderFactory; +import java.util.HashMap; +import java.util.Map; + +/** Base class for token-based credentials providers. */ +public class TokenSourceCredentialsProvider implements CredentialsProvider { + private final TokenSource tokenSource; + private final String authType; + + /** + * Creates a new TokenSourceCredentialsProvider with the specified token source and auth type. + * + * @param tokenSource The token source to use for token exchange + * @param authType The authentication type string + */ + public TokenSourceCredentialsProvider(TokenSource tokenSource, String authType) { + this.tokenSource = tokenSource; + this.authType = authType; + } + + @Override + public HeaderFactory configure(DatabricksConfig config) { + + return () -> { + Map headers = new HashMap<>(); + try { + String accessToken = tokenSource.getToken().getAccessToken(); + headers.put("Authorization", "Bearer " + accessToken); + return headers; + } catch (Exception e) { + return null; + } + }; + } + + @Override + public String authType() { + return authType; + } +} diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/GithubIDTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/GithubIDTokenSourceTest.java new file mode 100644 index 000000000..e53aaefa8 --- /dev/null +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/GithubIDTokenSourceTest.java @@ -0,0 +1,154 @@ +package com.databricks.sdk.core.oauth; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; + +import com.databricks.sdk.core.DatabricksException; +import com.databricks.sdk.core.http.HttpClient; +import com.databricks.sdk.core.http.Request; +import com.databricks.sdk.core.http.Response; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ObjectNode; +import java.io.IOException; +import java.net.MalformedURLException; +import java.net.URL; +import java.util.stream.Stream; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +public class GithubIDTokenSourceTest { + private static final String TEST_REQUEST_URL = "https://github.com/token"; + private static final String TEST_REQUEST_TOKEN = "test-request-token"; + private static final String TEST_ID_TOKEN = "test-id-token"; + private static final String TEST_AUDIENCE = "test-audience"; + + @Mock private static HttpClient mockHttpClient; + + private GithubIDTokenSource tokenSource; + private ObjectMapper mapper; + + @BeforeEach + void setUp() throws IOException { + MockitoAnnotations.openMocks(this); + mapper = new ObjectMapper(); + tokenSource = new GithubIDTokenSource(TEST_REQUEST_URL, TEST_REQUEST_TOKEN, mockHttpClient); + } + + @Test + void testSuccessfulTokenRetrieval() throws IOException { + // Prepare mock response + ObjectNode responseJson = mapper.createObjectNode(); + responseJson.put("value", TEST_ID_TOKEN); + Response mockResponse = makeResponse(responseJson.toString(), 200); + when(mockHttpClient.execute(any(Request.class))).thenReturn(mockResponse); + + // Test token retrieval + IDToken token = tokenSource.getIDToken(TEST_AUDIENCE); + + assertNotNull(token); + assertEquals(TEST_ID_TOKEN, token.getValue()); + + // Verify the request was made with correct parameters + verify(mockHttpClient) + .execute( + argThat( + request -> { + return request.getMethod().equals("GET") + && request.getUrl().startsWith(TEST_REQUEST_URL) + && request.getUrl().contains("audience=" + TEST_AUDIENCE) + && request + .getHeaders() + .get("Authorization") + .equals("Bearer " + TEST_REQUEST_TOKEN); + })); + } + + @Test + void testSuccessfulTokenRetrievalWithoutAudience() throws IOException { + // Prepare mock response + ObjectNode responseJson = mapper.createObjectNode(); + responseJson.put("value", TEST_ID_TOKEN); + Response mockResponse = makeResponse(responseJson.toString(), 200); + when(mockHttpClient.execute(any(Request.class))).thenReturn(mockResponse); + + // Test token retrieval without audience + IDToken token = tokenSource.getIDToken(""); + + assertNotNull(token); + assertEquals(TEST_ID_TOKEN, token.getValue()); + + // Verify the request was made with correct parameters + verify(mockHttpClient) + .execute( + argThat( + request -> { + return request.getMethod().equals("GET") + && request.getUrl().equals(TEST_REQUEST_URL) + && request + .getHeaders() + .get("Authorization") + .equals("Bearer " + TEST_REQUEST_TOKEN); + })); + } + + private static Stream provideInvalidConstructorParameters() { + return Stream.of( + Arguments.of("Missing Request URL", null, TEST_REQUEST_TOKEN, mockHttpClient), + Arguments.of("Missing Request Token", TEST_REQUEST_URL, null, mockHttpClient), + Arguments.of("Null HttpClient", TEST_REQUEST_URL, TEST_REQUEST_TOKEN, null)); + } + + @ParameterizedTest(name = "{0}") + @MethodSource("provideInvalidConstructorParameters") + void testInvalidConstructorParameters( + String testName, String requestUrl, String requestToken, HttpClient httpClient) { + GithubIDTokenSource invalidSource = + new GithubIDTokenSource(requestUrl, requestToken, httpClient); + assertThrows(DatabricksException.class, () -> invalidSource.getIDToken(TEST_AUDIENCE)); + } + + private static Stream provideHttpErrorScenarios() throws IOException { + HttpClient httpClientError = mock(HttpClient.class); + when(httpClientError.execute(any(Request.class))).thenThrow(new IOException("Network error")); + + HttpClient nonSuccessClient = mock(HttpClient.class); + when(nonSuccessClient.execute(any(Request.class))) + .thenReturn(makeResponse("Error response", 400)); + + HttpClient invalidJsonClient = mock(HttpClient.class); + when(invalidJsonClient.execute(any(Request.class))) + .thenReturn(makeResponse("Invalid json", 200)); + + HttpClient missingTokenClient = mock(HttpClient.class); + when(missingTokenClient.execute(any(Request.class))).thenReturn(makeResponse("{}", 200)); + + HttpClient emptyTokenClient = mock(HttpClient.class); + when(emptyTokenClient.execute(any(Request.class))) + .thenReturn(makeResponse("{\"value\":\"\"}", 200)); + + return Stream.of( + Arguments.of("HTTP Client Error", httpClientError), + Arguments.of("Non-Success Status Code", nonSuccessClient), + Arguments.of("Invalid JSON Response", invalidJsonClient), + Arguments.of("Missing Token Value", missingTokenClient), + Arguments.of("Empty Token Value", emptyTokenClient)); + } + + @ParameterizedTest(name = "{0}") + @MethodSource("provideHttpErrorScenarios") + void testHttpErrorScenarios(String testName, HttpClient httpClient) { + GithubIDTokenSource source = + new GithubIDTokenSource(TEST_REQUEST_URL, TEST_REQUEST_TOKEN, httpClient); + assertThrows(DatabricksException.class, () -> source.getIDToken(TEST_AUDIENCE)); + } + + private static Response makeResponse(String body, int status) throws MalformedURLException { + return new Response(body, status, "status", new URL("https://databricks.com/")); + } +}