Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ public class DatabricksConfig {
@ConfigAttribute(env = "DATABRICKS_SCOPES", auth = "oauth")
private List<String> scopes;

@ConfigAttribute(env = "DATABRICKS_TOKEN_AUDIENCE", auth = "oauth")
private String audience;

@ConfigAttribute(env = "DATABRICKS_REDIRECT_URL", auth = "oauth")
private String redirectUrl;

Expand Down Expand Up @@ -309,6 +312,15 @@ public DatabricksConfig setClientSecret(String clientSecret) {
return this;
}

public String getAudience() {
return audience;
}

public DatabricksConfig setAudience(String audience) {
this.audience = audience;
return this;
}

public String getOAuthRedirectUrl() {
return redirectUrl;
}
Expand Down Expand Up @@ -381,13 +393,17 @@ public DatabricksConfig setAzureUseMsi(boolean azureUseMsi) {
return this;
}

/** @deprecated Use {@link #getAzureUseMsi()} instead. */
/**
* @deprecated Use {@link #getAzureUseMsi()} instead.
*/
@Deprecated()
public boolean getAzureUseMSI() {
return azureUseMsi;
}

/** @deprecated Use {@link #setAzureUseMsi(boolean)} instead. */
/**
* @deprecated Use {@link #setAzureUseMsi(boolean)} instead.
*/
@Deprecated
public DatabricksConfig setAzureUseMSI(boolean azureUseMsi) {
this.azureUseMsi = azureUseMsi;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,57 +1,39 @@
package com.databricks.sdk.core;

import com.databricks.sdk.core.oauth.*;
import com.databricks.sdk.core.oauth.AzureGithubOidcCredentialsProvider;
import com.databricks.sdk.core.oauth.AzureServicePrincipalCredentialsProvider;
import com.databricks.sdk.core.oauth.DatabricksOAuthTokenSource;
import com.databricks.sdk.core.oauth.ExternalBrowserCredentialsProvider;
import com.databricks.sdk.core.oauth.GithubIDTokenSource;
import com.databricks.sdk.core.oauth.IDTokenSource;
import com.databricks.sdk.core.oauth.OAuthM2MServicePrincipalCredentialsProvider;
import com.databricks.sdk.core.oauth.OpenIDConnectEndpoints;
import com.databricks.sdk.core.oauth.TokenSourceCredentialsProvider;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DefaultCredentialsProvider implements CredentialsProvider {
private static final Logger LOG = LoggerFactory.getLogger(DefaultCredentialsProvider.class);

private static final List<Class<?>> providerClasses =
Arrays.asList(
PatCredentialsProvider.class,
BasicCredentialsProvider.class,
OAuthM2MServicePrincipalCredentialsProvider.class,
GithubOidcCredentialsProvider.class,
AzureGithubOidcCredentialsProvider.class,
AzureServicePrincipalCredentialsProvider.class,
AzureCliCredentialsProvider.class,
ExternalBrowserCredentialsProvider.class,
DatabricksCliCredentialsProvider.class,
NotebookNativeCredentialsProvider.class,
GoogleCredentialsCredentialsProvider.class,
GoogleIdCredentialsProvider.class);

private final List<CredentialsProvider> providers;
private List<CredentialsProvider> providers = new ArrayList<>();

private String authType = "default";

public DefaultCredentialsProvider() {}

@Override
public String authType() {
return authType;
}

public DefaultCredentialsProvider() {
providers = new ArrayList<>();
for (Class<?> clazz : providerClasses) {
try {
providers.add((CredentialsProvider) clazz.newInstance());
} catch (NoClassDefFoundError | InstantiationException | IllegalAccessException e) {
LOG.warn(
"Failed to instantiate credentials provider: "
+ clazz.getName()
+ ", skipping. Cause: "
+ e.getClass().getCanonicalName()
+ ": "
+ e.getMessage());
}
}
}

@Override
public synchronized HeaderFactory configure(DatabricksConfig config) {
addDefaultCredentialsProviders(config);

for (CredentialsProvider provider : providers) {
if (config.getAuthType() != null
&& !config.getAuthType().isEmpty()
Expand Down Expand Up @@ -80,4 +62,57 @@ public synchronized HeaderFactory configure(DatabricksConfig config) {
+ authFlowUrl
+ " to configure credentials for your preferred authentication method");
}

private void addDefaultCredentialsProviders(DatabricksConfig config) {
providers.add(new PatCredentialsProvider());
providers.add(new BasicCredentialsProvider());
providers.add(new OAuthM2MServicePrincipalCredentialsProvider());

addOIDCTokenCredentialsProviders(config);

providers.add(new AzureGithubOidcCredentialsProvider());
providers.add(new AzureServicePrincipalCredentialsProvider());
providers.add(new AzureCliCredentialsProvider());
providers.add(new ExternalBrowserCredentialsProvider());
providers.add(new DatabricksCliCredentialsProvider());
providers.add(new NotebookNativeCredentialsProvider());
providers.add(new GoogleCredentialsCredentialsProvider());
providers.add(new GoogleIdCredentialsProvider());
}

private void addOIDCTokenCredentialsProviders(DatabricksConfig config) {
OpenIDConnectEndpoints endpoints = null;
try {
endpoints = config.getOidcEndpoints();
} catch (Exception e) {
LOG.error("Error getting OIDC endpoints", e);
}

Map<String, IDTokenSource> namedIdTokenSources = new HashMap<>();
namedIdTokenSources.put(
"github-oidc",
new GithubIDTokenSource(
config.getActionsIdTokenRequestUrl(),
config.getActionsIdTokenRequestToken(),
config.getHttpClient()));
// Add new providers to the map as needed

for (Map.Entry<String, IDTokenSource> entry : namedIdTokenSources.entrySet()) {
String name = entry.getKey();
IDTokenSource idTokenSource = entry.getValue();

DatabricksOAuthTokenSource oauthTokenSource =
new DatabricksOAuthTokenSource.Builder(
config.getClientId(),
config.getHost(),
endpoints,
idTokenSource,
config.getHttpClient())
.audience(config.getAudience())
.accountId(config.isAccountClient() ? config.getAccountId() : null)
.build();

providers.add(new TokenSourceCredentialsProvider(oauthTokenSource, name));
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package com.databricks.sdk.core.oauth;

import com.databricks.sdk.core.DatabricksException;
import com.databricks.sdk.core.http.HttpClient;
import com.databricks.sdk.core.http.Request;
import com.databricks.sdk.core.http.Response;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.google.common.base.Strings;
import java.io.IOException;

/** GithubIDTokenSource retrieves JWT Tokens from GitHub Actions. */
public class GithubIDTokenSource implements IDTokenSource {
private final String actionsIDTokenRequestURL;
private final String actionsIDTokenRequestToken;
private final HttpClient httpClient;
private final ObjectMapper mapper = new ObjectMapper();

/**
* Constructs a new GithubIDTokenSource.
*
* @param actionsIDTokenRequestURL The URL to request the ID token from GitHub Actions.
* @param actionsIDTokenRequestToken The token used to authenticate the request.
* @param httpClient The HTTP client to use for making requests.
*/
public GithubIDTokenSource(
String actionsIDTokenRequestURL, String actionsIDTokenRequestToken, HttpClient httpClient) {
this.actionsIDTokenRequestURL = actionsIDTokenRequestURL;
this.actionsIDTokenRequestToken = actionsIDTokenRequestToken;
this.httpClient = httpClient;
}

@Override
public IDToken getIDToken(String audience) {
if (Strings.isNullOrEmpty(actionsIDTokenRequestURL)) {
throw new DatabricksException("Missing ActionsIDTokenRequestURL");
}
if (Strings.isNullOrEmpty(actionsIDTokenRequestToken)) {
throw new DatabricksException("Missing ActionsIDTokenRequestToken");
}
if (httpClient == null) {
throw new DatabricksException("HttpClient cannot be null");
}

String requestUrl = actionsIDTokenRequestURL;
if (!Strings.isNullOrEmpty(audience)) {
requestUrl = String.format("%s&audience=%s", requestUrl, audience);
}

Request req =
new Request("GET", requestUrl)
.withHeader("Authorization", "Bearer " + actionsIDTokenRequestToken);

Response resp;
try {
resp = httpClient.execute(req);
} catch (IOException e) {
throw new DatabricksException(
"Failed to request ID token from " + requestUrl + ": " + e.getMessage(), e);
}

if (resp.getStatusCode() != 200) {
throw new DatabricksException(
"Failed to request ID token: status code "
+ resp.getStatusCode()
+ ", response body: "
+ resp.getBody().toString());
}

ObjectNode jsonResp;
try {
jsonResp = mapper.readValue(resp.getBody(), ObjectNode.class);
} catch (IOException e) {
throw new DatabricksException(
"Failed to request ID token: corrupted token: " + e.getMessage());
}

if (!jsonResp.has("value")) {
throw new DatabricksException("ID token response missing 'value' field");
}

String tokenValue = jsonResp.get("value").textValue();
if (Strings.isNullOrEmpty(tokenValue)) {
throw new DatabricksException("Received empty ID token from GitHub Actions");
}

return new IDToken(tokenValue);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package com.databricks.sdk.core.oauth;

import com.databricks.sdk.core.CredentialsProvider;
import com.databricks.sdk.core.DatabricksConfig;
import com.databricks.sdk.core.HeaderFactory;
import java.util.HashMap;
import java.util.Map;

/** Base class for token-based credentials providers. */
public class TokenSourceCredentialsProvider implements CredentialsProvider {
private final TokenSource tokenSource;
private final String authType;

/**
* Creates a new TokenSourceCredentialsProvider with the specified token source and auth type.
*
* @param tokenSource The token source to use for token exchange
* @param authType The authentication type string
*/
public TokenSourceCredentialsProvider(TokenSource tokenSource, String authType) {
this.tokenSource = tokenSource;
this.authType = authType;
}

@Override
public HeaderFactory configure(DatabricksConfig config) {

return () -> {
Map<String, String> headers = new HashMap<>();
try {
String accessToken = tokenSource.getToken().getAccessToken();
headers.put("Authorization", "Bearer " + accessToken);
return headers;
} catch (Exception e) {
return null;
}
};
}

@Override
public String authType() {
return authType;
}
}
Loading
Loading