From 9d560d9325c084cde611bc507df6d21c95387923 Mon Sep 17 00:00:00 2001 From: Julian Alves <28436330+donotpush@users.noreply.github.com> Date: Wed, 15 Jan 2025 10:03:52 +0100 Subject: [PATCH 01/28] databricks: enable local files --- dlt/destinations/impl/databricks/factory.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dlt/destinations/impl/databricks/factory.py b/dlt/destinations/impl/databricks/factory.py index a73a575901..da4eec5f20 100644 --- a/dlt/destinations/impl/databricks/factory.py +++ b/dlt/destinations/impl/databricks/factory.py @@ -107,8 +107,8 @@ class databricks(Destination[DatabricksClientConfiguration, "DatabricksClient"]) def _raw_capabilities(self) -> DestinationCapabilitiesContext: caps = DestinationCapabilitiesContext() - caps.preferred_loader_file_format = None - caps.supported_loader_file_formats = [] + caps.preferred_loader_file_format = "parquet" + caps.supported_loader_file_formats = ["jsonl", "parquet"] caps.preferred_staging_file_format = "parquet" caps.supported_staging_file_formats = ["jsonl", "parquet"] caps.supported_table_formats = ["delta"] From 902c49df6b0c46b2ad22df33c0d1d7d136210432 Mon Sep 17 00:00:00 2001 From: Julian Alves <28436330+donotpush@users.noreply.github.com> Date: Wed, 15 Jan 2025 10:04:11 +0100 Subject: [PATCH 02/28] fix: databricks test config --- .../load/pipeline/test_databricks_pipeline.py | 27 +++++++++++++++---- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/tests/load/pipeline/test_databricks_pipeline.py b/tests/load/pipeline/test_databricks_pipeline.py index 078dce3a7f..7e07ba5277 100644 --- a/tests/load/pipeline/test_databricks_pipeline.py +++ b/tests/load/pipeline/test_databricks_pipeline.py @@ -159,14 +159,22 @@ def test_databricks_gcs_external_location(destination_config: DestinationTestCon ) def test_databricks_auth_oauth(destination_config: DestinationTestConfiguration) -> None: os.environ["DESTINATION__DATABRICKS__CREDENTIALS__ACCESS_TOKEN"] = "" - bricks = databricks() + + from dlt.destinations import databricks, filesystem + from dlt.destinations.impl.databricks.databricks import DatabricksLoadJob + + abfss_bucket_url = DatabricksLoadJob.ensure_databricks_abfss_url(AZ_BUCKET, "dltdata") + stage = filesystem(abfss_bucket_url) + + bricks = databricks(is_staging_external_location=False) config = bricks.configuration(None, accept_partial=True) + assert config.credentials.client_id and config.credentials.client_secret assert not config.credentials.access_token dataset_name = "test_databricks_oauth" + uniq_id() pipeline = destination_config.setup_pipeline( - "test_databricks_oauth", dataset_name=dataset_name, destination=bricks + "test_databricks_oauth", dataset_name=dataset_name, destination=bricks, staging=stage ) info = pipeline.run([1, 2, 3], table_name="digits", **destination_config.run_kwargs) @@ -179,20 +187,29 @@ def test_databricks_auth_oauth(destination_config: DestinationTestConfiguration) @pytest.mark.parametrize( "destination_config", - destinations_configs(default_sql_configs=True, subset=("databricks",)), + destinations_configs( + default_sql_configs=True, bucket_subset=(AZ_BUCKET,), subset=("databricks",) + ), ids=lambda x: x.name, ) def test_databricks_auth_token(destination_config: DestinationTestConfiguration) -> None: os.environ["DESTINATION__DATABRICKS__CREDENTIALS__CLIENT_ID"] = "" os.environ["DESTINATION__DATABRICKS__CREDENTIALS__CLIENT_SECRET"] = "" - bricks = databricks() + + from dlt.destinations import databricks, filesystem + from dlt.destinations.impl.databricks.databricks import DatabricksLoadJob + + abfss_bucket_url = DatabricksLoadJob.ensure_databricks_abfss_url(AZ_BUCKET, "dltdata") + stage = filesystem(abfss_bucket_url) + + bricks = databricks(is_staging_external_location=False) config = bricks.configuration(None, accept_partial=True) assert config.credentials.access_token assert not (config.credentials.client_secret and config.credentials.client_id) dataset_name = "test_databricks_token" + uniq_id() pipeline = destination_config.setup_pipeline( - "test_databricks_token", dataset_name=dataset_name, destination=bricks + "test_databricks_token", dataset_name=dataset_name, destination=bricks, staging=stage ) info = pipeline.run([1, 2, 3], table_name="digits", **destination_config.run_kwargs) From 1efe56564907b475b69d92559bd96127348844f0 Mon Sep 17 00:00:00 2001 From: Julian Alves <28436330+donotpush@users.noreply.github.com> Date: Wed, 15 Jan 2025 20:24:45 +0100 Subject: [PATCH 03/28] work in progress --- .../impl/databricks/configuration.py | 26 +- .../impl/databricks/databricks.py | 274 ++++++++++++------ .../load/pipeline/test_databricks_pipeline.py | 29 ++ 3 files changed, 242 insertions(+), 87 deletions(-) diff --git a/dlt/destinations/impl/databricks/configuration.py b/dlt/destinations/impl/databricks/configuration.py index 21338bd310..630f0bcc2a 100644 --- a/dlt/destinations/impl/databricks/configuration.py +++ b/dlt/destinations/impl/databricks/configuration.py @@ -5,7 +5,7 @@ from dlt.common.configuration.specs.base_configuration import CredentialsConfiguration, configspec from dlt.common.destination.reference import DestinationClientDwhWithStagingConfiguration from dlt.common.configuration.exceptions import ConfigurationValueError - +from dlt.common import logger DATABRICKS_APPLICATION_ID = "dltHub_dlt" @@ -15,6 +15,7 @@ class DatabricksCredentials(CredentialsConfiguration): catalog: str = None server_hostname: str = None http_path: str = None + is_token_from_context: bool = False access_token: Optional[TSecretStrValue] = None client_id: Optional[TSecretStrValue] = None client_secret: Optional[TSecretStrValue] = None @@ -37,11 +38,28 @@ class DatabricksCredentials(CredentialsConfiguration): def on_resolved(self) -> None: if not ((self.client_id and self.client_secret) or self.access_token): - raise ConfigurationValueError( - "No valid authentication method detected. Provide either 'client_id' and" - " 'client_secret' for OAuth, or 'access_token' for token-based authentication." + # databricks authentication: attempt context token + from databricks.sdk import WorkspaceClient + + w = WorkspaceClient() + dbutils = w.dbutils + self.access_token = ( + dbutils.notebook.entry_point.getDbutils() + .notebook() + .getContext() + .apiToken() + .getOrElse(None) ) + if not self.access_token: + raise ConfigurationValueError( + "No valid authentication method detected. Provide either 'client_id' and" + " 'client_secret' for OAuth, or 'access_token' for token-based authentication." + ) + + self.is_token_from_context = True + logger.info("Authenticating to Databricks using the user's Notebook API token.") + def to_connector_params(self) -> Dict[str, Any]: conn_params = dict( catalog=self.catalog, diff --git a/dlt/destinations/impl/databricks/databricks.py b/dlt/destinations/impl/databricks/databricks.py index fd97ded510..1fb98c1563 100644 --- a/dlt/destinations/impl/databricks/databricks.py +++ b/dlt/destinations/impl/databricks/databricks.py @@ -35,7 +35,7 @@ from dlt.destinations.sql_jobs import SqlMergeFollowupJob from dlt.destinations.job_impl import ReferenceFollowupJobRequest from dlt.destinations.utils import is_compression_disabled - +from dlt.common import logger SUPPORTED_BLOB_STORAGE_PROTOCOLS = AZURE_BLOB_STORAGE_PROTOCOLS + S3_PROTOCOLS + GCS_PROTOCOLS @@ -50,126 +50,234 @@ def __init__( self._staging_config = staging_config self._job_client: "DatabricksClient" = None + self._sql_client = None + self._workspace_client = None + self._created_volume = None + def run(self) -> None: self._sql_client = self._job_client.sql_client qualified_table_name = self._sql_client.make_qualified_table_name(self.load_table_name) - staging_credentials = self._staging_config.credentials - # extract and prepare some vars + + # Decide if this is a local file or a staged file + is_local_file = not ReferenceFollowupJobRequest.is_reference_job(self._file_path) + if is_local_file and self._job_client.config.credentials.is_token_from_context: + # Handle local file by uploading to a temporary volume on Databricks + from_clause, file_name = self._handle_local_file_upload(self._file_path) + credentials_clause = "" + orig_bucket_path = None # not used for local file + else: + # Handle staged file + from_clause, credentials_clause, file_name, orig_bucket_path = ( + self._handle_staged_file() + ) + + # Determine the source format and any additional format options + source_format, format_options_clause, skip_load = self._determine_source_format( + file_name, orig_bucket_path + ) + + if skip_load: + # If the file is empty or otherwise un-loadable, exit early + self._cleanup_volume() # in case we created a volume + return + + # Build and execute the COPY INTO statement + statement = self._build_copy_into_statement( + qualified_table_name, + from_clause, + credentials_clause, + source_format, + format_options_clause, + ) + + self._sql_client.execute_sql(statement) + + self._cleanup_volume() + + def _handle_local_file_upload(self, local_file_path: str) -> tuple[str, str]: + from databricks.sdk import WorkspaceClient + from databricks.sdk.service import catalog + import time + import io + + w = WorkspaceClient( + host=self._job_client.config.credentials.server_hostname, + token=self._job_client.config.credentials.access_token, + ) + self._workspace_client = w + + # Create a temporary volume + volume_name = "_dlt_temp_load_volume" + # created_volume = w.volumes.create( + # catalog_name=self._sql_client.database_name, + # schema_name=self._sql_client.dataset_name, + # name=volume_name, + # volume_type=catalog.VolumeType.MANAGED, + # ) + # self._created_volume = created_volume # store to delete later + + qualified_volume_name = ( + f"{self._sql_client.database_name}.{self._sql_client.dataset_name}.{volume_name}" + ) + self._sql_client.execute_sql(f""" + CREATE VOLUME IF NOT EXISTS {qualified_volume_name} + """) + + logger.info(f"datrabricks volume created {qualified_volume_name}") + + # Compute volume paths + volume_path = f"/Volumes/{self._sql_client.database_name}/{self._sql_client.dataset_name}/{volume_name}" + volume_folder = f"file_{time.time_ns()}" + volume_folder_path = f"{volume_path}/{volume_folder}" + + file_name = FileStorage.get_file_name_from_file_path(local_file_path) + volume_file_path = f"{volume_folder_path}/{file_name}" + + # Upload the file + with open(local_file_path, "rb") as f: + file_bytes = f.read() + binary_data = io.BytesIO(file_bytes) + w.files.upload(volume_file_path, binary_data, overwrite=True) + + # Return the FROM clause and file name + from_clause = f"FROM '{volume_folder_path}'" + + return from_clause, file_name + + def _handle_staged_file(self) -> tuple[str, str, str, str]: bucket_path = orig_bucket_path = ( ReferenceFollowupJobRequest.resolve_reference(self._file_path) if ReferenceFollowupJobRequest.is_reference_job(self._file_path) else "" ) - file_name = ( - FileStorage.get_file_name_from_file_path(bucket_path) - if bucket_path - else self._file_name - ) - from_clause = "" - credentials_clause = "" - format_options_clause = "" - if bucket_path: - bucket_url = urlparse(bucket_path) - bucket_scheme = bucket_url.scheme + if not bucket_path: + raise LoadJobTerminalException( + self._file_path, + "Cannot load from local file. Databricks does not support loading from local files." + " Configure staging with an s3, azure or google storage bucket.", + ) - if bucket_scheme not in SUPPORTED_BLOB_STORAGE_PROTOCOLS: - raise LoadJobTerminalException( - self._file_path, - f"Databricks cannot load data from staging bucket {bucket_path}. Only s3, azure" - " and gcs buckets are supported. Please note that gcs buckets are supported" - " only via named credential", - ) + # Extract filename + file_name = FileStorage.get_file_name_from_file_path(bucket_path) - if self._job_client.config.is_staging_external_location: - # just skip the credentials clause for external location - # https://docs.databricks.com/en/sql/language-manual/sql-ref-external-locations.html#external-location - pass - elif self._job_client.config.staging_credentials_name: - # add named credentials - credentials_clause = ( - f"WITH(CREDENTIAL {self._job_client.config.staging_credentials_name} )" - ) - else: - # referencing an staged files via a bucket URL requires explicit AWS credentials - if bucket_scheme == "s3": - assert isinstance(staging_credentials, AwsCredentialsWithoutDefaults) - s3_creds = staging_credentials.to_session_credentials() - credentials_clause = f"""WITH(CREDENTIAL( + staging_credentials = self._staging_config.credentials + bucket_url = urlparse(bucket_path) + bucket_scheme = bucket_url.scheme + + # Validate the storage scheme + if bucket_scheme not in SUPPORTED_BLOB_STORAGE_PROTOCOLS: + raise LoadJobTerminalException( + self._file_path, + f"Databricks cannot load data from staging bucket {bucket_path}. " + "Only s3, azure and gcs buckets are supported. " + "Please note that gcs buckets are supported only via named credential.", + ) + + credentials_clause = "" + # External location vs named credentials vs explicit keys + if self._job_client.config.is_staging_external_location: + # Skip the credentials clause + pass + elif self._job_client.config.staging_credentials_name: + # Named credentials + credentials_clause = ( + f"WITH(CREDENTIAL {self._job_client.config.staging_credentials_name} )" + ) + else: + # Use explicit keys if needed + if bucket_scheme == "s3": + assert isinstance(staging_credentials, AwsCredentialsWithoutDefaults) + s3_creds = staging_credentials.to_session_credentials() + credentials_clause = f"""WITH(CREDENTIAL( AWS_ACCESS_KEY='{s3_creds["aws_access_key_id"]}', AWS_SECRET_KEY='{s3_creds["aws_secret_access_key"]}', - AWS_SESSION_TOKEN='{s3_creds["aws_session_token"]}' - )) - """ - elif bucket_scheme in AZURE_BLOB_STORAGE_PROTOCOLS: - assert isinstance( - staging_credentials, AzureCredentialsWithoutDefaults - ), "AzureCredentialsWithoutDefaults required to pass explicit credential" - # Explicit azure credentials are needed to load from bucket without a named stage - credentials_clause = f"""WITH(CREDENTIAL(AZURE_SAS_TOKEN='{staging_credentials.azure_storage_sas_token}'))""" - bucket_path = self.ensure_databricks_abfss_url( - bucket_path, - staging_credentials.azure_storage_account_name, - staging_credentials.azure_account_host, - ) - else: - raise LoadJobTerminalException( - self._file_path, - "You need to use Databricks named credential to use google storage." - " Passing explicit Google credentials is not supported by Databricks.", - ) - - if bucket_scheme in AZURE_BLOB_STORAGE_PROTOCOLS: + ))""" + elif bucket_scheme in AZURE_BLOB_STORAGE_PROTOCOLS: assert isinstance( - staging_credentials, - ( - AzureCredentialsWithoutDefaults, - AzureServicePrincipalCredentialsWithoutDefaults, - ), - ) + staging_credentials, AzureCredentialsWithoutDefaults + ), "AzureCredentialsWithoutDefaults required to pass explicit credential" + credentials_clause = f"""WITH(CREDENTIAL(AZURE_SAS_TOKEN='{staging_credentials.azure_storage_sas_token}'))""" bucket_path = self.ensure_databricks_abfss_url( bucket_path, staging_credentials.azure_storage_account_name, staging_credentials.azure_account_host, ) + else: + raise LoadJobTerminalException( + self._file_path, + "You need to use Databricks named credential to use google storage." + " Passing explicit Google credentials is not supported by Databricks.", + ) - # always add FROM clause - from_clause = f"FROM '{bucket_path}'" - else: - raise LoadJobTerminalException( - self._file_path, - "Cannot load from local file. Databricks does not support loading from local files." - " Configure staging with an s3, azure or google storage bucket.", + if bucket_scheme in AZURE_BLOB_STORAGE_PROTOCOLS: + assert isinstance( + staging_credentials, + (AzureCredentialsWithoutDefaults, AzureServicePrincipalCredentialsWithoutDefaults), + ) + bucket_path = self.ensure_databricks_abfss_url( + bucket_path, + staging_credentials.azure_storage_account_name, + staging_credentials.azure_account_host, ) - # decide on source format, stage_file_path will either be a local file or a bucket path + from_clause = f"FROM '{bucket_path}'" + + return from_clause, credentials_clause, file_name, orig_bucket_path + + def _determine_source_format( + self, file_name: str, orig_bucket_path: str + ) -> tuple[str, str, bool]: if file_name.endswith(".parquet"): - source_format = "PARQUET" # Only parquet is supported + # Only parquet is supported + return "PARQUET", "", False + elif file_name.endswith(".jsonl"): if not is_compression_disabled(): raise LoadJobTerminalException( self._file_path, - "Databricks loader does not support gzip compressed JSON files. Please disable" - " compression in the data writer configuration:" + "Databricks loader does not support gzip compressed JSON files. " + "Please disable compression in the data writer configuration:" " https://dlthub.com/docs/reference/performance#disabling-and-enabling-file-compression", ) - source_format = "JSON" + # Databricks can load uncompressed JSON format_options_clause = "FORMAT_OPTIONS('inferTimestamp'='true')" - # Databricks fails when trying to load empty json files, so we have to check the file size + + # Check for an empty JSON file fs, _ = fsspec_from_config(self._staging_config) - file_size = fs.size(orig_bucket_path) - if file_size == 0: # Empty file, do nothing - return + if orig_bucket_path is not None: + file_size = fs.size(orig_bucket_path) + if file_size == 0: + return "JSON", format_options_clause, True - statement = f"""COPY INTO {qualified_table_name} + return "JSON", format_options_clause, False + + raise LoadJobTerminalException( + self._file_path, "Databricks loader only supports .parquet or .jsonl file extensions." + ) + + def _build_copy_into_statement( + self, + qualified_table_name: str, + from_clause: str, + credentials_clause: str, + source_format: str, + format_options_clause: str, + ) -> str: + return f"""COPY INTO {qualified_table_name} {from_clause} {credentials_clause} FILEFORMAT = {source_format} {format_options_clause} - """ - self._sql_client.execute_sql(statement) + """ + + def _cleanup_volume(self) -> None: + print("lalal") + # if self._workspace_client and self._created_volume: + # self._workspace_client.volumes.delete(name=self._created_volume.full_name) + # logger.info(f"Deleted temporary volume [{self._created_volume.full_name}]") @staticmethod def ensure_databricks_abfss_url( diff --git a/tests/load/pipeline/test_databricks_pipeline.py b/tests/load/pipeline/test_databricks_pipeline.py index 7e07ba5277..f07b4d579f 100644 --- a/tests/load/pipeline/test_databricks_pipeline.py +++ b/tests/load/pipeline/test_databricks_pipeline.py @@ -218,3 +218,32 @@ def test_databricks_auth_token(destination_config: DestinationTestConfiguration) with pipeline.sql_client() as client: rows = client.execute_sql(f"select * from {dataset_name}.digits") assert len(rows) == 3 + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, subset=("databricks",)), + ids=lambda x: x.name, +) +def test_databricks_direct_loading(destination_config: DestinationTestConfiguration) -> None: + os.environ["DESTINATION__DATABRICKS__CREDENTIALS__CLIENT_ID"] = "" + os.environ["DESTINATION__DATABRICKS__CREDENTIALS__CLIENT_SECRET"] = "" + + # is_token_from_context + os.environ["DESTINATION__DATABRICKS__CREDENTIALS__IS_TOKEN_FROM_CONTEXT"] = "True" + + bricks = databricks() + config = bricks.configuration(None, accept_partial=True) + assert config.credentials.access_token + + dataset_name = "test_databricks_token" + uniq_id() + pipeline = destination_config.setup_pipeline( + "test_databricks_token", dataset_name=dataset_name, destination=bricks + ) + + info = pipeline.run([1, 2, 3], table_name="digits", **destination_config.run_kwargs) + assert info.has_failed_jobs is False + + with pipeline.sql_client() as client: + rows = client.execute_sql(f"select * from {dataset_name}.digits") + assert len(rows) == 3 From b60b3d32f4ab4029e1365f4217e16b234708ec18 Mon Sep 17 00:00:00 2001 From: Julian Alves <28436330+donotpush@users.noreply.github.com> Date: Thu, 16 Jan 2025 13:22:39 +0100 Subject: [PATCH 04/28] added create and drop volume to interface --- dlt/destinations/impl/athena/athena.py | 6 ++++ dlt/destinations/impl/bigquery/sql_client.py | 6 ++++ .../impl/clickhouse/sql_client.py | 6 ++++ .../impl/databricks/databricks.py | 34 +------------------ .../impl/databricks/sql_client.py | 11 ++++++ dlt/destinations/impl/dremio/sql_client.py | 6 ++++ dlt/destinations/impl/duckdb/sql_client.py | 6 ++++ dlt/destinations/impl/mssql/sql_client.py | 6 ++++ dlt/destinations/impl/postgres/sql_client.py | 6 ++++ dlt/destinations/impl/snowflake/sql_client.py | 6 ++++ .../impl/sqlalchemy/db_api_client.py | 6 ++++ dlt/destinations/job_client_impl.py | 1 + dlt/destinations/sql_client.py | 8 +++++ dlt/pipeline/pipeline.py | 1 + 14 files changed, 76 insertions(+), 33 deletions(-) diff --git a/dlt/destinations/impl/athena/athena.py b/dlt/destinations/impl/athena/athena.py index c7e30aaf55..f47bce968d 100644 --- a/dlt/destinations/impl/athena/athena.py +++ b/dlt/destinations/impl/athena/athena.py @@ -190,6 +190,12 @@ def close_connection(self) -> None: self._conn.close() self._conn = None + def create_volume(self) -> None: + pass + + def drop_volume(self) -> None: + pass + @property def native_connection(self) -> Connection: return self._conn diff --git a/dlt/destinations/impl/bigquery/sql_client.py b/dlt/destinations/impl/bigquery/sql_client.py index 6911fa5c1c..194b1594ea 100644 --- a/dlt/destinations/impl/bigquery/sql_client.py +++ b/dlt/destinations/impl/bigquery/sql_client.py @@ -112,6 +112,12 @@ def close_connection(self) -> None: self._client.close() self._client = None + def create_volume(self) -> None: + pass + + def drop_volume(self) -> None: + pass + @contextmanager @raise_database_error def begin_transaction(self) -> Iterator[DBTransaction]: diff --git a/dlt/destinations/impl/clickhouse/sql_client.py b/dlt/destinations/impl/clickhouse/sql_client.py index a6c4ee0458..7c1847fa3c 100644 --- a/dlt/destinations/impl/clickhouse/sql_client.py +++ b/dlt/destinations/impl/clickhouse/sql_client.py @@ -99,6 +99,12 @@ def open_connection(self) -> clickhouse_driver.dbapi.connection.Connection: self._conn = clickhouse_driver.connect(dsn=self.credentials.to_native_representation()) return self._conn + def create_volume(self) -> None: + pass + + def drop_volume(self) -> None: + pass + @raise_open_connection_error def close_connection(self) -> None: if self._conn: diff --git a/dlt/destinations/impl/databricks/databricks.py b/dlt/destinations/impl/databricks/databricks.py index 1fb98c1563..ce3e480b3e 100644 --- a/dlt/destinations/impl/databricks/databricks.py +++ b/dlt/destinations/impl/databricks/databricks.py @@ -51,8 +51,6 @@ def __init__( self._job_client: "DatabricksClient" = None self._sql_client = None - self._workspace_client = None - self._created_volume = None def run(self) -> None: self._sql_client = self._job_client.sql_client @@ -93,11 +91,8 @@ def run(self) -> None: self._sql_client.execute_sql(statement) - self._cleanup_volume() - def _handle_local_file_upload(self, local_file_path: str) -> tuple[str, str]: from databricks.sdk import WorkspaceClient - from databricks.sdk.service import catalog import time import io @@ -105,29 +100,8 @@ def _handle_local_file_upload(self, local_file_path: str) -> tuple[str, str]: host=self._job_client.config.credentials.server_hostname, token=self._job_client.config.credentials.access_token, ) - self._workspace_client = w - - # Create a temporary volume - volume_name = "_dlt_temp_load_volume" - # created_volume = w.volumes.create( - # catalog_name=self._sql_client.database_name, - # schema_name=self._sql_client.dataset_name, - # name=volume_name, - # volume_type=catalog.VolumeType.MANAGED, - # ) - # self._created_volume = created_volume # store to delete later - - qualified_volume_name = ( - f"{self._sql_client.database_name}.{self._sql_client.dataset_name}.{volume_name}" - ) - self._sql_client.execute_sql(f""" - CREATE VOLUME IF NOT EXISTS {qualified_volume_name} - """) - logger.info(f"datrabricks volume created {qualified_volume_name}") - - # Compute volume paths - volume_path = f"/Volumes/{self._sql_client.database_name}/{self._sql_client.dataset_name}/{volume_name}" + volume_path = f"/Volumes/{self._sql_client.database_name}/{self._sql_client.dataset_name}/{self._sql_client.volume_name}" volume_folder = f"file_{time.time_ns()}" volume_folder_path = f"{volume_path}/{volume_folder}" @@ -273,12 +247,6 @@ def _build_copy_into_statement( {format_options_clause} """ - def _cleanup_volume(self) -> None: - print("lalal") - # if self._workspace_client and self._created_volume: - # self._workspace_client.volumes.delete(name=self._created_volume.full_name) - # logger.info(f"Deleted temporary volume [{self._created_volume.full_name}]") - @staticmethod def ensure_databricks_abfss_url( bucket_path: str, azure_storage_account_name: str = None, account_host: str = None diff --git a/dlt/destinations/impl/databricks/sql_client.py b/dlt/destinations/impl/databricks/sql_client.py index 9f695b9d6e..2fdf5fd968 100644 --- a/dlt/destinations/impl/databricks/sql_client.py +++ b/dlt/destinations/impl/databricks/sql_client.py @@ -63,6 +63,7 @@ def iter_df(self, chunk_size: int) -> Generator[DataFrame, None, None]: class DatabricksSqlClient(SqlClientBase[DatabricksSqlConnection], DBTransaction): dbapi: ClassVar[DBApi] = databricks_lib + volume_name: str = "_dlt_temp_load_volume" def __init__( self, @@ -102,6 +103,16 @@ def close_connection(self) -> None: self._conn.close() self._conn = None + def create_volume(self) -> None: + self.execute_sql(f""" + CREATE VOLUME IF NOT EXISTS {self.fully_qualified_dataset_name()}.{self.volume_name} + """) + + def drop_volume(self) -> None: + self.execute_sql(f""" + DROP VOLUME IF EXISTS {self.fully_qualified_dataset_name()}.{self.volume_name} + """) + @contextmanager def begin_transaction(self) -> Iterator[DBTransaction]: # Databricks does not support transactions diff --git a/dlt/destinations/impl/dremio/sql_client.py b/dlt/destinations/impl/dremio/sql_client.py index 030009c74b..d8c509bf18 100644 --- a/dlt/destinations/impl/dremio/sql_client.py +++ b/dlt/destinations/impl/dremio/sql_client.py @@ -64,6 +64,12 @@ def close_connection(self) -> None: self._conn.close() self._conn = None + def create_volume(self) -> None: + pass + + def drop_volume(self) -> None: + pass + @contextmanager @raise_database_error def begin_transaction(self) -> Iterator[DBTransaction]: diff --git a/dlt/destinations/impl/duckdb/sql_client.py b/dlt/destinations/impl/duckdb/sql_client.py index ee73965df6..ba8572ede8 100644 --- a/dlt/destinations/impl/duckdb/sql_client.py +++ b/dlt/destinations/impl/duckdb/sql_client.py @@ -95,6 +95,12 @@ def close_connection(self) -> None: self.credentials.return_conn(self._conn) self._conn = None + def create_volume(self) -> None: + pass + + def drop_volume(self) -> None: + pass + @contextmanager @raise_database_error def begin_transaction(self) -> Iterator[DBTransaction]: diff --git a/dlt/destinations/impl/mssql/sql_client.py b/dlt/destinations/impl/mssql/sql_client.py index 9f05b88bb5..467f0c2b6c 100644 --- a/dlt/destinations/impl/mssql/sql_client.py +++ b/dlt/destinations/impl/mssql/sql_client.py @@ -70,6 +70,12 @@ def close_connection(self) -> None: self._conn.close() self._conn = None + def create_volume(self) -> None: + pass + + def drop_volume(self) -> None: + pass + @contextmanager def begin_transaction(self) -> Iterator[DBTransaction]: try: diff --git a/dlt/destinations/impl/postgres/sql_client.py b/dlt/destinations/impl/postgres/sql_client.py index a97c8511f1..b76ca92353 100644 --- a/dlt/destinations/impl/postgres/sql_client.py +++ b/dlt/destinations/impl/postgres/sql_client.py @@ -58,6 +58,12 @@ def close_connection(self) -> None: self._conn.close() self._conn = None + def create_volume(self) -> None: + pass + + def drop_volume(self) -> None: + pass + @contextmanager def begin_transaction(self) -> Iterator[DBTransaction]: try: diff --git a/dlt/destinations/impl/snowflake/sql_client.py b/dlt/destinations/impl/snowflake/sql_client.py index 22e27ea48b..56e939e456 100644 --- a/dlt/destinations/impl/snowflake/sql_client.py +++ b/dlt/destinations/impl/snowflake/sql_client.py @@ -63,6 +63,12 @@ def close_connection(self) -> None: self._conn.close() self._conn = None + def create_volume(self) -> None: + pass + + def drop_volume(self) -> None: + pass + @contextmanager def begin_transaction(self) -> Iterator[DBTransaction]: try: diff --git a/dlt/destinations/impl/sqlalchemy/db_api_client.py b/dlt/destinations/impl/sqlalchemy/db_api_client.py index 27c4f2f1f9..915aee7eae 100644 --- a/dlt/destinations/impl/sqlalchemy/db_api_client.py +++ b/dlt/destinations/impl/sqlalchemy/db_api_client.py @@ -171,6 +171,12 @@ def close_connection(self) -> None: self._current_connection = None self._current_transaction = None + def create_volume(self) -> None: + pass + + def drop_volume(self) -> None: + pass + @property def native_connection(self) -> Connection: if not self._current_connection: diff --git a/dlt/destinations/job_client_impl.py b/dlt/destinations/job_client_impl.py index 888c80c006..234493104d 100644 --- a/dlt/destinations/job_client_impl.py +++ b/dlt/destinations/job_client_impl.py @@ -176,6 +176,7 @@ def initialize_storage(self, truncate_tables: Iterable[str] = None) -> None: self.sql_client.create_dataset() elif truncate_tables: self.sql_client.truncate_tables(*truncate_tables) + self.sql_client.create_volume() def is_storage_initialized(self) -> bool: return self.sql_client.has_dataset() diff --git a/dlt/destinations/sql_client.py b/dlt/destinations/sql_client.py index 345afff18e..56d11e143c 100644 --- a/dlt/destinations/sql_client.py +++ b/dlt/destinations/sql_client.py @@ -87,6 +87,14 @@ def close_connection(self) -> None: def begin_transaction(self) -> ContextManager[DBTransaction]: pass + @abstractmethod + def create_volume(self) -> None: + pass + + @abstractmethod + def drop_volume(self) -> None: + pass + def __getattr__(self, name: str) -> Any: # pass unresolved attrs to native connections if not self.native_connection: diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index 74466a09e4..5c66a60498 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -608,6 +608,7 @@ def load( runner.run_pool(load_step.config, load_step) info: LoadInfo = self._get_step_info(load_step) + self.sql_client().drop_volume() self.first_run = False return info except Exception as l_ex: From e772d20492cbcc7ca01731df352c1c5d16117f6e Mon Sep 17 00:00:00 2001 From: Julian Alves <28436330+donotpush@users.noreply.github.com> Date: Mon, 20 Jan 2025 16:25:34 +0100 Subject: [PATCH 05/28] refactor direct load authentication --- .../impl/databricks/configuration.py | 24 +++---- .../impl/databricks/databricks.py | 64 +++++++++++-------- .../impl/databricks/sql_client.py | 2 + tests/.dlt/config.toml | 2 + .../load/pipeline/test_databricks_pipeline.py | 6 +- 5 files changed, 54 insertions(+), 44 deletions(-) diff --git a/dlt/destinations/impl/databricks/configuration.py b/dlt/destinations/impl/databricks/configuration.py index 630f0bcc2a..ad1dc397a2 100644 --- a/dlt/destinations/impl/databricks/configuration.py +++ b/dlt/destinations/impl/databricks/configuration.py @@ -15,7 +15,7 @@ class DatabricksCredentials(CredentialsConfiguration): catalog: str = None server_hostname: str = None http_path: str = None - is_token_from_context: bool = False + direct_load: bool = False access_token: Optional[TSecretStrValue] = None client_id: Optional[TSecretStrValue] = None client_secret: Optional[TSecretStrValue] = None @@ -38,27 +38,23 @@ class DatabricksCredentials(CredentialsConfiguration): def on_resolved(self) -> None: if not ((self.client_id and self.client_secret) or self.access_token): - # databricks authentication: attempt context token + # databricks authentication: get context config from databricks.sdk import WorkspaceClient w = WorkspaceClient() - dbutils = w.dbutils - self.access_token = ( - dbutils.notebook.entry_point.getDbutils() - .notebook() - .getContext() - .apiToken() - .getOrElse(None) - ) + notebook_context = w.dbutils.notebook.entry_point.getDbutils().notebook().getContext() + self.access_token = notebook_context.apiToken().getOrElse(None) + + self.server_hostname = notebook_context.browserHostName().getOrElse(None) - if not self.access_token: + if not self.access_token or not self.server_hostname: raise ConfigurationValueError( "No valid authentication method detected. Provide either 'client_id' and" - " 'client_secret' for OAuth, or 'access_token' for token-based authentication." + " 'client_secret' for OAuth, or 'access_token' for token-based authentication," + " and the server_hostname." ) - self.is_token_from_context = True - logger.info("Authenticating to Databricks using the user's Notebook API token.") + self.direct_load = True def to_connector_params(self) -> Dict[str, Any]: conn_params = dict( diff --git a/dlt/destinations/impl/databricks/databricks.py b/dlt/destinations/impl/databricks/databricks.py index ce3e480b3e..3afeecbb2a 100644 --- a/dlt/destinations/impl/databricks/databricks.py +++ b/dlt/destinations/impl/databricks/databricks.py @@ -35,7 +35,7 @@ from dlt.destinations.sql_jobs import SqlMergeFollowupJob from dlt.destinations.job_impl import ReferenceFollowupJobRequest from dlt.destinations.utils import is_compression_disabled -from dlt.common import logger +from dlt.common.utils import digest128 SUPPORTED_BLOB_STORAGE_PROTOCOLS = AZURE_BLOB_STORAGE_PROTOCOLS + S3_PROTOCOLS + GCS_PROTOCOLS @@ -57,15 +57,15 @@ def run(self) -> None: qualified_table_name = self._sql_client.make_qualified_table_name(self.load_table_name) - # Decide if this is a local file or a staged file + # decide if this is a local file or a staged file is_local_file = not ReferenceFollowupJobRequest.is_reference_job(self._file_path) - if is_local_file and self._job_client.config.credentials.is_token_from_context: - # Handle local file by uploading to a temporary volume on Databricks + if is_local_file and self._job_client.config.credentials.direct_load: + # local file by uploading to a temporary volume on Databricks from_clause, file_name = self._handle_local_file_upload(self._file_path) credentials_clause = "" orig_bucket_path = None # not used for local file else: - # Handle staged file + # staged file from_clause, credentials_clause, file_name, orig_bucket_path = ( self._handle_staged_file() ) @@ -77,10 +77,8 @@ def run(self) -> None: if skip_load: # If the file is empty or otherwise un-loadable, exit early - self._cleanup_volume() # in case we created a volume return - # Build and execute the COPY INTO statement statement = self._build_copy_into_statement( qualified_table_name, from_clause, @@ -96,26 +94,42 @@ def _handle_local_file_upload(self, local_file_path: str) -> tuple[str, str]: import time import io - w = WorkspaceClient( - host=self._job_client.config.credentials.server_hostname, - token=self._job_client.config.credentials.access_token, - ) + w: WorkspaceClient - volume_path = f"/Volumes/{self._sql_client.database_name}/{self._sql_client.dataset_name}/{self._sql_client.volume_name}" - volume_folder = f"file_{time.time_ns()}" - volume_folder_path = f"{volume_path}/{volume_folder}" + credentials = self._job_client.config.credentials + if credentials.client_id and credentials.client_secret: + # oauth authentication + w = WorkspaceClient( + host=credentials.server_hostname, + client_id=credentials.client_id, + client_secret=credentials.client_secret, + ) + elif credentials.access_token: + # token authentication + w = WorkspaceClient( + host=credentials.server_hostname, + token=credentials.access_token, + ) file_name = FileStorage.get_file_name_from_file_path(local_file_path) - volume_file_path = f"{volume_folder_path}/{file_name}" + file_format = "" + if file_name.endswith(".parquet"): + file_format = "parquet" + elif file_name.endswith(".jsonl"): + file_format = "jsonl" + else: + return "",file_name + + volume_path = f"/Volumes/{self._sql_client.database_name}/{self._sql_client.dataset_name}/{self._sql_client.volume_name}/{time.time_ns()}" + volume_file_name = f"{digest128(file_name)}.{file_format}" # file_name must be hashed - databricks fails with file name starting with - or . + volume_file_path = f"{volume_path}/{volume_file_name}" - # Upload the file with open(local_file_path, "rb") as f: file_bytes = f.read() binary_data = io.BytesIO(file_bytes) w.files.upload(volume_file_path, binary_data, overwrite=True) - # Return the FROM clause and file name - from_clause = f"FROM '{volume_folder_path}'" + from_clause = f"FROM '{volume_path}'" return from_clause, file_name @@ -133,14 +147,12 @@ def _handle_staged_file(self) -> tuple[str, str, str, str]: " Configure staging with an s3, azure or google storage bucket.", ) - # Extract filename file_name = FileStorage.get_file_name_from_file_path(bucket_path) staging_credentials = self._staging_config.credentials bucket_url = urlparse(bucket_path) bucket_scheme = bucket_url.scheme - # Validate the storage scheme if bucket_scheme not in SUPPORTED_BLOB_STORAGE_PROTOCOLS: raise LoadJobTerminalException( self._file_path, @@ -150,17 +162,16 @@ def _handle_staged_file(self) -> tuple[str, str, str, str]: ) credentials_clause = "" - # External location vs named credentials vs explicit keys + if self._job_client.config.is_staging_external_location: - # Skip the credentials clause + # skip the credentials clause pass elif self._job_client.config.staging_credentials_name: - # Named credentials + # named credentials credentials_clause = ( f"WITH(CREDENTIAL {self._job_client.config.staging_credentials_name} )" ) else: - # Use explicit keys if needed if bucket_scheme == "s3": assert isinstance(staging_credentials, AwsCredentialsWithoutDefaults) s3_creds = staging_credentials.to_session_credentials() @@ -205,7 +216,6 @@ def _determine_source_format( self, file_name: str, orig_bucket_path: str ) -> tuple[str, str, bool]: if file_name.endswith(".parquet"): - # Only parquet is supported return "PARQUET", "", False elif file_name.endswith(".jsonl"): @@ -216,10 +226,10 @@ def _determine_source_format( "Please disable compression in the data writer configuration:" " https://dlthub.com/docs/reference/performance#disabling-and-enabling-file-compression", ) - # Databricks can load uncompressed JSON + format_options_clause = "FORMAT_OPTIONS('inferTimestamp'='true')" - # Check for an empty JSON file + # check for an empty JSON file fs, _ = fsspec_from_config(self._staging_config) if orig_bucket_path is not None: file_size = fs.size(orig_bucket_path) diff --git a/dlt/destinations/impl/databricks/sql_client.py b/dlt/destinations/impl/databricks/sql_client.py index 2fdf5fd968..a9e880a56e 100644 --- a/dlt/destinations/impl/databricks/sql_client.py +++ b/dlt/destinations/impl/databricks/sql_client.py @@ -109,6 +109,8 @@ def create_volume(self) -> None: """) def drop_volume(self) -> None: + if not self._conn: + self.open_connection() self.execute_sql(f""" DROP VOLUME IF EXISTS {self.fully_qualified_dataset_name()}.{self.volume_name} """) diff --git a/tests/.dlt/config.toml b/tests/.dlt/config.toml index f185a73865..983cec2b6b 100644 --- a/tests/.dlt/config.toml +++ b/tests/.dlt/config.toml @@ -1,4 +1,6 @@ [runtime] +log_level="DEBUG" +log_format="JSON" # sentry_dsn="https://6f6f7b6f8e0f458a89be4187603b55fe@o1061158.ingest.sentry.io/4504819859914752" [tests] diff --git a/tests/load/pipeline/test_databricks_pipeline.py b/tests/load/pipeline/test_databricks_pipeline.py index f07b4d579f..71c901fb16 100644 --- a/tests/load/pipeline/test_databricks_pipeline.py +++ b/tests/load/pipeline/test_databricks_pipeline.py @@ -225,12 +225,12 @@ def test_databricks_auth_token(destination_config: DestinationTestConfiguration) destinations_configs(default_sql_configs=True, subset=("databricks",)), ids=lambda x: x.name, ) -def test_databricks_direct_loading(destination_config: DestinationTestConfiguration) -> None: +def test_databricks_direct_load(destination_config: DestinationTestConfiguration) -> None: os.environ["DESTINATION__DATABRICKS__CREDENTIALS__CLIENT_ID"] = "" os.environ["DESTINATION__DATABRICKS__CREDENTIALS__CLIENT_SECRET"] = "" - # is_token_from_context - os.environ["DESTINATION__DATABRICKS__CREDENTIALS__IS_TOKEN_FROM_CONTEXT"] = "True" + # direct_load + os.environ["DESTINATION__DATABRICKS__CREDENTIALS__DIRECT_LOAD"] = "True" bricks = databricks() config = bricks.configuration(None, accept_partial=True) From 2bd0be005c95bce1ec1844c752ef934469d137db Mon Sep 17 00:00:00 2001 From: Julian Alves <28436330+donotpush@users.noreply.github.com> Date: Mon, 20 Jan 2025 17:16:48 +0100 Subject: [PATCH 06/28] fix databricks volume file name --- dlt/destinations/impl/databricks/databricks.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/dlt/destinations/impl/databricks/databricks.py b/dlt/destinations/impl/databricks/databricks.py index 3afeecbb2a..be1697cee0 100644 --- a/dlt/destinations/impl/databricks/databricks.py +++ b/dlt/destinations/impl/databricks/databricks.py @@ -35,7 +35,7 @@ from dlt.destinations.sql_jobs import SqlMergeFollowupJob from dlt.destinations.job_impl import ReferenceFollowupJobRequest from dlt.destinations.utils import is_compression_disabled -from dlt.common.utils import digest128 +from dlt.common.utils import uniq_id SUPPORTED_BLOB_STORAGE_PROTOCOLS = AZURE_BLOB_STORAGE_PROTOCOLS + S3_PROTOCOLS + GCS_PROTOCOLS @@ -118,10 +118,12 @@ def _handle_local_file_upload(self, local_file_path: str) -> tuple[str, str]: elif file_name.endswith(".jsonl"): file_format = "jsonl" else: - return "",file_name - + return "", file_name + volume_path = f"/Volumes/{self._sql_client.database_name}/{self._sql_client.dataset_name}/{self._sql_client.volume_name}/{time.time_ns()}" - volume_file_name = f"{digest128(file_name)}.{file_format}" # file_name must be hashed - databricks fails with file name starting with - or . + volume_file_name = ( # replace file_name for random hex code - databricks loading fails when file_name starts with - or . + f"{uniq_id()}.{file_format}" + ) volume_file_path = f"{volume_path}/{volume_file_name}" with open(local_file_path, "rb") as f: From 7641bcf458bfed44911021e3e4259a08fb63fbb3 Mon Sep 17 00:00:00 2001 From: Julian Alves <28436330+donotpush@users.noreply.github.com> Date: Wed, 22 Jan 2025 14:39:24 +0100 Subject: [PATCH 07/28] refactor databricks direct loading --- dlt/destinations/impl/athena/athena.py | 6 ---- dlt/destinations/impl/bigquery/sql_client.py | 6 ---- .../impl/clickhouse/sql_client.py | 6 ---- .../impl/databricks/configuration.py | 29 ++++++++++----- .../impl/databricks/databricks.py | 35 ++++++++++++------- .../impl/databricks/sql_client.py | 13 ------- dlt/destinations/impl/dremio/sql_client.py | 6 ---- dlt/destinations/impl/duckdb/sql_client.py | 6 ---- dlt/destinations/impl/mssql/sql_client.py | 6 ---- dlt/destinations/impl/postgres/sql_client.py | 6 ---- dlt/destinations/impl/snowflake/sql_client.py | 6 ---- .../impl/sqlalchemy/db_api_client.py | 6 ---- dlt/destinations/job_client_impl.py | 1 - dlt/destinations/sql_client.py | 8 ----- dlt/pipeline/pipeline.py | 1 - .../load/pipeline/test_databricks_pipeline.py | 8 +++-- 16 files changed, 47 insertions(+), 102 deletions(-) diff --git a/dlt/destinations/impl/athena/athena.py b/dlt/destinations/impl/athena/athena.py index f47bce968d..c7e30aaf55 100644 --- a/dlt/destinations/impl/athena/athena.py +++ b/dlt/destinations/impl/athena/athena.py @@ -190,12 +190,6 @@ def close_connection(self) -> None: self._conn.close() self._conn = None - def create_volume(self) -> None: - pass - - def drop_volume(self) -> None: - pass - @property def native_connection(self) -> Connection: return self._conn diff --git a/dlt/destinations/impl/bigquery/sql_client.py b/dlt/destinations/impl/bigquery/sql_client.py index 194b1594ea..6911fa5c1c 100644 --- a/dlt/destinations/impl/bigquery/sql_client.py +++ b/dlt/destinations/impl/bigquery/sql_client.py @@ -112,12 +112,6 @@ def close_connection(self) -> None: self._client.close() self._client = None - def create_volume(self) -> None: - pass - - def drop_volume(self) -> None: - pass - @contextmanager @raise_database_error def begin_transaction(self) -> Iterator[DBTransaction]: diff --git a/dlt/destinations/impl/clickhouse/sql_client.py b/dlt/destinations/impl/clickhouse/sql_client.py index 7c1847fa3c..a6c4ee0458 100644 --- a/dlt/destinations/impl/clickhouse/sql_client.py +++ b/dlt/destinations/impl/clickhouse/sql_client.py @@ -99,12 +99,6 @@ def open_connection(self) -> clickhouse_driver.dbapi.connection.Connection: self._conn = clickhouse_driver.connect(dsn=self.credentials.to_native_representation()) return self._conn - def create_volume(self) -> None: - pass - - def drop_volume(self) -> None: - pass - @raise_open_connection_error def close_connection(self) -> None: if self._conn: diff --git a/dlt/destinations/impl/databricks/configuration.py b/dlt/destinations/impl/databricks/configuration.py index ad1dc397a2..85eaaa4097 100644 --- a/dlt/destinations/impl/databricks/configuration.py +++ b/dlt/destinations/impl/databricks/configuration.py @@ -15,7 +15,6 @@ class DatabricksCredentials(CredentialsConfiguration): catalog: str = None server_hostname: str = None http_path: str = None - direct_load: bool = False access_token: Optional[TSecretStrValue] = None client_id: Optional[TSecretStrValue] = None client_secret: Optional[TSecretStrValue] = None @@ -38,14 +37,19 @@ class DatabricksCredentials(CredentialsConfiguration): def on_resolved(self) -> None: if not ((self.client_id and self.client_secret) or self.access_token): - # databricks authentication: get context config - from databricks.sdk import WorkspaceClient + try: + # attempt notebook context authentication + from databricks.sdk import WorkspaceClient - w = WorkspaceClient() - notebook_context = w.dbutils.notebook.entry_point.getDbutils().notebook().getContext() - self.access_token = notebook_context.apiToken().getOrElse(None) + w = WorkspaceClient() + notebook_context = ( + w.dbutils.notebook.entry_point.getDbutils().notebook().getContext() + ) + self.access_token = notebook_context.apiToken().getOrElse(None) - self.server_hostname = notebook_context.browserHostName().getOrElse(None) + self.server_hostname = notebook_context.browserHostName().getOrElse(None) + except Exception: + pass if not self.access_token or not self.server_hostname: raise ConfigurationValueError( @@ -54,8 +58,6 @@ def on_resolved(self) -> None: " and the server_hostname." ) - self.direct_load = True - def to_connector_params(self) -> Dict[str, Any]: conn_params = dict( catalog=self.catalog, @@ -83,6 +85,15 @@ class DatabricksClientConfiguration(DestinationClientDwhWithStagingConfiguration "If set, credentials with given name will be used in copy command" is_staging_external_location: bool = False """If true, the temporary credentials are not propagated to the COPY command""" + staging_volume_name: Optional[str] = None + """Name of the Databricks managed volume for temporary storage, e.g., ... Defaults to '_dlt_temp_load_volume' if not set.""" + + def on_resolved(self): + if self.staging_volume_name and self.staging_volume_name.count(".") != 2: + raise ConfigurationValueError( + f"Invalid staging_volume_name format: {self.staging_volume_name}. Expected format" + " is '..'." + ) def __str__(self) -> str: """Return displayable destination location""" diff --git a/dlt/destinations/impl/databricks/databricks.py b/dlt/destinations/impl/databricks/databricks.py index be1697cee0..9e6e445a3f 100644 --- a/dlt/destinations/impl/databricks/databricks.py +++ b/dlt/destinations/impl/databricks/databricks.py @@ -59,7 +59,7 @@ def run(self) -> None: # decide if this is a local file or a staged file is_local_file = not ReferenceFollowupJobRequest.is_reference_job(self._file_path) - if is_local_file and self._job_client.config.credentials.direct_load: + if is_local_file: # local file by uploading to a temporary volume on Databricks from_clause, file_name = self._handle_local_file_upload(self._file_path) credentials_clause = "" @@ -112,18 +112,27 @@ def _handle_local_file_upload(self, local_file_path: str) -> tuple[str, str]: ) file_name = FileStorage.get_file_name_from_file_path(local_file_path) - file_format = "" - if file_name.endswith(".parquet"): - file_format = "parquet" - elif file_name.endswith(".jsonl"): - file_format = "jsonl" - else: - return "", file_name - - volume_path = f"/Volumes/{self._sql_client.database_name}/{self._sql_client.dataset_name}/{self._sql_client.volume_name}/{time.time_ns()}" - volume_file_name = ( # replace file_name for random hex code - databricks loading fails when file_name starts with - or . - f"{uniq_id()}.{file_format}" - ) + volume_file_name = file_name + if file_name.startswith(("_", ".")): + volume_file_name = ( + "valid" + file_name + ) # databricks loading fails when file_name starts with - or . + + volume_catalog = self._sql_client.database_name + volume_database = self._sql_client.dataset_name + volume_name = "_dlt_staging_load_volume" + + # create staging volume name + fully_qualified_volume_name = f"{volume_catalog}.{volume_database}.{volume_name}" + if self._job_client.config.staging_volume_name: + fully_qualified_volume_name = self._job_client.config.staging_volume_name + volume_catalog, volume_database, volume_name = fully_qualified_volume_name.split(".") + + self._sql_client.execute_sql(f""" + CREATE VOLUME IF NOT EXISTS {fully_qualified_volume_name} + """) + + volume_path = f"/Volumes/{volume_catalog}/{volume_database}/{volume_name}/{time.time_ns()}" volume_file_path = f"{volume_path}/{volume_file_name}" with open(local_file_path, "rb") as f: diff --git a/dlt/destinations/impl/databricks/sql_client.py b/dlt/destinations/impl/databricks/sql_client.py index a9e880a56e..9f695b9d6e 100644 --- a/dlt/destinations/impl/databricks/sql_client.py +++ b/dlt/destinations/impl/databricks/sql_client.py @@ -63,7 +63,6 @@ def iter_df(self, chunk_size: int) -> Generator[DataFrame, None, None]: class DatabricksSqlClient(SqlClientBase[DatabricksSqlConnection], DBTransaction): dbapi: ClassVar[DBApi] = databricks_lib - volume_name: str = "_dlt_temp_load_volume" def __init__( self, @@ -103,18 +102,6 @@ def close_connection(self) -> None: self._conn.close() self._conn = None - def create_volume(self) -> None: - self.execute_sql(f""" - CREATE VOLUME IF NOT EXISTS {self.fully_qualified_dataset_name()}.{self.volume_name} - """) - - def drop_volume(self) -> None: - if not self._conn: - self.open_connection() - self.execute_sql(f""" - DROP VOLUME IF EXISTS {self.fully_qualified_dataset_name()}.{self.volume_name} - """) - @contextmanager def begin_transaction(self) -> Iterator[DBTransaction]: # Databricks does not support transactions diff --git a/dlt/destinations/impl/dremio/sql_client.py b/dlt/destinations/impl/dremio/sql_client.py index d8c509bf18..030009c74b 100644 --- a/dlt/destinations/impl/dremio/sql_client.py +++ b/dlt/destinations/impl/dremio/sql_client.py @@ -64,12 +64,6 @@ def close_connection(self) -> None: self._conn.close() self._conn = None - def create_volume(self) -> None: - pass - - def drop_volume(self) -> None: - pass - @contextmanager @raise_database_error def begin_transaction(self) -> Iterator[DBTransaction]: diff --git a/dlt/destinations/impl/duckdb/sql_client.py b/dlt/destinations/impl/duckdb/sql_client.py index ba8572ede8..ee73965df6 100644 --- a/dlt/destinations/impl/duckdb/sql_client.py +++ b/dlt/destinations/impl/duckdb/sql_client.py @@ -95,12 +95,6 @@ def close_connection(self) -> None: self.credentials.return_conn(self._conn) self._conn = None - def create_volume(self) -> None: - pass - - def drop_volume(self) -> None: - pass - @contextmanager @raise_database_error def begin_transaction(self) -> Iterator[DBTransaction]: diff --git a/dlt/destinations/impl/mssql/sql_client.py b/dlt/destinations/impl/mssql/sql_client.py index 467f0c2b6c..9f05b88bb5 100644 --- a/dlt/destinations/impl/mssql/sql_client.py +++ b/dlt/destinations/impl/mssql/sql_client.py @@ -70,12 +70,6 @@ def close_connection(self) -> None: self._conn.close() self._conn = None - def create_volume(self) -> None: - pass - - def drop_volume(self) -> None: - pass - @contextmanager def begin_transaction(self) -> Iterator[DBTransaction]: try: diff --git a/dlt/destinations/impl/postgres/sql_client.py b/dlt/destinations/impl/postgres/sql_client.py index b76ca92353..a97c8511f1 100644 --- a/dlt/destinations/impl/postgres/sql_client.py +++ b/dlt/destinations/impl/postgres/sql_client.py @@ -58,12 +58,6 @@ def close_connection(self) -> None: self._conn.close() self._conn = None - def create_volume(self) -> None: - pass - - def drop_volume(self) -> None: - pass - @contextmanager def begin_transaction(self) -> Iterator[DBTransaction]: try: diff --git a/dlt/destinations/impl/snowflake/sql_client.py b/dlt/destinations/impl/snowflake/sql_client.py index 56e939e456..22e27ea48b 100644 --- a/dlt/destinations/impl/snowflake/sql_client.py +++ b/dlt/destinations/impl/snowflake/sql_client.py @@ -63,12 +63,6 @@ def close_connection(self) -> None: self._conn.close() self._conn = None - def create_volume(self) -> None: - pass - - def drop_volume(self) -> None: - pass - @contextmanager def begin_transaction(self) -> Iterator[DBTransaction]: try: diff --git a/dlt/destinations/impl/sqlalchemy/db_api_client.py b/dlt/destinations/impl/sqlalchemy/db_api_client.py index 915aee7eae..27c4f2f1f9 100644 --- a/dlt/destinations/impl/sqlalchemy/db_api_client.py +++ b/dlt/destinations/impl/sqlalchemy/db_api_client.py @@ -171,12 +171,6 @@ def close_connection(self) -> None: self._current_connection = None self._current_transaction = None - def create_volume(self) -> None: - pass - - def drop_volume(self) -> None: - pass - @property def native_connection(self) -> Connection: if not self._current_connection: diff --git a/dlt/destinations/job_client_impl.py b/dlt/destinations/job_client_impl.py index 234493104d..888c80c006 100644 --- a/dlt/destinations/job_client_impl.py +++ b/dlt/destinations/job_client_impl.py @@ -176,7 +176,6 @@ def initialize_storage(self, truncate_tables: Iterable[str] = None) -> None: self.sql_client.create_dataset() elif truncate_tables: self.sql_client.truncate_tables(*truncate_tables) - self.sql_client.create_volume() def is_storage_initialized(self) -> bool: return self.sql_client.has_dataset() diff --git a/dlt/destinations/sql_client.py b/dlt/destinations/sql_client.py index 56d11e143c..345afff18e 100644 --- a/dlt/destinations/sql_client.py +++ b/dlt/destinations/sql_client.py @@ -87,14 +87,6 @@ def close_connection(self) -> None: def begin_transaction(self) -> ContextManager[DBTransaction]: pass - @abstractmethod - def create_volume(self) -> None: - pass - - @abstractmethod - def drop_volume(self) -> None: - pass - def __getattr__(self, name: str) -> Any: # pass unresolved attrs to native connections if not self.native_connection: diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index 5c66a60498..74466a09e4 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -608,7 +608,6 @@ def load( runner.run_pool(load_step.config, load_step) info: LoadInfo = self._get_step_info(load_step) - self.sql_client().drop_volume() self.first_run = False return info except Exception as l_ex: diff --git a/tests/load/pipeline/test_databricks_pipeline.py b/tests/load/pipeline/test_databricks_pipeline.py index 71c901fb16..5ff1cc2ca2 100644 --- a/tests/load/pipeline/test_databricks_pipeline.py +++ b/tests/load/pipeline/test_databricks_pipeline.py @@ -220,6 +220,11 @@ def test_databricks_auth_token(destination_config: DestinationTestConfiguration) assert len(rows) == 3 +# TODO: test config staging_volume_name on_resolve +# TODO: modify the DestinationTestConfiguration +# TODO: add test databricks credentials default auth error +# TODO: test on notebook +# TODO: check that volume doesn't block schema drop @pytest.mark.parametrize( "destination_config", destinations_configs(default_sql_configs=True, subset=("databricks",)), @@ -229,9 +234,6 @@ def test_databricks_direct_load(destination_config: DestinationTestConfiguration os.environ["DESTINATION__DATABRICKS__CREDENTIALS__CLIENT_ID"] = "" os.environ["DESTINATION__DATABRICKS__CREDENTIALS__CLIENT_SECRET"] = "" - # direct_load - os.environ["DESTINATION__DATABRICKS__CREDENTIALS__DIRECT_LOAD"] = "True" - bricks = databricks() config = bricks.configuration(None, accept_partial=True) assert config.credentials.access_token From 627b9856ec92dde1c3b046282140e0d261a9dd58 Mon Sep 17 00:00:00 2001 From: Julian Alves <28436330+donotpush@users.noreply.github.com> Date: Wed, 22 Jan 2025 17:56:30 +0100 Subject: [PATCH 08/28] format and lint --- dlt/destinations/impl/databricks/configuration.py | 4 ++-- dlt/destinations/impl/databricks/databricks.py | 2 -- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/dlt/destinations/impl/databricks/configuration.py b/dlt/destinations/impl/databricks/configuration.py index 85eaaa4097..0f092af4ec 100644 --- a/dlt/destinations/impl/databricks/configuration.py +++ b/dlt/destinations/impl/databricks/configuration.py @@ -43,7 +43,7 @@ def on_resolved(self) -> None: w = WorkspaceClient() notebook_context = ( - w.dbutils.notebook.entry_point.getDbutils().notebook().getContext() + w.dbutils.notebook.entry_point.getDbutils().notebook().getContext() # type: ignore[union-attr] ) self.access_token = notebook_context.apiToken().getOrElse(None) @@ -88,7 +88,7 @@ class DatabricksClientConfiguration(DestinationClientDwhWithStagingConfiguration staging_volume_name: Optional[str] = None """Name of the Databricks managed volume for temporary storage, e.g., ... Defaults to '_dlt_temp_load_volume' if not set.""" - def on_resolved(self): + def on_resolved(self) -> None: if self.staging_volume_name and self.staging_volume_name.count(".") != 2: raise ConfigurationValueError( f"Invalid staging_volume_name format: {self.staging_volume_name}. Expected format" diff --git a/dlt/destinations/impl/databricks/databricks.py b/dlt/destinations/impl/databricks/databricks.py index 9e6e445a3f..c7e8ce2455 100644 --- a/dlt/destinations/impl/databricks/databricks.py +++ b/dlt/destinations/impl/databricks/databricks.py @@ -50,8 +50,6 @@ def __init__( self._staging_config = staging_config self._job_client: "DatabricksClient" = None - self._sql_client = None - def run(self) -> None: self._sql_client = self._job_client.sql_client From 91c00280ea80d98ad518cf5f0a1a087b774fe5ed Mon Sep 17 00:00:00 2001 From: Julian Alves <28436330+donotpush@users.noreply.github.com> Date: Wed, 22 Jan 2025 17:57:31 +0100 Subject: [PATCH 09/28] revert config.toml changes --- tests/.dlt/config.toml | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/.dlt/config.toml b/tests/.dlt/config.toml index 983cec2b6b..f185a73865 100644 --- a/tests/.dlt/config.toml +++ b/tests/.dlt/config.toml @@ -1,6 +1,4 @@ [runtime] -log_level="DEBUG" -log_format="JSON" # sentry_dsn="https://6f6f7b6f8e0f458a89be4187603b55fe@o1061158.ingest.sentry.io/4504819859914752" [tests] From de291268e41608c2515dcef26b466e7f06bb6f0c Mon Sep 17 00:00:00 2001 From: Julian Alves <28436330+donotpush@users.noreply.github.com> Date: Thu, 23 Jan 2025 12:41:06 +0100 Subject: [PATCH 10/28] force notebook auth --- dlt/destinations/impl/databricks/configuration.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dlt/destinations/impl/databricks/configuration.py b/dlt/destinations/impl/databricks/configuration.py index 0f092af4ec..462cdf0efa 100644 --- a/dlt/destinations/impl/databricks/configuration.py +++ b/dlt/destinations/impl/databricks/configuration.py @@ -48,8 +48,8 @@ def on_resolved(self) -> None: self.access_token = notebook_context.apiToken().getOrElse(None) self.server_hostname = notebook_context.browserHostName().getOrElse(None) - except Exception: - pass + except Exception as e: + raise e if not self.access_token or not self.server_hostname: raise ConfigurationValueError( From d288f11e9ce1dbcc3ebe70009c9e43b8af1e0b49 Mon Sep 17 00:00:00 2001 From: Julian Alves <28436330+donotpush@users.noreply.github.com> Date: Thu, 23 Jan 2025 17:23:57 +0100 Subject: [PATCH 11/28] enhanced config validations --- .../impl/databricks/configuration.py | 25 ++++++++++++------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/dlt/destinations/impl/databricks/configuration.py b/dlt/destinations/impl/databricks/configuration.py index 462cdf0efa..629faabfbc 100644 --- a/dlt/destinations/impl/databricks/configuration.py +++ b/dlt/destinations/impl/databricks/configuration.py @@ -40,24 +40,31 @@ def on_resolved(self) -> None: try: # attempt notebook context authentication from databricks.sdk import WorkspaceClient + from databricks.sdk.service.sql import EndpointInfo w = WorkspaceClient() - notebook_context = ( - w.dbutils.notebook.entry_point.getDbutils().notebook().getContext() # type: ignore[union-attr] - ) - self.access_token = notebook_context.apiToken().getOrElse(None) + self.access_token = w.dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().getOrElse(None) # type: ignore[union-attr] - self.server_hostname = notebook_context.browserHostName().getOrElse(None) + # pick the first warehouse on the list + warehouses: List[EndpointInfo] = list(w.warehouses.list()) + self.server_hostname = warehouses[0].odbc_params.hostname + self.http_path = warehouses[0].odbc_params.path except Exception as e: raise e - if not self.access_token or not self.server_hostname: + if not self.access_token: raise ConfigurationValueError( - "No valid authentication method detected. Provide either 'client_id' and" - " 'client_secret' for OAuth, or 'access_token' for token-based authentication," - " and the server_hostname." + "Databricks authentication failed: No valid authentication method detected." + " Please provide either 'client_id' and 'client_secret' for OAuth, or" + " 'access_token' for token-based authentication." ) + if not self.server_hostname or not self.http_path or not self.catalog: + raise ConfigurationValueError( + "Databricks authentication failed: 'server_hostname', 'http_path', and 'catalog'" + " are required parameters. Ensure all are provided." + ) + def to_connector_params(self) -> Dict[str, Any]: conn_params = dict( catalog=self.catalog, From 37ca7f461abbd728c6bea67f6495c8c05261f580 Mon Sep 17 00:00:00 2001 From: Julian Alves <28436330+donotpush@users.noreply.github.com> Date: Thu, 23 Jan 2025 17:56:14 +0100 Subject: [PATCH 12/28] force exception --- dlt/destinations/impl/databricks/configuration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dlt/destinations/impl/databricks/configuration.py b/dlt/destinations/impl/databricks/configuration.py index 629faabfbc..ecf21ab3a5 100644 --- a/dlt/destinations/impl/databricks/configuration.py +++ b/dlt/destinations/impl/databricks/configuration.py @@ -48,7 +48,7 @@ def on_resolved(self) -> None: # pick the first warehouse on the list warehouses: List[EndpointInfo] = list(w.warehouses.list()) self.server_hostname = warehouses[0].odbc_params.hostname - self.http_path = warehouses[0].odbc_params.path + #self.http_path = warehouses[0].odbc_params.path except Exception as e: raise e From e89548f511090e5e0ee7e99a37ea6e827ce1977c Mon Sep 17 00:00:00 2001 From: Julian Alves <28436330+donotpush@users.noreply.github.com> Date: Fri, 24 Jan 2025 00:23:07 +0100 Subject: [PATCH 13/28] fix config resolve --- dlt/destinations/impl/databricks/configuration.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/dlt/destinations/impl/databricks/configuration.py b/dlt/destinations/impl/databricks/configuration.py index ecf21ab3a5..7f48732c1c 100644 --- a/dlt/destinations/impl/databricks/configuration.py +++ b/dlt/destinations/impl/databricks/configuration.py @@ -13,8 +13,8 @@ @configspec class DatabricksCredentials(CredentialsConfiguration): catalog: str = None - server_hostname: str = None - http_path: str = None + server_hostname: Optional[str] = None + http_path: Optional[str] = None access_token: Optional[TSecretStrValue] = None client_id: Optional[TSecretStrValue] = None client_secret: Optional[TSecretStrValue] = None @@ -48,9 +48,9 @@ def on_resolved(self) -> None: # pick the first warehouse on the list warehouses: List[EndpointInfo] = list(w.warehouses.list()) self.server_hostname = warehouses[0].odbc_params.hostname - #self.http_path = warehouses[0].odbc_params.path - except Exception as e: - raise e + self.http_path = warehouses[0].odbc_params.path + except Exception: + pass if not self.access_token: raise ConfigurationValueError( From c4239299b9e5d458dabfc0ac1439f1350fc92804 Mon Sep 17 00:00:00 2001 From: Julian Alves <28436330+donotpush@users.noreply.github.com> Date: Fri, 24 Jan 2025 01:55:05 +0100 Subject: [PATCH 14/28] remove imports --- dlt/destinations/impl/databricks/configuration.py | 1 - dlt/destinations/impl/databricks/databricks.py | 1 - 2 files changed, 2 deletions(-) diff --git a/dlt/destinations/impl/databricks/configuration.py b/dlt/destinations/impl/databricks/configuration.py index 7f48732c1c..7a623d3f1f 100644 --- a/dlt/destinations/impl/databricks/configuration.py +++ b/dlt/destinations/impl/databricks/configuration.py @@ -5,7 +5,6 @@ from dlt.common.configuration.specs.base_configuration import CredentialsConfiguration, configspec from dlt.common.destination.reference import DestinationClientDwhWithStagingConfiguration from dlt.common.configuration.exceptions import ConfigurationValueError -from dlt.common import logger DATABRICKS_APPLICATION_ID = "dltHub_dlt" diff --git a/dlt/destinations/impl/databricks/databricks.py b/dlt/destinations/impl/databricks/databricks.py index c7e8ce2455..009529e01c 100644 --- a/dlt/destinations/impl/databricks/databricks.py +++ b/dlt/destinations/impl/databricks/databricks.py @@ -35,7 +35,6 @@ from dlt.destinations.sql_jobs import SqlMergeFollowupJob from dlt.destinations.job_impl import ReferenceFollowupJobRequest from dlt.destinations.utils import is_compression_disabled -from dlt.common.utils import uniq_id SUPPORTED_BLOB_STORAGE_PROTOCOLS = AZURE_BLOB_STORAGE_PROTOCOLS + S3_PROTOCOLS + GCS_PROTOCOLS From f0c72088f1d901ae0a1ac8690629c59f76284910 Mon Sep 17 00:00:00 2001 From: Julian Alves <28436330+donotpush@users.noreply.github.com> Date: Fri, 24 Jan 2025 19:15:22 +0100 Subject: [PATCH 15/28] test: config exceptions --- .../impl/databricks/configuration.py | 17 +++----- dlt/destinations/impl/databricks/factory.py | 2 + .../test_databricks_configuration.py | 11 ++++- .../load/pipeline/test_databricks_pipeline.py | 43 +++++++++++-------- 4 files changed, 43 insertions(+), 30 deletions(-) diff --git a/dlt/destinations/impl/databricks/configuration.py b/dlt/destinations/impl/databricks/configuration.py index 7a623d3f1f..677f52716b 100644 --- a/dlt/destinations/impl/databricks/configuration.py +++ b/dlt/destinations/impl/databricks/configuration.py @@ -53,15 +53,15 @@ def on_resolved(self) -> None: if not self.access_token: raise ConfigurationValueError( - "Databricks authentication failed: No valid authentication method detected." - " Please provide either 'client_id' and 'client_secret' for OAuth, or" - " 'access_token' for token-based authentication." + "Authentication failed: No valid authentication method detected. " + "Provide either 'client_id' and 'client_secret' for OAuth authentication, " + "or 'access_token' for token-based authentication." ) if not self.server_hostname or not self.http_path or not self.catalog: raise ConfigurationValueError( - "Databricks authentication failed: 'server_hostname', 'http_path', and 'catalog'" - " are required parameters. Ensure all are provided." + "Configuration error: Missing required parameters. " + "Please provide 'server_hostname', 'http_path', and 'catalog' in the configuration." ) def to_connector_params(self) -> Dict[str, Any]: @@ -94,13 +94,6 @@ class DatabricksClientConfiguration(DestinationClientDwhWithStagingConfiguration staging_volume_name: Optional[str] = None """Name of the Databricks managed volume for temporary storage, e.g., ... Defaults to '_dlt_temp_load_volume' if not set.""" - def on_resolved(self) -> None: - if self.staging_volume_name and self.staging_volume_name.count(".") != 2: - raise ConfigurationValueError( - f"Invalid staging_volume_name format: {self.staging_volume_name}. Expected format" - " is '..'." - ) - def __str__(self) -> str: """Return displayable destination location""" if self.staging_config: diff --git a/dlt/destinations/impl/databricks/factory.py b/dlt/destinations/impl/databricks/factory.py index da4eec5f20..a513c14353 100644 --- a/dlt/destinations/impl/databricks/factory.py +++ b/dlt/destinations/impl/databricks/factory.py @@ -154,6 +154,7 @@ def __init__( staging_credentials_name: t.Optional[str] = None, destination_name: t.Optional[str] = None, environment: t.Optional[str] = None, + staging_volume_name: t.Optional[str] = None, **kwargs: t.Any, ) -> None: """Configure the Databricks destination to use in a pipeline. @@ -173,5 +174,6 @@ def __init__( staging_credentials_name=staging_credentials_name, destination_name=destination_name, environment=environment, + staging_volume_name=staging_volume_name, **kwargs, ) diff --git a/tests/load/databricks/test_databricks_configuration.py b/tests/load/databricks/test_databricks_configuration.py index 8b3beed2b3..beaee66246 100644 --- a/tests/load/databricks/test_databricks_configuration.py +++ b/tests/load/databricks/test_databricks_configuration.py @@ -90,9 +90,18 @@ def test_databricks_abfss_converter() -> None: def test_databricks_auth_invalid() -> None: - with pytest.raises(ConfigurationValueError, match="No valid authentication method detected.*"): + with pytest.raises(ConfigurationValueError, match="Authentication failed:*"): os.environ["DESTINATION__DATABRICKS__CREDENTIALS__CLIENT_ID"] = "" os.environ["DESTINATION__DATABRICKS__CREDENTIALS__CLIENT_SECRET"] = "" os.environ["DESTINATION__DATABRICKS__CREDENTIALS__ACCESS_TOKEN"] = "" bricks = databricks() bricks.configuration(None, accept_partial=True) + + +def test_databricks_missing_config() -> None: + with pytest.raises(ConfigurationValueError, match="Configuration error:*"): + os.environ["DESTINATION__DATABRICKS__CREDENTIALS__SERVER_HOSTNAME"] = "" + os.environ["DESTINATION__DATABRICKS__CREDENTIALS__HTTP_PATH"] = "" + os.environ["DESTINATION__DATABRICKS__CREDENTIALS__CATALOG"] = "" + bricks = databricks() + bricks.configuration(None, accept_partial=True) diff --git a/tests/load/pipeline/test_databricks_pipeline.py b/tests/load/pipeline/test_databricks_pipeline.py index 5ff1cc2ca2..8b419e2ce4 100644 --- a/tests/load/pipeline/test_databricks_pipeline.py +++ b/tests/load/pipeline/test_databricks_pipeline.py @@ -1,5 +1,6 @@ import pytest import os +import dlt from dlt.common.utils import uniq_id from dlt.destinations import databricks @@ -220,30 +221,38 @@ def test_databricks_auth_token(destination_config: DestinationTestConfiguration) assert len(rows) == 3 -# TODO: test config staging_volume_name on_resolve -# TODO: modify the DestinationTestConfiguration -# TODO: add test databricks credentials default auth error -# TODO: test on notebook -# TODO: check that volume doesn't block schema drop -@pytest.mark.parametrize( - "destination_config", - destinations_configs(default_sql_configs=True, subset=("databricks",)), - ids=lambda x: x.name, -) -def test_databricks_direct_load(destination_config: DestinationTestConfiguration) -> None: +def test_databricks_direct_load() -> None: os.environ["DESTINATION__DATABRICKS__CREDENTIALS__CLIENT_ID"] = "" os.environ["DESTINATION__DATABRICKS__CREDENTIALS__CLIENT_SECRET"] = "" bricks = databricks() - config = bricks.configuration(None, accept_partial=True) - assert config.credentials.access_token - dataset_name = "test_databricks_token" + uniq_id() - pipeline = destination_config.setup_pipeline( - "test_databricks_token", dataset_name=dataset_name, destination=bricks + dataset_name = "test_databricks_direct_load" + uniq_id() + + pipeline = dlt.pipeline( + "test_databricks_direct_load", dataset_name=dataset_name, destination=bricks ) + info = pipeline.run([1, 2, 3], table_name="digits") + assert info.has_failed_jobs is False - info = pipeline.run([1, 2, 3], table_name="digits", **destination_config.run_kwargs) + with pipeline.sql_client() as client: + rows = client.execute_sql(f"select * from {dataset_name}.digits") + assert len(rows) == 3 + + +def test_databricks_direct_load_with_custom_staging_volume_name() -> None: + os.environ["DESTINATION__DATABRICKS__CREDENTIALS__CLIENT_ID"] = "" + os.environ["DESTINATION__DATABRICKS__CREDENTIALS__CLIENT_SECRET"] = "" + + custom_staging_volume_name = "dlt_ci.dlt_tests_shared.custom_volume" + bricks = databricks(staging_volume_name=custom_staging_volume_name) + + dataset_name = "test_databricks_direct_load" + uniq_id() + + pipeline = dlt.pipeline( + "test_databricks_direct_load", dataset_name=dataset_name, destination=bricks + ) + info = pipeline.run([1, 2, 3], table_name="digits") assert info.has_failed_jobs is False with pipeline.sql_client() as client: From 1271e221208093bc09941d47ada955d6c6728b54 Mon Sep 17 00:00:00 2001 From: Julian Alves <28436330+donotpush@users.noreply.github.com> Date: Fri, 24 Jan 2025 20:09:54 +0100 Subject: [PATCH 16/28] restore comments --- dlt/destinations/impl/databricks/databricks.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/dlt/destinations/impl/databricks/databricks.py b/dlt/destinations/impl/databricks/databricks.py index 009529e01c..8b9d40886d 100644 --- a/dlt/destinations/impl/databricks/databricks.py +++ b/dlt/destinations/impl/databricks/databricks.py @@ -67,7 +67,7 @@ def run(self) -> None: self._handle_staged_file() ) - # Determine the source format and any additional format options + # decide on source format, file_name will either be a local file or a bucket path source_format, format_options_clause, skip_load = self._determine_source_format( file_name, orig_bucket_path ) @@ -172,14 +172,16 @@ def _handle_staged_file(self) -> tuple[str, str, str, str]: credentials_clause = "" if self._job_client.config.is_staging_external_location: - # skip the credentials clause + # just skip the credentials clause for external location + # https://docs.databricks.com/en/sql/language-manual/sql-ref-external-locations.html#external-location pass elif self._job_client.config.staging_credentials_name: - # named credentials + # add named credentials credentials_clause = ( f"WITH(CREDENTIAL {self._job_client.config.staging_credentials_name} )" ) else: + # referencing an staged files via a bucket URL requires explicit AWS credentials if bucket_scheme == "s3": assert isinstance(staging_credentials, AwsCredentialsWithoutDefaults) s3_creds = staging_credentials.to_session_credentials() @@ -192,6 +194,7 @@ def _handle_staged_file(self) -> tuple[str, str, str, str]: assert isinstance( staging_credentials, AzureCredentialsWithoutDefaults ), "AzureCredentialsWithoutDefaults required to pass explicit credential" + # Explicit azure credentials are needed to load from bucket without a named stage credentials_clause = f"""WITH(CREDENTIAL(AZURE_SAS_TOKEN='{staging_credentials.azure_storage_sas_token}'))""" bucket_path = self.ensure_databricks_abfss_url( bucket_path, @@ -216,6 +219,7 @@ def _handle_staged_file(self) -> tuple[str, str, str, str]: staging_credentials.azure_account_host, ) + # always add FROM clause from_clause = f"FROM '{bucket_path}'" return from_clause, credentials_clause, file_name, orig_bucket_path From aec0d4525b36abc89c06e1c945e960e523b8fb59 Mon Sep 17 00:00:00 2001 From: Julian Alves <28436330+donotpush@users.noreply.github.com> Date: Fri, 24 Jan 2025 21:07:25 +0100 Subject: [PATCH 17/28] restored destination_config --- .../load/pipeline/test_databricks_pipeline.py | 32 ++++++++++++++----- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/tests/load/pipeline/test_databricks_pipeline.py b/tests/load/pipeline/test_databricks_pipeline.py index 8b419e2ce4..6eca9224d8 100644 --- a/tests/load/pipeline/test_databricks_pipeline.py +++ b/tests/load/pipeline/test_databricks_pipeline.py @@ -221,18 +221,25 @@ def test_databricks_auth_token(destination_config: DestinationTestConfiguration) assert len(rows) == 3 -def test_databricks_direct_load() -> None: +@pytest.mark.parametrize( + "destination_config", + destinations_configs( + default_sql_configs=True, bucket_subset=(AZ_BUCKET,), subset=("databricks",) + ), + ids=lambda x: x.name, +) +def test_databricks_direct_load(destination_config: DestinationTestConfiguration) -> None: os.environ["DESTINATION__DATABRICKS__CREDENTIALS__CLIENT_ID"] = "" os.environ["DESTINATION__DATABRICKS__CREDENTIALS__CLIENT_SECRET"] = "" bricks = databricks() dataset_name = "test_databricks_direct_load" + uniq_id() - - pipeline = dlt.pipeline( + pipeline = destination_config.setup_pipeline( "test_databricks_direct_load", dataset_name=dataset_name, destination=bricks ) - info = pipeline.run([1, 2, 3], table_name="digits") + + info = pipeline.run([1, 2, 3], table_name="digits", **destination_config.run_kwargs) assert info.has_failed_jobs is False with pipeline.sql_client() as client: @@ -240,7 +247,16 @@ def test_databricks_direct_load() -> None: assert len(rows) == 3 -def test_databricks_direct_load_with_custom_staging_volume_name() -> None: +@pytest.mark.parametrize( + "destination_config", + destinations_configs( + default_sql_configs=True, bucket_subset=(AZ_BUCKET,), subset=("databricks",) + ), + ids=lambda x: x.name, +) +def test_databricks_direct_load_with_custom_staging_volume_name( + destination_config: DestinationTestConfiguration, +) -> None: os.environ["DESTINATION__DATABRICKS__CREDENTIALS__CLIENT_ID"] = "" os.environ["DESTINATION__DATABRICKS__CREDENTIALS__CLIENT_SECRET"] = "" @@ -248,11 +264,11 @@ def test_databricks_direct_load_with_custom_staging_volume_name() -> None: bricks = databricks(staging_volume_name=custom_staging_volume_name) dataset_name = "test_databricks_direct_load" + uniq_id() - - pipeline = dlt.pipeline( + pipeline = destination_config.setup_pipeline( "test_databricks_direct_load", dataset_name=dataset_name, destination=bricks ) - info = pipeline.run([1, 2, 3], table_name="digits") + + info = pipeline.run([1, 2, 3], table_name="digits", **destination_config.run_kwargs) assert info.has_failed_jobs is False with pipeline.sql_client() as client: From 37008500ba663c1ce3e3cdec899a6b13a538c7e2 Mon Sep 17 00:00:00 2001 From: Julian Alves <28436330+donotpush@users.noreply.github.com> Date: Fri, 24 Jan 2025 22:01:45 +0100 Subject: [PATCH 18/28] fix pokema api values --- tests/load/sources/rest_api/test_rest_api_source.py | 4 ++-- tests/sources/rest_api/test_rest_api_source.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/load/sources/rest_api/test_rest_api_source.py b/tests/load/sources/rest_api/test_rest_api_source.py index 25a9952ba4..583a67e69a 100644 --- a/tests/load/sources/rest_api/test_rest_api_source.py +++ b/tests/load/sources/rest_api/test_rest_api_source.py @@ -56,9 +56,9 @@ def test_rest_api_source(destination_config: DestinationTestConfiguration, reque assert table_counts.keys() == {"pokemon_list", "berry", "location"} - assert table_counts["pokemon_list"] == 1302 + assert table_counts["pokemon_list"] == 1304 assert table_counts["berry"] == 64 - assert table_counts["location"] == 1036 + assert table_counts["location"] == 1039 @pytest.mark.parametrize( diff --git a/tests/sources/rest_api/test_rest_api_source.py b/tests/sources/rest_api/test_rest_api_source.py index 904bcaf159..fe6376141b 100644 --- a/tests/sources/rest_api/test_rest_api_source.py +++ b/tests/sources/rest_api/test_rest_api_source.py @@ -79,9 +79,9 @@ def test_rest_api_source(destination_name: str, invocation_type: str) -> None: assert table_counts.keys() == {"pokemon_list", "berry", "location"} - assert table_counts["pokemon_list"] == 1302 + assert table_counts["pokemon_list"] == 1304 assert table_counts["berry"] == 64 - assert table_counts["location"] == 1036 + assert table_counts["location"] == 1039 @pytest.mark.parametrize("destination_name", ALL_DESTINATIONS) From 730ff47d2ab3b79a331982fe39bdc3b2be705356 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Wed, 29 Jan 2025 13:24:39 +0100 Subject: [PATCH 19/28] enables databricks no stage tests --- .../load/pipeline/test_databricks_pipeline.py | 15 +++------------ tests/load/utils.py | 18 ++++++++---------- 2 files changed, 11 insertions(+), 22 deletions(-) diff --git a/tests/load/pipeline/test_databricks_pipeline.py b/tests/load/pipeline/test_databricks_pipeline.py index 6eca9224d8..c8e92b6744 100644 --- a/tests/load/pipeline/test_databricks_pipeline.py +++ b/tests/load/pipeline/test_databricks_pipeline.py @@ -229,15 +229,11 @@ def test_databricks_auth_token(destination_config: DestinationTestConfiguration) ids=lambda x: x.name, ) def test_databricks_direct_load(destination_config: DestinationTestConfiguration) -> None: - os.environ["DESTINATION__DATABRICKS__CREDENTIALS__CLIENT_ID"] = "" - os.environ["DESTINATION__DATABRICKS__CREDENTIALS__CLIENT_SECRET"] = "" - - bricks = databricks() - dataset_name = "test_databricks_direct_load" + uniq_id() pipeline = destination_config.setup_pipeline( - "test_databricks_direct_load", dataset_name=dataset_name, destination=bricks + "test_databricks_direct_load", dataset_name=dataset_name ) + assert pipeline.staging is None info = pipeline.run([1, 2, 3], table_name="digits", **destination_config.run_kwargs) assert info.has_failed_jobs is False @@ -249,17 +245,12 @@ def test_databricks_direct_load(destination_config: DestinationTestConfiguration @pytest.mark.parametrize( "destination_config", - destinations_configs( - default_sql_configs=True, bucket_subset=(AZ_BUCKET,), subset=("databricks",) - ), + destinations_configs(default_sql_configs=True, subset=("databricks",)), ids=lambda x: x.name, ) def test_databricks_direct_load_with_custom_staging_volume_name( destination_config: DestinationTestConfiguration, ) -> None: - os.environ["DESTINATION__DATABRICKS__CREDENTIALS__CLIENT_ID"] = "" - os.environ["DESTINATION__DATABRICKS__CREDENTIALS__CLIENT_SECRET"] = "" - custom_staging_volume_name = "dlt_ci.dlt_tests_shared.custom_volume" bricks = databricks(staging_volume_name=custom_staging_volume_name) diff --git a/tests/load/utils.py b/tests/load/utils.py index 5660202ec3..cb634a4425 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -328,8 +328,7 @@ def destinations_configs( destination_configs += [ DestinationTestConfiguration(destination_type=destination) for destination in SQL_DESTINATIONS - if destination - not in ("athena", "synapse", "databricks", "dremio", "clickhouse", "sqlalchemy") + if destination not in ("athena", "synapse", "dremio", "clickhouse", "sqlalchemy") ] destination_configs += [ DestinationTestConfiguration(destination_type="duckdb", file_format="parquet"), @@ -365,14 +364,6 @@ def destinations_configs( destination_type="clickhouse", file_format="jsonl", supports_dbt=False ) ] - destination_configs += [ - DestinationTestConfiguration( - destination_type="databricks", - file_format="parquet", - bucket_url=AZ_BUCKET, - extra_info="az-authorization", - ) - ] destination_configs += [ DestinationTestConfiguration( @@ -464,6 +455,13 @@ def destinations_configs( bucket_url=AZ_BUCKET, extra_info="az-authorization", ), + DestinationTestConfiguration( + destination_type="databricks", + staging="filesystem", + file_format="parquet", + bucket_url=AZ_BUCKET, + extra_info="az-authorization", + ), DestinationTestConfiguration( destination_type="databricks", staging="filesystem", From acece599ace4297de5d79acdd286f6ff9ee3c77b Mon Sep 17 00:00:00 2001 From: Julian Alves <28436330+donotpush@users.noreply.github.com> Date: Wed, 29 Jan 2025 17:34:31 +0100 Subject: [PATCH 20/28] fix databricks config on_resolved --- .../impl/databricks/configuration.py | 34 ++++++++++++------- .../test_databricks_configuration.py | 28 ++++++++++++--- 2 files changed, 45 insertions(+), 17 deletions(-) diff --git a/dlt/destinations/impl/databricks/configuration.py b/dlt/destinations/impl/databricks/configuration.py index 677f52716b..ac7038fc56 100644 --- a/dlt/destinations/impl/databricks/configuration.py +++ b/dlt/destinations/impl/databricks/configuration.py @@ -37,19 +37,13 @@ class DatabricksCredentials(CredentialsConfiguration): def on_resolved(self) -> None: if not ((self.client_id and self.client_secret) or self.access_token): try: - # attempt notebook context authentication + # attempt context authentication from databricks.sdk import WorkspaceClient - from databricks.sdk.service.sql import EndpointInfo w = WorkspaceClient() self.access_token = w.dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().getOrElse(None) # type: ignore[union-attr] - - # pick the first warehouse on the list - warehouses: List[EndpointInfo] = list(w.warehouses.list()) - self.server_hostname = warehouses[0].odbc_params.hostname - self.http_path = warehouses[0].odbc_params.path except Exception: - pass + self.access_token = None if not self.access_token: raise ConfigurationValueError( @@ -58,11 +52,25 @@ def on_resolved(self) -> None: "or 'access_token' for token-based authentication." ) - if not self.server_hostname or not self.http_path or not self.catalog: - raise ConfigurationValueError( - "Configuration error: Missing required parameters. " - "Please provide 'server_hostname', 'http_path', and 'catalog' in the configuration." - ) + if not self.server_hostname or not self.http_path: + try: + # attempt to fetch warehouse details + from databricks.sdk import WorkspaceClient + from databricks.sdk.service.sql import EndpointInfo + + w = WorkspaceClient() + warehouses: List[EndpointInfo] = list(w.warehouses.list()) + self.server_hostname = self.server_hostname or warehouses[0].odbc_params.hostname + self.http_path = self.http_path or warehouses[0].odbc_params.path + except Exception: + pass + + for param in ("catalog", "server_hostname", "http_path"): + if not getattr(self, param): + raise ConfigurationValueError( + f"Configuration error: Missing required parameter '{param}'. " + "Please provide it in the configuration." + ) def to_connector_params(self) -> Dict[str, Any]: conn_params = dict( diff --git a/tests/load/databricks/test_databricks_configuration.py b/tests/load/databricks/test_databricks_configuration.py index beaee66246..57e60ff0b9 100644 --- a/tests/load/databricks/test_databricks_configuration.py +++ b/tests/load/databricks/test_databricks_configuration.py @@ -98,10 +98,30 @@ def test_databricks_auth_invalid() -> None: bricks.configuration(None, accept_partial=True) -def test_databricks_missing_config() -> None: - with pytest.raises(ConfigurationValueError, match="Configuration error:*"): - os.environ["DESTINATION__DATABRICKS__CREDENTIALS__SERVER_HOSTNAME"] = "" - os.environ["DESTINATION__DATABRICKS__CREDENTIALS__HTTP_PATH"] = "" +def test_databricks_missing_config_catalog() -> None: + with pytest.raises( + ConfigurationValueError, match="Configuration error: Missing required parameter 'catalog'*" + ): os.environ["DESTINATION__DATABRICKS__CREDENTIALS__CATALOG"] = "" bricks = databricks() bricks.configuration(None, accept_partial=True) + + +def test_databricks_missing_config_http_path() -> None: + with pytest.raises( + ConfigurationValueError, + match="Configuration error: Missing required parameter 'http_path'*", + ): + os.environ["DESTINATION__DATABRICKS__CREDENTIALS__HTTP_PATH"] = "" + bricks = databricks() + bricks.configuration(None, accept_partial=True) + + +def test_databricks_missing_config_server_hostname() -> None: + with pytest.raises( + ConfigurationValueError, + match="Configuration error: Missing required parameter 'server_hostname'*", + ): + os.environ["DESTINATION__DATABRICKS__CREDENTIALS__SERVER_HOSTNAME"] = "" + bricks = databricks() + bricks.configuration(None, accept_partial=True) From 7de861c379ab6445b991d1494797ea83c78e887e Mon Sep 17 00:00:00 2001 From: Julian Alves <28436330+donotpush@users.noreply.github.com> Date: Wed, 29 Jan 2025 19:02:36 +0100 Subject: [PATCH 21/28] adjusted direct load file management --- .../impl/databricks/configuration.py | 7 +++++ .../impl/databricks/databricks.py | 26 ++++++++++++------- .../load/pipeline/test_databricks_pipeline.py | 2 +- 3 files changed, 24 insertions(+), 11 deletions(-) diff --git a/dlt/destinations/impl/databricks/configuration.py b/dlt/destinations/impl/databricks/configuration.py index ac7038fc56..d6b2ec8100 100644 --- a/dlt/destinations/impl/databricks/configuration.py +++ b/dlt/destinations/impl/databricks/configuration.py @@ -5,6 +5,7 @@ from dlt.common.configuration.specs.base_configuration import CredentialsConfiguration, configspec from dlt.common.destination.reference import DestinationClientDwhWithStagingConfiguration from dlt.common.configuration.exceptions import ConfigurationValueError +from dlt.common.pipeline import get_dlt_pipelines_dir DATABRICKS_APPLICATION_ID = "dltHub_dlt" @@ -24,6 +25,7 @@ class DatabricksCredentials(CredentialsConfiguration): """Additional keyword arguments that are passed to `databricks.sql.connect`""" socket_timeout: Optional[int] = 180 user_agent_entry: Optional[str] = DATABRICKS_APPLICATION_ID + staging_allowed_local_path: Optional[str] = None __config_gen_annotations__: ClassVar[List[str]] = [ "server_hostname", @@ -35,6 +37,10 @@ class DatabricksCredentials(CredentialsConfiguration): ] def on_resolved(self) -> None: + # conn parameter staging_allowed_local_path must be set to use 'REMOVE volume_path' SQL statement + if not self.staging_allowed_local_path: + self.staging_allowed_local_path = get_dlt_pipelines_dir() + if not ((self.client_id and self.client_secret) or self.access_token): try: # attempt context authentication @@ -80,6 +86,7 @@ def to_connector_params(self) -> Dict[str, Any]: access_token=self.access_token, session_configuration=self.session_configuration or {}, _socket_timeout=self.socket_timeout, + staging_allowed_local_path=self.staging_allowed_local_path, **(self.connection_parameters or {}), ) diff --git a/dlt/destinations/impl/databricks/databricks.py b/dlt/destinations/impl/databricks/databricks.py index 8b9d40886d..c8089cc550 100644 --- a/dlt/destinations/impl/databricks/databricks.py +++ b/dlt/destinations/impl/databricks/databricks.py @@ -28,6 +28,7 @@ from dlt.common.schema import TColumnSchema, Schema from dlt.common.schema.typing import TColumnType from dlt.common.storages import FilesystemConfiguration, fsspec_from_config +from dlt.common.utils import uniq_id from dlt.destinations.job_client_impl import SqlJobClientWithStagingDataset from dlt.destinations.exceptions import LoadJobTerminalException from dlt.destinations.impl.databricks.configuration import DatabricksClientConfiguration @@ -58,7 +59,9 @@ def run(self) -> None: is_local_file = not ReferenceFollowupJobRequest.is_reference_job(self._file_path) if is_local_file: # local file by uploading to a temporary volume on Databricks - from_clause, file_name = self._handle_local_file_upload(self._file_path) + from_clause, file_name, volume_path, volume_file_path = self._handle_local_file_upload( + self._file_path + ) credentials_clause = "" orig_bucket_path = None # not used for local file else: @@ -86,9 +89,12 @@ def run(self) -> None: self._sql_client.execute_sql(statement) - def _handle_local_file_upload(self, local_file_path: str) -> tuple[str, str]: + if is_local_file: + self._sql_client.execute_sql(f"REMOVE '{volume_file_path}'") + self._sql_client.execute_sql(f"REMOVE '{volume_path}'") + + def _handle_local_file_upload(self, local_file_path: str) -> tuple[str, str, str, str]: from databricks.sdk import WorkspaceClient - import time import io w: WorkspaceClient @@ -119,17 +125,17 @@ def _handle_local_file_upload(self, local_file_path: str) -> tuple[str, str]: volume_database = self._sql_client.dataset_name volume_name = "_dlt_staging_load_volume" - # create staging volume name fully_qualified_volume_name = f"{volume_catalog}.{volume_database}.{volume_name}" if self._job_client.config.staging_volume_name: fully_qualified_volume_name = self._job_client.config.staging_volume_name volume_catalog, volume_database, volume_name = fully_qualified_volume_name.split(".") + else: + # create staging volume name + self._sql_client.execute_sql(f""" + CREATE VOLUME IF NOT EXISTS {fully_qualified_volume_name} + """) - self._sql_client.execute_sql(f""" - CREATE VOLUME IF NOT EXISTS {fully_qualified_volume_name} - """) - - volume_path = f"/Volumes/{volume_catalog}/{volume_database}/{volume_name}/{time.time_ns()}" + volume_path = f"/Volumes/{volume_catalog}/{volume_database}/{volume_name}/{uniq_id()}" volume_file_path = f"{volume_path}/{volume_file_name}" with open(local_file_path, "rb") as f: @@ -139,7 +145,7 @@ def _handle_local_file_upload(self, local_file_path: str) -> tuple[str, str]: from_clause = f"FROM '{volume_path}'" - return from_clause, file_name + return from_clause, file_name, volume_path, volume_file_path def _handle_staged_file(self) -> tuple[str, str, str, str]: bucket_path = orig_bucket_path = ( diff --git a/tests/load/pipeline/test_databricks_pipeline.py b/tests/load/pipeline/test_databricks_pipeline.py index c8e92b6744..95b5fdf397 100644 --- a/tests/load/pipeline/test_databricks_pipeline.py +++ b/tests/load/pipeline/test_databricks_pipeline.py @@ -251,7 +251,7 @@ def test_databricks_direct_load(destination_config: DestinationTestConfiguration def test_databricks_direct_load_with_custom_staging_volume_name( destination_config: DestinationTestConfiguration, ) -> None: - custom_staging_volume_name = "dlt_ci.dlt_tests_shared.custom_volume" + custom_staging_volume_name = "dlt_ci.dlt_tests_shared.static_volume" bricks = databricks(staging_volume_name=custom_staging_volume_name) dataset_name = "test_databricks_direct_load" + uniq_id() From 0aab8d446c10ae1decaf5df52c5e7ae14ef7adef Mon Sep 17 00:00:00 2001 From: Julian Alves <28436330+donotpush@users.noreply.github.com> Date: Wed, 29 Jan 2025 23:55:08 +0100 Subject: [PATCH 22/28] direct load docs --- .../dlt-ecosystem/destinations/databricks.md | 42 +++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/docs/website/docs/dlt-ecosystem/destinations/databricks.md b/docs/website/docs/dlt-ecosystem/destinations/databricks.md index a28a42f761..f41cb6b851 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/databricks.md +++ b/docs/website/docs/dlt-ecosystem/destinations/databricks.md @@ -260,6 +260,48 @@ The JSONL format has some limitations when used with Databricks: 2. The following data types are not supported when using the JSONL format with `databricks`: `decimal`, `json`, `date`, `binary`. Use `parquet` if your data contains these types. 3. The `bigint` data type with precision is not supported with the JSONL format. +## Direct Load (Databricks Managed Volumes) + +`dlt` now supports **Direct Load**, enabling pipelines to run seamlessly from **Databricks Notebooks** without external staging. When executed in a Databricks Notebook, `dlt` uses the notebook context for configuration if not explicitly provided. + +Direct Load also works **outside Databricks**, requiring explicit configuration of `server_hostname`, `http_path`, `catalog`, and authentication (`client_id`/`client_secret` for OAuth or `access_token` for token-based authentication). + +The example below demonstrates how to load data directly from a **Databricks Notebook**. Simply specify the **Databricks catalog** and optionally a **fully qualified volume name** (recommended for production) – the remaining configuration comes from the notebook context: + +```py +import dlt +from dlt.destinations import databricks +from dlt.sources.rest_api import rest_api_source + +# Fully qualified Databricks managed volume (recommended for production) +# - dlt assumes the named volume already exists +staging_volume_name = "dlt_ci.dlt_tests_shared.static_volume" + +bricks = databricks(credentials={"catalog": "dlt_ci"}, staging_volume_name=staging_volume_name) + +pokemon_source = rest_api_source( + { + "client": {"base_url": "https://pokeapi.co/api/v2/"}, + "resource_defaults": {"endpoint": {"params": {"limit": 1000}}}, + "resources": ["pokemon"], + } +) + +pipeline = dlt.pipeline( + pipeline_name="rest_api_example", + dataset_name="rest_api_data", + destination=bricks, +) + +load_info = pipeline.run(pokemon_source) +print(load_info) +print(pipeline.dataset().pokemon.df()) +``` + +- If **no** *staging_volume_name* **is provided**, dlt creates a **default volume** automatically. +- **For production**, explicitly setting *staging_volume_name* is recommended. +- The volume is used as a **temporary location** to store files before loading. Files are **deleted immediately** after loading. + ## Staging support Databricks supports both Amazon S3, Azure Blob Storage, and Google Cloud Storage as staging locations. `dlt` will upload files in Parquet format to the staging location and will instruct Databricks to load data from there. From 4998163b052c2b4d329165b4d96caac492054e00 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Thu, 30 Jan 2025 11:34:16 +0100 Subject: [PATCH 23/28] filters by bucket when subset of destinations is set when creating test cases --- tests/load/pipeline/test_databricks_pipeline.py | 2 +- tests/load/utils.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/load/pipeline/test_databricks_pipeline.py b/tests/load/pipeline/test_databricks_pipeline.py index 95b5fdf397..b70779b1cf 100644 --- a/tests/load/pipeline/test_databricks_pipeline.py +++ b/tests/load/pipeline/test_databricks_pipeline.py @@ -20,7 +20,7 @@ @pytest.mark.parametrize( "destination_config", destinations_configs( - default_sql_configs=True, bucket_subset=(AZ_BUCKET,), subset=("databricks",) + default_staging_configs=True, bucket_subset=(AZ_BUCKET,), subset=("databricks",) ), ids=lambda x: x.name, ) diff --git a/tests/load/utils.py b/tests/load/utils.py index cb634a4425..b77121a3e4 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -659,7 +659,9 @@ def destinations_configs( destination_configs = [ conf for conf in destination_configs - if conf.destination_type != "filesystem" or conf.bucket_url in bucket_subset + # filter by bucket when (1) filesystem OR (2) specific set of destinations requested + if (conf.destination_type != "filesystem" and not subset) + or conf.bucket_url in bucket_subset ] if exclude: destination_configs = [ From 799c41d496cd22b0d48593de04c303d5bc12ffd6 Mon Sep 17 00:00:00 2001 From: Julian Alves <28436330+donotpush@users.noreply.github.com> Date: Thu, 30 Jan 2025 14:37:38 +0100 Subject: [PATCH 24/28] simpler file upload --- .../impl/databricks/databricks.py | 25 +------------------ .../load/pipeline/test_databricks_pipeline.py | 12 ++++----- 2 files changed, 7 insertions(+), 30 deletions(-) diff --git a/dlt/destinations/impl/databricks/databricks.py b/dlt/destinations/impl/databricks/databricks.py index c8089cc550..49a2324232 100644 --- a/dlt/destinations/impl/databricks/databricks.py +++ b/dlt/destinations/impl/databricks/databricks.py @@ -94,26 +94,6 @@ def run(self) -> None: self._sql_client.execute_sql(f"REMOVE '{volume_path}'") def _handle_local_file_upload(self, local_file_path: str) -> tuple[str, str, str, str]: - from databricks.sdk import WorkspaceClient - import io - - w: WorkspaceClient - - credentials = self._job_client.config.credentials - if credentials.client_id and credentials.client_secret: - # oauth authentication - w = WorkspaceClient( - host=credentials.server_hostname, - client_id=credentials.client_id, - client_secret=credentials.client_secret, - ) - elif credentials.access_token: - # token authentication - w = WorkspaceClient( - host=credentials.server_hostname, - token=credentials.access_token, - ) - file_name = FileStorage.get_file_name_from_file_path(local_file_path) volume_file_name = file_name if file_name.startswith(("_", ".")): @@ -138,10 +118,7 @@ def _handle_local_file_upload(self, local_file_path: str) -> tuple[str, str, str volume_path = f"/Volumes/{volume_catalog}/{volume_database}/{volume_name}/{uniq_id()}" volume_file_path = f"{volume_path}/{volume_file_name}" - with open(local_file_path, "rb") as f: - file_bytes = f.read() - binary_data = io.BytesIO(file_bytes) - w.files.upload(volume_file_path, binary_data, overwrite=True) + self._sql_client.execute_sql(f"PUT '{local_file_path}' INTO '{volume_file_path}' OVERWRITE") from_clause = f"FROM '{volume_path}'" diff --git a/tests/load/pipeline/test_databricks_pipeline.py b/tests/load/pipeline/test_databricks_pipeline.py index b70779b1cf..9431db7bd0 100644 --- a/tests/load/pipeline/test_databricks_pipeline.py +++ b/tests/load/pipeline/test_databricks_pipeline.py @@ -106,7 +106,7 @@ def test_databricks_external_location(destination_config: DestinationTestConfigu @pytest.mark.parametrize( "destination_config", destinations_configs( - default_sql_configs=True, bucket_subset=(AZ_BUCKET,), subset=("databricks",) + default_staging_configs=True, bucket_subset=(AZ_BUCKET,), subset=("databricks",) ), ids=lambda x: x.name, ) @@ -155,7 +155,9 @@ def test_databricks_gcs_external_location(destination_config: DestinationTestCon @pytest.mark.parametrize( "destination_config", - destinations_configs(default_sql_configs=True, subset=("databricks",)), + destinations_configs( + default_staging_configs=True, bucket_subset=(AZ_BUCKET,), subset=("databricks",) + ), ids=lambda x: x.name, ) def test_databricks_auth_oauth(destination_config: DestinationTestConfiguration) -> None: @@ -189,7 +191,7 @@ def test_databricks_auth_oauth(destination_config: DestinationTestConfiguration) @pytest.mark.parametrize( "destination_config", destinations_configs( - default_sql_configs=True, bucket_subset=(AZ_BUCKET,), subset=("databricks",) + default_staging_configs=True, bucket_subset=(AZ_BUCKET,), subset=("databricks",) ), ids=lambda x: x.name, ) @@ -223,9 +225,7 @@ def test_databricks_auth_token(destination_config: DestinationTestConfiguration) @pytest.mark.parametrize( "destination_config", - destinations_configs( - default_sql_configs=True, bucket_subset=(AZ_BUCKET,), subset=("databricks",) - ), + destinations_configs(default_sql_configs=True, subset=("databricks",)), ids=lambda x: x.name, ) def test_databricks_direct_load(destination_config: DestinationTestConfiguration) -> None: From 9ba1801ea1c993d21f143c6847d1fbb7845c5d66 Mon Sep 17 00:00:00 2001 From: Julian Alves <28436330+donotpush@users.noreply.github.com> Date: Thu, 30 Jan 2025 14:40:55 +0100 Subject: [PATCH 25/28] fix comment --- dlt/destinations/impl/databricks/databricks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dlt/destinations/impl/databricks/databricks.py b/dlt/destinations/impl/databricks/databricks.py index 49a2324232..8680eecd2d 100644 --- a/dlt/destinations/impl/databricks/databricks.py +++ b/dlt/destinations/impl/databricks/databricks.py @@ -110,7 +110,7 @@ def _handle_local_file_upload(self, local_file_path: str) -> tuple[str, str, str fully_qualified_volume_name = self._job_client.config.staging_volume_name volume_catalog, volume_database, volume_name = fully_qualified_volume_name.split(".") else: - # create staging volume name + # create staging volume named _dlt_staging_load_volume self._sql_client.execute_sql(f""" CREATE VOLUME IF NOT EXISTS {fully_qualified_volume_name} """) From 2c54ed9b21ee0f56963e72b46e4afc82767e6060 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Thu, 30 Jan 2025 22:37:17 +0100 Subject: [PATCH 26/28] passes authentication directly from workspace, adds proper fingerprinting --- .../impl/databricks/configuration.py | 50 +++++++++---- .../impl/databricks/databricks.py | 14 +++- .../impl/databricks/sql_client.py | 4 ++ .../dlt-ecosystem/destinations/databricks.md | 17 ++++- .../test_databricks_configuration.py | 70 +++++++++++++++++++ .../load/pipeline/test_databricks_pipeline.py | 17 ++++- 6 files changed, 152 insertions(+), 20 deletions(-) diff --git a/dlt/destinations/impl/databricks/configuration.py b/dlt/destinations/impl/databricks/configuration.py index d6b2ec8100..3e3cf659a5 100644 --- a/dlt/destinations/impl/databricks/configuration.py +++ b/dlt/destinations/impl/databricks/configuration.py @@ -1,11 +1,12 @@ import dataclasses from typing import ClassVar, Final, Optional, Any, Dict, List +from dlt.common import logger from dlt.common.typing import TSecretStrValue from dlt.common.configuration.specs.base_configuration import CredentialsConfiguration, configspec from dlt.common.destination.reference import DestinationClientDwhWithStagingConfiguration from dlt.common.configuration.exceptions import ConfigurationValueError -from dlt.common.pipeline import get_dlt_pipelines_dir +from dlt.common.utils import digest128 DATABRICKS_APPLICATION_ID = "dltHub_dlt" @@ -25,7 +26,6 @@ class DatabricksCredentials(CredentialsConfiguration): """Additional keyword arguments that are passed to `databricks.sql.connect`""" socket_timeout: Optional[int] = 180 user_agent_entry: Optional[str] = DATABRICKS_APPLICATION_ID - staging_allowed_local_path: Optional[str] = None __config_gen_annotations__: ClassVar[List[str]] = [ "server_hostname", @@ -37,10 +37,6 @@ class DatabricksCredentials(CredentialsConfiguration): ] def on_resolved(self) -> None: - # conn parameter staging_allowed_local_path must be set to use 'REMOVE volume_path' SQL statement - if not self.staging_allowed_local_path: - self.staging_allowed_local_path = get_dlt_pipelines_dir() - if not ((self.client_id and self.client_secret) or self.access_token): try: # attempt context authentication @@ -51,6 +47,15 @@ def on_resolved(self) -> None: except Exception: self.access_token = None + try: + from databricks.sdk import WorkspaceClient + + w = WorkspaceClient() + self.access_token = w.config.authenticate # type: ignore[assignment] + logger.info(f"Will attempt to use default auth of type {w.config.auth_type}") + except Exception: + pass + if not self.access_token: raise ConfigurationValueError( "Authentication failed: No valid authentication method detected. " @@ -62,12 +67,19 @@ def on_resolved(self) -> None: try: # attempt to fetch warehouse details from databricks.sdk import WorkspaceClient - from databricks.sdk.service.sql import EndpointInfo w = WorkspaceClient() - warehouses: List[EndpointInfo] = list(w.warehouses.list()) - self.server_hostname = self.server_hostname or warehouses[0].odbc_params.hostname - self.http_path = self.http_path or warehouses[0].odbc_params.path + # warehouse ID may be present in an env variable + if w.config.warehouse_id: + warehouse = w.warehouses.get(w.config.warehouse_id) + else: + # for some reason list of warehouses has different type than a single one 🤯 + warehouse = list(w.warehouses.list())[0] # type: ignore[assignment] + logger.info( + f"Will attempt to use warehouse {warehouse.id} to get sql connection params" + ) + self.server_hostname = self.server_hostname or warehouse.odbc_params.hostname + self.http_path = self.http_path or warehouse.odbc_params.path except Exception: pass @@ -86,7 +98,6 @@ def to_connector_params(self) -> Dict[str, Any]: access_token=self.access_token, session_configuration=self.session_configuration or {}, _socket_timeout=self.socket_timeout, - staging_allowed_local_path=self.staging_allowed_local_path, **(self.connection_parameters or {}), ) @@ -97,6 +108,9 @@ def to_connector_params(self) -> Dict[str, Any]: return conn_params + def __str__(self) -> str: + return f"databricks://{self.server_hostname}{self.http_path}/{self.catalog}" + @configspec class DatabricksClientConfiguration(DestinationClientDwhWithStagingConfiguration): @@ -108,10 +122,18 @@ class DatabricksClientConfiguration(DestinationClientDwhWithStagingConfiguration """If true, the temporary credentials are not propagated to the COPY command""" staging_volume_name: Optional[str] = None """Name of the Databricks managed volume for temporary storage, e.g., ... Defaults to '_dlt_temp_load_volume' if not set.""" + keep_staged_files: Optional[bool] = True + """Tells if to keep the files in internal (volume) stage""" def __str__(self) -> str: """Return displayable destination location""" - if self.staging_config: - return str(self.staging_config.credentials) + if self.credentials: + return str(self.credentials) else: - return "[no staging set]" + return "" + + def fingerprint(self) -> str: + """Returns a fingerprint of host part of a connection string""" + if self.credentials and self.credentials.server_hostname: + return digest128(self.credentials.server_hostname) + return "" diff --git a/dlt/destinations/impl/databricks/databricks.py b/dlt/destinations/impl/databricks/databricks.py index 8680eecd2d..3f13298fd4 100644 --- a/dlt/destinations/impl/databricks/databricks.py +++ b/dlt/destinations/impl/databricks/databricks.py @@ -1,3 +1,4 @@ +import os from typing import Optional, Sequence, List, cast from urllib.parse import urlparse, urlunparse @@ -58,6 +59,10 @@ def run(self) -> None: # decide if this is a local file or a staged file is_local_file = not ReferenceFollowupJobRequest.is_reference_job(self._file_path) if is_local_file: + # conn parameter staging_allowed_local_path must be set to use 'PUT/REMOVE volume_path' SQL statement + self._sql_client.native_connection.thrift_backend.staging_allowed_local_path = ( + os.path.dirname(self._file_path) + ) # local file by uploading to a temporary volume on Databricks from_clause, file_name, volume_path, volume_file_path = self._handle_local_file_upload( self._file_path @@ -89,9 +94,12 @@ def run(self) -> None: self._sql_client.execute_sql(statement) - if is_local_file: - self._sql_client.execute_sql(f"REMOVE '{volume_file_path}'") - self._sql_client.execute_sql(f"REMOVE '{volume_path}'") + if is_local_file and not self._job_client.config.keep_staged_files: + self._handle_staged_file_remove(volume_path, volume_file_path) + + def _handle_staged_file_remove(self, volume_path: str, volume_file_path: str) -> None: + self._sql_client.execute_sql(f"REMOVE '{volume_file_path}'") + self._sql_client.execute_sql(f"REMOVE '{volume_path}'") def _handle_local_file_upload(self, local_file_path: str) -> tuple[str, str, str, str]: file_name = FileStorage.get_file_name_from_file_path(local_file_path) diff --git a/dlt/destinations/impl/databricks/sql_client.py b/dlt/destinations/impl/databricks/sql_client.py index 9f695b9d6e..7ef4b979cd 100644 --- a/dlt/destinations/impl/databricks/sql_client.py +++ b/dlt/destinations/impl/databricks/sql_client.py @@ -88,7 +88,11 @@ def open_connection(self) -> DatabricksSqlConnection: if self.credentials.client_id and self.credentials.client_secret: conn_params["credentials_provider"] = self._get_oauth_credentials + elif callable(self.credentials.access_token): + # this is w.config.authenticator + conn_params["credentials_provider"] = lambda: self.credentials.access_token else: + # this is access token conn_params["access_token"] = self.credentials.access_token self._conn = databricks_lib.connect( diff --git a/docs/website/docs/dlt-ecosystem/destinations/databricks.md b/docs/website/docs/dlt-ecosystem/destinations/databricks.md index f41cb6b851..d970378cce 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/databricks.md +++ b/docs/website/docs/dlt-ecosystem/destinations/databricks.md @@ -240,6 +240,13 @@ You can find other options for specifying credentials in the [Authentication sec See [Staging support](#staging-support) for authentication options when `dlt` copies files from buckets. +### Using default credentials +If none of auth methods above is configured, `dlt` attempts to get authorization from the Databricks workspace context. The context may +come, for example, from a Notebook (runtime) or via standard set of env variables that Databricks Python sdk recognizes (ie. **DATABRICKS_TOKEN** or **DATABRICKS_HOST**) + +`dlt` is able to set `server_hostname` and `http_path` from available warehouses. We use default warehouse id (**DATABRICKS_WAREHOUSE_ID**) +if set (via env variable), or a first one on warehouse's list. + ## Write disposition All write dispositions are supported. @@ -300,7 +307,15 @@ print(pipeline.dataset().pokemon.df()) - If **no** *staging_volume_name* **is provided**, dlt creates a **default volume** automatically. - **For production**, explicitly setting *staging_volume_name* is recommended. -- The volume is used as a **temporary location** to store files before loading. Files are **deleted immediately** after loading. +- The volume is used as a **temporary location** to store files before loading. + +:::tip:: +You can delete staged files **immediately** after loading by setting the following config option: +```toml +[destination.databricks] +keep_staged_files = false +``` +::: ## Staging support diff --git a/tests/load/databricks/test_databricks_configuration.py b/tests/load/databricks/test_databricks_configuration.py index 57e60ff0b9..bac0a985e8 100644 --- a/tests/load/databricks/test_databricks_configuration.py +++ b/tests/load/databricks/test_databricks_configuration.py @@ -1,8 +1,12 @@ import pytest import os +from dlt.common.schema.schema import Schema +from dlt.common.utils import digest128 + pytest.importorskip("databricks") +import dlt from dlt.common.exceptions import TerminalValueError from dlt.common.configuration.exceptions import ConfigurationValueError from dlt.destinations.impl.databricks.databricks import DatabricksLoadJob @@ -12,6 +16,7 @@ from dlt.destinations.impl.databricks.configuration import ( DatabricksClientConfiguration, DATABRICKS_APPLICATION_ID, + DatabricksCredentials, ) # mark all tests as essential, do not remove @@ -43,6 +48,11 @@ def test_databricks_credentials_to_connector_params(): assert params["_socket_timeout"] == credentials.socket_timeout assert params["_user_agent_entry"] == DATABRICKS_APPLICATION_ID + displayable_location = str(credentials) + assert displayable_location.startswith( + "databricks://my-databricks.example.com/sql/1.0/warehouses/asdfe/my-catalog" + ) + def test_databricks_configuration() -> None: bricks = databricks() @@ -125,3 +135,63 @@ def test_databricks_missing_config_server_hostname() -> None: os.environ["DESTINATION__DATABRICKS__CREDENTIALS__SERVER_HOSTNAME"] = "" bricks = databricks() bricks.configuration(None, accept_partial=True) + + +def test_default_credentials() -> None: + # from dlt.destinations.impl.databricks.sql_client import DatabricksSqlClient + # create minimal default env + os.environ["DATABRICKS_TOKEN"] = dlt.secrets["destination.databricks.credentials.access_token"] + os.environ["DATABRICKS_HOST"] = dlt.secrets[ + "destination.databricks.credentials.server_hostname" + ] + + config = resolve_configuration( + DatabricksClientConfiguration( + credentials=DatabricksCredentials(catalog="dlt_ci") + )._bind_dataset_name(dataset_name="my-dataset-1234") + ) + # we pass authenticator that will be used to make connection, that's why callable + assert callable(config.credentials.access_token) + # taken from a warehouse + assert isinstance(config.credentials.http_path, str) + + bricks = databricks(credentials=config.credentials) + # "my-dataset-1234" not present (we check SQL execution) + with bricks.client(Schema("schema"), config) as client: + assert not client.is_storage_initialized() + + # check fingerprint not default + assert config.fingerprint() != digest128("") + + +def test_oauth2_credentials() -> None: + config = resolve_configuration( + DatabricksClientConfiguration( + credentials=DatabricksCredentials( + catalog="dlt_ci", client_id="ctx-id", client_secret="xx0xx" + ) + )._bind_dataset_name(dataset_name="my-dataset-1234"), + sections=("destination", "databricks"), + ) + # will resolve to oauth token + bricks = databricks(credentials=config.credentials) + # "my-dataset-1234" not present (we check SQL execution) + with pytest.raises(Exception, match="Client authentication failed"): + with bricks.client(Schema("schema"), config): + pass + + +def test_default_warehouse() -> None: + os.environ["DATABRICKS_TOKEN"] = dlt.secrets["destination.databricks.credentials.access_token"] + os.environ["DATABRICKS_HOST"] = dlt.secrets[ + "destination.databricks.credentials.server_hostname" + ] + # will force this warehouse + os.environ["DATABRICKS_WAREHOUSE_ID"] = "588dbd71bd802f4d" + + config = resolve_configuration( + DatabricksClientConfiguration( + credentials=DatabricksCredentials(catalog="dlt_ci") + )._bind_dataset_name(dataset_name="my-dataset-1234") + ) + assert config.credentials.http_path == "/sql/1.0/warehouses/588dbd71bd802f4d" diff --git a/tests/load/pipeline/test_databricks_pipeline.py b/tests/load/pipeline/test_databricks_pipeline.py index 9431db7bd0..41791059e5 100644 --- a/tests/load/pipeline/test_databricks_pipeline.py +++ b/tests/load/pipeline/test_databricks_pipeline.py @@ -1,5 +1,7 @@ import pytest import os + +from pytest_mock import MockerFixture import dlt from dlt.common.utils import uniq_id @@ -248,11 +250,19 @@ def test_databricks_direct_load(destination_config: DestinationTestConfiguration destinations_configs(default_sql_configs=True, subset=("databricks",)), ids=lambda x: x.name, ) -def test_databricks_direct_load_with_custom_staging_volume_name( +@pytest.mark.parametrize("keep_staged_files", (True, False)) +def test_databricks_direct_load_with_custom_staging_volume_name_and_file_removal( destination_config: DestinationTestConfiguration, + keep_staged_files: bool, + mocker: MockerFixture, ) -> None: + from dlt.destinations.impl.databricks.databricks import DatabricksLoadJob + + remove_spy = mocker.spy(DatabricksLoadJob, "_handle_staged_file_remove") custom_staging_volume_name = "dlt_ci.dlt_tests_shared.static_volume" - bricks = databricks(staging_volume_name=custom_staging_volume_name) + bricks = databricks( + staging_volume_name=custom_staging_volume_name, keep_staged_files=keep_staged_files + ) dataset_name = "test_databricks_direct_load" + uniq_id() pipeline = destination_config.setup_pipeline( @@ -261,6 +271,9 @@ def test_databricks_direct_load_with_custom_staging_volume_name( info = pipeline.run([1, 2, 3], table_name="digits", **destination_config.run_kwargs) assert info.has_failed_jobs is False + print(info) + + assert remove_spy.call_count == 0 if keep_staged_files else 2 with pipeline.sql_client() as client: rows = client.execute_sql(f"select * from {dataset_name}.digits") From 18b2bd8e572588173ec2a8c2a1a889fafea650e0 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Fri, 31 Jan 2025 11:56:12 +0100 Subject: [PATCH 27/28] use real client_id in tests --- .../test_databricks_configuration.py | 34 ++++++++++++------- 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/tests/load/databricks/test_databricks_configuration.py b/tests/load/databricks/test_databricks_configuration.py index bac0a985e8..08823f2454 100644 --- a/tests/load/databricks/test_databricks_configuration.py +++ b/tests/load/databricks/test_databricks_configuration.py @@ -137,14 +137,25 @@ def test_databricks_missing_config_server_hostname() -> None: bricks.configuration(None, accept_partial=True) -def test_default_credentials() -> None: - # from dlt.destinations.impl.databricks.sql_client import DatabricksSqlClient +@pytest.mark.parametrize("auth_type", ("pat", "oauth2")) +def test_default_credentials(auth_type: str) -> None: # create minimal default env - os.environ["DATABRICKS_TOKEN"] = dlt.secrets["destination.databricks.credentials.access_token"] os.environ["DATABRICKS_HOST"] = dlt.secrets[ "destination.databricks.credentials.server_hostname" ] - + if auth_type == "pat": + os.environ["DATABRICKS_TOKEN"] = dlt.secrets[ + "destination.databricks.credentials.access_token" + ] + else: + os.environ["DATABRICKS_CLIENT_ID"] = dlt.secrets[ + "destination.databricks.credentials.client_id" + ] + os.environ["DATABRICKS_CLIENT_SECRET"] = dlt.secrets[ + "destination.databricks.credentials.client_secret" + ] + + # will not pick up the credentials from "destination.databricks" config = resolve_configuration( DatabricksClientConfiguration( credentials=DatabricksCredentials(catalog="dlt_ci") @@ -165,20 +176,17 @@ def test_default_credentials() -> None: def test_oauth2_credentials() -> None: + dlt.secrets["destination.databricks.credentials.access_token"] = "" config = resolve_configuration( - DatabricksClientConfiguration( - credentials=DatabricksCredentials( - catalog="dlt_ci", client_id="ctx-id", client_secret="xx0xx" - ) - )._bind_dataset_name(dataset_name="my-dataset-1234"), + DatabricksClientConfiguration()._bind_dataset_name(dataset_name="my-dataset-1234-oauth"), sections=("destination", "databricks"), ) + assert config.credentials.access_token == "" # will resolve to oauth token bricks = databricks(credentials=config.credentials) - # "my-dataset-1234" not present (we check SQL execution) - with pytest.raises(Exception, match="Client authentication failed"): - with bricks.client(Schema("schema"), config): - pass + # "my-dataset-1234-oauth" not present (we check SQL execution) + with bricks.client(Schema("schema"), config) as client: + assert not client.is_storage_initialized() def test_default_warehouse() -> None: From 428c075644368682c61f98ccb6f7bb690d217f5b Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Sat, 1 Feb 2025 12:39:34 +0100 Subject: [PATCH 28/28] fixes config resolver to not pass NotResolved hints to config providers --- dlt/common/configuration/providers/vault.py | 3 ++ dlt/common/configuration/resolve.py | 40 ++++++++++--------- .../configuration/test_configuration.py | 12 +++--- .../test_databricks_configuration.py | 4 ++ 4 files changed, 35 insertions(+), 24 deletions(-) diff --git a/dlt/common/configuration/providers/vault.py b/dlt/common/configuration/providers/vault.py index 0ed8842d55..b8181a3e41 100644 --- a/dlt/common/configuration/providers/vault.py +++ b/dlt/common/configuration/providers/vault.py @@ -53,6 +53,9 @@ def get_value( value, _ = super().get_value(key, hint, pipeline_name, *sections) if value is None: # only secrets hints are handled + # TODO: we need to refine how we filer out non-secrets + # at the least we should load known fragments for fields + # that are part of a secret (ie. coming from Credentials) if self.only_secrets and not is_secret_hint(hint) and hint is not AnyType: return None, full_key diff --git a/dlt/common/configuration/resolve.py b/dlt/common/configuration/resolve.py index 97bcfd315e..dd9502baa6 100644 --- a/dlt/common/configuration/resolve.py +++ b/dlt/common/configuration/resolve.py @@ -196,13 +196,25 @@ def _resolve_config_fields( if key in config.__hint_resolvers__: # Type hint for this field is created dynamically hint = config.__hint_resolvers__[key](config) + # check if hint optional + is_optional = is_optional_type(hint) # get default and explicit values default_value = getattr(config, key, None) explicit_none = False + explicit_value = None + current_value = None traces: List[LookupTrace] = [] + def _set_field() -> None: + # collect unresolved fields + # NOTE: we hide B023 here because the function is called only within a loop + if not is_optional and current_value is None: # noqa + unresolved_fields[key] = traces # noqa + # set resolved value in config + if default_value != current_value: # noqa + setattr(config, key, current_value) # noqa + if explicit_values: - explicit_value = None if key in explicit_values: # allow None to be passed in explicit values # so we are able to reset defaults like in regular function calls @@ -211,14 +223,15 @@ def _resolve_config_fields( # detect dlt.config and dlt.secrets and force injection if isinstance(explicit_value, ConfigValueSentinel): explicit_value = None - else: - if is_hint_not_resolvable(hint): - # for final fields default value is like explicit - explicit_value = default_value - else: - explicit_value = None - current_value = None + if is_hint_not_resolvable(hint): + # do not resolve not resolvable, but allow for explicit values to be passed + if not explicit_none: + current_value = default_value if explicit_value is None else explicit_value + traces = [LookupTrace("ExplicitValues", None, key, current_value)] + _set_field() + continue + # explicit none skips resolution if not explicit_none: # if hint is union of configurations, any of them must be resolved @@ -276,16 +289,7 @@ def _resolve_config_fields( # set the trace for explicit none traces = [LookupTrace("ExplicitValues", None, key, None)] - # check if hint optional - is_optional = is_optional_type(hint) - # collect unresolved fields - if not is_optional and current_value is None: - unresolved_fields[key] = traces - # set resolved value in config - if default_value != current_value: - if not is_hint_not_resolvable(hint) or explicit_value is not None or explicit_none: - # ignore final types - setattr(config, key, current_value) + _set_field() # Check for dynamic hint resolvers which have no corresponding fields unmatched_hint_resolvers: List[str] = [] diff --git a/tests/common/configuration/test_configuration.py b/tests/common/configuration/test_configuration.py index 00a28d652e..8e8618f90f 100644 --- a/tests/common/configuration/test_configuration.py +++ b/tests/common/configuration/test_configuration.py @@ -397,13 +397,13 @@ class FinalConfiguration(BaseConfiguration): class FinalConfiguration2(BaseConfiguration): pipeline_name: Final[str] = None - c2 = resolve.resolve_configuration(FinalConfiguration2()) - assert dict(c2) == {"pipeline_name": None} + with pytest.raises(ConfigFieldMissingException): + resolve.resolve_configuration(FinalConfiguration2()) c2 = resolve.resolve_configuration( FinalConfiguration2(), explicit_value={"pipeline_name": "exp"} ) - assert c.pipeline_name == "exp" + assert c2.pipeline_name == "exp" with pytest.raises(ConfigFieldMissingException): resolve.resolve_configuration(FinalConfiguration2(), explicit_value={"pipeline_name": None}) @@ -435,13 +435,13 @@ class NotResolvedConfiguration(BaseConfiguration): class NotResolvedConfiguration2(BaseConfiguration): pipeline_name: Annotated[str, NotResolved()] = None - c2 = resolve.resolve_configuration(NotResolvedConfiguration2()) - assert dict(c2) == {"pipeline_name": None} + with pytest.raises(ConfigFieldMissingException): + resolve.resolve_configuration(NotResolvedConfiguration2()) c2 = resolve.resolve_configuration( NotResolvedConfiguration2(), explicit_value={"pipeline_name": "exp"} ) - assert c.pipeline_name == "exp" + assert c2.pipeline_name == "exp" with pytest.raises(ConfigFieldMissingException): resolve.resolve_configuration( NotResolvedConfiguration2(), explicit_value={"pipeline_name": None} diff --git a/tests/load/databricks/test_databricks_configuration.py b/tests/load/databricks/test_databricks_configuration.py index 08823f2454..cc98c47d33 100644 --- a/tests/load/databricks/test_databricks_configuration.py +++ b/tests/load/databricks/test_databricks_configuration.py @@ -177,6 +177,10 @@ def test_default_credentials(auth_type: str) -> None: def test_oauth2_credentials() -> None: dlt.secrets["destination.databricks.credentials.access_token"] = "" + # we must prime the "destinations" for google secret manager config provider + # because it retrieves catalog as first element and it is not secret. and vault providers + # are secret only + dlt.secrets.get("destination.credentials") config = resolve_configuration( DatabricksClientConfiguration()._bind_dataset_name(dataset_name="my-dataset-1234-oauth"), sections=("destination", "databricks"),