Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
passes authentication directly from workspace, adds proper fingerprin…
Browse files Browse the repository at this point in the history
…ting
rudolfix committed Jan 30, 2025
1 parent 9ba1801 commit 2c54ed9
Showing 6 changed files with 152 additions and 20 deletions.
50 changes: 36 additions & 14 deletions dlt/destinations/impl/databricks/configuration.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import dataclasses
from typing import ClassVar, Final, Optional, Any, Dict, List

from dlt.common import logger
from dlt.common.typing import TSecretStrValue
from dlt.common.configuration.specs.base_configuration import CredentialsConfiguration, configspec
from dlt.common.destination.reference import DestinationClientDwhWithStagingConfiguration
from dlt.common.configuration.exceptions import ConfigurationValueError
from dlt.common.pipeline import get_dlt_pipelines_dir
from dlt.common.utils import digest128

DATABRICKS_APPLICATION_ID = "dltHub_dlt"

@@ -25,7 +26,6 @@ class DatabricksCredentials(CredentialsConfiguration):
"""Additional keyword arguments that are passed to `databricks.sql.connect`"""
socket_timeout: Optional[int] = 180
user_agent_entry: Optional[str] = DATABRICKS_APPLICATION_ID
staging_allowed_local_path: Optional[str] = None

__config_gen_annotations__: ClassVar[List[str]] = [
"server_hostname",
@@ -37,10 +37,6 @@ class DatabricksCredentials(CredentialsConfiguration):
]

def on_resolved(self) -> None:
# conn parameter staging_allowed_local_path must be set to use 'REMOVE volume_path' SQL statement
if not self.staging_allowed_local_path:
self.staging_allowed_local_path = get_dlt_pipelines_dir()

if not ((self.client_id and self.client_secret) or self.access_token):
try:
# attempt context authentication
@@ -51,6 +47,15 @@ def on_resolved(self) -> None:
except Exception:
self.access_token = None

try:
from databricks.sdk import WorkspaceClient

w = WorkspaceClient()
self.access_token = w.config.authenticate # type: ignore[assignment]
logger.info(f"Will attempt to use default auth of type {w.config.auth_type}")
except Exception:
pass

if not self.access_token:
raise ConfigurationValueError(
"Authentication failed: No valid authentication method detected. "
@@ -62,12 +67,19 @@ def on_resolved(self) -> None:
try:
# attempt to fetch warehouse details
from databricks.sdk import WorkspaceClient
from databricks.sdk.service.sql import EndpointInfo

w = WorkspaceClient()
warehouses: List[EndpointInfo] = list(w.warehouses.list())
self.server_hostname = self.server_hostname or warehouses[0].odbc_params.hostname
self.http_path = self.http_path or warehouses[0].odbc_params.path
# warehouse ID may be present in an env variable
if w.config.warehouse_id:
warehouse = w.warehouses.get(w.config.warehouse_id)
else:
# for some reason list of warehouses has different type than a single one 🤯
warehouse = list(w.warehouses.list())[0] # type: ignore[assignment]
logger.info(
f"Will attempt to use warehouse {warehouse.id} to get sql connection params"
)
self.server_hostname = self.server_hostname or warehouse.odbc_params.hostname
self.http_path = self.http_path or warehouse.odbc_params.path
except Exception:
pass

@@ -86,7 +98,6 @@ def to_connector_params(self) -> Dict[str, Any]:
access_token=self.access_token,
session_configuration=self.session_configuration or {},
_socket_timeout=self.socket_timeout,
staging_allowed_local_path=self.staging_allowed_local_path,
**(self.connection_parameters or {}),
)

@@ -97,6 +108,9 @@ def to_connector_params(self) -> Dict[str, Any]:

return conn_params

def __str__(self) -> str:
return f"databricks://{self.server_hostname}{self.http_path}/{self.catalog}"


@configspec
class DatabricksClientConfiguration(DestinationClientDwhWithStagingConfiguration):
@@ -108,10 +122,18 @@ class DatabricksClientConfiguration(DestinationClientDwhWithStagingConfiguration
"""If true, the temporary credentials are not propagated to the COPY command"""
staging_volume_name: Optional[str] = None
"""Name of the Databricks managed volume for temporary storage, e.g., <catalog_name>.<database_name>.<volume_name>. Defaults to '_dlt_temp_load_volume' if not set."""
keep_staged_files: Optional[bool] = True
"""Tells if to keep the files in internal (volume) stage"""

def __str__(self) -> str:
"""Return displayable destination location"""
if self.staging_config:
return str(self.staging_config.credentials)
if self.credentials:
return str(self.credentials)
else:
return "[no staging set]"
return ""

def fingerprint(self) -> str:
"""Returns a fingerprint of host part of a connection string"""
if self.credentials and self.credentials.server_hostname:
return digest128(self.credentials.server_hostname)
return ""
14 changes: 11 additions & 3 deletions dlt/destinations/impl/databricks/databricks.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from typing import Optional, Sequence, List, cast
from urllib.parse import urlparse, urlunparse

@@ -58,6 +59,10 @@ def run(self) -> None:
# decide if this is a local file or a staged file
is_local_file = not ReferenceFollowupJobRequest.is_reference_job(self._file_path)
if is_local_file:
# conn parameter staging_allowed_local_path must be set to use 'PUT/REMOVE volume_path' SQL statement
self._sql_client.native_connection.thrift_backend.staging_allowed_local_path = (
os.path.dirname(self._file_path)
)
# local file by uploading to a temporary volume on Databricks
from_clause, file_name, volume_path, volume_file_path = self._handle_local_file_upload(
self._file_path
@@ -89,9 +94,12 @@ def run(self) -> None:

self._sql_client.execute_sql(statement)

if is_local_file:
self._sql_client.execute_sql(f"REMOVE '{volume_file_path}'")
self._sql_client.execute_sql(f"REMOVE '{volume_path}'")
if is_local_file and not self._job_client.config.keep_staged_files:
self._handle_staged_file_remove(volume_path, volume_file_path)

def _handle_staged_file_remove(self, volume_path: str, volume_file_path: str) -> None:
self._sql_client.execute_sql(f"REMOVE '{volume_file_path}'")
self._sql_client.execute_sql(f"REMOVE '{volume_path}'")

def _handle_local_file_upload(self, local_file_path: str) -> tuple[str, str, str, str]:
file_name = FileStorage.get_file_name_from_file_path(local_file_path)
4 changes: 4 additions & 0 deletions dlt/destinations/impl/databricks/sql_client.py
Original file line number Diff line number Diff line change
@@ -88,7 +88,11 @@ def open_connection(self) -> DatabricksSqlConnection:

if self.credentials.client_id and self.credentials.client_secret:
conn_params["credentials_provider"] = self._get_oauth_credentials
elif callable(self.credentials.access_token):
# this is w.config.authenticator
conn_params["credentials_provider"] = lambda: self.credentials.access_token
else:
# this is access token
conn_params["access_token"] = self.credentials.access_token

self._conn = databricks_lib.connect(
17 changes: 16 additions & 1 deletion docs/website/docs/dlt-ecosystem/destinations/databricks.md
Original file line number Diff line number Diff line change
@@ -240,6 +240,13 @@ You can find other options for specifying credentials in the [Authentication sec

See [Staging support](#staging-support) for authentication options when `dlt` copies files from buckets.

### Using default credentials
If none of auth methods above is configured, `dlt` attempts to get authorization from the Databricks workspace context. The context may
come, for example, from a Notebook (runtime) or via standard set of env variables that Databricks Python sdk recognizes (ie. **DATABRICKS_TOKEN** or **DATABRICKS_HOST**)

`dlt` is able to set `server_hostname` and `http_path` from available warehouses. We use default warehouse id (**DATABRICKS_WAREHOUSE_ID**)
if set (via env variable), or a first one on warehouse's list.

## Write disposition
All write dispositions are supported.

@@ -300,7 +307,15 @@ print(pipeline.dataset().pokemon.df())

- If **no** *staging_volume_name* **is provided**, dlt creates a **default volume** automatically.
- **For production**, explicitly setting *staging_volume_name* is recommended.
- The volume is used as a **temporary location** to store files before loading. Files are **deleted immediately** after loading.
- The volume is used as a **temporary location** to store files before loading.

:::tip::
You can delete staged files **immediately** after loading by setting the following config option:
```toml
[destination.databricks]
keep_staged_files = false
```
:::

## Staging support

70 changes: 70 additions & 0 deletions tests/load/databricks/test_databricks_configuration.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import pytest
import os

from dlt.common.schema.schema import Schema
from dlt.common.utils import digest128

pytest.importorskip("databricks")

import dlt
from dlt.common.exceptions import TerminalValueError
from dlt.common.configuration.exceptions import ConfigurationValueError
from dlt.destinations.impl.databricks.databricks import DatabricksLoadJob
@@ -12,6 +16,7 @@
from dlt.destinations.impl.databricks.configuration import (
DatabricksClientConfiguration,
DATABRICKS_APPLICATION_ID,
DatabricksCredentials,
)

# mark all tests as essential, do not remove
@@ -43,6 +48,11 @@ def test_databricks_credentials_to_connector_params():
assert params["_socket_timeout"] == credentials.socket_timeout
assert params["_user_agent_entry"] == DATABRICKS_APPLICATION_ID

displayable_location = str(credentials)
assert displayable_location.startswith(
"databricks://my-databricks.example.com/sql/1.0/warehouses/asdfe/my-catalog"
)


def test_databricks_configuration() -> None:
bricks = databricks()
@@ -125,3 +135,63 @@ def test_databricks_missing_config_server_hostname() -> None:
os.environ["DESTINATION__DATABRICKS__CREDENTIALS__SERVER_HOSTNAME"] = ""
bricks = databricks()
bricks.configuration(None, accept_partial=True)


def test_default_credentials() -> None:
# from dlt.destinations.impl.databricks.sql_client import DatabricksSqlClient
# create minimal default env
os.environ["DATABRICKS_TOKEN"] = dlt.secrets["destination.databricks.credentials.access_token"]
os.environ["DATABRICKS_HOST"] = dlt.secrets[
"destination.databricks.credentials.server_hostname"
]

config = resolve_configuration(
DatabricksClientConfiguration(
credentials=DatabricksCredentials(catalog="dlt_ci")
)._bind_dataset_name(dataset_name="my-dataset-1234")
)
# we pass authenticator that will be used to make connection, that's why callable
assert callable(config.credentials.access_token)
# taken from a warehouse
assert isinstance(config.credentials.http_path, str)

bricks = databricks(credentials=config.credentials)
# "my-dataset-1234" not present (we check SQL execution)
with bricks.client(Schema("schema"), config) as client:
assert not client.is_storage_initialized()

# check fingerprint not default
assert config.fingerprint() != digest128("")


def test_oauth2_credentials() -> None:
config = resolve_configuration(
DatabricksClientConfiguration(
credentials=DatabricksCredentials(
catalog="dlt_ci", client_id="ctx-id", client_secret="xx0xx"
)
)._bind_dataset_name(dataset_name="my-dataset-1234"),
sections=("destination", "databricks"),
)
# will resolve to oauth token
bricks = databricks(credentials=config.credentials)
# "my-dataset-1234" not present (we check SQL execution)
with pytest.raises(Exception, match="Client authentication failed"):
with bricks.client(Schema("schema"), config):
pass


def test_default_warehouse() -> None:
os.environ["DATABRICKS_TOKEN"] = dlt.secrets["destination.databricks.credentials.access_token"]
os.environ["DATABRICKS_HOST"] = dlt.secrets[
"destination.databricks.credentials.server_hostname"
]
# will force this warehouse
os.environ["DATABRICKS_WAREHOUSE_ID"] = "588dbd71bd802f4d"

config = resolve_configuration(
DatabricksClientConfiguration(
credentials=DatabricksCredentials(catalog="dlt_ci")
)._bind_dataset_name(dataset_name="my-dataset-1234")
)
assert config.credentials.http_path == "/sql/1.0/warehouses/588dbd71bd802f4d"
17 changes: 15 additions & 2 deletions tests/load/pipeline/test_databricks_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import pytest
import os

from pytest_mock import MockerFixture
import dlt

from dlt.common.utils import uniq_id
@@ -248,11 +250,19 @@ def test_databricks_direct_load(destination_config: DestinationTestConfiguration
destinations_configs(default_sql_configs=True, subset=("databricks",)),
ids=lambda x: x.name,
)
def test_databricks_direct_load_with_custom_staging_volume_name(
@pytest.mark.parametrize("keep_staged_files", (True, False))
def test_databricks_direct_load_with_custom_staging_volume_name_and_file_removal(
destination_config: DestinationTestConfiguration,
keep_staged_files: bool,
mocker: MockerFixture,
) -> None:
from dlt.destinations.impl.databricks.databricks import DatabricksLoadJob

remove_spy = mocker.spy(DatabricksLoadJob, "_handle_staged_file_remove")
custom_staging_volume_name = "dlt_ci.dlt_tests_shared.static_volume"
bricks = databricks(staging_volume_name=custom_staging_volume_name)
bricks = databricks(
staging_volume_name=custom_staging_volume_name, keep_staged_files=keep_staged_files
)

dataset_name = "test_databricks_direct_load" + uniq_id()
pipeline = destination_config.setup_pipeline(
@@ -261,6 +271,9 @@ def test_databricks_direct_load_with_custom_staging_volume_name(

info = pipeline.run([1, 2, 3], table_name="digits", **destination_config.run_kwargs)
assert info.has_failed_jobs is False
print(info)

assert remove_spy.call_count == 0 if keep_staged_files else 2

with pipeline.sql_client() as client:
rows = client.execute_sql(f"select * from {dataset_name}.digits")

0 comments on commit 2c54ed9

Please sign in to comment.