diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index 0522e387f..0e935faf2 100644 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -3,6 +3,7 @@ ## Release v0.60.0 ### New Features and Improvements +- Azure Service Principal credential provider can now automatically discover tenant ID when not explicitly provided ### Bug Fixes diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProvider.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProvider.java index 6e6df86eb..640244b9d 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProvider.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProvider.java @@ -5,13 +5,18 @@ import com.fasterxml.jackson.databind.ObjectMapper; import java.util.HashMap; import java.util.Map; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * Adds refreshed Azure Active Directory (AAD) Service Principal OAuth tokens to every request, * while automatically resolving different Azure environment endpoints. */ public class AzureServicePrincipalCredentialsProvider implements CredentialsProvider { + private static final Logger logger = + LoggerFactory.getLogger(AzureServicePrincipalCredentialsProvider.class); private final ObjectMapper mapper = new ObjectMapper(); + private String tenantId; @Override public String authType() { @@ -22,12 +27,22 @@ public String authType() { public OAuthHeaderFactory configure(DatabricksConfig config) { if (!config.isAzure() || config.getAzureClientId() == null - || config.getAzureClientSecret() == null - || config.getAzureTenantId() == null) { + || config.getAzureClientSecret() == null) { return null; } - AzureUtils.ensureHostPresent( - config, mapper, AzureServicePrincipalCredentialsProvider::tokenSourceFor); + + try { + this.tenantId = + config.getAzureTenantId() != null + ? config.getAzureTenantId() + : AzureUtils.inferTenantId(config); + } catch (Exception e) { + logger.warn("Failed to infer Azure tenant ID: {}", e.getMessage()); + return null; + } + + AzureUtils.ensureHostPresent(config, mapper, this::tokenSourceFor); + CachedTokenSource inner = tokenSourceFor(config, config.getEffectiveAzureLoginAppId()); CachedTokenSource cloud = tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint()); @@ -55,9 +70,9 @@ public OAuthHeaderFactory configure(DatabricksConfig config) { * @return A CachedTokenSource instance capable of fetching OAuth tokens for the specified Azure * resource. */ - private static CachedTokenSource tokenSourceFor(DatabricksConfig config, String resource) { + private CachedTokenSource tokenSourceFor(DatabricksConfig config, String resource) { String aadEndpoint = config.getAzureEnvironment().getActiveDirectoryEndpoint(); - String tokenUrl = aadEndpoint + config.getAzureTenantId() + "/oauth2/token"; + String tokenUrl = aadEndpoint + this.tenantId + "/oauth2/token"; Map endpointParams = new HashMap<>(); endpointParams.put("resource", resource); diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/utils/AzureUtils.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/utils/AzureUtils.java index 09cea6e86..c5b0aaaa0 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/utils/AzureUtils.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/utils/AzureUtils.java @@ -10,12 +10,16 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.node.ObjectNode; import java.io.IOException; +import java.net.URL; import java.util.Map; import java.util.Optional; import java.util.function.BiFunction; public class AzureUtils { + /** Azure authentication endpoint for tenant ID discovery */ + private static final String AZURE_AUTH_ENDPOINT = "/aad/auth"; + public static String getWorkspaceFromJsonResponse(ObjectNode jsonResponse) throws IOException { JsonNode properties = jsonResponse.get("properties"); if (properties == null) { @@ -95,4 +99,74 @@ public static Optional getAzureWorkspaceResourceId(Workspace workspace) workspace.getWorkspaceName()); return Optional.of(resourceId); } + + /** + * Infers the Azure tenant ID from the Databricks workspace login page. + * + * @param config The DatabricksConfig instance + * @return the discovered tenant ID + * @throws DatabricksException if tenant ID discovery fails + */ + public static String inferTenantId(DatabricksConfig config) throws DatabricksException { + + if (config.getAzureTenantId() != null) { + return config.getAzureTenantId(); + } + + if (config.getHost() == null) { + throw new DatabricksException("Cannot infer tenant ID: host is missing"); + } + + if (!config.isAzure()) { + throw new DatabricksException("Cannot infer tenant ID: workspace is not Azure"); + } + + String loginUrl = config.getHost() + AZURE_AUTH_ENDPOINT; + + try { + String redirectLocation = getRedirectLocation(config, loginUrl); + return extractTenantIdFromUrl(redirectLocation); + + } catch (Exception e) { + throw new DatabricksException("Failed to infer Azure tenant ID from " + loginUrl, e); + } + } + + private static String getRedirectLocation(DatabricksConfig config, String loginUrl) + throws IOException { + Request request = new Request("GET", loginUrl); + request.setRedirectionBehavior(false); + Response response = config.getHttpClient().execute(request); + + if (response.getStatusCode() != 302) { + throw new DatabricksException( + "Expected redirect (302) from " + + loginUrl + + ", got status code: " + + response.getStatusCode()); + } + + String location = response.getFirstHeader("Location"); + if (location == null) { + throw new DatabricksException("No Location header in redirect response from " + loginUrl); + } + + return location; + } + + private static String extractTenantIdFromUrl(String redirectUrl) throws DatabricksException { + try { + // Parse: https://login.microsoftonline.com//oauth2/authorize?... + URL entraIdUrl = new URL(redirectUrl); + String[] pathSegments = entraIdUrl.getPath().split("/"); + + if (pathSegments.length < 2) { + throw new DatabricksException("Invalid path in Location header: " + entraIdUrl.getPath()); + } + + return pathSegments[1]; + } catch (Exception e) { + throw new DatabricksException("Failed to parse tenant ID from URL " + redirectUrl, e); + } + } } diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/utils/AzureUtilsTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/utils/AzureUtilsTest.java new file mode 100644 index 000000000..435471a32 --- /dev/null +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/utils/AzureUtilsTest.java @@ -0,0 +1,172 @@ +package com.databricks.sdk.core.utils; + +import static org.junit.jupiter.api.Assertions.*; + +import com.databricks.sdk.core.DatabricksConfig; +import com.databricks.sdk.core.DatabricksException; +import com.databricks.sdk.core.FixtureServer; +import com.databricks.sdk.core.commons.CommonsHttpClient; +import java.io.IOException; +import org.junit.jupiter.api.Test; + +public class AzureUtilsTest { + + @Test + public void testInferTenantId404() throws IOException { + try (FixtureServer server = new FixtureServer().with("GET", "/aad/auth", "", 404)) { + DatabricksConfig config = new DatabricksConfig(); + config.setHost(server.getUrl()); + config.setAzureWorkspaceResourceId( + "/subscriptions/123/resourceGroups/rg/providers/Microsoft.Databricks/workspaces/ws"); + config.setHttpClient(new CommonsHttpClient.Builder().withTimeoutSeconds(30).build()); + + DatabricksException exception = + assertThrows( + DatabricksException.class, + () -> { + AzureUtils.inferTenantId(config); + }); + assertEquals( + "Failed to infer Azure tenant ID from " + server.getUrl() + "/aad/auth", + exception.getMessage()); + + assertNotNull(exception.getCause()); + assertInstanceOf(DatabricksException.class, exception.getCause()); + DatabricksException cause = (DatabricksException) exception.getCause(); + assertEquals( + "Expected redirect (302) from " + server.getUrl() + "/aad/auth, got status code: 404", + cause.getMessage()); + + assertNull(config.getAzureTenantId()); + } + } + + @Test + public void testInferTenantIdNoLocationHeader() throws IOException { + try (FixtureServer server = new FixtureServer().with("GET", "/aad/auth", "", 302)) { + DatabricksConfig config = new DatabricksConfig(); + config.setHost(server.getUrl()); + config.setAzureWorkspaceResourceId( + "/subscriptions/123/resourceGroups/rg/providers/Microsoft.Databricks/workspaces/ws"); + config.setHttpClient(new CommonsHttpClient.Builder().withTimeoutSeconds(30).build()); + + DatabricksException exception = + assertThrows( + DatabricksException.class, + () -> { + AzureUtils.inferTenantId(config); + }); + assertEquals( + "Failed to infer Azure tenant ID from " + server.getUrl() + "/aad/auth", + exception.getMessage()); + + assertNotNull(exception.getCause()); + assertInstanceOf(DatabricksException.class, exception.getCause()); + DatabricksException cause = (DatabricksException) exception.getCause(); + assertEquals( + "No Location header in redirect response from " + server.getUrl() + "/aad/auth", + cause.getMessage()); + + assertNull(config.getAzureTenantId()); + } + } + + @Test + public void testInferTenantIdUnparsableLocationHeader() throws IOException { + FixtureServer.FixtureMapping fixture = + new FixtureServer.FixtureMapping.Builder() + .validateMethod("GET") + .validatePath("/aad/auth") + .withRedirect("https://unexpected-location", 302) + .build(); + + try (FixtureServer server = new FixtureServer().with(fixture)) { + DatabricksConfig config = new DatabricksConfig(); + config.setHost(server.getUrl()); + config.setAzureWorkspaceResourceId( + "/subscriptions/123/resourceGroups/rg/providers/Microsoft.Databricks/workspaces/ws"); + config.setHttpClient(new CommonsHttpClient.Builder().withTimeoutSeconds(30).build()); + + DatabricksException exception = + assertThrows( + DatabricksException.class, + () -> { + AzureUtils.inferTenantId(config); + }); + assertEquals( + "Failed to infer Azure tenant ID from " + server.getUrl() + "/aad/auth", + exception.getMessage()); + + assertNotNull(exception.getCause()); + assertInstanceOf(DatabricksException.class, exception.getCause()); + DatabricksException cause = (DatabricksException) exception.getCause(); + assertEquals( + "Failed to parse tenant ID from URL https://unexpected-location", cause.getMessage()); + + assertNull(config.getAzureTenantId()); + } + } + + @Test + public void testInferTenantIdHappyPath() throws IOException { + FixtureServer.FixtureMapping fixture = + new FixtureServer.FixtureMapping.Builder() + .validateMethod("GET") + .validatePath("/aad/auth") + .withRedirect("https://login.microsoftonline.com/test-tenant-id/oauth2/authorize", 302) + .build(); + + try (FixtureServer server = new FixtureServer().with(fixture)) { + DatabricksConfig config = new DatabricksConfig(); + config.setHost(server.getUrl()); + config.setAzureWorkspaceResourceId( + "/subscriptions/123/resourceGroups/rg/providers/Microsoft.Databricks/workspaces/ws"); + config.setHttpClient(new CommonsHttpClient.Builder().withTimeoutSeconds(30).build()); + String result = AzureUtils.inferTenantId(config); + assertEquals("test-tenant-id", result); + assertNull(config.getAzureTenantId()); // Config should remain unchanged + } + } + + @Test + public void testInferTenantIdSkipsWhenNotAzure() { + DatabricksConfig config = new DatabricksConfig(); + config.setHost("https://my-workspace.cloud.databricks.com"); // non-azure host + config.setHttpClient(new CommonsHttpClient.Builder().withTimeoutSeconds(30).build()); + + DatabricksException exception = + assertThrows( + DatabricksException.class, + () -> { + AzureUtils.inferTenantId(config); + }); + assertEquals("Cannot infer tenant ID: workspace is not Azure", exception.getMessage()); + assertNull(config.getAzureTenantId()); + } + + @Test + public void testInferTenantIdSkipsWhenAlreadySet() { + DatabricksConfig config = new DatabricksConfig(); + config.setHost("https://adb-123.0.azuredatabricks.net"); + config.setAzureTenantId("existing-tenant-id"); + config.setHttpClient(new CommonsHttpClient.Builder().withTimeoutSeconds(30).build()); + String result = AzureUtils.inferTenantId(config); + assertEquals("existing-tenant-id", result); + assertEquals("existing-tenant-id", config.getAzureTenantId()); // Config should remain unchanged + } + + @Test + public void testInferTenantIdSkipsWhenNoHost() { + DatabricksConfig config = new DatabricksConfig(); + config.setHttpClient(new CommonsHttpClient.Builder().withTimeoutSeconds(30).build()); + + DatabricksException exception = + assertThrows( + DatabricksException.class, + () -> { + AzureUtils.inferTenantId(config); + }); + assertEquals("Cannot infer tenant ID: host is missing", exception.getMessage()); + assertNull(config.getAzureTenantId()); + } +}